# Full example / tutorial of training a Transformer model (GPT2) for symbolic music generation

## Setup Environment

In [None]:
#@title Install all dependencies (run only once per session)

!nvidia-smi

!pip install miditok
!pip install miditoolkit
!pip install torch
!pip install torchtoolkit
!pip install transformers
!pip install tqdm

!wget http://www-ens.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip
!unzip 'JSB Chorales.zip'
!rm 'JSB Chorales.zip'
!mv 'JSB Chorales' 'JSB'

from typing import List, Tuple, Callable
from functools import partial
from pathlib import Path
from copy import deepcopy
import json

from torch import LongTensor, Tensor, cat, stack, no_grad, no_grad
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.nn import Module, CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtoolkit.sampling import nucleus
from torchtoolkit.train import train, log_model_parameters, log_cuda_info, select_device
from torchtoolkit.data import create_subsets
from transformers import GPT2LMHeadModel, GPT2Config
from miditok import REMI
from miditoolkit import MidiFile
from tqdm import tqdm

## Converts MIDI files to tokens, and load them

In [None]:
# Our parameters
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
nb_velocities = 32
additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True,
                     'rest_range': (2, 8),  # (half, 8 beats)
                     'nb_tempos': 32,  # nb of tempo bins
                     'tempo_range': (40, 250),  # (min, max)
                     'Program': False}

# Creates the tokenizer convert MIDIs to tokens
tokens_path = Path('JSB_tokens')
tokenizer = REMI(pitch_range, beat_res, nb_velocities, additional_tokens) # REMI encoding
midi_paths = list(Path('JSB').glob('**/*.mid'))
tokenizer.tokenize_midi_dataset(midi_paths, tokens_path)

class MIDIDataset:
    r"""Dataset for generator training

    :param data_path: path containing the real data to load, ex: 'data/death_metal_dataset'.
    :param min_seq_len: minimum sequence length (in nb of tokens)
    :param max_seq_len: maximum sequence length (in nb of tokens)
    :param padding_token: padding token, usually 0.
    :param tokenizer: tokenizer object, to use when fake_data_path is a list of MIDI paths. (default: None)
    """

    def __init__(self, data_path: Path, min_seq_len: int, max_seq_len: int, padding_token: int,
                 tokenizer = None):
        samples = []
        as_midi = False
        files_paths = list(Path(data_path).glob(f'**/*.json'))
        if len(files_paths) == 0:
            files_paths = list(Path(data_path).glob(f'**/*.mid'))
            as_midi = True

        for file_path in tqdm(files_paths, desc=f'Preparing data {data_path.name}'):
            if as_midi:
                tokens = tokenizer.midi_to_tokens(MidiFile(file_path))[0]  # first track
            else:
                with open(file_path) as json_file:
                    tokens = json.load(json_file)['tokens'][0]  # first track
            i = 0
            while i < len(tokens):
                if i >= len(tokens) - min_seq_len:
                    break  # last sample is too short
                samples.append(LongTensor(tokens[i:i + max_seq_len]))
                i += len(samples[-1])  # could be replaced with max_seq_len

        self.samples = pad_sequence(samples, batch_first=True, padding_value=padding_token)

    def __getitem__(self, idx) -> Tuple[LongTensor, int]: return self.samples[idx]
    
    def __len__(self) -> int: return len(self.samples)

    def __repr__(self): return self.__str__()

    def __str__(self) -> str: return 'No data loaded' if len(self) == 0 else f'{len(self.samples)} samples'


def collate_gen(batch: List[LongTensor]) -> Tuple[LongTensor, LongTensor]:
    batch = stack(batch)  # (N,T)
    return batch[:, :-1], batch[:, 1:]


# Loads tokens and create data loaders for training
dataset = MIDIDataset(tokens_path, max_seq_len=512, min_seq_len=384, padding_token=tokenizer['PAD_None'])
subset_train, subset_valid = create_subsets(dataset, [0.3])
dataloader_train = DataLoader(subset_train, batch_size=16, collate_fn=collate_gen)
dataloader_valid = DataLoader(subset_valid, batch_size=16, collate_fn=collate_gen)

## Create model

We will use the [GPT2 implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/gpt2). This 
Feel free to explore the documentation and source code to dig deeper.

