<a href="https://colab.research.google.com/github/asigalov61/Amazing-GPT2-Piano/blob/master/Amazing_GPT2_Piano.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Amazing GPT2 Piano (w/MTM 3.5)

***

## GPT2-based Symbolic Music Artificial Intelligence Model Creator/Trainer

### Multi-Track, Multi-Instrumental, MIDI-TXT-MIDI

***

1) Credit for char-based GPT2 implementation used in this colab goes out to Andrej Karpathy: https://github.com/karpathy/minGPT

2) Credit for very nice Arc diagram MIDI visualizer goes out to J. Brent Runyan: https://github.com/j-brent/arc-diagrams

***

#### Project Los Angeles

#### Tegridy Code 2020

***

# Setup Environment, clone needed repos, and install all required dependencies

In [None]:
#@title Clone minGPT repo and install all dependencies (run only once per session)
!git clone https://github.com/asigalov61/minGPT
!git clone https://github.com/asigalov61/arc-diagrams
!pip install pyknon
!pip install pretty_midi
!pip install pypianoroll
!pip install mir_eval
!apt install fluidsynth #Pip does not work for some reason. Only apt works
!pip install midi2audio
!cp /usr/share/sounds/sf2/FluidR3_GM.sf2 /content/font.sf2

!curl -L "https://github.com/asigalov61/MIDI-TXT-MIDI/raw/master/MIDI.py" > 'MIDI.py'

!mkdir Dataset

In [None]:
#@title Import all modules and setup logging
%cd /content/minGPT
# make deterministic
from mingpt.utils import set_seed
set_seed(42)

import numpy as np
import torchvision
import torch
import torch.nn as nn
from torch import optim
import torch.nn.functional as F
from torch.utils.data import Dataset

import keras
from keras.utils import to_categorical

import time
import math

from mingpt.model import GPT, GPTConfig
from mingpt.trainer import Trainer, TrainerConfig
from mingpt.utils import sample

import tqdm.auto

# For plotting
import mido
import librosa
import pretty_midi
import pypianoroll
from pypianoroll import Multitrack, Track
import matplotlib
import matplotlib.pyplot as plt
import mir_eval.display
import librosa.display
%matplotlib inline

from mido import MidiFile


from midi2audio import FluidSynth

from google.colab import output, drive

from IPython.display import display, Javascript, HTML, Audio, Image

# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

%cd /content/arc-diagrams/
from arc_diagram import plot_arc_diagram

dtype = torch.float
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# Assume that we are on a CUDA machine, then this should print a CUDA device:
print('Available Processing Device is:', device)
%cd /content/

# Upload/download and process MIDI dataset

In [None]:
#@title (OPTION 1) Download Tegridy special Piano/Violin MIDI dataset
%cd /content/Dataset/
!rm *.mid 
!rm *.midi
!wget 'https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Tegridy-MIDI-Dataset-CC-BY-NC-SA.zip'
!unzip '/content/Dataset/Tegridy-MIDI-Dataset-CC-BY-NC-SA.zip'
!rm '/content/Dataset/Tegridy-MIDI-Dataset-CC-BY-NC-SA.zip'
%cd /content/

In [None]:
#@title (OPTION 2) Tiny Karaoke Precision MIDI subset (CC-BY-NC-SA)
%cd /content/Dataset/
!rm *.mid 
!rm *.midi
!wget 'https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Tiny-Karaoke-Precision-MIDI-Subset-CC-BY-NC-SA.zip'
!unzip '/content/Dataset/Tiny-Karaoke-Precision-MIDI-Subset-CC-BY-NC-SA.zip'
!rm '/content/Dataset/Tiny-Karaoke-Precision-MIDI-Subset-CC-BY-NC-SA.zip'
%cd /content/

In [None]:
#@title Process MIDI to TXT (MIDI-TXT-MIDI v.3.5)
encoding_type = "opus-one-byte-encoding" #@param ["score-one-byte-encoding", "opus-one-byte-encoding", "opus-complete-words-encoding"]
enable_sampling = False #@param {type:"boolean"}
sample_length_in_MIDI_events = 1501 #@param {type:"slider", min:0, max:10000, step:1}
advanced_events = True #@param {type:"boolean"}
allow_tempo_changes = True #@param {type:"boolean"}
allow_control_change = True #@param {type:"boolean"}
karaoke = False #@param {type:"boolean"}
debug = False #@param {type:"boolean"}

%cd /content/

# MIDI Dataset to txt dataset converter 
import MIDI
import os
import numpy as np
import tqdm.auto

if os.path.exists("Dataset.txt"):
  os.remove("Dataset.txt")
  print('Removing old Dataset...')
else:
  print("Creating new Dataset file...")



