<a href="https://colab.research.google.com/github/asigalov61/Amazing-GPT2-Piano/blob/master/Super_GPT2_Piano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super GPT2 Piano

## All credit for this beautiful colab implementation of char-based Music GPT2 goes out to Andrej Karpathy and Edtky of GitHub on whose repos/code it is based: 

https://github.com/karpathy/minGPT

 https://github.com/edtky/mini-musical-neural-net

## A simple implementation of Music GPT2 trained on the classical music. 

The IOs here are simple text files, which we chop up to individual characters and then train GPT on. So you could say this is a char-transformer instead of a char-rnn. Doesn't quite roll off the tongue well. In this example we will feed it some Music, which we'll get it to predict on a character-level.

In [None]:
#@title Clone minGPT repo and install all dependencies (run only once per session)
!git clone https://github.com/asigalov61/minGPT
 
!pip install pyknon
!pip install pretty_midi
!pip install pypianoroll
!pip install mir_eval
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2

In [None]:
#@title Import all modules and setup logging
%cd /content/minGPT
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

import numpy as np
import torch
import torch.nn as nn
#from torch.nn import functional as F

# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

from torch import optim
import torch.nn.functional as F

import keras
from keras.utils import to_categorical

import time

import pretty_midi
from midi2audio import FluidSynth
from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print('Available Device:', device)

!mkdir /content/midis

sample_freq_variable = 12 #@param {type:"number"}
note_range_variable = 62 #@param {type:"number"}
note_offset_variable = 33 #@param {type:"number"}
number_of_instruments = 2 #@param {type:"number"}
chamber_option = True #@param {type:"boolean"}

