## Dependencies and Environment Setup

In [None]:
try:
    import transformers
except ImportError as e:
    print('transformers not installed')
    print('Installing now...')
    !pip install -q git+https://github.com/huggingface/transformers.git
    print("Install complete.")
    pass  

In [None]:
import torch
import io 
import os
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.metrics import classification_report,accuracy_score
import transformers
import json
from tqdm.notebook import tqdm
from transformers.utils.dummy_pt_objects import AutoModelForSequenceClassification
from transformers import AutoModelForTokenClassification,AutoConfig,AutoModel,AutoTokenizer,BertModel,BertConfig,AdamW, get_constant_schedule,BertForSequenceClassification,get_linear_schedule_with_warmup
import random
import numpy as np
import torch.nn as nn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

#Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

In [None]:
# if using Google Colab, set colab = True
colab = False

if colab == True:
    #Mounting Drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    
    %cd '/content/gdrive/Shareddrives/523 Project/Data'
    %ls
else:
    DATA_DIR = '/projectnb2/dl523/students/kjv/EC523_Project/Data/'

## Define Model Class and Training, Validation, and Testing Functions

In [None]:
class multihead_attn_bert(nn.Module):

    def __init__(self, bert_encoder, embed_dim, num_attn_layers, num_heads):
        super(multihead_attn_bert, self).__init__()
        
        self.bert = bert_encoder
        
        self.multiheads = nn.ModuleList([nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)])
        self.multiheads.extend([nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) for i in range(num_attn_layers-1)])
        
        self.GRU = nn.GRU(input_size=embed_dim, hidden_size=512, bidirectional=True, batch_first=True)
        
        # 1024 = hidden_size of GRU x 2 (for bidirectionality of GRU)
        self.fc = nn.Linear(in_features=1024, out_features=1)
        
        # use dropout set of 0.2 as in paper
        self.dropout = nn.Dropout(p=0.2)
        
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, tokenized_input_values, attention_mask):
        
        output = self.bert(tokenized_input_values, attention_mask=attention_mask).last_hidden_state
        
        for multihead_layer in self.multiheads:
            output,_ = multihead_layer(query=output,key=output,value=output,key_padding_mask=(~attention_mask.bool()))
        
        _,hidden = self.GRU(output)
        
        # concatenate bidirectional outputs from GRU to pass to linear layer
        hidden = torch.cat([hidden[0,:, :], hidden[1,:,:]], dim=1).unsqueeze(0)
        
        output = self.fc(hidden)
        
        output = self.dropout(output)
        
        output = self.sigmoid(output)
        
        return output

In [None]:
# define training, testing, and validation loss functions for headlines data

def train_mh_bert_headlines(model, trainloader, validationloader, optimizer, criterion, num_epochs, scheduler=None):
        
    avg_val_losses = []
    avg_training_losses = []
    epochs_finished = []    
    
    # conditions for early stopping
    last_val_loss = float('inf')
    min_val_loss = float('inf')
    patience = 3
    es_counter = 0

    print("Starting training...")

    for epoch in range(1, num_epochs+1):

        model.train()
        if scheduler != None:
            print("Learning rate: ", scheduler.get_last_lr())

        running_loss = 0
        curr_total_train_loss = 0
        print('Epoch: ',epoch)

        for idx, (inputs,attn_mask,labels) in enumerate(tqdm(trainloader,total = len(trainloader))):

            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)

            optimizer.zero_grad()

            output = mh_sarcasm_model(inputs, attn_mask)
            output = torch.flatten(output)

            # convert label type from int to float for use in BCELoss
            labels = labels.float()
            loss = criterion(output,labels)
            curr_total_train_loss += loss.item()

            loss.backward()
            running_loss += loss.item()
            optimizer.step()

            # print loss every 100 batches
            if idx % 100 == 0:
                print('Loss: ',float(loss))
                running_loss = 0.0

        epochs_finished.append(epoch)
        avg_training_losses.append(curr_total_train_loss/len(trainloader))
        
        # adjust scheduler after every epoch
        if scheduler != None:
            scheduler.step()

        # check for changes in avg validation loss to determine if early stopping is needed
        print("Checking validation loss...")
        curr_val_loss = validation_loss_headlines(model, validationloader)
        avg_val_losses.append(curr_val_loss)
        print("Average validation loss after last epoch: ", curr_val_loss)

        if curr_val_loss > last_val_loss:
            es_counter += 1

            if es_counter >= patience:
                print("Early stopping triggered. Ending training..")
                
                # plot training and validation losses
                plt.plot(epochs_finished, avg_training_losses, label = "Training Loss")
                plt.plot(epochs_finished, avg_val_losses, label = "Validation Loss")
                plt.title("Training and Validation Loss: Multihead Self-Attention Model")
                plt.ylabel("Loss")
                plt.xlabel("Epoch")
                plt.legend()
                return
            else:
                print(f"Increase in validation loss! {patience-es_counter} more consecutive loss increase(s) until early stop.")

        else:
            print("Decrease in validation loss. Early stop counter reset to 0.")
            es_counter = 0

        last_val_loss = curr_val_loss

        # check to save model if validation loss is lower than min recorded validation loss
        if curr_val_loss < min_val_loss:
            print("New best validation loss - saving model.")
            min_val_loss = curr_val_loss
            save_dir = "/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/trained_MH_BERT_headlines.pth"
            torch.save(model.state_dict(), save_dir)
            
    # plot training and validation losses
    plt.plot(epochs_finished, avg_training_losses, label = "Training Loss")
    plt.plot(epochs_finished, avg_val_losses, label = "Validation Loss")
    plt.title("Training and Validation Loss: Multihead Self-Attention Model")
    plt.ylabel("Loss")
    plt.xlabel("Epoch")
    plt.legend()
    

