In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from datasets import load_dataset

from dataset import TextPairDataset
from model import CrossEncoder
from utils import train

import pandas as pd

In [2]:
device = torch.device("cuda:0")
print(device)

cuda:0


In [35]:
model_name = "prajjwal1/bert-tiny"

msmarco_train_dataset = load_dataset("ms_marco", "v2.1", split="train")

In [36]:
train_dataset = TextPairDataset(msmarco_train_dataset, model_name)
train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True)

In [5]:
model = CrossEncoder(model_name)
print(model)

CrossEncoder(
  (transformer_base): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 128, padding_idx=0)
      (position_embeddings): Embedding(512, 128)
      (token_type_embeddings): Embedding(2, 128)
      (LayerNorm): LayerNorm((128,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-1): 2 x BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=128, out_features=128, bias=True)
              (key): Linear(in_features=128, out_features=128, bias=True)
              (value): Linear(in_features=128, out_features=128, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=128, out_features=128, bias=True)
              (LayerNorm): LayerNorm((128,), eps=1e-12, elemen

In [6]:
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)

train(
  model=model, 
  train_loader=train_loader, 
  val_loader=None,
  criterion=criterion, 
  optimizer=optimizer, 
  num_epoch=15, 
  device=device,
  eval=False
)

  7%|▋         | 1/15 [24:00<5:36:08, 1440.62s/it]

Epoch 1/15: train_loss = 0.24010867713198178


 13%|█▎        | 2/15 [49:04<5:20:11, 1477.81s/it]

Epoch 2/15: train_loss = 0.24173076811857777


 20%|██        | 3/15 [1:14:30<5:00:00, 1500.01s/it]

Epoch 3/15: train_loss = 0.24149719883929063


 27%|██▋       | 4/15 [1:38:04<4:28:46, 1466.04s/it]

Epoch 4/15: train_loss = 0.2410212155454501


 33%|███▎      | 5/15 [2:01:38<4:01:10, 1447.05s/it]

Epoch 5/15: train_loss = 0.24034988023323015


 40%|████      | 6/15 [2:25:12<3:35:21, 1435.75s/it]

Epoch 6/15: train_loss = 0.2399771871161076


 47%|████▋     | 7/15 [2:48:44<3:10:24, 1428.08s/it]

Epoch 7/15: train_loss = 0.2395195520065417


 53%|█████▎    | 8/15 [3:12:55<2:47:28, 1435.48s/it]

Epoch 8/15: train_loss = 0.23888007855633236


 60%|██████    | 9/15 [3:36:38<2:23:09, 1431.58s/it]

Epoch 9/15: train_loss = 0.23793003944764693


 67%|██████▋   | 10/15 [4:00:18<1:59:00, 1428.10s/it]

Epoch 10/15: train_loss = 0.23725168486904694


 73%|███████▎  | 11/15 [4:23:58<1:35:01, 1425.37s/it]

Epoch 11/15: train_loss = 0.23634171194079595


 80%|████████  | 12/15 [4:47:37<1:11:10, 1423.43s/it]

Epoch 12/15: train_loss = 0.23579277183661865


 87%|████████▋ | 13/15 [5:11:11<47:21, 1420.60s/it]  

Epoch 13/15: train_loss = 0.23569030152560977


 93%|█████████▎| 14/15 [10:11:52<1:47:20, 6440.71s/it]

Epoch 14/15: train_loss = 0.23530033503737627


100%|██████████| 15/15 [10:36:42<00:00, 2546.85s/it]  

Epoch 15/15: train_loss = 0.23445593489652042





[0.24010867713198178,
 0.24173076811857777,
 0.24149719883929063,
 0.2410212155454501,
 0.24034988023323015,
 0.2399771871161076,
 0.2395195520065417,
 0.23888007855633236,
 0.23793003944764693,
 0.23725168486904694,
 0.23634171194079595,
 0.23579277183661865,
 0.23569030152560977,
 0.23530033503737627,
 0.23445593489652042]

In [7]:
torch.save(model, "models/ce-model-v21.pth")