In [1]:
import torch
import torch.nn as nn
import pandas as pd
import ast
from typing import Optional
from dataclasses import dataclass
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence
from xformers.factory.model_factory import xFormer, xFormerConfig
from xformers.components.positional_embedding import (PositionEmbedding, PositionEmbeddingConfig, register_positional_embedding)
from transformers import AdamW, get_linear_schedule_with_warmup
from torch.utils.data import random_split
from torch.utils.tensorboard import SummaryWriter

A matching Triton is not available, some optimizations will not be enabled.
Error caught was: No module named 'triton'
Triton is not available, some optimizations will not be enabled.
Triton is not available, FusedMLP will not be enabled.
Either FairScale or torch distributed is not available, MixtureOfExperts will not be exposed. Please install them if you would like to use MoE


In [2]:
PAD_IDX = 128
BOS_IDX = 129
EOS_IDX = 130
PAD_VALUE = 0.0

NUM_ENCODER_LAYERS = 3
NUM_DECODER_LAYERS = 3
EMB_SIZE=128
MAX_LEN = 256
SRC_VOCAB_SIZE = 128+3 # 0-127 representing From C-1 to G9, 128 for PAD_IDX, 129 for BOS, 130 for EOS
TGT_VOCAB_SIZE = 128+3
NHEAD = 8
HIDDEN_LAYER_MULTIPLIER = 2
DROPOUT = 0.1

TRAIN_SPLIT = 0.9

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # A GPU with memory >=8GB is capable of training
LEARNING_RATE = 0.0001
NUM_EPOCHS = 7
BATCH_SIZE = 128

LOAD_PRETRAINED = False
MODEL_SAVE_PATH = './Model/MidiGen.pth'

In [3]:
class MidiDataset(Dataset):
    def __init__(self, csv_path='./Data/MidiDataset.csv'):
        self.data = pd.read_csv(csv_path)

    def __len__(self):
        # return len(self.data) - 2
        return 30000
    
    def __getitem__(self, idx):
        """
        Returns two dictionaries with keys:
            'sentence', 'time_since_last_note', 'duration', 'velocity'
        """
        current_row = self.data.iloc[idx]
        next_row = self.data.iloc[idx+1]
        return (
            {
                'sentence': current_row['Sentence'],
                'time_since_last_note': current_row['TimeSinceLastNoteStart'],
                'duration': current_row['Duration'],
                'velocity': current_row['Velocity']
            },
            {
                'sentence': next_row['Sentence'],
                'time_since_last_note': next_row['TimeSinceLastNoteStart'],
                'duration': next_row['Duration'],
                'velocity': next_row['Velocity']
            }
        )


token_transform = lambda x: torch.cat((torch.tensor([BOS_IDX]), torch.tensor(ast.literal_eval(x)), torch.tensor([EOS_IDX])))
tensor_transform = lambda x: torch.cat((torch.tensor([PAD_VALUE]), torch.tensor(ast.literal_eval(x)), torch.tensor([PAD_VALUE])))

def collate_fn(batch):
    """
    Collate function to be used with a DataLoader.
    """

    sentences1 = [token_transform(item[0]['sentence']) for item in batch]
    time_since_last_note1 = [tensor_transform(item[0]['time_since_last_note']) for item in batch]
    duration1 = [tensor_transform(item[0]['duration']) for item in batch]
    velocity1 = [tensor_transform(item[0]['velocity'])/100 for item in batch]
    sentences2 = [token_transform(item[1]['sentence']) for item in batch]
    time_since_last_note2 = [tensor_transform(item[1]['time_since_last_note']) for item in batch]
    duration2 = [tensor_transform(item[1]['duration']) for item in batch]
    velocity2 = [tensor_transform(item[1]['velocity'])/100 for item in batch]

    # Pad them to the same length per batch
    sentences1 = pad_sequence(sentences1, batch_first=True, padding_value=PAD_IDX)
    time_since_last_note1 = pad_sequence(time_since_last_note1, batch_first=True, padding_value=PAD_VALUE).unsqueeze(-1)
    duration1 = pad_sequence(duration1, batch_first=True, padding_value=PAD_VALUE).unsqueeze(-1)
    velocity1 = pad_sequence(velocity1, batch_first=True, padding_value=PAD_VALUE).unsqueeze(-1)
    extra1 = torch.cat([time_since_last_note1, duration1, velocity1], dim=-1)

    sentences2 = pad_sequence(sentences2, batch_first=True, padding_value=PAD_IDX)
    time_since_last_note2 = pad_sequence(time_since_last_note2, batch_first=True, padding_value=PAD_VALUE).unsqueeze(-1)
    duration2 = pad_sequence(duration2, batch_first=True, padding_value=PAD_VALUE).unsqueeze(-1)
    velocity2 = pad_sequence(velocity2, batch_first=True, padding_value=PAD_VALUE).unsqueeze(-1)
    extra2 = torch.cat([time_since_last_note2, duration2, velocity2], dim=-1)

    return (sentences1, extra1), (sentences2, extra2)


