In [None]:
!pip install miditoolkit
!pip install torchtoolkit


import json
import torch
import logging
import pandas as pd
from pretty_midi import PrettyMIDI

from pathlib import Path
from typing import Callable, Any, Dict, List, Optional
from miditoolkit import MidiFile
from miditok import REMI, TokenizerConfig
from torchtoolkit.data import create_subsets
from tqdm import tqdm

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch import LongTensor
from torch.nn.utils.rnn import pad_sequence

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class AbstractMusicDataset(Dataset):
    """
    Abstract base class for music dataset

    Args:
        data_path (List[Path]) : Path to the dataset
        preprocess_fn (Callable): Function to preprocess a single sample
        max_seq_len (int) : Maximum sequence length of the data
        pad_token (int) : token used for padding sequences
    """
    def __init__(self, data_path:List[Path],
                preprocess_fn:Callable, max_seq_len:int,
                pad_token:int, **kwargs):
        self.data_path = data_path
        self.preprocess_fn = preprocess_fn
        self.max_seq_len = max_seq_len
        self.pad_token = pad_token
        self.kwargs = kwargs
        self.data = self.load_data(**kwargs)
    def load_data(self, **kwargs)->List[Any]:
        """Load any data from specified data path, Override this in subclasses"""
        raise NotImplementedError("Subclasses must implement this method")

    def preprocess(self, sample:Any)->Dict[str,Any]:
        """Preprocess a single data sample"""
        return self.preprocess_fn(sample, max_seq_len=self.max_seq_len,
                                 pad_token=self.pad_token)
    def __len__(self)->int:
        return len(self.data)
    def __getitem__(self, idx:int)->Dict[str,Any]:
        return self.preprocess(self.data[idx])

In [None]:
class MIDIDataset(AbstractMusicDataset):
    def __init__(self, data_path:List[Path],preprocess_fn:Callable,
                 max_seq_len:int, pad_token:int):
        super().__init__(data_path, preprocess_fn,
                         max_seq_len, pad_token)
    def load_data(self)->List[MidiFile]:
        """Load MIDI file from dataset directory"""
        try:
            files = list(self.data_path.glob("*.midi"))
            if not files:
                logger.warning(f"No MIDI files found in: {self.data_path}")
                return []
            return [MidiFile(str(file)) for file in files]
        except FileNotFoundError:
            logger.error(f"MIDI directory not found: {self.data_path}")
            return []
        except Exception as e:
            logger.error(f"Error loading MIDI files: {e}")
            return []

In [None]:
class CSVDataset(AbstractMusicalDataset):
    def __init__(self, data_path:List[Path],preprocess_fn:Callable,
                 csv_filename:str, max_seq_len:int):
        super().__init__(data_path, preprocess_fn, max_seq_len, pad_token,
                        csv_filename = csv_filename)
        self.csv_filename = csv_filename
    def load_data(self, csv_filename:str)->List[Dict[str, Any]]:
        """Load the CSV file into a list of dictionaries"""
        csv_path = self.data_path/self.csv_filename # Use Path object for joining
        try:
            df = pd.read_csv(str(csv_path)) 
            return df.to_dict(orient="records")
        except FileNotFoundError:
            logger.error(f"CSV file not found: {csv_path}")
            return []
        except pd.errors.ParseError as e:
            logger.error(f"Error parsing CSV file: {e}")
            return []
        except Exception as e:
            logger.error(f"An unexpected error occured while loading or parsing CSV file: {e}")
            return []

In [None]:
class JSONDataset(AbstractMusicalDataset):
    def __init__(self, data_path:List[Path], preprocess_fn:Callable,
                max_seq_len:int, pad_token:int, json_filename:str):
        super().__init__(data_path, preprocess_fn, max_seq_len, pad_token,
                        json_filename = json_filename)
        self.json_filename = json_filename
    def load_data(self, json_filename:str)->List[Dict[str, Any]]:
        """Load a JSON file containing the dataset."""
        json_path = self.data_path[0]/self.json_filename
        try:
            with open(str(json_path), "r") as f:
                return json.load(f)
        except FileNotFoundError:
            logger.error(f"JSON file not found: {json_path}")
            return []
        except json.JSONDecodeError as e:
            logger.error(f"Error decoding JSON: {e}")
            return []
        except Exception as e:
            logger.error(f"An unexpected error occured while loading or parsing JSON file: {e}")
            return []

