In [1]:
try:
    import transformers
except ImportError as e:
    print('transformers not installed')
    print('Installing now...')
    !pip install -q git+https://github.com/huggingface/transformers.git
    print("Install complete.")
    pass  

In [2]:
import torch
import io 
import os
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.metrics import classification_report,accuracy_score
import transformers
import json
from tqdm.notebook import tqdm
from transformers.utils.dummy_pt_objects import AutoModelForSequenceClassification
from transformers import AutoModelForTokenClassification,AutoConfig,AutoModel,AutoTokenizer,BertModel,BertConfig,AdamW, get_constant_schedule,BertForSequenceClassification,get_linear_schedule_with_warmup
import random
import numpy as np
import torch.nn as nn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

#Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# if using Google Colab, set colab = True
colab = False

if colab == True:
    #Mounting Drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    
    %cd '/content/gdrive/Shareddrives/523 Project/Data'
    %ls
else:
    DATA_DIR = '/projectnb2/dl523/students/kjv/EC523_Project/Data/'

In [4]:
#Reading in the data
df = pd.read_json(DATA_DIR + "News Headlines/Sarcasm_Headlines_Dataset_v2.json",lines = True)
df = df.rename(columns={'is_sarcastic': 'label'})
df = df.drop('article_link', 1)
df.head()

#splits for training test validation

train_headlines, temporary_text, train_label, temporary_label = train_test_split(df['headline'], df['label'], 
                                                                    random_state=200, 
                                                                    test_size=0.2, 
                                                                    stratify=df['label'])







validation_headlines, test_headlines, validation_label, test_label = train_test_split(temporary_text, temporary_label, 
                                                                    random_state=200, 
                                                                    test_size=0.5, 
                                                                    stratify=temporary_label)

In [5]:
# set max length for padding/clipping during tokenization
max_length = 35

# create tokenized training, validation, and test splits
tokenizer = AutoTokenizer.from_pretrained("bert-large-uncased")

training_tokens = tokenizer.__call__(train_headlines.tolist(),max_length = max_length,padding = True,truncation = True)
validation_tokens = tokenizer.__call__(validation_headlines.tolist(),max_length = max_length,padding = True,truncation = True)
test_tokens = tokenizer.__call__(test_headlines.tolist(),max_length = max_length,padding= True,truncation = True)

# Stacking the inputs as tensors for use in the BERT model

training_set = TensorDataset(torch.tensor(training_tokens['input_ids']),torch.tensor(training_tokens['attention_mask']),torch.tensor(train_label.tolist()))
validation_set = TensorDataset(torch.tensor(validation_tokens['input_ids']),torch.tensor(validation_tokens['attention_mask']),torch.tensor(validation_label.tolist()))
test_set = TensorDataset(torch.tensor(test_tokens['input_ids']),torch.tensor(test_tokens['attention_mask']),torch.tensor(test_label.tolist()))

In [6]:
# initialize pre-trained BERT large

bertconfig = BertConfig()
bert_large = BertModel.from_pretrained("bert-large-uncased")

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertModel: ['cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [7]:
# freeze pre-trained layers in BERT

for param in bert_large.parameters():
    param.requires_grad = False

In [33]:
class multihead_attn_bert(nn.Module):

    def __init__(self, bert_encoder, embed_dim, num_attn_layers, num_heads):
        super(multihead_attn_bert, self).__init__()
        
        self.bert = bert_encoder
        
        self.multiheads = nn.ModuleList([nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)])
        self.multiheads.extend([nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) for i in range(num_attn_layers-1)])
        
        self.GRU = nn.GRU(input_size=embed_dim, hidden_size=512, bidirectional=True, batch_first=True)
        
        # 1024 = hidden_size of GRU x 2 (for bidirectionality of GRU)
        self.fc = nn.Linear(in_features=1024, out_features=1)
        
        # use dropout set of 0.2 as in paper
        #self.dropout = nn.Dropout(p=0.2)
        
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, tokenized_input_values, attention_mask):
        
        output = self.bert(tokenized_input_values, attention_mask=attention_mask).last_hidden_state
        
        for multihead_layer in self.multiheads:
            output,_ = multihead_layer(query=output,key=output,value=output,key_padding_mask=(~attention_mask.bool()))
        
        _,hidden = self.GRU(output)
        
        # concatenate bidirectional outputs from GRU to pass to linear layer
        hidden = torch.cat([hidden[0,:, :], hidden[1,:,:]], dim=1).unsqueeze(0)
        
        output = self.fc(hidden)
        
        #output = self.dropout(output)
        
        output = self.sigmoid(output)
        
        return output