In [None]:
#@title (OPTIONAL) Convert your own MIDIs to Notewise TXT DataSet (before running this cell, upload your MIDI DataSet to /content/midis folder)
import tqdm.auto
import argparse
import random
import os
import numpy as np
from math import floor
from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
import music21
from music21 import instrument, volume
from music21 import midi as midiModule
from pathlib import Path
import glob, sys
from music21 import converter, instrument
%cd /content
notes=[]
InstrumentID=0
folder = '/content/midis/*mid'
for file in tqdm.auto.tqdm(glob.glob(folder)):
    filename = file[-53:]
    print(filename)

    # fname = "../midi-files/mozart/sonat-3.mid"
    fname = filename

    mf=music21.midi.MidiFile()
    mf.open(fname)
    mf.read()
    mf.close()
    midi_stream=music21.midi.translate.midiFileToStream(mf)
    midi_stream



    sample_freq=sample_freq_variable
    note_range=note_range_variable
    note_offset=note_offset_variable
    chamber=chamber_option
    numInstruments=number_of_instruments

    s = midi_stream
    #print(s.duration.quarterLength)

    s[0].elements


    maxTimeStep = floor(s.duration.quarterLength * sample_freq)+1
    score_arr = np.zeros((maxTimeStep, numInstruments, note_range))

    #print(maxTimeStep, "\n", score_arr.shape)

    # define two types of filters because notes and chords have different structures for storing their data
    # chord have an extra layer because it consist of multiple notes

    noteFilter=music21.stream.filters.ClassFilter('Note')
    chordFilter=music21.stream.filters.ClassFilter('Chord')
      

    # pitch.midi-note_offset: pitch is the numerical representation of a note. 
    #                         note_offset is the the pitch relative to a zero mark. eg. B-=25, C=27, A=24

    # n.offset: the timestamps of each note, relative to the start of the score
    #           by multiplying with the sample_freq, you make all the timestamps integers

    # n.duration.quarterLength: the duration of that note as a float eg. quarter note = 0.25, half note = 0.5
    #                           multiply by sample_freq to represent duration in terms of timesteps

    notes = []
    instrumentID = 0
    parts = instrument.partitionByInstrument(s)
    for i in range(len(parts.parts)): 
      instru = parts.parts[i].getInstrument()
      

    for n in s.recurse().addFilter(noteFilter):
        if chamber:
          # assign_instrument where 0 means piano-like and 1 means violin-like, and -1 means neither
            notes.append((n.pitch.midi-note_offset, floor(n.offset*sample_freq), 
              floor(n.duration.quarterLength*sample_freq), 1))

          #if instru.instrumentName == 'Piano':
           #   instrumentID=0
          #if instru.instrumentName == 'Violin':
         #     instrumentID=1

        notes.append((n.pitch.midi-note_offset, floor(n.offset*sample_freq), 
              floor(n.duration.quarterLength*sample_freq), 0))
        
    #print(len(notes))
    notes[-5:]

    # do the same using a chord filter

    for c in s.recurse().addFilter(chordFilter):
        # unlike the noteFilter, this line of code is necessary as there are multiple notes in each chord
        # pitchesInChord is a list of notes at each chord eg. (<music21.pitch.Pitch D5>, <music21.pitch.Pitch F5>)
        pitchesInChord=c.pitches
        
        # do same as noteFilter and append all notes to the notes list
        for p in pitchesInChord:
            notes.append((p.midi-note_offset, floor(c.offset*sample_freq), 
                          floor(c.duration.quarterLength*sample_freq), 1))

        # do same as noteFilter and append all notes to the notes list
        for p in pitchesInChord:
            notes.append((p.midi-note_offset, floor(c.offset*sample_freq), 
                          floor(c.duration.quarterLength*sample_freq), 0))
    #print(len(notes))
    notes[-5:]

    # the variable/list "notes" is a collection of all the notes in the song, not ordered in any significant way

    for n in notes:
        
        # pitch is the first variable in n, previously obtained by n.midi-note_offset
        pitch=n[0]
        
        # do some calibration for notes that fall our of note range
        # i.e. less than 0 or more than note_range
        while pitch<0:
            pitch+=12
        while pitch>=note_range:
            pitch-=12
            
        # 3rd element refers to instrument type => if instrument is violin, use different pitch calibration
        if n[3]==1:      #Violin lowest note is v22
            while pitch<22:
                pitch+=12

        # start building the 3D-tensor of shape: (796, 1, 38)
        # score_arr[0] = timestep
        # score_arr[1] = type of instrument
        # score_arr[2] = pitch/note out of the range of note eg. 38
        
        # n[0] = pitch
        # n[1] = timestep
        # n[2] = duration
        # n[3] = instrument
        #print(n[3])
        score_arr[n[1], n[3], pitch]=1                  # Strike note
        score_arr[n[1]+1:n[1]+n[2], n[3], pitch]=2      # Continue holding note

    #print(score_arr.shape)
    # print first 5 timesteps
    score_arr[:5,0,]


    for timestep in score_arr:
        #print(list(reversed(range(len(timestep)))))
        break

    instr={}
    instr[0]="p"
    instr[1]="v"

    score_string_arr=[]

    # loop through all timesteps
    for timestep in score_arr:
        
        # selecting the instruments: i=0 means piano and i=1 means violin
        for i in list(reversed(range(len(timestep)))):   # List violin note first, then piano note
            
            # 
            score_string_arr.append(instr[i]+''.join([str(int(note)) for note in timestep[i]]))

    #print(type(score_string_arr), len(score_string_arr))
    score_string_arr[:5]

    modulated=[]
    # get the note range from the array
    note_range=len(score_string_arr[0])-1

    for i in range(0,12):
        for chord in score_string_arr:
            
            # minus the instrument letter eg. 'p'
            # add 6 zeros on each side of the string
            padded='000000'+chord[1:]+'000000'
            
            # add back the instrument letter eg. 'p'
            # append window of len=note_range back into 
            # eg. if we have "00012345000"
            # iteratively, we want to get "p00012", "p00123", "p01234", "p12345", "p23450", "p34500", "p45000",
            modulated.append(chord[0]+padded[i:i+note_range])

    # 796 * 12
    #print(len(modulated))
    modulated[:5]

    # input of this function is a modulated string
    long_string = modulated

    translated_list=[]

    # for every timestep of the string
    for j in range(len(long_string)):
        
        # chord at timestep j eg. 'p00000000000000000000000000000000000100'
        chord=long_string[j]
        next_chord=""
        
        # range is from next_timestep to max_timestep
        for k in range(j+1, len(long_string)):
            
            # checking if instrument of next chord is same as current chord
            if long_string[k][0]==chord[0]:
                
                # if same, set next chord as next element in modulation
                # otherwise, keep going until you find a chord with the same instrument
                # when you do, set it as the next chord
                next_chord=long_string[k]
                break
        
        # set prefix as the instrument
        # set chord and next_chord to be without the instrument prefix
        # next_chord is necessary to check when notes end
        prefix=chord[0]
        chord=chord[1:]
        next_chord=next_chord[1:]
        
        # checking for non-zero notes at one particular timestep
        # i is an integer indicating the index of each note the chord
        for i in range(len(chord)):
            
            if chord[i]=="0":
                continue
            
            # set note as 2 elements: instrument and index of note
            # examples: p22, p16, p4
            #p = music21.pitch.Pitch()
            #nt = music21.note.Note(p)
            #n.volume.velocity = 20
            #nt.volume.client == nt
            #V = nt.volume.velocity
            #print(V)
            #note=prefix+str(i)+' V' + str(V)
            note=prefix+str(i)                
            
            # if note in chord is 1, then append the note eg. p22 to the list
            if chord[i]=="1":
                translated_list.append(note)
            
            # If chord[i]=="2" do nothing - we're continuing to hold the note
            
            # unless next_chord[i] is back to "0" and it's time to end the note.
            if next_chord=="" or next_chord[i]=="0":      
                translated_list.append("end"+note)

        # wait indicates end of every timestep
        if prefix=="p":
            translated_list.append("wait")

    #print(len(translated_list))
    translated_list[:10]

    # this section transforms the list of notes into a string of notes

    # initialize i as zero and empty string
    i=0
    translated_string=""


    while i<len(translated_list):
        
        # stack all the repeated waits together using an integer to indicate the no. of waits
        # eg. "wait wait" => "wait2"
        wait_count=1
        if translated_list[i]=='wait':
            while wait_count<=sample_freq*2 and i+wait_count<len(translated_list) and translated_list[i+wait_count]=='wait':
                wait_count+=1
            translated_list[i]='wait'+str(wait_count)
            
        # add next note
        translated_string+=translated_list[i]+" "
        i+=wait_count

    translated_string[:100]
    len(translated_string)

    #print("chordwise encoding type and length:", type(modulated), len(modulated))
    #print("notewise encoding type and length:", type(translated_string), len(translated_string))

    # default settings: sample_freq=12, note_range=62

    chordwise_folder = "../"
    notewise_folder = "../"

    # export chordwise encoding
