# Full example with the Hugging Face Transformers package

This notebook shows how to train a model (GPT2) and generate music from it, using the Hugging Face Transformers package.

## Setup Environment

In [None]:
#@title Install all dependencies (run only once per session)

!nvidia-smi

!pip install miditok
!pip install miditoolkit
!pip install torch
!pip install torchtoolkit
!pip install transformers
!pip install evaluate
!pip install tqdm

!wget http://www-ens.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip
!unzip 'JSB Chorales.zip'
!rm 'JSB Chorales.zip'
!mv 'JSB Chorales' 'JSB'

from typing import List, Tuple, Dict, Callable
from functools import partial
from pathlib import Path
from copy import deepcopy
import json

from torch import LongTensor, stack
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from torchtoolkit.train import select_device
from torchtoolkit.data import create_subsets
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainingArguments
from evaluate import load as load_metric
from miditok import REMI, MIDITokenizer
from miditoolkit import MidiFile
from tqdm import tqdm

## Convert MIDI files to tokens, and load them for training

In [None]:
# Our parameters
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
nb_velocities = 32
additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True,
                     'rest_range': (2, 8),  # (half, 8 beats)
                     'nb_tempos': 32,  # nb of tempo bins
                     'tempo_range': (40, 250),  # (min, max)
                     'Program': False}

# Creates the tokenizer convert MIDIs to tokens
tokens_path = Path('JSB_tokens')
tokenizer = REMI(pitch_range, beat_res, nb_velocities, additional_tokens, sos_eos=True) # REMI encoding
midi_paths = list(Path('JSB').glob('**/*.mid'))
tokenizer.tokenize_midi_dataset(midi_paths, tokens_path)

class MIDIDataset(Dataset):
    r"""Dataset for generator training

    :param files_paths: list of paths to files to load.
    :param tokenizer: tokenizer object, to use to load MIDIs instead of tokens. (default: None)
    """

    def __init__(self, files_paths: List[Path], min_seq_len: int, max_seq_len: int, tokenizer: MIDITokenizer = None):
        samples = []

        for file_path in tqdm(files_paths, desc=f'Loading data: {files_paths[0].parent}'):
            if file_path.suffix in ["mid", "midi", "MID", "MIDI"]:
                midi = MidiFile(file_path)
                for _ in range(len(midi.instruments) - 1):
                    del midi.instruments[1]  # removes all tracks except first one
                tokens = tokenizer.midi_to_tokens(midi)[0]
            else:
                with open(file_path) as json_file:
                    tokens = json.load(json_file)['tokens'][0]  # first track
            i = 0
            while i < len(tokens):
                if i >= len(tokens) - min_seq_len:
                    break  # last sample is too short
                samples.append(LongTensor(tokens[i:i + max_seq_len]))
                i += len(samples[-1])  # could be replaced with max_seq_len

        self.samples = samples

    def __getitem__(self, idx) -> Dict[str, LongTensor]:
        return {"input_ids": self.samples[idx], "labels": self.samples[idx]}
    
    def __len__(self) -> int: return len(self.samples)

    def __repr__(self): return self.__str__()

    def __str__(self) -> str: return 'No data loaded' if len(self) == 0 else f'{len(self.samples)} samples'


def _pad_batch(examples: List[Dict[str, LongTensor]], pad_token: int) -> LongTensor:
    """Collate `examples` into a batch, using the information in `tokenizer` for padding if necessary."""

    length_of_first = examples[0]["input_ids"].size(0)

    # Check if padding is necessary.
    are_tensors_same_length = all(x["input_ids"].size(0) == length_of_first for x in examples)
    if are_tensors_same_length:
        return stack([e["input_ids"] for e in examples], dim=0).long()

    # Creating the full tensor and filling it with our data.
    return pad_sequence([e["input_ids"] for e in examples], batch_first=True, padding_value=pad_token).long()


class DataCollatorGen(DataCollatorMixin):
    def __init__(self, pad_token: int, return_tensors: str = "pt"):
        """Collator that simply pad the input sequences.
        Input_ids will be padded with the pad token given, while labels will be
        padded with -100.

        :param pad_token: pas token
        :param return_tensors:
        """
        self.pad_token = pad_token
        self.return_tensors = return_tensors

    def __call__(self, batch: List[Dict[str, Any]], return_tensors=None) -> Dict[str, LongTensor]:
        x, y = _pad_batch(batch, self.pad_token), _pad_batch(batch, -100)
        return {"input_ids": x, "labels": y}  # will be shifted in GPT2LMHead forward


