In [19]:
!pip install miditoolkit
!pip install torchtoolkit
!pip install pretty_midi
!pip install miditok

import os
import json
import torch
import logging
import pandas as pd
import random
from pretty_midi import PrettyMIDI

from pathlib import Path
from typing import Callable, Any, Dict, List, Optional
from miditoolkit import MidiFile, Instrument, Note
from miditok import REMI, TokenizerConfig
from torchtoolkit.data import create_subsets
from tqdm import tqdm

import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch import LongTensor
from torch.nn.utils.rnn import pad_sequence
import torch.optim as optim

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)



In [2]:
class AbstractMusicDataset(Dataset):
    """
    Abstract base class for music dataset

    Args:
        data_path (List[Path]) : Path to the dataset
        preprocess_fn (Callable): Function to preprocess a single sample
        max_seq_len (int) : Maximum sequence length of the data
        pad_token (int) : token used for padding sequences
    """
    def __init__(self, data_path:List[Path],
                preprocess_fn:Callable, max_seq_len:int,
                pad_token:int, **kwargs):
        self.data_path = data_path
        self.preprocess_fn = preprocess_fn
        self.max_seq_len = max_seq_len
        self.pad_token = pad_token
        self.kwargs = kwargs
        self.data = self.load_data(**kwargs)
    def load_data(self, **kwargs)->List[Any]:
        """Load any data from specified data path, Override this in subclasses"""
        raise NotImplementedError("Subclasses must implement this method")

    def preprocess(self, sample:Any)->Dict[str,Any]:
        """Preprocess a single data sample"""
        return self.preprocess_fn(sample, max_seq_len=self.max_seq_len,
                                 pad_token=self.pad_token)
    def __len__(self)->int:
        return len(self.data)
    def __getitem__(self, idx:int)->Dict[str,Any]:
        return self.preprocess(self.data[idx])

In [3]:
class MIDIDataset(AbstractMusicDataset):
    def __init__(self, data_path:List[Path],preprocess_fn:Callable,
                 max_seq_len:int, pad_token:int):
        super().__init__(data_path, preprocess_fn,
                         max_seq_len, pad_token)
    def load_data(self)->List[MidiFile]:
        """Load MIDI file from dataset directory"""
        try:
            files = list(self.data_path.glob("*.midi"))
            if not files:
                logger.warning(f"No MIDI files found in: {self.data_path}")
                return []
            return [MidiFile(str(file)) for file in files]
        except FileNotFoundError:
            logger.error(f"MIDI directory not found: {self.data_path}")
            return []
        except Exception as e:
            logger.error(f"Error loading MIDI files: {e}")
            return []

In [4]:
class CSVDataset(AbstractMusicDataset):
    def __init__(self, data_path:List[Path],preprocess_fn:Callable,
                 csv_filename:str, max_seq_len:int, pad_token:int):
        super().__init__(data_path, preprocess_fn, max_seq_len, pad_token,
                        csv_filename = csv_filename)
        self.csv_filename = csv_filename
    def load_data(self, csv_filename:str)->List[Dict[str, Any]]:
        """Load the CSV file into a list of dictionaries"""
        csv_path = self.data_path/self.csv_filename # Use Path object for joining
        try:
            df = pd.read_csv(str(csv_path)) 
            return df.to_dict(orient="records")
        except FileNotFoundError:
            logger.error(f"CSV file not found: {csv_path}")
            return []
        except pd.errors.ParseError as e:
            logger.error(f"Error parsing CSV file: {e}")
            return []
        except Exception as e:
            logger.error(f"An unexpected error occured while loading or parsing CSV file: {e}")
            return []

In [5]:
class JSONDataset(AbstractMusicDataset):
    def __init__(self, data_path:List[Path], preprocess_fn:Callable,
                max_seq_len:int, pad_token:int, json_filename:str):
        super().__init__(data_path, preprocess_fn, max_seq_len, pad_token,
                        json_filename = json_filename)
        self.json_filename = json_filename
    def load_data(self, json_filename:str)->List[Dict[str, Any]]:
        """Load a JSON file containing the dataset."""
        json_path = self.data_path[0]/self.json_filename
        try:
            with open(str(json_path), "r") as f:
                return json.load(f)
        except FileNotFoundError:
            logger.error(f"JSON file not found: {json_path}")
            return []
        except json.JSONDecodeError as e:
            logger.error(f"Error decoding JSON: {e}")
            return []
        except Exception as e:
            logger.error(f"An unexpected error occured while loading or parsing JSON file: {e}")
            return []

