In [None]:
%load_ext autoreload
%autoreload 2
import pandas as pd
import xml.etree.ElementTree as ET
import re
import numpy as np
import nltk
import torch
from torch import nn
import transformers
from nltk.metrics import windowdiff
from utils import *
from torch.utils.data import Dataset, DataLoader

In [None]:
class NSPDataset(Dataset):
    def __init__(self, message_pairs, device=torch.device("cuda:0")):
        self.message_pairs = message_pairs
        self.tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
        self.device = device
        self.labels = labels

    def __len__(self):
        return len(self.message_pairs)

    def __getitem__(self, idx):
        message_1 = self.message_pairs[idx][0][0]
        message_2 = self.message_pairs[idx][0][1]

        tokenized_input = self.tokenizer(
            message_1,
            message_2,
            max_length=25,
            truncation=True,
            padding="max_length",
            return_tensors="pt",
        )
        label = torch.tensor(self.message_pairs[idx][1])

        tokenized_input["input_ids"] = tokenized_input["input_ids"][0]
        tokenized_input["token_type_ids"] = tokenized_input["token_type_ids"][0]
        tokenized_input["attention_mask"] = tokenized_input["attention_mask"][0]

        return tokenized_input, label

In [None]:
torch.manual_seed(42)

train_message_pairs, test_message_pairs = train_test_split(list(zip(message_pairs, labels)), random_state=42)

train_dataset = NSPDataset(train_message_pairs)
test_dataset = NSPDataset(test_message_pairs)
quick_test_dataset = NSPDataset(test_message_pairs[:20])

train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=80, shuffle=True)
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=80)
quick_test_dataloader = torch.utils.data.DataLoader(quick_test_dataset, batch_size=80)

In [None]:
s = 0
for i in test_dataset:
    s += sum(i[0].input_ids > 0)
s / 25 / len(test_dataset)

In [None]:
from tqdm import tqdm

def train(model, optimizer, scheduler, dataloader, device=torch.device("cuda:0")):
    loss_fn = nn.CrossEntropyLoss()
    losses = []
    for num, (inp, target) in tqdm(enumerate(dataloader)):
        inp.to(device)
        output = model(**inp)

        loss = loss_fn(output.logits, target.to(device))
        losses.append(loss.item())
        loss.backward()
        optimizer.step()
        scheduler.step()

        if num % 100 == 0:
            validate(model, quick_test_dataloader)

    print(np.mean(losses))

def validate(mode, dataloader, device=torch.device("cuda:0")):
    loss_fn = nn.CrossEntropyLoss()
    model.eval()
    losses = []
    for inp, target in tqdm(dataloader):
        with torch.no_grad():
            inp.to(device)
            output = model(**inp)
            loss = loss_fn(output.logits, target.to(device))
            losses.append(loss.item())
            # print(loss)
    model.train()
    print(np.mean(losses))


In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

# model.cpu()
# del quick_test_dataloader
# del train_dataloader
# del test_dataloader
# del model 
# del optimizer

In [None]:
# model = BertForNextSentencePrediction.from_pretrained("distilbert-base-uncased").to("cuda:0")
lr = 1e-5
model = BertForNextSentencePrediction.from_pretrained("prajjwal1/bert-medium").to("cuda:0")
optimizer = torch.optim.AdamW(model.parameters(), lr=lr )

In [None]:
epochs = 2
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=lr , steps_per_epoch=len(train_dataset), epochs=epochs, anneal_strategy='linear')

In [None]:
for i in range(epochs):
    train(model, optimizer, scheduler, train_dataloader)
    validate(model, test_dataloader)