def write_notes(file_address):
      u = 0
      midi_file = open(file_address, 'rb')
      if debug: print('Processing File:', file_address)
      if encoding_type == 'score-one-byte-encoding':
        score = MIDI.midi2score(midi_file.read())
        midi_file.close()
        # ['note', start_time, duration, channel, note, velocity]

        itrack = 1
        


        notes = []

        tokens = []

        this_channel_has_note = False

        file = open('Dataset.txt', 'a')
        file.write('H d0 tMIDI-TXT-MIDI-Textual-Music-Dataset ')
        while itrack < len(score):
            for event in score[itrack]:

                if event[0] == 'note':
                    this_channel_has_note = True
                    notes.append(event[4])
                    
                    tokens.append([event[5], event[3], event[2], event[1]])
                    file.write('N' + ' d' + str(event[1]) + ' D' + str(event[2]) + ' C' + str(event[3]) + ' n' + str(event[4]) + ' V' + str(event[5]) + ' ')

            itrack += 1
            if not this_channel_has_note:
              u+=1
              if debug: 
                print('Uknown Event: ', event[0])

            if this_channel_has_note and len(notes) > sample_length_in_MIDI_events:
              if enable_sampling:
                break
          

        file.close()
        if debug:
          print('File:', midi_file, 'Number of skipped events: ', u)

      if encoding_type == 'opus-one-byte-encoding':
        score = MIDI.midi2opus(midi_file.read())
        midi_file.close()
        # ['note', start_time, duration, channel, note, velocity]

        itrack = 1


        notes = []

        tokens = []

        this_channel_has_note = False

        file = open('Dataset.txt', 'a')
        file.write('H d0 tMIDI-TXT-MIDI-Textual-Music-Dataset ')
        while itrack < len(score):
            for event in score[itrack]:

                if event[0] == 'note_off':
                    this_channel_has_note = True
                    notes.append(event[3])

                    
                    tokens.append([event[3], event[4], event[1]])
                    
                    file.write('F' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' n' + str(event[3]) + ' v' + str(event[4]) + ' ')


                if event[0] == 'note_on':
                    this_channel_has_note = True
                    notes.append(event[3])
                    
                    tokens.append([event[3], event[4], event[1]])

                    file.write('N' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' n' + str(event[3]) + ' v' + str(event[4]) + ' ')

                if event[0] == 'key_after_touch':
                  if advanced_events:
                    this_channel_has_note = True

                    
                    tokens.append([event[3], event[4], event[1]])
                    file.write('K' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' n' + str(event[3]) + ' v' + str(event[4]) + ' ')

                if event[0] == 'control_change':
                  if advanced_events:
                      if allow_control_change:
                        this_channel_has_note = True

                    
                        tokens.append([event[3], event[4], event[1]])
                    
                        file.write('C' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' r' + str(event[3]) + ' l' + str(event[4]) + ' ')

                if event[0] == 'patch_change':
                  if advanced_events:
                      this_channel_has_note = True
                  
                      tokens.append([event[3], event[2], event[1]])
                    
                      file.write('P' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' h' + str(event[3]) + ' ')

                if event[0] == 'channel_after_touch':
                  if advanced_events:
                      this_channel_has_note = True

                    
                      tokens.append([event[3], event[2], event[1]])
                    
                      file.write('Z' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' v' + str(event[3]) + ' ')

                if event[0] == 'pitch_wheel_change':
                  if advanced_events:
                    this_channel_has_note = True

                    
                    tokens.append([event[3], event[2], event[1]])
                    
                    file.write('W' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' p' + str(event[3]) + ' ')


                if event[0] == 'text_event':
                  if karaoke:
                      this_channel_has_note = True

                      tokens.append([event[2], event[1]])
                      
                      file.write('T' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'copyright_text_event':
                  if karaoke:
                      this_channel_has_note = True
                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('R' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'track_name':
                      this_channel_has_note = True
                    
                      tokens.append([event[2], event[1]])
                      
                      file.write('H' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'instrument_name':
                      this_channel_has_note = True
                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('I' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'lyric':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])
                      file.write('L' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'marker':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('M' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'cue_point':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[3], event[4], event[1]])
                      
                      file.write('U' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_08':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])

                      file.write('+' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_09':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])

                      file.write('&' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0a':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('@' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0b':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('#' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'text_event_0c':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('$' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0d':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('%' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0e':
                  if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('*' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0f':
                  if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('=' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'end_track':
                    this_channel_has_note = True
                    
                    tokens.append([event[1]])                
                    file.write('E' + ' d' + str(event[1]) + ' ')

                if event[0] == 'set_tempo':
                  if advanced_events:
                    if allow_tempo_changes:
                      this_channel_has_note = True
                    
                      tokens.append([ event[2], event[1]])
                      
                      file.write('S' + ' d' + str(event[1]) + ' o' + str(event[2]) + ' ')

                if event[0] == 'smpte_offset':
                  if advanced_events:
                    this_channel_has_note = True
                    
                    tokens.append([event[3], event[4], event[1]])
                    
                    file.write('Y' + ' d' + str(event[1]) + ' g' + str(event[2]) + ' n' + str(event[3]) + ' s' + str(event[4]) + ' f' + str(event[5]) + ' e' + str(event[6]) +' ')

                if event[0] == 'time_signature':
                  if advanced_events:
                    this_channel_has_note = True

                    
                    tokens.append([event[3], event[4], event[1]])
                    
                    file.write('B' + ' d' + str(event[1]) + ' u' + str(event[2]) + ' y' + str(event[3]) + ' i' + str(event[4]) + ' j' + str(event[5]) +' ')


                if event[0] == 'key_signature':
                  if advanced_events:
                    this_channel_has_note = True
                    
                    tokens.append([event[3], event[2], event[1]])
                    
                    file.write('A' + ' d' + str(event[1]) + ' b' + str(event[2]) + ' q' + str(event[3]) + ' ')


                if event[0] == 'sequincer_specific':
                  if advanced_events:
                      this_channel_has_note = True
                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('D' + ' d' + str(event[1]) + ' x' + str(event[2]) + ' ')


                if event[0] == 'raw_meta_event':
                  if advanced_events:
                      this_channel_has_note = True  

                      tokens.append([ event[2], event[1]]) 

                      file.write('E' + ' d' + str(event[1]) + ' z' + str(event[2]) + ' x' + str(event[2]) + ' ')

                if event[0] == 'sysex_f0':
                  if advanced_events:
                      this_channel_has_note = True   

                      tokens.append([ event[2], event[1]])  

                      file.write('G' + ' d' + str(event[1]) + ' x' + str(event[2]) + ' ')

                if event[0] == 'sysex_f7':
                  if advanced_events:
                      this_channel_has_note = True  

                      tokens.append([ event[2], event[1]]) 

                      file.write('!' + ' d' + str(event[1]) + ' x' + str(event[2]) + ' ')
                    
                if event[0] == 'song_position':
                  if advanced_events:
                      this_channel_has_note = True

                      tokens.append([ event[2], event[1]])

                      file.write('J' + ' d' + str(event[1]) + ' a' + str(event[2]) + ' ')

                if event[0] == 'song_select':
                  if advanced_events:
                      this_channel_has_note = True 

                      tokens.append([ event[2], event[1]])

                      file.write('O' + ' d' + str(event[1]) + ' m' + str(event[2]) + ' ')

                if event[0] == 'tune_request':
                  if advanced_events:
                      this_channel_has_note = True

                      tokens.append([ event[2], event[1]])

                      file.write('X' + ' d' + str(event[1]) + ' ')



            itrack += 1
            if not this_channel_has_note:
              print('Uknown Event: ', event[0])

            if this_channel_has_note and len(notes) > sample_length_in_MIDI_events:
              if enable_sampling:
                break         

        file.close()

      if encoding_type == 'opus-complete-words-encoding':

        score = MIDI.midi2opus(midi_file.read())
        midi_file.close()
        # ['note', start_time, duration, channel, note, velocity]

        itrack = 1


        notes = []

        tokens = []

        this_channel_has_note = False

        file = open('Dataset.txt', 'a')
        file.write('H d0 tMIDI-TXT-MIDI-Textual-Music-Dataset ')
        while itrack < len(score):
            for event in score[itrack]:

                if event[0] == 'note_off':
                    this_channel_has_note = True
                    notes.append(event[3])

                    
                    tokens.append([event[3], event[4], event[1]])
                    
                    file.write('NoteOff' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' n' + str(event[3]) + ' v' + str(event[4]) + ' ')


                if event[0] == 'note_on':
                    this_channel_has_note = True
                    notes.append(event[3])
                    
                    tokens.append([event[3], event[4], event[1]])

                    file.write('NoteOn' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' n' + str(event[3]) + ' v' + str(event[4]) + ' ')

                if event[0] == 'key_after_touch':
                  if advanced_events:
                    this_channel_has_note = True

                    
                    tokens.append([event[3], event[4], event[1]])
                    file.write('KeyAfterTouch' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' n' + str(event[3]) + ' v' + str(event[4]) + ' ')

                if event[0] == 'control_change':
                  if advanced_events:
                      if allow_control_change:
                        this_channel_has_note = True

                    
                        tokens.append([event[3], event[4], event[1]])
                    
                        file.write('C' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' r' + str(event[3]) + ' l' + str(event[4]) + ' ')

                if event[0] == 'patch_change':
                  if advanced_events:
                      this_channel_has_note = True
                  
                      tokens.append([event[3], event[2], event[1]])
                    
                      file.write('PatchChange' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' h' + str(event[3]) + ' ')

                if event[0] == 'channel_after_touch':
                  if advanced_events:
                      this_channel_has_note = True

                    
                      tokens.append([event[3], event[2], event[1]])
                    
                      file.write('ChannelAfterTouch' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' v' + str(event[3]) + ' ')

                if event[0] == 'pitch_wheel_change':
                  if advanced_events:
                    this_channel_has_note = True

                    
                    tokens.append([event[3], event[2], event[1]])
                    
                    file.write('PitchWheelChange' + ' d' + str(event[1]) + ' c' + str(event[2]) + ' p' + str(event[3]) + ' ')


                if event[0] == 'text_event':
                  if karaoke:
                      this_channel_has_note = True

                      tokens.append([event[2], event[1]])
                      
                      file.write('TextEvent' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'copyright_text_event':
                  if karaoke:
                      this_channel_has_note = True
                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('CopyrightTextEvent' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'track_name':
                      this_channel_has_note = True
                    
                      tokens.append([event[2], event[1]])
                      
                      file.write('TrackName' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'instrument_name':
                      this_channel_has_note = True
                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('InstrumentName' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'lyric':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])
                      file.write('Lyric' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'marker':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('Marker' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'cue_point':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[3], event[4], event[1]])
                      
                      file.write('CuePoint' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_08':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])

                      file.write('TextEvent08' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_09':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])

                      file.write('TextEvent09' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0a':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('TextEvent0a' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0b':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('TextEvent0b' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')


                if event[0] == 'text_event_0c':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('TextEvent0c' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0d':
                    if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('TextEvent0d' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0e':
                  if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('TextEvent0e' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'text_event_0f':
                  if karaoke:
                      this_channel_has_note = True

                      
                      tokens.append([event[2], event[1]])                  
                      file.write('TextEvent0f' + ' d' + str(event[1]) + ' t' + str(event[2]) + ' ')

                if event[0] == 'end_track':
                    this_channel_has_note = True
                    
                    tokens.append([event[1]])                
                    file.write('EndOfTrack' + ' d' + str(event[1]) + ' ')

                if event[0] == 'set_tempo':
                  if advanced_events:
                    if allow_tempo_changes:
                      this_channel_has_note = True
                    
                      tokens.append([ event[2], event[1]])
                      
                      file.write('SetTempo' + ' d' + str(event[1]) + ' o' + str(event[2]) + ' ')

                if event[0] == 'smpte_offset':
                  if advanced_events:
                    this_channel_has_note = True
                    
                    tokens.append([event[3], event[4], event[1]])
                    
                    file.write('SMPTEOffset' + ' d' + str(event[1]) + ' g' + str(event[2]) + ' n' + str(event[3]) + ' s' + str(event[4]) + ' f' + str(event[5]) + ' e' + str(event[6]) +' ')

                if event[0] == 'time_signature':
                  if advanced_events:
                    this_channel_has_note = True

                    
                    tokens.append([event[3], event[4], event[1]])
                    
                    file.write('TimeSignature' + ' d' + str(event[1]) + ' u' + str(event[2]) + ' y' + str(event[3]) + ' i' + str(event[4]) + ' j' + str(event[5]) +' ')


                if event[0] == 'key_signature':
                  if advanced_events:
                    this_channel_has_note = True
                    
                    tokens.append([event[3], event[2], event[1]])
                    
                    file.write('KeySignature' + ' d' + str(event[1]) + ' b' + str(event[2]) + ' q' + str(event[3]) + ' ')


                if event[0] == 'sequincer_specific':
                  if advanced_events:
                      this_channel_has_note = True
                      
                      tokens.append([event[2], event[1]])
                      
                      file.write('SequencerSpecific' + ' d' + str(event[1]) + ' x' + str(event[2]) + ' ')


                if event[0] == 'raw_meta_event':
                  if advanced_events:
                      this_channel_has_note = True  

                      tokens.append([ event[2], event[1]]) 

                      file.write('RawMetaEvent' + ' d' + str(event[1]) + ' z' + str(event[2]) + ' x' + str(event[2]) + ' ')

                if event[0] == 'sysex_f0':
                  if advanced_events:
                      this_channel_has_note = True   

                      tokens.append([ event[2], event[1]])  

                      file.write('SysExF0' + ' d' + str(event[1]) + ' x' + str(event[2]) + ' ')

                if event[0] == 'sysex_f7':
                  if advanced_events:
                      this_channel_has_note = True  

                      tokens.append([ event[2], event[1]]) 

                      file.write('SysExF7' + ' d' + str(event[1]) + ' x' + str(event[2]) + ' ')
                    
                if event[0] == 'song_position':
                  if advanced_events:
                      this_channel_has_note = True

                      tokens.append([ event[2], event[1]])

                      file.write('SongPosition' + ' d' + str(event[1]) + ' a' + str(event[2]) + ' ')

                if event[0] == 'song_select':
                  if advanced_events:
                      this_channel_has_note = True 

                      tokens.append([ event[2], event[1]])

                      file.write('SongSelect' + ' d' + str(event[1]) + ' m' + str(event[2]) + ' ')

                if event[0] == 'tune_request':
                  if advanced_events:
                      this_channel_has_note = True

                      tokens.append([ event[2], event[1]])

                      file.write('TuneRequest' + ' d' + str(event[1]) + ' ')



            itrack += 1
            if not this_channel_has_note:
              print('Uknown Event: ', event[0])

            if this_channel_has_note and len(notes) > sample_length_in_MIDI_events:
              if enable_sampling:
                break
          

        file.close()      
       

dataset_addr = "Dataset"
files = os.listdir(dataset_addr)
for file in tqdm.auto.tqdm(files):
    path = os.path.join(dataset_addr, file)
    write_notes(path)
#print('Done!')
#print('Number of skipped events: ', u)

# Setup and Intialize the Model

In [None]:
#@title Setup functions and procedures
model_attention_span_in_tokens = 256 #@param {type:"slider", min:0, max:512, step:16}

class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        """
        arrange data and targets so that the first i elements of x
        will be asked to predict the i-th element of y. Notice that
        the eventual language model will actually make block_size
        individual predictions at the same time based on this data,
        so we are being clever and amortizing the cost of the forward
        pass of the network. So for example if block_size is 4, then
        we could e.g. sample a chunk of text "hello", the integers in
        x will correspond to "hell" and in y will be "ello". This will
        then actually "multitask" 4 separate examples at the same time
        in the language model:
        - given just "h", please predict "e" as next
        - given "he" please predict "l" next
        - given "hel" predict "l" next
        - given "hell" predict "o" next
        
        In addition, because the DataLoader will create batches of examples,
        every forward/backward pass during traning will simultaneously train
        a LOT of predictions, amortizing a lot of computation. In particular,
        for a batched input of integers X (B, T) where B is batch size and
        T is block_size and Y (B, T), the network will during training be
        simultaneously training to make B*T predictions, all at once! Of course,
        at test time we can paralellize across batch B, but unlike during training
        we cannot parallelize across the time dimension T - we have to run
        a forward pass of the network to recover the next single character of the 
        sequence along each batch dimension, and repeatedly always feed in a next
        character to get the next one.
        
        So yes there is a big asymmetry between train/test time of autoregressive
        models. During training we can go B*T at a time with every forward pass,
        but during test time we can only go B at a time, T times, with T forward 
        passes.
        """
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

        
block_size = model_attention_span_in_tokens # spatial extent of the model for its context

In [None]:
#@title Specify input text file with training data (do not worry, any text format is fine)
full_path_to_training_text_file = "/content/Dataset.txt" #@param {type:"string"}
text = open(full_path_to_training_text_file, 'r').read() # don't worry we won't run out of file handles
train_dataset = CharDataset(text, block_size) # one line of poem is roughly 50 characters

In [None]:
#@title Create GPT2 model
model_embed_size = 128 #@param {type:"slider", min:0, max:1024, step:64}
number_of_heads = 16 #@param {type:"slider", min:1, max:16, step:1}
number_of_layers = 4 #@param {type:"slider", min:1, max:16, step:1}


mconf = GPTConfig(train_dataset.vocab_size, 
                  train_dataset.block_size,
                  n_layer=number_of_layers, 
                  n_head=number_of_heads, 
                  n_embd=model_embed_size)

model = GPT(mconf)

In [None]:
#@title Setup all training parameters
number_of_training_epochs = 5 #@param {type:"slider", min:1, max:100, step:1}
training_batch_size = 160 #@param {type:"slider", min:0, max:160, step:4}
model_learning_rate = 6e-4 #@param {type:"number"}
# initialize a trainer instance and kick off training

tconf = TrainerConfig(max_epochs=number_of_training_epochs, 
                      batch_size=training_batch_size, 
                      learning_rate=model_learning_rate,
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)

# Train the Model or load the existing pre-trained model checkpoint

In [None]:
#@title (OPTION 1) Train the model
%cd /content/
trainer.train()

In [None]:
#@title Plot Positional Embeddings

# visualize some of the learned positional embeddings, maybe they contain structure
plt.figure(figsize=(18, 1))  
ci = model.pos_emb.data[0, :, 0].cpu()
zci = torch.cat((torch.tensor([0.0]), ci)) # pre-cat a zero
plt.imshow(zci.view(1, block_size+1).numpy())
plt.axis('off')

In [None]:
#@title Save/Re-Save the model from memory
%cd /content/
torch.save(model, 'trained-model.pth')

In [None]:
#@title (OPTION 2) Load existing model/checkpoint
full_path_to_model_checkpoint = "/content/trained-model.pth" #@param {type:"string"}
model = torch.load(full_path_to_model_checkpoint)
model.eval()

# Generate, download, plot, and listen to the output

In [None]:
#@title Generate and download music TXT file with the loaded Model (MIDI-TXT-MIDI)
number_of_tokens_to_generate = 16384 #@param {type:"slider", min:0, max:32768, step:128}
creativity_temperature = 0.8 #@param {type:"slider", min:0.05, max:4, step:0.05}
top_k_prob = 3 #@param {type:"slider", min:0, max:50, step:1}
input_promt = "// C:\\Users\\asiga" #@param {type:"string"}

%cd /content/

model.to(device)

context = input_promt
x = torch.tensor([train_dataset.stoi[s] for s in context], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, number_of_tokens_to_generate, temperature=creativity_temperature, sample=True, top_k=top_k_prob)[0]
completion = ''.join([train_dataset.itos[int(i)] for i in y])
print('Done! Saving output.txt!')
with open("/content/output.txt", "w") as text_file:
    print(completion, file=text_file)


from google.colab import files
files.download('/content/output.txt')

In [None]:
#@title Convert to MIDI from TXT (MIDI-TXT-MIDI v.3.5)
number_of_ticks_per_quarter = 433 #@param {type:"slider", min:1, max:1000, step:8}

import MIDI
import tqdm.auto
notes = []
velocities = []
timings = []
durations = []

with open('/content/output.txt', 'r') as file:
    notestring=file.read()

score_note = notestring.split(" ")

score = score_note

i=0

z=1

zero_marker = True

song_score = [number_of_ticks_per_quarter, 
              [['track_name', 0, b'Composed by Artificial Intelligence Model']],              
              ]
if karaoke:
  song_score.append([['track_name', 0, b'M-T-M 3.x Karaoke Encoding']])
else:
  song_score.append([['track_name', 0, b'M-T-M 3.x Music Encoding']])

for i in tqdm.auto.tqdm(range(len(score))):

        # if the event is a blank, space, "eos" or unknown, skip and go to next event
        if score[i] in ["", " ", "<eos>", "<unk>"]:
            continue

        # if the event starts with 'end' indicating an end of note
        elif score[i][:2]=="@@":

            continue

        # in this block, we are looking for notes   
        else:
            # Look ahead to see if an end<noteid> was generated
            # soon after.  


            note_string_len = len(score[i])
            for j in range(1,200):
                if i+j==len(score):
                    break
            if encoding_type == 'score-one-byte-encoding':
              if score[i] == 'N':
                try:
                  if zero_marker == True:
                    trk_nm = 'Track #' + str(z++1)
                    song_score.append([['track_name', 0, trk_nm]])
                    zero_marker = False
                  song_score[-1].append(['note', 
                                        int(score[i+1][1:]), #Start Time
                                        int(score[i+2][1:]), #Duration
                                        int(score[i+3][1:]), #Channel
                                        int(score[i+4][1:]), #Note
                                        int(score[i+5][1:])]) #Velocity
                            
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1])

            if encoding_type == 'opus-one-byte-encoding':
              if score[i] == 'F':
                try:
                  song_score[-1].append(['note_off', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:])])
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1])

              if score[i] == 'N':
                try:
                  if zero_marker == True:
                    trk_nm = 'Track #' + str(z++1)
                    song_score.append([['track_name', 0, trk_nm]])
                    zero_marker = False
                  song_score[-1].append(['note_on', 
                                        int(score[i+1][1:]), #Duration
                                        int(score[i+2][1:]), #Channel
                                        int(score[i+3][1:]), #Note
                                        int(score[i+4][1:])]) #Velocity
                            
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1])

              if score[i] == 'K':
                try:
                  song_score[-1].append(['key_after_touch', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:])])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'C':
                try:
                  song_score[-1].append(['control_change',
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]), #Controller
                                        int(score[i+4][1:])]) #ControlValue
                except:
                  print("Unknown event: " + score[i])




              if score[i] == 'P':
                try:
                  song_score[-1].append(['patch_change', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])
                                        ])
                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'Z':
                try:
                  song_score[-1].append(['channel_after_touch', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])
                                        ])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'W':
                try:
                  song_score[-1].append(['pitch_wheel_change', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])
                                        ])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'T':
                try:
                  song_score[-1].append(['text_event', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                  zero_marker = True
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1] + ' ' + score[i+2])

              if score[i] == 'R':
                try:
                  song_score[-1].append(['copyright_text_event', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'H':
                try:
                  song_score[-1].append(['track_name', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                  zero_marker = True
                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'I':
                try:
                  song_score[-1].append(['instrument_name', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])              
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'L':
                try:
                  song_score[-1].append(['lyric', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                  zero_marker = True
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'M':
                try:
                  song_score[-1].append(['marker', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'U':
                try:
                  song_score[-1].append(['cue_point', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '+':
                try:
                  song_score[-1].append(['text_event_08', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '&':
                try:
                  song_score[-1].append(['text_event_09', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '@':
                try:
                  song_score[-1].append(['text_event_0a', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '#':
                try:
                  song_score[-1].append(['text_event_0b', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '$':
                try:
                  song_score[-1].append(['text_event_0c', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '%':
                try:
                  song_score[-1].append(['text_event_0d', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '*':
                try:
                  song_score[-1].append(['text_event_0e', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == '=':
                try:
                  song_score[-1].append(['text_event_0f', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'E':
                try:
                  song_score[-1].append(['end_track', 
                                        int(score[i+1][1:])])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'S':
                try:
                  song_score[-1].append(['set_tempo', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'Y':
                try:
                  song_score[-1].append(['smpte_offset',
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:]),
                                        int(score[i+5][1:]),
                                        int(score[i+6][1:])])
                except:
                  print("Unknown event: " + score[i])                

              if score[i] == 'B':
                try:
                  song_score[-1].append(['time_signature', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:]),
                                        int(score[i+5][1:])])

                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'A':
                try:
                  song_score[-1].append(['key_signature', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])])

                except:
                  print("Unknown event: " + score[i])



              if score[i] == 'D':
                try:
                  song_score[-1].append(['sequencer_specific', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'E':
                try:
                  song_score[-1].append(['raw_meta_event', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'G':
                try:
                  song_score[-1].append(['sysex_f0', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == '!':
                try:
                  song_score[-1].append(['sysex_f7', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'J':
                try:
                  song_score[-1].append(['song_position', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'O':
                try:
                  song_score[-1].append(['song_select', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'X':
                try:
                  song_score[-1].append(['tune_request', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])


            if encoding_type == 'opus-complete-words-encoding':
              if score[i] == 'NoteOff':
                try:
                  song_score[-1].append(['note_off', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:])])
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1])

              if score[i] == 'NoteOn':
                try:
                  if zero_marker == True:
                    trk_nm = 'Track #' + str(z++1)
                    song_score.append([['track_name', 0, trk_nm]])
                    zero_marker = False
                  song_score[-1].append(['note_on', 
                                        int(score[i+1][1:]), #Duration
                                        int(score[i+2][1:]), #Channel
                                        int(score[i+3][1:]), #Note
                                        int(score[i+4][1:])]) #Velocity
                            
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1])

              if score[i] == 'KeyAfterTouch':
                try:
                  song_score[-1].append(['key_after_touch', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:])])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'ControlChange':
                try:
                  song_score[-1].append(['control_change',
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]), #Controller
                                        int(score[i+4][1:])]) #ControlValue
                except:
                  print("Unknown event: " + score[i])




              if score[i] == 'PatchChange':
                try:
                  song_score[-1].append(['patch_change', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])
                                        ])
                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'ChannelAfterTouch':
                try:
                  song_score[-1].append(['channel_after_touch', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])
                                        ])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'PitchWheelChange':
                try:
                  song_score[-1].append(['pitch_wheel_change', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])
                                        ])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent':
                try:
                  song_score[-1].append(['text_event', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                  zero_marker = True
                except:
                  print("Unknown event: " + score[i] + ' ' + score[i+1] + ' ' + score[i+2])

              if score[i] == 'CopyrightTextEvent':
                try:
                  song_score[-1].append(['copyright_text_event', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TrackName':
                try:
                  song_score[-1].append(['track_name', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                  zero_marker = True
                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'InstrumentName':
                try:
                  song_score[-1].append(['instrument_name', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])              
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'Lyric':
                try:
                  song_score[-1].append(['lyric', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                  zero_marker = True
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'Marker':
                try:
                  song_score[-1].append(['marker', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'CuePoint':
                try:
                  song_score[-1].append(['cue_point', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent08':
                try:
                  song_score[-1].append(['text_event_08', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent09':
                try:
                  song_score[-1].append(['text_event_09', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent0a':
                try:
                  song_score[-1].append(['text_event_0a', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent0b':
                try:
                  song_score[-1].append(['text_event_0b', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent0c':
                try:
                  song_score[-1].append(['text_event_0c', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent0d':
                try:
                  song_score[-1].append(['text_event_0d', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent0e':
                try:
                  song_score[-1].append(['text_event_0e', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TextEvent0f':
                try:
                  song_score[-1].append(['text_event_0f', 
                                        int(score[i+1][1:]), 
                                        score[i+2][1:]])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'EndOfTrack':
                try:
                  song_score[-1].append(['end_track', 
                                        int(score[i+1][1:])])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'SetTempo':
                try:
                  song_score[-1].append(['set_tempo', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])
                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'SMPTEOffset':
                try:
                  song_score[-1].append(['smpte_offset',
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:]),
                                        int(score[i+5][1:]),
                                        int(score[i+6][1:])])
                except:
                  print("Unknown event: " + score[i])                

              if score[i] == 'TimeSignature':
                try:
                  song_score[-1].append(['time_signature', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:]),
                                        int(score[i+4][1:]),
                                        int(score[i+5][1:])])

                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'KeySignature':
                try:
                  song_score[-1].append(['key_signature', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:]),
                                        int(score[i+3][1:])])

                except:
                  print("Unknown event: " + score[i])



              if score[i] == 'SequencerSpecific':
                try:
                  song_score[-1].append(['sequencer_specific', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'RawMetaEvent':
                try:
                  song_score[-1].append(['raw_meta_event', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'SysEx_F0':
                try:
                  song_score[-1].append(['sysex_f0', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'SysEx_F7':
                try:
                  song_score[-1].append(['sysex_f7', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])


              if score[i] == 'SongPosition':
                try:
                  song_score[-1].append(['song_position', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'SongSelect':
                try:
                  song_score[-1].append(['song_select', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

              if score[i] == 'TuneRequest':
                try:
                  song_score[-1].append(['tune_request', 
                                        int(score[i+1][1:]), 
                                        int(score[i+2][1:])])

                except:
                  print("Unknown event: " + score[i])

if encoding_type == 'score-one-byte-encoding':
  midi_data = MIDI.score2midi(song_score)
  if debug:
    print('Encoding Type: ', encoding_type)
else:
  midi_data = MIDI.opus2midi(song_score)
  if debug:
    print('Encoding Type: ', encoding_type)

with open('output.mid', 'wb') as midi_file:
    midi_file.write(midi_data)
    midi_file.close()
print('Done!')

from google.colab import files
files.download('/content/output.mid')

MIDI.score2stats(song_score)

In [None]:
#@title Plot, Graph, and Listen to the Output :)
graphs_length_inches = 18 #@param {type:"slider", min:0, max:20, step:1}
notes_graph_height = 6 #@param {type:"slider", min:0, max:20, step:1}
highest_displayed_pitch = 92 #@param {type:"slider", min:1, max:128, step:1}
lowest_displayed_pitch = 24 #@param {type:"slider", min:1, max:128, step:1}

midi_data = pretty_midi.PrettyMIDI('/content/output.mid')

def plot_piano_roll(pm, start_pitch, end_pitch, fs=100):
    # Use librosa's specshow function for displaying the piano roll
    librosa.display.specshow(pm.get_piano_roll(fs)[start_pitch:end_pitch],
                             hop_length=1, sr=fs, x_axis='time', y_axis='cqt_note',
                             fmin=pretty_midi.note_number_to_hz(start_pitch))



roll = np.zeros([int(graphs_length_inches), 128])
# Plot the output

track = Multitrack('/content/output.mid', name='track')
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
fig, ax = track.plot()
fig.set_size_inches(graphs_length_inches, notes_graph_height)
plt.figure(figsize=[graphs_length_inches, notes_graph_height])
ax2 = plot_piano_roll(midi_data, int(lowest_displayed_pitch), int(highest_displayed_pitch))
plt.show(block=False)


FluidSynth("/content/font.sf2", 16000).midi_to_audio('/content/output.mid', '/content/output.wav')
Audio('/content/output.wav', rate=16000)

## Congrats! :) You did it :)

In [None]:
#@title Make a nice Arc diagram of the output to show friends and family :)
multi_track_input = False #@param {type:"boolean"}
%cd /content/arc-diagrams/

midi_file = '/content/output.mid'
plot_title = "Amazing GPT2 Piano Output Arc Diagram"

# midi_file = 'midis/fuer_elise.mid'
# plot_title = "Für Elise (Beethoven)"

def stringify_notes(midi_file, track_number):

    mid = MidiFile(midi_file)
    track_notes = {}
    for i, track in enumerate(mid.tracks):
        track_notes[i] = ''
        for msg in track:
            if( msg.type == 'note_on'):
                track_notes[i] += str(msg.note) +'n'
            if( msg.type == 'note_off'):
                track_notes[i] += str(msg.note) +'f'
    return track_notes[track_number]

if multi_track_input:
  try:
    plot_arc_diagram(stringify_notes(midi_file, 1), plot_title)
    if debug: 
      print('Debug mode')
    print('Track 1 Arc Diagram')
  except:
    plot_arc_diagram(stringify_notes(midi_file, 0), plot_title)
    if debug: 
      print('Debug mode')
    print('Track 0 Arc Diagram')

Image('output.png')

if multi_track_input:
  try:
   plot_arc_diagram(stringify_notes(midi_file, 2), plot_title)
   if debug: print('Debug mode')
   print('Track 2 Arc Digram')
   Image('output.png')
   plot_arc_diagram(stringify_notes(midi_file, 3), plot_title)
   if debug: print('Debug mode')
   print('Track 3 Arc Diagram')
   Image('output.png')
  except:
    print('Error in processing multiple tracks. Sorry.')
files.download('/content/arc-diagrams/output.png')

# Save the model and/or output to Google Drive (Standard GD connect code)

In [None]:
#@title Mount Google Drive
drive.mount('/content/drive')