#    f=open(chordwise_folder+fname+"_chordwise"+".txt","w+")
#    f.write(" ".join(modulated))
#    f.close()

    # export notewise encoding
    f=open(notewise_folder+fname+"_notewise"+".txt","w+")
    f.write(translated_string)
    f.close()

folder = '/content/midis/*notewise.txt'


filenames = glob.glob('/content')
with open('notewise_custom_dataset.txt', 'w') as outfile:
    for fname in glob.glob(folder)[-53:]:
        with open(fname) as infile:
            for line in infile:
                outfile.write(line)

#folder = '/content/midis/*chordwise.txt'

#filenames = glob.glob('/content')
#with open('chordwise_custom_dataset.txt', 'w') as outfile:
#    for fname in glob.glob(folder)[-53:]:
#        with open(fname) as infile:
#            for line in infile:
#                outfile.write(line)

In [None]:
#@title Setup functions and procedures
model_attention_span_in_tokens = 64 #@param {type:"slider", min:0, max:512, step:16}
import math
from torch.utils.data import Dataset

class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        """
        arrange data and targets so that the first i elements of x
        will be asked to predict the i-th element of y. Notice that
        the eventual language model will actually make block_size
        individual predictions at the same time based on this data,
        so we are being clever and amortizing the cost of the forward
        pass of the network. So for example if block_size is 4, then
        we could e.g. sample a chunk of text "hello", the integers in
        x will correspond to "hell" and in y will be "ello". This will
        then actually "multitask" 4 separate examples at the same time
        in the language model:
        - given just "h", please predict "e" as next
        - given "he" please predict "l" next
        - given "hel" predict "l" next
        - given "hell" predict "o" next
        
        In addition, because the DataLoader will create batches of examples,
        every forward/backward pass during traning will simultaneously train
        a LOT of predictions, amortizing a lot of computation. In particular,
        for a batched input of integers X (B, T) where B is batch size and
        T is block_size and Y (B, T), the network will during training be
        simultaneously training to make B*T predictions, all at once! Of course,
        at test time we can paralellize across batch B, but unlike during training
        we cannot parallelize across the time dimension T - we have to run
        a forward pass of the network to recover the next single character of the 
        sequence along each batch dimension, and repeatedly always feed in a next
        character to get the next one.
        
        So yes there is a big asymmetry between train/test time of autoregressive
        models. During training we can go B*T at a time with every forward pass,
        but during test time we can only go B at a time, T times, with T forward 
        passes.
        """
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

        
block_size = model_attention_span_in_tokens # spatial extent of the model for its context

