In [1]:
!pip install tokenizers
!pip install torchtext
!pip install pytorch_lightning
!pip install datasets
!pip install tensorboard
!pip install lion_pytorch

Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch>=2.3.0->torchtext)
  Using cached nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
Collecting nvidia-cuda-runtime-cu12==12.1.105 (from torch>=2.3.0->torchtext)
  Using cached nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
Collecting nvidia-cuda-cupti-cu12==12.1.105 (from torch>=2.3.0->torchtext)
  Using cached nvidia_cuda_cupti_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (14.1 MB)
Collecting nvidia-cudnn-cu12==8.9.2.26 (from torch>=2.3.0->torchtext)
  Using cached nvidia_cudnn_cu12-8.9.2.26-py3-none-manylinux1_x86_64.whl (731.7 MB)
Collecting nvidia-cublas-cu12==12.1.3.1 (from torch>=2.3.0->torchtext)
  Using cached nvidia_cublas_cu12-12.1.3.1-py3-none-manylinux1_x86_64.whl (410.6 MB)
Collecting nvidia-cufft-cu12==11.0.2.54 (from torch>=2.3.0->torchtext)
  Using cached nvidia_cufft_cu12-11.0.2.54-py3-none-manylinux1_x86_64.whl (121.6 MB)
Collecting nvidia-curand-cu12==10.3.2.106 

In [1]:
from lion_pytorch import Lion
import torch
from tqdm import tqdm
import torchmetrics
from torch.utils.tensorboard import SummaryWriter
import torch.nn as nn

In [2]:
from config_file import get_config, get_weights_file_path
from train import train_model, get_ds, get_model

config = get_config()
config["batch_size"] = 24
config["preload"] = None
config["num_epochs"] = 25

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device : {device}")

import torch
torch.cuda.amp.autocast(enabled=True)





Using device : cuda


<torch.cuda.amp.autocast_mode.autocast at 0x7aa2566e6f20>

In [3]:
train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)

optimizer = Lion(model.parameters(), lr = config["lr"], weight_decay= 1e-2)

#Tensorboard
writer = SummaryWriter(config["experiment_name"])

loss_fn = nn.CrossEntropyLoss(ignore_index = tokenizer_src.token_to_id("[PAD]"), label_smoothing=0.1)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Max length of the source sentence : 309
Max length of the source target : 274


In [4]:
MAX_LR = 10**-4
STEPS_PER_EPOCH = len(train_dataloader)
EPOCHS = 18


In [5]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer,
                                                    max_lr=MAX_LR,
                                                    steps_per_epoch= STEPS_PER_EPOCH,
                                                    epochs= EPOCHS,
                                                    pct_start = int(0.3*EPOCHS)/EPOCHS if EPOCHS != 1 else 0.5, #30% of total number of epochs
                                                    div_factor=100,
                                                    three_phase=False,
                                                    final_div_factor=100,
                                                    anneal_strategy="linear"
                                                    )

In [6]:
initial_epoch = 0
global_step = 0

scaler = torch.cuda.amp.GradScaler()
lr = [0.0]

for epoch in range(initial_epoch, EPOCHS):
  torch.cuda.empty_cache()
  print("Starting the epoch : ", epoch)
  model.train()
  batch_iterator = tqdm(train_dataloader, desc = f"Processing Epoch {epoch:02d}")

  for batch in batch_iterator:
    # One cycle policy change
    optimizer.zero_grad(set_to_none=True)
    encoder_input = batch["encoder_input"].to(device)
    decoder_input = batch["decoder_input"].to(device)
    encoder_mask = batch["encoder_mask"].to(device)
    decoder_mask = batch["decoder_mask"].to(device)

    with torch.autocast(device_type='cuda', dtype= torch.float16):
      encoder_output = model.encode(encoder_input, encoder_mask)
      decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask)
      proj_output = model.project(decoder_output)

      label = batch["label"].to(device)

      #Compute loss using cross entropy
      tgt_vocab_size = tokenizer_tgt.get_vocab_size()
      loss = loss_fn(proj_output.view(-1, tgt_vocab_size), label.view(-1))

    batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}", " lr ":f"{lr[-1]}"})

    #Log the loss
    writer.add_scalar('train_loss', loss.item(), global_step)
    writer.flush()

    #Backpropogate loss
    # loss.backward()
    scaler.scale(loss).backward()

    #Update weights
    # optimizer.step()
    # optimizer.zero_grad(set_to_none=True)

    scale = scaler.get_scale()
    scaler.step(optimizer)
    scaler.update()
    skip_lr_sched = (scale > scaler.get_scale())
    if not skip_lr_sched:
        scheduler.step()
    lr.append(scheduler.get_last_lr())

    global_step+=1

  # with risk of failing - taking too long
  # run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, writer, global_step)


  model_filename = get_weights_file_path(config, f"{epoch:02d}")
  torch.save(
      {
          "epoch": epoch,
          "model_state_dict": model.state_dict(),
          "optimizer_state_dict": optimizer.state_dict(),
          "global_step": global_step
      },
      model_filename
  )

