In [None]:
import torch
import logging
from pretty_midi import PrettyMIDI

from pathlib import Path
from typing import Callable, Any, Dict, List

from torch.utils.data import Dataset

logger = logging.getLogger(__name__)

In [None]:
class AbstractMusicDataset(Dataset):
    """
    Abstract base class for music dataset

    Args:
        data_path (List[Path]) : Path to the dataset
        preprocess_fn (Callable): Function to preprocess a single sample
        max_seq_len (int) : Maximum sequence length of the data
        pad_token (int) : token used for padding sequences
    """
    def __init__(self, data_path:List[Path],
                preprocess_fn:Callable, max_seq_len:int,
                pad_token:int):
        self.data_path = data_path
        self.preprocess_fn = preprocess_fn
        self.max_seq_len = max_seq_len
        self.pad_token = pad_token
        self.data = self.load_data()
    def load_data(self)->List[Any]:
        """Load any data from specified data path, Override this in subclasses"""
        raise NotImplementedError("Subclasses must implement this method")

    def preprocess(self, sample:Any)->Dict[str,Any]:
        """Preprocess a single data sample"""
        return self.preprocess_fn(sample, max_seq_len=self.max_seq_len,
                                 pad_token=self.pad_token)
    def __len__(self)->int:
        return len(self.data)
    def __getitem__(self, idx:int)->Dict[str,Any]:
        return self.preprocess(self.data[idx])

In [None]:
def midi_preprocess(sample:PrettyMIDI, max_seq_len:int, pad_token:int,
                   sos_token:int, eos_token:int
                    , default_sample=None)->Dict[str, torch.Tensor]:
    """
    Preprocess a MIDI sample: tokenize, truncate, and pad.

    Args:
        sample (MidiFile): MIDI file sample.
        max_seq_len (int): Maximum sequence length.
        pad_token (int): Padding token.
    Returns:
        Dict[str, torch.Tensor]: Preprocessed sample with input IDs and labels.
    """
    try:
        if not sample.instruments:
            logger.warning("MIDI file has no instruments.")
            return default_sample or {
                "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long)
            }
        tokens = []
        for instrument in sample.instruments:
            for note in instrument.notes:
                tokens.append(note.pitch)
        tokens = [sos_token] + tokens[:max_seq_len-2]+[eos_token]
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens,
                                                             dtype = torch.long),
                                               (0, max_seq_len-len(tokens)),
                                               value=pad_token)
        logger.info(f"Processed MIDI file : {len(tokens)} tokens")
        return {"input_ids":padded_tokens,
               "labels":padded_tokens}
    except Exception as e:
        logger.error(f"Error preprocessing MIDI: {e}")
        return default_sample or {
            "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long),
            "labels":torch.full(size(max_seq_len), pad_token, dtype=torch.long)}