In [146]:
! CUDA_LAUNCH_BLOCKING=1
import torch.nn as nn
import torch.nn.functional as F
import torch
import random
from tqdm import tqdm


REST_ID = 128
BOS_ID = 129
EOS_ID = 130

class MIDIVAE(nn.Module):
  def __init__(self, encoder, decoder,classifier,max_length=502,device = 'cuda',kld_weight = 1.0, classifier_weight = 1.0):
      super(MIDIVAE,self).__init__()

      self.device = device
      # untrained components
      self.encoder = encoder.to(self.device)
      self.decoder = decoder.to(self.device)  # Cross-Attention transformer to generate MIDI vectors
      self.max_length = max_length
      #workaround for cross-attention
      self.num_memory_tokens = (encoder.latent_dim // 16)
      self.memory_proj = nn.Sequential(
            nn.Linear(encoder.latent_dim, decoder.d_model * self.num_memory_tokens),
            nn.Unflatten(1, (self.num_memory_tokens, decoder.d_model))  # [B, num_tokens, d_model]
        ).to(self.device)
      # pre-trained classifier
      self.classifier = classifier

      #constants
      self.kld_weight = kld_weight
      self.classifier_weight = classifier_weight
      self.BOS_ID = BOS_ID
      self.EOS_ID = EOS_ID
      self.REST_ID = REST_ID

  def forward(self, x, label, teacher_forcing_ratio=0.5):
        z, mean, logvar = self.encoder(x, label)

        memory = self._prepare_memory(z.to(self.device)).to(self.device)

        tgt = torch.full((1, x.shape[0]), self.BOS_ID, device=self.device)  # [1, batch_size]
        generated_tokens = tgt
        #print(f'generated tok shape { generated_tokens.shape}')
        # logits_list = []
      
        for i in range(501): 
            if self.training and random.random() < teacher_forcing_ratio:
                # Pass the ground truth (x) as the target for the next token
                #print('------TEACHER FORCING--------')
                #print(x.shape)
                #print(x[:,:i+1].shape)
                next_token_logits = self.decoder(
                    tgt = x[:,:i+1].transpose(0,1),  # Use ground truth target token sequence
                    memory=memory,  # Latent memory
                    z=z,  # Latent representation from encoder
                    label=label,  # Conditioning label
                    teacher_forcing=True  # Enable teacher forcing
                ).transpose(0, 1)
            else:
                #print('------ FREESTYLE--------')
                #print(f'generated tok shape { generated_tokens.shape}')
                next_token_logits = self.decoder(
                    tgt=generated_tokens,  # Use generated tokens
                    memory=memory,  # Latent memory
                    z=z,  # Latent representation from encoder
                    label=label,  # Conditioning label
                    teacher_forcing=False  # Disable teacher forcing
                ).transpose(0, 1)  # Adjust shape for further processing
            #logits_list.append(next_token_logits)
            next_token_logits = next_token_logits[-1, :, :]  # [1, batch_size, vocab_size]
            probs = F.softmax(next_token_logits, dim=-1)
            next_token = torch.argmax(next_token_logits, dim=-1).unsqueeze(0)  # [1, batch_size]
            generated_tokens = torch.cat([generated_tokens, next_token], dim=0)
            # Force EOS token at max length (502 tokens)
            if i == 501:
                next_token.fill_(self.EOS_ID)  # Force EOS token at max_len
                print("Forcing <EOS> token at 502nd position.")
                generated_tokens = torch.cat([generated_tokens, next_token], dim=0)
                break
            if (next_token == self.EOS_ID).all():
                print("Forcing <EOS> token at max length.")
                break
            #print(f'generate seq shape {generated_tokens.shape},next_token shape{next_token.shape}')
        recon_midi = generated_tokens  # Final generated sequence of tokens
        composer_pred = self.classifier(recon_midi.transpose(0,1))

        return recon_midi,mean, logvar, composer_pred

  def train_model(self, dataloader, optimizer, epochs=10):
        self.train()
        for epoch in range(epochs):
            total_loss = 0
            for x, label in tqdm(dataloader,desc = f"Training Epoch {epoch + 1}"):
                x, label = x.to(self.device), label.to(self.device)

                optimizer.zero_grad()
                recon_x, mu, logvar, pred = self(x, label)
                #recon_x, token_logits,mu, logvar = self(x, label)
                
                # Loss calculations
                kld = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
                #print(f'x data type {x.dtype} recon data type {recon_x.dtype}')
                #print(f'recon shape{recon_x.transpose(0,1).shape},x shape {x.shape}')
                recon_loss = F.mse_loss(
                    recon_x.transpose(0,1).to(torch.float32),  # [batch,seq_len]
                    x.to(torch.float32)  # [batch, seq_len]
                )
                
                cls_loss = F.cross_entropy(pred, label)

                loss = recon_loss + self.kld_weight*kld + self.classifier_weight*cls_loss
                #loss = recon_loss + self.kld_weight*kld
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            print(f"Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f}")

  def _prepare_memory(self, z):
        memory = self.memory_proj(z).to(self.device)  # [B, num_tokens, d_model]
        return memory.transpose(0, 1).to(self.device)

  def generate(self, z, label, temperature=1.0):
        self.eval()
        with torch.no_grad():
            memory = self._prepare_memory(z)
            return self.decoder.generate(
                memory=memory,
                z=z,
                label=label,
                max_len=self.max_length,
                temperature=temperature,
                bos_id=self.bos_id,
                eos_id=self.eos_id
            )


In [80]:

class Token_Embedding(nn.Module):
  def __init__(self,vocab_size,embedding_dim):
    super(Token_Embedding,self).__init__()
    self.embedding = nn.Embedding(vocab_size,embedding_dim)

  def forward(self,x):
    return self.embedding(x)

class Pos_Embedding(nn.Module):
  def __init__(self,max_len,embedding_dim):
    super(Pos_Embedding,self).__init__()
    self.pos_embedding = nn.Embedding(max_len,embedding_dim)

  def forward(self,x):
    seq_len = x.size(1)
    pos_ids = torch.arange(seq_len,device=x.device).unsqueeze(0)
    return self.pos_embedding(pos_ids)

class Transformer_Encoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, num_layers, ff_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embedding_dim,
                nhead=num_heads,
                dim_feedforward=ff_dim,
                batch_first=True  # ← Critical for your input shape
            ) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class LatentSpace_Mean_Log(nn.Module):
  def __init__(self,embedding_dim,latent_dim):
    super(LatentSpace_Mean_Log,self).__init__()
    self.fc_mu = nn.Linear(embedding_dim,latent_dim)
    self.fc_logvar = nn.Linear(embedding_dim,latent_dim)

  def forward(self,x):
    mu = self.fc_mu(x)
    logvar = self.fc_logvar(x)

    return mu,logvar

