In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import numpy as np
import json
import pandas as pd
from tqdm.auto import tqdm
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
import warnings
from torch import optim
warnings.filterwarnings('always')
load = False

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
print('device: ' + str(device))

In [None]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', truncation=True)

In [None]:
class ClickBaitDataSet(Dataset):
    def __init__(self,df,tokenizer):
        self.df=df.dropna().reset_index()
        self.tokenizer=tokenizer
    def __getitem__(self, idx):
        heading = self.df["heading"][idx]
        body = self.df["body"][idx]
        if heading == "" or heading == " " or len(heading)== 0:
            heading = "something"
        if body == "" or body == " " or len(body)== 0:
            body = "something"
        
        l = self.df["label"][idx]
        hinputs = self.tokenizer.encode_plus(
            heading,
            None,
            add_special_tokens=True,
            max_length=512,
            padding="max_length",
            return_token_type_ids=True,
            truncation=True,
        )
        hi =  np.array(hinputs['input_ids']).astype(np.int_)
        hm =  np.array(hinputs['attention_mask']).astype(np.int_)
        ht =  np.array(hinputs["token_type_ids"]).astype(np.int_)
        binputs = self.tokenizer.encode_plus(
            body,
            None,
            add_special_tokens=True,
            max_length=512,
            padding="max_length",
            return_token_type_ids=True,
            truncation=True,
        )
        bi =  np.array(binputs['input_ids']).astype(np.int_)
        bm =  np.array(binputs['attention_mask']).astype(np.int_)
        bt =  np.array(binputs["token_type_ids"]).astype(np.int_)
        
        return hi,hm,ht,bi,bm,bt,l
    def __len__(self):
        return len(self.df)

In [None]:
class Attention(nn.Module):
    def __init__(self, feature_dim, step_dim, bias=True, **kwargs):
        super(Attention, self).__init__(**kwargs)
        
        self.supports_masking = True

        self.bias = bias
        self.feature_dim = feature_dim
        self.step_dim = step_dim
        self.features_dim = 0
        
        weight = torch.zeros(feature_dim, 1)
        nn.init.kaiming_uniform_(weight)
        self.weight = nn.Parameter(weight)
        
        if bias:
            self.b = nn.Parameter(torch.zeros(step_dim))
        
    def forward(self, x, mask=None):
        feature_dim = self.feature_dim 
        step_dim = self.step_dim

        eij = torch.mm(
            x.contiguous().view(-1, feature_dim), 
            self.weight
        ).view(-1, step_dim)
        
        if self.bias:
            eij = eij + self.b
            
        eij = torch.tanh(eij)
        a = torch.exp(eij)
        
        if mask is not None:
            a = a * mask

        a = a / (torch.sum(a, 1, keepdim=True) + 1e-10)

        weighted_input = x * torch.unsqueeze(a, -1)
        return torch.sum(weighted_input, 1)

In [None]:
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
        self.bert = AutoModel.from_pretrained("bert-base-uncased")
        self.h_atten = Attention(768,512)
        self.b_atten = Attention(768,512)
        self.sim_atten = Attention(1,4)
        self.cos = torch.nn.CosineSimilarity(dim=1, eps=1e-08)
        self.pred =  nn.Linear(1540 , 2)
        self.local =  nn.Linear(1 , 2)
    def forward(self,hi,hm,ht,bi,bm,bt):
        # Headine.
        houtputs = self.bert(hi, attention_mask=hm, token_type_ids=ht)
        hhidden , _ = houtputs[0], houtputs[1]
        h_att = self.h_atten(hhidden,hm) 
        
        # Body
        boutputs = self.bert(bi, attention_mask=bm, token_type_ids=bt)
        bhidden,_= boutputs[0], boutputs[1]
        b_att = self.b_atten(bhidden,bm)
        
        #Global Prediction
        glob_sim = self.cos(h_att,b_att).unsqueeze(-1)
        global_logits = torch.cat((glob_sim, 1-glob_sim),1)
        
        # Similarity Vector
        h_chunks = torch.chunk(h_att,4,dim=1)
        b_chunks = torch.chunk(b_att,4,dim=1)
        sim_vector = torch.cat([self.cos(hc,bc).unsqueeze(-1) for hc , bc in zip(h_chunks,b_chunks)],dim=1)
        
        # Local Prediction
        local_pred = self.sim_atten(sim_vector.unsqueeze(-1))
        local_logits = self.local(local_pred)
        
        # Prediction
        final_vector = torch.cat((h_att,sim_vector,b_att),dim=1)
        #print(final_vector.shape)
        prediction_logits = self.pred(final_vector)
        return local_logits , global_logits , prediction_logits