In [6]:
def midi_preprocess(sample:PrettyMIDI, max_seq_len:int, pad_token:int,
                   sos_token:int, eos_token:int, tokenizer:REMI
                    , default_sample=None)->Dict[str, torch.Tensor]:
    """
    Preprocess a MIDI sample: tokenize, truncate, and pad.

    Args:
        sample (MidiFile): MIDI file sample.
        max_seq_len (int): Maximum sequence length.
        pad_token (int): Padding token.
    Returns:
        Dict[str, torch.Tensor]: Preprocessed sample with input IDs and labels.
    """
    try:
        # Use REMI tokenizer for full musical representation
        tokens = tokenizer.midi_to_tokens(sample)
        tokens = [sos_token] + tokens[:max_seq_len-2] + [eos_token]
        padded_tokens = torch.nn.functional.pad(
            torch.tensor(tokens, dtype=torch.long),
            (0, max_seq_len - len(tokens)),
            value=pad_token
        )
        return {"input_ids": padded_tokens, "labels": padded_tokens}
    except Exception as e:
        logger.error(f"MIDI preprocessing failed: {e}")
        return default_sample

def csv_preprocess(sample:Dict[str, Any],
                  max_seq_len:int,
                  pad_token:int,
                  sos_token:int,
                  eos_token:int,
                  default_sample:Dict[str, torch.Tensor]=None)->Dict[str, torch.Tensor]:
    """
    Preprocess CSV data for model input.

    Args:
        sample: A dictionary containing 'notes' as a key with a list of integer tokens.
        max_seq_len: The maximum sequence length.
        pad_token: The token used for padding.
        sos_token: The start-of-sequence token.
        eos_token: The end-of-sequence token.
        default_sample: Optional default sample for invalid data.

    Returns:
        A dictionary with 'input_ids' and 'labels' tensors.
    """
    try:
        # For validating input format
        if not isinstance(sample, dict) or "notes" not in sample:
            raise ValueError("Invalid CSV sample format: missing 'notes' key.")
        tokens = sample["notes"]
        # Check token validity
        if not all(isinstance(token, int) for token in tokens):
            raise TypeError("CSV tokens must be integers.")

        # Add special tokens and truncats
        tokens = [sos_token] + tokens[:max_seq_len-2]+[eos_token]
        # Pad the sequence
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens, dtype=torch.long),
                                               (0, max_seq_len-len(tokens)),
                                               value=pad_token)
        logger.info(f"Processed CSV data: {len(tokens)} tokens (max {max_seq_len}).")
        return {
            "input_ids": padded_tokens,
            "labels":padded_tokens
        }
    except (ValueError, TypeError) as e:
        logger.error(f"Error preprocessing CSV: {e}")
        return default_sample or {
            "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long),
            "labels":torch.full((max_seq_len), pad_token, dtype=torch.long)
        }

def json_preprocess(sample:Dict[str, Any],
                   max_seq_len:int,
                   pad_token:int,
                   sos_token:int,
                   default_sample:Dict[str, torch.Tensor]=None)->Dict[str,torch.Tensor]:
    """
    Preprocess JSON data for model input.

    Args:
        sample: A dictionary containing 'sequence' as a key with a list of integer tokens.
        max_seq_len: The maximum sequence length.
        pad_token: The token used for padding.
        sos_token: The start-of-sequence token.
        eos_token: The end-of-sequence token.
        default_sample: Optional default sample for invalid data.

    Returns:
        A dictionary with 'input_ids' and 'labels' tensors.
    """
    try:
        # Validate input format
        if not isinstance(sample, dict) or "sequence" not in sample:
            raise ValueError("Invalid JSON sample format: missing 'sequence' key.")
        tokens = sample["sequence"]
        
        # Check token validity
        if not all(isinstance(token,int) for token in tokens):
            raise TypeError("JSON tokens must be integers.")
        # Add special tokens and truncate
        tokens = [sos_token]+tokens[:max_seq_len-2]+[eos_token]
        # Pad the sequence
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens, dtype=torch.long),
                                               (0, max_seq_len-len(tokens)), value=pad_token)
        logger.info(f"Processed JSON data: {len(tokens)} tokens (max {max_seq_len}).")
        return {"input_ids":padded_tokens,
               "labels":padded_tokens}
    except (ValueError, TypeError) as e:
        logger.error(f"Error preprocessing JSON:{e}")
        return default_sample or {"input_ids":torch.full((max_seq_len),pad_token, dtype=torch.long),
                                 "labels":torch.full((max_seq_len), pad_token, dtype=torch.long)}

### **Preprocessing for the specific MaestroV2 Dataset**

