In [30]:
import torch.nn as nn
import torch.nn.functional as F
import torch
import random
from tqdm import tqdm


REST_ID = 128
BOS_ID = 129
EOS_ID = 130

class DWA:
    def __init__(self, T=2.0):
        self.T = T  # temperature (smooths the adjustment)
        self.loss_hist = []  # store past two epochs

    def update(self, losses):
        self.loss_hist.append(losses)
        if len(self.loss_hist) > 2:
            self.loss_hist.pop(0)

    def get_weights(self):
        if len(self.loss_hist) < 2:
            return [1.0, 1.0, 1.0]  # equal at beginning
        
        w = []
        for i in range(3):
            r = self.loss_hist[1][i] / (self.loss_hist[0][i] + 1e-8)
            w.append(r)
        
        w = torch.exp(torch.tensor(w) / self.T)
        weights = (3 * w / w.sum()).tolist()  # 3 tasks here
        return weights

class MIDIVAE(nn.Module):
  def __init__(self, encoder, decoder,classifier,max_length=502,device = 'cuda',kld_weight = 1.0, classifier_weight = 1.0):
      super(MIDIVAE,self).__init__()

      self.device = device
      # untrained components
      self.encoder = encoder.to(self.device)
      self.decoder = decoder.to(self.device)  # Cross-Attention transformer to generate MIDI vectors
      self.max_length = max_length
      #workaround for cross-attention
      self.num_memory_tokens = (encoder.latent_dim // 16)
      self.memory_proj = nn.Sequential(
            nn.Linear(encoder.latent_dim, decoder.d_model * self.num_memory_tokens),
            nn.Unflatten(1, (self.num_memory_tokens, decoder.d_model))  # [B, num_tokens, d_model]
        ).to(self.device)
      # pre-trained classifier
      self.classifier = classifier

      #constants
      self.kld_weight = kld_weight
      self.classifier_weight = classifier_weight
      self.BOS_ID = BOS_ID
      self.EOS_ID = EOS_ID
      self.REST_ID = REST_ID

  
  def forward(self, x, label, teacher_forcing_ratio=0.5):
        z, mean, logvar = self.encoder(x, label)
        memory = self._prepare_memory(z.to(self.device)).to(self.device)

        tgt = torch.full((1, x.shape[0]), self.BOS_ID, device=self.device)  # [1, batch_size]
        generated_tokens = tgt
        #print(f'generated tok shape { generated_tokens.shape}')
        logits_seq = []

        # --- TRAINING PATH (teacher forcing) ------------------------------------
        if self.training:
            tgt_in  = torch.cat([tgt, x.T[:-1]], dim=0)   # [T, B] <- prepend <BOS>
            logits  = self.decoder(                       # [B, T, V] (batch_first)
                         tgt=tgt_in,                    # give decoder [B, T]
                         memory=memory, z=z, label=label,
                         teacher_forcing=True)
            recon_logits = logits                        # keep for CE loss
            recon_tokens = x.T                           # ground-truth tokens
            composer_pred = self.classifier(x)           # use GT for classifier
            return recon_logits, recon_tokens, mean, logvar, composer_pred
        
  def train_model(self, dataloader, optimizer, epochs=10,start_forcing_ratio=0.9,decay_factor=0.85):
        dwa = DWA() #dynamic balancing for losses
        self.train()
        for epoch in range(epochs):
            torch.cuda.empty_cache()
            forcing_ratio = start_forcing_ratio * (decay_factor ** epoch)
            epoch_losses = [0.0, 0.0, 0.0]
            total_loss = 0
            loss_weights = dwa.get_weights()
            for x, label in tqdm(dataloader,desc = f"Training Epoch {epoch + 1}"):
                x, label = x.to(self.device), label.to(self.device)

                optimizer.zero_grad()
                recon_logits, targets, mu, logvar, pred = self(x, label)
                
                # Loss calculations
                kld_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
                recon_loss = F.cross_entropy(recon_logits.reshape(-1, recon_logits.size(-1)),targets.reshape(-1))
                cls_loss = F.cross_entropy(pred, label)
                loss = loss_weights[0] * recon_loss + loss_weights[1] * kld_loss + loss_weights[2] * cls_loss
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
                loss.backward()
                optimizer.step()

                epoch_losses[0] += recon_loss.item()
                epoch_losses[1] += kld_loss.item()
                epoch_losses[2] += cls_loss.item()
      
                total_loss += loss.item()
                
            dwa.update(epoch_losses)
            print(f"Epoch {epoch+1} | Loss: {total_loss/len(dataloader):.4f} | KLD: {kld_loss/len(dataloader)} | CLS: {cls_loss/len(dataloader)}")

            # Checkpoint: Save the model to a temporary directory after each epoch 
            # in case the model doesn't finish training during the 12 hour SCC secession
            torch.save(model.state_dict(), f"/projectnb/ec523/projects/proj_MIDIgen/checkpoints/MIDIVAE_v2_CHECKPOINT_{epoch+1}.pth")
  
  def _prepare_memory(self, z):
        memory = self.memory_proj(z).to(self.device)  # [B, num_tokens, d_model]
        return memory.transpose(0, 1).to(self.device)

  def generate(self, z, label, temperature=1.0):
        self.eval()
        with torch.no_grad():
            memory = self._prepare_memory(z)
            print(memory.shape)
            return self.decoder.generate(
                memory=memory,
                z=z,
                label=label,
                max_len=self.max_length,
                temperature=temperature,
                bos_id=self.BOS_ID,
                eos_id=self.EOS_ID
            )


In [31]:

class Token_Embedding(nn.Module):
  def __init__(self,vocab_size,embedding_dim):
    super(Token_Embedding,self).__init__()
    self.embedding = nn.Embedding(vocab_size,embedding_dim)

  def forward(self,x):
    return self.embedding(x)

class Pos_Embedding(nn.Module):
  def __init__(self,max_len,embedding_dim):
    super(Pos_Embedding,self).__init__()
    self.pos_embedding = nn.Embedding(max_len,embedding_dim)

  def forward(self,x):
    seq_len = x.size(1)
    pos_ids = torch.arange(seq_len,device=x.device).unsqueeze(0)
    return self.pos_embedding(pos_ids)

class Transformer_Encoder(nn.Module):
    def __init__(self, embedding_dim, num_heads, num_layers, ff_dim):
        super().__init__()
        self.layers = nn.ModuleList([
            nn.TransformerEncoderLayer(
                d_model=embedding_dim,
                nhead=num_heads,
                dim_feedforward=ff_dim,
                batch_first=True  # ← Critical for your input shape
            ) for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class LatentSpace_Mean_Log(nn.Module):
  def __init__(self,embedding_dim,latent_dim):
    super(LatentSpace_Mean_Log,self).__init__()
    self.fc_mu = nn.Linear(embedding_dim,latent_dim)
    self.fc_logvar = nn.Linear(embedding_dim,latent_dim)

  def forward(self,x):
    mu = self.fc_mu(x)
    logvar = self.fc_logvar(x)

    return mu,logvar

In [3]:
class Variational_Encoder(nn.Module):
    def __init__(self, vocab_size, embedding_dim, max_len=502, latent_dim=64,
                 num_heads=8, num_layers=6, ff_dim=512, label_dim=34):
        super().__init__()
        assert max_len >= 502, "max_len must cover BOS+EOS+500 tokens"
        self.token_embedding = Token_Embedding(vocab_size, embedding_dim)
        self.pos_embedding = Pos_Embedding(max_len, embedding_dim)
        self.encoder = Transformer_Encoder(embedding_dim, num_heads, num_layers, ff_dim)
        self.latent_proj = LatentSpace_Mean_Log(embedding_dim, latent_dim)
        self.latent_dim = latent_dim
        # Optional label conditioning
        self.label_projection = nn.Embedding(label_dim, embedding_dim) if label_dim > 0 else None
        self.device = 'cuda'
        # initialization
        #token embedding
        torch.nn.init.xavier_uniform_(self.token_embedding.embedding.weight)
        #posembedding
        torch.nn.init.xavier_uniform_(self.pos_embedding.pos_embedding.weight)
        #latentspacemeanlog
        torch.nn.init.xavier_uniform_(self.latent_proj.fc_mu.weight)
        torch.nn.init.zeros_(self.latent_proj.fc_mu.bias)
        torch.nn.init.xavier_uniform_(self.latent_proj.fc_logvar.weight)
        torch.nn.init.zeros_(self.latent_proj.fc_logvar.bias)
        #labelprojection
        torch.nn.init.xavier_uniform_(self.label_projection.weight)

        for layer in self.encoder.layers:
            # Initialize the multi-head attention layers
            torch.nn.init.xavier_uniform_(layer.self_attn.in_proj_weight)  # Query, Key, Value weights
            torch.nn.init.xavier_uniform_(layer.self_attn.out_proj.weight)  # Output projection
    
            
            # Initialize feed-forward network weights
            torch.nn.init.xavier_uniform_(layer.linear1.weight)  # First linear layer of FFN
            torch.nn.init.xavier_uniform_(layer.linear2.weight)  # Second linear layer of FFN
            
            # Initialize layer normalization weights
            torch.nn.init.constant_(layer.norm1.weight, 1)  # Norm1 (self-attention layer)
            torch.nn.init.constant_(layer.norm2.weight, 1)  # Norm2 (feed-forward layer)
            
            # Initialize biases
            torch.nn.init.zeros_(layer.self_attn.in_proj_bias)  # Bias for attention weights
            torch.nn.init.zeros_(layer.self_attn.out_proj.bias)  # Bias for attention output
            torch.nn.init.zeros_(layer.linear1.bias)  # Bias for first linear layer
            torch.nn.init.zeros_(layer.linear2.bias)  # Bias for second linear layer
            torch.nn.init.zeros_(layer.norm1.bias)  # Bias for Norm1
            torch.nn.init.zeros_(layer.norm2.bias)  # Bias for Norm2
            
    def forward(self, x, label=None):
        # Input x: [batch_size, 502]
        #print(f"Input shape: {x.shape}")  # Should be [batch, 502]
        #print(f"Input device: {x.device}")
        tok_emb = self.token_embedding(x)  # [B, 502, D]
        pos_emb = self.pos_embedding(x)    # [B, 502, D]
        embeddings = tok_emb + pos_emb

        # Inject label info (if provided)
        if self.label_projection and label is not None:
            label_emb = self.label_projection(label.to(self.device)).unsqueeze(1)  # [B, 1, D]
            #shape of label_emb = [16,1,128]
            #shape of embedding = [16,502,128]
            #x.size(1) is 502
            temp = label_emb.expand(-1, x.size(1), -1)  # [B, 502, D]
            embeddings = embeddings + temp
        
        # Transformer process
        output = self.encoder(embeddings)  # [B, 502, D]

        # Pool and project to latent space
        pooled = output.mean(dim=1)  # [B, D]
        mu, logvar = self.latent_proj(pooled)
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        z = mu + eps * std

        return z, mu, logvar


In [4]:
#!!! Adapted from HW5!!!

from torch import nn, Tensor
import torch
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Cross-Attention Transformer Decoder
class Transformer_Decoder(nn.Module):

    def __init__(self,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 d_model: int,
                 nhead: int,
                 d_hid: int,
                 nlayers: int,
                 latent_dim: int,
                 label_dim: int,
                 dropout: float = 0.5,
                 device = 'cuda'
    ):
        super().__init__()
        self.model_type = 'Transformer'
        self.device = device
        # discrete latent tokens
        # self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        # predicted MIDI tokens
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout)

        self.memory_proj = nn.Linear(latent_dim, d_model)

        
        # Decoder (self-attention and cross-attention)
        dec_layer = TransformerDecoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(dec_layer, nlayers)
        self.d_model = d_model

        self.linear = nn.Linear(d_model, tgt_vocab_size)

        #weight modulation layers
        self.z_scale = nn.Linear(latent_dim, d_model)
        self.z_shift = nn.Linear(latent_dim, d_model)
        # label dim should be B x
        #label conditioning layers for cross attention
        self.label_projection = nn.Embedding(label_dim,d_model)
        self.label_attn = nn.MultiheadAttention(d_model,nhead,dropout=dropout)
        # Project z to scaling factors
        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)
        torch.nn.init.xavier_uniform_(self.tgt_embedding.weight)
        torch.nn.init.xavier_uniform_(self.label_projection.weight)
        torch.nn.init.xavier_uniform_(self.memory_proj.weight)
        torch.nn.init.xavier_uniform_(self.linear.weight)
        torch.nn.init.xavier_uniform_(self.z_scale.weight)
        torch.nn.init.xavier_uniform_(self.z_shift.weight)
        self.memory_proj.bias.data.zero_()
        self.linear.bias.data.zero_()
        self.z_scale.bias.data.zero_()
        self.z_shift.bias.data.zero_()
        for layer in self.transformer_decoder.layers:
            # Initialize multi-head attention weights
            torch.nn.init.xavier_uniform_(layer.self_attn.in_proj_weight)
            torch.nn.init.xavier_uniform_(layer.self_attn.out_proj.weight)
            torch.nn.init.xavier_uniform_(layer.multihead_attn.in_proj_weight)
            torch.nn.init.xavier_uniform_(layer.multihead_attn.out_proj.weight)
        
        # Initialize feed-forward network weights
        torch.nn.init.xavier_uniform_(layer.linear1.weight)
        torch.nn.init.xavier_uniform_(layer.linear2.weight)
        
        # Initialize layer norms
        torch.nn.init.constant_(layer.norm1.weight, 1)
        torch.nn.init.constant_(layer.norm2.weight, 1)

        torch.nn.init.xavier_uniform_(self.label_attn.in_proj_weight)
        torch.nn.init.xavier_uniform_(self.label_attn.out_proj.weight)

    def forward(self, tgt: Tensor,memory: Tensor,z : Tensor, label:Tensor,teacher_forcing=True):
        # Scale and Embed the memory sequence
        memory = self.memory_proj(memory) * math.sqrt(self.d_model)
        # Add positional encoding
        memory = self.pos_encoder(memory)
        # teacher forcing so that during training we can train well but also
        # autoregressively predict during evaluation
        #print(f'tgt shape {tgt.shape}')
        tgt_emb = self.tgt_embedding(tgt) * math.sqrt(self.d_model)
        tgt_emb = self.pos_encoder(tgt_emb)
        #print(f'tgt emb{tgt_emb.shape}')
        shift = self.z_shift(z)
        scale = self.z_scale(z)
        #print(f'shift{shift.shape}')
        #print(f'scale{scale.shape}')
        #print(f'tgt shape{tgt_emb.shape} shift shape {shift.shape} scale shape {scale.shape}')
        tgt_emb = tgt_emb * scale.unsqueeze(0) + shift.unsqueeze(0)
    
        label_emb = self.label_projection(label).unsqueeze(0)
        label_emb = label_emb.expand(tgt_emb.shape[0], -1, -1)  # (batch_size,seq_len, d_model)
        #print(f'tgt emb{tgt_emb.shape} | label_emb shape {label_emb.shape}')
        tgt_emb = self.label_attn(query=tgt_emb, key=label_emb, value=label_emb)[0]
            
        # Create the mask
        tgt_mask = nn.Transformer.generate_square_subsequent_mask(len(tgt_emb)).to(self.device)
        # Pass them through the transformer
        output = self.transformer_decoder(tgt=tgt_emb, memory=memory, tgt_mask=tgt_mask)
        # Apply the linear layer
        output = self.linear(output)
        output = output.transpose(0, 1)
        
        return output

    def generate(self, memory, z, label, max_len=502, temperature=1.0, bos_id=129, eos_id=130):
        batch_size = z.size(0)
        device = self.device
    
        # 1. Start with [BOS] token for all sequences in batch
        tgt = torch.full((1, batch_size), bos_id, device=device)  # [1, batch_size]
        # 2. Project and encode memory
        memory = self.memory_proj(memory) * math.sqrt(self.d_model)
        memory = self.pos_encoder(memory)

        #logits_list = []
        for i in range(max_len-1):
            # 3. Forward pass
            logits = self(
                tgt=tgt,  # Transformer expects [seq_len, batch_size]
                memory=memory,
                z=z,
                label=label,
                teacher_forcing=False
            ) # Get logits for last token only [batch_size, vocab_size]
            #logits_list.append(logits)
            logits = logits.transpose(0, 1)[-1,:,:]
            # 4. Sample next token
            probs = F.softmax(logits / temperature, dim=-1)
            next_token = torch.multinomial(probs, num_samples= 1).transpose(0, 1)  # [1, batch_size]
            # 5. Append next token
            tgt = torch.cat([tgt, next_token], dim=0)  # along seq_len (dim=0)
            if i == max_len - 1:
                next_token.fill_(eos_id)  # Force EOS token at max_len
                print("Forcing <EOS> token at 502nd position.")
                tgt = torch.cat([tgt, next_token], dim=0)
                break
            # 6. Early stopping if all samples predict EOS
            if (next_token == eos_id).all():
                break
        #print(f'tgt shape{tgt.shape}')
        return tgt#,torch.cat(logits_list, dim=0)  # [seq_len-1, batch_size]

In [5]:
#from google.colab import drive
#drive.mount('/content/drive')

In [6]:
file_directory = '/content/drive/MyDrive/maestro_token_sequences.csv' # colab
file_directory = '/projectnb/ec523/projects/proj_MIDIgen/maestro_new_splits_no_augmentation.csv' #scc


In [7]:
import pandas as pd
import numpy as np
import io

def string_to_vector(s):
    numbers = [int(x) for x in s.split()]
    return numbers


def prepare_data(file, augmentations=False, collaborations=False,split='train'):
    # Read the data
    df = pd.read_csv(file)
    print(df.head())
    # only read train data
    if(split != 'all'):
      df = df[df['split'] == split].copy()

    # Convert sequences to vectors
    df['sequence_vector'] = df['sequence_vector'].apply(string_to_vector)
    # Apply filters
    if (not collaborations):
        df = df[~df['label'].str.contains('/', na=False)].copy()

    if (not augmentations):
        df = df[df['transposition amount'] == 0].copy()

    # Create labels (assuming composer_label_dict exists)
    clean_df = df.copy()  # Final cleaned version
    composer_list = sorted(list(set(clean_df['label'])))  # Convert to sorted list for consistent ordering
    num_composers = len(composer_list)
    print(f"Unique composers: {composer_list}")
    print(f"Total composers: {num_composers}")

    # Create proper label dictionary
    composer_label_dict = {composer: idx for idx, composer in enumerate(composer_list)}
    index_composer_dict = {idx: composer for composer, idx in composer_label_dict.items()}
    clean_df['label'] = clean_df['label'].map(composer_label_dict)

    # Select only the two columns we want
    data = clean_df[['sequence_vector', 'label']].copy()

    return data,composer_label_dict,index_composer_dict

In [8]:
import pandas as pd
import numpy as np
import io

def string_to_vector(s):
    s = s.strip('[]')  # Removes [ and ] from the start and end
    numbers = [int(x.strip()) for x in s.split(',')]
    return numbers


def prepare_data(file, augmentations=False, collaborations=False,split='train'):
    # Read the data
    df = pd.read_csv(file)
    print(df.head())
    # only read train data
    if(split != 'all'):
      df = df[df['split'] == split].copy()

    # Convert sequences to vectors
    df['sequence_vector'] = df['sequence_vector'].apply(string_to_vector)
    # Apply filters

    # Create labels (assuming composer_label_dict exists)
    clean_df = df.copy()  # Final cleaned version
    composer_list = sorted(list(set(clean_df['label'])))  # Convert to sorted list for consistent ordering
    num_composers = len(composer_list)
    print(f"Unique composers: {composer_list}")
    print(f"Total composers: {num_composers}")

    # Create proper label dictionary
    composer_label_dict = {composer: idx for idx, composer in enumerate(composer_list)}
    index_composer_dict = {idx: composer for composer, idx in composer_label_dict.items()}
    clean_df['label'] = clean_df['label'].map(composer_label_dict)

    # Select only the two columns we want
    data = clean_df[['sequence_vector', 'label']].copy()

    return data,composer_label_dict,index_composer_dict

In [9]:
import torch
from torch.utils.data import Dataset,DataLoader
class MIDI_Dataset(Dataset):
  def __init__(self,sequences,labels):
    self.sequences = sequences
    self.labels = labels

  def __len__(self):
      return len(self.sequences)

  def __getitem__(self, idx):
      # Convert to tensor directly (no padding needed)
      seq = torch.tensor(self.sequences[idx], dtype=torch.long)

      if self.labels is not None:
          label = torch.tensor(self.labels.iloc[idx], dtype=torch.long)
          return seq, label
      return seq
def create_MIDI_Dataloaders(train_data, batch_size=32):

  # Create datasets
  train_dataset = MIDI_Dataset(
      sequences=train_data['sequence_vector'].tolist(),
      labels=train_data['label']
  )

  # Create dataloaders
  train_loader = DataLoader(
      train_dataset,
      batch_size=batch_size,
      shuffle=True,
      pin_memory=True,
      num_workers=4  # Parallel loading
  )

  return train_loader


In [10]:
# Prepare Data
data,composer_to_label_map,inv_map = prepare_data(file_directory,augmentations=False,split='train')
train_data = create_MIDI_Dataloaders(data)

                                     sequence_vector  label  split
0  [129, 48, 48, 48, 48, 48, 48, 48, 48, 48, 48, ...     20  train
1  [129, 64, 64, 64, 67, 60, 60, 62, 62, 62, 60, ...      9  train
2  [129, 31, 31, 31, 31, 31, 31, 31, 31, 31, 31, ...     23  train
3  [129, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, ...      2  train
4  [129, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, ...     32  train
Unique composers: [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33]
Total composers: 34


In [11]:
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from sklearn.metrics import accuracy_score
from tqdm import tqdm
import json
import sys

# Tried 256 hidden layers initiallty, but overfitted... :( (Also increased dropout from 0.3 to 0.5)
class MidiGRUClassifier(nn.Module):
    def __init__(self, vocab_size=131, embed_dim=128, hidden_size=128, num_layers=2, num_classes=34, bidirectional=True, dropout=0.5):
        super().__init__()
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_classes = num_classes
        self.bidirectional = bidirectional
        self.num_directions = 2 if bidirectional else 1  

        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.gru = nn.GRU(embed_dim, hidden_size, num_layers,
                          batch_first=True,
                          dropout=dropout if num_layers > 1 else 0,
                          bidirectional=bidirectional)
        self.norm = nn.LayerNorm(hidden_size * self.num_directions)
        self.fc = nn.Sequential(
            nn.Linear(hidden_size * self.num_directions, hidden_size),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_size, num_classes)
        )

    def forward(self, x):
        x = self.embedding(x)
        batch_size = x.size(0)
        h0 = torch.zeros(self.num_layers * self.num_directions, batch_size, self.hidden_size).to(x.device)
        out, _ = self.gru(x, h0)
        out = self.norm(out[:, -1, :])
        return self.fc(out)

In [12]:
import torch
import torch.optim as optim
# Set random seed for reproducibility
torch.manual_seed(0)
torch.cuda.memory._record_memory_history()

# Shared Parameters
vocab_size = 131        # Number of unique MIDI tokens
embedding_dim = 128     # Size of token embeddings
max_len = 502           # Max sequence length (500 tokens + EOS + BOS)
latent_dim = 128        # Latent space dimension
num_heads = 8           # Number of attention heads
num_layers = 6          # Number of transformer layers
ff_dim = 512            # Feed-forward layer dimension
dropout = 0.1           # Dropout rate
device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Additional Decoder-specific Parameters
label_dim = 34          # Dimension for label embeddings (# of composers)
hidden_dim = 256        # Hidden dimension in decoder (d_hid in your code)

print(data.shape)
# Initialize Encoder
encoder = Variational_Encoder(
    vocab_size=vocab_size,
    embedding_dim=embedding_dim,
    max_len=max_len,
    latent_dim=latent_dim,
    num_heads=num_heads,
    num_layers=num_layers,
    ff_dim=ff_dim
).to(device)

# Initialize Decoder
decoder = Transformer_Decoder(
    src_vocab_size=vocab_size,    # Same as encoder vocab size
    tgt_vocab_size=vocab_size,    # Same unless you have different input/output vocabs
    d_model=embedding_dim,        # Should match encoder's embedding_dim
    nhead=num_heads,              # Same as encoder
    d_hid=hidden_dim,             # Decoder-specific hidden dim
    nlayers=num_layers,           # Same as encoder
    latent_dim=latent_dim,        # Same as encoder
    label_dim=label_dim,          # For composer/style conditioning
    device=device,
    dropout=dropout,
).to(device)

#load classifier
#vocab_size=131, embed_dim=128, hidden_size=128, num_layers=2, num_classes=len(label_map)
classifier = torch.load("midi_gru_classifier_v4.pth", map_location=device)

classifier.eval()
classifier.to(device)

# workout since classifier was trained with 
model = MIDIVAE(encoder,decoder,classifier,max_len)
optimizer = optim.Adam([{'params': encoder.parameters(), 'lr': 3e-4},{'params': decoder.parameters(), 'lr': 3e-4},{'params':model.memory_proj.parameters(),'lr':3e-4}],weight_decay=1e-5) 
                       


print("Encoder initialized with:")
print(f"- Vocab size: {vocab_size}")
print(f"- Embedding dim: {embedding_dim}")
print(f"- Latent dim: {latent_dim}")
print(f"- {num_layers} layers with {num_heads} attention heads each")

print("\nDecoder initialized with:")
print(f"- Same vocab size: {vocab_size}")
print(f"- Label embedding dim: {label_dim}")
print(f"- Hidden dim: {hidden_dim}")
print(f"- Using dropout: {dropout}")



(25192, 2)


  classifier = torch.load("midi_gru_classifier_v4.pth", map_location=device)


Encoder initialized with:
- Vocab size: 131
- Embedding dim: 128
- Latent dim: 128
- 6 layers with 8 attention heads each

Decoder initialized with:
- Same vocab size: 131
- Label embedding dim: 34
- Hidden dim: 256
- Using dropout: 0.1


In [13]:
#def train_model(self, dataloader, optimizer, epochs=10):
model.train_model(train_data,optimizer)


Training Epoch 1: 100%|██████████| 788/788 [01:42<00:00,  7.68it/s]


Epoch 1 | Loss: 5.9371 | KLD: 4.1826075403150753e-07 | CLS: 0.002972245682030916


Training Epoch 2: 100%|██████████| 788/788 [01:42<00:00,  7.68it/s]


Epoch 2 | Loss: 5.8665 | KLD: 3.519113249694783e-07 | CLS: 0.001901553594507277


Training Epoch 3: 100%|██████████| 788/788 [01:42<00:00,  7.67it/s]


Epoch 3 | Loss: 6.7193 | KLD: 3.450830661222426e-07 | CLS: 0.0024236852768808603


Training Epoch 4: 100%|██████████| 788/788 [01:43<00:00,  7.60it/s]


Epoch 4 | Loss: 5.8405 | KLD: 1.8499949305805785e-07 | CLS: 0.0021625205408781767


Training Epoch 5: 100%|██████████| 788/788 [01:43<00:00,  7.62it/s]


Epoch 5 | Loss: 6.2528 | KLD: 2.112688690658615e-07 | CLS: 0.0021148943342268467


Training Epoch 6: 100%|██████████| 788/788 [01:43<00:00,  7.60it/s]


Epoch 6 | Loss: 5.8849 | KLD: 1.58731623400854e-07 | CLS: 0.0023847969714552164


Training Epoch 7: 100%|██████████| 788/788 [01:44<00:00,  7.57it/s]


Epoch 7 | Loss: 6.0690 | KLD: 1.1856084114469922e-07 | CLS: 0.0017963640857487917


Training Epoch 8: 100%|██████████| 788/788 [01:45<00:00,  7.47it/s]


Epoch 8 | Loss: 5.9767 | KLD: 1.0025456731455051e-07 | CLS: 0.0021744626574218273


Training Epoch 9: 100%|██████████| 788/788 [01:44<00:00,  7.54it/s]


Epoch 9 | Loss: 6.0164 | KLD: 9.072479656424548e-08 | CLS: 0.002223146380856633


Training Epoch 10: 100%|██████████| 788/788 [01:44<00:00,  7.51it/s]


Epoch 10 | Loss: 5.9900 | KLD: 7.461472506520295e-08 | CLS: 0.0031632618047297


In [14]:
torch.save(model.state_dict(), '/projectnb/ec523/projects/proj_MIDIgen/MIDIVAE_v2.pth')

In [36]:
# Generate Songs Testing
model = MIDIVAE(encoder,decoder,classifier,max_len)
model.load_state_dict(torch.load('MIDIVAE_v2.pth',map_location=device))
z = torch.randn(1, 128).to('cuda') 
label = torch.LongTensor([12]).to('cuda')
generated_seq = model.generate(z,label)

  model.load_state_dict(torch.load('MIDIVAE_v2.pth',map_location=device))


torch.Size([8, 1, 128])


In [34]:
! pip install mido
! pip install pretty_midi

Defaulting to user installation because normal site-packages is not writeable
Defaulting to user installation because normal site-packages is not writeable


In [40]:
import pretty_midi

# Your tensor
tokens = generated_seq.cpu().numpy()

# Create a PrettyMIDI object
midi = pretty_midi.PrettyMIDI()
instrument = pretty_midi.Instrument(program=0)  # Acoustic Grand Piano

current_time = 0.0
note_duration = 0.5  # Seconds per note
rest_duration = 0.5

for token in tokens:
    token = int(token)        # make doubly sure it’s a Python int
    if token == 129:
        continue  # BOS
    if token == 130:
        break  # EOS
    if token == 128:
        # Rest: just advance time
        current_time += rest_duration
    elif 0 <= token <= 127:
        # Create a Note
        note = pretty_midi.Note(
            velocity=100,  # fixed velocity
            pitch=token,
            start=current_time,
            end=current_time + note_duration
        )
        instrument.notes.append(note)
        current_time += note_duration  # move forward

# Add instrument to MIDI and write it out
midi.instruments.append(instrument)
output_midi = "generated_sequence_pretty.mid"
midi.write(output_midi)

print(f"Saved to {output_midi}")

Saved to generated_sequence_pretty.mid


  token = int(token)        # make doubly sure it’s a Python int


In [None]:
print(generated_seq.transpose(0,1).cpu().numpy())

In [None]:
print(generated_seq2.transpose(0,1).cpu().numpy())

In [None]:
 print(generated_seq3.transpose(0,1).cpu().numpy())

In [None]:
# Latent Space Analysis
