# Using Midi traning data and MidiTok Remi to generate music with Mistral model


## Setup Environment

In [None]:
To compile Symusic 

Get g++11 or higher

git clone --recursive https://github.com/Yikai-Liao/symusic
CXX=/usr/bin/g++-11 pip install ./symusic


In [None]:


from copy import deepcopy
from pathlib import Path
from random import shuffle

from evaluate import load as load_metric
from miditok import REMI, TokenizerConfig
from miditok.pytorch_data import DatasetMIDI, DataCollator
from miditok.utils import split_files_for_training

from miditok.data_augmentation import augment_dataset
from torch import Tensor, argmax
from torch.utils.data import DataLoader
from torch.cuda import is_available as cuda_available, is_bf16_supported
from torch.backends.mps import is_available as mps_available
from transformers import AutoModelForCausalLM, MistralConfig, Trainer, TrainingArguments, GenerationConfig, AutoConfig
from transformers.trainer_utils import set_seed
from tqdm import tqdm

## Setup Tokenizer

In [2]:
# Seed
set_seed(777)

# Our tokenizer's configuration
BEAT_RES = {(0, 1): 12, (1, 2): 4, (2, 4): 2, (4, 8): 1}
TOKENIZER_PARAMS = {
    "pitch_range": (21, 109),
    "beat_res": BEAT_RES,
    "num_velocities": 24,
    "special_tokens": ["PAD", "BOS", "EOS"],
    "use_chords": True,
    "use_rests": True,
    "use_tempos": True,
    "use_time_signatures": True,
    "use_programs": False,  # no multitrack here
    "num_tempos": 32,
    "tempo_range": (50, 200),  # (min_tempo, max_tempo)
}
config = TokenizerConfig(**TOKENIZER_PARAMS)

# Creates the tokenizer
tokenizer = REMI(config)

# Load Midi filed and train the the tokenizer on the midi files

In [3]:
root_data_dir = Path('/home/wombat/Documents/projects/music/midiTok/data/')


In [None]:

# Trains the tokenizer with Byte Pair Encoding (BPE) to build the vocabulary, here 30k tokens
data_dirs = ["adl-piano-midi", "maestro-v3.0.0", "musicnet_midis", "clean_midi", "vg_music_database" ]

midi_paths = []
for data_dir in data_dirs:
    path = Path(root_data_dir / 'Traning Data' / data_dir)
    midi_paths.extend(list(path.resolve().glob("**/*.mid")) + list(path.resolve().glob("**/*.midi")))

print(f"Found {len(midi_paths)} MIDI files")

In [None]:
tokenizer.train(
    vocab_size=30000,
    files_paths=midi_paths,
)
tokenizer.save("tokenizer.json")


In [3]:
tokenizer = REMI(params=Path("tokenizer.json"))

## Prepare MIDIs for training

Here we split the files in three subsets: train, validation and test.
Then data augmentation is performed on each subset independently, and the MIDIs are split into smaller chunks that make approximately the desired token sequence length for training.

In [9]:
root_save = Path(root_data_dir / 'Pre_Training_Data_Music_small')

In [37]:
# Split MIDI paths in train/valid/test sets
total_num_files = len(midi_paths)
num_files_valid = round(total_num_files * 0.15)
num_files_test = round(total_num_files * 0.15)
shuffle(midi_paths)
midi_paths_valid = midi_paths[:num_files_valid]
midi_paths_test = midi_paths[num_files_valid:num_files_valid + num_files_test]
midi_paths_train = midi_paths[num_files_valid + num_files_test:]



# Chunk MIDIs and perform data augmentation on each subset independently
for files_paths, subset_name in (
    (midi_paths_train, "train"), (midi_paths_valid, "valid"), (midi_paths_test, "test")
):

    # Split the MIDIs into chunks of sizes approximately about 1024 tokens
    subset_chunks_dir = root_save / f"Maestro_{subset_name}"
    print(subset_chunks_dir)
    split_files_for_training(
        files_paths=files_paths,
        tokenizer=tokenizer,
        save_dir=subset_chunks_dir,
        max_seq_len=1024,
        num_overlap_bars=2,
    )

    if subset_name == 'train':
        print("Augmentation")
        # Perform data augmentation
        augment_dataset(
            subset_chunks_dir,
            pitch_offsets=[-12, 12],
            velocity_offsets=[-4, 4],
            duration_offsets=[-0.5, 0.5],
        )


In [26]:
# Create Dataset and Collator for training
midi_paths_train = list(root_save.joinpath(Path("Maestro_train")).glob("**/*.mid")) + list(root_save.joinpath(Path("Maestro_train")).glob("**/*.midi"))
midi_paths_valid = list(root_save.joinpath(Path("Maestro_valid")).glob("**/*.mid")) + list(root_save.joinpath(Path("Maestro_valid")).glob("**/*.midi")) 
midi_paths_test = list(root_save.joinpath(Path("Maestro_test")).glob("**/*.mid")) + list(root_save.joinpath(Path("Maestro_test")).glob("**/*.midi"))