In [7]:
class MaestroDataset(Dataset):
    """
    A dataset for processing Maestro MIDI files.

    Args:
        file_paths (list): List of paths to MIDI or JSON files.
        min_seq (int): Minimum sequence length.
        max_seq (int): Maximum sequence length.
        tokenizer_config (TokenizerConfig, optional): Configuration for the tokenizer (default: None)
        pad_token (int): The token used for padding sequences.
        preprocess (bool): Whether to preprocess and save tokenized files.
        output_dir (Path): Directory to save preprocessed token files.

    Attributes:
        samples (list): List of tokenized sequences.
    """
    def __init__(self, file_paths:List[Path],
                min_seq:int,
                max_seq:int,
                pad_token:int,
                tokenizer_config: TokenizerConfig=None,
                preprocess:bool=True,
                output_dir:Path=None):
        
        if not isinstance(file_paths, list):
            file_paths = [file_paths]
        file_paths = [Path(fp) for fp in file_paths]
        
        self.samples = []
        self.pad_token = pad_token
        # Preprocessing if needed
        if preprocess and output_dir is not None:
            self._preprocessing_(file_paths, tokenizer_config, output_dir)
            file_paths = list(output_dir.glob("*.json"))
        # Load preprocessed tokens
        self.load_samples(file_paths, min_seq, max_seq)
    def _preprocessing_(self, file_paths:List[Path],
                       tokenizer_config:TokenizerConfig,
                       output_dir:Path):
        try:
            output_dir = Path(output_dir)
            output_dir.mkdir(parents=True, exist_ok=True)
        except Exception as e:
            logger.error(f"Failed to create output directory: {e}")
            raise
            
        for i in tqdm(file_paths, desc="Preprocessing MIDI files"):
            try:
                if i.suffix in ["MIDI", "MID", "midi", "mid"]:
                    midi = MidiFile(i)
                    if tokenizer_config is None:
                        tokenizer_config = TokenizerConfig()
                    tokenizer = REMI(tokenizer_config) if tokenizer_config is not None else REMI(tokenizer_config)
                    all_tracks_tokens = [tokenizer.midi_to_tokens(midi)[0].ids 
                                         for track in midi.tracks if len(track)>0]
                    tokens = [token for track in all_tracks_tokens for token in track]
                else:
                    continue # Skip non-MIDI files
                # Save tokens to JSON
                output_file = output_dir/f"{i.stem}_tokens.json"
                with open(output_file, "w") as f:
                    json.dump({"ids":tokens}, f)
            except Exception as e:
                logger.warning(f"Error processing {i}: {e}")
    def load_samples(self, file_paths:List[Path],
                        min_seq:int,
                        max_seq:int):
        """Load tokenized samples and create sequences"""
        for file_path in tqdm(file_paths, desc="Loading tokenized files"):
            try:
                with open(file_path, "r") as f:
                    tokens = json.load(f)["ids"]
                # Create fixed length sequences
                i = 0
                while i<len(tokens):
                    if i>=len(tokens)-min_seq:
                        break
                    self.samples.append(LongTensor(tokens[i:i+max_seq]))
                    i+=len(self.samples[-1])
            except Exception as e:
                logger.warning(f"Error loading {file_path}: {e}")
    def __getitem__(self, idx)->Dict[str, LongTensor]:
        return {"input_ids":self.samples[idx],
               "labels":self.samples[idx]}
    def __len__(self)->int:
        return len(self.samples)
    def __regr__(self):
        return self.__str__()
    def __str__(self)->str:
        return "No data loaded" if len(self)==0 else f"{len(self.samples)} samples"

In [8]:
def collate_fn(batch:List[Dict[str, LongTensor]],
              pad_token:int)->Dict[str,LongTensor]:
    """
    Collate function for dynamic padding.
    Args:
        batch: List of dictionaries, each containing 'input_ids' and 'labels'.
        pad_token: Token used for padding.
    Returns:
        A dictionary with padded 'input_ids' and 'labels' as LongTensor.
    """
    input_ids = [item["input_ids"] for item in batch]
    labels    = [item["labels"] for item in batch]

    input_ids = pad_sequence(input_ids, batch_first=True, padding_value=pad_token).long()
    labels = pad_sequence(labels, batch_first=True, padding_value=pad_token).long()

    return {"input_ids":input_ids, "labels":labels}

# **Model Architecture**
### 1. Attention Mechanism

