***

# Perceiver Multi-instrumental Maker (Version 1.0)

https://github.com/asigalov61/tegridy-tools

https://github.com/lucidrains/perceiver-ar-pytorch

## Project Los Angeles
## Tegridy Code 2022

***

# GPU check

In [None]:
!nvidia-smi

# Install dependencies and import modules

In [None]:
!git clone https://github.com/asigalov61/tegridy-tools

In [None]:
!pip install einops
!pip install torch-summary
!pip install sklearn

In [None]:
# Load modules and make data dir

print('Loading modules...')

import pickle
import os
import random
import secrets
import tqdm

import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /notebooks/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /notebooks/tegridy-tools/tegridy-tools/Perceiver-AR/

from perceiver_ar_pytorch import PerceiverAR
from autoregressive_wrapper import AutoregressiveWrapper

%cd /notebooks/

if not os.path.exists('/notebooks/INTS'):
    os.makedirs('/notebooks/INTS')

print('Done')

# Download and unzip training data

In [None]:
# Perceiver Multi-Instrumental Training Data Pack
%cd /notebooks/INTS/
!wget --no-check-certificate -O 'Perceiver-MI-Training-Data.zip' "https://onedrive.live.com/download?cid=8A0D502FC99C608F&resid=8A0D502FC99C608F%2118738&authkey=AF6k171kUSX-Yrk"
!unzip 'Perceiver-MI-Training-Data.zip'
!rm 'Perceiver-MI-Training-Data.zip'
%cd /notebooks/

# Load training data

In [None]:
# Load training data

dataset_addr = "/notebooks/INTS"

filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

filez.sort()

print('Loading training data... Please wait...')

train_data = torch.Tensor()

for f in tqdm.tqdm(filez):
    train_data = torch.cat((train_data, torch.Tensor(pickle.load(open(f, 'rb')))))
    print('Loaded file:', f)
    
print('Done!')

In [None]:
len(train_data)

In [None]:
train_data[:15], train_data[-15:]

# Initialize the Model

In [None]:
# Setup model

# constants

SEQ_LEN = 8192 * 4 # 32k
PREFIX_SEQ_LEN = (8192 * 4) - 1024
BATCH_SIZE = 4

NUM_BATCHES = len(train_data) // SEQ_LEN // BATCH_SIZE

GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 200
SAVE_EVERY = 2000
GENERATE_LENGTH = 32

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate model

model = PerceiverAR(
    num_tokens = 512,
    dim = 1024,
    depth = 24,
    heads = 16,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)

model = AutoregressiveWrapper(model)
model.cuda()

print('Done!')
      
summary(model)

