# Quantum Music (ver. 1.0)

## Tokenized Sparse Time Quantization Example

***

Powered by tegridy-tools TMIDIX Optimus Processors: https://github.com/asigalov61/tegridy-tools

***

Credit for GPT2-RGA code used in this colab goes out @ Sashmark97 https://github.com/Sashmark97/midigen and @ Damon Gwinn https://github.com/gwinndr/MusicTransformer-Pytorch

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2021

***

# (Setup Environment)

In [None]:
#@title nvidia-smi gpu check
!nvidia-smi

In [None]:
#@title Install all dependencies (run only once per session)

!git clone https://github.com/asigalov61/tegridy-tools
!pip install torch
!pip install tqdm
!pip install matplotlib

In [None]:
#@title Import all needed modules

print('Loading needed modules. Please wait...')
import os
from datetime import datetime
import secrets
import copy
import tqdm
from tqdm import tqdm

if not os.path.exists('/notebooks/Dataset'):
    os.makedirs('/notebooks/Dataset')

print('Loading TMIDIX module...')
os.chdir('/notebooks/tegridy-tools/tegridy-tools')
import TMIDIX

os.chdir('/notebooks/tegridy-tools/tegridy-tools')
from GPT2RGAX import *

import matplotlib.pyplot as plt

os.chdir('/notebooks/')

# (FROM SCRATCH) Download and process MIDI dataset

In [None]:
#@title Download Endless Violin Carousel MIDI dataset (Recommended)

#@markdown Piano Violin Duo

#@markdown Works best stand-alone/as-is for the optimal results
%cd /notebooks/Dataset/

!wget 'https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Endless-Violin-Carousel-CC-BY-NC-SA.zip'
!unzip -j '/notebooks/Dataset/Endless-Violin-Carousel-CC-BY-NC-SA.zip'
!rm '/notebooks/Dataset/Endless-Violin-Carousel-CC-BY-NC-SA.zip'

%cd /notebooks/

In [None]:
#@title Process MIDIs to special MIDI dataset with Tegridy MIDI Processor

#@markdown IMPORTANT NOTES:

#@markdown 1) Best results are achieved with the single-track, single-channel, single-instrument MIDI 0 files with plain English names (avoid special or sys/foreign chars)

#@markdown 2) MIDI Channel = -1 means all MIDI channels except the drums. MIDI Channel = 16 means all channels will be processed. Otherwise, only single indicated MIDI channel will be processed

desired_dataset_name = "Quantum-Music-Dataset" #@param {type:"string"}
file_name_to_output_dataset_to = "/notebooks/Quantum-Music-Dataset" #@param {type:"string"}
desired_MIDI_channel_to_process = -1 #@param {type:"slider", min:-1, max:16, step:1}
sorted_or_random_file_loading_order = False #@param {type:"boolean"}
encode_velocities = True #@param {type:"boolean"}
encode_MIDI_channels = True #@param {type:"boolean"}
add_transposed_dataset_by_this_many_pitches = 0 #@param {type:"slider", min:-12, max:12, step:1}
add_transposed_and_flipped_dataset = False #@param {type:"boolean"}
chordify_input_MIDIs = False #@param {type:"boolean"}
melody_conditioned_chords = False #@param {type:"boolean"}
melody_pitch_baseline = 60 #@param {type:"slider", min:0, max:127, step:1}
time_denominator = 1 #@param {type:"slider", min:1, max:50, step:1}
transform_to_pitch = 0 #@param {type:"slider", min:0, max:127, step:1}
perfect_timings = True #@param {type:"boolean"}
MuseNet_encoding = True #@param {type:"boolean"}
chars_encoding_offset = 0 #@param {type:"number"}

print('TMIDI Optimus MIDI Processor')
print('Starting up...')
###########

average_note_pitch = 0
min_note = 127
max_note = 0

files_count = 0

gfiles = 0

chords_list_f = []
melody_list_f = []

chords_list = []
chords_count = 0

melody_chords = []
melody_count = 0

TXT_String = ''

TXT = ''
melody = []
chords = []
INTS_f = []

flist = []

###########

print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/notebooks/Dataset/"
os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

if filez == []:
  print('Could not find any MIDI files. Please check Dataset dir...')
  print('=' * 70)

if sorted_or_random_file_loading_order:
  print('Sorting files...')
  filez.sort()
  print('Done!')
  print('=' * 70)

else:
  random.shuffle(filez)

# Stamping the dataset info
print('Stamping the dataset info...')

