https://github.com/lucidrains/perceiver-ar-pytorch

In [None]:
!nvidia-smi

In [None]:
!git clone https://github.com/asigalov61/tegridy-tools

In [None]:
!pip install einops
!pip install torch-summary
!pip install sklearn

In [None]:
import pickle
import os
import random
import secrets
import tqdm

import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /notebooks/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /notebooks/tegridy-tools/tegridy-tools/Perceiver-AR/

from perceiver_ar_pytorch import PerceiverAR
from autoregressive_wrapper import AutoregressiveWrapper

%cd /notebooks/

# Load training data

In [None]:
dataset_addr = "/notebooks/Euterpe-INTs"
# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

filez.sort()

print('Processing MIDI files. Please wait...')

train_data = torch.Tensor()

for f in tqdm.tqdm(filez):
    train_data = torch.cat((train_data, torch.Tensor(pickle.load(open(f, 'rb')))))
    print('Loaded file:', f)

In [None]:
len(train_data)

In [None]:
train_data[:15], train_data[-15:]

# Initialize the Model

In [None]:
# Setup model

# constants

SEQ_LEN = 8192 * 4 # 32k
PREFIX_SEQ_LEN = (8192 * 4) - 1024
BATCH_SIZE = 4

NUM_BATCHES = len(train_data) // SEQ_LEN // BATCH_SIZE

GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 2e-4
VALIDATE_EVERY  = 100
GENERATE_EVERY  = 200
SAVE_EVERY = 1000
GENERATE_LENGTH = 32

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate model

model = PerceiverAR(
    num_tokens = 512,
    dim = 1024,
    depth = 16,
    heads = 8,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)

model = AutoregressiveWrapper(model)
model.cuda()

print('Done!')
      
summary(model)

# prepare enwik8 data

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        
        # random sampling
        # idx = secrets.randbelow((self.data.size(0) // (self.seq_len))-1) * (self.seq_len)
        
        # consequtive sampling seems to be better at 64k seq_len
        idx = index * self.seq_len
        
        full_seq = self.data[idx: idx + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return (self.data.size(0) // self.seq_len)-1

train_dataset = MusicDataset(train_data, SEQ_LEN)
val_dataset   = MusicDataset(train_data, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# TRAIN

In [None]:
# Train the model

train_losses = []
val_losses = []

train_accs = []
val_accs = []

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss, acc = model(next(train_loader))
        loss.backward()

    print(f'Training loss: {loss.item()}')
    print(f'Training acc: {acc.item()}')
    
    train_losses.append(loss.item())
    train_accs.append(acc.item())
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = model(next(val_loader))
            print(f'Validation loss: {val_loss.item()}')
            print(f'Validation acc: {val_acc.item()}')
            val_losses.append(val_loss.item())
            val_accs.append(val_acc.item())
            
            print('Plotting training loss graph...')
            
            tr_loss_list = train_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/training_loss_graph.png')
            plt.close()
            print('Done!')
            
            print('Plotting training acc graph...')
            
            tr_loss_list = train_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/training_acc_graph.png')
            plt.close()
            print('Done!')
            
            print('Plotting validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/validation_loss_graph.png')
            plt.close()
            print('Done!')
            
            print('Plotting validation acc graph...')
            tr_loss_list = val_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/validation_accs_graph.png')
            plt.close()
            print('Done!')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        
        print(inp)

        sample = model.generate(inp[None, ...], GENERATE_LENGTH)
        
        print(sample)
        
    if i % SAVE_EVERY == 0:
        
        print('Saving model progress. Please wait...')
        print('model_checkpoint_' + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth')
        
        fname = '/notebooks/model_checkpoint_'  + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth'
        
        torch.save(model.state_dict(), fname)
        
        # torch.save({'state_dict': model.state_dict(),
        #             'optimizer': optim.state_dict(),
        #            }, fname)
        
        print('Done!')        

In [None]:
# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/notebooks/training_loss_graph.png')
print('Done!')

In [None]:
# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/notebooks/validation_loss_graph.png')
print('Done!')

# Load/Reload the Model

In [None]:
# Load model

# constants

SEQ_LEN = 8192 * 4 # 32k
PREFIX_SEQ_LEN = (8192 * 4) - 1024

model = PerceiverAR(
    num_tokens = 512,
    dim = 1024,
    depth = 16,
    heads = 8,
    dim_head = 64,
    cross_attn_dropout = 0.5,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)
model = AutoregressiveWrapper(model)
model.cuda()

state_dict = torch.load('model_checkpoint_2000_steps_1.218_loss.pth')

model.load_state_dict(state_dict)

model.eval()

# Model stats

summary(model)

In [None]:
# Plot Token Embeddings

cos_sim = metrics.pairwise.cosine_similarity(
   model.net.token_emb.weight.detach().cpu().numpy()
)
plt.figure(figsize=(8, 8))
plt.imshow(cos_sim, cmap="inferno", interpolation="none")
im_ratio = cos_sim.shape[0] / cos_sim.shape[1]
plt.colorbar(fraction=0.046 * im_ratio, pad=0.04)
plt.xlabel("Position")
plt.ylabel("Position")
plt.tight_layout()
plt.plot()
plt.savefig("/notebooks/Euterpe-Positional-Embeddings-Plot.png", bbox_inches="tight")

# EVAL

In [None]:
import time

model.eval()
inp = val_dataset[0][:-512]

print(inp)

start_time = time.time()

out = model.generate(inp[None, ...], 
                     512, 
                     temperature=1)

print(time.time() - start_time, "seconds")
print(out)

In [None]:
out1 = out.cpu().tolist()[0]

if len(out1) != 0:
    
    song = out1
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0

    son = []

    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:

        channel = s[0] // 11

        vel = (s[0] % 11) * 19

        time += (s[1]-128) * 16
            
        dur = (s[2] - 256) * 32
        
        pitch = (s[3] - 384)
                                  
        song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Euterpe',  
                                                        output_file_name = '/notebooks/Euterpe-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 16, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')