# Giant Music Transformer Auto Generator (ver. 1.0)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2024

***

# (GPU CHECK)

In [None]:
#@title NVIDIA GPU check
!nvidia-smi

# (SETUP ENVIRONMENT)

In [None]:
#@title Install dependencies
!git clone --depth 1 https://github.com/asigalov61/Giant-Music-Transformer
!sudo pip install -U torch
!pip install -U torch
!sudo pip install einops
!pip install einops
!sudo pip install torch-summary
!sudo pip install matplotlib
!pip install matplotlib

!sudo pip install tqdm

!pip install huggingface_hub
!sudo pip install huggingface_hub

!sudo pip install ipywidgets

In [None]:
#@title Import modules

print('=' * 70)
print('Loading core Giant Music Transformer modules...')

import os
import copy
import pickle
import secrets
import statistics
from time import time
import tqdm

print('=' * 70)
print('Loading main Giant Music Transformer modules...')

os.environ['USE_FLASH_ATTENTION'] = '1'

import torch

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(True)

%cd /home/ubuntu/Giant-Music-Transformer

import TMIDIX

from x_transformer_1_23_2 import *

import random

import numpy as np

%cd /home/ubuntu/

print('=' * 70)
print('Loading aux Giant Music Transformer modules...')

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

from IPython.display import Audio, display

from huggingface_hub import hf_hub_download

print('=' * 70)
print('PyTorch version:', torch.__version__)
print('=' * 70)
print('Done!')
print('Enjoy! :)')
print('=' * 70)

# (LOAD MODEL)

In [None]:
#@title Load Giant Music Transformer Pre-Trained Model

#@markdown Choose model

select_model_to_load = "482M-8L-Ultra-Fast-Medium" # @param ["482M-8L-Ultra-Fast-Medium","585M-32L-Very-Fast-Large","786M-44L-Fast-Extra-Large"]

#@markdown Model precision option

model_precision = "bfloat16" # @param ["bfloat16", "float16"]

#@markdown bfloat16 == Half precision/faster speed (if supported, otherwise the model will default to float16)

#@markdown float16 == Full precision/fast speed

plot_tokens_embeddings = "None" # @param ["None", "Start Times", "Durations Velocities", "Piano Pitches", "Drums Pitches", "Aux"]

print('=' * 70)
print('Loading Giant Music Transformer', select_model_to_load,'Pre-Trained Model...')
print('Please wait...')
print('=' * 70)

full_path_to_models_dir = "/home/ubuntu/Giant-Music-Transformer/Models"

if select_model_to_load == '786M-44L-Fast-Extra-Large':

  model_checkpoint_file_name = 'Giant_Music_Transformer_Extra_Large_Trained_Model_18001_steps_0.2657_loss_0.9272_acc.pth'
  model_path = full_path_to_models_dir+'/Extra Large/'+model_checkpoint_file_name

  mdim = 1024
  num_layers = 44
  mrpe = False

  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/home/ubuntu/Giant-Music-Transformer/Models/Extra Large',
                    )

elif select_model_to_load == '585M-32L-Very-Fast-Large':

  model_checkpoint_file_name = 'Giant_Music_Transformer_Large_Trained_Model_36074_steps_0.3067_loss_0.927_acc.pth'
  model_path = full_path_to_models_dir+'/Large/'+model_checkpoint_file_name

  mdim = 1024
  num_layers = 32
  mrpe = False

  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/home/ubuntu/Giant-Music-Transformer/Models/Large',
                    )

elif select_model_to_load == '482M-8L-Ultra-Fast-Medium':

  model_checkpoint_file_name = 'Giant_Music_Transformer_Medium_Trained_Model_42174_steps_0.5211_loss_0.8542_acc.pth'
  model_path = full_path_to_models_dir+'/Medium/'+model_checkpoint_file_name

  mdim = 2048
  num_layers = 8
  mrpe = True

  if os.path.isfile(model_path):
    print('Model already exists...')

  else:
    hf_hub_download(repo_id='asigalov61/Giant-Music-Transformer',
                    filename=model_checkpoint_file_name,
                    local_dir='/home/ubuntu/Giant-Music-Transformer/Models/Medium',
                    )