In [81]:
class Variational_Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len=502, latent_dim=64,
                 num_heads=8, num_layers=6, ff_dim=512, label_dim=0):
        super().__init__()
        assert max_len >= 502, "max_len must cover BOS+EOS+500 tokens"
        self.token_embedding = Token_Embedding(vocab_size, embedding_dim)
        self.pos_embedding = Pos_Embedding(max_len, embedding_dim)
        self.encoder = Transformer_Encoder(embedding_dim, num_heads, num_layers, ff_dim)
        self.latent_proj = LatentSpace_Mean_Log(embedding_dim, latent_dim)
        self.latent_dim = latent_dim
        # Optional label conditioning
        self.label_projection = nn.Linear(label_dim, embedding_dim) if label_dim > 0 else None

    def forward(self, x, label=None):
        # Input x: [batch_size, 502]
        #print(f"Input shape: {x.shape}")  # Should be [batch, 502]
        #print(f"Input device: {x.device}")
        tok_emb = self.token_embedding(x)  # [B, 502, D]
        pos_emb = self.pos_embedding(x)    # [B, 502, D]
        embeddings = tok_emb + pos_emb

        # Inject label info (if provided)
        if self.label_projection and label is not None:
            label_emb = self.label_projection(label).unsqueeze(1)  # [B, 1, D]
            embeddings += label_emb.expand(-1, x.size(1), -1)  # [B, 502, D]

        # Transformer process
        output = self.encoder(embeddings)  # [B, 502, D]

        # Pool and project to latent space
        pooled = output.mean(dim=1)  # [B, D]
        mu, logvar = self.latent_proj(pooled)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

        return z, mu, logvar


