In [1]:
import torch
import torchinfo
import torch.nn.functional as F
from torch import nn
from torch import Tensor
from torch.utils.data import (
    DataLoader,
    SequentialSampler,
    BatchSampler,
    RandomSampler
)
from torch.utils.tensorboard import SummaryWriter
from torch.optim import Adam
from torch.optim.lr_scheduler import ExponentialLR, CyclicLR
from torcheval import metrics
from torch.profiler import (
    profile,
    ProfilerActivity,
    tensorboard_trace_handler
)
import datasets
from datasets import (
    Features,
    Array2D
)
from transformers import (
    AutoTokenizer,
    BertTokenizer
)
import os
import math
from tqdm import tqdm
from functools import partial
from typing import (
    Optional,
    Dict 
)
import lightning as L

import pandas as pd
from src import data, neural, func

try:
    from config import config
except:
    pass

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
train_dataloader, test_dataloader, validation_dataloader = data.create_dataloaders(config)

## Pure torch

### Test different configs

In [None]:
func.test_model(config, train_dataloader)

In [None]:
print(20 + 16 + 16)
print((48.5 + 60.1 + 60.6) / 3)
print((0.006 + 0.006 + 0.005) / 3)
print(f'''
WORKERS: {config["DATA"]["DATALOADER_NUM_WORKERS"]};
NON_BLOCKING: {config["DATA"]["NON_BLOCKING"]};
DTYPE: {config["MODEL"]["DTYPE"]};
WARMUP: {config["MODEL"]["WARMUP"]};
TF32: {config["MODEL"]["TF32"]}'''
)

### Train model

In [None]:
bert_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert_tokenizer.add_special_tokens({
    "eos_token": "[EOS]",
    "bos_token": "[BOS]"
})

model = neural.Transformer(
    vocab_size = len(bert_tokenizer),
    seq_len = config["MODEL"]["SEQ_LEN"],
    emb_dim = config["MODEL"]["EMB_DIM"],
    n_heads = config["MODEL"]["ATTN_HEADS"],
    feedforward_dim = config["MODEL"]["FF_DIM"],
    dropouts = config["MODEL"]["DROPOUTS"],
    dtype = config["MODEL"]["DTYPE"],
    activation = nn.LeakyReLU(),
    enc_num = 5,
    dec_num = 5,
).to(config["DEVICE"])

writer = func.create_run_logger(config["TRAINING"]["LOGS_FOLDER"], model)
loss_fn = nn.CrossEntropyLoss(ignore_index = 0)
optimizer = Adam(params = model.parameters(), betas = (0.9, 0.999), eps = 1e-8)
model_metrics = {"BLEU": metrics.BLEUScore(n_gram = 3)}
scheduler = CyclicLR(optimizer, base_lr = 1e-4, max_lr = 1e-3, step_size_up = 10, cycle_momentum = False)

trainer = neural.Trainer(
    model = model,
    optimizer = optimizer,
    loss_fn = loss_fn,
    scheduler = scheduler,
    tokenizer = bert_tokenizer,
    epoch = config["TRAINING"]["EPOCH"],
    device = config["DEVICE"],
    checkpoint_path = config["TRAINING"]["CHECKPOINT_PATH"],
    checkpoint_by = config["TRAINING"]["CHECKPOINT_BY"],
    non_blocking = config["DATA"]["NON_BLOCKING"],
    warmup = config["MODEL"]["WARMUP"],
    max_src_len = 191,
    max_tgt_len = 80,
    batch_size = config["DATA"]["BATCH_SIZE"],
    model_metrics = model_metrics
)

In [None]:
state_dict = torch.load("train_states/best_state_5dec.pt")
model.state_dict().update(state_dict["model_state"])
scheduler.state_dict().update(state_dict["scheduler_state"])
trainer.BLEU = 0

In [5]:
trainer.load_state_dict(torch.load(config["TRAINING"]["CHECKPOINT_PATH"]))

In [None]:
trainer.train(train_dataloader, validation_dataloader, validation_rate = 2, writer = writer)

In [None]:
trainer.validate(validation_dataloader, model_metrics)

In [None]:
sample = next(iter(validation_dataloader))
index = 2
# model(sample["document"][index].unsqueeze(0).to(config["DEVICE"]), sample["summary"][index].unsqueeze(0).to(config["DEVICE"])).shape
model.eval()
with torch.inference_mode():
    pred = model(sample["document"].to(config["DEVICE"]), sample["summary"].to(config["DEVICE"])).argmax(-1)
bert_tokenizer.batch_decode(pred, skip_special_tokens = True)[index], bert_tokenizer.decode(sample["summary"][index], skip_special_tokens = True)

In [None]:
for step, batch in enumerate(train_dataloader):
    x, y = batch["document"].to(config["DEVICE"], non_blocking = config["DATA"]["NON_BLOCKING"]), batch["summary"].to(config["DEVICE"], non_blocking = config["DATA"]["NON_BLOCKING"])

    y_pred = model(x, y)

    summary_with_eos = F.pad(y[:, 1:], pad = (0,1), value = 0)
    summary_with_eos[torch.arange(config["DATA"]["BATCH_SIZE"], device = config["DEVICE"]), summary_with_eos.argmin(dim = 1)] = bert_tokenizer.eos_token_id

    loss = loss_fn(y_pred.view(-1, y_pred.shape[2]), summary_with_eos.view(-1))
    loss.backward()
    ave_grads = []
    max_grads= []
    layers = []
    for n, p in model.named_parameters():
        if(p.requires_grad) and ("bias" not in n):
            layers.append(n)
            ave_grads.append(p.grad.abs().mean().cpu())
            max_grads.append(p.grad.abs().max().cpu())
    writer.add_scalars("test_grads",{name: value for name,value in zip(layers, ave_grads)}, step)
    optimizer.zero_grad()