# Piano Hands Maker (ver. 1.0)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2024

***

# GPU check

In [None]:
!nvidia-smi

# Setup environment

In [None]:
!git clone --depth 1 https://github.com/asigalov61/tegridy-tools

In [None]:
!sudo pip3 install torch torchvision torchaudio
!pip3 install -U torch torchvision torchaudio
!sudo pip install einops
!sudo pip install torch-summary
!sudo pip3 install -U tqdm

In [None]:
# Load modules and make data dir

print('Loading modules...')

import os
import pickle
import random
import secrets
import tqdm
import math
import copy
import gc

!set USE_FLASH_ATTENTION=1
os.environ['USE_FLASH_ATTENTION'] = '1'

import torch
import torch.optim as optim

from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /home/ubuntu/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /home/ubuntu/tegridy-tools/tegridy-tools/X-Transformer

from x_transformer_1_23_2 import *

torch.set_float32_matmul_precision('high')
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_cudnn_sdp(False)

!set USE_FLASH_ATTENTION=1

%cd /home/ubuntu/

if not os.path.exists('/home/ubuntu/Dataset'):
    os.makedirs('/home/ubuntu/Dataset')

import random

print('Done')

print('Torch version:', torch.__version__)

# Donload and unzip Piano Hands MIDI dataset

In [None]:
%cd /home/ubuntu/Dataset/
!wget "https://github.com/asigalov61/Tegridy-MIDI-Dataset/raw/refs/heads/master/Piano-Hands/Piano-Hands-MIDI-Dataset-CC-BY-NC-SA.zip"
!unzip Piano-Hands-MIDI-Dataset-CC-BY-NC-SA.zip > /dev/null
!rm Piano-Hands-MIDI-Dataset-CC-BY-NC-SA.zip
%cd /home/ubuntu/

# Create MIDIs files list

In [None]:
dataset_addr = "/home/ubuntu/Dataset/MIDIs/"

#==========================================================================

filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames if file.endswith('.mid')]
print('=' * 70)

random.shuffle(filez)

print('Loaded', len(filez), 'data files')
print('=' * 70)

# Process MIDIs

In [None]:
good_files_count = 0

events_matrix_final = []

print('Processing MIDI files. Please wait...')

for f in tqdm.tqdm(filez):
    
    try:
        
        fn = os.path.basename(f)
        fn1 = fn.split('.')[0]

        #===============================================================================
        # Raw single-track ms score
        
        raw_score = TMIDIX.midi2single_track_ms_score(f)
        
        #===============================================================================
        # Enhanced score notes
        
        escore_notes = TMIDIX.advanced_score_processor(raw_score, return_enhanced_score_notes=True)[0]
        
        #===============================================================================
        # Augmented enhanced score notes
        
        escore_notes = TMIDIX.augment_enhanced_score_notes(escore_notes, timings_divider=32)

        for tv in range(-3, 4):

            events_matrix4 = [386, 386, 386, 386]
    
            pe = escore_notes[0]
              
            for e in escore_notes:
                
                dtime = max(0, min(127, e[1]-pe[1]))
                
                dur = max(1, min(127, e[2]))
                
                ptc = max(1, min(127, e[4]+tv))
                
                handt = max(0, min(1, e[3]))

                events_matrix4.extend([dtime, dur+128, ptc+256, handt+384])
                
                pe = e
    
            events_matrix_final.append([fn1, events_matrix4])
    
        good_files_count += 1
        
    except KeyboardInterrupt:
        print('Saving current progress and quitting...')
        break  
    
    except Exception as ex:
        print('Bad MIDI:', f)
        print(ex)
        continue

print('=' * 70)
print('Done!')   
print('=' * 70)

print('Resulting Stats:')
print('=' * 70)
print('Total good MIDI Files:', good_files_count)
print('=' * 70)

# Save/Load processed MIDIs

In [None]:
TMIDIX.Tegridy_Any_Pickle_File_Writer(events_matrix_final, '/home/ubuntu/Piano_Hands_Processed_MIDIs')

