In [None]:
!pip install transformers pytorch-lightning

In [None]:
from google.colab import files
files.upload()

In [None]:
import warnings 
import numpy as np 
import pandas as pd
from typing import List 
warnings.filterwarnings('ignore')

In [None]:
import torch
import torch.nn as nn 
from torch.utils.data import DataLoader, Dataset 

In [None]:
import transformers
import pytorch_lightning as pl

In [None]:
# Configs

EPOCHS=3
DEVICE='cuda'

TRAIN_BATCH_SIZE=2
VALID_BATCH_SIZE=2
TEST_BATCH_SIZE=1

ACCUMULATION_STEPS=4
MAX_LEN=128

TRAINING_DATASET='/content/train.csv'
TESTING_DATASET='/content/evaluation.csv'
CHECKPOINT='/content/checkpoints'

TOKENIZER=transformers.BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=True)
BERTMODEL=transformers.BertModel.from_pretrained('bert-base-uncased')

In [None]:
# Bert Dataset Approach 1 

class BERTDataset:
    def __init__(self, texts : List[str], reasons : List[str], targets : List[int]):
        self.texts=texts
        self.reasons=reasons
        self.targets=targets
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, item : int):
        text=str(self.texts[item])
        reason=str(self.reasons[item])

        text=" ".join(text.split())
        reason=" ".join(reason.split()) 
        inputs=TOKENIZER.encode_plus(
            text, 
            reason, 
            add_special_tokens=True, 
            max_length=MAX_LEN,
            padding='max_length'
        )

        ids=inputs['input_ids']
        token_types=inputs['token_type_ids']
        mask=inputs['attention_mask']

        return {
            'ids' : torch.tensor(ids, dtype=torch.long), 
            'mask': torch.tensor(mask, dtype=torch.long),
            'token_type_ids' : torch.tensor(token_types, dtype=torch.long),
            'targets' : torch.tensor(int(self.targets[item]), dtype=torch.float32)
        }

In [None]:
# utility functions 

def train_valid_split(df, test_split=0.2):
    train_length = int(len(df) * (1 - test_split))
    train_data = pd.DataFrame(df.iloc[:train_length, :])
    valid_data = pd.DataFrame(df.iloc[train_length:, :])
    return (train_data, valid_data)

def get_text_reason_target(df):
    return [
        df['text'].tolist(),
        df['reason'].tolist(),
        df['label'].tolist()
    ]

def bce_loss(outputs, targets):
    return nn.BCEWithLogitsLoss()(outputs, targets)


def get_train_valid_test_dataset(train_df, valid_df, test_df, loaders=True):
    train_text, train_reason, train_targets = get_text_reason_target(train_df) 
    valid_text, valid_reason, valid_targets = get_text_reason_target(valid_df)
    test_text, test_reason, test_targets = get_text_reason_target(test_df)

    train_bert_dataset = BERTDataset(
    texts=train_text, reasons=train_reason, targets=train_targets
    )

    valid_bert_dataset = BERTDataset(
        texts=valid_text, reasons=valid_reason, targets=valid_targets
    )

    test_bert_dataset = BERTDataset(
        texts=test_text, reasons=test_reason, targets=test_targets
    )

    if loaders:
        return [
            DataLoader(train_bert_dataset, batch_size=TRAIN_BATCH_SIZE, shuffle=True),
            DataLoader(valid_bert_dataset, batch_size=VALID_BATCH_SIZE),
            DataLoader(test_bert_dataset, batch_size=TEST_BATCH_SIZE)
        ]
    
    return [
        train_bert_dataset, valid_bert_dataset, test_bert_dataset
    ]

In [None]:
dataset = pd.read_csv(TRAINING_DATASET)
test_dataset = pd.read_csv(TESTING_DATASET)

train_dataset, valid_dataset = train_valid_split(dataset)

In [None]:
train_loader, valid_loader, test_loader = get_train_valid_test_dataset(
    train_dataset, valid_dataset, test_dataset
)

