In [1]:
import os
from pathlib import Path
from copy import deepcopy

from mido import MidiFile
from symusic import Score

from miditok import REMI, TokenizerConfig, TokSequence
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

import miditoolkit
from miditoolkit import MidiFile

import pretty_midi

from torch.utils.data import DataLoader




In [2]:
def extract_and_save_rh_lh(midi_path: Path, rh_out_dir: Path, lh_out_dir: Path):
    pm = pretty_midi.PrettyMIDI(str(midi_path))

    # Check minimum tracks
    if len(pm.instruments) != 2:
        print(f"Skipping {midi_path.name}, doesn't contain 2 tracks")
        return

    # Extract right hand (track 0)
    pm_rh = pretty_midi.PrettyMIDI()
    pm_rh.instruments.append(pm.instruments[0])

    # Extract left hand (track 1)
    pm_lh = pretty_midi.PrettyMIDI()
    pm_lh.instruments.append(pm.instruments[1])

    # Save RH MIDI
    rh_out_path = rh_out_dir / midi_path.name
    pm_rh.write(str(rh_out_path))

    # Save LH MIDI
    lh_out_path = lh_out_dir / midi_path.name
    pm_lh.write(str(lh_out_path))

    print(f"Saved RH: {rh_out_path.name} | LH: {lh_out_path.name}")


def preprocess_dataset(
    input_dir: Path,
    rh_out_dir: Path,
    lh_out_dir: Path
):
    rh_out_dir.mkdir(parents=True, exist_ok=True)
    lh_out_dir.mkdir(parents=True, exist_ok=True)

    midi_files = list(input_dir.glob("**/*.mid"))
    print(f"Found {len(midi_files)} MIDI files.")

    for midi_path in midi_files:
        try:
            extract_and_save_rh_lh(midi_path, rh_out_dir, lh_out_dir)
        except Exception as e:
            print(f"Error processing {midi_path.name}: {e}")





In [None]:
base_dir = Path("../../data").resolve()
right_hand_dir = Path("data_right_hand").resolve()
left_hand_dir = Path("data_left_hand").resolve()

preprocess_dataset(base_dir, right_hand_dir, left_hand_dir)

In [3]:
from torch.utils.data import Dataset
from miditok import REMI
from pathlib import Path
import torch


class MIDIPairedConditionalDataset(Dataset):
    def __init__(self, rh_paths, lh_paths, tokenizer: REMI, max_seq_len=1024):
        assert len(rh_paths) == len(lh_paths), "Mismatch in dataset size"
        self.rh_paths = sorted(rh_paths)
        self.lh_paths = sorted(lh_paths)
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

        self.bos = tokenizer["BOS_None"]
        self.sep = tokenizer["SEP_None"]
        self.eos = tokenizer["EOS_None"]
        self.pad = tokenizer["PAD_None"]

    def __len__(self):
        return len(self.rh_paths)

    def __getitem__(self, idx):
        rh_tokens = self.tokenizer.encode(self.rh_paths[idx])
        lh_tokens = self.tokenizer.encode(self.lh_paths[idx])

        # Compose input: [BOS] RH [SEP] LH [EOS]
        input_ids = [self.bos] + rh_tokens + [self.sep] + lh_tokens + [self.eos]

        # Mask RH + BOS + SEP with -100 in the labels
        labels = [-100] * (len(rh_tokens) + 2) + lh_tokens + [self.eos]

        # Truncate if too long
        if len(input_ids) > self.max_seq_len:
            input_ids = input_ids[:self.max_seq_len]
            labels = labels[:self.max_seq_len]

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long)
        }


In [4]:
import os
from miditok import REMI, TokenizerConfig
from pathlib import Path
import json
from tqdm import tqdm

# To initialize/load
tokenizer = None
tokenizer_path = Path("tokenizer.json")

# Load or train tokenizer
if tokenizer_path.exists():
    print('Pre-trained tokenizer exists. Loading...')
    tokenizer = REMI.from_pretrained(tokenizer_path)