Starting the epoch :  0


Processing Epoch 00: 100%|██████████| 1213/1213 [04:08<00:00,  4.89it/s, loss=6.590,  lr =[2.078693931398417e-05]]


Starting the epoch :  1


Processing Epoch 01: 100%|██████████| 1213/1213 [04:08<00:00,  4.88it/s, loss=5.656,  lr =[4.059020448548813e-05]]


Starting the epoch :  2


Processing Epoch 02: 100%|██████████| 1213/1213 [04:09<00:00,  4.86it/s, loss=5.267,  lr =[6.039346965699209e-05]]


Starting the epoch :  3


Processing Epoch 03: 100%|██████████| 1213/1213 [04:12<00:00,  4.80it/s, loss=5.428,  lr =[8.019673482849604e-05]]


Starting the epoch :  4


Processing Epoch 04: 100%|██████████| 1213/1213 [04:10<00:00,  4.84it/s, loss=4.564,  lr =[0.0001]]


Starting the epoch :  5


Processing Epoch 05: 100%|██████████| 1213/1213 [04:08<00:00,  4.87it/s, loss=4.466,  lr =[9.231480246052382e-05]]


Starting the epoch :  6


Processing Epoch 06: 100%|██████████| 1213/1213 [04:09<00:00,  4.86it/s, loss=4.454,  lr =[8.462960492104763e-05]]


Starting the epoch :  7


Processing Epoch 07: 100%|██████████| 1213/1213 [04:10<00:00,  4.85it/s, loss=3.690,  lr =[7.693806645950916e-05]]


Starting the epoch :  8


Processing Epoch 08: 100%|██████████| 1213/1213 [04:09<00:00,  4.87it/s, loss=3.613,  lr =[6.925286892003298e-05]]


Starting the epoch :  9


Processing Epoch 09: 100%|██████████| 1213/1213 [04:08<00:00,  4.88it/s, loss=3.882,  lr =[6.156767138055678e-05]]


Starting the epoch :  10


Processing Epoch 10: 100%|██████████| 1213/1213 [04:08<00:00,  4.87it/s, loss=2.951,  lr =[5.3876132919018325e-05]]


Starting the epoch :  11


Processing Epoch 11: 100%|██████████| 1213/1213 [04:08<00:00,  4.88it/s, loss=2.690,  lr =[4.618459445747987e-05]]


Starting the epoch :  12


Processing Epoch 12: 100%|██████████| 1213/1213 [04:08<00:00,  4.88it/s, loss=2.390,  lr =[3.8499396918003685e-05]]


Starting the epoch :  13


Processing Epoch 13: 100%|██████████| 1213/1213 [04:09<00:00,  4.87it/s, loss=2.458,  lr =[3.080785845646521e-05]]


Starting the epoch :  14


Processing Epoch 14: 100%|██████████| 1213/1213 [04:10<00:00,  4.85it/s, loss=2.146,  lr =[2.3122660916989024e-05]]


Starting the epoch :  15


Processing Epoch 15: 100%|██████████| 1213/1213 [04:09<00:00,  4.87it/s, loss=2.162,  lr =[1.543746337751284e-05]]


Starting the epoch :  16


Processing Epoch 16: 100%|██████████| 1213/1213 [04:08<00:00,  4.88it/s, loss=1.857,  lr =[7.745924915974377e-06]]


Starting the epoch :  17


Processing Epoch 17: 100%|██████████| 1213/1213 [04:10<00:00,  4.84it/s, loss=1.787,  lr =[5.4386454435912854e-08]]