In [None]:
class BertBaseUncasedSingleSentence(nn.Module):
    def __init__(self, dropout_prob=0.2):
        super(BertBaseUncasedSingleSentence, self).__init__() 
        self.bert=BERTMODEL.to(DEVICE)
        self.bert_drop=nn.Dropout(dropout_prob)
        self.classifier=nn.Linear(768, 1)
    
    def forward(self, ids : torch.Tensor, token_type_ids : torch.Tensor, mask : torch.Tensor) -> torch.Tensor:
        """Computes the similarity between two sentences provided using the dataset format 

        Args:
            ids (torch.Tensor): Token ids to be used
            token_type_ids (torch.Tensor): Token type ids to be used
            mask (torch.Tensor): Attention mask

        Returns:
            torch.Tensor: Returns logits between 0 to 1 for computing the probability of similarity
        """
        ids = ids.to(DEVICE)
        token_type_ids = token_type_ids.to(DEVICE)
        mask=mask.to(DEVICE)
        output = self.bert(
            ids,token_type_ids,mask
        )

        embeddings = self.bert_drop(output.pooler_output)
        return self.classifier(embeddings)

In [None]:
def compute_accuracy(logits, targets, threshold=0.5, conversion=True):
    if conversion:
        logits = logits.detach().cpu().numpy()
        targets = targets.cpu().numpy()

    predictions = np.array([0 if logit < threshold else 1 for logit in logits])
    return np.sum(predictions==targets) / len(logits)


def free_gpu(ids, mask, token_type_ids, targets, logits):
    del ids 
    del mask  
    del token_type_ids
    del targets
    del logits
    torch.cuda.empty_cache()

In [None]:
# Training Batch loop

def train_batch(model, dataloader, optimizer, show_after_every_batch=100):
    batch_loss = 0.0
    accuracy = 0.0 

    for batch_idx, batch in enumerate(dataloader):
        ids, mask, token_type_ids, targets = batch.values() 
        ids = ids.to(DEVICE)
        mask = mask.to(DEVICE)
        token_type_ids = token_type_ids.to(DEVICE)
        targets = targets.to(DEVICE)

        optimizer.zero_grad()
        logits=model(ids, mask, token_type_ids).view(-1)
        loss = bce_loss(logits, targets)
        batch_loss += loss.item()
        predictions = torch.sigmoid(logits)
        accuracy += compute_accuracy(predictions, targets)

        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % show_after_every_batch == 0:
            print(f'TRAIN: After completing batch {batch_idx + 1} | batch_loss : {(batch_loss / batch_idx):.3f} | batch_accuracy: {(accuracy / batch_idx):.3f}')
        
        free_gpu(ids, mask, token_type_ids, targets, logits)
    return model, batch_loss / len(dataloader), accuracy / len(dataloader)


# Validation Batch loop

def validate_batch(model, dataloader, show_after_every_batch=10):
    batch_loss = 0.0
    accuracy = 0.0 

    model.eval()
    with torch.no_grad():
        for batch_idx, batch in enumerate(dataloader):
            ids, mask, token_type_ids, targets = batch.values() 
            ids = ids.to(DEVICE)
            mask = mask.to(DEVICE)
            token_type_ids = token_type_ids.to(DEVICE)
            targets = targets.to(DEVICE)

            logits=model(ids, mask, token_type_ids).view(-1)
            loss = bce_loss(logits, targets)
            predictions = torch.sigmoid(logits)
            accuracy += compute_accuracy(predictions, targets)

            if (batch_idx + 1) % show_after_every_batch == 0:
                print(
                    f'EVAL: After completing batch {batch_idx + 1} | batch_loss : {(batch_loss / batch_idx):.3f} | batch_accuracy: {(accuracy / batch_idx):.3f}'
            )
            
            free_gpu(ids, mask, token_type_ids, targets, logits)
    return batch_loss / len(dataloader), accuracy / len(dataloader)

In [None]:
model = BertBaseUncasedSingleSentence().to(DEVICE)

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
for epoch in range(EPOCHS):
    print("\nTRAINING\n")
    model, train_batch_loss, train_batch_acc = train_batch(model, train_loader, optimizer)
    print("\nVALIDATION\n")
    valid_batch_loss, valid_batch_acc = validate_batch(model, valid_loader, 100)
    print("\nEVALUATION\n")
    test_batch_loss, test_batch_acc = validate_batch(model, test_loader, 1000)

    print("\n=========== Epochs Results ===========")
    print(f"TRAIN: Loss : {train_batch_loss:.3f}, Acc: {train_batch_acc:.3f}")
    print(f"VALID: Loss : {valid_batch_loss:.3f}, Acc: {valid_batch_acc:.3f}")
    print(f"TEST: Loss : {test_batch_loss:.3f}, Acc: {test_batch_acc:.3f}")
    print("=======================================\n")