kwargs_dataset = {"max_seq_len": 1024, "tokenizer": tokenizer, "bos_token_id": tokenizer["BOS_None"], "eos_token_id": tokenizer["EOS_None"]}

dataset_train = DatasetMIDI(midi_paths_train, **kwargs_dataset)
dataset_valid = DatasetMIDI(midi_paths_valid, **kwargs_dataset)
dataset_test = DatasetMIDI(midi_paths_test, **kwargs_dataset)
print (len(midi_paths_train), len(midi_paths_valid), len(midi_paths_test))

# Preview files data load and split

In [None]:
testing_files = ['/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/adl-piano-midi/Pop/Dance Pop/Tiësto/Adagio For Strings.mid', '/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/adl-piano-midi/Pop/Dance Pop/Robbie Williams/Angels.mid',
                  "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/AC DC/Thunderstruck.mid" , "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/AC DC/Highway To Hell.1.mid",
                  "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Aerosmith/I Don't Want to Miss a Thing.1.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Aerosmith/I Don't Want to Miss a Thing.mid",
                  "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Alanis Morissette/Hand in My Pocket.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Alanis Morissette/Ironic.mid",
                  "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/All Saints/Pure Shoes.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Bob Dylan/All Along the Watchtower.mid",
                  "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Bob Dylan/Hurricane.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Brown James/Papa's Got a Brand New Bag.mid",
                  "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Creed/Higher.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Creed/With Arms Wide Open (Strings version).1.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Curtis Mayfield/Move On Up.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Eagle-Eye Cherry/Save Tonight.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Energy 52/Cafe del Mar.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Faithless/Insomnia (D Donatis mix).mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Fatboy Slim/Right Here, Right Now.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Linkin Park/One Step Closer.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Live/Lightning Crashes.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Midnight Oil/Beds Are Burning.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Natalie Imbruglia/Torn.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Radiohead/High and Dry.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Radiohead/Creep.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Radiohead/Paranoid Android.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Soundgarden/Black Hole Sun.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Tears for Fears/Everybody Wants To Rule The World.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/The Cranberries/Dreams.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/The Cranberries/Zombie.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/The Prodigy/Breath.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/The Prodigy/Smack My Bitch Up.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Third Eye Blind/Jumper.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Third Eye Blind/Semi-Charmed Life.mid",
                   "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/Trance/Breathe.mid", "/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Traning Data/clean_midi/William Orbit/Barber's Adagio for Strings.mid"]

preview_files_path = []
for testing_file in testing_files:
    preview_files_path.append(Path(testing_file))

preview_dir = Path("/media/wombat/c6928dc9-ba03-411d-9483-8e28df5973b9/Music Data/Pre_Training_Data_Music_small/preview")
split_files_for_training(
        files_paths=preview_files_path,
        tokenizer=tokenizer,
        save_dir=preview_dir,
        max_seq_len=1024,
        num_overlap_bars=2,
    )


In [36]:
midi_split_preview = list(preview_dir.glob("**/*.mid")) + list(preview_dir.glob("**/*.midi"))

file_name_lookup = []
def func_to_get_labels(p1, p2, p3):
    if p3.name not in file_name_lookup:
        file_name_lookup.append(p3.name)
    return file_name_lookup.index(p3.name)
    
kwargs_dataset = {"max_seq_len": 1024, "tokenizer": tokenizer, "bos_token_id": tokenizer["BOS_None"], "eos_token_id": tokenizer["EOS_None"], "func_to_get_labels" : func_to_get_labels}
dataset_preview = DatasetMIDI(midi_split_preview, **kwargs_dataset)

# Save and Load datasets

In [27]:
import torch
torch.save(dataset_train, "dataset_train.pt")
torch.save(dataset_valid, "dataset_valid.pt")
torch.save(dataset_test, "dataset_test.pt")


In [14]:
import torch
dataset_train = torch.load("dataset_train.pt")
dataset_valid = torch.load("dataset_valid.pt")
dataset_test = torch.load("dataset_test.pt")



In [None]:
print(dataset_train[0])


## Model initialization

We will use the [Mistral implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/mistral).
Feel free to explore the documentation and source code to dig deeper.

**You may need to adjust the model's configuration, the training configuration and the maximum input sequence length (cell above) depending on your hardware.**

In [12]:
# Creates model
model_config = MistralConfig(
    vocab_size=len(tokenizer),
    hidden_size=512,
    intermediate_size=2048,
    num_hidden_layers=8,
    num_attention_heads=8,
    num_key_value_heads=4,
    sliding_window=256,
    max_position_embeddings=8192,
    pad_token_id=tokenizer['PAD_None'],
    bos_token_id=tokenizer['BOS_None'],
    eos_token_id=tokenizer['EOS_None'],
)
model = AutoModelForCausalLM.from_config(model_config)