In [4]:
dataset = MidiDataset()
train_split = int(len(dataset) * TRAIN_SPLIT)
train_dataset, val_dataset = random_split(dataset, [train_split, len(dataset) - train_split])
train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, collate_fn=collate_fn)

NUM_TRAINING_STEPS = len(train_dataset) // BATCH_SIZE * (NUM_EPOCHS + 2) # +2 to bias the learning rate decay
NUM_WARMUP_STEPS = NUM_TRAINING_STEPS // NUM_EPOCHS # Use 1 epoch for warmup

In [5]:
@dataclass
class MidiEmbeddingConfig(PositionEmbeddingConfig):
    pitch_size: int
    dropout: float

@register_positional_embedding("midi", MidiEmbeddingConfig)
class MidiEmbedding(PositionEmbedding):
    def __init__(self, dim_model: int, seq_len: int, pitch_size: int, dropout: float = 0.0,*args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dim_model = dim_model
        self.seq_len = seq_len
        self.pitch_size = pitch_size
        self.dropout = torch.nn.Dropout(p=dropout)

        self.position_embeddings = nn.Embedding(seq_len, self.dim_model)
        self.pitch_embeddings = nn.Embedding(self.pitch_size, self.dim_model - 3)

        self.position_ids: Optional[torch.Tensor] = None

    def init_weights(self, gain: float = 1.0):
        torch.nn.init.normal_(self.position_embeddings.weight, std=0.02 * gain)
        torch.nn.init.normal_(self.pitch_embeddings.weight, std=0.02 * gain)

    def forward(self, x: torch.Tensor):
        sentence = x[0]
        extra = x[1]

        position_ids = torch.arange(sentence.shape[1], dtype=torch.long, device=sentence.device)[
            None, :
        ].repeat(sentence.shape[0], 1)

        pitch_token = self.pitch_embeddings(sentence)
        
        x = torch.cat([pitch_token, extra], dim=-1)
        pos = self.position_embeddings(position_ids)

        X = x + pos
        X = self.dropout(X)

        return X


class MidiTransformer(nn.Module):
    def __init__(self, model_config) -> None:
        super().__init__()
        self.model_config = xFormerConfig(model_config)
        self.transformer = xFormer.from_config(self.model_config)
        self.generator = nn.Linear(
            model_config[1]['dim_model'], model_config[1]['position_encoding_config']['pitch_size'])
        self.extra_generator = nn.Sequential(
            nn.Linear(model_config[1]['dim_model'], model_config[1]['dim_model']*2), 
            nn.Linear(model_config[1]['dim_model']*2, 3))  # [time_since_last_note, duration, velocity]
        self.softmax = nn.Softmax(dim=-1)
        self.relu = nn.ReLU()

    def forward(self, src, tgt, src_mask=None, tgt_mask=None):
        memory = self.encode(src, src_mask)
        out = self.decode(tgt, memory, tgt_mask)
        return self.softmax(self.generator(out)), self.relu(self.extra_generator(out))
    
    def encode(self, src, src_mask=None):
        encoders = self.transformer.encoders
        memory = src[:]
        if isinstance(encoders, torch.nn.ModuleList):
            for encoder in encoders:
                memory = encoder(memory, input_mask=src_mask)
        else:
            if self.transformer.rev_enc_pose_encoding:
                memory = self.transformer.rev_enc_pose_encoding(src)

            # Reversible Encoder
            x = torch.cat([memory, memory], dim=-1)

            # Apply the optional input masking
            if src_mask is not None:
                if x.dim() - src_mask.dim() > 1:
                    src_mask.unsqueeze(0)
                x += src_mask.unsqueeze(-1)

            x = encoders(x)
            memory = torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
        return memory
    
    def decode(self, tgt, memory, tgt_mask=None):
        for decoder in self.transformer.decoders:
            tgt = decoder(target=tgt, memory=memory, input_mask=tgt_mask)
        return tgt


model_config = [
    {
        "reversible": True,  # Reversible encoder can save a lot memory when training
        "block_type": "encoder",
        "num_layers": NUM_ENCODER_LAYERS,
        "dim_model": EMB_SIZE,
        "residual_norm_style": "pre",
        "position_encoding_config": {
            "name": "midi",  # The vocab type position encoding includes token embedding layer and position encoding layer
            "seq_len": MAX_LEN,
            "pitch_size": SRC_VOCAB_SIZE,
        },
        "multi_head_config": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "linformer",
                "dropout": 0,
                "causal": False,
                "seq_len": MAX_LEN,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": DROPOUT,
            "activation": "relu",
            "hidden_layer_multiplier": HIDDEN_LAYER_MULTIPLIER, # Hidden layer dimension is HIDDEN_LAYER_MULTIPLIER times dim_model
        },
    },
    {
        "reversible": False,
        "block_type": "decoder",
        "num_layers": NUM_DECODER_LAYERS,
        "dim_model": EMB_SIZE,
        "residual_norm_style": "pre",
        "position_encoding_config": {
            "name": "midi",
            "seq_len": MAX_LEN,
            "pitch_size": TGT_VOCAB_SIZE,
        },
        "multi_head_config_masked": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "nystrom",
                "dropout": 0,
                "causal": True,  # Causal attention is used to prevent the decoder from attending the future tokens in the target sequences
                "seq_len": MAX_LEN,
            },
        },
        "multi_head_config_cross": {
            "num_heads": NHEAD,
            "residual_dropout": 0,
            "attention": {
                "name": "favor",
                "dropout": 0,
                "causal": False,
                "seq_len": MAX_LEN,
            },
        },
        "feedforward_config": {
            "name": "MLP",
            "dropout": DROPOUT,
            "activation": "relu",
            "hidden_layer_multiplier": HIDDEN_LAYER_MULTIPLIER,
        },
    },
]