In [None]:
def midi_preprocess(sample:PrettyMIDI, max_seq_len:int, pad_token:int,
                   sos_token:int, eos_token:int
                    , default_sample=None)->Dict[str, torch.Tensor]:
    """
    Preprocess a MIDI sample: tokenize, truncate, and pad.

    Args:
        sample (MidiFile): MIDI file sample.
        max_seq_len (int): Maximum sequence length.
        pad_token (int): Padding token.
    Returns:
        Dict[str, torch.Tensor]: Preprocessed sample with input IDs and labels.
    """
    try:
        if not sample.instruments:
            logger.warning("MIDI file has no instruments.")
            return default_sample or {
                "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long)
            }
        tokens = []
        for instrument in sample.instruments:
            for note in instrument.notes:
                tokens.append(note.pitch)
        tokens = [sos_token] + tokens[:max_seq_len-2]+[eos_token]
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens,
                                                             dtype = torch.long),
                                               (0, max_seq_len-len(tokens)),
                                               value=pad_token)
        logger.info(f"Processed MIDI file : {len(tokens)} tokens")
        return {"input_ids":padded_tokens,
               "labels":padded_tokens}
    except Exception as e:
        logger.error(f"Error preprocessing MIDI: {e}")
        return default_sample or {
            "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long),
            "labels":torch.full(size(max_seq_len), pad_token, dtype=torch.long)}

def csv_preprocess(sample:Dict[str, Any],
                  max_seq_len:int,
                  pad_token:int,
                  sos_token:int,
                  eos_token:int,
                  default_sample:Dict[str, torch.Tensor]=None)->Dict[str, torch.Tensor]:
    """
    Preprocess CSV data for model input.

    Args:
        sample: A dictionary containing 'notes' as a key with a list of integer tokens.
        max_seq_len: The maximum sequence length.
        pad_token: The token used for padding.
        sos_token: The start-of-sequence token.
        eos_token: The end-of-sequence token.
        default_sample: Optional default sample for invalid data.

    Returns:
        A dictionary with 'input_ids' and 'labels' tensors.
    """
    try:
        # For validating input format
        if not isinstance(sample, dict) or "notes" not in sample:
            raise ValueError("Invalid CSV sample format: missing 'notes' key.")
        tokens = sample["notes"]
        # Check token validity
        if not all(isinstance(token, int) for token in tokens):
            raise TypeError("CSV tokens must be integers.")

        # Add special tokens and truncats
        tokens = [sos_token] + tokens[:max_seq_len-2]+[eos_token]
        # Pad the sequence
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens, dtype=torch.long),
                                               (0, max_seq_len-len(tokens)),
                                               value=pad_token)
        logger.info(f"Processed CSV data: {len(tokens)} tokens (max {max_seq_len}).")
        return {
            "input_ids": padded_tokens,
            "labels":padded_tokens
        }
    except (ValueError, TypeError) as e:
        logger.error(f"Error preprocessing CSV: {e}")
        return default_sample or {
            "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long),
            "labels":torch.full((max_seq_len), pad_token, dtype=torch.long)
        }

