# Giant Music Transformer MIDI Search (ver. 1.0)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2024

***

# (GPU CHECK)

In [None]:
#@title NVIDIA GPU check
!nvidia-smi

# (SETUP ENVIRONMENT)

In [None]:
#@title Install dependencies
!git clone --depth 1 https://github.com/asigalov61/Giant-Music-Transformer

!pip install -U torch
!sudo pip install -U torch

!pip install einops
!sudo pip install einops

!pip install -U joblib
!sudo pip install -U joblib

!sudo pip install torch-summary

!pip install matplotlib
!sudo pip install matplotlib

!sudo pip install ipywidgets

# (IMPORT MODULES)

In [None]:
#@title Import modules

print('=' * 70)
print('Loading core Giant Music Transformer modules...')

import os
import copy
import pickle
import secrets
import statistics
from time import time
import tqdm

print('=' * 70)
print('Loading main Giant Music Transformer modules...')

os.environ['USE_FLASH_ATTENTION'] = '1'

import torch

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(True)

%cd /home/ubuntu/Giant-Music-Transformer/

import TMIDIX

from x_transformer_1_23_2 import *

import random

import numpy as np

from joblib import Parallel, delayed, parallel_config

%cd /home/ubuntu/
print('=' * 70)
print('Loading aux Giant Music Transformer modules...')

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

from IPython.display import Audio, display

from huggingface_hub import hf_hub_download

print('=' * 70)
print('Setting-up dataset dir...')

os.makedirs(os.path.dirname('/home/ubuntu/Dataset/'), exist_ok=True)

print('=' * 70)
print('PyTorch version:', torch.__version__)
print('=' * 70)
print('Done!')
print('Enjoy! :)')
print('=' * 70)

# (LOAD MODEL)

In [None]:
#@title Load Giant Music Transformer Pre-Trained Model

#@markdown Choose model

select_model_to_load = "482M-8L-Ultra-Fast-Medium" # @param ["482M-8L-Ultra-Fast-Medium","585M-32L-Very-Fast-Large","786M-44L-Fast-Extra-Large"]

#@markdown Model precision option

model_precision = "bfloat16" # @param ["bfloat16", "float16"]

#@markdown bfloat16 == Half precision/faster speed (if supported, otherwise the model will default to float16)

#@markdown float16 == Full precision/fast speed

plot_tokens_embeddings = "None" # @param ["None", "Start Times", "Durations Velocities", "Piano Pitches", "Drums Pitches", "Aux"]

print('=' * 70)
print('Loading Giant Music Transformer', select_model_to_load,'Pre-Trained Model...')
print('Please wait...')
print('=' * 70)

full_path_to_models_dir = "/home/ubuntu/Giant-Music-Transformer/Models"

if select_model_to_load == '786M-44L-Fast-Extra-Large':

  model_checkpoint_file_name = 'Giant_Music_Transformer_Extra_Large_Trained_Model_18001_steps_0.2657_loss_0.9272_acc.pth'
  model_path = full_path_to_models_dir+'/Extra Large/'+model_checkpoint_file_name

  mdim = 1024
  num_layers = 44
  mrpe = False

  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/home/ubuntu/Giant-Music-Transformer/Models/Extra Large',
                    )

elif select_model_to_load == '585M-32L-Very-Fast-Large':

  model_checkpoint_file_name = 'Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth'
  model_path = full_path_to_models_dir+'/Large/'+model_checkpoint_file_name

  mdim = 1024
  num_layers = 32
  mrpe = False

  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/home/ubuntu/Giant-Music-Transformer/Models/Large',
                    )

elif select_model_to_load == '482M-8L-Ultra-Fast-Medium':

  model_checkpoint_file_name = 'Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
  model_path = full_path_to_models_dir+'/Medium/'+model_checkpoint_file_name

  mdim = 2048
  num_layers = 8
  mrpe = True

  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/home/ubuntu/Giant-Music-Transformer/Models/Medium',
                    )

print('=' * 70)
print('Instantiating model...')

device_type = 'cuda'

if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():
  dtype = 'bfloat16'
else:
  dtype = 'float16'

if model_precision == 'float16':
  dtype = 'float16'

ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)

SEQ_LEN = 8192
PAD_IDX = 19463