In [None]:
class Transformer(GPT2LMHeadModel):
    def __init__(self, config: GPT2Config, padding_token: int):
        super().__init__(config)
        self.transformer.wpe.padding_idx = padding_token  # updates the padding idx
        self.transformer.wte.padding_idx = padding_token

        self.register_buffer('padding_token', LongTensor([padding_token]))

    def forward_train(self, x: LongTensor, target: LongTensor, criterion: Module):
        y = self.forward(x).logits  # (N,T,C)
        loss = criterion(y.transpose(2, 1), target)
        return y, loss, None  # no need for sampled

    @no_grad()
    def generate(self, x: LongTensor, nb_steps: int, max_seq_len: int, sampling_func: Callable = None) -> LongTensor:
        r"""Generate (extend) from the generator
        :param x: input tensor to extend, shape (N,T) or (T)
        :param nb_steps: number of steps (inferences) to run
        :param max_seq_len: maximum sequence length during inference
        :param sampling_func: sampling function (default: top_k with k=15)
        :return: the generated tensor
        """
        assert max_seq_len <= (nb_pos := self.transformer.wpe.weight.shape[0]), \
            'The maximum sequence length must be <= to the nb of positions the model can handle'
        sampling_func = partial(nucleus, p=0.9) if sampling_func is None else sampling_func
        y = x.clone()
        if y.dim() == 1:
            y = y[x != self.padding_token].unsqueeze(0)  # (T) --> (N,T) with N=1
        
        # past_key_val stores the past computations so that we do not recompute them again
        past_key_val, pos_ids = None, None  # (NLY,2,N,NH,T,DH) & (T'), T' for the non-past-kv part (often 1)
        offset = 0
        tokens = y.clone()  # (N,T)
        for _ in range(nb_steps):
            # Adds the prediction to the target sequence, updates past key values and y sequence
            logits = self.forward(tokens, past_key_val, position_ids=pos_ids)
            logits, past_key_val = logits.logits, logits.past_key_values  # (N,T,C)
            tokens = sampling_func(logits[:, -1]).unsqueeze(1).to(x.device)  # (N,1)
            y = cat([y, tokens], dim=1)  # (N,T+1)

            # Reset past_kv and offset to not exceed pos enc
            if past_key_val[0][0].shape[-2] + offset >= nb_pos:
                past_key_val, pos_ids, offset = None, None, 0
                tokens = y[..., -x.shape[-1]:].clone()  # starting back with len of x for prompt

            # Reduces past_kv if the max len is reached
            if past_key_val is not None and past_key_val[0][0].shape[-2] >= max_seq_len:
                offset += 1
                past_key_val = convert_past_key_values_to_tensor(past_key_val)[..., -max_seq_len:, :]
                pos_ids = LongTensor([past_key_val.shape[-2] + offset]).to(x.device)

        return y[0] if x.dim() == 1 else y  # (T) or (N,T)


def convert_past_key_values_to_tensor(past_kv: Tuple) -> Tensor:
    """Convert past_key_values returned by HF model from tuple(tuple(Tensor)) to a Tensor.
    :param past_kv: tuple of past_key_val, shape (NLY,2,N,NH,T,DH) with first two dims as tuple
    :return: Tensor of shape (NLY,2,N,NH,T,DH)
    """
    return stack([stack([kv for kv in layer]) for layer in past_kv])  # (NLY,2,N,NH,T,DH)


# Creates model
config = GPT2Config(vocab_size=len(tokenizer), n_positions=2048, n_embd=512, n_layer=8, n_head=8,
                    n_inner=2048, resid_pdrop=.1, embd_pdrop=.1, attn_pdrop=.1,
                    bos_token_id=tokenizer['SOS_None'], eos_token_id=tokenizer['EOS_None'])
model = Transformer(config, padding_token=tokenizer['PAD_None'])

## Train the model

In [None]:
save_path = Path('run')
device = select_device(True)
model = model.to(device)
criterion = CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=1e-4, weight_decay=1e-2)
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, 20, 2)

log_model_parameters(model)
if device.type == 'cuda':
      log_cuda_info()

train(model, criterion, optimizer, dataloader_train, dataloader_valid, 50000, 20, 10, None,
      'TRAINING MIDI GENERATOR', None, 10, None, lr_scheduler=lr_scheduler, device_=device,
      use_amp=True, saving_dir=save_path)

## Generate music

In [None]:
nb_inferences = 512  # extends samples by 512 tokens
gen_results_path = Path('gen_res')

for i, sample in enumerate(subset_valid):
    seq = sample[sample != tokenizer['PAD_None']]  # trims the sample if padded

    res = model.generate(seq, nb_inferences, 512)  # generates auto regressively
    continuation = res[len(seq):]  # just generated tokens
    tokens = [continuation, seq, res]
    tokens = [seq.tolist() for seq in tokens]  # converts to lists
    midi = tokenizer.tokens_to_midi(deepcopy(tokens), time_division=384)
    midi.instruments[0].name = f'Continuation of original sample ({len(continuation)} tokens)'
    midi.instruments[1].name = f'Original sample ({len(seq)} tokens)'
    midi.instruments[2].name = f'Original sample and continuation'
    midi.dump(gen_results_path / f'{i}.mid')
    tokenizer.save_tokens(tokens, gen_results_path / f'{i}.json')