In [11]:
model = MidiTransformer(model_config=model_config)
print(f'Num Params: {sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.2f} M', model)
model = model.to(DEVICE)

cross_entropy = nn.CrossEntropyLoss(ignore_index=PAD_IDX)
huber = nn.HuberLoss()
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=NUM_WARMUP_STEPS, num_training_steps=NUM_TRAINING_STEPS)

Num Params: 1.42 M MidiTransformer(
  (transformer): xFormer(
    (rev_enc_pose_encoding): MidiEmbedding(
      (dropout): Dropout(p=0.0, inplace=False)
      (position_embeddings): Embedding(320, 64)
      (pitch_embeddings): Embedding(131, 61)
    )
    (encoders): ReversibleSequence(
      (blocks): ModuleList(
        (0-7): 8 x ReversibleBlock(
          (f): Deterministic(
            (net): PreNorm(
              (norm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
              (sublayer): MultiHeadDispatch(
                (attention): LinformerAttention(
                  (E): Linear(in_features=320, out_features=80, bias=False)
                  (F): Linear(in_features=320, out_features=80, bias=False)
                  (attn_drop): Dropout(p=0, inplace=False)
                )
                (in_proj_container): InputProjection(
                  (q_proj): Linear(in_features=64, out_features=64, bias=True)
                  (k_proj): Linear(in_features=64, out_feat

In [7]:
if LOAD_PRETRAINED:
    model.load_state_dict(torch.load(MODEL_SAVE_PATH))

In [8]:
training_step = 0
put_tuple_to_device = lambda x: tuple(x.to(DEVICE) for x in x)
create_mask = lambda x: x != PAD_IDX
writer = SummaryWriter()

def train_epoch(model, optimizer):
    model.train()
    for src, tgt in tqdm(train_dataloader, desc="Train"):
        src = put_tuple_to_device(src)
        tgt = put_tuple_to_device(tgt)
        tgt_input = (tgt[0][:, :-1], tgt[1][:, :-1])
        src_mask, tgt_mask = create_mask(src[0]), create_mask(tgt_input[0])

        optimizer.zero_grad()
        logits, extra = model(src, tgt_input, src_mask, tgt_mask)
        tgt_out = tgt[0][:, 1:]
        
        loss_c = cross_entropy(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss_h = huber(tgt[1][:, 1:].reshape(-1), extra.reshape(-1))
        loss = loss_c + loss_h
        
        global training_step

        writer.add_scalar("Loss-Train/CrossEntropy", loss_c.item(), training_step)
        writer.add_scalar("Loss-Train/Huber", loss_h.item(), training_step)
        writer.add_scalar("Loss-Train/Total", loss.item(), training_step)
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), 0.7)
        optimizer.step()
        scheduler.step()

        writer.add_scalar("LearningRate", optimizer.param_groups[0]["lr"], training_step)
        training_step += 1

def eval_epoch(model):
    model.eval()
    total_losses = 0
    total_losses_c = 0
    total_losses_h = 0
    total_steps = 0
    with torch.no_grad():
        for src, tgt in tqdm(val_dataloader, desc="Validation"):
            src = put_tuple_to_device(src)
            tgt = put_tuple_to_device(tgt)
            tgt_input = (tgt[0][:, :-1], tgt[1][:, :-1])
            src_mask, tgt_mask = create_mask(src[0]), create_mask(tgt_input[0])

            logits, extra = model(src, tgt_input, src_mask, tgt_mask)
            tgt_out = tgt[0][:, 1:]
            loss_c = cross_entropy(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
            loss_h = huber(tgt[1][:, 1:].reshape(-1), extra.reshape(-1))
            loss = loss_c + loss_h
            total_losses += loss.item()
            total_losses_c += loss_c.item()
            total_losses_h += loss_h.item()
            total_steps += 1

        global training_step
        writer.add_scalar("Loss-Val/CrossEntropy", total_losses_c/total_steps, training_step)
        writer.add_scalar("Loss-Val/Huber", total_losses_h/total_steps, training_step)
        writer.add_scalar("Loss-Val/Total", total_losses/total_steps, training_step)
    return total_losses / total_steps


In [9]:
from timeit import default_timer as timer
for epoch in range(NUM_EPOCHS):
    start_time = timer()
    print("-" * 64)
    print("Start Epoch {}/{}".format(epoch + 1, NUM_EPOCHS))
    train_epoch(model, optimizer)
    end_time = timer()
    writer.flush()
    eval_loss = eval_epoch(model)
    writer.flush()
    torch.save(model.state_dict(), MODEL_SAVE_PATH)
    print(f"End Epoch {epoch + 1}/{NUM_EPOCHS} in {end_time - start_time:.2f}s with eval loss {eval_loss:.2f}")
    print("-" * 64)

----------------------------------------------------------------
Start Epoch 1/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 1/7 in 86.88s with eval loss 4.87
----------------------------------------------------------------
----------------------------------------------------------------
Start Epoch 2/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 2/7 in 87.91s with eval loss 4.84
----------------------------------------------------------------
----------------------------------------------------------------
Start Epoch 3/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 3/7 in 76.59s with eval loss 4.82
----------------------------------------------------------------
----------------------------------------------------------------
Start Epoch 4/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 4/7 in 76.83s with eval loss 4.81
----------------------------------------------------------------
----------------------------------------------------------------
Start Epoch 5/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 5/7 in 76.77s with eval loss 4.81
----------------------------------------------------------------
----------------------------------------------------------------
Start Epoch 6/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 6/7 in 77.21s with eval loss 4.80
----------------------------------------------------------------
----------------------------------------------------------------
Start Epoch 7/7


Train:   0%|          | 0/211 [00:00<?, ?it/s]

Validation:   0%|          | 0/24 [00:00<?, ?it/s]

End Epoch 7/7 in 76.80s with eval loss 4.80
----------------------------------------------------------------


In [12]:
# def greedy_decode(model, src, src_mask=None, max_Len=MAX_LEN, start_symbol=BOS_IDX):
#     src = put_tuple_to_device(src)
#     memory = model.encode(src, src_mask)
#     print(f'Encoder first encode src to memory {memory.shape}')
#     ys_token = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
#     print(f'We init the temp token output with start symbol {start_symbol}| {ys_token.shape} | {ys_token}')
#     ys_extra = torch.ones(1, 1, 3).fill_(PAD_VALUE).type(torch.long).to(DEVICE)
#     print(f'We init the temp extra output with pad value {PAD_VALUE}| {ys_extra.shape} | {ys_extra}')
#     ys = (ys_token, ys_extra)
#     print(f'Then we combine them back to tuple {ys}')
#     for i in range(max_Len-1):
#         print('-'*64)
#         print(f'Decoder step {i}')
#         out = model.decode(ys, memory)
#         print(f'Decoder output {out.shape}, we then use the [:, -1, :] to get prob of next token\nout:{out}')
#         prob = model.generator(out[:, -1, :])
#         print(f'Generator output prob {prob.shape}\nprob:{prob}')
#         extra = model.extra_generator(out[:, -1, :]).unsqueeze(0)
#         print(f'Extra generator output {extra.shape}')
#         _, next_token = torch.max(prob, dim=1)
#         print(f'Next token {next_token}')
#         next_token = next_token.item()
#         ys_token = torch.cat([ys_token, torch.ones(1, 1).type_as(src[0].data).fill_(next_token)], dim=1)
#         print(f'We add next token to the temp token output to form {ys_token.shape}|{ys_token}')
#         ys_extra = torch.cat([ys_extra, extra], dim=1)
#         print(f'We add extra to the temp extra output to form {ys_extra.shape}|{ys_extra}')
#         ys = (ys_token, ys_extra)
#         print(f'Finally combine them and go to next loop {ys}')
#         if next_token == EOS_IDX:
#             break
#     return ys

def greedy_decode(model, src, src_mask=None, max_len=MAX_LEN, start_symbol=BOS_IDX):
    src = put_tuple_to_device(src)
    memory = model.encode(src, src_mask)
    ys_token = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    ys_extra = torch.ones(1, 1, 3).fill_(PAD_VALUE).type(torch.long).to(DEVICE)
    ys = (ys_token, ys_extra)
    for i in range(max_len-1):
        out = model.decode(ys, memory)
        prob = model.generator(out[:, -1, :])
        extra = model.extra_generator(out[:, -1, :]).unsqueeze(0)
        _, next_token = torch.max(prob, dim=1)
        next_token = next_token.item()
        ys_token = torch.cat([ys_token, torch.ones(1, 1).type_as(src[0].data).fill_(next_token)], dim=1)
        ys_extra = torch.cat([ys_extra, extra], dim=1)
        ys = (ys_token, ys_extra)
        if next_token == EOS_IDX:
            break
    return ys


def continue_writing(model, src):
    model.eval()
    with torch.no_grad():
        mask = create_mask(src[0]).to(DEVICE)
        result = greedy_decode(model, src, mask)
    return result

src, trg = next(iter(val_dataloader))
src_test = (src[0][0][:17].unsqueeze(0), src[1][0][:17].unsqueeze(0))
print(f'Input should be a tuple contains tensor of shape {src_test[0].shape, src_test[1].shape}')
print(f'Such as {src_test}')
continue_writing(model, src_test)

Input should be a tuple contains tensor of shape (torch.Size([1, 17]), torch.Size([1, 17, 3]))
Such as (tensor([[129,  69,  74,  57,  42,  74,  69,  50,  57,  67,  50,  62,  50,  66,
         130, 128, 128]]), tensor([[[0.0000, 0.0000, 0.0000],
         [0.2219, 0.0698, 0.7800],
         [0.2615, 0.3594, 0.8600],
         [0.0250, 0.0354, 0.5100],
         [0.2427, 0.5250, 0.5600],
         [0.1917, 1.4031, 0.8400],
         [0.0229, 0.6635, 0.5600],
         [0.2990, 0.0771, 0.4900],
         [0.1479, 0.2531, 0.5200],
         [0.2167, 0.4719, 0.6800],
         [0.0063, 0.1365, 0.4300],
         [0.2198, 0.0542, 0.4200],
         [0.1688, 0.0479, 0.4300],
         [0.0229, 0.2531, 0.7700],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000]]]))
Encoder first encode src to memory torch.Size([1, 17, 64])
We init the temp token output with start symbol 129| torch.Size([1, 1]) | tensor([[129]], device='cuda:0')
We init the temp extra out

(tensor([[129, 106, 106,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,  95,
           95,  95,  95,  95,  95,  95, 