model = TransformerWrapper(
        num_tokens = PAD_IDX+1,
        max_seq_len = SEQ_LEN,
        attn_layers = Decoder(dim = mdim,
                              depth = num_layers,
                              heads = 32,
                              rotary_pos_emb = mrpe,
                              attn_flash = True
                              )
)

model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX, return_cache=True)

print('=' * 70)
print('Loading model checkpoint...')

model.load_state_dict(torch.load(model_path))

print('=' * 70)

model.cuda()
model.eval()

print('Done!')
print('=' * 70)
print('Model will use', dtype, 'precision...')
print('=' * 70)

# Model stats
print('Model summary...')
summary(model)

# Plot Token Embeddings
if plot_tokens_embeddings != 'None':
  tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()

if plot_tokens_embeddings == 'Start Times':
  tok_range = [0, 256]

elif plot_tokens_embeddings == 'Durations Velocities':
  tok_range = [256, 2304]

elif plot_tokens_embeddings == 'Piano Pitches':
  tok_range = [2304, 2304+128]

elif plot_tokens_embeddings == 'Drums Pitches':
  tok_range = [18945-128, 18945]

elif plot_tokens_embeddings == 'Aux':
  tok_range = [18945, 19465]

if plot_tokens_embeddings != 'None':

  tok_emb1 = []

  for t in tok_emb[tok_range[0]:tok_range[1]]:
    tok_emb1.append(t)

  cos_sim = metrics.pairwise_distances(
    tok_emb1, metric='cosine'
  )
  plt.figure(figsize=(7, 7))
  plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
  im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
  plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
  plt.xlabel("Position")
  plt.ylabel("Position")
  plt.tight_layout()
  plt.plot()
  plt.savefig("/home/ubuntu/Giant-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")

# (LOAD FUNCTIONS)

In [None]:
def load_midi(input_midi):

    raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
    
    escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)

    if escore_notes:
    
        escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16)
    
        instruments_list = list(set([y[6] for y in escore_notes]))
    
        #=======================================================
        # FINAL PROCESSING
        #=======================================================
        
        melody_chords = []
    
        # Break between compositions / Intro seq
        
        if 128 in instruments_list:
          drums_present = 19331 # Yes
        else:
          drums_present = 19330 # No
        
        pat = escore_notes[0][6]
        
        melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
        
        #=======================================================
        # MAIN PROCESSING CYCLE
        #=======================================================
        
        pe = escore_notes[0]
        
        for e in escore_notes:
        
            #=======================================================
            # Timings...
            
            # Cliping all values...
            delta_time = max(0, min(255, e[1]-pe[1]))
            
            # Durations and channels
            
            dur = max(0, min(255, e[2]))
            cha = max(0, min(15, e[3]))
            
            # Patches
            if cha == 9: # Drums patch will be == 128
              pat = 128
            
            else:
              pat = e[6]
            
            # Pitches
            
            ptc = max(1, min(127, e[4]))
            
            # Velocities
            
            # Calculating octo-velocity
            vel = max(8, min(127, e[5]))
            velocity = round(vel / 15)-1
            
            #=======================================================
            # FINAL NOTE SEQ
            #=======================================================
            
            # Writing final note asynchronously
            
            dur_vel = (8 * dur) + velocity
            pat_ptc = (129 * pat) + ptc
            
            melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
            
            pe = e
    
        return melody_chords

    else:
        return escore_notes

In [None]:
def process_midis(midi_file):

    emb_step = 8190

    try:

        fdir = midi_file.split('/')[-2]
    
        fn1 = midi_file.split('/')[-1].split('.mid')[0]
    
        src_inp = load_midi(midi_file)
    
        if src_inp:
    
            for j in range(0, len(src_inp), emb_step):
        
                inp_chunk = src_inp[j:j+emb_step+1]
            
                inp = torch.LongTensor([inp_chunk]).to('cuda')
                
                with ctx:
                    with torch.no_grad():
                        out = model(inp)
                
                cache = out[2]
                src_embeddings = cache.layer_hiddens[-1]
        
                if j > 0:
                    src_embeddings_np = np.concatenate((src_embeddings_np, src_embeddings.cpu().detach().numpy()[0]))
        
                else:
                    src_embeddings_np = src_embeddings.cpu().detach().numpy()[0]
        
        
                src_embeddings_matrix = np.full((18945, src_embeddings_np.shape[1]), -10000)
                
                for i, emb in enumerate(src_embeddings_np):
                    token = src_inp[i]
                
                    if 0 <= token < 18945:
                        if i > 0:
                            src_embeddings_matrix[token] = np.max(np.array([src_embeddings_matrix[token], emb]), axis=0)
                        else:
                            src_embeddings_matrix[token] = emb
                
            src_embeddings_matrix = np.mean(src_embeddings_matrix[~np.all(src_embeddings_matrix == -10000, axis=1)], axis=0)
                        
            return [fdir, fn1, src_embeddings_matrix.tolist()]

        else:
            return None
            
    except Exception as ex:
        print(ex)
        print(midi_file)
        
        return None

