# Importing Libraries

In [None]:
from datasets import load_dataset
import numpy as np
import torch
from transformers import BertTokenizer, BertModel,BertForSequenceClassification
from tqdm.notebook import tqdm
from torch import nn
import torch.nn.functional as F
import transformers
import random
from sklearn.metrics import f1_score 
from sklearn.metrics import accuracy_score
import gc
from sklearn.metrics import roc_auc_score
from sklearn.metrics import f1_score
import torchmetrics
from sklearn.metrics import auc
from sklearn.metrics import precision_recall_curve, average_precision_score
from sklearn.metrics import auc

# Downloading Tokenizer and Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")

# Loading Dataset

In [None]:
dataset = load_dataset("hatexplain")

# Pre-processing

In [None]:
# Create Dataset Class
class HateXplainDataset(torch.utils.data.Dataset):
    def __init__(self, data):
        dataset={}
        
        if len(data["annotators"]) != len(data["rationales"]) or len(data["annotators"]) != len(data["post_tokens"]):
            raise AttributeError("Incorrect length in data_dict")
        
        rationale  = data['rationales']
        tokens = data['post_tokens']
        
        rationale =[[np.array(x) for x in multiple_lists] for multiple_lists in rationale]
        rat = [[str(int(round(np.mean(k)))) for k in zip(*arrays)] for arrays in rationale]
        
        for i in range(len(rat)):
            if not rat[i]:
                rat[i] = ['0' for k in range(len(tokens[i]))]
        x=[x["label"] for x in data["annotators"]]
        labels=[]
        for i in x:
            labels.append(max(set(i), key=i.count))
            
        dataset["post_tokens"]=data["post_tokens"]
        dataset["rationales"]=rat
        dataset["labels"]=labels
        self.data = dataset

    def __len__(self):
        dd = self.data
        return len(dd["post_tokens"])

    def __getitem__(self, idx):
        dd = self.data
        return dd["post_tokens"][idx], dd["rationales"][idx], dd["labels"][idx]
    

In [None]:
# Create Loader builder
def build_loader(data_dict: dict, batch_size: int = 64, shuffle: bool = False):
    ls=list(zip(data_dict["post_tokens"],data_dict["rationales"],data_dict["labels"]))
    def loader():
        if shuffle:
            random.shuffle(ls)
        for i in range(0, len(ls), batch_size):
            batch=ls[i:min(i + batch_size, len(ls))]
            tok, rat, lab = zip(*batch)
            yield tok,rat,lab

    return loader

In [None]:
# Since the tokenizer might expand some words, this functions 
# expands the corresponding rationales
def get_token_rationales(token_ls: "list[list[str]]", rationale_ls: "list[list[int]]"):
    rat=[]
    sep_id=[]
    for i in range(0,len(token_ls)):
        temp_rat=[]
        for j in range(len(token_ls[i])):
            ids=tokenizer(token_ls[i][j])["input_ids"]
            temp_rat=temp_rat+[rationale_ls[i][j]]*len(ids[1:-1])
        rat.append(temp_rat)
        sep_id.append(len(rat[i])+1)
    return rat,sep_id

In [None]:
# Takes embedded sequences and rationales and generates masked output, p is the percentage masking
# p defaulted to 15% since bert's MLM is best at 15%
# Output
# - comb: sum of sentence sequence and rationale embeddings
# - masked_indices: list of the indices of masked rationales for each sequence sentence 
def mask_emb(seq,rat,sep_id,p=0.15):
    rat=rat.detach().clone()
    rat[:,0,:]=0
    masked_indices=[]
    for i in range(len(seq)):
        rat[i,sep_id[i]:]=0
        x=round(p*(sep_id[i]-1))
        temp_masked_indices = torch.randperm(sep_id[i]-1)[:x]+1
        rat[i,temp_masked_indices]=0
        #print(temp_masked_indices)
        #print(rat[i,temp_masked_indices])
        masked_indices.append(temp_masked_indices)
    comb=seq+rat
    return comb,masked_indices

# Pre-finetuning

### MRP Class

