<a href="https://colab.research.google.com/github/asigalov61/Amazing-GPT2-Piano/blob/master/Amazing_GPT2_Piano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super GPT2 Piano

***

## GPT2-based Symbolic Music Artificial Intelligence Model Creator/Trainer

### Multi-Track, Multi-Instrumental, MIDI-TXT-MIDI

***

### Credit for char-based GPT2 implementation used in this colab goes out to Andrej Karpathy: https://github.com/karpathy/minGPT


#### Tegridy Code 2020

## A simple implementation of Music GPT2 trained on the classical music. 

The IOs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue well. In this example we will feed it some Music, which we'll get it to predict on a character-level.

In [None]:
#@title Clone minGPT repo and install all dependencies (run only once per session)
!git clone https://github.com/asigalov61/minGPT
 
!pip install pyknon
!pip install pretty_midi
!pip install pypianoroll
!pip install mir_eval
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2

!curl -L "https://pjb.com.au/midi/free/MIDI.py" > 'MIDI.py'

!mkdir Dataset

In [None]:
#@title Import all modules and setup logging
%cd /content/minGPT
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

import numpy as np
import torch
import torch.nn as nn
#from torch.nn import functional as F

# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

from torch import optim
import torch.nn.functional as F

import keras
from keras.utils import to_categorical

import time

import pretty_midi
from midi2audio import FluidSynth
from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print('Available Device:', device)



In [None]:
#@title Process MIDI to TXT
one_byte_encoding = True #@param {type:"boolean"}
enable_sampling = False #@param {type:"boolean"}
sample_length_in_notes = 997 #@param {type:"slider", min:0, max:2000, step:1}
parse_only_basics = False #@param {type:"boolean"}
parse_text_fields_for_nlp = True #@param {type:"boolean"}
allow_tempo_changes = True #@param {type:"boolean"}



allow_control_change = True #@param {type:"boolean"}
%cd /content/
# MIDI Dataset to txt converter 
import MIDI
import os
import numpy as np
import tqdm.auto