In [None]:
def get_midi_embeddings(input_midi):

    emb_step = 8190

    raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
    
    escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)

    if escore_notes:
    
        escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16)
    
        instruments_list = list(set([y[6] for y in escore_notes]))
    
        #=======================================================
        # FINAL PROCESSING
        #=======================================================
        
        melody_chords = []
    
        # Break between compositions / Intro seq
        
        if 128 in instruments_list:
          drums_present = 19331 # Yes
        else:
          drums_present = 19330 # No
        
        pat = escore_notes[0][6]
        
        melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
        
        #=======================================================
        # MAIN PROCESSING CYCLE
        #=======================================================
        
        pe = escore_notes[0]
        
        for e in escore_notes:
        
            #=======================================================
            # Timings...
            
            # Cliping all values...
            delta_time = max(0, min(255, e[1]-pe[1]))
            
            # Durations and channels
            
            dur = max(0, min(255, e[2]))
            cha = max(0, min(15, e[3]))
            
            # Patches
            if cha == 9: # Drums patch will be == 128
              pat = 128
            
            else:
              pat = e[6]
            
            # Pitches
            
            ptc = max(1, min(127, e[4]))
            
            # Velocities
            
            # Calculating octo-velocity
            vel = max(8, min(127, e[5]))
            velocity = round(vel / 15)-1
            
            #=======================================================
            # FINAL NOTE SEQ
            #=======================================================
            
            # Writing final note asynchronously
            
            dur_vel = (8 * dur) + velocity
            pat_ptc = (129 * pat) + ptc
            
            melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
            
            pe = e

        if melody_chords:
    
            src_inp = melody_chords
        
            if src_inp:
        
                for j in range(0, len(src_inp), emb_step):
            
                    inp_chunk = src_inp[j:j+emb_step+1]
                
                    inp = torch.LongTensor([inp_chunk]).to('cuda')
                    
                    with ctx:
                        with torch.no_grad():
                            out = model(inp)
                    
                    cache = out[2]
                    src_embeddings = cache.layer_hiddens[-1]
            
                    if j > 0:
                        src_embeddings_np = np.concatenate((src_embeddings_np, src_embeddings.cpu().detach().numpy()[0]))
            
                    else:
                        src_embeddings_np = src_embeddings.cpu().detach().numpy()[0]
            
            
                    src_embeddings_matrix = np.full((18945, src_embeddings_np.shape[1]), -10000)
                    
                    for i, emb in enumerate(src_embeddings_np):
                        token = src_inp[i]
                    
                        if 0 <= token < 18945:
                            if i > 0:
                                src_embeddings_matrix[token] = np.maximum(src_embeddings_matrix[token], emb)
                            else:
                                src_embeddings_matrix[token] = emb
                    
                src_embeddings_matrix = np.mean(src_embeddings_matrix[~np.all(src_embeddings_matrix == -10000, axis=1)], axis=0)

                return src_embeddings_matrix
                
            else:
                return melody_chords
    else:
        return escore_notes

In [None]:
def overall_cosine_similarity(A, B):
    # Normalize the vectors
    A_norm = A / np.linalg.norm(A, axis=1, keepdims=True)
    B_norm = B / np.linalg.norm(B, axis=1, keepdims=True)
    
    # Compute the cosine similarity matrix
    similarity_matrix = np.dot(A_norm, B_norm.T)
    
    # Compute the overall similarity by averaging the cosine similarities
    overall_similarity = np.mean(similarity_matrix)
    
    return overall_similarity

