# Mini Muse Maker (ver. 1.0)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

Credit for GPT2-RGA code used in this colab goes out @ Sashmark97 https://github.com/Sashmark97/midigen and @ Damon Gwinn https://github.com/gwinndr/MusicTransformer-Pytorch

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2022

***

# (Setup Environment)

In [None]:
#@title nvidia-smi gpu check
!nvidia-smi

In [None]:
#@title Install all dependencies (run only once per session)

!git clone https://github.com/asigalov61/Mini-Muse
!pip install torch
!pip install tqdm
!pip install matplotlib

In [None]:
#@title Import all needed modules

print('Loading needed modules. Please wait...')
import os
from tqdm import tqdm
import random
import secrets

if not os.path.exists('/content/Dataset'):
    os.makedirs('/content/Dataset')

print('Loading TMIDIX and GPT2RGAX modules...')
os.chdir('/content/Mini-Muse')
import TMIDIX
from GPT2RGAX import *

import matplotlib.pyplot as plt

os.chdir('/content/')

# (FROM SCRATCH) Download and process MIDI dataset

In [None]:
#@title Download original LAKH MIDI Dataset (Recommended)

%cd /content/Dataset/

!wget 'http://hog.ee.columbia.edu/craffel/lmd/lmd_full.tar.gz'
!tar -xvf 'lmd_full.tar.gz'
!rm 'lmd_full.tar.gz'

%cd /content/

# (PROCESS)

In [None]:
#@title Process MIDIs with TMIDIX MIDI processor

sorted_or_random_file_loading_order = False # Sorted order is NOT usually recommended
dataset_ratio = 1 # Change this if you need more data


print('TMIDIX MIDI Processor')
print('Starting up...')
###########

files_count = 0

gfiles = []

melody_chords_f = []

###########

print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/content/Dataset"
# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

if filez == []:
    print('Could not find any MIDI files. Please check Dataset dir...')
    print('=' * 70)

if sorted_or_random_file_loading_order:
    print('Sorting files...')
    filez.sort()
    print('Done!')
    print('=' * 70)
else:
    print('Randomizing file list...')
    random.shuffle(filez)

    
stats = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
middles_stats = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

print('Processing MIDI files. Please wait...')
for f in tqdm(filez[:int(len(filez) * dataset_ratio)]):
    try:
        fn = os.path.basename(f)
        fn1 = fn.split('.')[0]

        #print('Loading MIDI file...')
        score = TMIDIX.midi2ms_score(open(f, 'rb').read())

        events_matrix = []

        itrack = 1

        patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

        patch_map = [[0, 1, 2, 3, 4, 5, 6, 7], # Piano 
                     [24, 25, 26, 27, 28, 29, 30], # Guitar
                     [32, 33, 34, 35, 36, 37, 38, 39], # Bass
                     [40, 41], # Violin
                     [42, 43], # Cello
                     [46], # Harp
                     [56, 57, 58, 59, 60], # Trumpet
                     [71, 72], # Clarinet
                     [73, 74, 75], # Flute
                     [-1], # Fake Drums
                     [52, 53] # Choir
                    ]

        while itrack < len(score):
            for event in score[itrack]:         
                if event[0] == 'note' or event[0] == 'patch_change':
                    events_matrix.append(event)
            itrack += 1

        if len(events_matrix) > 512:

          events_matrix.sort(key=lambda x: x[1])

          events_matrix1 = []
          for event in events_matrix:
                  if event[0] == 'patch_change':
                      patches[event[2]] = event[3]

                  if event[0] == 'note':
                      event.extend([patches[event[3]]])
                      once = False
                      
                      for p in patch_map:
                          if event[6] in p and event[3] != 9: # Except the drums
                              event[3] = patch_map.index(p)
                              once = True
                              
                      if not once and event[3] != 9: # Except the drums
                          event[3] = 0 # All other instruments/patches channel
                          event[5] = max(80, event[5])
                          
                      if event[3] < 11: # We won't write chans 11-16 for now...
                          events_matrix1.append(event)
                          stats[event[3]] += 1

          # Sorting...
          events_matrix1.sort(key=lambda x: (x[1], x[3]))

          # recalculating timings
          for e in events_matrix1:
              e[1] = int(e[1] / 16)
              e[2] = int(e[2] / 32)
          
          # final processing...

          if len(events_matrix1) > 512:
              melody_chords = []

              pe = events_matrix1[int(len(events_matrix1) / 2)-128]
              for e in events_matrix1[int(len(events_matrix1) / 2)-128:int(len(events_matrix1) / 2)+128]:

                  time = max(0, min(127, e[1]-pe[1]))
                  dur = max(1, min(127, e[2]))
                  cha = max(0, min(10, e[3]))
                  ptc = max(1, min(127, e[4]))
                  vel = max(16, min(127, e[5]))

                  div_vel = int(vel / 16)

                  chan_vel = (cha * 10) + div_vel

                  melody_chords.append([chan_vel, time+128, dur+256, ptc+384])

                  middles_stats[cha] += 1

                  pe = e

              melody_chords_f.append(melody_chords)

              files_count += 1
        
    except KeyboardInterrupt:
        print('Saving current progress and quitting...')
        break  

    except:
        print('Bad MIDI:', f)
        continue

