In [76]:
from TransformerWrapper import TransformerWrapper
from AutoregressiveWrapper import AutoregressiveWrapper
from x_transformers import Decoder

from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from torch import tensor
import random
import torch
import tqdm
import pickle

import sys
sys.path.append('C:\\Users\\ilove\\CODING\\PYStuff\\MusicNet\\Midi2Numpy\\MIDI-Generator-with-Transformers\\Data_Extraction')
sys.path.append('C:\\Users\\ilove\\CODING\\PYStuff\\MusicNet\\Midi2Numpy\\MIDI-Generator-with-Transformers')
from tokenizer import Tokenizer
from data_to_MIDI import data_to_MIDI

In [59]:
with open('small_tokenizer.pickle', 'rb') as f:
    small_tokenizer = pickle.load(f)

In [60]:
with open('small_data.pickle', 'rb') as f:
    data = pickle.load(f)

In [61]:
NUM_BATCHES = 1000
BATCH_SIZE = 4
GRADIENT_ACCUMULATE_EVERY = 4
LEARNING_RATE = 1e-4
VALIDATE_EVERY  = 10
GENERATE_EVERY  = 1000
GENERATE_LENGTH = 256
SEQ_LEN = 512
#refer to small_tokenizer_test
NUM_TOKENS_VALUES = 106
NUM_TOKENS_TIMES = 16
NUM_TOKENS_INSTRUMENTS = 17

In [62]:
class PieceDataset(Dataset):
    def __init__(self, data, seq_length):
        self.data = data
        self.seq_length = seq_length
        self.padding_value = 0  # You can adjust the padding value as needed

        # Perform sequence padding
        self.padded_data = [self.pad_sequence(matrix) for matrix in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.padded_data[idx]
        return tensor(x,dtype=torch.long)

    def pad_sequence(self, matrix):
        # Pad the sequence with zeros to match the desired sequence length
        #  seq_length x 3
        matrix = torch.tensor(matrix)
        if matrix.shape[0] < self.seq_length:
            padding_size = self.seq_length - matrix.shape[0]
            padding = torch.full((padding_size, 3), self.padding_value)
            padding[:, 0] = 1
            matrix = torch.concatenate((matrix, padding), axis=0)
        return matrix

In [63]:
len(data)

5

In [64]:
data[0].shape

(242, 3)

In [65]:
data[4].shape

(460, 3)

In [66]:
def cycle(loader):
    while True:
        for data in loader:
            yield data

In [67]:
train = PieceDataset(data[1:], SEQ_LEN)

In [68]:
val = PieceDataset([data[0]], SEQ_LEN)

In [69]:
train_loader = cycle(DataLoader(train, batch_size=BATCH_SIZE, shuffle=False))

In [70]:
val_loader = cycle(DataLoader(val, batch_size=BATCH_SIZE, shuffle=False))

In [71]:
model = TransformerWrapper(
    num_tokens_values=NUM_TOKENS_VALUES,
    num_tokens_times=NUM_TOKENS_TIMES,
    num_tokens_instruments=NUM_TOKENS_INSTRUMENTS,
    max_seq_len=SEQ_LEN,
    use_abs_pos_emb = False,
    post_emb_norm=True,
    attn_layers=Decoder(
        dim = 32,
        depth = 1,
        heads = 2,
        rotary_pos_emb=True,
        attn_flash=True,
        use_scalenorm=True,
        ff_glu=True,
    )
)

In [72]:
model = AutoregressiveWrapper(model)

In [73]:
optim = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [74]:
loss_list = []
validate_loss_list = []

In [75]:
for i in tqdm.tqdm(range(NUM_BATCHES), mininterval=10., desc='training'):
    model.train()

    for __ in range(GRADIENT_ACCUMULATE_EVERY):
        loss = model(next(train_loader))
        loss_list.append(loss.item())
        (loss / GRADIENT_ACCUMULATE_EVERY).backward()

    print(f'training loss: {loss.item()}')
    torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
    optim.step()
    optim.zero_grad()

    if i % VALIDATE_EVERY == 0:
        model.eval()
        with torch.no_grad():
            loss = model(next(val_loader))
            validate_loss_list.append(loss.item())
            print(f'validation loss: {loss.item()}')

    if i % GENERATE_EVERY == 0:
        model.eval()
        inp = random.choice(val)
        print(inp)
        # add dimension to inp
        inp = inp.unsqueeze(0)

        sample = model.generate(
            prompts = inp,
            seq_len = GENERATE_LENGTH,
            cache_kv = True,
            eos_token = tensor([1,0,0])
        )
        output_str = small_tokenizer.detokenize(sample[:,-1])
        print(output_str)

  return tensor(x,dtype=torch.long)


training loss: 10.19651985168457
validation loss: 10.169755935668945
tensor([[ 0,  0,  0],
        [10,  0, 15],
        [10,  0, 16],
        ...,
        [ 1,  0,  0],
        [ 1,  0,  0],
        [ 1,  0,  0]])



100%|██████████| 1/1 [00:00<?, ?it/s][A


[[ 0  1 -1 -1 -1 -1 -1  0 -1 -1 -1]]
training loss: 10.195898056030273
training loss: 10.195276260375977
training loss: 10.194652557373047
training loss: 10.194031715393066
training loss: 10.193408966064453
training loss: 10.192789077758789
training loss: 10.192166328430176
training loss: 10.191547393798828
training loss: 10.190927505493164


training:   1%|          | 11/1000 [00:01<02:35,  6.34it/s]

training loss: 10.1903076171875
validation loss: 10.164663314819336





KeyboardInterrupt: 

In [None]:
plt.plot(loss_list)
# show validate loss list with seperations of 10 (repeat 10 times for each)
validate_loss_list = [item for item in validate_loss_list for i in range(40)]

plt.plot(validate_loss_list)

In [None]:
#save model
torch.save(model.state_dict(), 'model.pth')

In [None]:
#load model
model = TransformerWrapper(
    num_tokens_values=NUM_TOKENS_VALUES,
    num_tokens_times=NUM_TOKENS_TIMES,
    num_tokens_instruments=NUM_TOKENS_INSTRUMENTS,
    max_seq_len=SEQ_LEN,
    use_abs_pos_emb = False,
    post_emb_norm=True,
    attn_layers=Decoder(
        dim = 32,
        depth = 1,
        heads = 2,
        rotary_pos_emb=True,
        attn_flash=True,
        use_scalenorm=True,
        ff_glu=True,
    )
)
model = AutoregressiveWrapper(model)
model.load_state_dict(torch.load('model.pth'))

In [None]:
output = model.generate(
    prompts = tensor([[[0, 0, 0],[0,0,0]]]),
    seq_len = 512,
    cache_kv = True,
    eos_token = tensor([1,0,0]))


In [None]:
output[:,0]