In [1]:
import torch
from torch import tensor, nn, optim
import torch.nn.functional as F

In [2]:
n_embed = 32
block_size = 16
bs=2
vocab_size = 500

In [3]:
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training
from torch.utils.data import DataLoader
from pathlib import Path

# Creating a multitrack tokenizer, read the doc to explore all the parameters
config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(config)

# Train the tokenizer with Byte Pair Encoding (BPE)
files_paths = list(Path("/notebooks/classical-music-gen/midis_test").glob("**/*.midi"))
tokenizer.train(vocab_size=vocab_size, files_paths=files_paths)
tokenizer.save(Path("tokenizer", "tokenizer.json"))
# And pushing it to the Hugging Face hub (you can download it back with .from_pretrained)
tokenizer.push_to_hub("ABicGrill/miditok_tokenizer", private=True, token="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")

# Split MIDIs into smaller chunks for training
dataset_chunks_dir = Path("chunks")
split_files_for_training(
    files_paths=files_paths,
    tokenizer=tokenizer,
    save_dir=dataset_chunks_dir,
    max_seq_len=1024,
)

# Create a Dataset, a DataLoader and a collator to train a model
dataset = DatasetMIDI(
    files_paths=list(dataset_chunks_dir.glob("**/*.midi")),
    tokenizer=tokenizer,
    max_seq_len=block_size+1,
    bos_token_id=tokenizer["BOS_None"],
    eos_token_id=tokenizer["EOS_None"],
)
collator = DataCollator(tokenizer.pad_token_id, copy_inputs_as_labels=False)
dataloader = DataLoader(dataset, batch_size=bs, collate_fn=collator)

# Iterate over the dataloader to train a model
for batch in dataloader:
    print(batch)
    break

  super().__init__(tokenizer_config, params)







No files have been modified since last commit. Skipping to prevent empty commit.
Splitting music files (chunks): 100%|██████████| 5/5 [00:00<00:00, 10.01it/s]

{'input_ids': tensor([[313, 223, 265, 169, 254, 219, 170, 250, 350, 172, 242, 287, 174, 271,
          93, 112, 176],
        [285, 236, 217,  71, 355,  28, 217,  16, 210,  68, 215, 168, 273, 227,
         171, 302, 431]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]],
       dtype=torch.int32)}





In [11]:
n_embed = 32
class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.head_size = head_size
        self.k = nn.Linear(n_embed, head_size)
        self.q = nn.Linear(n_embed, head_size)
        self.v = nn.Linear(n_embed, head_size)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        
    def forward(self, x):
        B, T, C = x.shape
        k = self.k(x) # B, T, head_size
        q = self.q(x) # B, T, head_size
        v = self.k(x) # B, T, head_size
        
        wei = k @ q.transpose(-2, -1) * self.head_size ** -0.5
        wei = wei.masked_fill(self.tril[:T, :T]==0, float('-inf'))
        wei = wei.softmax(-1)
        out = wei @ v
        return out
        
class MusicModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embs = nn.Embedding(vocab_size, n_embed)
        self.pos_embs = nn.Embedding(block_size, n_embed)
        self.head = Head(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)
        
    def forward(self, x, targets=None, train=False):
        xb = x[:,:len(x[0])-1] if train else x
        if train:
            targets = x[:,1:len(x[0])]
        token_embs = self.token_embs(xb) # B, T, C
        pos_embs = self.pos_embs(torch.arange(len(x[0])-1)) if train else self.pos_embs(torch.arange(len(x[0]))) # T, C

        x = token_embs + pos_embs
        x = self.head(x)
        out= self.lm_head(x)
        if targets is not None:
            B, T, C = out.shape
            out = out.view(B * T, C)
            targets = targets.reshape(B*T)
            loss = F.cross_entropy(out, targets)
            return out, loss
            
        return out, None
    
    def generate(self, idx, max_tokens):
        for i in range(max_tokens):
            idx_cond = idx[:, -block_size:]
            out, loss = self(idx_cond)
            out = out[:,-1,:]
            probs = out.softmax(-1)
            preds = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, preds), dim=-1)
            
        return idx

In [12]:
model = MusicModel()

In [13]:
batch = next(iter(dataloader))
print(batch['input_ids'].shape)
model(batch['input_ids'], train=True)

torch.Size([2, 17])


(tensor([[ 0.0551, -0.4080,  0.4622,  ..., -0.5808,  1.1027, -0.1700],
         [ 0.0690, -0.2885,  0.4294,  ..., -0.4132,  0.8041,  0.0054],
         [ 0.2931, -0.1585,  0.2932,  ...,  0.0552,  0.1912,  0.2538],
         ...,
         [ 0.2675,  0.1726,  0.0611,  ...,  0.0640,  0.0796, -0.1878],
         [ 0.4842, -0.2243, -0.0027,  ..., -0.0271, -0.2844,  0.1868],
         [ 0.1485, -0.0588,  0.1964,  ..., -0.1663,  0.0137,  0.0048]],
        grad_fn=<ViewBackward0>),
 tensor(6.2268, grad_fn=<NllLossBackward0>))

In [14]:
opt = optim.AdamW(model.parameters(), lr=0.002)
n_epochs = 2
for epoch in range(n_epochs):
    for batch in dataloader:
        out, loss = model(batch['input_ids'], train=True)
        
        loss.backward()
        with torch.no_grad():
            opt.step()
            opt.zero_grad()
            
        # print(loss.item())
        
    print(loss)

tensor(6.0977, grad_fn=<NllLossBackward0>)
tensor(5.6310, grad_fn=<NllLossBackward0>)


In [15]:
tokenizer(model.generate(torch.ones((1, 1), dtype=torch.long), max_tokens=3000)[0]).dump_midi("bruh.mid")