In [9]:
class RelativeAttention(nn.Module):
    """
    Relative self-attention mechanism
    Args:
         d_model: dimensional input/output of the model
         num_heads: number of attention heads
    """
    def __init__(self, d_model:int, num_heads:int, max_seq_len:int):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.max_seq_len = max_seq_len
        self.head_dim = d_model//num_heads
        self.Wq = nn.Linear(d_model, d_model)
        self.Wk = nn.Linear(d_model, d_model)
        self.Wv = nn.Linear(d_model, d_model)
        self.Er = nn.Parameter(torch.randn(max_seq_len*2-1, self.head_dim))
        self.Wo = nn.Linear(d_model, d_model) # Turn the context vector to desired output dimension

        self.logger = logging.getLogger(__name__)
        
    def forward(self, q:torch.Tensor, k:torch.Tensor, v:torch.Tensor,
               mask:Optional[torch.Tensor]=None)->torch.Tensor:
        """
        Forward pass of the relative attention mechanism.

        Args:
            q: Query tensor of shape (B, T, d_model).
            k: Key tensor of shape (B, T, d_model).
            v: Value tensor of shape (B, T, d_model).
            mask: Attention mask of shape (1, 1, T, T) or None.
            B: Batch size
            T: Sequence Length
            H: Number of heads
            _: Placeholder

        Returns:
            Context vector of shape (B, T, d_model).
        """
        B, T, _ = q.shape
        H = self.num_heads
        q = self.Wq(q).view(B,T,H, self.head_dim).transpose(1,2) #(B,H,T, d_k)
        k = self.Wk(k).view(B,T,H, self.head_dim).transpose(1,2) #(B,H,T, d_k)
        v = self.Wv(v).view(B,T,H, self.head_dim).transpose(1,2)

        # Relative position attention
        QEr = torch.matmul(q, self.Er.transpose(0,1)) # (B,H,T,2T-1)
        scores = torch.matmul(q,k.transpose(2,3))     # (B,H,T,2T-1)
        scores_relative = self._relative_shift(QEr)
        scores = scores+scores_relative

        if mask is not None:
            scores = scores.masked_fill(mask[:,:,:T,:T]==0, float('-inf'))
        attention = F.softmax(scores/(self.head_dim**.5), dim=-1)
        context = torch.matmul(attention, v).transpose(1, 2).contiguous().view(B,T, self.d_model) # (B,T, d_model)
        self.logger.debug(f"Attention weights: {attention}")
        return self.Wo(context)
    def _relative_shift(self, x:torch.Tensor)->torch.Tensor:
        """
        Performs relative shifting for relative attention.

        Args:
            x: Input tensor.
        Returns:
            Shifted tensor.
        """
        batch_size, num_heads, seq_length, _ = x.shape
        x_padded = F.pad(x, (0,0,0,1))
        x_reshaped = x_padded.view(batch_size, num_heads, seq_length+1, seq_length)
        return x_reshaped[:,:,1:,:]

### 2. Encoder - Decoder Layer

In [10]:
class MusicTransformerEncoderLayer(nn.Module):
    """
   Music Transformer encoder layer.

   Args:
       d_model: The input/output dimension of the model.
       num_heads: The number of attention heads.
       dff: The dimension of the feed-forward network.
       dropout_rate: The dropout rate.
   """
    def __init__(self, d_model:int, num_heads:int,
                dff:int, dropout_rate:float,
                max_seq_len:int):
        super().__init__()
        self.self_attn = RelativeAttention(d_model, num_heads, max_seq_len)
        self.ffn = nn.Sequential(nn.Linear(d_model, dff),
                                nn.ReLU(), nn.Linear(dff, d_model)) # Feed-forward upwards projection
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x:torch.Tensor, mask:torch.Tensor)->torch.Tensor:
        """
       Forward pass of the encoder layer.

       Args:
           x: Input tensor of shape (B, T, d_model).
           mask: Attention mask of shape (1, 1, T, T).

       Returns:
           Output tensor of shape (B, T, d_model).
       """
        attn_output = self.self_attn(x, x, x, mask)
        attn_output = self.dropout(attn_output)
        x = self.norm1(x+attn_output)
        ffn_output = self.ffn(x)
        ffn_output = self.dropout(ffn_output)
        x = self.norm2(x + ffn_output)
        return x

class MusicTransformerDecoderLayer(nn.Module):
    def __init__(self, d_model:int, num_heads:int, dff:int,
                dropout_rate:float, max_seq_len:int):
        super().__init__()
        self.self_attn = RelativeAttention(d_model, num_heads, max_seq_len)
        self.cross_attn = RelativeAttention(d_model, num_heads, max_seq_len)
        self.ffn = nn.Sequential(nn.Linear(d_model, dff),
                                nn.ReLU(), nn.Linear(dff, d_model))
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout_rate)
    def forward(self, x, enc_output, tgt_mask, memory_mask):
        attn_output = self.dropout(self.self_attn(x, x, x, tgt_mask))
        x = self.norm1(x + attn_output)
        cross_attn_output = self.dropout(self.cross_attn(x, enc_output, enc_output,
                                                        memory_mask))
        x = self.norm2(x + cross_attn_output)
        ffn_output = self.dropout(self.ffn(x))
        return self.norm3(x + ffn_output)

