# Perceiver Music Transformer Maker (ver. 0.5)

***

Powered by tegridy-tools: https://github.com/asigalov61/tegridy-tools

***

WARNING: This complete implementation is a functioning model of the Artificial Intelligence. Please excercise great humility, care, and respect. https://www.nscai.gov/

***

#### Project Los Angeles

#### Tegridy Code 2023

***

# GPU check

In [None]:
!nvidia-smi

# Setup environment

In [None]:
!git clone https://github.com/asigalov61/tegridy-tools

In [None]:
!pip install torch
!pip install einops
!pip install torch-summary
!pip install sklearn
!pip install tqdm
!pip install matplotlib

In [None]:
# Load modules and make data dir

print('Loading modules...')

import os
import pickle
import random
import secrets
import tqdm

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import matplotlib.pyplot as plt

from torchsummary import summary
from sklearn import metrics

%cd /content/tegridy-tools/tegridy-tools/

import TMIDIX

%cd /content/tegridy-tools/tegridy-tools/Perceiver-AR

from perceiver_ar_pytorch_full import PerceiverAR, AutoregressiveWrapper

%cd /content/

if not os.path.exists('/content/INTS'):
    os.makedirs('/content/INTS')

print('Done')

# Load training data

In [None]:
# Load training data

dataset_addr = "/content/INTS"

filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

filez.sort()

print('Loading training data... Please wait...')

train_data = torch.Tensor([0, 0, 0, 0]) # Quick dirty hack to offset the training data for proper loading

for f in tqdm.tqdm(filez):
    train_data = torch.cat((train_data, torch.Tensor(pickle.load(open(f, 'rb')))))
    print('Loaded file:', f)
    
print('Done!')

In [None]:
len(train_data)

In [None]:
((len(train_data) // 6144 // 32) * 32) / 32

In [None]:
train_data[:15], train_data[-15:]

# Setup model

In [None]:
# Setup model

# constants

SEQ_LEN = 6144 # 6k
PREFIX_SEQ_LEN = 4096 # 4k
BATCH_SIZE = 32

GRADIENT_ACCUMULATE_EVERY = 1

# We going to train for 1 full epoch because we are using consequtive sampling.
# Number of steps will be calculated based on the training data length.

NUM_BATCHES = (len(train_data) // SEQ_LEN // BATCH_SIZE // GRADIENT_ACCUMULATE_EVERY)

LEARNING_RATE = 2e-4

VALIDATE_EVERY  = 100
SAVE_EVERY = 1000
GENERATE_EVERY  = 200
PRINT_STATS_EVERY = 50

GENERATE_LENGTH = 32

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate model

model = PerceiverAR(
    num_tokens = 2145,
    dim = 1024,
    depth = 32,
    ff_mult=2,
    cross_attn_dropout = 0.25,
    max_seq_len = SEQ_LEN,
    cross_attn_seq_len = PREFIX_SEQ_LEN
)

model = AutoregressiveWrapper(model)

model = torch.nn.DataParallel(model)

model.cuda()

print('Done!')
      
summary(model)

# Dataloader

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        
        # consequtive sampling
        idx = index * self.seq_len
   
        full_seq = self.data[idx: idx + self.seq_len + 1].long()
        
        return full_seq.cuda()

    def __len__(self):
        return (self.data.size(0) // self.seq_len // BATCH_SIZE) * BATCH_SIZE

train_dataset = MusicDataset(train_data, SEQ_LEN)
val_dataset   = MusicDataset(train_data, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [None]:
train_dataset[666]

In [None]:
len(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# Train

In [None]:
# Train the model

train_losses = []
val_losses = []

train_accs = []
val_accs = []

for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='Training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss, acc = model(next(train_loader))
        loss.backward(torch.ones(loss.shape).cuda())
        
    if i % PRINT_STATS_EVERY == 0:
        print(f'Training loss: {loss.mean().item()}')
        print(f'Training acc: {acc.mean().item()}')
    
    train_losses.append(loss.mean().item())
    train_accs.append(acc.mean().item())
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            val_loss, val_acc = model(next(val_loader))
            
            print(f'Validation loss: {val_loss.mean().item()}')
            print(f'Validation acc: {val_acc.mean().item()}')
            
            val_losses.append(val_loss.mean().item())
            val_accs.append(val_acc.mean().item())
            
            print('Plotting training loss graph...')
            
            tr_loss_list = train_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')
            
            print('Plotting training acc graph...')
            
            tr_loss_list = train_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')
            
            print('Plotting validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')
            
            print('Plotting validation acc graph...')
            tr_loss_list = val_accs
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            plt.close()
            print('Done!')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        
        print(inp)

        sample = model.module.generate(inp[None, ...], GENERATE_LENGTH)
        
        print(sample)
        
    if i % SAVE_EVERY == 0:
        
        print('Saving model progress. Please wait...')
        print('model_checkpoint_' + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth')
        
        fname = '/content/model_checkpoint_'  + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth'
        
        torch.save(model.state_dict(), fname)
        
        data = [train_losses, train_accs, val_losses, val_accs]

        TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs')
        
        print('Done!')      

# Final Save

In [None]:
print('Saving model progress. Please wait...')
print('model_checkpoint_' + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth')

fname = '/content/model_checkpoint_'  + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth'

torch.save(model.state_dict(), fname)

print('Done!')   

In [None]:
# Save training loss graph

plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/content/training_loss_graph.png')
plt.close()
print('Done!')

# Save training acc graph

plt.plot([i for i in range(len(train_accs))] ,train_accs, 'b')
plt.savefig('/content/training_acc_graph.png')
plt.close()
print('Done!')

# Save validation loss graph

plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/content/validation_loss_graph.png')
plt.close()
print('Done!')

# Save validation acc graph

plt.plot([i for i in range(len(val_accs))] ,val_accs, 'b')
plt.savefig('/content/validation_acc_graph.png')
plt.close()
print('Done!')

In [None]:
data = [train_losses, train_accs, val_losses, val_accs]

TMIDIX.Tegridy_Any_Pickle_File_Writer(data, '/content/losses_accs')

# Eval

In [None]:
model.eval()
inp = [random.choice(val_dataset)[:4164].tolist()]

inp = torch.LongTensor(inp).cuda()

print(inp)

sample = model.module.generate(inp, 512, temperature=0.8, return_prime=False)

print(sample)

In [None]:
#@title Convert to MIDI

train_data1 = sample[0].tolist()

print('Sample INTs', train_data1[:15])

out = train_data1[:200000]

if len(out) != 0:
    
    song = out
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0
                    
    for ss in song:
      
      if ss > 0 and ss < 256:

          time += ss * 8
        
      if ss >= 256 and ss < 512:

          dur = (ss-256) * 16

      if ss >= 512 and ss < 608:

          channel = (ss-512) // 8
          vel = (((ss-512) % 8)+1) * 15
              
      if ss >= 608 and ss < 608+(12*128):
          
          pitch = (ss-608) % 128

          song_f.append(['note', time, dur, channel, pitch, vel ])
            
    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Perceiver Music Transformer',  
                                                        output_file_name = '/content/Perceiver-Music-Transformer-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')

# Congrats! You did it! :)