In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from mamba_ssm.models.mixer_seq_simple import MambaConfig, MambaLMHeadModel
import lightning as L
from lightning.pytorch.loggers import TensorBoardLogger
import torch.nn.functional as F
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetTok, DataCollator
from pathlib import Path
from symusic import Score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
VOCAB_SIZE = 5000

In [3]:
# Creating a multitrack tokenizer configuration, read the doc to explore other parameters
config = TokenizerConfig(num_velocities=16, use_chords=True, use_programs=True)
tokenizer = REMI(config)
# tokenizer = REMI(params='./tokenizer.json')

# Trains the tokenizer with BPE, and save it to load it back later
midi_paths = list(Path("../data/maestro-v3.0.0/2018").glob("**/*.midi"))
# tokenizer.learn_bpe(vocab_size=VOCAB_SIZE, files_paths=midi_paths)
# tokenizer.save_params(Path("tokenizer.json"))

# Creates a Dataset and a collator to be used with a PyTorch DataLoader to train a model
dataset = DatasetTok(
    files_paths=midi_paths,
    min_seq_len=100,
    max_seq_len=1024,
    tokenizer=tokenizer,
)
collator = DataCollator(
    tokenizer["PAD_None"], tokenizer["BOS_None"], tokenizer["EOS_None"]
)

Loading data: ../data/maestro-v3.0.0/2018: 100%|██████████| 93/93 [00:20<00:00,  4.47it/s]


In [16]:
config = MambaConfig(d_model=384,
                     n_layer=2,
                     vocab_size=VOCAB_SIZE)

class MambaModel(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.config = config
        self.model = MambaLMHeadModel(self.config).to('cuda')
        
    def forward(self, x):
        x = self.model(x).logits
        return x

In [17]:
lr = 3e-4
mamba = MambaModel(config)

mamba

MambaModel(
  (model): MambaLMHeadModel(
    (backbone): MixerModel(
      (embedding): Embedding(5000, 384)
      (layers): ModuleList(
        (0-1): 2 x Block(
          (mixer): Mamba(
            (in_proj): Linear(in_features=384, out_features=1536, bias=False)
            (conv1d): Conv1d(768, 768, kernel_size=(4,), stride=(1,), padding=(3,), groups=768)
            (act): SiLU()
            (x_proj): Linear(in_features=768, out_features=56, bias=False)
            (dt_proj): Linear(in_features=24, out_features=768, bias=True)
            (out_proj): Linear(in_features=768, out_features=384, bias=False)
          )
          (norm): RMSNorm()
        )
      )
      (norm_f): RMSNorm()
    )
    (lm_head): Linear(in_features=384, out_features=5000, bias=False)
  )
)

In [18]:
class LitMamba(L.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.model = model

    def training_step(self, batch, batch_idx):
        batch = batch['input_ids']
        x, y_true = batch[:, :-1], batch[:, 1:]
        
        y_pred = self.model(x) # -> BATCH x SEQ x VOCAB
        
        y_pred = y_pred.reshape(-1, y_pred.shape[-1])
        y_true = y_true.reshape(-1)

        loss = F.cross_entropy(y_pred, y_true)
        
        # TODO
        self.logger.experiment.add_scalar('Loss', loss, self.trainer.global_step)
        # self.logger.experiment

        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(mamba.parameters(), lr=3e-4)
        return optimizer

In [19]:
data_loader = DataLoader(dataset=dataset, collate_fn=collator, batch_size=4, num_workers=11)

m = LitMamba(mamba)

In [20]:
logger = TensorBoardLogger("tb_logs", name="my_model")
trainer = L.Trainer(max_epochs=5, devices=1, logger=logger)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
trainer.fit(model=m, train_dataloaders=data_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type       | Params
-------------------------------------
0 | model | MambaModel | 3.8 M 
-------------------------------------
3.8 M     Trainable params
0         Non-trainable params
3.8 M     Total params
15.395    Total estimated model params size (MB)


Epoch 4: 100%|██████████| 1001/1001 [00:30<00:00, 33.36it/s, v_num=6]

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|██████████| 1001/1001 [00:30<00:00, 33.21it/s, v_num=6]


In [22]:
from mamba_ssm.utils.generation import decode

In [69]:
for i in data_loader:
    sample = i['input_ids'][1:2, :80]
    break


test_music = decode(sample.to('cuda'), model=mamba.model.to('cuda'), max_length=1000)

In [70]:
test_music.sequences

tensor([[  1, 102, 124, 281,  37, 100, 129, 188, 281,  30, 100, 111, 194, 281,
          46,  99, 116, 281,  54, 103, 117, 201, 281,  40, 100, 111,   4, 187,
         281,  36,  97, 126, 190, 281,  45,  98, 115, 281,  54,  98, 140, 195,
         281,  28,  97, 110, 200, 281,  47, 100, 118, 201, 281,  45,  98, 113,
           4, 173, 281,  40,  98, 112, 176, 281,  48, 101, 125, 281,  52,  99,
         119, 177, 281,  36,  98, 120, 181, 281,  28,  96, 112, 185, 281,  40,
          98, 110, 189, 281,  40,  98, 110, 192, 281,  40,  98, 110, 195, 281,
          40,  98, 110, 197, 281,  40,  98, 110, 199, 281,  40,  98, 110, 201,
         281,  40,  98, 110, 204, 281,  40,  98, 110,   4, 175, 281,  40,  98,
         110, 177, 281,  40,  98, 110, 179, 281,  40,  98, 110, 181, 281,  40,
          98, 110, 183, 281,  40,  98, 110, 185, 281,  40,  98, 110, 187, 281,
          40,  98, 110, 189, 281,  40,  98, 110, 191, 281,  40,  98, 110, 193,
         281,  40,  98, 110, 195, 281,  40,  98, 110

In [71]:
tokenizer.tokens_to_midi(test_music.sequences.reshape(-1).tolist()).dump_midi('xd.midi')