def test_mh_bert_headlines(model, testloader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in testloader:
            
            inputs, attn_mask, labels = data
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            
            output = torch.flatten(model(inputs, attn_mask))
            
            # convert output to class predictions
            output[output<0.5] = 0
            output[output>=0.5] = 1
            
            total += labels.size(0)
            correct += (output==labels).float().sum().item()
            
        acc = correct/total * 100
    
    return acc

def validation_loss_headlines(model, validationloader):
    model.eval()
    total_val_loss = 0
    criterion = nn.BCELoss()
    
    with torch.no_grad():
        for data in validationloader:
            
            inputs, attn_mask, labels = data
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            labels = labels.float()
            
            output = torch.flatten(model(inputs, attn_mask))
            loss = criterion(output, labels)
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss/len(validationloader)
    
    return avg_val_loss

In [None]:
# define training, testing, and validation loss functions for Reddit data

def train_mh_bert_reddit(model, trainloader, validationloader, optimizer, criterion, num_epochs, scheduler=None):
        
    avg_val_losses = []
    avg_training_losses = []
    epochs_finished = []    
    
    # conditions for early stopping
    last_val_loss = float('inf')
    min_val_loss = float('inf')
    patience = 3
    es_counter = 0

    print("Starting training...")

    for epoch in range(1, num_epochs+1):

        model.train()
        if scheduler != None:
            print("Learning rate: ", scheduler.get_last_lr())

        running_loss = 0
        curr_total_train_loss = 0
        print('Epoch: ',epoch)

        for idx, (encodings, labels) in enumerate(tqdm(trainloader,total = len(trainloader))):

            inputs = encodings['input_ids']
            attn_mask = encodings['attention_mask']
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)

            optimizer.zero_grad()

            output = mh_sarcasm_model(inputs, attn_mask)
            output = torch.flatten(output)

            # convert label type from int to float for use in BCELoss
            labels = labels.float()
            loss = criterion(output,labels)
            curr_total_train_loss += loss.item()

            loss.backward()
            running_loss += loss.item()
            optimizer.step()

            # print loss every 100 batches
            if idx % 100 == 0:
                print('Loss: ',float(loss))
                running_loss = 0.0

        epochs_finished.append(epoch)
        avg_training_losses.append(curr_total_train_loss/len(trainloader))
        
        # adjust scheduler after every epoch
        if scheduler != None:
            scheduler.step()

        # check for changes in avg validation loss to determine if early stopping is needed
        print("Checking validation loss...")
        curr_val_loss = validation_loss_reddit(model, validationloader)
        avg_val_losses.append(curr_val_loss)
        print("Average validation loss after last epoch: ", curr_val_loss)

        if curr_val_loss > last_val_loss:
            es_counter += 1

            if es_counter >= patience:
                print("Early stopping triggered. Ending training..")
                
                # plot training and validation losses
                plt.plot(epochs_finished, avg_training_losses, label = "Training Loss")
                plt.plot(epochs_finished, avg_val_losses, label = "Validation Loss")
                plt.title("Training and Validation Loss: Multihead Self-Attention Model")
                plt.ylabel("Loss")
                plt.xlabel("Epoch")
                plt.legend()
                return
            else:
                print(f"Increase in validation loss! {patience-es_counter} more consecutive loss increase(s) until early stop.")

        else:
            print("Decrease in validation loss. Early stop counter reset to 0.")
            es_counter = 0

        last_val_loss = curr_val_loss

        # check to save model if validation loss is lower than min recorded validation loss
        if curr_val_loss < min_val_loss:
            print("New best validation loss - saving model.")
            min_val_loss = curr_val_loss
            save_dir = "/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/trained_MH_BERT_reddit.pth"
            torch.save(model.state_dict(), save_dir)
            
    # plot training and validation losses
    plt.plot(epochs_finished, avg_training_losses, label = "Training Loss")
    plt.plot(epochs_finished, avg_val_losses, label = "Validation Loss")
    plt.title("Training and Validation Loss: Multihead Self-Attention Model")
    plt.ylabel("Loss")
    plt.xlabel("Epoch")
    plt.legend()
    