TXT_String += 'DATASET=' + str(desired_dataset_name) + chr(10)
TXT_String += 'CREATED_ON=' + str(datetime.now()).replace(' ', '-').replace(':', '-').replace('.', '-') + chr(10)

TXT_String += 'CHARS_ENCODING_OFFSET=' + str(chars_encoding_offset) + chr(10)
TXT_String += 'TIME_DENOMINATOR=' + str(time_denominator) + chr(10)
TXT_String += 'TRANSFORM=' + str(transform_to_pitch) + chr(10)
TXT_String += 'PERFECT_TIMINGS=' + str(perfect_timings) + chr(10)
TXT_String += 'MUSENET_ENCODING=' + str(MuseNet_encoding) + chr(10)
TXT_String += 'TRANSPOSED_BY=' + str(add_transposed_dataset_by_this_many_pitches) + chr(10)
TXT_String += 'TRANSPOSED_AND_FLIPPED=' + str(add_transposed_and_flipped_dataset) + chr(10)

TXT_String += 'LEGEND=STA-DUR-PTC'
if encode_velocities:
  TXT_String += '-VEL'
if encode_MIDI_channels:
  TXT_String += '-CHA'
TXT_String += chr(10)

print('Processing MIDI files. Please wait...')
for f in tqdm(filez):
  try:
    fn = os.path.basename(f)
    fn1 = fn.split('.')[0]

    files_count += 1
    TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, MIDI_patch=range(0, 127), melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)
    TXT_String += TXT
    melody_list_f += melody
    chords_list_f.append(chords)
    INTS_f.append(INTS)
    flist.append([f, fn1])
    gfiles += 1

    if add_transposed_dataset_by_this_many_pitches != 0:

      TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, transpose_by=add_transposed_dataset_by_this_many_pitches, MIDI_patch=range(0, 127), melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)
      TXT_String += TXT
      melody_list_f += melody
      chords_list_f.append(chords)
      INTS_f.append(INTS)
      gfiles += 1

    if add_transposed_and_flipped_dataset == True:

      TXT, melody, chords, bass_melody, karaokez, INTS, aux1, aux2 = TMIDIX.Optimus_MIDI_TXT_Processor(f, chordify_TXT=chordify_input_MIDIs, output_MIDI_channels=encode_MIDI_channels, char_offset=chars_encoding_offset, dataset_MIDI_events_time_denominator=time_denominator, output_velocity=encode_velocities, MIDI_channel=desired_MIDI_channel_to_process, transpose_by=-12, MIDI_patch=range(0, 127), flip=True, melody_conditioned_encoding=melody_conditioned_chords, melody_pitch_baseline=melody_pitch_baseline, perfect_timings=perfect_timings, musenet_encoding=MuseNet_encoding, transform=transform_to_pitch)
      TXT_String += TXT
      melody_list_f += melody
      chords_list_f += chords
      INTS_f.append(INTS)
      gfiles += 1

  except KeyboardInterrupt:
    print('Saving current progress and quitting...')
    break  
  
  except:
    print('Bad MIDI:', f)
    continue

TXT_String += 'TOTAL_SONGS_IN_DATASET=' + str(gfiles)

try:
  print('Task complete :)')
  print('==================================================')
  if add_transposed_dataset_by_this_many_pitches != 0:
    print('NOTE: Transposed dataset was added per users request.')
    print('==================================================')
  if add_transposed_and_flipped_dataset == True:
    print('NOTE: Flipped dataset was added per users request.')  
    print('==================================================')
  print('Number of processed dataset MIDI files:', files_count)
  print('Number of MIDI chords recorded:', len(chords_list_f))
  print('First chord event:', chords_list_f[0], 'Last chord event:', chords_list_f[-1]) 
  print('Number of recorded melody events:', len(melody_list_f))
  print('First melody event:', melody_list_f[0], 'Last Melody event:', melody_list_f[-1])
  print('Total number of MIDI events recorded:', len(chords_list_f) + len(melody_list_f))
  print('==================================================')

  # Writing dataset to TXT file
  with open(file_name_to_output_dataset_to + '.txt', 'wb') as f:
    f.write(TXT_String.encode('utf-8', 'replace'))
    f.close

  # Dataset
  MusicDataset = [chords_list_f, melody_list_f, INTS_f]

  # Writing dataset to pickle file
  TMIDIX.Tegridy_Any_Pickle_File_Writer(MusicDataset, file_name_to_output_dataset_to)