### 2.1 Music Transformer Model

In [11]:
class MusicTransformer(nn.Module):
    """
    Music Transformer model.

    Args:
        num_classes: The number of classes (vocabulary size).
        d_model: The input/output dimension of the model.
        num_layers: The number of encoder layers.
        num_heads: The number of attention heads.
        dff: The dimension of the feed-forward network.
        dropout_rate: The dropout rate.
        max_seq_len: The maximum sequence length.
    """
    def __init__(self, num_classes:int, d_model:int,
                num_layers:int, num_heads:int, dff:int, 
                dropout_rate:float, max_seq_len:int, pad_token:int):
        super().__init__()
        self.embedding = nn.Embedding(num_classes, d_model)
        self.pos_embedding = nn.Embedding(max_seq_len, d_model)
        self.encoder = nn.ModuleList([MusicTransformerEncoderLayer(d_model, num_heads,
                                                                  dff, dropout_rate, 
                                                                  max_seq_len)
                                     for _ in range(num_layers)])
        self.decoder = nn.ModuleList([MusicTransformerDecoderLayer(d_model, num_heads,
                                                                  dff, dropout_rate,
                                                                  max_seq_len)
                                     for _ in range(num_layers)])
        self.fc = nn.Linear(d_model, num_classes)
        self.max_seq_len = max_seq_len
        self.pad_token = pad_token

        self._init_weights() 

    def forward(self, src:torch.Tensor, tgt:torch.Tensor)->torch.Tensor:
        """
        Forward pass of the Music Transformer.

        Args:
            x: Input tensor of shape (B, T).

        Returns:
            Output tensor of shape (B, T, num_classes).
        """

        # Embedding with positional encoding
        # 1.Encoder Processing
        B, T_src = src.shape

        src_pos = torch.arange(T_src, device = src.device).unsqueeze(0).expand(B,T_src)
        src_emb = self.embedding(src) + self.pos_embedding(src_pos)

        # 2.Decoder Processing
        B, T_tgt = tgt.shape
        
        tgt_pos = torch.arange(T_tgt, device = tgt.device).unsqueeze(0).expand(B,T_tgt)
        tgt_emb = self.embedding(tgt) + self.pos_embedding(tgt_pos)

        # Masks
        src_padding_mask = (src_emb[:,:,0] != self.pad_token).unsqueeze(1).unsqueeze(2)

        tgt_padding_mask = (tgt_emb[:,:,0] != self.pad_token).unsqueeze(1).unsqueeze(2)
        tgt_casual_mask = torch.tril(torch.ones((T_tgt, T_tgt),
                                               device = tgt.device)).unsqueeze(0).unsqueeze(0)
        tgt_mask = tgt_padding_mask + tgt_casual_mask

        # Encode
        encoder_output = src_emb
        for layer in self.encoder:
            encoder_output = layer(encoder_output, src_padding_mask)
        # Decode
        decoder_output = tgt_emb
        for layer in self.decoder:
            decoder_output = layer(decoder_output, encoder_output, tgt_mask, src_padding_mask)
        return self.fc(tgt)
    
    def _init_weights(self):
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_uniform_(module.weight)
                if module.bias is not None:
                    nn.init.zero_(module.bias)
            elif isinstance(module, nn.LayerNorm):
                nn.init.ones_(module.weight)
                nn.init.zeros_(module.bias)

In [12]:
class LabelSmoothingLoss(nn.Module):
    """
    Cross-entropy loss with label smoothing.

    Args:
        num_classes: Number of output classes.
        smoothing: Smoothing factor (alpha). Default is 0.1.
        ignore_index: Index to ignore in the target.
    """
    def __init__(self, num_classes:int, smoothing:float=0.1,
                ignore_index:int=1):
        super().__init__()
        self.num_classes = num_classes
        self.smoothing = smoothing
        self.ignore_index = ignore_index
    def forward(self, pred, target):
        """
        Compute the label-smoothing loss.

        Args:
            pred: Predictions of shape (B, T, num_classes).
            target: Ground truth of shape (B, T).

        Returns:
            Smoothed cross-entropy loss.
        """
        pred = pred.log_softmax(dim=-1)
        with torch.no_grad():
            true_dist = torch.zeros_like(pred)
            true_dist.fill_(self.smoothing/(self.num_classes-1))
            true_dist.scatter_(2, target.unsqueeze(-1), 1.0-self.smoothing)
            if self.ignore_index>0:
                true_dist.masked_fill_((target==self.ignore_index).unsqueeze(-1),0)
        return torch.mean(torch.sum(-true_dist*pred, dim=-1)) 