def write_notes(file_address):
    midi_file = open(file_address, 'rb')
    #print('Processing File:', file_address)
    score = MIDI.midi2opus(midi_file.read())
    midi_file.close()
    # ['note', start_time, duration, channel, note, velocity]

    itrack = 1


    notes = []

    tokens = []

    this_channel_has_note = False

    file = open('Dataset.txt', 'a')
    file.write('[MIDI-TXT-MIDI Textual Music Dataset] ')
    while itrack < len(score):
        for event in score[itrack]:
          if not one_byte_encoding:

            if event[0] == 'note_off':
                this_channel_has_note = True
                notes.append(event[3])

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('Nf' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Nt' + str(event[3]) + ' Ve' + str(event[4]) + ' ')


            if event[0] == 'note_on':
                this_channel_has_note = True
                notes.append(event[3])
                
                tokens.append([event[3], event[4], event[1]])

                file.write('No' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Nt' + str(event[3]) + ' Ve' + str(event[4]) + ' ')

            if event[0] == 'key_after_touch':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                file.write('Ka' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Nt' + str(event[3]) + ' Ve' + str(event[4]) + ' ')

            if event[0] == 'control_change':
              if allow_control_change:
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('Cc' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Co' + str(event[3]) + ' Cv' + str(event[4]) + ' ')

            if event[0] == 'patch_change':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('Pc' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Pt' + str(event[3]) + ' ')

            if event[0] == 'channel_after_touch':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('Ct' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Ve' + str(event[3]) + ' ')

            if event[0] == 'pitch_wheel_change':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('Pc' + ' Du' + str(event[1]) + ' Ch' + str(event[2]) + ' Pw' + str(event[3]) + ' ')

            if event[0] == 'instrument_name':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('In' + ' Dt' + str(event[1]) + ' Tx' + str(event[2]) + ' ')

            if event[0] == 'end_track':
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('Et' + ' Dt' + str(event[1]) + ' ')

            if event[0] == 'set_tempo':
              if not parse_only_basics:
                if allow_tempo_changes:
                  this_channel_has_note = True

                
                tokens.append([ event[2], event[1]])
                
                file.write('St' + ' Dt' + str(event[1]) + ' Tm' + str(event[2]) + ' ')


            if event[0] == 'time_signature':
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('Ts' + ' Dt' + str(event[1]) + ' nn' + str(event[2]) + ' dd' + str(event[3]) + ' cc' + str(event[4]) + ' bb' + str(event[5]) +' ')

            if event[0] == 'key_signature':
                this_channel_has_note = True
                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('Ks' + ' Dt' + str(event[1]) + ' sf' + str(event[2]) + ' mi' + str(event[3]) + ' ')

            if event[0] == 'track_name':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True
                  
                  tokens.append([event[2], event[1]])
                  
                  file.write('Tn' + ' Dt' + str(event[1]) + ' Tx' + str(event[2]) + ' ')

            if event[0] == 'text_event':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True
                  
                  tokens.append([ event[2], event[1]])
                  
                  file.write('Te' + ' Dt' + str(event[1]) + ' Tx' + str(event[2]) + ' ')

            if event[0] == 'lyric':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True
                  
                  tokens.append([ event[2], event[1]])
                  
                  file.write('Ly' + ' Dt' + str(event[1]) + ' Tx' + str(event[2]) + ' ')

          if one_byte_encoding:

            if event[0] == 'note_off':
                this_channel_has_note = True
                notes.append(event[3])

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('F' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' N' + str(event[3]) + ' V' + str(event[4]) + ' ')


            if event[0] == 'note_on':
                this_channel_has_note = True
                notes.append(event[3])
                
                tokens.append([event[3], event[4], event[1]])

                file.write('O' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' N' + str(event[3]) + ' V' + str(event[4]) + ' ')

            if event[0] == 'key_after_touch':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                file.write('K' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' N' + str(event[3]) + ' V' + str(event[4]) + ' ')

            if event[0] == 'control_change':
              if allow_control_change:
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('C' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' R' + str(event[3]) + ' E' + str(event[4]) + ' ')

            if event[0] == 'patch_change':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('A' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' P' + str(event[3]) + ' ')

            if event[0] == 'channel_after_touch':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('H' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' V' + str(event[3]) + ' ')

            if event[0] == 'pitch_wheel_change':
              if not parse_only_basics:
                this_channel_has_note = True

                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('L' + ' D' + str(event[1]) + ' C' + str(event[2]) + ' W' + str(event[3]) + ' ')

            if event[0] == 'instrument_name':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True

                  
                  tokens.append([event[3], event[4], event[1]])
                  
                  file.write('I' + ' D' + str(event[1]) + ' T' + str(event[2]) + ' ')

            if event[0] == 'end_track':
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('B' + ' D' + str(event[1]) + ' ')

            if event[0] == 'set_tempo':
              if not parse_only_basics:
                if allow_tempo_changes:
                  this_channel_has_note = True

                
                  tokens.append([ event[2], event[1]])
                  
                  file.write('G' + ' D' + str(event[1]) + ' J' + str(event[2]) + ' ')


            if event[0] == 'time_signature':
                this_channel_has_note = True

                
                tokens.append([event[3], event[4], event[1]])
                
                file.write('Q' + ' D' + str(event[1]) + ' n' + str(event[2]) + ' d' + str(event[3]) + ' c' + str(event[4]) + ' b' + str(event[5]) +' ')

            if event[0] == 'key_signature':
                this_channel_has_note = True
                
                tokens.append([event[3], event[2], event[1]])
                
                file.write('X' + ' D' + str(event[1]) + ' s' + str(event[2]) + ' m' + str(event[3]) + ' ')

            if event[0] == 'track_name':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True
                  
                  tokens.append([event[2], event[1]])
                  
                  file.write('M' + ' D' + str(event[1]) + ' T' + str(event[2]) + ' ')

            if event[0] == 'text_event':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True
                  
                  tokens.append([ event[2], event[1]])
                  
                  file.write('V' + ' D' + str(event[1]) + ' T' + str(event[2]) + ' ')

            if event[0] == 'lyric':
              if not parse_only_basics:
                if parse_text_fields_for_nlp:
                  this_channel_has_note = True
                  
                  tokens.append([ event[2], event[1]])
                  
                  file.write('Y' + ' D' + str(event[1]) + ' T' + str(event[2]) + ' ')




        itrack += 1
        if this_channel_has_note and len(notes) > sample_length_in_notes:
          if enable_sampling:
            break
    file.close()     
       

dataset_addr = "Dataset"
files = os.listdir(dataset_addr)
for file in tqdm.auto.tqdm(files):
    path = os.path.join(dataset_addr, file)
    write_notes(path)

In [None]:
#@title Setup functions and procedures
model_attention_span_in_tokens = 256 #@param {type:"slider", min:0, max:512, step:16}
import math
from torch.utils.data import Dataset

class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        """
        arrange data and targets so that the first i elements of x
        will be asked to predict the i-th element of y. Notice that
        the eventual language model will actually make block_size
        individual predictions at the same time based on this data,
        so we are being clever and amortizing the cost of the forward
        pass of the network. So for example if block_size is 4, then
        we could e.g. sample a chunk of text "hello", the integers in
        x will correspond to "hell" and in y will be "ello". This will
        then actually "multitask" 4 separate examples at the same time
        in the language model:
        - given just "h", please predict "e" as next
        - given "he" please predict "l" next
        - given "hel" predict "l" next
        - given "hell" predict "o" next
        
        In addition, because the DataLoader will create batches of examples,
        every forward/backward pass during traning will simultaneously train
        a LOT of predictions, amortizing a lot of computation. In particular,
        for a batched input of integers X (B, T) where B is batch size and
        T is block_size and Y (B, T), the network will during training be
        simultaneously training to make B*T predictions, all at once! Of course,
        at test time we can paralellize across batch B, but unlike during training
        we cannot parallelize across the time dimension T - we have to run
        a forward pass of the network to recover the next single character of the 
        sequence along each batch dimension, and repeatedly always feed in a next
        character to get the next one.
        
        So yes there is a big asymmetry between train/test time of autoregressive
        models. During training we can go B*T at a time with every forward pass,
        but during test time we can only go B at a time, T times, with T forward 
        passes.
        """
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

        
block_size = model_attention_span_in_tokens # spatial extent of the model for its context

In [None]:
#@title Specify input text file with training data (do not worry, any text format is fine)
full_path_to_training_text_file = "/content/Dataset.txt" #@param {type:"string"}
text = open(full_path_to_training_text_file, 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

In [None]:
#@title Create GPT2 model
model_embed_size = 256 #@param {type:"slider", min:0, max:1024, step:64}
number_of_heads = 8 #@param {type:"slider", min:1, max:16, step:1}
number_of_layers = 8 #@param {type:"slider", min:1, max:16, step:1}
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=number_of_layers, n_head=number_of_heads, n_embd=model_embed_size)
model = GPT(mconf)

In [None]:
#@title Setup all training parameters
number_of_training_epochs = 100 #@param {type:"slider", min:1, max:100, step:1}
training_batch_size = 128 #@param {type:"slider", min:0, max:3072, step:16}
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=number_of_training_epochs, batch_size=training_batch_size, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)


In [None]:
#@title (OPTION 1) Train the model
trainer.train()

In [None]:
#@title Save/Re-Save the model from memory
torch.save(model, 'trained-model.pth')

In [None]:
#@title (OPTION 2) Load existing model/checkpoint
model = torch.load('trained-model.pth')
model.eval()

In [None]:
#@title Generate and Download Music with the Model (Notewise TXT2MIDI)
number_of_tokens_to_generate = 8192 #@param {type:"slider", min:0, max:8192, step:64}
creativity_temperature = 0.7 #@param {type:"slider", min:0.05, max:4, step:0.05}
top_k_prob = 5 #@param {type:"slider", min:0, max:50, step:1}
input_promt = "[MIDI" #@param {type:"string"}
# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample
import tqdm

context = input_promt
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, number_of_tokens_to_generate, temperature=creativity_temperature, sample=True, top_k=top_k_prob)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print('Done! Saving output.txt!')
with open("/content/output.txt", "w") as text_file:
    print(completion, file=text_file)



#@title Generate and Download resulting MIDI file

#@title Convert to MIDI from TXT
number_of_ticks_per_quarter = 425 #@param {type:"slider", min:1, max:1440, step:8}

import MIDI
import tqdm.auto
notes = []
velocities = []
timings = []
durations = []

with open('/content/output.txt', 'r') as file:
    notestring=file.read()

score_note = notestring.split(" ")

score = score_note

i=0

song_score = [number_of_ticks_per_quarter, [['track_name', 0, b'Composed by Artificial Intelligence Model']]]

for i in tqdm.auto.tqdm(range(len(score))):

        # if the event is a blank, space, "eos" or unknown, skip and go to next event
        if score[i] in ["", " ", "<eos>", "<unk>"]:
            continue

        # if the event starts with 'end' indicating an end of note
        elif score[i][:2]=="@@":

            continue

        # in this block, we are looking for notes   
        else:
            # Look ahead to see if an end<noteid> was generated
            # soon after.  


            note_string_len = len(score[i])
            for j in range(1,200):
                if i+j==len(score):
                    break

            if not one_byte_encoding:
              
                if score[i] == 'No':
                  try:
                    
                    song_score[-1].append(['note_on', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:]),
                                          int(score[i+4][2:])])
                              
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Nf':
                  try:
                    song_score[-1].append(['note_off', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:]),
                                          int(score[i+4][2:])])
                    
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Cc':
                  try:
                    song_score[-1].append(['control_change',
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:]),
                                          int(score[i+4][2:])])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Kt':
                  try:
                    song_score[-1].append(['key_after_touch', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:]),
                                          int(score[i+4][2:])])
                  except:
                    print("Unknown note: " + score[i])



                if score[i] == 'Pc':
                  try:
                    song_score[-1].append(['patch_change', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:])
                                          ])
                  except:
                    print("Unknown note: " + score[i])


                if score[i] == 'Ca':
                  try:
                    song_score[-1].append(['channel_after_touch', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:])
                                          ])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Pw':
                  try:
                    song_score[-1].append(['pitch_wheel_change', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:])
                                          ])

                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'In':
                  try:
                    song_score[-1].append(['instrument_name', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:])])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Et':
                  try:
                    song_score[-1].append(['end_track', 
                                          int(score[i+1][2:])])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'St':
                  try:
                    song_score[-1].append(['set_tempo', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:])])
                  except:
                    print("Unknown note: " + score[i])


                if score[i] == 'Ts':
                  try:
                    song_score[-1].append(['time_signature', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:]),
                                          int(score[i+4][2:]),
                                          int(score[i+5][2:])])

                  except:
                    print("Unknown note: " + score[i])


                if score[i] == 'Ks':
                  try:
                    song_score[-1].append(['key_signature', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:]),
                                          int(score[i+3][2:])])

                  except:
                    print("Unknown note: " + score[i])



                if score[i] == 'Ly':
                  try:
                    song_score[-1].append(['lyric', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:])])

                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Te':
                  try:
                    song_score[-1].append(['text_event', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:])])

                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'Tn':
                  try:
                    song_score[-1].append(['track_name', 
                                          int(score[i+1][2:]), 
                                          int(score[i+2][2:])])

                  except:
                    print("Unknown note: " + score[i])


