# MidiTok Full Workflow Example/Tutorial

***

Credit for GPT2-RGA code used in this colab goes out @ Sashmark97 https://github.com/Sashmark97/midigen and @ Damon Gwinn https://github.com/gwinndr/MusicTransformer-Pytorch

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please exercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2021

***

# (Setup Environment)

In [None]:
#@title nvidia-smi gpu check
!nvidia-smi

In [None]:
#@title Install all dependencies (run only once per session)

!pip install torch
!pip install tqdm
!pip install matplotlib

!pip install miditok

!wget 'https://raw.githubusercontent.com/asigalov61/tegridy-tools/main/tegridy-tools/GPT2RGA.py'

!wget 'https://github.com/asigalov61/Optimus-VIRTUOSO/raw/main/Samples/Relative-Global-Attention/Optimus-VIRTUOSO-RGA-Edition-Main-Sample.mid'

In [None]:
#@title Import all needed modules

print('Loading needed modules. Please wait...')
import os
from datetime import datetime
import secrets
import copy
import tqdm
from tqdm import auto

from GPT2RGA import *

from miditok import REMI, CPWord, get_midi_programs
from miditoolkit import MidiFile

import matplotlib.pyplot as plt

print('Done!')

In [None]:
print('Tokenizing source MIDI file...')

# Our parameters
pitch_range = range(21, 109)
beat_res = {(0, 4): 8, (4, 12): 4}
num_velocities = 32
additional_tokens = {'Chord': True, 'Rest': True, 'Tempo': True,
                     'rest_range': (2, 8),  # (half, 8 beats)
                     'num_tempos': 32,  # num of tempo bins
                     'tempo_range': (40, 250),  # (min, max)
                     'Program': False}

# Creates the tokenizer and loads a MIDI
tokenizer = REMI(pitch_range, beat_res, nb_velocities, additional_tokens) # REMI encoding

# tokenizer = CPWord(pitch_range, beat_res, nb_velocities, additional_tokens) # CP encoding


midi = MidiFile('Optimus-VIRTUOSO-RGA-Edition-Main-Sample.mid')

# Converts MIDI to tokens, and back to a MIDI
tokens = tokenizer.midi_to_tokens(midi)

print('Done!')

# (QUICK DEMO)

In [None]:
#@title Load processed INTs datasets
number_of_batches = 16 #@param {type:"slider", min:2, max:32, step:2}
n_workers = 6

print('=' * 50)
print('Prepping INTs datasets...')

train_data = []

train_data.extend(tokens[0])

val_dataset = train_data[:int(len(train_data) * 0.5)]
test_dataset = train_data[:int(len(train_data) * 0.5)]

train_list = train_data
val_list = val_dataset
test_list = []
print('=' * 50)

print('Processing INTs datasets...')
train_dataset = EPianoDataset(train_list, max_seq, random_seq)
val_dataset = EPianoDataset(val_list, max_seq)
test_dataset = EPianoDataset(test_list, max_seq)
print('=' * 50)

print('Loading INTs datasets...')
batch_size = number_of_batches
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=n_workers)
print('=' * 50)

print('Total INTs in the dataset', len(train_data))
print('Total unique INTs in the dataset', len(set(train_data)))
print('Max INT in the dataset', max(train_data))
print('Min INT in the dataset', min(train_data))
print('=' * 50)

print('Checking datasets shapes...')
print('=' * 50)

print('Train loader')
for x, tgt in train_loader:
    print(f'X shape: {x.shape}')
    print(f'Target shape: {tgt.shape}')
    break
print('=' * 50)

print('Validation loader')
for x, tgt in val_loader:
    print(f'X shape: {x.shape}')
    print(f'Target shape: {tgt.shape}')
    break
print('=' * 50)

print('Test loader')
for x, tgt in test_loader:
    print(f'X shape: {x.shape}')
    print(f'Target shape: {tgt.shape}')
    break
print('=' * 50)

print('Done! Enjoy! :)')
print('=' * 50)

# (TRAIN)

# Train the model

In [None]:
#@title Train

print('MidiTok Model Trainer')

config = GPTConfig(VOCAB_SIZE, 
                   max_seq,
                   dim_feedforward=dim_feedforward,
                   n_layer=6, 
                   n_head=8, 
                   n_embd=512,
                   enable_rpr=True,
                   er_len=max_seq)
model = GPT(config).to(get_device())

#=====

init_step = 0
lr = LR_DEFAULT_START
lr_stepper = LrStepTracker(d_model, SCHEDULER_WARMUP_STEPS, init_step)
eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)
train_loss_func = eval_loss_func

opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)
lr_scheduler = LambdaLR(opt, lr_stepper.step)


#===