## **Utility Functions**

In [13]:
def build_music_transformer(num_classes:int, config:dict):
    """
    Build a MusicTransformer model using the provided configuration.

    Args:
        num_classes (int): The number of output classes (vocabulary size).
        config (dict): A dictionary containing model hyperparameters.

    Returns:
        MusicTransformer: A configured MusicTransformer instance.
    """
    return MusicTransformer(num_classes = num_classes,
                           d_model = config["d_model"],
                           num_layers = config["num_layers"],
                           num_heads = config["num_heads"],
                           dff = config["dff"],
                           dropout_rate = config["dropout_rate"],
                           max_seq_len = config["max_seq_len"],
                           pad_token = config["pad_token"])

def get_loss_function(loss_type:str, num_classes:int,
                      smoothing:float, pad_token:int):
    """
    Get the appropriate loss function based on the provided type

    Args:
        loss_type (str): The type of loss function to use ("cross_entropy" or custom).
        num_classes (int): The number of output classes.
        smoothing (float): Label smoothing value for cross-entropy.
        pad_token (int): Padding token ID.

    Returns:
        nn.Module: The loss function.
    """
    if loss_type == "cross_entropy":
        return LabelSmoothingLoss(num_classes, smoothing, pad_token)
    elif loss_type == "mse":
        return nn.MSELoss()
    else:
        raise ValueError(f"Unsupported loss type: {loss_type}")

def get_optimizer(optimizer_type:str,
                  model:nn.Module,
                  learning_rate:float):
    """
    Get the optimizer based on the provided type.

    Args:
        optimizer_type (str): The type of optimizer ("adam", "adamw", etc.).
        model (nn.Module): The model to optimize.
        learning_rate (float): The learning rate.

    Returns:
        Optimizer: A configured optimizer instance."""
    if optimizer_type == "adam":
        return optim.Adam(model.parameters(), lr=learning_rate)
    elif optimizer_type == "adamw":
        return optim.AdamW(model.parameters(), lr=learning_rate)
    else:
        raise ValueError(f"Unsupported optimizer: {optimizer_type}")

def get_scheduler(scheduler_type:str,
                  optimizer, **kwargs):
    """
    Get the learning rate scheduler.

    Args:
        scheduler_type (str): Type of scheduler ("cosine", "step", etc.).
        optimizer (Optimizer): The optimizer instance.
        **kwargs: Additional parameters for specific schedulers.

    Returns:
        Scheduler: A configured scheduler instance.
    """
    if scheduler_type == "cosine":
        return optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                    T_max=kwargs.get("T_max",
                                                                     10))
    elif scheduler_type == "step":
        return optim.lr_scheduler.StepLR(optimizer,
                                         step_size=kwargs.get("step_size",10))
    else:
        raise ValueError(f"Unsupported scheduler: {scheduler_type}")

def evaluate_model(model:nn.Module,
                   dataloader:DataLoader,
                   criterion:nn.Module,
                   metrics:dict,
                   device:torch.device):
    """
    Evaluate the model on a dataset.

    Args:
        model (nn.Module): The model to evaluate.
        dataloader (DataLoader): The dataloader for evaluation.
        criterion (nn.Module): The loss function.
        metrics (dict): A dictionary of metric functions.

    Returns:
        tuple: Average loss and a dictionary of average metric values.
    """
    model.eval()
    total_loss = 0
    total_metrics = {name:0 for name in metrics.keys()}

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch["input_ids"].to(device)
            targets = batch["labels"].to(device)

            decoder_inputs = input_ids[:,:-1]
            decoder_targets = targets[:,1:]

            outputs = model(decoder_inputs)
            loss = criterion(outputs.view(-1, outputs.size(-1)),
                             decoder_targets.contiguous().view(-1)).item()
            total_loss += loss.item()

            for name, metric_fn in metrics.items():
                total_metrics[name] += metric_fn(outputs, targets)
    avg_loss = total_loss/len(dataloader)
    avg_metrics = {name:value/len(dataloader) for name,value in total_metrics.items()}
    return avg_loss, avg_metrics

def freeze_layers(model:nn.Module, freeze_embedding=True,
                  freeze_layers=[]):
    """
    Freeze specified layers of the model.

    Args:
        model (nn.Module): The model to modify.
        freeze_embedding (bool): Whether to freeze the embedding layer.
        freeze_layers (list[int]): List of layer indices to freeze.
    """
    if freeze_embedding:
        for param in model.embedding.parameters():
            param.requires_grad = False
    for idx in freeze_layers:
        for param in model.layers[idx].parameters():
            param.requires_grad = False


