# Experimental Music Transformer Version 2 (ver. 0.1)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2023

***

# (SETUP ENVIRONMENT)

In [None]:
!nvidia-smi

In [None]:
#@title Install all dependencies (run only once per session)

!git clone https://github.com/asigalov61/tegridy-tools

In [None]:
#@title Import all needed modules

print('Loading needed modules. Please wait...')
import os

import math
import statistics
import random

from collections import Counter

from tqdm import tqdm

if not os.path.exists('/content/Dataset'):
    os.makedirs('/content/Dataset')

print('Loading TMIDIX module...')
os.chdir('/content/tegridy-tools/tegridy-tools')

import TMIDIX

from joblib import Parallel, delayed, parallel_config

print('Done!')

os.chdir('/content/')
print('Enjoy! :)')

# (DOWNLOAD AND UNZIP DATASETS)

In [None]:
# @title MIDI Dataset
%cd /content/Dataset
!wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Mono-Melodies/Piano-Violin/Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.001
!wget https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/master/Mono-Melodies/Piano-Violin/Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.002
!cat Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip* > Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip
!unzip Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip
!rm Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.001
!rm Mono-Melodies-Piano-Violin-CC-BY-NC-SA.zip.002
%cd /content/

# (FILE LIST)

# (PROCESS)

In [None]:
#@title TMIDIX MIDI Processor

print('=' * 70)
print('Loading TMIDIX MIDI Processor...')
print('=' * 70)

def TMIDIX_MIDI_Processor(midi_file):

    melody_chords = []

    try:

        fn = os.path.basename(midi_file)

        # Filtering out GIANT4 MIDIs
        file_size = os.path.getsize(midi_file)

        if file_size <= 1000000:

          #=======================================================
          # START PROCESSING

          # Convering MIDI to ms score with MIDI.py module
          score = TMIDIX.midi2single_track_ms_score(open(midi_file, 'rb').read(), recalculate_channels=False)

          # INSTRUMENTS CONVERSION CYCLE
          events_matrix = []
          itrack = 1
          patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

          while itrack < len(score):
              for event in score[itrack]:
                  if event[0] == 'note' or event[0] == 'patch_change':
                      events_matrix.append(event)
              itrack += 1

          events_matrix.sort(key=lambda x: x[1])

          events_matrix1 = []

          for event in events_matrix:
                  if event[0] == 'patch_change':
                        patches[event[2]] = event[3]

                  if event[0] == 'note':
                        event.extend([patches[event[3]]])

                        if events_matrix1:
                            if (event[1] == events_matrix1[-1][1]):
                                if ([event[3], event[4]] != events_matrix1[-1][3:5]):
                                    events_matrix1.append(event)
                            else:
                                events_matrix1.append(event)

                        else:
                            events_matrix1.append(event)

        if len(events_matrix1) > 0:
            if min([e[1] for e in events_matrix1]) >= 0 and min([e[2] for e in events_matrix1]) >= 0:

                #=======================================================
                # PRE-PROCESSING

                # checking number of instruments in a composition
                instruments_list = list(set([y[3] for y in events_matrix1]))

                if len(events_matrix1) > 0:

                    #===================================
                    # ORIGINAL COMPOSITION
                    #===================================

                    # Adjusting timings

                    for e in events_matrix1:
                      e[1] = int(e[1] / 16)
                      e[2] = int(e[2] / 16)

                    # Sorting by patch, pitch, then by start-time

                    events_matrix1.sort(key=lambda x: x[6])
                    events_matrix1.sort(key=lambda x: x[4], reverse=True)
                    events_matrix1.sort(key=lambda x: x[1])

                    #=======================================================
                    # FINAL PROCESSING

                    #=======================================================
                    # MAIN PROCESSING CYCLE
                    #=======================================================

                    pe = events_matrix1[0]

                    notes = []

                    for e in events_matrix1:

                      time = max(0, min(127, (e[1] - pe[1])))
                      dur = max(0, min(127, e[2]))
                      cha = max(0, min(15, e[3]))
                      ptc = max(1, min(127, e[4]))

                      if cha == 3:
                        cha = 1

                      notes.append([time, dur, cha, ptc])

                      pe = e

                    return notes

    except:
      return None

print('Done!')
print('=' * 70)

In [None]:
#@title Save file list
###########

print('=' * 70)
print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/content/Dataset"

# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

if not filez:
    print('Could not find any MIDI files. Please check Dataset dir...')
    print('=' * 70)

else:
  print('Randomizing file list...')
  random.shuffle(filez)
  print('Done!')
  print('=' * 70)
  print('Total files:', len(filez))
  print('=' * 70)

In [None]:
#@title Process MIDIs with TMIDIX MIDI processor

print('=' * 70)
print('TMIDIX MIDI Processor')
print('=' * 70)
print('Starting up...')
print('=' * 70)

###########

melody_chords_f = []

print('Processing MIDI files. Please wait...')
print('=' * 70)

for i in tqdm(range(0, len(filez), 16)):

  with parallel_config(backend='threading', n_jobs=16, verbose = 0):

    output = Parallel()(delayed(TMIDIX_MIDI_Processor)(f) for f in filez[i:i+16])

    for o in output:

        if o is not None:
            melody_chords_f.append(o)