print('=' * 70)
print('Instantiating model...')

device_type = 'cuda'

if model_precision == 'bfloat16' and torch.cuda.is_bf16_supported():
  dtype = 'bfloat16'
else:
  dtype = 'float16'

if model_precision == 'float16':
  dtype = 'float16'

ptdtype = {'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = torch.amp.autocast(device_type=device_type, dtype=ptdtype)

SEQ_LEN = 8192
PAD_IDX = 19463

model = TransformerWrapper(
        num_tokens = PAD_IDX+1,
        max_seq_len = SEQ_LEN,
        attn_layers = Decoder(dim = mdim,
                              depth = num_layers,
                              heads = 32,
                              rotary_pos_emb = mrpe,
                              attn_flash = True
                              )
)

model = AutoregressiveWrapper(model, ignore_index=PAD_IDX, pad_value=PAD_IDX, return_cache=True)

print('=' * 70)
print('Loading model checkpoint...')

model.load_state_dict(torch.load(model_path))

print('=' * 70)

model.cuda()
model.eval()

print('Done!')
print('=' * 70)
print('Model will use', dtype, 'precision...')
print('=' * 70)

# Model stats
print('Model summary...')
summary(model)

# Plot Token Embeddings
if plot_tokens_embeddings != 'None':
  tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()

if plot_tokens_embeddings == 'Start Times':
  tok_range = [0, 256]

elif plot_tokens_embeddings == 'Durations Velocities':
  tok_range = [256, 2304]

elif plot_tokens_embeddings == 'Piano Pitches':
  tok_range = [2304, 2304+128]

elif plot_tokens_embeddings == 'Drums Pitches':
  tok_range = [18945-128, 18945]

elif plot_tokens_embeddings == 'Aux':
  tok_range = [18945, 19465]

if plot_tokens_embeddings != 'None':

  tok_emb1 = []

  for t in tok_emb[tok_range[0]:tok_range[1]]:
    tok_emb1.append(t)

  cos_sim = metrics.pairwise_distances(
    tok_emb1, metric='cosine'
  )
  plt.figure(figsize=(7, 7))
  plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
  im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
  plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
  plt.xlabel("Position")
  plt.ylabel("Position")
  plt.tight_layout()
  plt.plot()
  plt.savefig("/home/ubuntu/Giant-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")

# (LOAD FUNCTIONS)

In [None]:
def load_midi(input_midi):

    raw_score = TMIDIX.midi2single_track_ms_score(input_midi)
    
    escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)
    
    escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes[0], timings_divider=16)

    instruments_list = list(set([y[6] for y in escore_notes]))

    #=======================================================
    # FINAL PROCESSING
    #=======================================================
    
    melody_chords = []

    # Break between compositions / Intro seq
    
    if 128 in instruments_list:
      drums_present = 19331 # Yes
    else:
      drums_present = 19330 # No
    
    pat = escore_notes[0][6]
    
    melody_chords.extend([19461, drums_present, 19332+pat]) # Intro seq
    
    #=======================================================
    # MAIN PROCESSING CYCLE
    #=======================================================
    
    pe = escore_notes[0]
    
    for e in escore_notes:
    
        #=======================================================
        # Timings...
        
        # Cliping all values...
        delta_time = max(0, min(255, e[1]-pe[1]))
        
        # Durations and channels
        
        dur = max(0, min(255, e[2]))
        cha = max(0, min(15, e[3]))
        
        # Patches
        if cha == 9: # Drums patch will be == 128
          pat = 128
        
        else:
          pat = e[6]
        
        # Pitches
        
        ptc = max(1, min(127, e[4]))
        
        # Velocities
        
        # Calculating octo-velocity
        vel = max(8, min(127, e[5]))
        velocity = round(vel / 15)-1
        
        #=======================================================
        # FINAL NOTE SEQ
        #=======================================================
        
        # Writing final note asynchronously
        
        dur_vel = (8 * dur) + velocity
        pat_ptc = (129 * pat) + ptc
        
        melody_chords.extend([delta_time, dur_vel+256, pat_ptc+2304])
        
        pe = e

    return melody_chords

