In [None]:
""" pre-training:
- Token sequences: "BOS_None" token sequence "PAD_None". (Treat ["PAD_None"] as eos token).
"""
""" SFT:
- Prompt answer pairs: "BOS_None" "MIDI_Prompt" melody "MIDI_Answer" piano "PAD_None" ("PAD_None" ...)
"""

In [None]:
"""
- Train a tokenizer with BPE.
- Build the pre-train dataset
"""
from pathlib import Path
from random import shuffle
from miditok import TokenizerConfig, TSD
import pickle

DATA_PATH = "datasets/aria-midi-v1-pruned-ext/data"
VOCAB_SIZE = 20_000
num_files = 90_000 # Takes lots of memory

midi_paths = [d for d in Path(DATA_PATH).rglob("*.mid")]
shuffle(midi_paths)
num_files_valid = round(len(midi_paths) * 0.02)
midi_paths_val = midi_paths[:num_files_valid]
midi_paths_train = midi_paths[num_files_valid:]

with open('models/train_paths.pkl', 'wb') as f:
    pickle.dump(midi_paths_train, f)

with open('models/val_paths.pkl', 'wb') as f:
    pickle.dump(midi_paths_val, f)

print(f"number of train and val MIDI files: {len(midi_paths_train)}, {len(midi_paths_val)}")
assert num_files <= len(midi_paths_train)

TOKENIZER_PARAMS = {
    "special_tokens": ["BOS_None", "PAD_None", "MASK_None", "MIDI_Prompt", "MIDI_Answer"], #These are referenced throughout the code
    "pitch_range": (21, 109),
    "use_velocities": False,
    "beat_res": {(0, 8): 4},
    "use_pitchdrum_tokens": False,
    "use_chords": False,
    "chord_tokens_with_root_note": False,
}
config = TokenizerConfig(**TOKENIZER_PARAMS)
tokenizer = TSD(config)
tokenizer.save("models/tokenizer_base.json")

tokenizer.train(
    vocab_size=VOCAB_SIZE,
    model="BPE",
    files_paths=midi_paths_train[:num_files],
)

tokenizer.save("models/tokenizer_trained.json")

number of train and val MIDI files: 804525, 16419


In [1]:
from config import HyperParams
import pickle
from preprocess import create_dataset

with open('models/val_paths.pkl', 'rb') as f:
    val_paths = pickle.load(f)
with open('models/train_paths.pkl', 'rb') as f:
    train_paths = pickle.load(f)

tokenizer_path = "models/tokenizer_trained.json"

create_dataset(val_paths, tokenizer_path, "pre-training_dataset/val", HyperParams.block_size, num_workers=6)
create_dataset(train_paths, tokenizer_path, "pre-training_dataset/train", HyperParams.block_size, num_workers=6)

Tokenizing MIDI files: 100%|██████████| 16419/16419 [00:53<00:00, 309.63it/s]


Merging all part files into one sequences.npy
Merged 46754 sequences into pre-training_dataset/val\sequences.npy


Tokenizing MIDI files: 100%|██████████| 804525/804525 [55:30<00:00, 241.56it/s]  


Merging all part files into one sequences.npy
Merged 2317153 sequences into pre-training_dataset/train\sequences.npy


In [1]:
import os
from typing import List
from pathlib import Path
from miditok import REMI
from tqdm import tqdm
from symusic import Score

def create_dataset(midi_paths: List[Path], out_path: str):
    tokenizer = REMI()
    window_size = 4
    stride = 2
    os.makedirs(out_path, exist_ok=True)

    for midi_path in tqdm(midi_paths):
        seq = tokenizer(midi_path)
        melody_bars = seq[0].split_per_bars()
        piano_bars = seq[2].split_per_bars()
        n_bars = min(len(melody_bars), len(piano_bars))

        # Get the last folder name (e.g., "001" from "POP909/001/xxx.mid")
        parent_folder = midi_path.parent.name
        midi_out_dir = os.path.join(out_path, parent_folder)
        os.makedirs(midi_out_dir, exist_ok=True)

        for i in range(0, n_bars - window_size + 1, stride):
            valid = all(len(melody_bars[j]) > 5 and len(piano_bars[j]) > 5 for j in range(i, i + window_size))
            if valid:
                melody_ids = []
                piano_ids = []
                for j in range(i, i + window_size):
                    melody_ids.extend(melody_bars[j].ids)
                    piano_ids.extend(piano_bars[j].ids)
                out = tokenizer([melody_ids, piano_ids])
                out.dump_midi(os.path.join(midi_out_dir, f"bar_{i}.mid"))