print('Done!')
print('=' * 70)

In [None]:
melody_chords_f[1]

In [None]:
TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/content/Processed_MIDIs')

In [None]:
melody_chords_f = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/Processed_MIDIs')

In [None]:
#@title Test INTs

train_data1 = melody_chords_f[4]

#train_data1 = max(melody_chords_f, key = len)

print('Sample INTs', train_data1[:15])

out = train_data1

patches = [0] * 16
patches[1] = 40

if len(out) != 0:

    song = out
    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0


    for s in song:


        time += s[0] * 16
        dur = s[1] * 16
        channel = s[2]
        pitch = s[3]


        song_f.append(['note', time, dur, channel, pitch, vel ])



detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                          output_signature = 'Experimental Music Transformer',
                                                          output_file_name = '/content/Experimental-Music-Transformer-Composition',
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches
                                                          )

print('Done!')

In [None]:
len(melody_chords_f[0])

# (TRAIN DATA)

In [None]:
train_data = []

for m in tqdm.tqdm(melody_chords_f):

    cha = m[0][2]

    dat = [1025, ((cha * 128) + m[0][3])+640, 0]

    for mm in m:

        cha = mm[2]

        if mm[0] != 0:
            dat.extend([mm[0], mm[1]+128, ((cha * 128) + mm[3])+256])
        else:
            dat.extend([mm[1]+128, ((cha * 128) + mm[3])+256])

    dat = dat[:1025]

    ids = [] # 0 - 256 and 640 - 1024
    nums = [] # 256 - 640
    masks = [] # 1024

    for d in dat:
      if 0 <= d < 256:
        ids.append(d)
        nums.append(-1)
        masks.append(False)

      if 256 <= d < 640:
        ids.append(1024)
        nums.append(d)
        masks.append(True)

    ids += [1026] * (1025 - len(ids))
    nums += [-1] * (1025 - len(nums))
    masks += [False] * (1025 - len(masks))

    train_data.append([ids, nums, masks])

# Total dict size 1027


In [None]:
len(train_data), max(train_data, key=len) == min(train_data, key=len)

In [None]:
train_data[555][:8]

In [None]:
random.shuffle(train_data)

In [None]:
TMIDIX.Tegridy_Any_Pickle_File_Writer(train_data, '/content/INTs')

In [None]:
train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/INTs)

In [None]:
len(max(train_data[0], key=len)), len(min(train_data[0], key=len))

# (TRAIN MODEL)

In [None]:
!pip install x-transformers
!pip install einops
!pip install torch-summary

In [None]:
import torch
torch.__version__

In [None]:
# Load modules and make data dir

print('Loading modules...')

import os
import pickle
import random
import secrets
import tqdm
import math
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /content/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /content/tegridy-tools/tegridy-tools/X-Transformer

from x_transformers import (
    Decoder,
    XValTransformerWrapper,
    XValAutoregressiveWrapper
)

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn

%cd /content/

if not os.path.exists('/content/INTS'):
    os.makedirs('/content/INTS')

import random

print('Done')

In [None]:
train_data = TMIDIX.Tegridy_Any_Pickle_File_Reader('/content/INTs')

In [None]:
len(train_data) // 8

In [None]:
# @title Setup and init the model

# constants

SEQ_LEN = 8192 # Models seq len
PAD_IDX = 1026 # Models pad index

BATCH_SIZE = 4
NUM_EPOCHS = 100
GRADIENT_ACCUMULATE_EVERY = 4


LEARNING_RATE = 2e-4

VALIDATE_EVERY  = 100
SAVE_EVERY = 500
GENERATE_EVERY  = 100
PRINT_STATS_EVERY = 20

GENERATE_LENGTH = 32

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate the model

model = XValTransformerWrapper(
    num_tokens = 1027,
    numerical_token_id = 1024,
    max_seq_len = 1024,
    attn_layers = Decoder(
        dim = 1024,
        depth = 8,
        heads = 8,

    )
)

# wrap it with the xval autoregressive wrapper

model = XValAutoregressiveWrapper(model, ignore_index=PAD_IDX)

model.cuda()

print('Done!')

summary(model)

# Dataloader

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):

        ids = torch.Tensor(self.data[index][0][:self.seq_len+1]).long()
        nums = torch.Tensor(self.data[index][1][:self.seq_len+1]).long()
        masks = torch.Tensor(self.data[index][2][:self.seq_len+1]).bool()

        return ids.cuda(), nums.cuda(), masks.cuda()

    def __len__(self):
        return (len(self.data) // BATCH_SIZE) * BATCH_SIZE

# precision/optimizer/scaler

dtype = torch.float16

ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)

optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

scaler = torch.cuda.amp.GradScaler(enabled=True)

In [None]:
random.shuffle(train_data)

train_dataset = MusicDataset(train_data, SEQ_LEN)
val_dataset   = MusicDataset(train_data, SEQ_LEN)
train_loader  = DataLoader(train_dataset, batch_size = BATCH_SIZE)
val_loader    = DataLoader(val_dataset, batch_size = BATCH_SIZE)

