### MusicGen Model

In [2]:
import torch
from musicgen_inpaint import *

  from .autonotebook import tqdm as notebook_tqdm


In [17]:
preprocess_dataset()

[Preprocessing] Complete.


In [44]:
model = load_small_model()



In [None]:
tokenize_dataset(model)

In [45]:
import torchaudio
from IPython.display import Audio, display

left_path = "dataset/chunks/203 - Lift_chunk0014.wav"
right_path = "dataset/chunks/Radiohead - Creep_chunk0000.wav"

print("Left:", left_path)
print("Right:", right_path)

L, _ = torchaudio.load(left_path)
R, _ = torchaudio.load(right_path)

L = L.to("cuda")
R = R.to("cuda")

display(Audio(L.cpu().numpy(), rate=model.sample_rate))
display(Audio(R.cpu().numpy(), rate=model.sample_rate))

Left: dataset/chunks/203 - Lift_chunk0014.wav
Right: dataset/chunks/Radiohead - Creep_chunk0000.wav


In [46]:
output = inpaint_audio(model, L, R, left_sec=5, right_sec=5, mask_sec=2,)
Audio(output, rate=32000)

In [None]:
import os
from torch.utils.data import DataLoader

class TokenDataset(torch.utils.data.Dataset):
    def __init__(self, token_dir="dataset/tokenized"):
        self.files = [
            os.path.join(token_dir, f) 
            for f in os.listdir(token_dir) 
            if f.endswith(".pt")
        ]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        codes = torch.load(self.files[idx])
        if codes.dim() == 1:
            codes = codes.unsqueeze(0)
        return codes.long()

dataloader = DataLoader(
    TokenDataset(),
    batch_size=8,
    shuffle=True,
    drop_last=True
)

In [None]:
import torch

def random_mask(tokens, mask_id, fraction=0.15):
    B, K, T = tokens.shape
    mask_len = int(T * fraction)

    start = torch.randint(0, T - mask_len, (1,)).item()
    end = start + mask_len

    masked = tokens.clone()
    masked[:, :, start:end] = mask_id

    mask_region = torch.zeros_like(tokens, dtype=torch.bool)
    mask_region[:, :, start:end] = True

    return masked, mask_region


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm import tqdm

os.makedirs("checkpoints", exist_ok=True)
    
lm = model.lm.to('cuda') 
optimizer = optim.AdamW(lm.parameters(), lr=1e-4) 
criterion = nn.CrossEntropyLoss()
mask_id = int(lm.special_token_id or 0)

for epoch in range(5):

    epoch_loss = 0.0
    progress = tqdm(dataloader, desc=f"Epoch {epoch}", leave=True)

    for codes in progress:
        codes = codes.to("cuda")

        masked_tokens, mask_region = random_mask(codes, mask_id)

        descriptions = [None] * codes.size(0)
        attributes, _ = model._prepare_tokens_and_attributes(descriptions, None)

        out = lm.compute_predictions(masked_tokens, conditions=attributes)
        logits = out.logits

        logits_flat = logits.permute(0, 2, 1, 3).reshape(-1, logits.size(-1))
        target_flat = codes.permute(0, 2, 1).reshape(-1)
        mask_flat = mask_region.permute(0, 2, 1).reshape(-1)

        loss = criterion(logits_flat[mask_flat], target_flat[mask_flat])

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()
        avg_loss = epoch_loss / (progress.n + 1)

        progress.set_postfix({
            "loss": f"{loss.item():.4f}",
            "avg": f"{avg_loss:.4f}",
            "T": logits.shape[2]
        })

    print(f"[Epoch {epoch}] avg_loss={avg_loss:.4f}")
    
    save_path = f"checkpoints/lm_epoch_{epoch}.pt"
    torch.save(model.lm.state_dict(), save_path)
    print("Saved LM checkpoint:", save_path)


Epoch 0:   0%|          | 0/1653 [00:00<?, ?it/s]

Epoch 0: 100%|██████████| 1653/1653 [3:12:46<00:00,  7.00s/it, loss=5.9774, avg=nan, T=200]  


[Epoch 0] avg_loss=nan
Saved LM checkpoint: checkpoints/lm_epoch_0.pt


Epoch 1: 100%|██████████| 1653/1653 [3:12:20<00:00,  6.98s/it, loss=6.8059, avg=nan, T=200]  


[Epoch 1] avg_loss=nan
Saved LM checkpoint: checkpoints/lm_epoch_1.pt


Epoch 2: 100%|██████████| 1653/1653 [3:12:10<00:00,  6.98s/it, loss=6.7930, avg=nan, T=200]  


[Epoch 2] avg_loss=nan
Saved LM checkpoint: checkpoints/lm_epoch_2.pt


Epoch 3: 100%|██████████| 1653/1653 [3:23:21<00:00,  7.38s/it, loss=7.3029, avg=nan, T=200]  


[Epoch 3] avg_loss=nan
Saved LM checkpoint: checkpoints/lm_epoch_3.pt


Epoch 4: 100%|██████████| 1653/1653 [3:20:48<00:00,  7.29s/it, loss=7.3065, avg=nan, T=200]  


[Epoch 4] avg_loss=nan
Saved LM checkpoint: checkpoints/lm_epoch_4.pt


In [15]:
torch.save(model.lm.state_dict(), "lm_inpaint_final.pt")
print("Saved model.")

Saved model.


In [3]:
state = torch.load("checkpoints/lm_epoch_4.pt", map_location="cuda")

model = load_small_model()
model.lm.load_state_dict(state)

print("Model restored!")




Model restored!


In [9]:
import torchaudio
from IPython.display import Audio, display

left_path = "dataset/chunks/203 - Lift_chunk0014.wav"
right_path = "dataset/chunks/Radiohead - Creep_chunk0000.wav"
output = inpaint_audio(model, left_path, right_path, left_sec=5, right_sec=5, mask_sec=2,)
output = output[0]

sf.write("_temp.wav", output, 32000)

Audio(output, rate=32000)