else:
    print('No tokenizer found. Training...')
    #config = TokenizerConfig(num_velocities=32, use_chords=True, use_programs=True)

    config = TokenizerConfig(
        pitch_range=(21, 109),
        beat_res={(0, 4): 4, (4, 12): 8, (12, 32): 16}, 
        num_velocities=64,
        special_tokens=["PAD", "BOS", "EOS", "MASK"],
        encode_ids_split="no",
        use_velocities=True,
        use_note_duration_programs=[0],  
        use_chords=True,
        use_rests=True,  # True if you want to model silences
        use_tempos=True,
        use_time_signatures=True,
        use_sustain_pedals=False,
        use_pitch_bends=False,
        use_programs=True,  
        programs=[0],
        use_pitchdrum_tokens=True,
        remove_duplicated_notes=False,
        one_token_stream_for_programs=True,
    )

    tokenizer = REMI(config)
    midi_paths = list(Path("../../data").rglob("*.mid"))
    #tokenizer.train(vocab_size=30000, files_paths=midi_paths)
    tokenizer.save(tokenizer_path)

# Input and output paths
out_dir = Path("tokenized_json")
out_dir.mkdir(exist_ok=True)

right_midi_dir = Path("data_right_hand")
left_midi_dir = Path("data_left_hand")

(right_json_dir := out_dir / "right_hand").mkdir(parents=True, exist_ok=True)
(left_json_dir := out_dir / "left_hand").mkdir(parents=True, exist_ok=True)

def tokenize_and_save(midi_path: Path, out_path: Path):
    tokens = tokenizer.encode(midi_path)
    with open(out_path.with_suffix(".json"), "w") as f:
        json.dump(tokens.ids, f)


# Tokenize right-hand tracks
print("Tokenizing right-hand MIDIs...")
for midi_file in tqdm(sorted(right_midi_dir.glob("*.mid"))):
    out_path = right_json_dir / midi_file.stem
    tokenize_and_save(midi_file, out_path)

# Tokenize left-hand tracks
print("Tokenizing left-hand MIDIs...")
for midi_file in tqdm(sorted(left_midi_dir.glob("*.mid"))):
    out_path = left_json_dir / midi_file.stem
    tokenize_and_save(midi_file, out_path)


  super().__init__(tokenizer_config, params)


Pre-trained tokenizer exists. Loading...
Tokenizing right-hand MIDIs...


100%|██████████| 425/425 [00:07<00:00, 58.64it/s]


Tokenizing left-hand MIDIs...


100%|██████████| 425/425 [00:06<00:00, 64.07it/s]


In [5]:
generated_base_dir = Path("symbolic_conditional").resolve()
generated_right_hand_dir = Path("generated_data_right_hand").resolve()
generated_left_hand_dir = Path("generated_data_left_hand").resolve()

preprocess_dataset(generated_base_dir, generated_right_hand_dir, generated_left_hand_dir)

Found 53 MIDI files.
Saved RH: symbolic_conditional_2.mid | LH: symbolic_conditional_2.mid
Saved RH: symbolic_conditional_3.mid | LH: symbolic_conditional_3.mid
Saved RH: symbolic_conditioned.mid | LH: symbolic_conditioned.mid
Saved RH: symbolic_conditioned_1.mid | LH: symbolic_conditioned_1.mid
Saved RH: symbolic_conditioned_10.mid | LH: symbolic_conditioned_10.mid
Saved RH: symbolic_conditioned_11.mid | LH: symbolic_conditioned_11.mid
Saved RH: symbolic_conditioned_12.mid | LH: symbolic_conditioned_12.mid
Saved RH: symbolic_conditioned_13.mid | LH: symbolic_conditioned_13.mid
Saved RH: symbolic_conditioned_14.mid | LH: symbolic_conditioned_14.mid
Saved RH: symbolic_conditioned_15.mid | LH: symbolic_conditioned_15.mid
Saved RH: symbolic_conditioned_16.mid | LH: symbolic_conditioned_16.mid
Saved RH: symbolic_conditioned_17.mid | LH: symbolic_conditioned_17.mid
Saved RH: symbolic_conditioned_18.mid | LH: symbolic_conditioned_18.mid
Saved RH: symbolic_conditioned_19.mid | LH: symbolic_co