In [None]:
import torch
import torch.nn as nn
import pandas as pd
import ast
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence, pack_padded_sequence
from tqdm.notebook import tqdm

In [None]:
PAD_IDX = 0
PAD_VALUE = 0.0 

PITCH_EMB = 48
EXTRA_DIM = 3
EXTRA_EMB = 16
DROPOUT = 0.2

HIDDEN_SIZE_GEN = 256
NUM_LAYERS_ENCODER_GEN = 2
NUM_LAYERS_GENERATOR_GEN = 2

HIDDEN_SIZE_DIS = 64
NUM_LAYERS_ENCODER_DIS = 1
HIDDEN_LAYERS_DIS = 1

DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
LEARNING_RATE_GEN = 1e-3
LEARNING_RATE_DIS = 1e-4
NUM_EPOCHS = 5
BATCH_SIZE = 128
GEN_SAVE_PATH = './Model/MidiGenLSTM.pth'
DIS_SAVE_PATH = './Model/MidiDisLSTM.pth'

In [None]:
class ResMLP(nn.Module):
    def __init__(self, in_dim, hidden_dim, out_dim, dropout=0.1):
        super(ResMLP, self).__init__()
        self.linear0 = nn.Linear(in_dim, out_dim)

        self.linear1 = nn.Linear(in_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_dim, out_dim)
        self.dropout = nn.Dropout(dropout)

        self.norm = nn.LayerNorm(out_dim)

    def forward(self, x):
        b1 = self.linear1(x)
        b1 = self.relu(b1)
        b1 = self.linear2(b1)
        b1 = self.dropout(b1)
        b2 = self.linear0(x)
        y = b1 + b2
        y = self.norm(y)
        return y

class MidiGenLSTM(nn.Module):
    def __init__(self, pitch_emb = 48, extra_dim = 3, extra_emb = 16, hidden_size = 256, num_layers_encoder = 4, num_layers_generator = 4):
        super(MidiGenLSTM, self).__init__()
        self.pitch_emb = pitch_emb
        self.extra_emb = extra_emb
        self.hidden_size = hidden_size
        self.num_layers_encoder = num_layers_encoder
        self.num_layers_generator = num_layers_generator
        
        self.pitch_embedding = nn.Embedding(num_embeddings=128, embedding_dim=pitch_emb)
        self.extra_embedding = ResMLP(in_dim=extra_dim, hidden_dim=extra_emb, out_dim=extra_emb, dropout=0.0)
        self.encoder = nn.LSTM(input_size=pitch_emb+extra_emb, hidden_size=hidden_size, num_layers=num_layers_encoder)

        self.generator = nn.LSTM(input_size=hidden_size, hidden_size=hidden_size, num_layers=num_layers_generator)
        self.generator_pitch = nn.Sequential(nn.Linear(in_features=hidden_size, out_features=hidden_size), nn.LeakyReLU(), nn.Linear(in_features=hidden_size, out_features=128), nn.Softmax(dim=-1))
        self.generator_extra = ResMLP(in_dim=hidden_size, hidden_dim=hidden_size, out_dim=extra_dim)

    def encode(self, pitch:torch.Tensor, extra:torch.Tensor)->torch.Tensor:
        pitch_emb = self.pitch_embedding(pitch)
        extra_emb = self.extra_embedding(extra)
        memory, _ = self.encoder(torch.cat((pitch_emb, extra_emb), dim=-1))
        return torch.mean(memory, dim=0)
    
    def generate(self, memory:torch.Tensor, len_latent=16, batch_size=8, memory_factor=0.5)->torch.Tensor:
        latent = torch.randn(len_latent, batch_size, self.hidden_size, device=memory.device)
        memory_factor = max(0, min(1, memory_factor)) # Clamp memory factor to 0-1
        x = memory_factor * memory + (1-memory_factor) * latent
        out, _ = self.generator(x)
        return out 

    def forward(self, pitch:torch.Tensor, extra:torch.Tensor, len_latent=16, memory_factor=0.5)->torch.Tensor:
        memory = self.encode(pitch, extra)
        y = self.generate(memory, len_latent, batch_size=pitch.shape[1], memory_factor=memory_factor)
        return self.generator_pitch(y), self.generator_extra(y)
    