def json_preprocess(sample:Dict[str, Any],
                   max_seq_len:int,
                   pad_token:int,
                   sos_token:int,
                   default_sample:Dict[str, torch.Tensor]=None)->Dict[str,torch.Tensor]:
    """
    Preprocess JSON data for model input.

    Args:
        sample: A dictionary containing 'sequence' as a key with a list of integer tokens.
        max_seq_len: The maximum sequence length.
        pad_token: The token used for padding.
        sos_token: The start-of-sequence token.
        eos_token: The end-of-sequence token.
        default_sample: Optional default sample for invalid data.

    Returns:
        A dictionary with 'input_ids' and 'labels' tensors.
    """
    try:
        # Validate input format
        if not isinstance(sample, dict) or "sequence" not in sample:
            raise ValueError("Invalid JSON sample format: missing 'sequence' key.")
        tokens = sample["sequence"]
        
        # Check token validity
        if not all(isinstance(token,int) for token in tokens):
            raise TypeError("JSON tokens must be integers.")
        # Add special tokens and truncate
        tokens = [sos_token]+tokens[:max_seq_len-2]+[eos_token]
        # Pad the sequence
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens, dtype=torch.long),
                                               (0, max_seq_len-len(tokens)), value=pad_token)
        logger.info(f"Processed JSON data: {len(tokens)} tokens (max {max_seq_len}).")
        return {"input_ids":padded_tokens,
               "labels":padded_tokens}
    except (ValueError, TypeError) as e:
        logger.error(f"Error preprocessing JSON:{e}")
        return default_sample or {"input_ids":torch.full((max_seq_len),pad_token, dtype=torch.long),
                                 "labels":torch.full((max_seq_len), pad_token, dtype=torch.long)}

### **Preprocessing for the specific MaestroV2 Dataset**

In [None]:
class MaestroDataset(Dataset):
    """
    A dataset for processing Maestro MIDI files.

    Args:
        file_paths (list): List of paths to MIDI or JSON files.
        min_seq (int): Minimum sequence length.
        max_seq (int): Maximum sequence length.
        tokenizer_config (TokenizerConfig, optional): Configuration for the tokenizer (default: None)
        pad_token (int): The token used for padding sequences.
        preprocess (bool): Whether to preprocess and save tokenized files.
        output_dir (Path): Directory to save preprocessed token files.

    Attributes:
        samples (list): List of tokenized sequences.
    """
    def __init__(self, file_paths:List[Path],
                min_seq:int,
                max_seq:int,
                pad_token:int,
                tokenizer_config: TokenizerConfig=None,
                preprocess:bool=True,
                output_dir:Path=None):
        self.samples = []
        self.pad_token = pad_token
        # Preprocessing if needed
        if preprocess and output_dir is not None:
            self._preprocessing_(file_paths, tokenizer_config, output_dir)
            file_paths = List(output_dir.glob("*.json"))
        # Load preprocessed tokens
        self.load_samples(file_paths, min_seq, max_seq)
    def _preprocessing_(self, file_paths:List[Path],
                       tokenizer_config:TokenizerConfig,
                       output_dir:Path):
        output_dir.mkdir(parents=True, exist_ok=True)
        for i in tqdm(file_paths, desc="Preprocessing MIDI files"):
            try:
                if i.suffix in ["MIDI", "MID", "midi", "mid"]:
                    midi = MidiFile(i)
                    tokenizer = REMI(tokenizer_config) if tokenizer_config is not None else REMI()
                    all_tracks_tokens = [tokenizer.midi_to_tokens(midi)[0].ids 
                                         for track in midi.tracks if len(track)>0]
                    tokens = [token for track in all_tracks_tokens for token in track]
                else:
                    continue # Skip non-MIDI files
                # Save tokens to JSON
                output_file = output_dir/f"{i.stem}_tokens.json"
                with open(output_file, "w") as f:
                    json.dump({"ids":tokens}, f)
            except Exception as e:
                logger.warning(f"Error processing {i}: {e}")
    def load_samples(self, file_paths:List[Path],
                        min_seq:int,
                        max_seq:int):
        """Load tokenized samples and create sequences"""
        for file_path in tqdm(file_paths, desc="Loading tokenized files"):
            try:
                with open(file_path, "r") as f:
                    tokens = json.load(f)["ids"]
                # Create fixed length sequences
                i = 0
                while i<len(tokens):
                    if i>=len(tokens)-min_seq:
                        break
                    self.samples.append(LongTensor(tokens[i:i+max_seq]))
                    i+=len(self.samples[-1])
            except Exception as e:
                logger.warning(f"Error loading {file_path}: {e}")
    def __getitem__(self, idx)->Dict[str, LongTensor]:
        return {"input_ids":self.samples[idx],
               "labels":self.samples[idx]}
    def __len__(self)->int:
        return len(self.samples)
    def __regr__(self):
        return self.__str__()
    def __str__(self)->str:
        return "No data loaded" if len(self)==0 else f"{len(self.samples)} samples"