In [None]:
# @title Train the model
torch.cuda.empty_cache()
train_losses = []
val_losses = []

train_accs = []
val_accs = []

nsteps = 0

PRINT_STATS_EVERY = 200

for epoch in range(NUM_EPOCHS):  # replace NUM_EPOCHS with the actual number of epochs

      print('=' * 70)
      print('Epoch #', epoch)
      print('=' * 70)
      model.train()  # set the model to training mode
      total_loss = 0
      optimizer.zero_grad(set_to_none=True)  # Initialize gradients to zero at the start of the epoch

      for batch_idx, batch in enumerate(tqdm.tqdm(train_loader)):  # iterate over batches of data
          ids, nums, masks = batch  # unpack the source and target tensors from the current batch

          with torch.cuda.amp.autocast():
              loss = model(ids, nums, mask=masks)  # forward pass

          # loss = loss / GRADIENT_ACCUMULATE_EVERY  # Normalize the loss by the number of accumulation steps
          # scaler.scale(loss).backward()  # Backward pass with gradient scaling

          train_losses.append(loss.mean().item() * GRADIENT_ACCUMULATE_EVERY)
          # train_accs.append(acc.mean().item())


          if (batch_idx + 1) % GRADIENT_ACCUMULATE_EVERY == 0:  # Perform optimization step after accumulating gradients
              # scaler.unscale_(optimizer)
              torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
              #scaler.step(optimizer)
              #scaler.update()
              optimizer.step()
              optimizer.zero_grad(set_to_none=True)  # Reset gradients after optimization step

          total_loss += loss.item() * GRADIENT_ACCUMULATE_EVERY  # Undo the normalization for logging

          if nsteps % PRINT_STATS_EVERY == 0:
              # print(f'Training Loss: {total_loss / (batch_idx + 1)}, Accuracy: {acc.item()}')
              print(f'Training Loss: {total_loss / (batch_idx + 1)}')



      nsteps += 1

      '''if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
          with ctx:
            val_loss, val_acc = model(next(val_loader))

            print(f'Validation loss: {val_loss.mean().item()}')
            print(f'Validation acc: {val_acc.mean().item()}')

            val_losses.append(val_loss.mean().item())
            val_accs.append(val_acc.mean().item())

            print('Plotting training loss graph...')

            tr_loss_list = train_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting training acc graph...')

            tr_loss_list = train_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

            print('Plotting validation acc graph...')
            tr_loss_list = val_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')'''

      '''if i % GENERATE_EVERY == 0:
        model.eval()

        inp = random.choice(val_dataset)[:-1]

        print(inp)

        with ctx:

            sample = model.generate(inp[None, ...], GENERATE_LENGTH)

        print(sample)'''

      if i % SAVE_EVERY == 0:

          print('Saving model progress. Please wait...')
          print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

          fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

          torch.save(model.state_dict(), fname)

          data = [train_losses, train_accs, val_losses, val_accs]

          TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs')

          print('Done!')

#======================================================================================================

print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

torch.save(model.state_dict(), fname)

print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')

# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/content/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/content/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/content/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/content/validation_acc_graph.png')
plt.close()
print('Done!')

In [None]:
print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

fname = '/content/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

torch.save(model.state_dict(), fname)

print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accuracies')

# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/content/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/content/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/content/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/content/validation_acc_graph.png')
plt.close()
print('Done!')

# EVAL

In [None]:
dtype = 'float16'
device_type = 'cuda'
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)

In [None]:
model.eval()

x = torch.tensor(train_data[2][:900], dtype=torch.long, device='cuda')[None, ...]
#x = torch.tensor([[1024]] * 1, dtype=torch.long, device='cuda')

# run generation

with ctx:
    out = model.generate(x,
                        1023,
                        temperature=1,
                        return_prime=False,
                        verbose=True)

y = out.tolist()

print('---------------')

In [None]:
print(y)

In [None]:
#@title Test INTs

train_data1 = out3 # y[0]

#train_data1 = max(melody_chords_f, key = len)

print('Sample INTs', train_data1[:15])

out = train_data1

patches = [0] * 16
patches[3] = 40

if len(out) != 0:

    song = out
    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    for ss in tqdm.tqdm(song):

        if 0 <= ss < 256:

            time += (ss * 16)

        if 256 <= ss < 512:

            dur = (ss-256) * 16

        if 512 <= ss < 640:

            pitch = ss-512

        if 640 <= ss < 642:

            channel = ss-640

            if channel == 1:
                channel = 3

        if 642 <= ss < 770:
            vel = ss-642

            song_f.append(['note', time, dur, channel, pitch, vel ])

detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                          output_signature = 'Experimental Music Transformer',
                                                          output_file_name = '/content/Experimental-Music-Transformer-Composition',
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches
                                                          )

print('Done!')

In [None]:
tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()

cos_sim = metrics.pairwise_distances(
  tok_emb, metric='cosine'
)
plt.figure(figsize=(7, 7))
plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/content/Experimental-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")

# Congrats! You did it! :)