# Full example of music generation, with the Hugging Face GPT2 Transformer model

## Setup Environment

In [None]:
#@title Install all dependencies (run only once per session)

!nvidia-smi

!pip install miditok
!pip install miditoolkit
!pip install torch
!pip install torchtoolkit
!pip install transformers
!pip install tqdm

!wget http://www-ens.iro.umontreal.ca/~boulanni/JSB%20Chorales.zip
!unzip 'JSB Chorales.zip'
!rm 'JSB Chorales.zip'
!mv 'JSB Chorales' 'JSB'

from typing import List, Tuple, Callable
from functools import partial
from pathlib import Path
from copy import deepcopy
import json

from torch import LongTensor, cat, stack, full, flip
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from torch.nn import Module, CrossEntropyLoss
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torchtoolkit.train import train, log_model_parameters, log_cuda_info, select_device
from torchtoolkit.data import create_subsets
from transformers import GPT2LMHeadModel, GPT2Config
from miditok import REMI
from miditoolkit import MidiFile
from tqdm import tqdm

## Convert MIDI files to tokens, and load them for training

In [None]:
# Our parameters
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
nb_velocities = 32
additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True,
                     'rest_range': (2, 8),  # (half, 8 beats)
                     'nb_tempos': 32,  # nb of tempo bins
                     'tempo_range': (40, 250),  # (min, max)
                     'Program': False}

# Creates the tokenizer convert MIDIs to tokens
tokens_path = Path('JSB_tokens')
tokenizer = REMI(pitch_range, beat_res, nb_velocities, additional_tokens, sos_eos=True) # REMI encoding
midi_paths = list(Path('JSB').glob('**/*.mid'))
tokenizer.tokenize_midi_dataset(midi_paths, tokens_path)

class MIDIDataset:
    r"""Dataset for generator training

    :param data_path: path containing the real data to load, ex: 'data/death_metal_dataset'.
    :param min_seq_len: minimum sequence length (in nb of tokens)
    :param max_seq_len: maximum sequence length (in nb of tokens)
    :param padding_token: padding token, usually 0.
    :param sos_token: "Start Of Sequence" token, to be placed at the beginning of each token sequence.
    :param tokenizer: tokenizer object, to use when fake_data_path is a list of MIDI paths. (default: None)
    """

    def __init__(self, data_path: Path, min_seq_len: int, max_seq_len: int, padding_token: int, sos_token: int,
                 tokenizer = None):
        self.pad_token = padding_token
        self.sos_token = sos_token
        samples = []
        as_midi = False
        files_paths = list(Path(data_path).glob(f'**/*.json'))
        if len(files_paths) == 0:
            files_paths = list(Path(data_path).glob(f'**/*.mid'))
            as_midi = True

        for file_path in tqdm(files_paths, desc=f'Preparing data {data_path.name}'):
            if as_midi:
                tokens = tokenizer.midi_to_tokens(MidiFile(file_path))[0]  # first track
            else:
                with open(file_path) as json_file:
                    tokens = json.load(json_file)['tokens'][0]  # first track
            i = 0
            while i < len(tokens):
                if i >= len(tokens) - min_seq_len:
                    break  # last sample is too short
                samples.append(LongTensor(tokens[i:i + max_seq_len]))
                i += len(samples[-1])  # could be replaced with max_seq_len

        self.samples = samples
    
    def collate_fn(self, batch: List[LongTensor]) -> Tuple[LongTensor, LongTensor]:
        batch = pad_sequence(batch, batch_first=True, padding_value=self.pad_token)  # (N,T) or (N,T,Z)
        (sos_shape := list(batch.shape))[1] = 1  # (N,1) or (N,1,Z)
        batch = cat([full(sos_shape, self.sos_token), batch], dim=1)  # adds sos token to every samples
        return batch[:, :-1], batch[:, 1:]

    def collate_fn_infer(self, batch: List[LongTensor]) -> LongTensor:
        # Here the sequences are padded to the left, so that the last element along the time dimension
        # is always the last of each seq, allowing to efficiently generate by batch
        sos_shape = (1,) if batch[0].dim() == 1 else (1, batch[0].shape[-1])  # (1) or (1,Z)
        batch = [flip(cat([full(sos_shape, self.sos_token), seq], dim=0), dims=(0, )) for seq in batch]
        batch = pad_sequence(batch, batch_first=True, padding_value=self.pad_token)  # (N,T) or (N,T,Z)
        batch = flip(batch, dims=(1, )).long()
        return batch

    def __getitem__(self, idx) -> Tuple[LongTensor, int]: return self.samples[idx]
    
    def __len__(self) -> int: return len(self.samples)

    def __repr__(self): return self.__str__()

    def __str__(self) -> str: return 'No data loaded' if len(self) == 0 else f'{len(self.samples)} samples'