In [None]:
events_matrix_final = TMIDIX.Tegridy_Any_Pickle_File_Reader('/home/ubuntu/Piano_Hands_Processed_MIDIs')

# Prep Training Data

In [None]:
SEQ_LEN = 2102
PAD_IDX = 390 # Model pad index

#==========================================================================

print('=' * 70)
print('Loading data files...')
print('Please wait...')
print('=' * 70)

train_data = []

chunks_counter = 0

for td in tqdm.tqdm(events_matrix_final):
    
    t = td[1]
    
    if 0 <= max(t) < PAD_IDX: # final data integrity check

        for i in range(0, len(t), 600):

            chunk = t[i: i+1200]
               
            if len(chunk) == 1200:
                
                triplets = []

                for j in range(0, len(chunk), 4):
                    triplets.extend(chunk[j:j+3])

                seq = [387] + triplets + [388] + chunk + [389]
                    
                train_data.append(seq)

                chunks_counter += 1
        
    else:
        print('Bad data!!!')
        #print(t)
        #break

#==========================================================================

print('Done!')
print('=' * 70)
print('Total number of main chunks:', chunks_counter)
print('All data is good:', len(max(train_data, key=len)) == len(min(train_data, key=len)))
print('=' * 70)
print('Randomizing train data...')
random.shuffle(train_data)
print('Done!')
print('=' * 70)
print('Total length of train data:', len(train_data))
print('=' * 70)

In [None]:
len(train_data[0])

In [None]:
train_data[0][:15]

# Setup model

In [None]:
# Setup model

# constants

VALIDATE_EVERY  = 100
SAVE_EVERY = 500
GENERATE_EVERY  = 500
GENERATE_LENGTH = 902
PRINT_STATS_EVERY = 10

NUM_EPOCHS = 3

BATCH_SIZE = 32
GRADIENT_ACCUMULATE_EVERY = 1

LEARNING_RATE = 1e-4
GRAD_CLIP = 1.5

# instantiate the model

model = TransformerWrapper(
    num_tokens = PAD_IDX+1,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 1024, 
                          depth = 4, 
                          heads = 32,
                          rotary_pos_emb = True,
                          attn_flash = True
                         )
    )

model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)

model.cuda()

print('Done!')

summary(model)

# Dataloader

def get_train_data_batch(tdata, index, seq_len, batch_size, pad_idx):

    batch = tdata[(index*batch_size):(index*batch_size)+batch_size]
        
    return torch.LongTensor(batch).cuda()
        
# precision/optimizer/scaler

dtype = torch.bfloat16

ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

scaler = torch.amp.GradScaler('cuda')

# Train the model

In [None]:
# Train the model

train_losses = []
val_losses = []

train_accs = []
val_accs = []

nsteps = 0

