In [None]:
import torch.nn as nn
import torch.nn.functional as F
import torch

class MIDIVAE(nn.Module):
  def __init__(self, encoder,vector_quantizer, decoder,classifier,device = 'cuda'):
      super(MIDIVAE,self).__init__()


      # untrained components
      self.encoder = encoder
      self.vector_quantizer = vector_quantizer  # Vector Quantized VAE to discretize the latent space
      self.decoder = decoder  # Cross-Attention transformer to generate MIDI vectors

      # pre-trained classifier
      self.classifer = classifier


      self.device = device

  def forward(self, x,label):
    #obtain latent space
    z, mean,logvariance = self.encoder(x)

    # quantize the latent space
    z, quantized, vq_loss = self.vector_quantizer(z)

    # concatenate x label with z
    z = torch.cat([z,label.to(self.device)])
    #feed through decoder
    recon_midi = self.decoder(z)

    #feed through classifier
    composer_pred = self.classifier(recon_midi)

    return recon_midi,mean,logvariance,composer_pred, vq_loss

  def train_model(self, dataloader, optimizer, epochs=10, device='cuda'):
    self.train()
    # load data into some kind of trainloader?

    # for loop with loading data for each batch
    for epoch in range(epochs):
        total_loss = 0
        for x, label, composer in dataloader:  # assume dataloader returns inputs, labels, composers
            x, label, composer = x.to(device), label.to(device), composer.to(device)

            optimizer.zero_grad()

            # call forward and get back the reconstructions, mean, logvariance, and composer prediction

            recon_x, mu, logvar, pred, vq_loss = self.forward(x, label)

            # do KLD with mean, logvariance and our prior N(0,I), backprop to encoder only
            kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / x.size(0)

            #! Connor - I made changes here                         <------
            # compute cross-entropy loss between the model’s next-token logits (dropping the last step) and the true next tokens (dropping the first)
            tgt = x[:, 1:].transpose(0, 1)
            logits = recon_x[:, :-1]
            recon_loss = F.cross_entropy(logits.reshape(-1, logits.size(-1)),tgt.reshape(-1)) 

            # do CrossEntropy with data labels and composer predictions, backprop to decoder only

            ce_loss = F.cross_entropy(pred, composer)

            loss = recon_loss + kld + ce_loss + vq_loss  # optionally weigh each

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        print(f"Epoch {epoch+1} | Loss: {total_loss:.4f}")

  def generate(self,label):
    # put on test mode
    self.eval()
    # generate normal gaussian noise
    batch_size = label.size(0)
    z = torch.randn(batch_size, self.encoder.latent_proj.fc_mu.out_features).to(self.device)
    # concatenate noise with label
    z_cond = torch.cat([z, label.to(self.device)], dim=1)
    # feed through decoder

    recon = self.decoder(z_cond)
    return recon


Token Embedding + Positional Encoding
           ↓
  Transformer Encoder
           ↓
     Sequence Embeddings (B, T, D)
           ↓
    Mean Pool / CLS Token → (B, D)
           ↓
      Linear → mu (B, latent_dim)
      Linear → logvar (B, latent_dim)
           ↓
Reparameterization Trick: z = mu + eps * std

label → Embedding → (B, label_dim)
z     → (B, latent_dim)

→ concat [z, label_emb] → (B, latent_dim + label_dim)
→ Linear projection → (B, decoder_dim)  # matches decoder embedding dim
→ z_cond

[BOS, ..., tokens[:t-1]] → token embedding + pos encoding → (B, T', D)

z_cond → unsqueeze(1) → broadcast across T' → (B, T', D)

decoder_input = token_emb + z_cond