# Loads tokens and create data loaders for training
dataset = MIDIDataset(
    tokens_path, max_seq_len=512, min_seq_len=384, 
    padding_token=tokenizer['PAD_None'],
    sos_token=tokenizer['SOS_None']
)
subset_train, subset_valid = create_subsets(dataset, [0.3])
dataloader_train = DataLoader(subset_train, batch_size=16, collate_fn=dataset.collate_fn)
dataloader_valid = DataLoader(subset_valid, batch_size=16, collate_fn=dataset.collate_fn)

## Create the model

We will use the [GPT2 implementation of Hugging Face](https://huggingface.co/docs/transformers/model_doc/gpt2). This 
Feel free to explore the documentation and source code to dig deeper.

In [None]:
class Transformer(GPT2LMHeadModel):
    def __init__(self, config: GPT2Config, padding_token: int):
        super().__init__(config)
        self.transformer.wpe.padding_idx = padding_token  # updates the padding idx
        self.transformer.wte.padding_idx = padding_token

    def forward_train(self, x: LongTensor, target: LongTensor, criterion: Module):
        y = self.forward(x).logits  # (N,T,C)
        loss = criterion(y.transpose(2, 1), target)
        return y, loss, None  # no need for sampled


# Creates model
config = GPT2Config(vocab_size=len(tokenizer), n_positions=2048, n_embd=512, n_layer=8, n_head=8,
                    n_inner=2048, resid_pdrop=.1, embd_pdrop=.1, attn_pdrop=.1,
                    padding_token_id=tokenizer['PAD_None'], bos_token_id=tokenizer['SOS_None'],
                    eos_token_id=tokenizer['EOS_None'])
model = Transformer(config, padding_token=tokenizer['PAD_None'])

## Train it

In [None]:
save_path = Path('run')
device = select_device(True)
model = model.to(device)
criterion = CrossEntropyLoss()
optimizer = Adam(params=model.parameters(), lr=2e-5, weight_decay=1e-2)
lr_scheduler = CosineAnnealingWarmRestarts(optimizer, 20, 2)

log_model_parameters(model)
if device.type == 'cuda':
      log_cuda_info()

train(
    model=model,
    criterion=criterion,
    optimizer=optimizer,
    dataloader_train=dataloader_train,
    dataloader_valid=dataloader_valid,
    nb_steps=50000,
    valid_intvl=20,
    nb_valid_steps=10,
    log_intvl=100,
    pbar_desc='MIDI generator training',
    lr_scheduler=lr_scheduler,
    use_amp=True,
    gradient_clip_norm=3,
    device=device,
    saving_dir=save_path
)

## Generate music

In [None]:
nb_inferences = 512  # extends samples by 512 tokens
(gen_results_path := Path('gen_res')).mkdir(parents=True, exist_ok=True)
dataloader_test = DataLoader(subset_valid, batch_size=16, collate_fn=dataset.collate_fn_infer)

model.eval()
count = 0
for batch in tqdm(dataloader_test, desc='Testing model / Generating results'):  # (N,T)
    # Attention mask (handling padding), sampling are handled in generate method
    if device.type == 'cuda':
        batch = batch.to(device)
    res = model.generate(batch, do_sample=True, num_beams=5, top_p=0.9, max_new_tokens=600)  # (N,T)

    # Saves the generated music, as MIDI files and tokens (json)
    for prompt, continuation in zip(batch, res):

        generated = continuation[len(prompt):]
        tokens = [generated, prompt, continuation]  # list compr. as seqs of dif. lengths
        tokens = [seq.tolist() for seq in tokens]
        midi = tokenizer.tokens_to_midi(deepcopy(tokens), time_division=384)
        midi.instruments[0].name = f'Continuation of original sample ({len(generated)} tokens)'
        midi.instruments[1].name = f'Original sample ({len(prompt)} tokens)'
        midi.instruments[2].name = f'Original sample and continuation'
        midi.dump(gen_results_path / f'{count}.mid')
        tokenizer.save_tokens(tokens, gen_results_path / f'{count}.json')   

        count += 1