In [2]:
!pip install einops

Collecting einops
  Downloading einops-0.7.0-py3-none-any.whl.metadata (13 kB)
Downloading einops-0.7.0-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.6/44.6 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: einops
Successfully installed einops-0.7.0
[0m

In [3]:
import os
import pickle
import random
import secrets
import tqdm

import zipfile


import json
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset

import math

from torch.optim.lr_scheduler import LambdaLR 
from torch.optim import Adam 
from torch.cuda.amp import autocast, GradScaler # Mixed Precision

from x_transformer2 import TransformerWrapper, Decoder, AutoregressiveWrapper
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
class LrStepTracker:
    """
    ----------
    Author: Ryan Marshall
    Modified: Damon Gwinn
    ----------
    Class for custom learn rate scheduler (to be used by torch.optim.lr_scheduler.LambdaLR).
    Learn rate for each step (batch) given the warmup steps is:
        lr = [ 1/sqrt(d_model) ] * min[ 1/sqrt(step) , step * (warmup_steps)^-1.5 ]
    This is from Attention is All you Need (https://arxiv.org/abs/1706.03762)
    ----------
    """

    def __init__(self, model_dim=1024, warmup_steps=4000, init_steps=0):
        # Store Values
        self.warmup_steps = warmup_steps
        self.model_dim = model_dim
        self.init_steps = init_steps

        # Begin Calculations
        self.invsqrt_dim = (1 / math.sqrt(model_dim))
        self.invsqrt_warmup = (1 / (warmup_steps * math.sqrt(warmup_steps)))

    # step
    def step(self, step):
        """
        ----------
        Author: Ryan Marshall
        Modified: Damon Gwinn
        ----------
        Method to pass to LambdaLR. Increments the step and computes the new learn rate.
        ----------
        """

        step += self.init_steps
        if(step <= self.warmup_steps):
            return self.invsqrt_dim * self.invsqrt_warmup * step
        else:
            invsqrt_step = (1 / math.sqrt(step))
            return self.invsqrt_dim * invsqrt_step



# get_lr
def get_lr(optimizer):
    """
    ----------
    Author: Damon Gwinn
    ----------
    Hack to get the current learn rate of the model
    ----------
    """

    for param_group in optimizer.param_groups:
        return param_group['lr']

In [18]:
tokens_train = list()
train_path = "data_remi_norm_bpe/train"
for (dirpath, dirnames, filenames) in os.walk(train_path):
    # Filter files with only .json
    tokens_train += [os.path.join(dirpath, file) for file in filenames if file.endswith(".json")]

tokens_train_aug = list()
train_path = "data_remi_norm_bpe_aug/train"
for (dirpath, dirnames, filenames) in os.walk(train_path):
    # Filter files with only .json
    tokens_train_aug += [os.path.join(dirpath, file) for file in filenames if file.endswith(".json")]

tokens_val = list()
val_path = "data_remi_norm_bpe/val"
for (dirpath, dirnames, filenames) in os.walk(val_path):
    # Filter files with only .json
    tokens_val += [os.path.join(dirpath, file) for file in filenames if file.endswith(".json")]


In [19]:
print(f"Train len: {len(tokens_train)}, Train Aug len: {len(tokens_train_aug)}, Val len: {len(tokens_val)}")

Train len: 21919, Train Aug len: 136653, Val len: 1159


In [5]:
train_data = torch.Tensor([0, 0, 0, 0]) # Quick dirty hack to offset the training data for proper loading

for f in tqdm.tqdm(tokens_train):
    a = json.load(open(f))['ids'][0]
    a.append(2001) # Append END_TOKEN
    train_data = torch.cat((train_data, torch.Tensor(a)))

#val_data = torch.Tensor([0, 0, 0, 0]) # Quick dirty hack to offset the training data for proper loading

for f in tqdm.tqdm(tokens_val):
    a = json.load(open(f))['ids'][0]
    a.append(2001) # Append END_TOKEN
    val_data = torch.cat((val_data, torch.Tensor(a)))

100%|██████████| 21919/21919 [10:50<00:00, 33.69it/s]
100%|██████████| 1159/1159 [00:01<00:00, 643.21it/s]


In [5]:
#torch.save(train_data, 'train_data_r_norm_bpe_tensor.pt')
#torch.save(val_data, 'val_data_r_norm_bpe_tensor.pt')
train_data = torch.load('train_data_r_norm_bpe_tensor.pt')
val_data = torch.load('val_data_r_norm_bpe_tensor.pt')

In [6]:
lr_init = None
batch_size = 2
max_sequence = 1024
model_dim = 1024
model_depth = 32
epochs = 10
gradient_accum = 16

num_batches = (len(train_data) // max_sequence // batch_size) * epochs

VALIDATE_EVERY  = 500
SAVE_EVERY = 5000
SAVE_STATS_EVERY = 500

# helpers

def cycle(loader):
    while True:
        for data in loader:
            yield data

# Dataloader

class MusicDataset(Dataset):
    def __init__(self, data, seq_len):
        super().__init__()
        self.data = data
        self.seq_len = seq_len

    def __getitem__(self, index):
        
        # random sampling
        
        idx = secrets.randbelow(self.data.size(0) - self.seq_len - 1)
        full_seq = self.data[idx: idx + self.seq_len + 1].long()
        
        return full_seq.cuda()

    def __len__(self):
        return self.data.size(0)

train_dataset = MusicDataset(train_data, max_sequence)
val_dataset   = MusicDataset(val_data, max_sequence)
train_loader  = cycle(DataLoader(train_dataset, batch_size = batch_size))
val_loader    = cycle(DataLoader(val_dataset, batch_size = batch_size))

In [7]:
def last_5000(lst):
    if len(lst) > 10:
        return lst[-10:]
    else:
        return lst

In [24]:
model = TransformerWrapper(
    num_tokens = 2002,
    max_seq_len = max_sequence,
    attn_layers = Decoder(dim = model_dim, depth = model_depth, heads = 16)
)
model = AutoregressiveWrapper(model)
model = torch.nn.DataParallel(model)
model.cuda()

DataParallel(
  (module): AutoregressiveWrapper(
    (net): TransformerWrapper(
      (token_emb): TokenEmbedding(
        (emb): Embedding(2002, 1024)
      )
      (pos_emb): AbsolutePositionalEmbedding(
        (emb): Embedding(1024, 1024)
      )
      (post_emb_norm): Identity()
      (emb_dropout): Dropout(p=0.0, inplace=False)
      (project_emb): Identity()
      (attn_layers): Decoder(
        (layers): ModuleList(
          (0): ModuleList(
            (0): ModuleList(
              (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (1-2): 2 x None
            )
            (1): Attention(
              (to_q): Linear(in_features=1024, out_features=1024, bias=False)
              (to_k): Linear(in_features=1024, out_features=1024, bias=False)
              (to_v): Linear(in_features=1024, out_features=1024, bias=False)
              (attend): Attend(
                (attn_dropout): Dropout(p=0.0, inplace=False)
              )
              (to_out): Li

In [9]:
model = TransformerWrapper(
    num_tokens = 2002,
    max_seq_len = max_sequence,
    attn_layers = Decoder(
        dim = model_dim,
        depth = model_depth,
        heads = 16,
        layer_dropout = 0.15,   # stochastic depth - dropout entire layer
        attn_dropout = 0.1,    # dropout post-attention
        ff_dropout = 0.15       # feedforward dropout
    ))
model = AutoregressiveWrapper(model)
model = torch.nn.DataParallel(model)
model.cuda()

DataParallel(
  (module): AutoregressiveWrapper(
    (net): TransformerWrapper(
      (token_emb): TokenEmbedding(
        (emb): Embedding(2002, 1024)
      )
      (pos_emb): AbsolutePositionalEmbedding(
        (emb): Embedding(1024, 1024)
      )
      (post_emb_norm): Identity()
      (emb_dropout): Dropout(p=0.0, inplace=False)
      (project_emb): Identity()
      (attn_layers): Decoder(
        (layers): ModuleList(
          (0): ModuleList(
            (0): ModuleList(
              (0): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)
              (1-2): 2 x None
            )
            (1): Attention(
              (to_q): Linear(in_features=1024, out_features=1024, bias=False)
              (to_k): Linear(in_features=1024, out_features=1024, bias=False)
              (to_v): Linear(in_features=1024, out_features=1024, bias=False)
              (attend): Attend(
                (attn_dropout): Dropout(p=0.1, inplace=False)
              )
              (to_out): Li

In [10]:
ADAM_BETA_1             = 0.9
ADAM_BETA_2             = 0.98
ADAM_EPSILON            = 10e-9

LR_DEFAULT_START        = 1.0

if lr_init is None: 
    init_step = 0 
    lr = LR_DEFAULT_START 
    lr_stepper = LrStepTracker(1024, 4000, 0)
else: 
    lr = lr_init

opt = Adam(model.parameters(), lr=lr, betas=(ADAM_BETA_1, ADAM_BETA_2), eps=ADAM_EPSILON)

if lr_init is None: 
    lr_scheduler = LambdaLR(opt, lr_stepper.step) #LambdaLR(optimizer, lr_lambda=lf,last_epoch = start_epoch)
else:
    lr_scheduler = None

In [11]:
# Initialize the scaler for mixed-precision training
scaler = GradScaler()

In [50]:
torch.cuda.empty_cache()

In [13]:
ckpt_dir = "./ckpt"
train_losses = []
val_losses = []

train_accs = []
val_accs = []

train_losses_a = 0
val_losses_a = 0

train_accs_a = 0
val_accs_a = 0

last_val_accs_a = 0

In [14]:
checkpoint = torch.load("./ckpt/good_bpe/latest_bpe2.pth", map_location="cpu")
model.load_state_dict(checkpoint['state_dict'])
opt.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
scaler.load_state_dict(checkpoint['scaler'])

In [None]:
with tqdm.tqdm(range(num_batches), mininterval=10., desc='Training') as pbar:
    for i in pbar:
        model.train()
    
        with autocast():
            loss, acc = model(next(train_loader))
            # Backward pass with autocasting and gradient scaling
        scaler.scale(loss).backward()
            
        if i % SAVE_STATS_EVERY == 0:
            with open('train_acc.txt', 'a') as file:
                file.write(f"{acc.mean().item()}\n")
            with open('train_loss.txt', 'a') as file:
                file.write(f"{loss.mean().item()}\n")
        
        train_losses.append(loss.mean().item())
        train_accs.append(acc.mean().item())
        train_losses_a = sum(last_5000(train_losses))/len(last_5000(train_losses))
        train_accs_a = sum(last_5000(train_accs))/len(last_5000(train_accs))

        if ((i + 1) % gradient_accum == 0) or (i + 1 == num_batches):
        # Gradient scaling step and optimizer update
            torch.nn.utils.clip_grad_norm_(model.parameters(), 2)
            scaler.step(opt)
            scaler.update()
            opt.zero_grad()
            lr_scheduler.step() 

        pbar.set_postfix(stats=f" TL: {round(train_losses_a, 3)}, TA: {round(train_accs_a, 3)}, VL: {round(val_losses_a, 3)}, VA: {round(val_accs_a, 3)} -----")
    
        if i % VALIDATE_EVERY == 0:
            model.eval()
            with torch.no_grad():
                with autocast():
                    val_loss, val_acc = model(next(val_loader))
                
                val_losses.append(val_loss.mean().item())
                val_accs.append(val_acc.mean().item())
                val_losses_a = sum(last_5000(val_losses))/len(last_5000(val_losses))
                val_accs_a = sum(last_5000(val_accs))/len(last_5000(val_accs))
                #print(f"val acc:{val_accs_a}")
    
                with open('val_acc.txt', 'a') as file:
                    file.write(f"{val_acc.mean().item()}\n")
                with open('val_loss.txt', 'a') as file:
                    file.write(f"{val_loss.mean().item()}\n")
        if i % SAVE_EVERY == 0:
            if len(val_accs) > 1:
                if val_accs_a > last_val_accs_a:
                    torch.save({
                        'state_dict': model.state_dict(),
                        'optimizer': opt.state_dict(),
                        'lr_scheduler': lr_scheduler.state_dict(),
                        'scaler': scaler.state_dict()
                    }, os.path.join(ckpt_dir, 'latest_bpe_orig.pth'))
                    #print(f"last acc before:{last_val_accs_a}")
                    last_val_accs_a = val_accs_a
                    #print(f"last acc:{last_val_accs_a}")

Training:   9%|▉         | 30322/328390 [1:26:32<12:29:26,  6.63it/s, stats=TL: 1.202, TA: 0.687, VL: 1.448, VA: 0.648 -----]

In [32]:
val_accs_a

0.6470703125

In [14]:
torch.save({
    'state_dict': model.state_dict(),
    'optimizer': opt.state_dict(),
    'lr_scheduler': lr_scheduler.state_dict(),
    'scaler': scaler.state_dict(),
}, os.path.join(ckpt_dir, 'latest_bpe_orig.pth'))

In [21]:
from IPython.display import FileLink
FileLink("ckpt/latest.pth")