In [None]:
!git clone https://github.com/RInkalshah93/ERA-V2-Assignment_Rinkal-Shah.git

In [None]:
%cd ERA-V2-Assignment_Rinkal-Shah/S18_Assignment

In [None]:
!pip install -r requirements.txt

In [None]:
from config_file import get_config, get_weights_file_path
from train import get_model, get_ds, run_validation

config = get_config()
config["batch_size"] = 16
config["preload"] = None
config["num_epochs"] = 1
import torch
torch.cuda.amp.autocast(enabled = True)


In [None]:
import warnings
from tqdm import tqdm
import os
from pathlib import Path

from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using Device:", device)

Path(config["model_folder"]).mkdir(parent = True, exist_ok = True)

train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

writer = SummaryWriter(config['experiment_name'])
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAd]'), lable_smoothing=0.1).to(device)

optimizer = torch.optim.Adam(model.parameters(),lr = config['lr'], eps=1e-9)


In [None]:
MAX_LR = 10**-3
STEPS_PER_EPOCH = len(train_dataloader)
EPOCHS = 18

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
                                                optimizer,
                                                max_lr =  MAX_LR,
                                                steps_per_epoch= STEPS_PER_EPOCH
                                                epochs = EPOCHS
                                                pct_start = int(0.3*EPOCHS)/EPOCHS if EPOCHS ! = 1 else 0.5,
                                                div_factor = 100
                                                three_phase = False
                                                final_div_factor = 100
                                                anneal_strategy = "linear"

)

In [None]:
intial_epoch = 0
global_step =  0

scaler = torch.cuda.amp.GradScaler()
lr = [0.0]
for epoch in range(intial_epoch, 30):
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(train_dataloader, desc =f"Processing Epoch {epoch:02d}")
    for batch in batch_iterator:
        optimizer.zero_grad(set_to_none=True)
        encoder_input = batch['encoder_input'].to(device)
        decoder_input = batch['decoder_input'].to(device)
        encoder_mask = batch['encoder_mask'].to(device)
        decoder_mask = batch['decoder_mask'].to(device)

        with torch.autocast(divece_type = 'cuda', dtype = torch.float16):
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            lable = batch['label'].to(device)

            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
        batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}", "lr": f"{lr[-1]}"})

        writer.add_scalar('train loss', loss.item(), global_step)
        writer.flush()

        scaler.scal(loss).backward()

        scale = scaler.get_scale()
        scaler.step(optimizer)
        scaler.update()
        skip_lr_sched = (scale > scaler.get_scale())
        if not skip_lr_sched:
            scheduler.step()
        lr.append(scheduler.get_last_lr())

        global_step += 1

    run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step)

    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save({
        'epoch' : epoch,
        'model_state_dict' : model.state_dict(),
        'optimizer_state_dict' : optimizer.state_dict(),
        'global_step' : global_step
        }, model_filename)