In [None]:
# This class creates a new BERT instance and initializes a MlP head on top.
# The classifier head predicts the masked predictions for rationales
class Bert_MRP(nn.Module):

    def __init__(self, n_classes=1):
        super(Bert_MRP, self).__init__()
        self.bert = BertModel.from_pretrained("bert-base-uncased",output_hidden_states=True, output_attentions=True)
        #self.out = nn.Linear(self.bert.config.hidden_size, n_classes)
        self.out =  nn.Sequential(
                              nn.Linear(self.bert.config.hidden_size, 512),
                              nn.ReLU(),
                              nn.Linear(512, 512),
                              nn.ReLU(),
                              nn.Dropout(p=0.3),
                              nn.Linear(512, 1)
                            )
        self.criterion= nn.BCELoss()
        self.tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
        
    def forward(self, combined_encoding, attention_mask=None):
        output = self.bert.encoder(combined_encoding, attention_mask = attention_mask)
        return torch.sigmoid(self.out(output.last_hidden_state))
    
    def get_criterion(self):
        return self.criterion
    
    
    def assign_optimizer(self, **kwargs):
        # TODO: your work below
        return torch.optim.RAdam(self.parameters(),**kwargs)
    
    def tokenize(
        self,
        tok: "list[str]",
        max_length: int = 128,
        truncation: bool = True,
        padding: bool = True,
    ):
        return self.tokenizer(tok, return_tensors='pt', is_split_into_words=True, padding='longest').to(device)
    
    def slice_cls_hidden_state(
        self, x: transformers.modeling_outputs.BaseModelOutput
    ) -> torch.Tensor:
        return torch.stack([i[0] for i in x.last_hidden_state])

### MRP Train Loop

In [None]:
# Training function for the intermediate MRP stage
def train_bertMRP(model, loader, device):
    #criterion = model.get_criterion()
    total_loss = 0.0
    c=0
    for tokens, rationale, target in tqdm(loader()):
        c+=1
        optimizer.zero_grad()

        inputs = model.tokenize(tokens).to(device)
        
        rationale_in = rationale
        
        
        rationale_in,sep_id=get_token_rationales(tokens,rationale_in)
        rationale_token = model.tokenize(rationale_in).to(device)
        
                    
        with torch.no_grad():
            output_input = model.bert(**inputs)
            output_rat = model.bert(**rationale_token)
            attention = output_input["attentions"][0]
            combined_encoding,masked_indices=mask_emb(output_input["hidden_states"][0],output_rat["hidden_states"][0],sep_id, p =0.5)

        
        combined_encoding = combined_encoding.to(device)
        pred = model(combined_encoding, attention_mask = attention)
        
        rationale_l = [model.tokenizer.convert_ids_to_tokens(ids) for ids in rationale_token['input_ids']]
        
        for r in rationale_l:
            for i in range(len(r)):
                if r[i]!='1' and r[i]!='0':
                    r[i] = 0
                else:
                    r[i] = int(r[i])

        rationale_l=torch.tensor(rationale_l, dtype=torch.float).to(device)
        pred=torch.squeeze(pred).to(device)
        
        weights = rationale_l.clone()
        for i in range(len(weights)):
            for j in range(len(weights[i])):
                if j not in masked_indices[i]:
                    weights[i,j] = 0
                else: 
                    weights[i,j] = 1
        
        criterion = nn.BCELoss(weight=weights)
        loss = criterion(pred, rationale_l)

        loss.backward()
        optimizer.step() 

        total_loss += loss.item()
        del pred
        del rationale_in
        del inputs
        del combined_encoding
        gc.collect()
        torch.cuda.empty_cache()
    return total_loss / c