# (DOWNLOAD AND PREP MIDI DATASET)

In [None]:
%cd /home/ubuntu/Dataset/

!wget 'http://hog.ee.columbia.edu/craffel/lmd/clean_midi.tar.gz'
!tar -xvf 'clean_midi.tar.gz' > /dev/null
!rm 'clean_midi.tar.gz'

%cd /home/ubuntu/

# (PREP MIDIs FILES LIST)

In [None]:
#@title Save file list
###########

print('=' * 70)
print('Loading MIDI files...')
print('This may take a while on a large dataset in particular.')

dataset_addr = "/home/ubuntu/Dataset"
# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames if file.endswith('.mid')]
print('=' * 70)

if filez == []:
    print('Could not find any MIDI files. Please check Dataset dir...')
    print('=' * 70)

print('Randomizing file list...')
random.shuffle(filez)
print('Done!')
print('=' * 70)
print('Found', len(filez), 'MIDIs')
print('=' * 70)

# (PROCESS MIDIs)

In [None]:
#@title Process MIDIs with TMIDIX MIDI processor

NUMBER_OF_PARALLEL_JOBS = 2 # Number of parallel jobs
NUMBER_OF_FILES_PER_ITERATION = 8 # Number of files to queue for each parallel iteration
SAVE_EVERY_NUMBER_OF_ITERATIONS = 160 # Save every...

print('=' * 70)
print('TMIDIX MIDI Processor')
print('=' * 70)
print('Starting up...')
print('=' * 70)

###########

melody_chords_f = []

files_count = 0

print('Processing MIDI files. Please wait...')
print('=' * 70)

for i in tqdm.tqdm(range(0, len(filez), NUMBER_OF_FILES_PER_ITERATION)):

  with parallel_config(n_jobs=NUMBER_OF_PARALLEL_JOBS, verbose = 0):

    output = Parallel(backend='threading', n_jobs=NUMBER_OF_PARALLEL_JOBS, verbose=0)(delayed(process_midis)(f) for f in filez[i:i+NUMBER_OF_FILES_PER_ITERATION])

    for o in output:

        if o is not None:
            melody_chords_f.append(o)

    files_count = len(melody_chords_f)

    # Saving every 2560 processed files
    if i % (NUMBER_OF_FILES_PER_ITERATION * SAVE_EVERY_NUMBER_OF_ITERATIONS) == 0 and i != 0:
        print('SAVING !!!')
        print('=' * 70)
        print('Saving processed files...')
        print('=' * 70)
        print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio')
        print('=' * 70)
        TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/home/ubuntu/processed_MIDIs_embeddings')
        print('=' * 70)

print('SAVING !!!')
print('=' * 70)
print('Saving processed files...')
print('=' * 70)
print('Processed so far:', files_count, 'out of', len(filez), '===', files_count / len(filez), 'good files ratio')
print('=' * 70)
TMIDIX.Tegridy_Any_Pickle_File_Writer(melody_chords_f, '/home/ubuntu/processed_MIDIs_embeddings')
print('=' * 70)

# (LOAD PROCESSED EMBEDDINGS)

In [None]:
all_processed_MIDIs_embeddings = TMIDIX.Tegridy_Any_Pickle_File_Reader('/home/ubuntu/processed_MIDIs_embeddings')

# (SEARCH EMBEDDINGS)

In [None]:
print('Loading MIDI and computing embeddings...')
src_midi = get_midi_embeddings('/home/ubuntu/Giant-Music-Transformer/Seeds/Giant-Music-Transformer-MI-Seed-1.mid')
print('Done!')

In [None]:
max_match_ratio = 1.0

print('=' * 70)
print('Searching embeddings...')
print('=' * 70)

cos_sims = []

for emb in tqdm.tqdm(all_processed_MIDIs_embeddings):
    trg_emb = np.array(emb[2])
    sim = overall_cosine_similarity([src_midi], [trg_emb])

    if sim <= max_match_ratio:
        cos_sims.append(sim)

    else:
        cos_sims.append(0)

print('Done!')
print('=' * 70)
print('Best match ratio:', max(cos_sims))
print('Best match:', all_processed_MIDIs_embeddings[cos_sims.index(max(cos_sims))][:2])
print('=' * 70)

# Congrats! You did it! :)