In [None]:
from config import get_config, get_weights_file_path
import warnings
from tqdm import tqdm
import os
from pathlib import Path
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

config = get_config()
config['batch_size'] = 64
config['preload'] = None
config['num_epochs'] = 60

from train import get_model, get_ds, run_validation
import torch
torch.cuda.amp.autocast(enabled=True) 

In [None]:
hasattr(torch.nn.functional, 'scaled_dot_product_attention')

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'device : {device}')

Path(config['model_folder']).mkdir(parents=True, exist_ok=True)

train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

writer = SummaryWriter(config['experiment_name'])
loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

In [None]:
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
dtype

In [None]:
def get_lr(optimizer):
    for param_group in optimizer.param_groups:
        return param_group['lr']

In [None]:
MAX_LR = 10**-3
STEPS_PER_EPOCH = len(train_dataloader)
EPOCHS = 10

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                max_lr=MAX_LR,
                                                steps_per_epoch=STEPS_PER_EPOCH,
                                                epochs=EPOCHS,
                                                pct_start=1/10 if EPOCHS !=1 else 0.5,
                                                div_factor=10,
                                                three_phase=True,
                                                final_div_factor=10, 
                                                anneal_strategy='linear'
                                                )

In [None]:
initial_epoch = 0
global_step = 0

scaler = torch.cuda.amp.GradScaler()
lr = [0.0]

for epoch in range(initial_epoch, EPOCHS):
    loss_acc = []
    torch.cuda.empty_cache()
    model.train()
    batch_iterator = tqdm(train_dataloader, desc=f'Processing epoch {epoch: 02d}')
    for batch in batch_iterator:
        optimizer.zero_grad(set_to_none=True)
        encoder_input = batch['encoder_input'].to(device)
        decoder_input = batch['decoder_input'].to(device)
        encoder_mask = batch['encoder_mask'].to(device)
        decoder_mask = batch['decoder_mask'].to(device)  

        with torch.autocast(device_type='cuda', dtype=torch.float16):
            encoder_output = model.encode(encoder_input, encoder_mask)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
            proj_output = model.project(decoder_output)

            label = batch['label'].to(device)

            loss = loss_fn(proj_output.view(-1, tokenizer_tgt.get_vocab_size()), label.view(-1))
        
        batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

        writer.add_scalar('train_loss', loss.item(), global_step)
        writer.flush()

        scaler.scale(loss).backward()

        scale = scaler.get_scale()
        scaler.step(optimizer)
        scaler.update()
        skip_lr_sched = (scale > scaler.get_scale())

        if not skip_lr_sched:
            scheduler.step()
        lr.append(scheduler.get_last_lr())

        global_step += 1

    run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)
        
    model_filename = get_weights_file_path(config, f"{epoch:02d}")
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'global_step': global_step
    }, model_filename)