# Evaluation function for the intermediate MRP stage
@torch.no_grad()
def eval_bertRP(model, loader, device):
    model.eval()
    criterion = model.get_criterion()
    targets = []
    preds = []
    c=0
    total_loss=0
    for tokens, rationale, target in loader():
        c+=1

        inputs = model.tokenize(tokens).to(device)
        
        rationale_in = rationale
        
        
        rationale_in,sep_id=get_token_rationales(tokens,rationale_in)
        rationale_token = model.tokenize(rationale_in).to(device)
        

        output_input = model.bert(**inputs)
        combined_encoding=output_input["hidden_states"][0]
        

        combined_encoding = combined_encoding.to(device)
        pred = model(combined_encoding, attention_mask = None)
 
        preds+=pred
        rationale_l = [model.tokenizer.convert_ids_to_tokens(ids) for ids in rationale_token['input_ids']]
        
        for r in rationale_l:
            for i in range(len(r)):
                if r[i]!='1' and r[i]!='0':
                    r[i] = 0
                else:
                    r[i] = int(r[i])
        targets+=rationale_in
        loss = criterion(torch.squeeze(pred).to(device), torch.tensor(rationale_l, dtype=torch.float).to(device))

        total_loss += loss.item()    
    return preds, targets, total_loss/c    



In [None]:
# Initializing loaders
batch_size=64
train_loader=build_loader(HateXplainDataset(dataset["train"]).data,batch_size=batch_size,shuffle=True)
valid_loader=build_loader(HateXplainDataset(dataset["validation"]).data,batch_size=batch_size,shuffle=False)

In [None]:
# Initializing training parameters
n_epochs=20
model_mrp=Bert_MRP()
model_mrp=model_mrp.to(device)
model_mrp.train()
optimizer = model_mrp.assign_optimizer(lr=5e-5)
v_loss_threshhold = 100
v_loss = 99

In [None]:
# Training loop for MRP training stage
for epoch in range(n_epochs):
        v_loss_threshhold = v_loss
        print("Epoch:", epoch)
        loss = train_bertMRP(model_mrp, train_loader, device=device)
        preds, targets, v_loss = eval_bertRP(model_mrp, valid_loader, device=device)
     
        print("Training loss:", loss)
        print("Validation loss:", v_loss)
        
        preds = [x.tolist()[1:len(y)+1] for x,y in zip(preds, targets)]

        pred_prob = [[i[0] for i in nested] for nested in preds]
        preds = [[int(round(i[0])) for i in nested] for nested in preds]
        
        preds =[element for sublist in preds for element in sublist]
        pred_prob = [element for sublist in pred_prob for element in sublist]
        targets = [int(element) for sublist in targets for element in sublist]
        print("Avg precision", average_precision_score(targets, pred_prob))
        precision, recall, thresholds = precision_recall_curve(preds, targets)
        auc_precision_recall = auc(recall, precision)
        epoch+=1
        
        print("AUCPR",auc_precision_recall)


# Finetunning

### Final Classifier Class

In [None]:
# This class initializes using the BERT created from the previous intermediate stage
# It creates a new head on top of BERT, and classifies input into the final three labels
# Output classes: "Normal", "Offensive", "Hate Speech"
class BERT_HSD(nn.Module):

    def __init__(self, bert):
        super(BERT_HSD, self).__init__()
        self.bert = bert 
        #self.out = nn.Linear(self.bert.config.hidden_size, 3)
        self.out =  nn.Sequential(
                              nn.Linear(self.bert.config.hidden_size, 512),
                              nn.ReLU(),
                              nn.Linear(512, 512),
                              nn.ReLU(),
                              nn.Dropout(p=0.2),
                              nn.Linear(512, 3)
                            )
        self.softmax = nn.LogSoftmax(dim=1)
        self.tokenizer=BertTokenizer.from_pretrained('bert-base-uncased')
        self.criterion= nn.CrossEntropyLoss()
        
    def slice_cls_hidden_state(
        self, x: transformers.modeling_outputs.BaseModelOutput
    ) -> torch.Tensor:
        return torch.stack([i[0] for i in x.last_hidden_state])
    
    def get_criterion(self):
        return self.criterion
    
    def assign_optimizer(self, **kwargs):
        # TODO: your work below
        return torch.optim.RAdam(self.parameters(),**kwargs)
    
    def tokenize(
        self,
        tok: "list[str]",
        max_length: int = 128,
        truncation: bool = True,
        padding: bool = True,
    ):
        return tokenizer(tok, return_tensors='pt', is_split_into_words=True, padding='longest')

    def forward(self, inputs):
        #pass the inputs to the model  
        x = self.bert(**inputs)
        x = self.slice_cls_hidden_state(x)
        x = self.out(x)
        # apply softmax activation
        x = self.softmax(x)
        return x#torch.reshape(x,(-1,))