# Loads tokens and create data loaders for training
dataset = MIDIDataset(
    tokens_path, max_seq_len=512, min_seq_len=384, 
    padding_token=tokenizer['PAD_None'],
    sos_token=tokenizer['SOS_None']
)
subset_train, subset_valid = create_subsets(dataset, [0.3])

## Create the model

We will use the [GPT2 implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/gpt2). This 
Feel free to explore the documentation and source code to dig deeper.

In [None]:
# Creates model
config = GPT2Config(
    vocab_size=len(tokenizer),
    n_positions=2048,
    n_embd=512,
    n_layer=8,
    n_head=8,
    n_inner=2048,
    resid_pdrop=.1,
    embd_pdrop=.1,
    attn_pdrop=.1,
    padding_token_id=tokenizer['PAD_None'],
    bos_token_id=tokenizer['SOS_None'],
    eos_token_id=tokenizer['EOS_None'],
)
model = GPT2LMHeadModel(config)

## Train it

In [None]:
metrics = {metric: load_metric(metric) for metric in ["accuracy"]}

def compute_metrics(eval_pred):
    """Computes metrics for pretraining.
    Must use proprocess_logits function that converts logits to predictions (argmax or sampling).

    :param eval_pred: EvalPrediction containing predictions and labels
    :return: metrics
    """
    predictions, labels = eval_pred
    not_pad_mask = labels != -100
    labels, predictions = labels[not_pad_mask], predictions[not_pad_mask]
    return metrics["accuracy"].compute(predictions=predictions.flatten(), references=labels.flatten())

def preprocess_logits(logits: Tensor, _: Tensor) -> Tensor:
    """Preprocesses the logits before accumulating them during evaluation.
    This allows to significantly reduce the memory usage and make the training tractable.
    """
    pred_ids = argmax(logits, dim=-1)  # long dtype
    return pred_ids

training_config = TrainingArguments(
    "runs", False, True, True, False, "steps",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=48,
    gradient_accumulation_steps=3,
    eval_accumulation_steps=None,
    eval_steps=1000,
    learning_rate=1e-4,
    weight_decay=0.01,
    max_grad_norm=3.0,
    max_steps=100000,
    lr_scheduler_type="cosine_with_restarts",
    warmup_ratio=0.3,
    log_level="debug",
    logging_strategy="steps",
    logging_steps=20,
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=5,
    no_cuda=False,
    seed=444,
    fp16=True,
    load_best_model_at_end=True,
    label_smoothing_factor=0.,
    optim="adamw_torch",
    report_to=["tensorboard"],
    gradient_checkpointing=True,
)

trainer = Trainer(
    model=model,
    args=training_config,
    data_collator=DataCollatorGen(tokenizer["PAD_None"]),
    train_dataset=subset_train,
    eval_dataset=subset_valid,
    compute_metrics=compute_metrics,
    callbacks=None,
    preprocess_logits_for_metrics=preprocess_logits,
)

# Training
train_result = trainer.train()
trainer.save_model()  # Saves the tokenizer too
trainer.log_metrics("train", train_result.metrics)
trainer.save_metrics("train", train_result.metrics)
trainer.save_state()

## Generate music

In [None]:
def collate_gen(batch: List[Dict[str, Union[LongTensor, int]]]) -> LongTensor:
    return _pad_batch(batch, tokenizer["PAD_None"])  # (N,T)

nb_inferences = 512  # extends samples by 512 tokens
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
dataloader_test = DataLoader(subset_valid, batch_size=16, collate_fn=collate_gen)
device = select_device(True)
model.eval()
model = model.to(device)
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):  # (N,T)
    if device.type == 'cuda':
        batch = batch.to(device)
    res = model.generate(batch, do_sample=True, num_beams=1, top_p=0.9, max_new_tokens=600)  # (N,T)

    # Saves the generated music, as MIDI files and tokens (json)
    for prompt, continuation in zip(batch, res):
        generated = continuation[len(prompt):]
        tokens = [generated, prompt, continuation]  # list compr. as seqs of dif. lengths
        tokens = [seq.tolist() for seq in tokens]
        midi = tokenizer.tokens_to_midi(deepcopy(tokens), time_division=384)
        midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
        midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
        midi.instruments[2].name = f'Original sample and continuation'
        midi.dump(gen_results_path / f'{count}.mid')
        tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')   

        count += 1