In [None]:
import os
from pathlib import Path
from copy import deepcopy

from mido import MidiFile
from symusic import Score

from miditok import REMI, TokenizerConfig, TokSequence
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

import miditoolkit
from miditoolkit import MidiFile

import pretty_midi

from torch.utils.data import DataLoader




In [30]:
def extract_and_save_rh_lh(midi_path: Path, rh_out_dir: Path, lh_out_dir: Path):
    pm = pretty_midi.PrettyMIDI(str(midi_path))

    # Check minimum tracks
    if len(pm.instruments) < 2:
        print(f"Skipping {midi_path.name}, less than 2 tracks")
        return

    # Extract right hand (track 0)
    pm_rh = pretty_midi.PrettyMIDI()
    pm_rh.instruments.append(pm.instruments[0])

    # Extract left hand (track 1)
    pm_lh = pretty_midi.PrettyMIDI()
    pm_lh.instruments.append(pm.instruments[1])

    # Save RH MIDI
    rh_out_path = rh_out_dir / midi_path.name
    pm_rh.write(str(rh_out_path))

    # Save LH MIDI
    lh_out_path = lh_out_dir / midi_path.name
    pm_lh.write(str(lh_out_path))

    print(f"Saved RH: {rh_out_path.name} | LH: {lh_out_path.name}")


def preprocess_dataset(
    input_dir: Path,
    rh_out_dir: Path,
    lh_out_dir: Path
):
    rh_out_dir.mkdir(parents=True, exist_ok=True)
    lh_out_dir.mkdir(parents=True, exist_ok=True)

    midi_files = list(input_dir.glob("**/*.mid"))
    print(f"Found {len(midi_files)} MIDI files.")

    for midi_path in midi_files:
        try:
            extract_and_save_rh_lh(midi_path, rh_out_dir, lh_out_dir)
        except Exception as e:
            print(f"Error processing {midi_path.name}: {e}")




base_dir = Path("../../data").resolve()
right_hand_dir = Path("data_right_hand").resolve()
left_hand_dir = Path("data_left_hand").resolve()

preprocess_dataset(base_dir, right_hand_dir, left_hand_dir)


Found 426 MIDI files.
Saved RH: qnraglin.mid | LH: qnraglin.mid
Saved RH: honeyrag.mid | LH: honeyrag.mid
Saved RH: redpeppr.mid | LH: redpeppr.mid
Saved RH: oldvarag.mid | LH: oldvarag.mid
Saved RH: standbyf.mid | LH: standbyf.mid
Saved RH: walknrbr.mid | LH: walknrbr.mid
Saved RH: chmpagne.mid | LH: chmpagne.mid
Saved RH: hothouse.mid | LH: hothouse.mid
Saved RH: poisoniv.mid | LH: poisoniv.mid
Saved RH: toboggan.mid | LH: toboggan.mid
Saved RH: adingysd.mid | LH: adingysd.mid
Saved RH: chestnut.mid | LH: chestnut.mid
Saved RH: fascintr.mid | LH: fascintr.mid
Saved RH: mooserag.mid | LH: mooserag.mid
Saved RH: binkswlz.mid | LH: binkswlz.mid
Saved RH: mmajstic.mid | LH: mmajstic.mid
Saved RH: atgrapes.mid | LH: atgrapes.mid
Saved RH: halleys.mid | LH: halleys.mid
Saved RH: armadill.mid | LH: armadill.mid
Saved RH: slpysdny.mid | LH: slpysdny.mid
Saved RH: broadway.mid | LH: broadway.mid
Saved RH: thirdavl.mid | LH: thirdavl.mid
Saved RH: qnoflove.mid | LH: qnoflove.mid
Saved RH: jung

In [21]:
from torch.utils.data import Dataset
from miditok import REMI
from pathlib import Path
import torch