In [None]:
def train_func_epoch(epoch, model, dataloader, device, optimizer):
    model.train()
    total_loss = 0
    with tqdm(dataloader, unit="batch", total=len(dataloader)) as single_epoch:
        
        for step, batch in enumerate(single_epoch):

            single_epoch.set_description(f"Training- Epoch {epoch}")
            hi,hm,ht,bi,bm,bt,l = batch 
            hi = hi.to(device)
            hm = hm.to(device)
            ht = ht.to(device)
            bi = bi.to(device)
            bm = bm.to(device)
            bt = bt.to(device)
            l = l.to(device)
            model.zero_grad()
            ll , gl , pl = model(hi,hm,ht,bi,bm,bt)
            l1 = F.cross_entropy(ll,l)
            l2 = F.cross_entropy(gl,l)
            l3 = F.cross_entropy(pl,l)
            loss = l1+l2+l3
            total_loss += loss.item()
            loss.backward()
            optimizer.step()
            model.zero_grad()
            single_epoch.set_postfix(train_loss=total_loss/(step+1))
    return total_loss / len(dataloader)

In [None]:
def eval_func_epoch(model, dataloader, device, epoch):
    model.eval()
    total_loss = 0
    targets = []
    predictions = []
    with tqdm(dataloader, unit="batch", total=len(dataloader)) as single_epoch:
        for step, batch in enumerate(single_epoch):
            single_epoch.set_description(f"Evaluating- Epoch {epoch}")
            hi,hm,ht,bi,bm,bt,l = batch 
            hi = hi.to(device)
            hm = hm.to(device)
            ht = ht.to(device)
            bi = bi.to(device)
            bm = bm.to(device)
            bt = bt.to(device)
            l = l.to(device)
            model.zero_grad()
            with torch.no_grad():
                ll , gl , pl = model(hi,hm,ht,bi,bm,bt)
                l1 = F.cross_entropy(ll,l)
                l2 = F.cross_entropy(gl,l)
                l3 = F.cross_entropy(pl,l)
                loss = l1+l2+l3
                total_loss += loss.item()
            single_epoch.set_postfix(train_loss=total_loss/(step+1))
            pred = torch.argmax(pl, dim=1).flatten().cpu().numpy()
            predictions.append(pred)
            targets.append(l.cpu().numpy())
    targets = np.concatenate(targets, axis=0)
    predictions = np.concatenate(predictions, axis=0)
    epoch_validation_loss = total_loss/len(dataloader)
    report = classification_report(targets, predictions, output_dict=True, labels=[0,1])
    tn, fp, fn, tp = confusion_matrix(targets, predictions).ravel()
    if epoch == "TESTING":
        ConfusionMatrixDisplay.from_predictions(targets, predictions)
        plt.savefig("confusion.png",dpi=300)
    return epoch_validation_loss, report, tn, fp, fn, tp

In [None]:
model = Model()
model_path="models/model.pt"
if load :
    loaded_state_dict = torch.load(model_path,  map_location=device)
    model.load_state_dict(loaded_state_dict)  
opt = optim.Adam(model.parameters(),lr = 1e-5)
model.to(device)
batch_size=4
epochs=10

In [None]:
train_df = pd.read_csv("../../data/clean/train.csv")#.head(500)
test_df = pd.read_csv("../../data/clean/train.csv")#.head(500)
val_df = pd.read_csv("../../data/clean/val.csv")#.head(500)
train_data = ClickBaitDataSet(train_df,tokenizer)
test_data = ClickBaitDataSet(test_df,tokenizer)
val_data = ClickBaitDataSet(val_df,tokenizer)
train_data_loader = DataLoader(train_data, batch_size=batch_size)
val_data_loader = DataLoader(val_data, batch_size=batch_size)
test_data_loader = DataLoader(test_data, batch_size=batch_size)

In [None]:
best_loss = np.inf
best_epoch = 0
for epoch in range(epochs):
    print(f"\n---------------------- Epoch: {epoch+1} ---------------------------------- \n")
    ## Training Loop
    train_loss = train_func_epoch(epoch+1, model, train_data_loader, device, opt)
    ## Validation loop
    val_loss, report, tn, fp, fn, tp = eval_func_epoch(model, val_data_loader, device, epoch+1)
    print(f"\nEpoch: {epoch+1} | Training loss: {train_loss} | Validation Loss: {val_loss}")
    print()
    print(report)
    print()
    print(f"TP: {tp} | FP: {fp} | TN: {tn}, FN: {fn} ")
    print(f"\n----------------------------------------------------------------------------")
    ## Save the model 
    if (val_loss < best_loss):
        torch.save(model.state_dict(), model_path)
        best_loss = val_loss
        best_epoch = epoch+1

In [None]:
loaded_state_dict = torch.load(model_path,  map_location=device)
model.load_state_dict(loaded_state_dict)
print(f"\n---------------------- Testing best model (at epoch: {best_epoch} )---------------------------------- \n")
test_loss,report, tn, fp, fn, tp = eval_func_epoch(model, test_data_loader, device, "TESTING")
print(f"\nTest loss: {test_loss}")
print()
print(report)
print()
print(f"TP: {tp} | FP: {fp} | TN: {tn}, FN: {fn} ")

with open("./report.json","w") as f:
    json.dump(report,f,indent=4) 