def test_mh_bert_reddit(model, testloader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for encodings, labels in testloader:
            
            inputs = encodings['input_ids']
            attn_mask = encodings['attention_mask']
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            
            output = torch.flatten(model(inputs, attn_mask))
            
            # convert output to class predictions
            output[output<0.5] = 0
            output[output>=0.5] = 1
            
            total += labels.size(0)
            correct += (output==labels).float().sum().item()
            
        acc = correct/total * 100
    
    return acc

def validation_loss_reddit(model, validationloader):
    model.eval()
    total_val_loss = 0
    criterion = nn.BCELoss()
    
    with torch.no_grad():
        for encodings, labels in validationloader:
            
            inputs = encodings['input_ids']
            attn_mask = encodings['attention_mask']
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            labels = labels.float()
            
            output = torch.flatten(model(inputs, attn_mask))
            loss = criterion(output, labels)
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss/len(validationloader)
    
    return avg_val_loss

## Model Initialization and Configuration

In [None]:
# initialize pre-trained BERT

bertconfig = BertConfig()
bert = BertModel.from_pretrained("bert-base-uncased")

In [None]:
# freeze pre-trained layers in BERT

for param in bert.parameters():
    param.requires_grad = False

In [None]:
# initialize multihead attention sarcasm model with BERT embedder

# embed_dim = 768 if using bert_base, 1024 for bert_large
mh_sarcasm_model = multihead_attn_bert(bert, embed_dim=768, num_attn_layers=3, num_heads=8)
mh_sarcasm_model.to(device)

# save untrained model weights
torch.save(mh_sarcasm_model.state_dict(), "/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/untrained_mhbert.pth")

## Headlines Data Import and Tokenization

In [None]:
# Reading in headlines data
df = pd.read_json(DATA_DIR + "News Headlines/Sarcasm_Headlines_Dataset_v2.json",lines = True)
df = df.rename(columns={'is_sarcastic': 'label'})
df = df.drop('article_link', 1)
df.head()

# splits for headlines training, test, and validation

train_headlines, temporary_text, train_label, temporary_label = train_test_split(df['headline'], df['label'], 
                                                                    random_state=200, 
                                                                    test_size=0.2, 
                                                                    stratify=df['label'])







validation_headlines, test_headlines, validation_label, test_label = train_test_split(temporary_text, temporary_label, 
                                                                    random_state=200, 
                                                                    test_size=0.5, 
                                                                    stratify=temporary_label)

In [None]:
# set max length for padding/clipping during tokenization
max_length = 35

# create tokenized training, validation, and test splits
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

training_tokens = tokenizer.__call__(train_headlines.tolist(),max_length = max_length,padding = True,truncation = True)
validation_tokens = tokenizer.__call__(validation_headlines.tolist(),max_length = max_length,padding = True,truncation = True)
test_tokens = tokenizer.__call__(test_headlines.tolist(),max_length = max_length,padding= True,truncation = True)

# Stacking the inputs as tensors for use in the BERT model

training_set = TensorDataset(torch.tensor(training_tokens['input_ids']),torch.tensor(training_tokens['attention_mask']),torch.tensor(train_label.tolist()))
validation_set = TensorDataset(torch.tensor(validation_tokens['input_ids']),torch.tensor(validation_tokens['attention_mask']),torch.tensor(validation_label.tolist()))
test_set = TensorDataset(torch.tensor(test_tokens['input_ids']),torch.tensor(test_tokens['attention_mask']),torch.tensor(test_label.tolist()))

## Training for Headlines Data

In [None]:
# Dataloaders for the headlines sets
batch_size = 64

trainloader = DataLoader(training_set, batch_size = batch_size, num_workers=2, shuffle = True)
validationloader = DataLoader(validation_set, batch_size = batch_size, num_workers=2, shuffle = True)
testloader = DataLoader(test_set, batch_size = batch_size, num_workers=2, shuffle = True)

# Loss function
criterion = nn.BCELoss()

In [None]:
# Training mh_sarcasm bert
num_epochs = 50

# optimizer using learning rate from multihead reference paper
tuning_parameters = [parameter for parameter in mh_sarcasm_model.parameters() if parameter.requires_grad]
optimizer = torch.optim.Adam(tuning_parameters,lr = 1e-4)

# train the model with early stopping (add scheduler once done)
train_mh_bert_headlines(mh_sarcasm_model, trainloader, validationloader, optimizer, criterion, num_epochs)

## Testing on Headlines Data

In [None]:
# load model with lowest validation loss
mh_sarcasm_model.load_state_dict(torch.load("/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/trained_MH_BERT_headlines.pth"))

# test mh_sarcasm_bert
acc = test_mh_bert_headlines(mh_sarcasm_model, testloader)
print(f"Accuracy of network: {acc}")

## Reddit Data Import and Tokenization

In [None]:
# import Reddit data classes and functions
import importlib.util
fxns_filepath = "/projectnb/dl523/students/kjv/EC523_Project/sarcasm-detector/reddit_sarcasm/reddit_bert_functions.py"

spec = importlib.util.spec_from_file_location("reddit_bert_functions", fxns_filepath)
reddit_mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(reddit_mod)

In [None]:
csv_path = "/projectnb/dl523/students/kjv/EC523_Project/Data/Sarcasm_Reddit/train-balanced-sarcasm.csv"
x_train, y_train, x_val, y_val, x_test, y_test = reddit_mod.split_reddit_data(csv_path)

In [None]:
# word count per each sample for determining max length for tokenization
count = x_train.str.split().str.len()
plt.hist(count, bins=30, range=(0, 100))

In [None]:
# tokenize the reddit data
max_length = 35  #based on word count bar plot above, 35 is reasonable

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

reddit_train = reddit_mod.Reddit(x_train, y_train, tokenizer, max_length)
reddit_val = reddit_mod.Reddit(x_val, y_val, tokenizer, max_length)
reddit_test = reddit_mod.Reddit(x_test, y_test, tokenizer, max_length)

## Training for Reddit Data

In [None]:
# initialize Dataloaders
batch_size = 64
num_workers = 2

trainloader, validationloader, testloader = reddit_mod.get_data_loaders(reddit_train, reddit_val, reddit_test, batch_size, num_workers)

# define loss function
criterion = nn.BCELoss()

In [None]:
num_epochs = 50

# optimizer using learning rate from multihead reference paper
tuning_parameters = [parameter for parameter in mh_sarcasm_model.parameters() if parameter.requires_grad]
optimizer = torch.optim.Adam(tuning_parameters,lr = 1e-4)

# set mh_sarcasm_model to pre-training weights
mh_sarcasm_model.load_state_dict(torch.load("/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/untrained_mhbert.pth"))

# train the model with early stopping (add scheduler once done)
train_mh_bert_reddit(mh_sarcasm_model, trainloader, validationloader, optimizer, criterion, num_epochs)

## Testing on Reddit Data

In [None]:
# load model with lowest validation loss
mh_sarcasm_model.load_state_dict(torch.load("/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/trained_MH_BERT_reddit.pth"))

# test mh_sarcasm_bert
acc = test_mh_bert_reddit(mh_sarcasm_model, testloader)
print(f"Accuracy of network: {acc}")