In [1]:
import torch
from torch import tensor, nn, optim
import torch.nn.functional as F

In [2]:
n_embed = 32
block_size = 16
bs=2
vocab_size = 500

In [3]:
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training
from torch.utils.data import DataLoader
from pathlib import Path

# Creating a multitrack tokenizer, read the doc to explore all the parameters
config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(config)

# Train the tokenizer with Byte Pair Encoding (BPE)
files_paths = list(Path("/notebooks/classical-music-gen/midis_test").glob("**/*.midi"))
tokenizer.train(vocab_size=vocab_size, files_paths=files_paths)
tokenizer.save(Path("tokenizer", "tokenizer.json"))
# And pushing it to the Hugging Face hub (you can download it back with .from_pretrained)
tokenizer.push_to_hub("ABicGrill/miditok_tokenizer", private=True, token="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")

# Split MIDIs into smaller chunks for training
dataset_chunks_dir = Path("chunks")
split_files_for_training(
    files_paths=files_paths,
    tokenizer=tokenizer,
    save_dir=dataset_chunks_dir,
    max_seq_len=1024,
)

# Create a Dataset, a DataLoader and a collator to train a model
dataset = DatasetMIDI(
    files_paths=list(dataset_chunks_dir.glob("**/*.midi")),
    tokenizer=tokenizer,
    max_seq_len=block_size+1,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
collator = DataCollator(tokenizer.pad_token_id, copy_inputs_as_labels=False)
dataloader = DataLoader(dataset, batch_size=bs, collate_fn=collator)

# Iterate over the dataloader to train a model
for batch in dataloader:
    print(batch)
    break

  super().__init__(tokenizer_config, params)







No files have been modified since last commit. Skipping to prevent empty commit.
Splitting music files (chunks): 100%|██████████| 5/5 [00:00<00:00,  9.67it/s]


{'input_ids': tensor([[313, 223, 265, 169, 254, 219, 170, 250, 350, 172, 242, 287, 174, 271,
          93, 112, 176],
        [285, 236, 217,  71, 355,  28, 217,  16, 210,  68, 215, 168, 273, 227,
         171, 302, 431]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       dtype=torch.int32)}


In [99]:
n_embed = 32
class MusicModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embs = nn.Embedding(vocab_size, n_embed)
        self.pos_embs = nn.Embedding(block_size, n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, x, targets=None, train=False):
        xb = x[:,:len(x[0])-1] if train else x
        if train:
            targets = x[:,1:len(x[0])]
        token_embs = self.token_embs(xb) # B, T, C
        pos_embs = self.pos_embs(torch.arange(len(x[0])-1)) if train else self.pos_embs(torch.arange(len(x[0]))) # T, C

        embs = token_embs + pos_embs
        out= self.lm_head(embs)
        if targets is not None:
            B, T, C = out.shape
            out = out.view(B * T, C)
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(out, targets)
            return out, loss
            
        return out, None
    
    def generate(self, idx, max_tokens):
        for i in range(max_tokens):
            idx_cond = idx[:, -block_size:]
            out, loss = self(idx_cond)
            out = out[:,-1,:]
            probs = out.softmax(-1)
            preds = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, preds), dim=-1)
            
        return idx

In [100]:
model = MusicModel()

In [101]:
batch = next(iter(dataloader))
print(batch['input_ids'].shape)
model(batch['input_ids'], train=True)

torch.Size([2, 17])


(tensor([[ 0.4797, -0.4055, -0.9884,  ...,  1.4533,  0.6016, -1.1717],
         [ 0.4810,  0.1952, -0.8077,  ...,  0.3127, -0.0186, -1.0887],
         [ 0.4758,  0.0312,  0.2170,  ...,  1.2059, -0.4132, -0.2913],
         ...,
         [-0.4644, -0.5362, -1.4650,  ...,  0.6117, -0.4254,  0.7923],
         [-1.3241,  0.6971, -0.3349,  ..., -1.5338,  0.0239, -0.4660],
         [-0.5952, -0.7234,  1.2323,  ...,  0.7520,  1.1895,  0.1760]],
        grad_fn=<ViewBackward0>),
 tensor(6.3863, grad_fn=<NllLossBackward0>))

In [102]:
opt = optim.AdamW(model.parameters(), lr=0.002)
n_epochs = 2
for epoch in range(n_epochs):
    for batch in dataloader:
        out, loss = model(batch['input_ids'], train=True)
        
        loss.backward()
        with torch.no_grad():
            opt.step()
            opt.zero_grad()
            
        # print(loss.item())
        
    print(loss)

tensor(6.0237, grad_fn=<NllLossBackward0>)
tensor(5.4833, grad_fn=<NllLossBackward0>)


In [107]:
tokenizer(model.generate(torch.ones((1, 1), dtype=torch.long), max_tokens=23)[0]).dump_midi("bruh.mid")