def accuracy_fn(predictions, targets, pad_token):
    """Compute accuracy of predictions"""
    _, pred_ids = torch.max(predictions, dim=-1)
    correct = (pred_ids==targets).float()
    mask = targets != pad_token # Ignore padding tokens
    return (correct*mask).sum().item()/mask.sum().item()

def save_generated_sequence(sequence, output_path):
    midi = MidiFile()
    track = Instrument(program=0)
    for pitch in sequence:
        track.notes.append(Note(velocity=64, pitch = pitch,
                               start = 0, end=480))
    midi.instruments.append(track)
    midi.dump(output_path)

In [16]:
token_paths = list(Path('/kaggle/working/preprocessedv2/').glob("**/*.json"))

tokenizer_config = TokenizerConfig(num_velocities=32, use_chords=True,
                                   use_rests=True, use_tempos=True,
                                   use_time_signatures=True,
                                   beat_res={(0, 4): 8, (4, 12): 4})

dataset = MaestroDataset(file_paths = token_paths,
                        max_seq=512, min_seq=384,
                        pad_token=0, preprocess=True,
                        tokenizer_config = tokenizer_config,
                        output_dir=Path("/kaggle/working/"))
subset_train, subset_valid = create_subsets(dataset, split_ratio = [0.3])

Preprocessing MIDI files: 0it [00:00, ?it/s]
Loading tokenized files: 0it [00:00, ?it/s]


## **Training Loop**

In [None]:
from torch.cuda.amp import GradScaler, autocast

scaler = GradScaler()
accumulation_steps = 4 

with open("/kaggle/input/jsonconfigfile/config.json", "r") as f:
    config = json.load(f)
    model_params = config["model_params"]
    training_params = config["training_params"]
    scheduler_params = config["scheduler"]
    data_paths = config["data_paths"]

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model_ = MusicTransformer(num_classes = model_params["num_classes"],
                            d_model = model_params["d_model"],
                            num_layers = model_params["num_layers"],
                            num_heads = model_params["num_heads"],
                            dff = model_params["dff"],
                            dropout_rate = model_params["dropout_rate"],
                            max_seq_len = model_params["max_seq_len"],
                            pad_token = model_params["pad_token"]).to(device)

    criterion = get_loss_function(config["loss_function"], model_.embedding.num_embeddings,
                                  smoothing=training_params["smoothing"],
                                  pad_token = model_params["pad_token"])
    optimizer = get_optimizer(config["optimizer"], model_,
                              learning_rate=training_params["learning_rate"])

    # Learning rate scheduler
    # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
    scheduler = get_scheduler(scheduler_params["type"], optimizer,
                              step_size = scheduler_params["step_size"] ,
                              gamma = scheduler_params["gamma"])
    # Ensure batch_size is not larger than the dataset
    batch_size = min(training_params["batch_size"], len(subset_train))
    # Dataloader
    train_loader = DataLoader(subset_train, batch_size=batch_size,
                              shuffle=True, collate_fn=lambda x:collate_fn(x, model_params["pad_token"]))
    val_loader   = DataLoader(subset_valid, batch_size=batch_size,
                              shuffle=False, collate_fn=lambda x:collate_fn(x, model_params["pad_token"]))

    # Metrics
    metrics = {"accuracy":accuracy_fn}

    # Training Loop
    best_val_loss = float("inf")
    os.makedirs(data_paths["checkpoint_dir"], exist_ok=True)

    for epoch in range(training_params["num_epochs"]):
        model_.train()
        epoch_loss = 0
        for i, inputs in enumerate(tqdm(train_loader
                                        , desc=f"Epoch {epoch+1}/{training_params["num_epochs"]}")):
            inputs_ids = inputs["input_ids"].to(device)
            targets    = inputs["labels"].to(device)

            # Shift targets for autoregressive modeling
            encoder_inputs = inputs_ids[:,:-1]
            decoder_inputs = inputs_ids[:,:-1]
            decoder_targets = targets[:,1:]

            
            # Forward Pass
            with autocast():
                outputs = model_(encoder_inputs, decoder_inputs)
                loss = criterion(outputs.view(-1, outputs.size(-1)),
                                 decoder_targets.contiguous().view(-1))/accumulation_steps

            # Backward pass and optimization
            scaler.scale(loss).backward()
            if (i+1)%accumulation_steps==0:
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
                torch.nn.utils.clip_grad_norm_(model_.parameters(),
                                           max_norm=training_params["clip_value"])
            epoch_loss += loss.item()*accumulation_steps

        train_loss = epoch_loss/len(train_loader)
        print(f"Epoch {epoch+1} | Training Loss : {train_loss:.4f}")

        # Validation Loss
        val_loss, val_metrics = evaluate_model(model = model_,
                                               dataloader = val_loader,
                                               criterion=criterion,
                                               metrics = metrics,
                                               device = device)
        if val_loss<best_val_loss:
            best_val_loss = val_loss
            torch.save(model_.state_dict(), os.path.join(data_paths["checkpoint_dir"],
                                                        "best_model.pth"))
        print(f"Epoch {epoch + 1} | Validating Loss: {val_loss:.4f} | Validating Metrics: {val_metrics}")

        # Generate and save music samples
        if epoch%1==0: # Generate for every epoch
            model_.eval()
            with torch.no_grad():
                start_sequence = torch.tensor([random.randint(1, model_params["num_classes"]-1)],
                                              device=device).unsqueeze(0) # Random start token
                generated_sequence = start_sequence.clone()
                for _ in range(training_params["generation_length"]):
                    output = model_(generated_sequence)
                    next_token = output[:, -1, :].argmax(dim=-1, keepdim=True)
                    generated_sequence = torch.cat([generated_sequence, next_token], dim=1)
                    if next_token.item() == model_params["pad_token"]: # Stop on pad token
                        break
                print(f"Generated Sequence : {generated_sequence.tolist()}")
                save_generated_sequence(generated_sequence.squeeze().tolist(),
                                       f"generated_epoch_{epoch+1}.mid")

        # Save checkpoint
        torch.save(model_.state_dict(),os.path.join(data_paths["checkpoint_dir"],
                                                    f"model_epoch{epoch + 1}.pth"))
        scheduler.step()