In [34]:
# initialize multihead attention sarcasm model with BERT large embedder

# change 1024 to 768 if using bert_base
mh_sarcasm_model = multihead_attn_bert(bert_large, 1024, 3, 8)
mh_sarcasm_model.to(device)

multihead_attn_bert(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 1024, padding_idx=0)
      (position_embeddings): Embedding(512, 1024)
      (token_type_embeddings): Embedding(2, 1024)
      (LayerNorm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=True)
              (key): Linear(in_features=1024, out_features=1024, bias=True)
              (value): Linear(in_features=1024, out_features=1024, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=1024, out_features=1024, bias=True)
              (LayerNorm): LayerNorm((1024,), eps=1e-12, elem

In [22]:
#Dataloaders for the sets

batch_size = 64

trainloader = DataLoader(training_set, batch_size = batch_size, num_workers=2, shuffle = True)
validationloader = DataLoader(validation_set, batch_size = batch_size, num_workers=2, shuffle = True)
testloader = DataLoader(test_set, batch_size = batch_size, num_workers=2, shuffle = True)

#Loss function (used in reference multihead attention paper)
criterion = nn.BCELoss()

In [35]:
# define training, testing, and validation loss functions

def train_mh_bert(model, trainloader, validationloader, optimizer, criterion, num_epochs, scheduler=None):
        
        # conditions for early stopping
        last_val_loss = float('inf')
        patience = 3
        es_counter = 0
        
        print("Starting training...")
        
        for epoch in range(1, num_epochs+1):
    
            model.train()
            if scheduler != None:
                print("Learning rate: ", scheduler.get_last_lr())
            
            running_loss = 0
            print('Epoch: ',epoch)

            for idx, (inputs,attention_mask,label) in enumerate(tqdm(trainloader,total = len(trainloader))):

                inputs, attention_mask, label = inputs.to(device),attention_mask.to(device),label.to(device)

                optimizer.zero_grad()

                output = mh_sarcasm_model(inputs, attention_mask)
                output = torch.flatten(output)

                # convert label type from int to float for use in BCELoss
                label = label.float()
                loss = criterion(output,label)

                loss.backward()
                running_loss += loss.item()
                optimizer.step()

                # print loss every 100 batches
                if idx % 100 == 0:
                    print('Loss: ',float(loss))
                    running_loss = 0.0
            
            # adjust scheduler after every epoch
            if scheduler != None:
                scheduler.step()
                
            # save model after every epoch
            save_dir = "/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/trained_MH_BERT.pth"
            torch.save(model.state_dict(), save_dir)

            # check for changes in total validation loss to determine if early stopping is needed
            print("Checking validation loss...")
            curr_val_loss = validation_loss(model, validationloader)
            print("Average validation loss after last epoch: ", curr_val_loss)

            if curr_val_loss > last_val_loss:
                es_counter += 1

                if es_counter >= patience:
                    print("Early stopping triggered. Ending training..")
                    return
                else:
                    print(f"Increase in validation loss! {patience-es_counter} more consecutive loss increase(s) until early stop.")

            else:
                print("Decrease in validation loss. Early stop counter reset to 0.")
                es_counter = 0

            last_val_loss = curr_val_loss

def test_mh_bert(model, testloader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in testloader:
            
            inputs, attn_mask, labels = data
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            
            output = torch.flatten(model(inputs, attn_mask))
            
            # convert output to class predictions
            output[output<0.5] = 0
            output[output>=0.5] = 1
            
            total += labels.size(0)
            correct += (output==labels).float().sum().item()
            
        acc = correct/total * 100
    
    return acc

def validation_loss(model, validationloader):
    model.eval()
    total_val_loss = 0
    criterion = nn.BCELoss()
    
    with torch.no_grad():
        for data in validationloader:
            
            inputs, attn_mask, labels = data
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            labels = labels.float()
            
            output = torch.flatten(model(inputs, attn_mask))
            loss = criterion(output, labels)
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss/len(validationloader)
    
    return avg_val_loss

In [36]:
#Training mh_sarcasm bert
num_epochs = 50

# optimizer using learning rate from reference paper
tuning_parameters = [parameter for parameter in mh_sarcasm_model.parameters() if parameter.requires_grad]
optimizer = torch.optim.Adam(tuning_parameters,lr = 1e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = 0,num_training_steps = len(trainloader)*Epochs)

# train the model with early stopping (add scheduler once done)
train_mh_bert(mh_sarcasm_model, trainloader, validationloader, optimizer, criterion, num_epochs)

Starting training...
Epoch:  1


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.687972366809845
Loss:  0.4920232892036438
Loss:  0.44000837206840515
Loss:  0.32191818952560425
Checking validation loss...
Average validation loss after last epoch:  0.3535509688986672
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  2


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.34363505244255066
Loss:  0.29265064001083374
Loss:  0.35571590065956116
Loss:  0.32597094774246216
Checking validation loss...
Average validation loss after last epoch:  0.31470103363196056
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  3


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.43116313219070435
Loss:  0.4073447585105896
Loss:  0.23980936408042908
Loss:  0.16097748279571533
Checking validation loss...
Average validation loss after last epoch:  0.25843374994066026
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  4


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.24841777980327606
Loss:  0.224084734916687
Loss:  0.3205772936344147
Loss:  0.2939577102661133
Checking validation loss...
Average validation loss after last epoch:  0.2553435183233685
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  5


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.278423935174942
Loss:  0.3267265558242798
Loss:  0.18211877346038818
Loss:  0.3260405659675598
Checking validation loss...
Average validation loss after last epoch:  0.2739714854293399
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  6


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.2212839424610138
Loss:  0.24746596813201904
Loss:  0.22732441127300262
Loss:  0.24572569131851196
Checking validation loss...
Average validation loss after last epoch:  0.24341251651446025
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  7


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.34916549921035767
Loss:  0.30558091402053833
Loss:  0.12516215443611145
Loss:  0.2945457100868225
Checking validation loss...
Average validation loss after last epoch:  0.279259838991695
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  8


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.2583935260772705
Loss:  0.2532022297382355
Loss:  0.2221299260854721
Loss:  0.17339572310447693
Checking validation loss...
Average validation loss after last epoch:  0.2845269375377231
Increase in validation loss! 1 more consecutive loss increase(s) until early stop.
Epoch:  9


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.2381218820810318
Loss:  0.18849679827690125
Loss:  0.24173861742019653
Loss:  0.3654114603996277
Checking validation loss...
Average validation loss after last epoch:  0.23871274292469025
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  10


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.2823105752468109
Loss:  0.25548726320266724
Loss:  0.1877935379743576
Loss:  0.19111143052577972
Checking validation loss...
Average validation loss after last epoch:  0.2880510431196954
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  11


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.14366579055786133
Loss:  0.22477151453495026
Loss:  0.2320987582206726
Loss:  0.14392969012260437
Checking validation loss...
Average validation loss after last epoch:  0.2572141614225176
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  12


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.24185577034950256
Loss:  0.14140695333480835
Loss:  0.15432396531105042
Loss:  0.12921954691410065
Checking validation loss...
Average validation loss after last epoch:  0.242532836066352
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  13


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.06344521045684814
Loss:  0.15792135894298553
Loss:  0.2288358360528946
Loss:  0.1911040097475052
Checking validation loss...
Average validation loss after last epoch:  0.2334957629442215
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  14


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.11442434787750244
Loss:  0.18163517117500305
Loss:  0.19798102974891663
Loss:  0.20956553518772125
Checking validation loss...
Average validation loss after last epoch:  0.24201238056023916
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  15


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.17317679524421692
Loss:  0.18041856586933136
Loss:  0.2365204095840454
Loss:  0.132924422621727
Checking validation loss...
Average validation loss after last epoch:  0.24777863307131662
Increase in validation loss! 1 more consecutive loss increase(s) until early stop.
Epoch:  16


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.06124235689640045
Loss:  0.06969709694385529
Loss:  0.14631755650043488
Loss:  0.25236865878105164
Checking validation loss...
Average validation loss after last epoch:  0.25289272550079556
Early stopping triggered. Ending training..


In [37]:
# test mh_sarcasm_bert

acc = test_mh_bert(mh_sarcasm_model, testloader)
print(f"Accuracy of network: {acc}")

Accuracy of network: 90.63591893780573
