In [None]:
!pip install miditoolkit
!pip install torchtoolkit


import json
import torch
import logging
import pandas as pd
from pretty_midi import PrettyMIDI

from pathlib import Path
from typing import Callable, Any, Dict, List
from miditoolkit import MidiFile
from miditok import REMI, TokenizerConfig
from torchtoolkit.data import create_subsets
from tqdm import tqdm

from torch.utils.data import Dataset
from torch import LongTensor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

In [None]:
class AbstractMusicDataset(Dataset):
    """
    Abstract base class for music dataset

    Args:
        data_path (List[Path]) : Path to the dataset
        preprocess_fn (Callable): Function to preprocess a single sample
        max_seq_len (int) : Maximum sequence length of the data
        pad_token (int) : token used for padding sequences
    """
    def __init__(self, data_path:List[Path],
                preprocess_fn:Callable, max_seq_len:int,
                pad_token:int, **kwargs):
        self.data_path = data_path
        self.preprocess_fn = preprocess_fn
        self.max_seq_len = max_seq_len
        self.pad_token = pad_token
        self.kwargs = kwargs
        self.data = self.load_data(**kwargs)
    def load_data(self, **kwargs)->List[Any]:
        """Load any data from specified data path, Override this in subclasses"""
        raise NotImplementedError("Subclasses must implement this method")

    def preprocess(self, sample:Any)->Dict[str,Any]:
        """Preprocess a single data sample"""
        return self.preprocess_fn(sample, max_seq_len=self.max_seq_len,
                                 pad_token=self.pad_token)
    def __len__(self)->int:
        return len(self.data)
    def __getitem__(self, idx:int)->Dict[str,Any]:
        return self.preprocess(self.data[idx])

In [None]:
class MIDIDataset(AbstractMusicDataset):
    def __init__(self, data_path:List[Path],preprocess_fn:Callable,
                 max_seq_len:int, pad_token:int):
        super().__init__(data_path, preprocess_fn,
                         max_seq_len, pad_token)
    def load_data(self)->List[MidiFile]:
        """Load MIDI file from dataset directory"""
        try:
            files = list(self.data_path.glob("*.midi"))
            if not files:
                logger.warning(f"No MIDI files found in: {self.data_path}")
                return []
            return [MidiFile(str(file)) for file in files]
        except FileNotFoundError:
            logger.error(f"MIDI directory not found: {self.data_path}")
            return []
        except Exception as e:
            logger.error(f"Error loading MIDI files: {e}")
            return []

In [None]:
class CSVDataset(AbstractMusicalDataset):
    def __init__(self, data_path:List[Path],preprocess_fn:Callable,
                 csv_filename:str, max_seq_len:int):
        super().__init__(data_path, preprocess_fn, max_seq_len, pad_token,
                        csv_filename = csv_filename)
        self.csv_filename = csv_filename
    def load_data(self, csv_filename:str)->List[Dict[str, Any]]:
        """Load the CSV file into a list of dictionaries"""
        csv_path = self.data_path/self.csv_filename # Use Path object for joining
        try:
            df = pd.read_csv(str(csv_path)) 
            return df.to_dict(orient="records")
        except FileNotFoundError:
            logger.error(f"CSV file not found: {csv_path}")
            return []
        except pd.errors.ParseError as e:
            logger.error(f"Error parsing CSV file: {e}")
            return []
        except Exception as e:
            logger.error(f"An unexpected error occured while loading or parsing CSV file: {e}")
            return []

In [None]:
class JSONDataset(AbstractMusicalDataset):
    def __init__(self, data_path:List[Path], preprocess_fn:Callable,
                max_seq_len:int, pad_token:int, json_filename:str):
        super().__init__(data_path, preprocess_fn, max_seq_len, pad_token,
                        json_filename = json_filename)
        self.json_filename = json_filename
    def load_data(self, json_filename:str)->List[Dict[str, Any]]:
        """Load a JSON file containing the dataset."""
        json_path = self.data_path[0]/self.json_filename
        try:
            with open(str(json_path), "r") as f:
                return json.load(f)
        except FileNotFoundError:
            logger.error(f"JSON file not found: {json_path}")
            return []
        except json.JSONDecodeError as e:
            logger.error(f"Error decoding JSON: {e}")
            return []
        except Exception as e:
            logger.error(f"An unexpected error occured while loading or parsing JSON file: {e}")
            return []