class MIDIPairedConditionalDataset(Dataset):
    def __init__(self, rh_paths, lh_paths, tokenizer: REMI, max_seq_len=1024):
        assert len(rh_paths) == len(lh_paths), "Mismatch in dataset size"
        self.rh_paths = sorted(rh_paths)
        self.lh_paths = sorted(lh_paths)
        self.tokenizer = tokenizer
        self.max_seq_len = max_seq_len

        self.bos = tokenizer["BOS_None"]
        self.sep = tokenizer["SEP_None"]
        self.eos = tokenizer["EOS_None"]
        self.pad = tokenizer["PAD_None"]

    def __len__(self):
        return len(self.rh_paths)

    def __getitem__(self, idx):
        rh_tokens = self.tokenizer.encode(self.rh_paths[idx])
        lh_tokens = self.tokenizer.encode(self.lh_paths[idx])

        # Compose input: [BOS] RH [SEP] LH [EOS]
        input_ids = [self.bos] + rh_tokens + [self.sep] + lh_tokens + [self.eos]

        # Mask RH + BOS + SEP with -100 in the labels
        labels = [-100] * (len(rh_tokens) + 2) + lh_tokens + [self.eos]

        # Truncate if too long
        if len(input_ids) > self.max_seq_len:
            input_ids = input_ids[:self.max_seq_len]
            labels = labels[:self.max_seq_len]

        return {
            "input_ids": torch.tensor(input_ids, dtype=torch.long),
            "labels": torch.tensor(labels, dtype=torch.long)
        }


In [31]:
import os
from miditok import REMI, TokenizerConfig
from pathlib import Path
import json
from tqdm import tqdm

# To initialize/load
tokenizer = None
tokenizer_path = Path("tokenizer.json")

# Load or train tokenizer
if tokenizer_path.exists():
    tokenizer = REMI(tokenizer_path)
else:
    config = TokenizerConfig(num_velocities=32, use_chords=True, use_programs=True)
    tokenizer = REMI(config)
    midi_paths = list(Path("../../data").rglob("*.mid"))
    tokenizer.train(vocab_size=30000, files_paths=midi_paths)
    tokenizer.save(tokenizer_path)

# Input and output paths
out_dir = Path("tokenized_json")
out_dir.mkdir(exist_ok=True)

right_midi_dir = Path("data_right_hand")
left_midi_dir = Path("data_left_hand")

(right_json_dir := out_dir / "right_hand").mkdir(parents=True, exist_ok=True)
(left_json_dir := out_dir / "left_hand").mkdir(parents=True, exist_ok=True)

def tokenize_and_save(midi_path: Path, out_path: Path):
    tokens = tokenizer.encode(midi_path)
    with open(out_path.with_suffix(".json"), "w") as f:
        json.dump(tokens.ids, f)

# Tokenize right-hand tracks
print("Tokenizing right-hand MIDIs...")
for midi_file in tqdm(sorted(right_midi_dir.glob("*.mid"))):
    out_path = right_json_dir / midi_file.stem
    tokenize_and_save(midi_file, out_path)

# Tokenize left-hand tracks
print("Tokenizing left-hand MIDIs...")
for midi_file in tqdm(sorted(left_midi_dir.glob("*.mid"))):
    out_path = left_json_dir / midi_file.stem
    tokenize_and_save(midi_file, out_path)


  super().__init__(tokenizer_config, params)





Tokenizing right-hand MIDIs...


100%|██████████| 425/425 [00:07<00:00, 60.54it/s]


Tokenizing left-hand MIDIs...


100%|██████████| 425/425 [00:06<00:00, 67.95it/s]


In [None]:
from miditok import REMI, TokSequence
from miditoolkit import MidiFile
from pathlib import Path
from copy import deepcopy

def combine_tracks_to_midi(
    right_hand_ids,
    left_hand_ids,
    tokenizer_path="tokenizer.json",
    output_path="combined_output.mid"
):
    # Load tokenizer
    tokenizer = REMI.from_params(Path(tokenizer_path))

    # Convert to TokSequence if not already
    if not isinstance(right_hand_ids, TokSequence):
        right_hand_ids = TokSequence(ids=right_hand_ids)
    if not isinstance(left_hand_ids, TokSequence):
        left_hand_ids = TokSequence(ids=left_hand_ids)

    # Decode both sequences
    right_midi = tokenizer.decode(right_hand_ids)
    left_midi = tokenizer.decode(left_hand_ids)

    # Extract individual instrument tracks
    right_track = right_midi.instruments[0]
    left_track = left_midi.instruments[0]

    # Assign separate programs (instruments) and names if needed
    right_track.name = "Right Hand"
    left_track.name = "Left Hand"
    right_track.program = 0   # Acoustic Grand Piano
    left_track.program = 0
    right_track.is_drum = False
    left_track.is_drum = False

    # Create new MIDI file and add tracks
    combined_midi = MidiFile()
    combined_midi.instruments.append(deepcopy(right_track))
    combined_midi.instruments.append(deepcopy(left_track))

    # Save to file
    combined_midi.dump(output_path)
    print(f"Combined MIDI saved to: {output_path}")

    return combined_midi
