In [1]:
import torch
from trainer import Trainer, TrainerConfig
from model import Samba, SambaConfig, VocabPart
import vocab

In [2]:
torch.manual_seed(54384)

model = Samba(SambaConfig(
    vocab_parts=[
        VocabPart(vocab.PITCH_COUNT, 1.0),
        VocabPart(vocab.VELOCITY_COUNT, 0.5),
        VocabPart(vocab.ADVANCE_COUNT, 1.0),
        VocabPart(vocab.DURATION_COUNT, 0.75),
    ],
    dropout_rate=0.0,
    embedded_size=768 // 2,
    d_state=8,
    d_conv=4,
    head_count=6,
    layer_count=6,
    batch_size=32,
    device="cuda",
))

trainer_config = TrainerConfig(
    iteration_count=6000,
    checkpoint_freq=500,
    max_learning_rate=3e-5,
    min_learning_rate=1e-7,
    weight_decay=0.0,
    grad_clip=1.0,
    grad_accum_steps=1,
    context_size=1024,
    data_folder='maestro-v3.0.0-midi',
    checkpoint_path="checkpoint.pth",
)

trainer = Trainer(model, trainer_config)
trainer.load_checkpoint(load_training_state=True)

In [3]:
from tqdm.notebook import tqdm

bar = tqdm(total=trainer.config.iteration_count)
bar.update(trainer.iteration)

while trainer.config.iteration_count > trainer.iteration:
    if trainer.iteration % trainer.config.checkpoint_freq == 0:
        loss = trainer.validate_step()

        if trainer.iteration != 0 and loss < trainer.best_validation_loss:
            trainer.best_validation_loss = loss
            trainer.save_checkpoint()

    trainer.train_step()
    trainer.iteration += 1
    bar.update(1)

trainer.save_checkpoint()
bar.close()

  0%|          | 0/6000 [00:00<?, ?it/s]

loss: {'train': 4.050461292266846, 'validate': 4.1725592613220215}
loss: {'train': 3.9428794384002686, 'validate': 4.090741157531738}
loss: {'train': 3.9370028972625732, 'validate': 4.055316925048828}
loss: {'train': 3.880692958831787, 'validate': 4.055731296539307}
loss: {'train': 3.907755136489868, 'validate': 4.02533483505249}
loss: {'train': 3.923783302307129, 'validate': 4.029364109039307}
loss: {'train': 3.8805956840515137, 'validate': 3.991508722305298}
loss: {'train': 3.8984646797180176, 'validate': 3.985015630722046}
loss: {'train': 3.894092082977295, 'validate': 3.991166591644287}


KeyboardInterrupt: 

In [4]:
def generate_sample(model: Samba, path: str, token_count: int = 1024 * 4):
    context = torch.zeros((1, 1, 4), dtype=torch.long,
                          device=model.config.device)
    tokens = model.generate(context, new_token_count=token_count)
    tokens = tokens.squeeze().cpu().numpy()
    vocab.tokens_to_midi(path, tokens)


def generate_sample_from_context(model: Samba, path: str, context_path: str, token_count: int = 1024 * 4):
    tokens = vocab.midi_to_tokens(context_path, 0)
    context = torch.from_numpy(tokens).to(
        dtype=torch.long, device=model.config.device)
    context = context.reshape(1, *context.shape)
    tokens = model.generate(context, new_token_count=token_count)
    tokens = tokens.squeeze().cpu().numpy()
    vocab.tokens_to_midi(path, tokens)

In [8]:

generate_sample(model, "sample4.midi", token_count=1024 * 4)