for ep in range(NUM_EPOCHS):

        print('=' * 70)
        print('Randomizing train data...')
        random.shuffle(train_data)
        print('=' * 70)
    
        print('=' * 70)
        print('Epoch #', ep)
        print('=' * 70)

        NUM_BATCHES = len(train_data) // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY

        model.train()

        for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):

            optim.zero_grad()

            for j in range(GRADIENT_ACCUMULATE_EVERY):
                with ctx:
                    loss, acc = model(get_train_data_batch(train_data, (i*GRADIENT_ACCUMULATE_EVERY)+j, SEQ_LEN, BATCH_SIZE, PAD_IDX))
                    #loss = loss / GRADIENT_ACCUMULATE_EVERY
                scaler.scale(loss).backward()

            if i % PRINT_STATS_EVERY == 0:
                print(f'Training loss: {loss.item() * GRADIENT_ACCUMULATE_EVERY}')
                print(f'Training acc: {acc.item()}')

            train_losses.append(loss.item() * GRADIENT_ACCUMULATE_EVERY)
            train_accs.append(acc.item())

            scaler.unscale_(optim)
            torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
            scaler.step(optim)
            scaler.update()
            
            nsteps += 1

            if i % VALIDATE_EVERY == 0:
                model.eval()
                with torch.no_grad():
                    with ctx:
                        val_loss, val_acc = model(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))

                        print(f'Validation loss: {val_loss.item()}')
                        print(f'Validation acc: {val_acc.item()}')

                        val_losses.append(val_loss.item())
                        val_accs.append(val_acc.item())

                        print('Plotting training loss graph...')

                        tr_loss_list = train_losses
                        plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
                        plt.show()
                        plt.close()
                        print('Done!')

                        print('Plotting training acc graph...')

                        tr_loss_list = train_accs
                        plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
                        plt.show()
                        plt.close()
                        print('Done!')

                        print('Plotting validation loss graph...')
                        tr_loss_list = val_losses
                        plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
                        plt.show()
                        plt.close()
                        print('Done!')

                        print('Plotting validation acc graph...')
                        tr_loss_list = val_accs
                        plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
                        plt.show()
                        plt.close()
                        print('Done!')
                    
                model.train()
                
            if i % GENERATE_EVERY == 0:
                model.eval()

                inp = random.choice(get_train_data_batch(train_data, i, SEQ_LEN, BATCH_SIZE, PAD_IDX))[:GENERATE_LENGTH]

                print(inp)

                with ctx:
                    sample = model.generate(inp[None, ...], GENERATE_LENGTH)

                print(sample)

                data = sample.tolist()[0]

                print('Sample INTs', data[:15])

                if len(data) != 0:

                    song = data
                    song_f = []

                    time = 0
                    dur = 0
                    vel = 90
                    pitch = 0
                    channel = 0
                
                    patches = [0] * 16
                
                    for ss in song:
                
                        if 0 <= ss < 128:
                        
                            time += ss
                        
                        if 128 <= ss < 256:
                        
                            dur = ss-128
                          
                        if 256 <= ss < 384:
                        
                            pitch = ss-256
                            vel = max(40, pitch)
                        
                        if 384 <= ss < 386:
                
                            channel = ss-384
                        
                            song_f.append(['note', time, dur, channel, pitch, vel, 0])

                detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                                          output_signature = 'Piano Hands Music Transformer',
                                                                          output_file_name = '/home/ubuntu/Piano-Hands-Music-Transformer-Composition',
                                                                          track_name='Project Los Angeles',
                                                                          list_of_MIDI_patches=patches,
                                                                          timings_multiplier=32,
                                                                          )

                print('Done!')

                model.train()

            if i % SAVE_EVERY == 0:

                print('Saving model progress. Please wait...')
                print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

                fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

                torch.save(model.state_dict(), fname)

                data = [train_losses, train_accs, val_losses, val_accs]

                TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accs')

                print('Done!')

# Final Save

In [None]:
print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth')

fname = '/home/ubuntu/model_checkpoint_' + str(nsteps) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss_' + str(round(float(train_accs[-1]), 4)) + '_acc.pth'

torch.save(model.state_dict(), fname)
#torch.save(optim.state_dict(), fname+'_opt')

print('Done!')

data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/home/ubuntu/losses_accuracies')

# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/home/ubuntu/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/home/ubuntu/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/home/ubuntu/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/home/ubuntu/validation_acc_graph.png')
plt.close()
print('Done!')

# Eval the model

In [None]:
# Load/re-load the model

SEQ_LEN = 2102
PAD_IDX = 390 # Model pad index

model = TransformerWrapper(
    num_tokens = PAD_IDX+1,
    max_seq_len = SEQ_LEN,
    attn_layers = Decoder(dim = 1024, 
                          depth = 4, 
                          heads = 32, 
                          rotary_pos_emb = True,  
                          attn_flash = True
                         )
    )

model = AutoregressiveWrapper(model, ignore_index = PAD_IDX, pad_value=PAD_IDX)

print('=' * 70)
print('Loading model checkpoint...')