except:
  print('=' * 70)
  print('IO Error!')
  print('Please check that Dataset dir is not empty/check other IO code.')
  print('=' * 70)
  print('Shutting down...')
  print('=' * 70)

In [None]:
INTS_f1 = []


for chords_list in tqdm(chords_list_f):
    INTS_f1.append([-1, -1, -1]) # Intro
    pe = chords_list[0]
    for i in chords_list:

        INTS_f1.append([int(abs(i[1]-pe[1])/ 10), int(i[2] / 10), i[4] ])
        
        if chords_list.index(i) == len(chords_list)-50:
            INTS_f1.append([-2, -2, -2]) # Outro
        
        
        pe = i
    INTS_f1.append([-3, -3, -3]) # End

In [None]:
#@title Load processed INTs datasets
number_of_batches = 16 #@param {type:"slider", min:2, max:32, step:2}
n_workers = 6

print('=' * 50)
print('Prepping INTs datasets...')


train_data1 = []
for i in INTS_f1:
  if max(i) < 256 and min(i) >= 0:

      if i[0] < 16:
        train_data1.extend([i[0]])
      else:
        train_data1.extend([16, i[0]-16])
       
      train_data1.extend([256+i[2], 512+i[1]-4 ])
  
  if max(i) == -1 and min(i) == -1: # Intro
      train_data1.extend([256+512-3])
  
  if max(i) == -2 and min(i) == -2: # Outro
      train_data1.extend([256+512-2])
  
  if max(i) == -3 and min(i) == -3: # End
      train_data1.extend([256+512-1])

train_data = train_data1[:int(len(train_data1) / 3)]

val_dataset = train_data[:int(len(train_data) * 0.03)]
test_dataset = train_data[:int(len(train_data) * 0.03)]

train_list = train_data
val_list = val_dataset
test_list = []
print('=' * 50)

print('Processing INTs datasets...')
train_dataset = EPianoDataset(train_list, max_seq, random_seq)
val_dataset = EPianoDataset(val_list, max_seq)
test_dataset = EPianoDataset(test_list, max_seq)
print('=' * 50)

print('Loading INTs datasets...')
batch_size = number_of_batches
train_loader = DataLoader(train_dataset, batch_size=batch_size, num_workers=n_workers, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=n_workers)
test_loader = DataLoader(test_dataset, batch_size=batch_size, num_workers=n_workers)
print('=' * 50)

print('Total INTs in the dataset', len(train_data))
print('Total unique INTs in the dataset', len(set(train_data)))
print('Max INT in the dataset', max(train_data))
print('Min INT in the dataset', min(train_data))
print('=' * 50)

print('Checking datasets shapes...')
print('=' * 50)

print('Train loader')
for x, tgt in train_loader:
    print(f'X shape: {x.shape}')
    print(f'Target shape: {tgt.shape}')
    break
print('=' * 50)

print('Validation loader')
for x, tgt in val_loader:
    print(f'X shape: {x.shape}')
    print(f'Target shape: {tgt.shape}')
    break
print('=' * 50)

print('Test loader')
for x, tgt in test_loader:
    print(f'X shape: {x.shape}')
    print(f'Target shape: {tgt.shape}')
    break
print('=' * 50)

print('Done! Enjoy! :)')
print('=' * 50)

# Test the resulting INTs dataset...

In [None]:
train_data

In [None]:
out = train_data[:10000]
if len(out) != 0:
  song = []
  song = out
  song_f = []
  time = 0
  pitch = 0
  duration = 0
  for s in song:
    if s >= 0 and s <= 256:
        time += s
    if s >= 256 and s < 512:
        pitch = s-256
    if s >= 512 and s < 256+512-4:
        duration = s-512
        song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])
    
  detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Quantum Music',  
                                                        output_file_name = '/notebooks/Quantum-Music-Composition', 
                                                        track_name='Project Los Angeles', 
                                                        number_of_ticks_per_quarter=500)

  print('Done!')


# (TRAIN)

# Train the model

In [None]:
#@title Train
config = GPTConfig(VOCAB_SIZE, 
                   max_seq,
                   dim_feedforward=dim_feedforward,
                   n_layer=6, 
                   n_head=8, 
                   n_embd=512,
                   enable_rpr=True,
                   er_len=max_seq)
model = GPT(config).to(get_device())

#=====

init_step = 0
lr = LR_DEFAULT_START
lr_stepper = LrStepTracker(d_model, SCHEDULER_WARMUP_STEPS, init_step)
eval_loss_func = nn.CrossEntropyLoss(ignore_index=TOKEN_PAD)
train_loss_func = eval_loss_func

opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)
lr_scheduler = LambdaLR(opt, lr_stepper.step)


#===

best_eval_acc        = 0.0
best_eval_acc_epoch  = -1
best_eval_loss       = float("inf")
best_eval_loss_epoch = -1
best_acc_file = '/notebooks/gpt2_rpr_acc.pth'
best_loss_file = '/notebooks/gpt2_rpr_loss.pth'
loss_train, loss_val, acc_val = [], [], []

for epoch in range(0, epochs):
    new_best = False
    
    loss = train(epoch+1, model, train_loader, train_loss_func, opt, lr_scheduler, num_iters=-1)
    loss_train.append(loss)
    
    eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)
    loss_val.append(eval_loss)
    acc_val.append(eval_acc)
    
    if(eval_acc > best_eval_acc):
        best_eval_acc = eval_acc
        best_eval_acc_epoch  = epoch+1
        torch.save(model.state_dict(), best_acc_file)
        new_best = True

    if(eval_loss < best_eval_loss):
        best_eval_loss       = eval_loss
        best_eval_loss_epoch = epoch+1
        torch.save(model.state_dict(), best_loss_file)
        new_best = True
    
    if(new_best):
        print("Best eval acc epoch:", best_eval_acc_epoch)
        print("Best eval acc:", best_eval_acc)
        print("")
        print("Best eval loss epoch:", best_eval_loss_epoch)
        print("Best eval loss:", best_eval_loss)

In [None]:
eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)

In [None]:
train_data

In [None]:
#@title Plot resulting training loss graph

tr_loss_list = [item for sublist in loss_train for item in sublist]
plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
plt.savefig('/notebooks/training-loss.png')

# (SAVE/LOAD)

In [None]:
#@title Save the model

print('Saving the model...')
full_path_to_model_checkpoint = "/notebooks/Quantum-Music-Trained-Model-6.pth" #@param {type:"string"}
torch.save(model.state_dict(), full_path_to_model_checkpoint)
print('Done!')

In [None]:
#@title Load/Reload the model
full_path_to_model_checkpoint = "/notebooks/Quantum-Music-Trained-Model-6.pth" #@param {type:"string"}

print('Loading the model...')
config = GPTConfig(256+512+2, 
                   max_seq,
                   dim_feedforward=dim_feedforward,
                   n_layer=6, 
                   n_head=8, 
                   n_embd=512,
                   enable_rpr=True,
                   er_len=max_seq)

model = GPT(config).to(get_device())

model.load_state_dict(torch.load(full_path_to_model_checkpoint))
print('Done!')

# Custom MIDI option

In [None]:
data = TMIDIX.Optimus_MIDI_TXT_Processor('/notebooks/seed97-super.mid', 
                                         dataset_MIDI_events_time_denominator=10, 
                                         perfect_timings=True, 
                                         musenet_encoding=True, 
                                         char_offset=0, 
                                         MIDI_channel=-1, 
                                         MIDI_patch=range(0, 127)
                                        )

SONG = data[5]
inputs = []
for i in SONG:
    if max(i) < 256 and max(i) >= 0:
        if i[0] < 16:
            inputs.extend([i[0]])
        else:
            
            inputs.extend([16, i[0]-16])
        
        inputs.extend([256+i[3], 512+i[1] ])            

# (GENERATE MUSIC)

In [None]:
#@title Generate and download a MIDI file

number_of_tokens_to_generate = 1024 #@param {type:"slider", min:8, max:1024, step:8}
use_random_primer = False #@param {type:"boolean"}
start_with_zero_token = False #@param {type:"boolean"}
number_of_ticks_per_quarter = 500 #@param {type:"slider", min:50, max:1000, step:50}
dataset_time_denominator = 10
melody_conditioned_encoding = False
encoding_has_MIDI_channels = False 
encoding_has_velocities = False
simulate_velocity = True #@param {type:"boolean"}
save_only_first_composition = True
chars_encoding_offset_used_for_dataset = 33

fname = '/notebooks/Quantum-Music-Composition'

print('Quantum Music Model Generator')

output_signature = 'Quantum Music'
song_name = 'RGA Composition'

model.eval()

if use_random_primer:
  sequence = [random.randint(10, 387) for i in range(64)]
  idx = secrets.randbelow(len(sequence))
  rand_seq = model.generate(torch.Tensor(sequence[idx:idx+120]), target_seq_length=number_of_tokens_to_generate)
  out = rand_seq[0].cpu().numpy().tolist()