## Model training

In [None]:
metrics = {metric: load_metric(metric) for metric in ["accuracy"]}

def compute_metrics(eval_pred):
    """
    Compute metrics for pretraining.

    Must use preprocess_logits function that converts logits to predictions (argmax or sampling).

    :param eval_pred: EvalPrediction containing predictions and labels
    :return: metrics
    """
    predictions, labels = eval_pred
    not_pad_mask = labels != -100
    labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]
    return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten())

def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
    """
    Preprocess the logits before accumulating them during evaluation.

    This allows to significantly reduce the memory usage and make the training tractable.
    """
    pred_ids = argmax(logits, dim=-1)  # long dtype
    return pred_ids

# Create config for the Trainer
USE_CUDA = cuda_available()
print(USE_CUDA)
if not cuda_available():
    FP16 = FP16_EVAL = BF16 = BF16_EVAL = False
elif is_bf16_supported():
    BF16 = BF16_EVAL = True
    FP16 = FP16_EVAL = False
else:
    BF16 = BF16_EVAL = False
    FP16 = FP16_EVAL = True
USE_MPS = not USE_CUDA and mps_available()
training_config = TrainingArguments(
    "runs", False, True, True, False, "steps",
    per_device_train_batch_size=24, #76% @ 24 batch size
    per_device_eval_batch_size=24,
    gradient_accumulation_steps=3,
    eval_accumulation_steps=None,
    eval_steps=1000,
    learning_rate=1e-4,
    weight_decay=0.01,
    max_grad_norm=3.0,
    max_steps=20000,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.3,
    log_level="debug",
    logging_strategy="steps",
    logging_steps=20,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=5,
    no_cuda=not USE_CUDA,
    seed=444,
    fp16=FP16,
    fp16_full_eval=FP16_EVAL,
    bf16=BF16,
    bf16_full_eval=BF16_EVAL,
    load_best_model_at_end=True,
    label_smoothing_factor=0.,
    optim="adamw_torch",
    report_to=["tensorboard"],
    gradient_checkpointing=True,
)

collator = DataCollator(tokenizer["PAD_None"], copy_inputs_as_labels=True)
trainer = Trainer(
    model=model,
    args=training_config,
    data_collator=collator,
    train_dataset=dataset_train,
    eval_dataset=dataset_valid,
    compute_metrics=compute_metrics,
    callbacks=None,
    preprocess_logits_for_metrics=preprocess_logits,
)



In [32]:
# Training
train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

In [6]:

config = AutoConfig.from_pretrained("./runs/config.json")
model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path="./runs/model.safetensors", from_tf=False, config=config)

## Generate music

In [None]:
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
generation_config = GenerationConfig(
    max_new_tokens=200,  # extends samples by 200 tokens
    num_beams=1,         # no beam search
    do_sample=True,      # but sample instead
    temperature=0.9,
    top_k=15,
    top_p=0.95,
    epsilon_cutoff=3e-4,
    eta_cutoff=1e-3,
    pad_token_id=tokenizer.pad_token_id,
)

# Here the sequences are padded to the left, so that the last token along the time dimension
# is always the last token of each seq, allowing to efficiently generate by batch
collator.pad_on_left = True
collator.eos_token = None
dataloader_test = DataLoader(dataset_preview, batch_size=24, collate_fn=collator)
model.eval()
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):  # (N,T)
    res = model.generate(
        inputs=batch["input_ids"].to(model.device),
        attention_mask=batch["attention_mask"].to(model.device),
        generation_config=generation_config)  # (N,T)


    # Saves the generated music, as MIDI files and tokens (json)
    for prompt, continuation in zip(batch["input_ids"], res):
        generated = continuation[len(prompt):]
        midi = tokenizer.decode([deepcopy(generated.tolist())])
        tokens = [generated, prompt, continuation]  # list compr. as seqs of dif. lengths
        tokens = [seq.tolist() for seq in tokens]
        for tok_seq in tokens[1:]:
            _midi = tokenizer.decode([deepcopy(tok_seq)])
            midi.tracks.append(_midi.tracks[0])
            
        file_name = file_name_lookup[count]
        print(file_name)
        midi.tracks[0].name = f'Continuation of original sample ({len(generated)} tokens) Original file {file_name}'
        midi.tracks[1].name = f'Original sample ({len(prompt)} tokens)'
        if (len(midi.tracks) > 2):
            midi.tracks[2].name = f'Original sample and continuation'
        midi.dump_midi(gen_results_path / f'{count}_{file_name}.mid')
        tokenizer.save_tokens(tokens, gen_results_path / f'{count}_{file_name}.json') 

        count += 1

In [None]:
print(file_name_lookup)