###
            if one_byte_encoding:

                if score[i] == 'O':
                  try:
                    
                    song_score[-1].append(['note_on', 
                                          int(score[i+1][1:]), #Duration
                                          int(score[i+2][1:]), #Channel
                                          int(score[i+3][1:]), #Note
                                          int(score[i+4][1:])]) #Velocity
                              
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'F':
                  try:
                    song_score[-1].append(['note_off', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:]),
                                          int(score[i+4][1:])])
                    
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'C':
                  try:
                    song_score[-1].append(['control_change',
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:]), #Controller
                                          int(score[i+4][1:])]) #ControlValue
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'K':
                  try:
                    song_score[-1].append(['key_after_touch', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:]),
                                          int(score[i+4][1:])])
                  except:
                    print("Unknown note: " + score[i])



                if score[i] == 'A':
                  try:
                    song_score[-1].append(['patch_change', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:])
                                          ])
                  except:
                    print("Unknown note: " + score[i])


                if score[i] == 'H':
                  try:
                    song_score[-1].append(['channel_after_touch', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:])
                                          ])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'L':
                  try:
                    song_score[-1].append(['pitch_wheel_change', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:])
                                          ])

                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'I':
                  try:
                    song_score[-1].append(['instrument_name', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:])])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'B':
                  try:
                    song_score[-1].append(['end_track', 
                                          int(score[i+1][1:])])
                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'G':
                  try:
                    song_score[-1].append(['set_tempo', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:])])
                  except:
                    print("Unknown note: " + score[i])


                if score[i] == 'Q':
                  try:
                    song_score[-1].append(['time_signature', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:]),
                                          int(score[i+4][1:]),
                                          int(score[i+5][1:])])

                  except:
                    print("Unknown note: " + score[i])


                if score[i] == 'X':
                  try:
                    song_score[-1].append(['key_signature', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:]),
                                          int(score[i+3][1:])])

                  except:
                    print("Unknown note: " + score[i])



                if score[i] == 'Y':
                  try:
                    song_score[-1].append(['lyric', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:])])

                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'V':
                  try:
                    song_score[-1].append(['text_event', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:])])

                  except:
                    print("Unknown note: " + score[i])

                if score[i] == 'M':
                  try:
                    song_score[-1].append(['track_name', 
                                          int(score[i+1][1:]), 
                                          int(score[i+2][1:])])

                  except:
                    print("Unknown note: " + score[i])