else:
  out = []
  
  try:
    if start_with_zero_token:
      sequence = inputs[-512:] #[256+512 - 2, 0]# inputs[-512:]
      rand_seq = model.generate(torch.Tensor(sequence), target_seq_length=number_of_tokens_to_generate, stop_token=256+512)
      out = rand_seq[0].cpu().numpy().tolist()
    else:
      idx = secrets.randbelow(len(train_data))
      sequence = train_data[idx:idx+512]
      rand_seq = model.generate(torch.Tensor(sequence), target_seq_length=number_of_tokens_to_generate, stop_token=256+512)
      out = rand_seq[0].cpu().numpy().tolist()
  
  except:
    print('=' * 50)
    print('Error! Try random priming instead!')
    print('Shutting down...')
    print('=' * 50)

if len(out) != 0:
  song = []
  song = out
  song_f = []
  time = 0
  pitch = 0
  duration = 0
  once = True
  for s in song:
    if s >= 0 and s < 256:
        time += s
    if s >= 256 and s < 512:
        pitch = s-256
    if s >= 512 and s < 256+512-4:
        duration = s-512
        song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])
    
    if song.index(s) >= len(sequence) and once:
        song_f.append(['text_event', abs(time) * 10, 'Continuation Start Here'])
        once = False
    
  detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Quantum Music',  
                                                        output_file_name = '/notebooks/Quantum-Music-Composition', 
                                                        track_name='Project Los Angeles', 
                                                        number_of_ticks_per_quarter=500)
  
  print('Done!')


  #print('Downloading your composition now...')
  #from google.colab import files
  #files.download(fname + '.mid')

  print('=' * 70)
  print('Detailed MIDI stats:')
  for key, value in detailed_stats.items():
        print('=' * 70)
        print(key, '|', value)

  print('=' * 70)

else:
  print('Models output is empty! Check the code...')
  print('Shutting down...')

In [None]:
len(out)

In [None]:
out[-64:]

In [None]:
#@title Auto-Regressive Generator

#@markdown NOTE: You much generate a seed composition first or it is not going to start

number_of_cycles_to_run = 5 #@param {type:"slider", min:1, max:50, step:1}
number_of_prime_tokens = 128 #@param {type:"slider", min:64, max:256, step:64}

print('=' * 70)
print('Quantum Music Auto-Regressive Model Generator')
print('=' * 70)
print('Starting up...')
print('=' * 70)
print('Prime length:', len(out))
print('Prime tokens:', number_of_prime_tokens)
print('Prime input sequence', out[-8:])

if len(out) != 0:
  print('=' * 70)
  out_all = []
  out_all.append(out)
  for i in tqdm(range(number_of_cycles_to_run)):
      rand_seq1 = model.generate(torch.Tensor(out[-number_of_prime_tokens:]), target_seq_length=1024, stop_token=256+512)
      out1 = rand_seq1[0].cpu().numpy().tolist()
      out_all.append(out1[number_of_prime_tokens:])
      out = out1[number_of_prime_tokens:]
      
      print(chr(10))
      print('=' * 70)
      print('Block number:', i+1)
      print('Composition length so far:', (i+1) * 1024, 'notes')
      print('=' * 70)

  print('Done!' * 70)
  print('Total blocks:', i+1)
  print('Final omposition length:', (i+1) * 1024, 'notes')
  print('=' * 70)
  
  out2 = []
  for o in out_all:
    out2.extend(o)

  if len(out2) != 0:
      song = []
      song = out2
      song_f = []
      time = 0
      pitch = 0
      duration = 0
      once = True
      for s in song:
        if s >= 0 and s < 256:
            time += s
        if s >= 256 and s < 512:
            pitch = s-256
        if s >= 512 and s < 256+512-4:
            duration = s-512
            song_f.append(['note', (abs(time))*10, (duration*10), 0, pitch, pitch ])

      detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                            output_signature = 'Quantum Music',  
                                                            output_file_name = '/notebooks/Quantum-Music-Composition', 
                                                            track_name='Project Los Angeles', 
                                                            number_of_ticks_per_quarter=500)

      print('Done!')

    

else:
  print('=' * 70)
  print('INPUT ERROR !!!')
  print('Prime sequence is empty...')
  print('Please generate prime sequence and retry')

print('=' * 70)

# Congrats! You did it! :)