In [5]:
import os
from pathlib import Path
from copy import deepcopy

from mido import MidiFile
from symusic import Score

from miditok import REMI, TokenizerConfig, TokSequence
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

import miditoolkit
from miditoolkit import MidiFile

import pretty_midi

from torch.utils.data import DataLoader




In [6]:
def extract_and_save_rh(midi_path: Path, rh_out_dir: Path):
    pm = pretty_midi.PrettyMIDI(str(midi_path))

    # Check minimum tracks
    if len(pm.instruments) < 2:
        print(f"Skipping {midi_path.name}, less than 2 tracks")
        return

    # Extract right hand (track 0)
    pm_rh = pretty_midi.PrettyMIDI()
    pm_rh.instruments.append(pm.instruments[0])

    # Extract left hand (track 1)
    pm_lh = pretty_midi.PrettyMIDI()
    pm_lh.instruments.append(pm.instruments[1])

    # Save RH MIDI
    rh_out_path = rh_out_dir / midi_path.name
    pm_rh.write(str(rh_out_path))

   

    print(f"Saved RH: {rh_out_path.name}")


def preprocess_dataset(
    input_dir: Path,
    rh_out_dir: Path,
):
    rh_out_dir.mkdir(parents=True, exist_ok=True)

    midi_files = list(input_dir.glob("**/*.mid"))
    print(f"Found {len(midi_files)} MIDI files.")

    for midi_path in midi_files:
        try:
            extract_and_save_rh(midi_path, rh_out_dir)
        except Exception as e:
            print(f"Error processing {midi_path.name}: {e}")




base_dir = Path("../../data").resolve()
right_hand_dir = Path("data_right_hand").resolve()

preprocess_dataset(base_dir, right_hand_dir)


Found 426 MIDI files.
Saved RH: acsrnade.mid
Saved RH: adingysd.mid
Saved RH: adlyn.mid
Saved RH: afrcnpas.mid
Saved RH: agitaton.mid
Saved RH: ailucky.mid
Saved RH: aladream.mid
Saved RH: alagazam.mid
Saved RH: alexrsrb.mid
Saved RH: amazon.mid
Saved RH: amrcnbty.mid
Saved RH: amweddng.mid
Saved RH: anoma.mid
Saved RH: antointe.mid
Saved RH: applejck.mid
Saved RH: applsass.mid
Saved RH: armadill.mid
Saved RH: ashyafrc.mid
Saved RH: atgrapes.mid
Saved RH: atlanta.mid
Saved RH: augustan.mid
Saved RH: bachelrb.mid
Saved RH: bantam.mid
Saved RH: barbwire.mid
Saved RH: bck2life.mid
Saved RH: beeswax.mid
Saved RH: belleofc.mid
Saved RH: benhur.mid
Saved RH: bethena.mid
Saved RH: bfootlou.mid
Saved RH: billikin.mid
Saved RH: binkswlz.mid
Saved RH: blackand.mid
Saved RH: blackcat.mid
Saved RH: blacksmk.mid
Saved RH: blazaway.mid
Saved RH: blckbawl.mid
Saved RH: blckblue.mid
Saved RH: blkdimnd.mid
Saved RH: blkwhtgb.mid
Saved RH: blugoose.mid
Saved RH: bnchblkb.mid
Saved RH: bohemia.mid
Saved 

In [7]:
from torch.utils.data import Dataset
from miditok import REMI  # or any other tokenizer you use
import torch


class RightHandUnconditionalMIDIDataset(Dataset):
    def __init__(self, rh_paths, tokenizer: REMI, max_seq_len=1024):
        self.rh_paths = sorted(rh_paths)
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

        self.bos = tokenizer["BOS_None"]
        self.eos = tokenizer["EOS_None"]
        self.pad = tokenizer["PAD_None"]

    def __len__(self):
        return len(self.rh_paths)

    def __getitem__(self, idx):
        rh_tokens = self.tokenizer.encode(self.rh_paths[idx])

        # Add BOS and EOS
        tokens = [self.bos] + rh_tokens + [self.eos]

        # Truncate
        if len(tokens) > self.max_seq_len:
            tokens = tokens[:self.max_seq_len]

        input_ids = tokens
        labels = tokens.copy()  # same as input for CLM
        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long)
        }


In [None]:
import os
from miditok import REMI, TokenizerConfig
from pathlib import Path
import json
from tqdm import tqdm

# To initialize/load
tokenizer = None
tokenizer_path = Path("tokenizer.json")

# Load or train tokenizer
if tokenizer_path.exists():
    print('Pre-trained tokenizer exists. Loading...')
    tokenizer = REMI.from_pretrained(tokenizer_path)
else:
    print('No tokenizer found. Training...')

    config = TokenizerConfig(
        pitch_range=(21, 109),
        beat_res={(0, 4): 4, (4, 12): 8, (12, 32): 16}, 
        num_velocities=64,
        special_tokens=["PAD", "BOS", "EOS", "MASK"],
        encode_ids_split="no",
        use_velocities=True,
        use_note_duration_programs=[0],
        use_chords=True,
        use_rests=True,  # True if you want to model silences
        use_tempos=True,
        use_time_signatures=True,
        use_sustain_pedals=False,
        use_pitch_bends=False,
        use_programs=True,
        programs=[0],
        use_pitchdrum_tokens=True,
        remove_duplicated_notes=False,
        one_token_stream_for_programs=True,
    )
right_midi_dir = Path("data_right_hand")

tokenizer = REMI(config)
midi_paths = list(right_midi_dir.rglob("*.mid"))
#tokenizer.train(vocab_size=15000, files_paths=midi_paths)
tokenizer.save(tokenizer_path)

# Input and output paths
out_dir = Path("tokenized_json")
out_dir.mkdir(exist_ok=True)


(right_json_dir := out_dir / "right_hand").mkdir(parents=True, exist_ok=True)

def tokenize_and_save(midi_path: Path, out_path: Path):
    tokens = tokenizer.encode(midi_path)
    with open(out_path.with_suffix(".json"), "w") as f:
        #json.dump(tokens[0].ids, f)
        json.dump(tokens.ids, f)


# Tokenize right-hand tracks
print("Tokenizing right-hand MIDIs...")
for midi_file in tqdm(sorted(right_midi_dir.glob("*.mid"))):
    out_path = right_json_dir / midi_file.stem
    tokenize_and_save(midi_file, out_path)


No tokenizer found. Training...
Tokenizing right-hand MIDIs...


100%|██████████| 425/425 [00:18<00:00, 22.88it/s]