In [None]:
# Initialize Loaders
batch_size=64
train_loader=build_loader(HateXplainDataset(dataset["train"]).data,batch_size=batch_size,shuffle=True)
valid_loader=build_loader(HateXplainDataset(dataset["validation"]).data,batch_size=batch_size,shuffle=False)
test_loader=build_loader(HateXplainDataset(dataset["test"]).data,batch_size=batch_size,shuffle=False)

In [None]:
# Training dunction for the final finetuning stage
def train_berthsd(model, loader, device):
    criterion = model.get_criterion()
    total_loss = 0.0
    c=0
    for token, rat, target in tqdm(loader()):
        c+=1
        optimizer.zero_grad()

        inputs = model.tokenize(token).to(device)
        
        target= torch.tensor(target,dtype=torch.long)
        target = target.to(device, dtype=torch.long)

        pred = model(inputs)
                
        loss = criterion(pred, target)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / c

# Evaluation loop used for validation and test datasets
@torch.no_grad()
def eval_berthsd(model, loader, device):
    model.eval()

    targets = []
    preds = []
    outs = []
    for token, rat, target in loader():
        inputs = model.tokenize(token).to(device)
        out=model(inputs)
        _,pred=torch.max(out, dim = 1)   
        preds.append(pred)
        target= torch.tensor(target,dtype=torch.long)
        target = target.to(device, dtype=torch.long)
        targets.append(target)
        outs.append(out)

    return torch.cat(preds), torch.cat(targets), torch.exp(torch.cat(outs))


In [None]:
# Initial training parameters
n_epochs=10
model=model_mrp.bert
model_hsd=BERT_HSD(model)
model_hsd=model_hsd.to(device)
optimizer = model_hsd.assign_optimizer(lr=2e-6)

v_score_threshhold = 0
score = 0.00000001

# Finetuning training loop
for epoch in range(n_epochs):
        v_score_threshhold = score
        loss = train_berthsd(model_hsd, train_loader, device=device)

        preds, targets,_= eval_berthsd(model_hsd, valid_loader, device=device)
        #preds = preds.round()
        score = accuracy_score(targets.cpu(), preds.cpu())
        print("Epoch:", epoch)
        print("Training loss:", loss)
        print("Validation F1 score:", score)
        print()
        
        
        
        preds, targets, outs= eval_berthsd(model_hsd, test_loader, device=device)
        acc_score = accuracy_score(targets.cpu(), preds.cpu())
        auroc = roc_auc_score(targets.cpu(),outs.cpu(), multi_class='ovr')
        mf1=f1_score(targets.cpu(), preds.cpu(), average='macro')
        pr_curve = torchmetrics.PrecisionRecallCurve(task="multiclass", num_classes=3)
        precision, recall, thresholds = pr_curve(outs, targets)
        auprc=np.average([auc(recall[i].cpu(), precision[i].cpu()) for i in range(3)])

        print("Test Accuracy Score:", acc_score)
        print("Test AUROC:", auroc)
        print("Test Macro-F1:", mf1)
        print("Test AUPRC:", auprc)
        epoch+=1
        
        
        
        preds, targets, v_loss = eval_bertRP(model_mrp, valid_loader, device=device)
     

        print("Validation loss:", v_loss)
        
        preds = [x.tolist()[1:len(y)+1] for x,y in zip(preds, targets)]
        pred_prob = [[i[0] for i in nested] for nested in preds]
        preds = [[int(round(i[0])) for i in nested] for nested in preds]
        
        preds =[element for sublist in preds for element in sublist]
        pred_prob = [element for sublist in pred_prob for element in sublist]
        targets = [int(element) for sublist in targets for element in sublist]
        print("Avg precision", average_precision_score(targets, pred_prob))

        precision, recall, thresholds = precision_recall_curve(preds, targets)

        auc_precision_recall = auc(recall, precision)
        print()