In [None]:
#@title (OPTIONAL) Download ready-to-use Piano and Chamber Notewise DataSets
%cd /content/
!wget 'https://github.com/asigalov61/SuperPiano/raw/master/Super%20Chamber%20Piano%20Violin%20Notewise%20DataSet.zip'
!unzip '/content/Super Chamber Piano Violin Notewise DataSet.zip'
!rm '/content/Super Chamber Piano Violin Notewise DataSet.zip'

!wget 'https://github.com/asigalov61/SuperPiano/raw/master/Super%20Chamber%20Piano%20Only%20Notewise%20DataSet.zip'
!unzip '/content/Super Chamber Piano Only Notewise DataSet.zip'
!rm '/content/Super Chamber Piano Only Notewise DataSet.zip'

In [None]:
#@title Specify input text file with training data (do not worry, any text format is fine)
full_path_to_training_text_file = "/content/notewise_piano.txt" #@param {type:"string"}
text = open(full_path_to_training_text_file, 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

In [None]:
#@title Create GPT2 model
model_embed_size = 128 #@param {type:"slider", min:0, max:1024, step:64}
number_of_heads = 4 #@param {type:"slider", min:1, max:16, step:1}
number_of_layers = 4 #@param {type:"slider", min:1, max:16, step:1}
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=number_of_layers, n_head=number_of_heads, n_embd=model_embed_size)
model = GPT(mconf)

In [None]:
#@title Setup all training parameters
number_of_training_epochs = 100 #@param {type:"slider", min:1, max:100, step:1}
training_batch_size = 3072 #@param {type:"slider", min:0, max:3072, step:16}
from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=number_of_training_epochs, batch_size=training_batch_size, learning_rate=6e-4,
                      lr_decay=True, warmup_tokens=512*20, final_tokens=2*len(train_dataset)*block_size,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)


In [None]:
#@title (OPTION 1) Train the model
trainer.train()

epoch 1 iter 898: train loss 0.49843. lr 5.999570e-04:   1%|          | 899/83377 [07:14<11:01:30,  2.08it/s]

In [None]:
#@title Save/Re-Save the model from memory
torch.save(model, 'trained-model.pth')

In [None]:
#@title (OPTION 2) Load existing model/checkpoint
model = torch.load('trained-model.pth')
model.eval()

In [None]:
#@title Generate and Download Music with the Model (Notewise TXT2MIDI)
number_of_tokens_to_generate = 4096 #@param {type:"slider", min:0, max:8192, step:64}
creativity_temperature = 0.7 #@param {type:"slider", min:0.05, max:4, step:0.05}
top_k_prob = 3 #@param {type:"slider", min:0, max:50, step:1}
input_promt = "v" #@param {type:"string"}
# alright, let's sample some character-level Shakespeare
from mingpt.utils import sample
import tqdm

context = input_promt
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, number_of_tokens_to_generate, temperature=creativity_temperature, sample=True, top_k=top_k_prob)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print('Done! Saving output.txt!')
with open("/content/output.txt", "w") as text_file:
    print(completion, file=text_file)



#@title Generate and Download resulting MIDI file
time_coefficient = 4 #@param {type:"integer"}

import os
import dill as pickle
from pathlib import Path
import random
import numpy as np
import pandas as pd
from math import floor
from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
import music21
import random
import os, argparse

# default settings: sample_freq=12, note_range=62

