https://github.com/lucidrains/sinkhorn-transformer

In [None]:
!nvidia-smi

In [None]:
!pip install sinkhorn_transformer

In [None]:
!git clone https://github.com/asigalov61/tegridy-tools

In [None]:
%cd /notebooks/tegridy-tools/tegridy-tools/
import TMIDIX
%cd /notebooks/

In [None]:
import pickle
import os
import tqdm
import torch 

dataset_addr = "/notebooks/Euterpe-INTs"
# os.chdir(dataset_addr)
filez = list()
for (dirpath, dirnames, filenames) in os.walk(dataset_addr):
    filez += [os.path.join(dirpath, file) for file in filenames]
print('=' * 70)

filez.sort()

print('Processing MIDI files. Please wait...')

train_data = torch.Tensor()

for f in tqdm.tqdm(filez):
    train_data = torch.cat((train_data, torch.Tensor(pickle.load(open(f, 'rb')))))
    print('Loaded file:', f)

In [None]:
len(train_data)

In [None]:
train_data[:15], train_data[-15:]

# TRAIN

In [None]:
train_loader

In [None]:
from sinkhorn_transformer import SinkhornTransformerLM
from sinkhorn_transformer.autoregressive_wrapper import AutoregressiveWrapper

import random
import tqdm
import torch
import torch.optim as optim
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
import secrets
import matplotlib.pyplot as plt

# constants

SEQ_LEN = 4096 # 4096
BATCH_SIZE = 16

NUM_BATCHES = len(train_data) // SEQ_LEN // BATCH_SIZE

GRADIENT_ACCUMULATE_EVERY = 1
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 50
GENERATE_EVERY  = 150
SAVE_EVERY = 50
GENERATE_LENGTH = 32


# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# instantiate model

model = SinkhornTransformerLM(
    num_tokens = 512,
    emb_dim = 128,
    dim = 1024,
    depth = 24,
    max_seq_len = SEQ_LEN,
    heads = 8,
    bucket_size = 128,
    ff_chunks = 4,
    causal = True,
    reversible = True,
    attn_dropout = 0.1,
    n_local_attn_heads = 4,
)

model = AutoregressiveWrapper(model)
model.cuda()

# prepare enwik8 data

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        
        idx = secrets.randbelow((self.data.size(0) // (self.seq_len))-1) * (self.seq_len)
        
        full_seq = self.data[idx: idx + self.seq_len + 1].long()
        return full_seq.cuda()

    def __len__(self):
        return (self.data.size(0) // self.seq_len)-1

train_dataset = MusicDataset(train_data, SEQ_LEN)
val_dataset   = MusicDataset(train_data, SEQ_LEN)
train_loader  = cycle(DataLoader(train_dataset, batch_size = BATCH_SIZE))
val_loader    = cycle(DataLoader(val_dataset, batch_size = BATCH_SIZE))

# optimizer

optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

# training

train_losses = []
val_losses = []


for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader), return_loss = True)
        loss.backward()

    print(f'training loss: {loss.item()}')
    
    train_losses.append(loss.item())
    
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            val_loss = model(next(val_loader), return_loss = True)
            print(f'validation loss: {val_loss.item()}')
            val_losses.append(val_loss.item())
            
            print('Saving validation loss graph...')
            tr_loss_list = val_losses
            plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
            plt.show()
            # plt.savefig('/notebooks/validation_loss_graph.png')
            plt.close()
            print('Done!')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val_dataset)[:-1]
        
        print(f'%s \n\n %s', (inp, '*' * 100))

        sample = model.generate(inp, GENERATE_LENGTH)
        
        print(sample)
        
    if i % SAVE_EVERY == 0:
        
        print('Saving model progress. Please wait...')
        print('model_checkpoint_' + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth')
        torch.save(model.state_dict(), '/notebooks/model_checkpoint_'  + str(i) + '_steps_' + str(round(float(train_losses[-1]), 4)) + '_loss.pth')
        print('Done!')
        
        print('Saving training loss graph...')
        tr_loss_list = train_losses
        plt.plot([i for i in range(len(tr_loss_list))] ,tr_loss_list, 'b')
        plt.show()
        # plt.savefig('/notebooks/training_loss_graph.png')
        plt.close()
        print('Done!')
        

In [None]:
plt.plot([i for i in range(len(train_losses))] ,train_losses, 'b')
plt.savefig('/notebooks/training_loss_graph.png')
print('Done!')

In [None]:
plt.plot([i for i in range(len(val_losses))] ,val_losses, 'b')
plt.savefig('/notebooks/validation_loss_graph.png')
print('Done!')

# EVAL

In [None]:
import time

In [None]:
model.eval()
inp = val_dataset[2][:512]

print(f'%s \n\n %s', (inp, '*' * 100))
# torch.LongTensor([6]).cuda()
start_time = time.time()
out = model.generate(inp, 512)
print(time.time() - start_time, "seconds")
print(out)

In [None]:
out1 = out.cpu().tolist()

In [None]:
if len(out1) != 0:
    
    song = out1
    song_f = []
    time = 0
    dur = 0
    vel = 0
    pitch = 0
    channel = 0

    son = []

    song1 = []

    for s in song:
      if s > 127:
        son.append(s)

      else:
        if len(son) == 4:
          song1.append(son)
        son = []
        son.append(s)
    
    for s in song1:

        channel = s[0] // 11

        vel = (s[0] % 11) * 19

        time += (s[1]-128) * 16
            
        dur = (s[2] - 256) * 32
        
        pitch = (s[3] - 384)
                                  
        song_f.append(['note', time, dur, channel, pitch, vel ])

    detailed_stats = TMIDIX.Tegridy_SONG_to_MIDI_Converter(song_f,
                                                        output_signature = 'Mini Muse',  
                                                        output_file_name = '/notebooks/Mini-Muse-Music-Composition', 
                                                        track_name='Project Los Angeles',
                                                        list_of_MIDI_patches=[0, 24, 32, 40, 42, 46, 56, 71, 73, 0, 53, 19, 0, 0, 0, 0],
                                                        number_of_ticks_per_quarter=500)

    print('Done!')

In [None]:
out1 = train_data[:160000]