In [82]:
#!!! Adapted from HW5!!!

from torch import nn, Tensor
import torch
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Cross-Attention Transformer Decoder
class Transformer_Decoder(nn.Module):

    def __init__(self,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 d_model: int,
                 nhead: int,
                 d_hid: int,
                 nlayers: int,
                 latent_dim: int,
                 label_dim: int,
                 dropout: float = 0.5,
                 device = 'cuda'
    ):
        super().__init__()
        self.model_type = 'Transformer'
        self.device = device
        # discrete latent tokens
        # self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        # predicted MIDI tokens
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        self.memory_proj = nn.Linear(latent_dim, d_model)

        
        # Decoder (self-attention and cross-attention)
        dec_layer = TransformerDecoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(dec_layer, nlayers)
        self.d_model = d_model

        self.linear = nn.Linear(d_model, tgt_vocab_size)

        #weight modulation layers
        self.z_scale = nn.Linear(latent_dim, d_model)
        self.z_shift = nn.Linear(latent_dim, d_model)
        # label dim should be B x
        #label conditioning layers for cross attention
        self.label_projection = nn.Embedding(label_dim,d_model)
        self.label_attn = nn.MultiheadAttention(d_model,nhead,dropout=dropout)
        # Project z to scaling factors
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.memory_proj.weight.data.uniform_(-initrange, initrange)
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, tgt: Tensor,memory: Tensor,z : Tensor, label:Tensor,teacher_forcing=True):
        # Scale and Embed the memory sequence
        #print(f'Memory Shape{memory.shape}')
        memory = self.memory_proj(memory) * math.sqrt(self.d_model)
        # Add positional encoding
        memory = self.pos_encoder(memory)
        
        # teacher forcing so that during training we can train well but also
        # autoregressively predict during evaluation
        #print(f'tgt shape {tgt.shape}')
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoder(tgt_emb)
        #print(f'tgt emb{tgt_emb.shape}')
        shift = self.z_shift(z)
        scale = self.z_scale(z)
        #print(f'shift{shift.shape}')
        #print(f'scale{scale.shape}')
        #print(f'tgt shape{tgt_emb.shape} shift shape {shift.shape} scale shape {scale.shape}')
        tgt_emb = tgt_emb * scale.unsqueeze(0) + shift.unsqueeze(0)
    
        label_emb = self.label_projection(label).unsqueeze(0)
        label_emb = label_emb.expand(tgt_emb.shape[0], -1, -1)  # (batch_size,seq_len, d_model)
        tgt_emb = self.label_attn(query=tgt_emb, key=label_emb, value=label_emb)[0]
            
        # Create the mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(len(tgt_emb)).to(self.device)
        # Pass them through the transformer
        output = self.transformer_decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask)
        # Apply the linear layer
        output = self.linear(output)
        output = output.transpose(0, 1)
        
        return output

    def generate(self, memory, z, label, max_len=502, temperature=1.0, bos_id=129, eos_id=130):
        batch_size = z.size(0)
        device = self.device
    
        # 1. Start with [BOS] token for all sequences in batch
        tgt = torch.full((1, batch_size), bos_id, device=device)  # [1, batch_size]
        # 2. Project and encode memory
        memory = self.memory_proj(memory) * math.sqrt(self.d_model)
        memory = self.pos_encoder(memory)

        logits_list = []
        for i in range(max_len-1):
            # 3. Forward pass
            logits = self(
                tgt=tgt.transpose(0, 1),  # Transformer expects [seq_len, batch_size]
                memory=memory,
                z=z,
                label=label,
                teacher_forcing=False
            ) # Get logits for last token only [batch_size, vocab_size]
            logits_list.append(logits)
            logits = logits.transpose(0, 1)[-1,:,:]
            # 4. Sample next token
            probs = F.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples= 1).transpose(0, 1)  # [1, batch_size]
            # 5. Append next token
            tgt = torch.cat([tgt, next_token], dim=0)  # along seq_len (dim=0)
            if i == max_len - 1:
                next_token.fill_(eos_id)  # Force EOS token at max_len
                print("Forcing <EOS> token at 502nd position.")
                tgt = torch.cat([tgt, next_token], dim=0)
                break
            # 6. Early stopping if all samples predict EOS
            if (next_token == eos_id).all():
                break
        #print(f'tgt shape{tgt.shape}')
        return tgt,torch.cat(logits_list, dim=0)  # [seq_len-1, batch_size]