model_path = '/home/ubuntu/Piano_Hands_Music_Transformer_Trained_Model_3222_steps_0.1875_loss_0.942_acc.pth'

model.load_state_dict(torch.load(model_path))

print('=' * 70)

model.cuda()
model.eval()

print('Done!')

summary(model)

dtype = torch.bfloat16

ctx = torch.amp.autocast(device_type='cuda', dtype=dtype)

In [None]:
# Load source MIDI

midi_file = '/home/ubuntu/tegridy-tools/tegridy-tools/seed2.mid'

score = TMIDIX.midi2score(open(midi_file, 'rb').read())

events_matrix0 = []
original_score_ticks = score[0]
original_score_notes = []
original_score_events = []
itrack = 1

while itrack < len(score):
    for event in score[itrack]:
        
        if event[0] == 'note':
            original_score_notes.append(event)
        else:
            original_score_events.append(event)
            
        events_matrix0.append(event)
        
    itrack += 1    

original_score_notes.sort(key=lambda x: x[4], reverse=True)
original_score_notes.sort(key=lambda x: x[1])

opus = TMIDIX.score2opus([score[0], events_matrix0])
ms_score = TMIDIX.opus2score(TMIDIX.to_millisecs(opus))[1]

ms_score = [y for y in ms_score if y[0] == 'note']

ms_score.sort(key=lambda x: x[4], reverse=True)
ms_score.sort(key=lambda x: x[1])

melody_chords = [[386, 386, 386]]

pe = ms_score[0]
  
for e in ms_score:
    
    dtime = max(0, min(127, abs(int(e[1] / 32)) - abs(int(pe[1] / 32))))
    
    dur = max(1, min(127, abs(int(e[2] / 32))))
    
    ptc = max(1, min(127, e[4]))
    
    melody_chords.append([dtime, dur+128, ptc+256])

    pe = e

  #=======================================================

print('Done!')
print('=' * 70)
print('Composition has', len(melody_chords), 'notes')
print('=' * 70)

In [None]:
# Generate piano hands labels

memory = True # Memory option

song = []
hands_labels = []
hands_labels_last = []

#================================================================================