In [None]:
def get_embeddings(inputs):

    with ctx:
        with torch.no_grad():
            out = model(inputs)

    cache = out[2]
    
    inp_embeddings = cache.layer_hiddens[-1]
    
    return inp_embeddings.cpu().detach().numpy()

In [None]:
def select_best_output(outputs, embeddings, prime_embeddings):

    emb_sims = []

    for emb in embeddings:
        emb_sims.append(overall_cosine_similarity(prime_embeddings, emb))

    max_emb_sim = max(emb_sims)

    max_emb_idx = emb_sims.index(max_emb_sim)

    return outputs[max_emb_idx]

In [None]:
def overall_cosine_similarity(A, B):

    A_norm = A / np.linalg.norm(A, axis=1, keepdims=True)
    B_norm = B / np.linalg.norm(B, axis=1, keepdims=True)
    
    similarity_matrix = np.dot(A_norm, B_norm.T)
    
    overall_similarity = np.mean(similarity_matrix)
    
    return overall_similarity

# (GENERATE)

In [None]:
melody_chords= load_midi('/home/ubuntu/Giant-Music-Transformer/Seeds/Giant-Music-Transformer-Piano-Seed-1.mid')

In [None]:
num_prime_tokens = 600
num_gen_tokens = 600
num_blocks_to_generate = 5

num_batches = 4

temperature = 0.9
sampling_top_p_value = 0.96

#=======================================================================

print('=' * 70)
print('Giant Music Transformer Model Auto Generator')
print('=' * 70)


torch.cuda.empty_cache()

prime = melody_chords[:num_prime_tokens]

song = prime

prime_embeddings = get_embeddings(torch.LongTensor([prime]).cuda())[0]

for i in tqdm.tqdm(range(num_blocks_to_generate)):

  try:

    inp = torch.LongTensor([song] * num_batches).cuda()
    
    with ctx:
        with torch.inference_mode():
            out = model.generate(inp,
                                num_gen_tokens,
                                filter_logits_fn=top_p,
                                filter_kwargs={'thres': sampling_top_p_value},
                                temperature=temperature,
                                return_prime=False,
                                verbose=False)

            
            embeddings = get_embeddings(out)

            song.extend(select_best_output(out, embeddings, prime_embeddings).tolist())
            
  except KeyboardInterrupt:
    print('Stopping inpainting...')
    break

  except Exception as e:
    print('Error', e)
    break

print('Done!')
print('=' * 70)

torch.cuda.empty_cache()

#==================================================

print('Rendering results...')
print('=' * 70)

if len(song) != 0:

    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    patches = [-1] * 16

    channels = [0] * 16
    channels[9] = 1

    for ss in song:

        if 0 <= ss < 256:

            time += ss * 16

        if 256 <= ss < 2304:

            dur = ((ss-256) // 8) * 16
            vel = (((ss-256) % 8)+1) * 15

        if 2304 <= ss < 18945:

            patch = (ss-2304) // 129

            if patch < 128:

                if patch not in patches:
                  if 0 in channels:
                      cha = channels.index(0)
                      channels[cha] = 1
                  else:
                      cha = 15

                  patches[cha] = patch
                  channel = patches.index(patch)
                else:
                  channel = patches.index(patch)

            if patch == 128:
                channel = 9

            pitch = (ss-2304) % 129

            song_f.append(['note', time, dur, channel, pitch, vel, patch ])

    patches = [0 if x==-1 else x for x in patches]

    detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                              output_signature = 'Giant Music Transformer',
                                                              output_file_name = '/home/ubuntu/Giant-Music-Transformer-Composition',
                                                              track_name='Project Los Angeles',
                                                              list_of_MIDI_patches=patches
                                                              )


    print('=' * 70)
    print('Displaying resulting composition...')
    print('=' * 70)

    TMIDIX.plot_ms_SONG(song_f, plot_title='Giant Music Transformer Composition')

# Congrats! You did it! :)