class MidiDiscriminator(nn.Module):
    def __init__(self, pitch_emb = 48, extra_dim = 3, extra_emb = 16, hidden_size = 256, num_layers_encoder = 2, dropout=0.2, hidden_layers=2):
        super(MidiDiscriminator, self).__init__()
        self.pitch_embedding = nn.Embedding(num_embeddings=128, embedding_dim=pitch_emb)
        self.extra_embedding = ResMLP(in_dim=extra_dim, hidden_dim=extra_emb, out_dim=extra_emb, dropout=0.0)
        self.encoder = nn.LSTM(input_size=pitch_emb+extra_emb, hidden_size=hidden_size, num_layers=num_layers_encoder)
        self.bn1 = nn.BatchNorm1d(num_features=hidden_size)

        res_mlp_layers = [ResMLP(in_dim=hidden_size, hidden_dim=hidden_size, out_dim=hidden_size, dropout=dropout) for _ in range(hidden_layers)]
        self.res_mlp = nn.Sequential(*res_mlp_layers)
        self.classifier = nn.Linear(hidden_size, 1)
        self.sigmoid = nn.Sigmoid()

    def forward(self, pitch:torch.Tensor, extra:torch.Tensor)->torch.Tensor:
        x = torch.cat((self.pitch_embedding(pitch), self.extra_embedding(extra)), dim=-1)
        x, _ = self.encoder(x)
        x = self.bn1(x.permute(1, 2, 0)).permute(2, 0, 1)
        x = torch.mean(x, dim=0)
        x = self.res_mlp(x)
        x = self.classifier(x)
        x = self.sigmoid(x)
        return x

In [None]:
generator = MidiGenLSTM(pitch_emb=PITCH_EMB, extra_dim=EXTRA_DIM, extra_emb=EXTRA_EMB, hidden_size=HIDDEN_SIZE_GEN, num_layers_encoder=NUM_LAYERS_ENCODER_GEN, num_layers_generator=NUM_LAYERS_GENERATOR_GEN)
print(generator)
print(f'Generator Num Params: {sum(p.numel() for p in generator.parameters() if p.requires_grad)/1e6:.2f} M')
generator.to(DEVICE)

for name, param in generator.named_parameters():
    if param.requires_grad:
        print(f'{name}: {param.numel()/1e3:.1f}K')
print('-'*64)
discriminator = MidiDiscriminator(pitch_emb=PITCH_EMB, extra_dim=EXTRA_DIM, extra_emb=EXTRA_EMB, hidden_size=HIDDEN_SIZE_DIS, num_layers_encoder=NUM_LAYERS_ENCODER_DIS, dropout=DROPOUT, hidden_layers=HIDDEN_LAYERS_DIS)
print(discriminator)
print(f'Discriminator Num Params: {sum(p.numel() for p in discriminator.parameters() if p.requires_grad)/1e6:.2f} M')
discriminator.to(DEVICE)
for name, param in discriminator.named_parameters():
    if param.requires_grad:
        print(f'{name}: {param.numel()/1e3:.1f}K')

In [None]:
class MidiDataset(Dataset):
    def __init__(self, csv_path='./Data/MidiDataset.csv'):
        self.data = pd.read_csv(csv_path)

    def __len__(self):
        return len(self.data) - 2
    
    def __getitem__(self, idx):
        current_row = self.data.loc[idx]
        next_row = self.data.loc[idx+1]

        return (
            torch.tensor(ast.literal_eval(current_row.at['Sentence'])),
            torch.tensor([ast.literal_eval(current_row.at['TimeSinceLastNoteStart']), ast.literal_eval(current_row.at['Duration']), ast.literal_eval(current_row.at['Velocity'])], dtype=torch.float32).transpose(0,1),
            torch.tensor(ast.literal_eval(next_row.at['Sentence'])),
            torch.tensor([ast.literal_eval(next_row.at['TimeSinceLastNoteStart']), ast.literal_eval(next_row.at['Duration']), ast.literal_eval(next_row.at['Velocity'])], dtype=torch.float32).transpose(0,1)
        )