In [None]:
def midi_preprocess(sample:PrettyMIDI, max_seq_len:int, pad_token:int,
                   sos_token:int, eos_token:int
                    , default_sample=None)->Dict[str, torch.Tensor]:
    """
    Preprocess a MIDI sample: tokenize, truncate, and pad.

    Args:
        sample (MidiFile): MIDI file sample.
        max_seq_len (int): Maximum sequence length.
        pad_token (int): Padding token.
    Returns:
        Dict[str, torch.Tensor]: Preprocessed sample with input IDs and labels.
    """
    try:
        if not sample.instruments:
            logger.warning("MIDI file has no instruments.")
            return default_sample or {
                "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long)
            }
        tokens = []
        for instrument in sample.instruments:
            for note in instrument.notes:
                tokens.append(note.pitch)
        tokens = [sos_token] + tokens[:max_seq_len-2]+[eos_token]
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens,
                                                             dtype = torch.long),
                                               (0, max_seq_len-len(tokens)),
                                               value=pad_token)
        logger.info(f"Processed MIDI file : {len(tokens)} tokens")
        return {"input_ids":padded_tokens,
               "labels":padded_tokens}
    except Exception as e:
        logger.error(f"Error preprocessing MIDI: {e}")
        return default_sample or {
            "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long),
            "labels":torch.full(size(max_seq_len), pad_token, dtype=torch.long)}

def csv_preprocess(sample:Dict[str, Any],
                  max_seq_len:int,
                  pad_token:int,
                  sos_token:int,
                  eos_token:int,
                  default_sample:Dict[str, torch.Tensor]=None)->Dict[str, torch.Tensor]:
    """
    Preprocess CSV data for model input.

    Args:
        sample: A dictionary containing 'notes' as a key with a list of integer tokens.
        max_seq_len: The maximum sequence length.
        pad_token: The token used for padding.
        sos_token: The start-of-sequence token.
        eos_token: The end-of-sequence token.
        default_sample: Optional default sample for invalid data.

    Returns:
        A dictionary with 'input_ids' and 'labels' tensors.
    """
    try:
        # For validating input format
        if not isinstance(sample, dict) or "notes" not in sample:
            raise ValueError("Invalid CSV sample format: missing 'notes' key.")
        tokens = sample["notes"]
        # Check token validity
        if not all(isinstance(token, int) for token in tokens):
            raise TypeError("CSV tokens must be integers.")

        # Add special tokens and truncats
        tokens = [sos_token] + tokens[:max_seq_len-2]+[eos_token]
        # Pad the sequence
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens, dtype=torch.long),
                                               (0, max_seq_len-len(tokens)),
                                               value=pad_token)
        logger.info(f"Processed CSV data: {len(tokens)} tokens (max {max_seq_len}).")
        return {
            "input_ids": padded_tokens,
            "labels":padded_tokens
        }
    except (ValueError, TypeError) as e:
        logger.error(f"Error preprocessing CSV: {e}")
        return default_sample or {
            "input_ids":torch.full((max_seq_len), pad_token, dtype=torch.long),
            "labels":torch.full((max_seq_len), pad_token, dtype=torch.long)
        }