best_eval_acc        = 0.0
best_eval_acc_epoch  = -1
best_eval_loss       = float("inf")
best_eval_loss_epoch = -1
best_acc_file = 'gpt2_rpr_acc.pth'
best_loss_file = 'gpt2_rpr_loss.pth'
loss_train, loss_val, acc_val = [], [], []

for epoch in range(0, epochs):
    new_best = False
    
    loss = train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)
    loss_train.append(loss)
    
    eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)
    loss_val.append(eval_loss)
    acc_val.append(eval_acc)
    
    if(eval_acc > best_eval_acc):
        best_eval_acc = eval_acc
        best_eval_acc_epoch  = epoch+1
        torch.save(model.state_dict(), best_acc_file)
        new_best = True

    if(eval_loss < best_eval_loss):
        best_eval_loss       = eval_loss
        best_eval_loss_epoch = epoch+1
        torch.save(model.state_dict(), best_loss_file)
        new_best = True
    
    if(new_best):
        print("Best eval acc epoch:", best_eval_acc_epoch)
        print("Best eval acc:", best_eval_acc)
        print("")
        print("Best eval loss epoch:", best_eval_loss_epoch)
        print("Best eval loss:", best_eval_loss)

In [None]:
#@title Plot resulting training loss graph
tr_loss_list = [item for sublist in loss_train for item in sublist]
plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
plt.savefig('MidiTok-Training-Loss.png')

# (SAVE/LOAD)

In [None]:
#@title Save the model

print('Saving the model...')
full_path_to_model_checkpoint = "MidiTok.pth" #@param {type:"string"}
torch.save(model.state_dict(), full_path_to_model_checkpoint)
print('Done!')

In [None]:
#@title Load/Reload the model
full_path_to_model_checkpoint = "MidiTok.pth" #@param {type:"string"}

print('Loading the model...')
config = GPTConfig(VOCAB_SIZE, 
                   max_seq,
                   dim_feedforward=dim_feedforward,
                   n_layer=6, 
                   n_head=8, 
                   n_embd=512,
                   enable_rpr=True,
                   er_len=max_seq)

model = GPT(config).to(get_device())

model.load_state_dict(torch.load(full_path_to_model_checkpoint))
print('Done!')

# (Generate)

In [None]:
# Continuation routine draft code

print('Loading source continuation MIDI...')

# Creates the tokenizer and loads a MIDI
tokenizer1 = REMI(pitch_range, beat_res, nb_velocities, additional_tokens)
midi1 = MidiFile('seed.mid')

# Converts MIDI to tokens, and back to a MIDI
tokens1 = tokenizer1.midi_to_tokens(midi1)

print('Done!')

In [None]:
# Generate from the model

## Seed generator...

### If you getting a file error on MIDI save, create test.mid file manually, then re-run


print('MidiTok Model Generator')

model.eval()

# rand_seq = model.generate(torch.Tensor(tokens1[0][-64:]), target_seq_length=1024) # Continuation example
rand_seq = model.generate(torch.Tensor([1]), target_seq_length=1024)
out = rand_seq[0].cpu().numpy().tolist()

converted_back_midi = tokenizer.tokens_to_midi([out], get_midi_programs(midi))
converted_back_midi.dump('MidiTok-OUTPUT.mid')

print('Done!')

In [None]:
#@title Auto-Regressive Generator

#@markdown NOTE: You much generate a seed composition first or it is not going to start

number_of_cycles_to_run = 5 #@param {type:"slider", min:1, max:50, step:1}
number_of_prime_tokens = 128 #@param {type:"slider", min:64, max:256, step:64}

print('=' * 70)
print('MidiTok Auto-Regressive Model Generator')
print('=' * 70)
print('Starting up...')
print('=' * 70)
print('Prime length:', len(out))
print('Prime tokens:', number_of_prime_tokens)
print('Prime input sequence', out[-8:])

if len(out) != 0:
  print('=' * 70)
  out_all = []
  out_all.append(out)
  for i in tqdm(range(number_of_cycles_to_run)):
      rand_seq1 = model.generate(torch.Tensor(out[-number_of_prime_tokens:]), target_seq_length=1024)
      out1 = rand_seq1[0].cpu().numpy().tolist()
      out_all.append(out1[number_of_prime_tokens:])
      out = out1[number_of_prime_tokens:]
      
      print(chr(10))
      print('=' * 70)
      print('Block number:', i+1)
      print('Composition length so far:', (i+1) * 1024, 'notes')
      print('=' * 70)

  print('Done!' * 70)
  print('Total blocks:', i+1)
  print('Final omposition length:', (i+1) * 1024, 'notes')
  print('=' * 70)

OUT = []

for o in out_all:
    OUT.extend(o)
    
    
converted_back_midi = tokenizer.tokens_to_midi([OUT], get_midi_programs(midi))
converted_back_midi.dump('MidiTok-OUTPUT.mid')

# Congrats! You did it! :)