Transformer Decoder → token_logits (B, T', vocab_size)


In [28]:

class Token_Embedding(nn.Module):
  def __init__(self,vocab_size,embedding_dim):
    super(Token_Embedding,self).__init__()
    self.embedding = nn.Embedding(vocab_size,embedding_dim)

  def forward(self,x):
    return self.embedding(x)

class Pos_Embedding(nn.Module):
  def __init__(self,max_len,embedding_dim):
    super(Pos_Embedding,self).__init__()
    self.pos_embedding = nn.Embedding(max_len,embedding_dim)

  def forward(self,x):
    seq_len = x.size(1)
    pos_ids = torch.arange(seq_len,device=x.device).unsqueeze(0)
    return self.pos_embedding(pos_ids)

class Transformer_Encoder(nn.Module):
  def __init__(self,embedding_dim,num_heads,num_layers,ff_dim):
    super(Transformer_Encoder,self).__init__()
    self.self_attention = nn.MultiheadAttention(embedding_dim,num_heads)
    self.layer_norm = nn.LayerNorm(embedding_dim)
    self.ffn = nn.Sequential(nn.Linear(embedding_dim,ff_dim),nn.ReLU(),nn.Linear(ff_dim,embedding_dim))

  def forward(self,x):
    print(x.shape)
    # multihead attention is dumb and doesnt like batch size first
    x = x.permute(1, 0, 2)
    output,_ = self.self_attention(x,x,x)
    # maybe add normalization

    ffn_output = self.ffn(x)
    x = self.layer_norm(x + ffn_output)
    #undo what we did
    x = x.permute(1, 0, 2)
    return x

class LatentSpace_Mean_Log(nn.Module):
  def __init__(self,embedding_dim,latent_dim):
    super(LatentSpace_Mean_Log,self).__init__()
    self.fc_mu = nn.Linear(embedding_dim,latent_dim)
    self.fc_logvar = nn.Linear(embedding_dim,latent_dim)

  def forward(self,x):
    mu = self.fc_mu(x)
    logvar = self.fc_logvar(x)

    return mu,logvar

In [30]:
class Variational_Encoder(nn.Module):
  def __init__(self, vocab_size, embedding_dim, max_len, latent_dim, num_heads, num_layers, ff_dim):
      super(Variational_Encoder,self).__init__()

      self.token_embedding = Token_Embedding(vocab_size,embedding_dim)
      self.pos_embedding = Pos_Embedding(max_len,embedding_dim)
      self.encoder = Transformer_Encoder(embedding_dim,num_heads,num_layers,ff_dim)
      self.latent_proj = LatentSpace_Mean_Log(embedding_dim,latent_dim)

  def forward(self, x):
    # create embeddings
    tok_embeddings = self.token_embedding(x)
    pos_embeddings = self.pos_embedding(x)
    embeddings = pos_embeddings + tok_embeddings

    # obtain output
    output = self.encoder(embeddings)

    # pool it up!!
    pooled_output = output.mean(dim=1)

    # project into latent space and reparameterize
    mean,logvariance = self.latent_proj(pooled_output)
    std = torch.exp(0.5 * logvariance)
    eps = torch.randn_like(std)
    z = mean + eps * std

    return z, mean,logvariance


In [None]:
class VectorQuantizer(nn.Module):
    """
    Vector Quantizer implements the VQ‐VAE codebook lookup and loss.
    Args:
      num_embeddings: number of codebook vectors (K)
      embedding_dim:   dimensionality of each vector (D)
      commitment_cost: hyperparameter that pushes the encoder to stay close to their codebook vectors
    """
    def __init__(self, num_embeddings, embedding_dim = 512, commitment_cost = 0.25):
        super().__init__()
        self.num_embeddings = num_embeddings
        self.embedding_dim   = embedding_dim
        self.commitment_cost = commitment_cost

        # Codebook: K × D, initialized uniformly
        self.embedding = nn.Embedding(num_embeddings, embedding_dim)
        self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)
    
    def forward(self, z):
        """
        z: continuous latent vectors from encoder, shape [B, D]
        returns:
          encoding_indices: [B] discrete tokens we are feeding to the transformer decoder
          quantized: [B, D] embedding vectors for each token
          vq_loss: scalar loss
        """
        # get batch-size and latent dimension  
        B, D = z.shape
        assert D == self.embedding_dim, f"Expected latent_dim={self.embedding_dim}, but got {D}"

        # prep for distance calculation
        z_flat = z.view(B, D) # ensure z is [B, D]
        emb_w   = self.embedding.weight # our codebook matrix, shape [K, D]

        # compute squared L2 distances: ||z - e_k ||^2
        z_sq = torch.sum(z_flat**2, dim=1, keepdim=True) # [B, 1], each ||z ||^2
        e_sq = torch.sum(emb_w**2, dim=1) # [K], each ||e_k||^2
        dist = z_sq + e_sq.unsqueeze(0) - 2 * (z_flat @ emb_w.t())  # [B, K], dot(z, e_k')

        # find nearest code for each z
        encoding_indices = torch.argmin(dist, dim=1) # [B]

        # convert to one-hot encoding
        encodings = F.one_hot(encoding_indices, self.num_embeddings).type(z.dtype)  # [B, K]

        # quantized vectors = one_hot encodings @ codebook
        quantized = encodings @ emb_w # [B, D]
        quantized = quantized.view_as(z)

        # compute VQ losses
        # codebook loss: move embeddings toward encoder outputs
        e_latent_loss = F.mse_loss(quantized.detach(), z)
        # commitment loss: move encoder outputs toward embeddings
        q_latent_loss = F.mse_loss(quantized, z.detach())
        vq_loss = e_latent_loss + self.commitment_cost * q_latent_loss

        # straight-through: allow gradients to flow to z
        quantized = z + (quantized - z).detach()

        return encoding_indices, quantized, vq_loss