if len(melody_chords) >= 300:

    for i in range(0, len(melody_chords), 150):
    
        if len(melody_chords[i:i+300]) == 300:
    
            triplets = TMIDIX.flatten(melody_chords[i:i+300])
    
            if i > 0 and memory:
                seq = [387] + triplets + [388] + song[-600:]
                shift1 = 150
                shift2 = 150
     
            else:            
                seq = [387] + triplets + [388]
                shift1 = 0
                shift2 = 150
            
            for m in tqdm.tqdm(melody_chords[i+shift1:i+shift2+150], desc='Block # '+str((i // 150)+1)+' / '+str((len(melody_chords) // 150)-1)):
        
                seq.extend(m)
        
                if m != [386, 386, 386]:
            
                    x = torch.LongTensor(seq).cuda()
                
                    with ctx:
                        out = model.generate(x,
                                             1,
                                             temperature=1.0,
                                             #filter_logits_fn=top_p,
                                             #filter_kwargs={'thres': 0.96},
                                             filter_logits_fn=top_k,
                                             filter_kwargs={'k': 1},
                                             return_prime=False,
                                             verbose=False)
                    
                    y = out.tolist()[0][0]
                
                    seq.append(y)
                    hands_labels.append(y-384)
        
                else:
                    seq.append(386)
        
            if i > 0 and memory:
                song.extend(seq[902+600:])
    
            else:            
                song.extend(seq[902:])
    
        else:
            break

#================================================================================

    if len(melody_chords)-(len(song) // 4) != 0:

        triplets = TMIDIX.flatten(melody_chords[-300:])
        
        if memory:
            seq = [387] + triplets + [388] + song[-600:]
            shift = -150
        
        else:
            seq = [387] + triplets + [388]
            shift = -300
    
        for m in tqdm.tqdm(melody_chords[shift:], desc='Last block'):
        
            seq.extend(m)
        
            x = torch.LongTensor(seq).cuda()
        
            with ctx:
                out = model.generate(x,
                                     1,
                                     temperature=1.0,
                                     #filter_logits_fn=top_p,
                                     #filter_kwargs={'thres': 0.96},
                                     filter_logits_fn=top_k,
                                     filter_kwargs={'k': 1},
                                     return_prime=False,
                                     verbose=False)
            
            y = out.tolist()[0][0]
        
            seq.append(y)
            hands_labels_last.append(y-384)
        
        triplets_remainder = len(melody_chords)-(len(song) // 4)
    
        song.extend(seq[902:][-(triplets_remainder*4):])
        hands_labels.extend(hands_labels_last[-triplets_remainder:])

else:
    
    melody_chords_copy = copy.deepcopy(melody_chords)
    melody_chords_copy[-1][0] = 127

    padding_factor = ((len(melody_chords_copy) // 300)+1)

    melody_chords_padded = melody_chords + melody_chords * padding_factor

    melody_chords_padded = melody_chords_padded[:300]    
    
    triplets = TMIDIX.flatten(melody_chords_padded)

    seq = [387] + triplets + [388]
    
    for m in tqdm.tqdm(melody_chords_padded, desc='Short block'):
    
        seq.extend(m)
    
        x = torch.LongTensor(seq).cuda()
    
        with ctx:
            out = model.generate(x,
                                 1,
                                 temperature=1.0,
                                 #filter_logits_fn=top_p,
                                 #filter_kwargs={'thres': 0.96},
                                 filter_logits_fn=top_k,
                                 filter_kwargs={'k': 1},
                                 return_prime=False,
                                 verbose=False)
        
        y = out.tolist()[0][0]
    
        seq.append(y)
        hands_labels_last.append(y-384)
    
    song.extend(seq[902:][:len(melody_chords)*4])
    hands_labels.extend(hands_labels_last[1:len(melody_chords)])

In [None]:
# Apply generated hands labels to the original MIDI

for i, label in enumerate(hands_labels):
    original_score_notes[i][3] = label # applying label to channel

original_score = sorted(original_score_events+original_score_notes, key=lambda x: x[1])

midi_data = TMIDIX.score2midi([original_score_ticks, original_score])

with open('original_score_MIDI_with_hands_labels.mid', 'wb') as fi:
    fi.write(midi_data) 

In [None]:
# Or save the ms version of the original MIDI with piano parts split

print('Sample INTs', song[:15])

if len(song) != 0:

    song_f = []

    time = 0
    dur = 0
    vel = 90
    pitch = 0
    channel = 0

    patches = [0] * 16

    for ss in song:

        if 0 <= ss < 128:
        
            time += ss
        
        if 128 <= ss < 256:
        
            dur = ss-128
          
        if 256 <= ss < 384:
        
            pitch = ss-256
            vel = max(40, pitch)
        
        if 384 <= ss < 386:

            channel = ss-384
        
            song_f.append(['note', time, dur, channel, pitch, vel, 0])

detailed_stats = TMIDIX.Tegridy_ms_SONG_to_MIDI_Converter(song_f,
                                                          output_signature = 'Piano Hands Music Transformer',
                                                          output_file_name = '/home/ubuntu/Piano-Hands-Music-Transformer-Composition',
                                                          track_name='Project Los Angeles',
                                                          list_of_MIDI_patches=patches,
                                                          timings_multiplier=32,
                                                          )

print('Done!')

# Plot tokens embeddings

In [None]:
tok_emb = model.net.token_emb.emb.weight.detach().cpu().tolist()

cos_sim = metrics.pairwise_distances(
  tok_emb, metric='cosine'
)
plt.figure(figsize=(7, 7))
plt.imshow(cos_sim, cmap="inferno", interpolation="nearest")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/home/ubuntu/Piano-Hands-Music-Transformer-Tokens-Embeddings-Plot.png", bbox_inches="tight")

# Congrats! You did it! :)