def json_preprocess(sample:Dict[str, Any],
                   max_seq_len:int,
                   pad_token:int,
                   sos_token:int,
                   default_sample:Dict[str, torch.Tensor]=None)->Dict[str,torch.Tensor]:
    """
    Preprocess JSON data for model input.

    Args:
        sample: A dictionary containing 'sequence' as a key with a list of integer tokens.
        max_seq_len: The maximum sequence length.
        pad_token: The token used for padding.
        sos_token: The start-of-sequence token.
        eos_token: The end-of-sequence token.
        default_sample: Optional default sample for invalid data.

    Returns:
        A dictionary with 'input_ids' and 'labels' tensors.
    """
    try:
        # Validate input format
        if not isinstance(sample, dict) or "sequence" not in sample:
            raise ValueError("Invalid JSON sample format: missing 'sequence' key.")
        tokens = sample["sequence"]
        
        # Check token validity
        if not all(isinstance(token,int) for token in tokens):
            raise TypeError("JSON tokens must be integers.")
        # Add special tokens and truncate
        tokens = [sos_token]+tokens[:max_seq_len-2]+[eos_token]
        # Pad the sequence
        padded_tokens = torch.nn.functional.pad(torch.tensor(tokens, dtype=torch.long),
                                               (0, max_seq_len-len(tokens)), value=pad_token)
        logger.info(f"Processed JSON data: {len(tokens)} tokens (max {max_seq_len}).")
        return {"input_ids":padded_tokens,
               "labels":padded_tokens}
    except (ValueError, TypeError) as e:
        logger.error(f"Error preprocessing JSON:{e}")
        return default_sample or {"input_ids":torch.full((max_seq_len),pad_token, dtype=torch.long),
                                 "labels":torch.full((max_seq_len), pad_token, dtype=torch.long)}

### **Preprocessing for the specific MaestroV2 Dataset**

In [None]:
class MaestroDataset(Dataset):
    """
    A dataset for processing Maestro MIDI files.

    Args:
        file_paths (list): List of paths to MIDI or JSON files.
        min_seq (int): Minimum sequence length.
        max_seq (int): Maximum sequence length.
        tokenizer_config (TokenizerConfig, optional): Configuration for the tokenizer (default: None)
        pad_token (int): The token used for padding sequences.
        preprocess (bool): Whether to preprocess and save tokenized files.
        output_dir (Path): Directory to save preprocessed token files.

    Attributes:
        samples (list): List of tokenized sequences.
    """
    def __init__(self, file_paths:List[Path],
                min_seq:int,
                max_seq:int,
                pad_token:int,
                tokenizer_config: TokenizerConfig=None,
                preprocess:bool=True,
                output_dir:Path=None):
        self.samples = []
        self.pad_token = pad_token
        # Preprocessing if needed
        if preprocess and output_dir is not None:
            self._preprocessing_(file_paths, tokenizer_config, output_dir)
            file_paths = List(output_dir.glob("*.json"))
        # Load preprocessed tokens
        self.load_samples(file_paths, min_seq, max_seq)
    def _preprocessing_(self, file_paths:List[Path],
                       tokenizer_config:TokenizerConfig,
                       output_dir:Path):
        output_dir.mkdir(parents=True, exist_ok=True)
        for i in tqdm(file_paths, desc="Preprocessing MIDI files"):
            try:
                if i.suffix in ["MIDI", "MID", "midi", "mid"]:
                    midi = MidiFile(i)
                    tokenizer = REMI(tokenizer_config) if tokenizer_config is not None else REMI()
                    all_tracks_tokens = [tokenizer.midi_to_tokens(midi)[0].ids 
                                         for track in midi.tracks if len(track)>0]
                    tokens = [token for track in all_tracks_tokens for token in track]
                else:
                    continue # Skip non-MIDI files
                # Save tokens to JSON
                output_file = output_dir/f"{i.stem}_tokens.json"
                with open(output_file, "w") as f:
                    json.dump({"ids":tokens}, f)
            except Exception as e:
                logger.warning(f"Error processing {i}: {e}")
    def load_samples(self, file_paths:List[Path],
                        min_seq:int,
                        max_seq:int):
        """Load tokenized samples and create sequences"""
        for file_path in tqdm(file_paths, desc="Loading tokenized files"):
            try:
                with open(file_path, "r") as f:
                    tokens = json.load(f)["ids"]
                # Create fixed length sequences
                i = 0
                while i<len(tokens):
                    if i>=len(tokens)-min_seq:
                        break
                    self.samples.append(LongTensor(tokens[i:i+max_seq]))
                    i+=len(self.samples[-1])
            except Exception as e:
                logger.warning(f"Error loading {file_path}: {e}")
    def __getitem__(self, idx)->Dict[str, LongTensor]:
        return {"input_ids":self.samples[idx],
               "labels":self.samples[idx]}
    def __len__(self)->int:
        return len(self.samples)
    def __regr__(self):
        return self.__str__()
    def __str__(self)->str:
        return "No data loaded" if len(self)==0 else f"{len(self.samples)} samples"