def collate_fn(batch):
    sentence_input, extra_input, sentence_tgt, extra_tgt = zip(*batch)
    padded_sentence_input = pad_sequence(sentence_input, padding_value=PAD_IDX)
    padded_extra_input = pad_sequence(extra_input, padding_value=PAD_VALUE)
    padded_sentence_tgt = pad_sequence(sentence_tgt, padding_value=PAD_IDX)
    padded_extra_tgt = pad_sequence(extra_tgt, padding_value=PAD_VALUE)

    return padded_sentence_input, padded_extra_input, padded_sentence_tgt, padded_extra_tgt

dataset = MidiDataset()
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn)

In [None]:
criterion = nn.BCELoss()
optimizerD = torch.optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_DIS)
optimizerG = torch.optim.Adam(generator.parameters(), lr=LEARNING_RATE_GEN)

def train_batch(sentence_input, extra_input, sentence_tgt, extra_tgt, optimizerD, optimizerG):
    discriminator.zero_grad()
    label = torch.full((sentence_tgt.shape[1],), 1., dtype=torch.float).to(DEVICE)
    output = discriminator(sentence_tgt, extra_tgt).view(-1)
    errD_real = criterion(output, label)
    errD_real.backward()
    D_x = output.mean().item()

    random_memory_factor = torch.normal(0.5, 0.25, size=(1,)).item() # Random factor that controls the noise we apply to the input sentence
    sentence_generated, extra_generated = generator(sentence_input, extra_input, len_latent=sentence_tgt.shape[0], memory_factor=random_memory_factor)
    sentence_generated = torch.argmax(sentence_generated, dim=-1)

    label.fill_(0.)
    output = discriminator(sentence_generated.detach(), extra_generated.detach()).view(-1)
    errD_fake = criterion(output, label)
    errD_fake.backward()
    D_G_z1 = output.mean().item()
    errD = errD_real + errD_fake
    optimizerD.step()

    generator.zero_grad()
    label.fill_(1.)
    output = discriminator(sentence_generated, extra_generated).view(-1)
    errG = criterion(output, label)
    errG.backward()
    D_G_z2 = output.mean().item()
    optimizerG.step()

    return errD.item(), errG.item(), D_x, D_G_z1, D_G_z2

def train(epochs, writer:SummaryWriter, gdt=0):
    discriminator.train()
    generator.train()
    for epoch in tqdm(range(epochs)):
        for sentence_input, extra_input, sentence_tgt, extra_tgt in tqdm(dataloader, desc=f'Epoch {epoch}'):
            sentence_input, extra_input, sentence_tgt, extra_tgt = sentence_input.to(DEVICE), extra_input.to(DEVICE), sentence_tgt.to(DEVICE), extra_tgt.to(DEVICE)
            errD, errG, D_x, D_G_z1, D_G_z2 = train_batch(sentence_input, extra_input, sentence_tgt, extra_tgt, optimizerD, optimizerG)
            gdt += 1
            writer.add_scalar('D_x', D_x, gdt)
            writer.add_scalar('D_G_z1', D_G_z1, gdt)
            writer.add_scalar('D_G_z2', D_G_z2, gdt)
            writer.add_scalar('errD', errD, gdt)
            writer.add_scalar('errG', errG, gdt)
        torch.save(generator.state_dict(), GEN_SAVE_PATH)
        torch.save(discriminator.state_dict(), DIS_SAVE_PATH)

tensorboard = SummaryWriter()
train(NUM_EPOCHS, tensorboard)