def decoder(filename):
    
    filedir = '/content/'

    notetxt = filedir + filename

    with open(notetxt, 'r') as file:
        notestring=file.read()

    score_note = notestring.split(" ")

    # define some parameters (from encoding script)
    sample_freq=sample_freq_variable
    note_range=note_range_variable
    note_offset=note_offset_variable
    chamber=chamber_option
    numInstruments=number_of_instruments

    # define variables and lists needed for chord decoding
    speed=time_coefficient/sample_freq
    piano_notes=[]
    violin_notes=[]
    time_offset=0

    # start decoding here
    score = score_note

    i=0

    # for outlier cases, not seen in sonat-1.txt
    # not exactly sure what scores would have "p_octave_" or "eoc" (end of chord?)
    # it seems to insert new notes to the score whenever these conditions are met
    while i<len(score):
        if score[i][:9]=="p_octave_":
            add_wait=""
            if score[i][-3:]=="eoc":
                add_wait="eoc"
                score[i]=score[i][:-3]
            this_note=score[i][9:]
            score[i]="p"+this_note
            score.insert(i+1, "p"+str(int(this_note)+12)+add_wait)
            i+=1
        i+=1


    # loop through every event in the score
    for i in range(len(score)):

        # if the event is a blank, space, "eos" or unknown, skip and go to next event
        if score[i] in ["", " ", "<eos>", "<unk>"]:
            continue

        # if the event starts with 'end' indicating an end of note
        elif score[i][:3]=="end":

            # if the event additionally ends with eoc, increare the time offset by 1
            if score[i][-3:]=="eoc":
                time_offset+=1
            continue

        # if the event is wait, increase the timestamp by the number after the "wait"
        elif score[i][:4]=="wait":
            time_offset+=int(score[i][4:])
            continue

        # in this block, we are looking for notes   
        else:
            # Look ahead to see if an end<noteid> was generated
            # soon after.  
            duration=1
            has_end=False
            note_string_len = len(score[i])
            for j in range(1,200):
                if i+j==len(score):
                    break
                if score[i+j][:4]=="wait":
                    duration+=int(score[i+j][4:])
                if score[i+j][:3+note_string_len]=="end"+score[i] or score[i+j][:note_string_len]==score[i]:
                    has_end=True
                    break
                if score[i+j][-3:]=="eoc":
                    duration+=1

            if not has_end:
                duration=12

            add_wait = 0
            if score[i][-3:]=="eoc":
                score[i]=score[i][:-3]
                add_wait = 1

            try: 
                new_note=music21.note.Note(int(score[i][1:])+note_offset)    
                new_note.duration = music21.duration.Duration(duration*speed)
                new_note.offset=time_offset*speed
                if score[i][0]=="v":
                    violin_notes.append(new_note)
                else:
                    piano_notes.append(new_note)                
            except:
                print("Unknown note: " + score[i])




            time_offset+=add_wait

    # list of all notes for each instrument should be ready at this stage

    # creating music21 instrument objects      
    
    piano=music21.instrument.fromString("Piano")
    violin=music21.instrument.fromString("Violin")

    # insert instrument object to start (0 index) of notes list
    
    piano_notes.insert(0, piano)
    violin_notes.insert(0, violin)
    # create music21 stream object for individual instruments
    
    piano_stream=music21.stream.Stream(piano_notes)
    violin_stream=music21.stream.Stream(violin_notes)
    # merge both stream objects into a single stream of 2 instruments
    note_stream = music21.stream.Stream([piano_stream, violin_stream])

    
    note_stream.write('midi', fp="/content/"+filename[:-4]+".mid")
    print("Done! Decoded midi file saved to 'content/'")

    
decoder('output.txt')
from google.colab import files
files.download('/content/output.mid')

In [None]:
#@title Plot, Graph, and Listen to the Output :)
graphs_length_inches = 18 #@param {type:"slider", min:0, max:20, step:1}
notes_graph_height = 6 #@param {type:"slider", min:0, max:20, step:1}
highest_displayed_pitch = 92 #@param {type:"slider", min:1, max:128, step:1}
lowest_displayed_pitch = 24 #@param {type:"slider", min:1, max:128, step:1}

import librosa
import numpy as np
import pretty_midi
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
# For plotting
import mir_eval.display
import librosa.display
%matplotlib inline


midi_data = pretty_midi.PrettyMIDI('/content/output.mid')

def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))



roll = np.zeros([int(graphs_length_inches), 128])
# Plot the output

track = Multitrack('/content/output.mid', name='track')
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
fig, ax = track.plot()
fig.set_size_inches(graphs_length_inches, notes_graph_height)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
ax2 = plot_piano_roll(midi_data, int(lowest_displayed_pitch), int(highest_displayed_pitch))
plt.show(block=False)


FluidSynth("/content/font.sf2", 16000).midi_to_audio('/content/output.mid', '/content/output.wav')
Audio('/content/output.wav', rate=16000)

In [None]:
#@title Plot Positional Embeddings
import torchvision

import matplotlib.pyplot as plt
%matplotlib inline
# visualize some of the learned positional embeddings, maybe they contain structure
plt.figure(figsize=(18, 1))  
ci = model.pos_emb.data[0, :, 0].cpu()
zci = torch.cat((torch.tensor([0.0]), ci)) # pre-cat a zero
plt.imshow(zci.view(1, block_size+1).numpy())
plt.axis('off')

#Congrats! :) You did it :)

## Save the model to Google Drive (Standard GD connect code)

In [None]:
from google.colab import drive
drive.mount('/content/drive')