print('=' * 70)
        
print('Done!')   
print('=' * 70)

print('Resulting Stats:')
print('=' * 70)
print('Total MIDI Excerpts:', files_count)
print('=' * 70)

print('Piano:', middles_stats[0])
print('Guitar:', middles_stats[1])
print('Bass:', middles_stats[2])
print('Violin:', middles_stats[3])
print('Cello:', middles_stats[4])
print('Harp:', middles_stats[5])
print('Trumpet:', middles_stats[6])
print('Clarinet:', middles_stats[7])
print('Flute:', middles_stats[8])
print('Drums:', middles_stats[9])
print('Choir:', middles_stats[10])

print('=' * 70)

TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/Mini_Muse_Processed_MIDIs')

# (PREP INTs)

In [None]:
#@title Process and prep INTs...


randomize_dataset = True

print('=' * 70)
print('Prepping INTs dataset...')

if randomize_dataset:
    print('=' * 70)
    print('Randomizing the dataset...')
    random.shuffle(melody_chords_f)
    print('Done!')
    
print('=' * 70)
print('Processing the dataset...')

train_data1 = []

for m in tqdm(melody_chords_f):
  if len(m) != 256:
    print('Error')
  else:
    train_data1.extend([0])
    for mm in m:
      train_data1.extend(mm)

print('Done!')        
print('=' * 70)
        
print('Total INTs:', len(train_data1))
print('Minimum INT:', min(train_data1))
print('Maximum INT:', max(train_data1))
print('Unique INTs:', len(set(train_data1)))
print('Intro/Zero INTs:', train_data1.count(0))
print('=' * 70)

In [None]:
#@title Save INTs
TMIDIX.Tegridy_Any_Pickle_File_Writer(train_data1, '/content/Mini_Muse_INTs')

# Test the resulting INTs dataset...

In [None]:
#@title Test INTs
print('Sample INTs', train_data1[:15])

out = train_data1[:16000]

if len(out) != 0:
    
    song = out
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0

    son = []

    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:

        channel = s[0] // 10

        vel = (s[0] % 10) * 16

        time += (s[1]-128) * 16
            
        dur = (s[2] - 256) * 32
        
        pitch = (s[3] - 384)
                                  
        song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Mini Muse',  
                                                        output_file_name = '/content/Mini-Muse-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 0, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')

# (LOAD INTs)

In [None]:
#@title Load processed INTs dataset

SEQ_LEN = max_seq

BATCH_SIZE = 4 # Change this to your specs

# DO NOT FORGET TO ADJUST MODEL PARAMS IN GPT2RGAX module to your specs

print('=' * 50)
print('Loading training data...')

data_train, data_val = torch.LongTensor(train_data1), torch.LongTensor(train_data1)

class MusicSamplerDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):

        rand = secrets.randbelow((self.data.size(0)-(self.seq_len+1)) // (self.seq_len+1)) * (self.seq_len+1)

        x = self.data[rand: rand + self.seq_len].long()
        trg = self.data[(rand+1): (rand+1) + self.seq_len].long()
        
        return x, trg

    def __len__(self):
        return self.data.size(0)

train_dataset = MusicSamplerDataset(data_train, SEQ_LEN)
val_dataset   = MusicSamplerDataset(data_val, SEQ_LEN)
train_loader  = DataLoader(train_dataset, batch_size = BATCH_SIZE)
val_loader    = DataLoader(val_dataset, batch_size = BATCH_SIZE)
print('=' * 50)
print('Total INTs in the dataset', len(train_data1))
print('Total unique INTs in the dataset', len(set(train_data1)))
print('Max INT in the dataset', max(train_data1))
print('Min INT in the dataset', min(train_data1))
print('=' * 50)
print('Length of the dataset:',len(train_dataset))
print('Number of batched samples per epoch:', len(train_data1) // max_seq // BATCH_SIZE)
print('=' * 50)
print('Sample train dataset:', train_dataset[0])
print('Sample val dataset:', val_dataset[0])
print('=' * 50)
print('Train loader length:', len(train_loader))
print('Val loader length:', len(val_loader))
print('=' * 50)
print('Done! Enjoy! :)')
print('=' * 50)

# (TRAIN)

# Train the model

In [None]:
#@title Train

DIC_SIZE = 512

# DO NOT FORGET TO ADJUST MODEL PARAMS IN GPT2RGAX module to your specs

config = GPTConfig(DIC_SIZE, 
                   max_seq,
                   dim_feedforward=1024,
                   n_layer=16, 
                   n_head=16, 
                   n_embd=1024,
                   enable_rpr=True,
                   er_len=max_seq)

# DO NOT FORGET TO ADJUST MODEL PARAMS IN GPT2RGAX module to your specs

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = GPT(config)

model = nn.DataParallel(model)

model.to(device)

#=====

init_step = 0
lr = LR_DEFAULT_START
lr_stepper = LrStepTracker(d_model, SCHEDULER_WARMUP_STEPS, init_step)
eval_loss_func = nn.CrossEntropyLoss(ignore_index=DIC_SIZE)
train_loss_func = eval_loss_func

opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)
lr_scheduler = LambdaLR(opt, lr_stepper.step)


#===

best_eval_acc        = 0.0
best_eval_acc_epoch  = -1
best_eval_loss       = float("inf")
best_eval_loss_epoch = -1
best_acc_file = '/content/gpt2_rpr_acc.pth'
best_loss_file = '/content/gpt2_rpr_loss.pth'
loss_train, loss_val, acc_val = [], [], []

for epoch in range(0, epochs):
    new_best = False
    
    loss = train(epoch+1, 
                 model, train_loader, 
                 train_loss_func, 
                 opt, 
                 lr_scheduler, 
                 num_iters=-1, 
                 save_checkpoint_steps=4000)
    
    loss_train.append(loss)
    
    eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)
    loss_val.append(eval_loss)
    acc_val.append(eval_acc)
    
    if(eval_acc > best_eval_acc):
        best_eval_acc = eval_acc
        best_eval_acc_epoch  = epoch+1
        torch.save(model.state_dict(), best_acc_file)
        new_best = True

    if(eval_loss < best_eval_loss):
        best_eval_loss       = eval_loss
        best_eval_loss_epoch = epoch+1
        torch.save(model.state_dict(), best_loss_file)
        new_best = True
    
    if(new_best):
        print("Best eval acc epoch:", best_eval_acc_epoch)
        print("Best eval acc:", best_eval_acc)
        print("")
        print("Best eval loss epoch:", best_eval_loss_epoch)
        print("Best eval loss:", best_eval_loss)

In [None]:
#@title Eval funct to eval separately if needed


#=====

init_step = 0
lr = LR_DEFAULT_START
lr_stepper = LrStepTracker(d_model, SCHEDULER_WARMUP_STEPS, init_step)
eval_loss_func = nn.CrossEntropyLoss(ignore_index=DIC_SIZE)
train_loss_func = eval_loss_func

opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)
lr_scheduler = LambdaLR(opt, lr_stepper.step)


eval_loss, eval_acc = eval_model(model, val_loader, eval_loss_func, num_iters=-1)

# (SAVE)

In [None]:
#@title Save the model

print('Saving the model...')
full_path_to_model_checkpoint = "/content/Mini-Muse-Trained-Model.pth" #@param {type:"string"}
torch.save(model.state_dict(), full_path_to_model_checkpoint)
print('Done!')

# Congrats! You did it! :)