In [None]:
def collate_fn(batch:List[Dict[str, LongTensor]],
              pad_token:int)->Dict[str,LongTensor]:
    """
    Collate function for dynamic padding.
    Args:
        batch: List of dictionaries, each containing 'input_ids' and 'labels'.
        pad_token: Token used for padding.
    Returns:
        A dictionary with padded 'input_ids' and 'labels' as LongTensor.
    """
    input_ids = [item["input_ids"] for item in batch]
    labels    = [item["labels"] for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token).long()
    labels = pad_sequence(labels, batch_first=True, padding_value=pad_token).long()

    return {"input_ids":input_ids, "labels":labels}

# **Model Architecture**
### 1. Attention Mechanism

In [None]:
class RelativeAttention(nn.Module):
    """
    Relative self-attention mechanism
    Args:
         d_model: dimensional input/output of the model
         num_heads: number of attention heads
    """
    def __init__(self, d_model:int, num_heads:int, max_seq_len:int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.max_seq_len = max_seq_len
        self.head_dim = d_model//num_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Er = nn.Parameter(torch.randn(max_seq_len*2-1, self.head_dim))
        self.Wo = nn.Linear(d_model, d_model) # Turn the context vector to desired output dimension

        self.logger = logging.getLogger(__name__)
        
    def forward(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor,
               mask:Optional[torch.Tensor]=None)->torch.Tensor:
        """
        Forward pass of the relative attention mechanism.

        Args:
            q: Query tensor of shape (B, T, d_model).
            k: Key tensor of shape (B, T, d_model).
            v: Value tensor of shape (B, T, d_model).
            mask: Attention mask of shape (1, 1, T, T) or None.
            B: Batch size
            T: Sequence Length
            H: Number of heads
            _: Placeholder

        Returns:
            Context vector of shape (B, T, d_model).
        """
        B, T, _ = q.shape
        H = self.num_heads
        q = self.Wq(q).view(B,T,H, self.head_dim).transpose(1,2) #(B,H,T, d_k)
        k = self.Wk(k).view(B,T,H, self.head_dim).transpose(1,2) #(B,H,T, d_k)
        v = self.Wv(v).view(B,T,H, self.head_dim).transpose(1,2)

        # Relative position attention
        QEr = torch.matmul(q, self.Er.transpose(0,1)) # (B,H,T,2T-1)
        scores = torch.matmul(q,k.transpose(2,3))     # (B,H,T,2T-1)
        scores_relative = self._relative_shift(QEr)
        scores = scores+scores_relative

        if mask is not None:
            scores = scores.masked_fill(mask[:,:,:T,:T]==0, float('-inf'))
        attention = F.softmax(scores/(self.head_dim**.5), dim=-1)
        context = torch.matmul(attention, v).transpose(1, 2).contiguous().view(B,T, self.d_model) # (B,T, d_model)
        self.logger.debug(f"Attention weights: {attention}")
        return self.Wo(context)
    def _relative_shift(self, x:torch.Tensor)->torch.Tensor:
        """
        Performs relative shifting for relative attention.

        Args:
            x: Input tensor.
        Returns:
            Shifted tensor.
        """
        batch_size, num_heads, seq_length, _ = x.shape
        x_padded = F.pad(x, (0,0,0,1))
        x_reshaped = x_padded.view(batch_size, num_heads, seq_length+1, seq_length)
        return x_reshaped[:,:,1:,:]

### 2. Encoder - Decoder Layer

In [None]:
class MusicTransformerEncoderLayer(nn.Module):
    """
   Music Transformer encoder layer.

   Args:
       d_model: The input/output dimension of the model.
       num_heads: The number of attention heads.
       dff: The dimension of the feed-forward network.
       dropout_rate: The dropout rate.
   """
    def __init__(self, d_model:int, num_heads:int,
                dff:int, dropout_rate:float,
                max_seq_len:int):
        super().__init__()
        self.self_attn = RelativeAttention(d_model, num_heads, max_seq_len)
        self.ffn = nn.Sequential(nn.Linear(d_model, dff),
                                nn.ReLU(), nn.Linear(dff, d_model)) # Feed-forward upwards projection
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x:torch.Tensor, mask:torch.Tensor)->torch.Tensor:
        """
       Forward pass of the encoder layer.

       Args:
           x: Input tensor of shape (B, T, d_model).
           mask: Attention mask of shape (1, 1, T, T).

       Returns:
           Output tensor of shape (B, T, d_model).
       """
        attn_output = self.self_attn(x, x, x, mask)
        attn_output = self.dropout(attn_dropout)
        x = self.norm1(x+attn_output)
        ffn_output = self.ffn(x)
        ffn_output = self.dropout(ffn_output)
        x = self.norm2(x + ffn_output)
        return x

class MusicTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model:int, num_heads:int, dff:int,
                dropout_rate:float, max_seq_len:int):
        super().__init__()
        self.self_attn = RelativeAttention(d_model, num_heads, max_seq_len)
        self.cross_attn = RelativeAttention(d_model, num_heads, max_seq_len)
        self.ffn = nn.Sequential(nn.Linear(d_model, dff),
                                nn.ReLU(), nn.Linear(dff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x, enc_output, tgt_mask, memory_mask):
        attn_output = self.dropout(self.self_attn(x, x, x, tgt_mask))
        x = self.norm1(x + attn_output)
        cross_attn_output = self.dropout(self.cross_attn(x, enc_output, enc_output,
                                                        memory_mask))
        x = self.norm2(x + cross_attn_output)
        ffn_output = self.dropout(self.ffn(x))
        return self.norm3(x + ffn_output)

### 2.1 Music Transformer Model

In [None]:
class MusicTransformer(nn.Module):
    """
    Music Transformer model.

    Args:
        num_classes: The number of classes (vocabulary size).
        d_model: The input/output dimension of the model.
        num_layers: The number of encoder layers.
        num_heads: The number of attention heads.
        dff: The dimension of the feed-forward network.
        dropout_rate: The dropout rate.
        max_seq_len: The maximum sequence length.
    """
    def __init__(self, num_classes:int, d_model:int,
                num_layers:int, num_heads:int, dff:int, 
                dropout_rate:float, max_seq_len:int, pad_token:int):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        self.encoder = nn.ModuleList([MusicTransformerEncoderLayer(d_model, num_heads,
                                                                  dff, dropout_rate, 
                                                                  max_seq_len)
                                     for _ in range(num_layers)])
        self.decoder = nn.ModuleList([MusicTransformerDecoderLayer(d_model, num_heads,
                                                                  dff, dropout_rate,
                                                                  max_seq_len)
                                     for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, num_classes)
        self.max_seq_len = max_seq_len
        self.pad_token = pad_token

        self._init_weights() 

    def forward(self, src:torch.Tensor, tgt:torch.Tensor)->torch.Tensor:
        """
        Forward pass of the Music Transformer.

        Args:
            x: Input tensor of shape (B, T).

        Returns:
            Output tensor of shape (B, T, num_classes).
        """
        B, T_src = src.shape
        B, T_tgt = tgt.shape

        # Embedding with positional encoding
        src_pos = torch.arange(T_src, device = src.device).unsqueeze(0).expand(B,T_src)
        src = self.embedding(src) + self.pos_embedding(src_pos)

        tgt_pos = torch.arange(T_tgt, device = tgt.device).unsqueeze(0).expand(B,T_tgt)
        tgt = self.embedding(tgt) + self.pos_embedding(tgt_pos)

        # Masks
        src_padding_mask = (src[:,:,0] != self.pad_token).unsqueeze(1).unsqueeze(2)

        tgt_padding_mask = (tgt[:,:,0] != self.pad_token).unsqueeze(1).unsqueeze(2)
        tgt_casual_mask = torch.tril(torch.ones((T_tgt, T_tgt),
                                               device = tgt.device)).unsqueeze(0).unsqueeze(0)
        tgt_mask = tgt_padding_mask + tgt_casual_mask

        # Encode
        for layer in self.encoder:
            src = layer(src, src_padding_mask)
        # Decode
        for layer in self.decoder:
            tgt = layer(tgt, src, tgt_mask, src_padding_mask)
        return self.fc(tgt)
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zero_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)