In [None]:
#!!! Adapted from HW5!!!

from torch import nn, Tensor
import torch
import math
from torch.nn import TransformerEncoder, TransformerEncoderLayer, TransformerDecoder, TransformerDecoderLayer

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

# Cross-Attention Transformer Decoder
class TransformerModel(nn.Module):

    def __init__(self, 
                 src_vocab_size: int, 
                 tgt_vocab_size: int, 
                 d_model: int, 
                 nhead: int, 
                 d_hid: int,
                 nlayers: int, 
                 dropout: float = 0.5,
    ):
        super().__init__()
        self.model_type = 'Transformer'

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)  # discrete latent tokens
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)  # predicted MIDI tokens

        self.pos_encoder = PositionalEncoding(d_model, dropout)

        # Decoder (self-attention and cross-attention)
        dec_layer = TransformerDecoderLayer(d_model, nhead, d_hid, dropout)
        self.transformer_decoder = TransformerDecoder(dec_layer, nlayers)

        self.d_model = d_model
        self.linear = nn.Linear(d_model, tgt_vocab_size)

        self.init_weights()

    def init_weights(self) -> None:
        initrange = 0.1
        self.src_embedding.weight.data.uniform_(-initrange, initrange)
        self.tgt_embedding.weight.data.uniform_(-initrange, initrange)
        self.linear.bias.data.zero_()
        self.linear.weight.data.uniform_(-initrange, initrange)

    def forward(self, 
                tgt: Tensor,
                memory: Tensor,
    ): 
        """
        Args:
          tgt: [tgt_seq_len, batch_size] - the target sequence to predict the next token of (discrete MIDI tokens)
          memory: [src_seq_len, batch_size] - the discrete VQ-VAE code indices
        Returns:
          [tgt_seq_len, batch_size, tgt_vocab_size] logits output for the next token
        """
        
        src = self.src_embedding(memory) * math.sqrt(self.d_model) # Scale and Embed the source sequence
        src = self.pos_encoder(src) # Add positional encoding

        tgt = self.tgt_embedding(tgt) * math.sqrt(self.d_model) # Scale and Embed the target sequence
        tgt = self.pos_encoder(tgt) # Add positional encoding


        tgt_mask = nn.Transformer.generate_square_subsequent_mask(len(tgt)).to(device)  # Create the mask
        output = self.transformer_decoder(tgt=tgt, memory=src, tgt_mask=tgt_mask)  # Pass them through the transformer
        output = self.linear(output)  # Apply the linear layer
        return output

torch.manual_seed(0)

In [31]:
# example of instantiation of it

vocab_size = 512        # Number of unique MIDI tokens
embedding_dim = 256     # Size of token embeddings
max_len = 512          # Max sequence length (you can adjust this)
latent_dim = 128       # Latent space dimension
num_heads = 8          # Number of attention heads
num_layers = 6         # Number of transformer layers
ff_dim = 512           # Feed-forward layer dimension

encoder = Variational_Encoder(vocab_size, embedding_dim, max_len, latent_dim, num_heads, num_layers, ff_dim)

# Example input (batch of MIDI token sequences)
x = torch.randint(0, vocab_size, (32, 100))  # Batch size 32, sequence length 100

z, mu, logvar = encoder(x)
print(f"z shape: {z.shape}, mu shape: {mu.shape}, logvar shape: {logvar.shape}")


torch.Size([32, 100, 256])
z shape: torch.Size([32, 128]), mu shape: torch.Size([32, 128]), logvar shape: torch.Size([32, 128])