## **Testing purpose with a single sample**

In [21]:
try:
    import torch
    from pathlib import Path
    from typing import List, Dict, Any
    from mido import MidiFile, MidiTrack, Message
    from miditok import REMI, TokenizerConfig
    
    with open("/kaggle/input/jsonconfigfile/config.json", "r") as f:
        config = json.load(f)
        model_params = config["model_params"]
        training_params = config["training_params"]
        scheduler_params = config["scheduler"]
        data_paths = config["data_paths"]
        
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        midi = MidiFile()
        track = MidiTrack()
        track.append(Message('program_change', program=0, time=0))  # Set instrument to piano
        track.append(Message('note_on', note=60, velocity=64, time=0)) # C4
        track.append(Message('note_off', note=60, velocity=64, time=480))
        track.append(Message('note_on', note=64, velocity=64, time=480)) # E4
        track.append(Message('note_off', note=64, velocity=64, time=480))
        midi.tracks.append(track)
        midi_path = Path("test_midi.mid")
        midi.save(str(midi_path))
        
        
        
        dataset.data = [midi_path]
        sample = dataset[0]
        
        model = MusicTransformer(num_classes = model_params["num_classes"],
                            d_model = model_params["d_model"],
                            num_layers = model_params["num_layers"],
                            num_heads = model_params["num_heads"],
                            dff = model_params["dff"],
                            dropout_rate = model_params["dropout_rate"],
                            max_seq_len = model_params["max_seq_len"],
                            pad_token = model_params["pad_token"]).to(device)
        
        criterion = get_loss_function(config["loss_function"], model_.embedding.num_embeddings,
                                  smoothing=training_params["smoothing"],
                                  pad_token = model_params["pad_token"])
        optimizer = get_optimizer(config["optimizer"], model_,
                              learning_rate=training_params["learning_rate"])
        
        # Learning rate scheduler
        # scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        scheduler = get_scheduler(scheduler_params["type"], optimizer,
                              step_size = scheduler_params["step_size"] ,
                              gamma = scheduler_params["gamma"])
        
        # Training Loop (Single Sample Test)
        model_.train()
        input_ids = sample["input_ids"].unsqueeze(0)
        targets = sample["labels"].unsqueeze(0)
        decoder_inputs = input_ids[:, :-1]
        decoder_targets = targets[:, 1:]
        outputs = model_(decoder_inputs)
        loss = criterion(outputs.view(-1, outputs.size(-1)), decoder_targets.contiguous().view(-1))
        
        print("Output Shape:", outputs.shape)
        print("Loss:", loss.item())
        
        # Clean up the created midi file
        os.remove(midi_path)
except Exception as e:
    logger.error(f"Error during single sample test: {e}")
finally:
    if os.path.exists(midi_path):
        os.remove(midi_path)