midi_data = MIDI.opus2midi(song_score)
with open('output.mid', 'wb') as midi_file:
    midi_file.write(midi_data)

MIDI.score2stats(song_score)

from google.colab import files
files.download('/content/output.mid')

In [None]:
#@title Plot, Graph, and Listen to the Output :)
graphs_length_inches = 18 #@param {type:"slider", min:0, max:20, step:1}
notes_graph_height = 6 #@param {type:"slider", min:0, max:20, step:1}
highest_displayed_pitch = 92 #@param {type:"slider", min:1, max:128, step:1}
lowest_displayed_pitch = 24 #@param {type:"slider", min:1, max:128, step:1}

import librosa
import numpy as np
import pretty_midi
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
# For plotting
import mir_eval.display
import librosa.display
%matplotlib inline


midi_data = pretty_midi.PrettyMIDI('/content/output.mid')

def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))



roll = np.zeros([int(graphs_length_inches), 128])
# Plot the output

track = Multitrack('/content/output.mid', name='track')
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
fig, ax = track.plot()
fig.set_size_inches(graphs_length_inches, notes_graph_height)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
ax2 = plot_piano_roll(midi_data, int(lowest_displayed_pitch), int(highest_displayed_pitch))
plt.show(block=False)


FluidSynth("/content/font.sf2", 16000).midi_to_audio('/content/output.mid', '/content/output.wav')
Audio('/content/output.wav', rate=16000)

In [None]:
#@title Plot Positional Embeddings
import torchvision

import matplotlib.pyplot as plt
%matplotlib inline
# visualize some of the learned positional embeddings, maybe they contain structure
plt.figure(figsize=(18, 1))  
ci = model.pos_emb.data[0, :, 0].cpu()
zci = torch.cat((torch.tensor([0.0]), ci)) # pre-cat a zero
plt.imshow(zci.view(1, block_size+1).numpy())
plt.axis('off')

#Congrats! :) You did it :)

## Save the model to Google Drive (Standard GD connect code)

In [None]:
from google.colab import drive
drive.mount('/content/drive')