In [5]:
#from google.colab import drive
#drive.mount('/content/drive')

In [6]:
file_directory = '/content/drive/MyDrive/maestro_token_sequences.csv' # colab
file_directory = '/projectnb/ec523/projects/proj_MIDIgen/maestro_token_sequences.csv' #scc

In [7]:
import pandas as pd
import numpy as np
import io

def string_to_vector(seq_str):
    return [int(x) for x in seq_str.split()]

def prepare_data(file, augmentations=False, collaborations=False,split='train'):
    # Read the data
    df = pd.read_csv(file)

    # only read train data
    if(split != 'all'):
      df = df[df['split'] == split].copy()

    # Convert sequences to vectors
    df['sequence_vector'] = df['sequence'].apply(string_to_vector)

    # Apply filters
    if (not collaborations):
        df = df[~df['composer'].str.contains('/', na=False)].copy()

    if (not augmentations):
        df = df[df['transposition amount'] == 0].copy()

    # Create labels (assuming composer_label_dict exists)
    clean_df = df.copy()  # Final cleaned version
    composer_list = sorted(list(set(clean_df['composer'])))  # Convert to sorted list for consistent ordering
    num_composers = len(composer_list)
    print(f"Unique composers: {composer_list}")
    print(f"Total composers: {num_composers}")

    # Create proper label dictionary
    composer_label_dict = {composer: idx for idx, composer in enumerate(composer_list)}
    index_composer_dict = {idx: composer for composer, idx in composer_label_dict.items()}
    clean_df['label'] = clean_df['composer'].map(composer_label_dict)

    # Select only the two columns we want
    data = clean_df[['sequence_vector', 'label']].copy()

    return data,composer_label_dict,index_composer_dict

In [90]:
import torch
from torch.utils.data import Dataset,DataLoader
class MIDI_Dataset(Dataset):
  def __init__(self,sequences,labels):
    self.sequences = sequences
    self.labels = labels

  def __len__(self):
      return len(self.sequences)

  def __getitem__(self, idx):
      # Convert to tensor directly (no padding needed)
      seq = torch.tensor(self.sequences[idx], dtype=torch.long)

      if self.labels is not None:
          label = torch.tensor(self.labels.iloc[idx], dtype=torch.long)
          return seq, label
      return seq
def create_MIDI_Dataloaders(train_data, batch_size=16):

  # Create datasets
  train_dataset = MIDI_Dataset(
      sequences=train_data['sequence_vector'].tolist(),
      labels=train_data['label']
  )

  # Create dataloaders
  train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      shuffle=True,
      pin_memory=True,
      num_workers=4  # Parallel loading
  )

  return train_loader


In [152]:
# Prepare Data
data,composer_to_label_map,inv_map = prepare_data(file_directory,augmentations=False,split='validation')
train_data = create_MIDI_Dataloaders(data)

Unique composers: ['Alexander Scriabin', 'César Franck', 'Felix Mendelssohn', 'Franz Liszt', 'Franz Schubert', 'Frédéric Chopin', 'Johann Sebastian Bach', 'Johannes Brahms', 'Joseph Haydn', 'Ludwig van Beethoven', 'Mily Balakirev', 'Robert Schumann', 'Sergei Rachmaninoff', 'Wolfgang Amadeus Mozart']
Total composers: 14