# Dataloader

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        
        # random sampling
        # idx = secrets.randbelow((self.data.size(0) // (self.seq_len))-1) * (self.seq_len)
        
        # consequtive sampling seems to be better at 64k seq_len
        idx = index * self.seq_len
        
        full_seq = self.data[idx: idx + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return (self.data.size(0) // self.seq_len)-1

train_dataset = MusicDataset(train_data, SEQ_LEN)
val_dataset   = MusicDataset(train_data, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# TRAIN

In [None]:
# Train the model

train_losses = []
val_losses = []

train_accs = []
val_accs = []

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss, acc = model(next(train_loader))
        loss.backward()

    print(f'Training loss: {loss.item()}')
    print(f'Training acc: {acc.item()}')
    
    train_losses.append(loss.item())
    train_accs.append(acc.item())
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = model(next(val_loader))
            print(f'Validation loss: {val_loss.item()}')
            print(f'Validation acc: {val_acc.item()}')
            val_losses.append(val_loss.item())
            val_accs.append(val_acc.item())
            
            print('Plotting training loss graph...')
            
            tr_loss_list = train_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/training_loss_graph.png')
            plt.close()
            print('Done!')
            
            print('Plotting training acc graph...')
            
            tr_loss_list = train_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/training_acc_graph.png')
            plt.close()
            print('Done!')
            
            print('Plotting validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/validation_loss_graph.png')
            plt.close()
            print('Done!')
            
            print('Plotting validation acc graph...')
            tr_loss_list = val_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/validation_accs_graph.png')
            plt.close()
            print('Done!')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        
        print(inp)

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        
        print(sample)
        
    if i % SAVE_EVERY == 0:
        
        print('Saving model progress. Please wait...')
        print('model_checkpoint_' + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth')
        
        fname = '/notebooks/model_checkpoint_'  + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth'
        
        torch.save(model.state_dict(), fname)
        
        print('Done!')        

# Save stats graphs

In [None]:
# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/notebooks/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/notebooks/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/notebooks/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/notebooks/validation_acc_graph.png')
plt.close()
print('Done!')

# Load/Reload the Model

In [None]:
# Load model

# constants

SEQ_LEN = 8192 * 4 # 32k
PREFIX_SEQ_LEN = (8192 * 4) - 1024

model = PerceiverAR(
    num_tokens = 512,
    dim = 1024,
    depth = 24,
    heads = 16,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)
model = AutoregressiveWrapper(model)
model.cuda()

state_dict = torch.load('model_checkpoint_0_steps_6.8286_loss.pth')

model.load_state_dict(state_dict)

model.eval()

# Model stats

summary(model)

In [None]:
# Plot Token Embeddings

cos_sim = metrics.pairwise.cosine_similarity(
   model.net.token_emb.weight.detach().cpu().numpy()[0].reshape(-1, 1)
)
plt.figure(figsize=(8, 8))
plt.imshow(cos_sim, cmap="inferno", interpolation="none")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/notebooks/Perceiver-Positional-Embeddings-Plot.png", bbox_inches="tight")

# EVAL

In [None]:
# Custom MIDI option

full_path_to_custom_MIDI_file = "/notebooks/tegridy-tools/tegridy-tools/seed2.mid" #@param {type:"string"}

print('Loading custom MIDI file...')
score = TMIDIX.midi2ms_score(open(full_path_to_custom_MIDI_file, 'rb').read())

events_matrix = []

itrack = 1

patches = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]

patch_map = [[0, 1, 2, 3, 4, 5, 6, 7], # Piano 
              [24, 25, 26, 27, 28, 29, 30], # Guitar
              [32, 33, 34, 35, 36, 37, 38, 39], # Bass
              [40, 41], # Violin
              [42, 43], # Cello
              [46], # Harp
              [56, 57, 58, 59, 60], # Trumpet
              [71, 72], # Clarinet
              [73, 74, 75], # Flute
              [-1], # Fake Drums
              [52, 53], # Choir
              [16, 17, 18, 19, 20] # Organ
            ]

while itrack < len(score):
    for event in score[itrack]:         
        if event[0] == 'note' or event[0] == 'patch_change':
            events_matrix.append(event)
    itrack += 1

events_matrix.sort(key=lambda x: x[1])

events_matrix1 = []
for event in events_matrix:
        if event[0] == 'patch_change':
            patches[event[2]] = event[3]

        if event[0] == 'note':
            event.extend([patches[event[3]]])
            once = False
            
            for p in patch_map:
                if event[6] in p and event[3] != 9: # Except the drums
                    event[3] = patch_map.index(p)
                    once = True
                    
            if not once and event[3] != 9: # Except the drums
                event[3] = 0 # All other instruments/patches channel
                event[5] = max(80, event[5])
                
            if event[3] < 12: # We won't write chans 11-16 for now...
                events_matrix1.append(event)

# Sorting...
events_matrix1.sort(key=lambda x: (x[1], x[3]))

# recalculating timings
for e in events_matrix1:
    e[1] = int(e[1] / 16)
    e[2] = int(e[2] / 32)

# final processing...

inputs = []

melody = []

melody_chords = []

pe = events_matrix1[0]
for e in events_matrix1:

    time = max(0, min(127, e[1]-pe[1]))
    dur = max(1, min(127, e[2]))
    cha = max(0, min(11, e[3]))
    ptc = max(1, min(127, e[4]))
    vel = max(19, min(127, e[5]))

    div_vel = int(vel / 19)

    chan_vel = (cha * 11) + div_vel

    # Continuation / Inpainting
    inputs.extend([chan_vel, time+128, dur+256, ptc+384])

    # Melody Orchestration
    if time != 0:
      if ptc < 60:
        ptc = (ptc % 12) + 60  

      # Converted to Piano
      melody.extend([div_vel, time+128, dur+256, ptc+384])

    # For future development
    melody_chords.append([time, dur, cha, ptc, vel])

    pe = e

# =================================
print('Done!')

In [None]:
len(inputs)

In [None]:
# Generate

import time

model.eval()
inp = val_dataset[0]

inp = [0, 127+128, 127+256, 0+384] * ((8192 * 3) + (8192-1024))

inp += inputs[:1024]

inp = torch.LongTensor(inp).cuda()

print(inp)

start_time = time.time()

out = model.generate(inp[None, ...], 
                     1024, 
                     temperature=0.6)

print(time.time() - start_time, "seconds")
print(out)

In [None]:
# Convert to MIDI

out1 = out.cpu().tolist()[0]

if len(out1) != 0:
    
    song = out1
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0

    son = []

    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:

        channel = s[0] // 11

        vel = (s[0] % 11) * 19

        time += (s[1]-128) * 16
            
        dur = (s[2] - 256) * 32
        
        pitch = (s[3] - 384)
                                  
        song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Perceiver',  
                                                        output_file_name = '/notebooks/Perceiver-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')