def augment_dataset(midi_paths: List[Path], pitch_range=[-6,-5,-4,-3,-2,-1,1,2,3,4,5,6]):
    for midi_path in tqdm(midi_paths):
        score = Score(midi_path)
        for n in pitch_range:
            try:
                new_score = score.shift_pitch(n)
                # Insert pitch shift value before the .mid extension
                new_filename = midi_path.stem + f"_{n}" + midi_path.suffix
                new_path = midi_path.with_name(new_filename)
                new_score.dump_midi(new_path)
            except Exception as e:
                continue

In [None]:
"""build the sft dataset"""
from pathlib import Path
from random import shuffle

DATA_PATH = "datasets/POP909" # The create_dataset function processes tracks 0 and 2
midi_paths = [f for d in Path(DATA_PATH).iterdir() if d.is_dir() for f in d.glob("*.mid")]
midi_paths = [p.resolve() for p in midi_paths if p.is_file()]
shuffle(midi_paths)
total_num_files = len(midi_paths)
print(f"number of MIDI files: {total_num_files}")

num_files_valid = round(total_num_files * 0.05)
num_files_test = round(total_num_files * 0.03)
midi_paths_val = midi_paths[:num_files_valid]
midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midi_paths[num_files_valid + num_files_test:]

create_dataset(midi_paths_train, "sft_dataset/train")
create_dataset(midi_paths_val, "sft_dataset/val")
create_dataset(midi_paths_test, "sft_dataset/test")

midi_paths_train = [d for d in Path("sft_dataset/train").rglob("*.mid")]
augment_dataset(midi_paths_train)

number of MIDI files: 909


100%|██████████| 837/837 [00:28<00:00, 29.34it/s]
100%|██████████| 45/45 [00:01<00:00, 29.29it/s]
100%|██████████| 27/27 [00:00<00:00, 28.14it/s]
100%|██████████| 18925/18925 [03:09<00:00, 99.77it/s] 


In [1]:
import torch
from tqdm import tqdm
from miditok import TSD
from typing import List
from pathlib import Path
import pickle

def tokenizer_dataset(midi_paths: List[Path], tokenizer: TSD, save_path: Path, max_seq_len: int = 1024):
    prompt_start = [tokenizer["BOS_None"], tokenizer["MIDI_Prompt"]]
    prompt_end = [tokenizer["MIDI_Answer"]]
    answer_end = [tokenizer["PAD_None"]]

    input_ids = torch.full((len(midi_paths), max_seq_len), tokenizer["PAD_None"], dtype=torch.long)
    prompt_lengths = torch.empty(len(midi_paths), dtype=torch.long)
    lengths = torch.empty(len(midi_paths), dtype=torch.long)

    i = 0
    for midi_path in tqdm(midi_paths, desc="Tokenizing midi files..."):
        try:
            tokens = tokenizer(midi_path)
            melody = tokens[0].ids
            piano = tokens[1].ids

            prompt = prompt_start + melody + prompt_end
            full_sequence = prompt + piano + answer_end

            if len(full_sequence) <= max_seq_len:
                prompt_lengths[i] = len(prompt)
                lengths[i] = len(full_sequence)
                input_ids[i, :len(full_sequence)] = torch.tensor(full_sequence, dtype=torch.long)
                i += 1
        except Exception as e:
            print(f"Error processing {midi_path}: {e}")
            continue

    # Slice tensors to the actual number of sequences
    input_ids = input_ids[:i]
    prompt_lengths = prompt_lengths[:i]
    lengths = lengths[:i]

    # Save as pickle
    save_path.mkdir(parents=True, exist_ok=True)
    data_dict = {
        "input_ids": input_ids,
        "prompt_lengths": prompt_lengths,
        "lengths": lengths,
        "dataset_length": i,
    }
    with open(save_path / "data.pkl", "wb") as f:
        pickle.dump(data_dict, f)

    print(f"Saved tokenized dataset with {i} samples to {save_path / 'data.pkl'}")


tokenizer = TSD(params="models/tokenizer_trained.json")
midi_paths = [d for d in Path("sft_dataset/val").rglob("*.mid")]
tokenizer_dataset(midi_paths, tokenizer, Path("sft_dataset/val"))
midi_paths = [d for d in Path("sft_dataset/test").rglob("*.mid")]
tokenizer_dataset(midi_paths, tokenizer, Path("sft_dataset/test"))
midi_paths = [d for d in Path("sft_dataset/train").rglob("*.mid")]
tokenizer_dataset(midi_paths, tokenizer, Path("sft_dataset/train"))

Tokenizing midi files...: 100%|██████████| 987/987 [00:01<00:00, 698.63it/s]


Saved tokenized dataset with 987 samples to sft_dataset\val\data.pkl


Tokenizing midi files...: 100%|██████████| 607/607 [00:03<00:00, 199.38it/s]


Saved tokenized dataset with 607 samples to sft_dataset\test\data.pkl


Tokenizing midi files...: 100%|██████████| 246025/246025 [16:01<00:00, 255.84it/s]


Saved tokenized dataset with 246025 samples to sft_dataset\train\data.pkl