In [131]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import json
import sys

# Tried 256 hidden layers initiallty, but overfitted... :( (Also increased dropout from 0.3 to 0.5)
class MidiGRUClassifier(nn.Module):
    def __init__(self, vocab_size=131, embed_dim=128, hidden_size=128, num_layers=2, num_classes=34, bidirectional=True, dropout=0.5):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1  

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_size, num_layers,
                          batch_first=True,
                          dropout=dropout if num_layers > 1 else 0,
                          bidirectional=bidirectional)
        self.norm = nn.LayerNorm(hidden_size * self.num_directions)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size * self.num_directions, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        x = self.embedding(x)
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(x.device)
        out, _ = self.gru(x, h0)
        out = self.norm(out[:, -1, :])
        return self.fc(out)

In [153]:
import torch
import torch.optim as optim
# Set random seed for reproducibility
torch.manual_seed(0)

# Shared Parameters
vocab_size = 131        # Number of unique MIDI tokens
embedding_dim = 128     # Size of token embeddings
max_len = 502           # Max sequence length (500 tokens + EOS + BOS)
latent_dim = 128        # Latent space dimension
num_heads = 8           # Number of attention heads
num_layers = 6          # Number of transformer layers
ff_dim = 512            # Feed-forward layer dimension
dropout = 0.1           # Dropout rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Additional Decoder-specific Parameters
label_dim = 55          # Dimension for label embeddings (# of composers)
hidden_dim = 256        # Hidden dimension in decoder (d_hid in your code)

print(data.shape)
# Initialize Encoder
encoder = Variational_Encoder(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    max_len=max_len,
    latent_dim=latent_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    ff_dim=ff_dim
).to(device)

# Initialize Decoder
decoder = Transformer_Decoder(
    src_vocab_size=vocab_size,    # Same as encoder vocab size
    tgt_vocab_size=vocab_size,    # Same unless you have different input/output vocabs
    d_model=embedding_dim,        # Should match encoder's embedding_dim
    nhead=num_heads,              # Same as encoder
    d_hid=hidden_dim,             # Decoder-specific hidden dim
    nlayers=num_layers,           # Same as encoder
    latent_dim=latent_dim,        # Same as encoder
    label_dim=label_dim,          # For composer/style conditioning
    device=device,
    dropout=dropout,
).to(device)

#load classifier
#vocab_size=131, embed_dim=128, hidden_size=128, num_layers=2, num_classes=len(label_map)
classifier = torch.load("midi_gru_classifier_v2.pth", map_location=device)

classifier.eval()
classifier.to(device)

# workout since classifier was trained with 
optimizer = optim.Adam([{'params': encoder.parameters(), 'lr': 1e-4},{'params': decoder.parameters(), 'lr': 3e-4}],weight_decay=1e-5) 
                       
model = MIDIVAE(encoder,decoder,classifier,max_len)

print("Encoder initialized with:")
print(f"- Vocab size: {vocab_size}")
print(f"- Embedding dim: {embedding_dim}")
print(f"- Latent dim: {latent_dim}")
print(f"- {num_layers} layers with {num_heads} attention heads each")

print("\nDecoder initialized with:")
print(f"- Same vocab size: {vocab_size}")
print(f"- Label embedding dim: {label_dim}")
print(f"- Hidden dim: {hidden_dim}")
print(f"- Using dropout: {dropout}")



(4130, 2)
Encoder initialized with:
- Vocab size: 131
- Embedding dim: 128
- Latent dim: 128
- 6 layers with 8 attention heads each

Decoder initialized with:
- Same vocab size: 131
- Label embedding dim: 55
- Hidden dim: 256
- Using dropout: 0.1


  classifier = torch.load("midi_gru_classifier_v2.pth", map_location=device)


In [None]:
#def train_model(self, dataloader, optimizer, epochs=10):
model.train_model(train_data,optimizer)

Training Epoch 1:  96%|█████████▌| 248/259 [14:44<00:39,  3.56s/it]