<a href="https://colab.research.google.com/github/asigalov61/Amazing-GPT2-Piano/blob/master/Super_Fine_Tuned_GPT2_Piano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Super Fine-Tuned GPT2 Piano

##  Finetune SOTA GPT-2 Text-Generating Model to Generate Chamber Music

All thanks and credit for this colab go out to [Max Woolf](http://minimaxir.com)

*Last updated: November 10th, 2019*

Retrain an advanced text generating neural network on any text dataset **for free on a GPU using Collaboratory** using `gpt-2-simple`!

For more about `gpt-2-simple`, you can visit [this GitHub repository](https://github.com/minimaxir/gpt-2-simple). You can also read my [blog post](https://minimaxir.com/2019/09/howto-gpt2/) for more information how to use this notebook!


To get started:

1. Copy this notebook to your Google Drive to keep it and save your changes. (File -> Save a Copy in Drive)
2. Make sure you're running the notebook in Google Chrome.
3. Run the cells below:


In [None]:
#@title Install all dependencies and packages
%tensorflow_version 1.x
!pip install -q gpt-2-simple



!pip install pyknon
!pip install pretty_midi
!pip install pypianoroll
!pip install mir_eval
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2


import gpt_2_simple as gpt2
from datetime import datetime
from google.colab import files

!nvidia-smi

## Downloading GPT-2

If you're retraining a model on new text, you need to download the GPT-2 model first. 

There are three released sizes of GPT-2:

* `124M` (default): the "small" model, 500MB on disk.
* `355M`: the "medium" model, 1.5GB on disk.
* `774M`: the "large" model, cannot currently be finetuned with Colaboratory but can be used to generate text from the pretrained model (see later in Notebook)
* `1558M`: the "extra large", true model. Will not work if a K80 GPU is attached to the notebook. (like `774M`, it cannot be finetuned).

Larger models have more knowledge, but take longer to finetune and longer to generate text. You can specify which base model to use by changing `model_name` in the cells below.

The next cell downloads it from Google Cloud Storage and saves it in the Colaboratory VM at `/models/<model_name>`.

This model isn't permanently saved in the Colaboratory VM; you'll have to redownload it if you want to retrain it at a later time.

In [None]:
gpt2.download_gpt2(model_name="774M")

## Mounting Google Drive

The best way to get input text to-be-trained into the Colaboratory VM, and to get the trained model *out* of Colaboratory, is to route it through Google Drive *first*.

Running this cell (which will only work in Colaboratory) will mount your personal Google Drive in the VM, which later cells can use to get data in/out. (it will ask for an auth code; that auth is not saved anywhere)

In [None]:
gpt2.mount_gdrive()

In [None]:
#@title (OPTIONAL) Download ready-to-use Piano and Chamber Notewise DataSets
%cd /content/
!wget 'https://github.com/asigalov61/SuperPiano/raw/master/Super%20Chamber%20Piano%20Violin%20Notewise%20DataSet.zip'
!unzip '/content/Super Chamber Piano Violin Notewise DataSet.zip'
!rm '/content/Super Chamber Piano Violin Notewise DataSet.zip'

!wget 'https://github.com/asigalov61/SuperPiano/raw/master/Super%20Chamber%20Piano%20Only%20Notewise%20DataSet.zip'
!unzip '/content/Super Chamber Piano Only Notewise DataSet.zip'
!rm '/content/Super Chamber Piano Only Notewise DataSet.zip'

In [None]:
file_name = "/content/notewise_chamber.txt"

If your text file is larger than 10MB, it is recommended to upload that file to Google Drive first, then copy that file from Google Drive to the Colaboratory VM.

In [None]:
gpt2.copy_file_from_gdrive(file_name)

## Finetune GPT-2

The next cell will start the actual finetuning of GPT-2. It creates a persistent TensorFlow session which stores the training config, then runs the training for the specified number of `steps`. (to have the finetuning run indefinitely, set `steps = -1`)

The model checkpoints will be saved in `/checkpoint/run1` by default. The checkpoints are saved every 500 steps (can be changed) and when the cell is stopped.

The training might time out after 4ish hours; make sure you end training and save the results so you don't lose them!

**IMPORTANT NOTE:** If you want to rerun this cell, **restart the VM first** (Runtime -> Restart Runtime). You will need to rerun imports but not recopy files.

Other optional-but-helpful parameters for `gpt2.finetune`:


*  **`restore_from`**: Set to `fresh` to start training from the base GPT-2, or set to `latest` to restart training from an existing checkpoint.
* **`sample_every`**: Number of steps to print example output
* **`print_every`**: Number of steps to print training progress.
* **`learning_rate`**:  Learning rate for the training. (default `1e-4`, can lower to `1e-5` if you have <1MB input data)
*  **`run_name`**: subfolder within `checkpoint` to save the model. This is useful if you want to work with multiple models (will also need to specify  `run_name` when loading the model)
* **`overwrite`**: Set to `True` if you want to continue finetuning an existing model (w/ `restore_from='latest'`) without creating duplicate copies. 

In [None]:
sess = gpt2.start_tf_sess()

gpt2.finetune(sess,
              dataset=file_name,
              model_name='774M',
              steps=1000,
              restore_from='fresh',
              run_name='run1',
              print_every=10,
              sample_every=200,
              save_every=500
              )

[360 | 410.12] loss=0.64 avg=0.77
[370 | 420.20] loss=0.73 avg=0.77
[380 | 430.29] loss=0.62 avg=0.76
[390 | 440.38] loss=0.71 avg=0.76
[400 | 450.47] loss=0.81 avg=0.76


After the model is trained, you can copy the checkpoint folder to your own Google Drive.

If you want to download it to your personal computer, it's strongly recommended you copy it there first, then download from Google Drive. The checkpoint folder is copied as a `.rar` compressed file; you can download it and uncompress it locally.

In [None]:
gpt2.copy_checkpoint_to_gdrive(run_name='run1')

You're done! Feel free to go to the **Generate Text From The Trained Model** section to generate text based on your retrained model.

## Load a Trained Model Checkpoint

Running the next cell will copy the `.rar` checkpoint file from your Google Drive into the Colaboratory VM.

In [None]:
gpt2.copy_checkpoint_from_gdrive(run_name='run1')

The next cell will allow you to load the retrained model checkpoint + metadata necessary to generate text.

**IMPORTANT NOTE:** If you want to rerun this cell, **restart the VM first** (Runtime -> Restart Runtime). You will need to rerun imports but not recopy files.

In [None]:
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess, run_name='run1')

## Generate Text From The Trained Model

After you've trained the model or loaded a retrained model from checkpoint, you can now generate text. `generate` generates a single text from the loaded model.

In [None]:
gpt2.generate(sess, run_name='run1')

If you're creating an API based on your model and need to pass the generated text elsewhere, you can do `text = gpt2.generate(sess, return_as_list=True)[0]`

You can also pass in a `prefix` to the generate function to force the text to start with a given character sequence and generate text from there (good if you add an indicator when the text starts).

You can also generate multiple texts at a time by specifing `nsamples`. Unique to GPT-2, you can pass a `batch_size` to generate multiple samples in parallel, giving a massive speedup (in Colaboratory, set a maximum of 20 for `batch_size`).

Other optional-but-helpful parameters for `gpt2.generate` and friends:

*  **`length`**: Number of tokens to generate (default 1023, the maximum)
* **`temperature`**: The higher the temperature, the crazier the text (default 0.7, recommended to keep between 0.7 and 1.0)
* **`top_k`**: Limits the generated guesses to the top *k* guesses (default 0 which disables the behavior; if the generated output is super crazy, you may want to set `top_k=40`)
* **`top_p`**: Nucleus sampling: limits the generated guesses to a cumulative probability. (gets good results on a dataset with `top_p=0.9`)
* **`truncate`**: Truncates the input text until a given sequence, excluding that sequence (e.g. if `truncate='<|endoftext|>'`, the returned text will include everything before the first `<|endoftext|>`). It may be useful to combine this with a smaller `length` if the input texts are short.
*  **`include_prefix`**: If using `truncate` and `include_prefix=False`, the specified `prefix` will not be included in the returned text.

In [None]:
gen_file = 'output.txt'

gpt2.generate_to_file(sess,
                      destination_path=gen_file,
                      length=1023,
                      temperature=0.6,
                      prefix="p45 wait6 p54 wait3 p26 wait5 endp45 wait1 p33 wait5 p42 wait13 p47 wait5 endp54",
                      nsamples=1,
                      batch_size=1,
                      )

In [None]:
#@title Plot, Graph, and Listen to the Output :)
graphs_length_inches = 18 #@param {type:"slider", min:0, max:20, step:1}
notes_graph_height = 6 #@param {type:"slider", min:0, max:20, step:1}
highest_displayed_pitch = 92 #@param {type:"slider", min:1, max:128, step:1}
lowest_displayed_pitch = 24 #@param {type:"slider", min:1, max:128, step:1}


#@title Generate and Download resulting MIDI file
time_coefficient = 1 #@param {type:"integer"}

import numpy as np
import torch
import torch.nn as nn

from torch import optim
import torch.nn.functional as F

import keras
from keras.utils import to_categorical

import time

import pretty_midi
from midi2audio import FluidSynth
from google.colab import output
from IPython.display import display, Javascript, HTML, Audio

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")


sample_freq_variable = 12 #@param {type:"number"}
note_range_variable = 62 #@param {type:"number"}
note_offset_variable = 33 #@param {type:"number"}
number_of_instruments = 2 #@param {type:"number"}
chamber_option = True #@param {type:"boolean"}





import os
import dill as pickle
from pathlib import Path
import random
import numpy as np
import pandas as pd
from math import floor
from pyknon.genmidi import Midi
from pyknon.music import NoteSeq, Note
import music21
import random
import os, argparse

# default settings: sample_freq=12, note_range=62

def decoder(filename):
    
    filedir = '/content/'

    notetxt = filedir + filename

    with open(notetxt, 'r') as file:
        notestring=file.read()

    score_note = notestring.split(" ")

    # define some parameters (from encoding script)
    sample_freq=sample_freq_variable
    note_range=note_range_variable
    note_offset=note_offset_variable
    chamber=chamber_option
    numInstruments=number_of_instruments

    # define variables and lists needed for chord decoding
    speed=time_coefficient/sample_freq
    piano_notes=[]
    violin_notes=[]
    time_offset=0

    # start decoding here
    score = score_note

    i=0

    # for outlier cases, not seen in sonat-1.txt
    # not exactly sure what scores would have "p_octave_" or "eoc" (end of chord?)
    # it seems to insert new notes to the score whenever these conditions are met
    while i<len(score):
        if score[i][:9]=="p_octave_":
            add_wait=""
            if score[i][-3:]=="eoc":
                add_wait="eoc"
                score[i]=score[i][:-3]
            this_note=score[i][9:]
            score[i]="p"+this_note
            score.insert(i+1, "p"+str(int(this_note)+12)+add_wait)
            i+=1
        i+=1


    # loop through every event in the score
    for i in range(len(score)):

        # if the event is a blank, space, "eos" or unknown, skip and go to next event
        if score[i] in ["", " ", "<eos>", "<unk>"]:
            continue

        # if the event starts with 'end' indicating an end of note
        elif score[i][:3]=="end":

            # if the event additionally ends with eoc, increare the time offset by 1
            if score[i][-3:]=="eoc":
                time_offset+=1
            continue

        # if the event is wait, increase the timestamp by the number after the "wait"
        elif score[i][:4]=="wait":
            time_offset+=int(score[i][4:])
            continue

        # in this block, we are looking for notes   
        else:
            # Look ahead to see if an end<noteid> was generated
            # soon after.  
            duration=1
            has_end=False
            note_string_len = len(score[i])
            for j in range(1,200):
                if i+j==len(score):
                    break
                if score[i+j][:4]=="wait":
                    duration+=int(score[i+j][4:])
                if score[i+j][:3+note_string_len]=="end"+score[i] or score[i+j][:note_string_len]==score[i]:
                    has_end=True
                    break
                if score[i+j][-3:]=="eoc":
                    duration+=1

            if not has_end:
                duration=12

            add_wait = 0
            if score[i][-3:]=="eoc":
                score[i]=score[i][:-3]
                add_wait = 1

            try: 
                new_note=music21.note.Note(int(score[i][1:])+note_offset)    
                new_note.duration = music21.duration.Duration(duration*speed)
                new_note.offset=time_offset*speed
                if score[i][0]=="v":
                    violin_notes.append(new_note)
                else:
                    piano_notes.append(new_note)                
            except:
                print("Unknown note: " + score[i])




            time_offset+=add_wait

    # list of all notes for each instrument should be ready at this stage

    # creating music21 instrument objects      
    
    piano=music21.instrument.fromString("Piano")
    violin=music21.instrument.fromString("Violin")

    # insert instrument object to start (0 index) of notes list
    
    piano_notes.insert(0, piano)
    violin_notes.insert(0, violin)
    # create music21 stream object for individual instruments
    
    piano_stream=music21.stream.Stream(piano_notes)
    violin_stream=music21.stream.Stream(violin_notes)
    # merge both stream objects into a single stream of 2 instruments
    note_stream = music21.stream.Stream([piano_stream, violin_stream])

    
    note_stream.write('midi', fp="/content/"+filename[:-4]+".mid")
    print("Done! Decoded midi file saved to 'content/'")

    
decoder('output.txt')
from google.colab import files
files.download('/content/output.mid')





import librosa
import numpy as np
import pretty_midi
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
#matplotlib.use('SVG')
# For plotting
import mir_eval.display
import librosa.display
%matplotlib inline


midi_data = pretty_midi.PrettyMIDI('/content/output.mid')

def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))



roll = np.zeros([int(graphs_length_inches), 128])
# Plot the output

track = Multitrack('/content/output.mid', name='track')
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
fig, ax = track.plot()
fig.set_size_inches(graphs_length_inches, notes_graph_height)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
ax2 = plot_piano_roll(midi_data, int(lowest_displayed_pitch), int(highest_displayed_pitch))
plt.show(block=False)


FluidSynth("/content/font.sf2", 16000).midi_to_audio('/content/output.mid', '/content/output.wav')
Audio('/content/output.wav', rate=16000)

# LICENSE

MIT License

Copyright (c) 2019 Max Woolf

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.