In [None]:
!git clone https://github.com/FedorZaitsev/VKR25
%cd VKR25

In [None]:
import os
os.environ['TORCH_CUDA_ARCH_LIST']="5.0 5.2 5.3 6.0 6.1 6.2 7.0 7.2 7.5 8.0 8.6 8.7 8.9 9.0"

!pip install Ninja
!git clone https://github.com/c-hofer/torchph.git
!pip install -e torchph

import sys
sys.path.append("/kaggle/working/VKR25/torchph")

In [None]:
import os

config = {
    'SEED' : 228,
    
    'BOS_TOKEN' : 4096,
    'EOS_TOKEN' : 4097,
    'INP_PAD_TOKEN' : 4098,
    'TAR_PAD_TOKEN' : -100,
    'VOCAB_SIZE' : 4099,
    'MAX_LENGTH' : 256,
    'OVERLAP' : 64,
    
    'NUM_WORKERS' : 4,
    'BATCH_SIZE' : 16,

    'ACCUM_STEPS' : 1,
}

for key, value in config.items():
    os.environ[key] = str(value)

In [None]:
import torch
import random
import numpy as np

device = 'cuda'
root_dir = '/kaggle/input/groove-tokens'

torch.manual_seed(config['SEED'])
random.seed(config['SEED'])
np.random.seed(config['SEED'])

In [None]:
from models.topotransformer_model import TopoTransformerModel, PositionalEncoding, CustomTransformerEncoderLayer
torch.serialization.safe_globals([TopoTransformerModel])

In [None]:
def load_old_model(model_old, model_new):
    sd = model_new.state_dict()
    for key, value in model_old.state_dict().items():
        sd[key].copy_(value)

    model_new.load_state_dict(sd)
    del sd
    return model_new

In [None]:
model_old = torch.load('/kaggle/input/transformer/pytorch/default/1/checkpoint_400.pt', map_location=device)

In [None]:
ttm = load_old_model(model_old, TopoTransformerModel().to(device))

In [None]:
old_model_param_names = model_old.state_dict().keys()
for n, p in ttm.named_parameters():
    if n in old_model_param_names:
        p.requires_grad = False

ttm.linear.weight.requires_grad=True
ttm.linear.bias.requires_grad=True

In [None]:
from data import data

sequences = data.read_sequences(root_dir)
train_dataset, valid_dataset = data.get_train_val_dataset(sequences)

train_loader = data.get_loader(train_dataset)
valid_loader = data.get_loader(valid_dataset)

In [None]:
%cd logging
from wandb_logger import WandBLogger

In [None]:
import matplotlib.pyplot as plt

def plot_losses(train_l, valid_l, eval_every, name):
    plt.grid()
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.plot(np.arange(len(train_l)), train_l, label='training', c='blue')
    plt.scatter(np.arange(eval_every-1, len(train_l), eval_every), valid_l, label='validation', c='orange')
    plt.legend()
    plt.savefig(name)

In [None]:
import torch.nn as nn
from torch.optim import Adam, AdamW, SGD
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from transformers import get_cosine_schedule_with_warmup

In [None]:
EPOCHS = 50
LR = 1e-5
EPS = 1e-8
WD = 1e-2

D_MODEL = 512
NHEAD = 8
NUM_LAYERS = 6
DIM_FEEDFORWARD = 2048


model = ttm

optimizer = AdamW(model.parameters(), lr=LR, eps=EPS, weight_decay=WD)
criterion = nn.CrossEntropyLoss(ignore_index=config['TAR_PAD_TOKEN'])



total_steps = EPOCHS * (len(train_loader) + 1)
# scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=total_steps//3, num_training_steps=total_steps)
scheduler = CosineAnnealingLR(optimizer, T_max=total_steps//5, eta_min=1e-6)
# scheduler = LambdaLR(optimizer, lambda x: 1)

In [None]:
key = ''
proj_name = 'VKR25'
logger = WandBLogger(
    key=key,
    proj_name=proj_name,
    name='',
    cfg={
    'MAX_LENGTH' : config['MAX_LENGTH'],
    'OVERLAP' : config['OVERLAP'],
    
    'NUM_WORKERS' : config['NUM_WORKERS'],
    'BATCH_SIZE' : config['BATCH_SIZE'],

    'ACCUM_STEPS' : config['ACCUM_STEPS'],

    'D_MODEL' : D_MODEL,
    'NHEAD' : NHEAD,
    'NUM_LAYERS' : NUM_LAYERS,
    'DIM_FEEDFORWARD' : DIM_FEEDFORWARD,
        
    'OPTIMIZER' : 'AdamW',
    'LR' : LR,
    'EPS' : EPS,
    'WD' : WD,
    }
)

In [None]:
EVAL_EVERY = 1
CHECKPOINT_EVERY = 1

train_losses = []
valid_losses = []

print(f'Total parameters: {sum(p.numel() for p in model.parameters())}')

for epoch in range(EPOCHS):
    avg_loss = model.train_epoch(train_loader, optimizer, criterion, scheduler, logger)
    train_losses.append(avg_loss)
    print(f"Epoch {epoch+1}, Avg Loss: {avg_loss:.4f}")
    if (epoch+1) % EVAL_EVERY == 0:
        val_loss = model.validate(valid_loader, criterion, logger)
        valid_losses.append(val_loss)
        print(f"Epoch {epoch+1}, Val Loss: {val_loss:.4f}")
        
    if (epoch+1) % CHECKPOINT_EVERY == 0:
        torch.save(model, f"checkpoint_{epoch+1}.pt")
        plot_losses(train_losses, valid_losses, EVAL_EVERY, f'epoch{epoch+1}.png')

In [None]:
logger.kill()