In [1]:
try:
    import transformers
except ImportError as e:
    print('transformers not installed')
    print('Installing now...')
    !pip install -q git+https://github.com/huggingface/transformers.git
    print("Install complete.")
    pass  

In [2]:
import torch
import io 
import os
from torch.utils.data import Dataset,DataLoader,TensorDataset
from sklearn.metrics import classification_report,accuracy_score
import transformers
import json
from tqdm.notebook import tqdm
from transformers.utils.dummy_pt_objects import AutoModelForSequenceClassification
from transformers import AutoModelForTokenClassification,AutoConfig,AutoModel,AutoTokenizer,BertModel,BertConfig,AdamW, get_constant_schedule,BertForSequenceClassification,get_linear_schedule_with_warmup
import random
import numpy as np
import torch.nn as nn
from sklearn.metrics import confusion_matrix
from sklearn.metrics import ConfusionMatrixDisplay
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.model_selection import train_test_split

#Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
# if using Google Colab, set colab = True
colab = False

if colab == True:
    #Mounting Drive
    from google.colab import drive
    drive.mount('/content/gdrive')
    
    %cd '/content/gdrive/Shareddrives/523 Project/Data'
    %ls
else:
    DATA_DIR = '/projectnb2/dl523/students/kjv/EC523_Project/Data/'

In [4]:
#Reading in headlines data
df = pd.read_json(DATA_DIR + "News Headlines/Sarcasm_Headlines_Dataset_v2.json",lines = True)
df = df.rename(columns={'is_sarcastic': 'label'})
df = df.drop('article_link', 1)
df.head()

#splits for headlines training, test, and validation

train_headlines, temporary_text, train_label, temporary_label = train_test_split(df['headline'], df['label'], 
                                                                    random_state=200, 
                                                                    test_size=0.2, 
                                                                    stratify=df['label'])







validation_headlines, test_headlines, validation_label, test_label = train_test_split(temporary_text, temporary_label, 
                                                                    random_state=200, 
                                                                    test_size=0.5, 
                                                                    stratify=temporary_label)

In [6]:
# set max length for padding/clipping during tokenization
max_length = 35

# create tokenized training, validation, and test splits
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

training_tokens = tokenizer.__call__(train_headlines.tolist(),max_length = max_length,padding = True,truncation = True)
validation_tokens = tokenizer.__call__(validation_headlines.tolist(),max_length = max_length,padding = True,truncation = True)
test_tokens = tokenizer.__call__(test_headlines.tolist(),max_length = max_length,padding= True,truncation = True)

# Stacking the inputs as tensors for use in the BERT model

training_set = TensorDataset(torch.tensor(training_tokens['input_ids']),torch.tensor(training_tokens['attention_mask']),torch.tensor(train_label.tolist()))
validation_set = TensorDataset(torch.tensor(validation_tokens['input_ids']),torch.tensor(validation_tokens['attention_mask']),torch.tensor(validation_label.tolist()))
test_set = TensorDataset(torch.tensor(test_tokens['input_ids']),torch.tensor(test_tokens['attention_mask']),torch.tensor(test_label.tolist()))

In [17]:
class multihead_attn_bert(nn.Module):

    def __init__(self, bert_encoder, embed_dim, num_attn_layers, num_heads):
        super(multihead_attn_bert, self).__init__()
        
        self.bert = bert_encoder
        
        self.multiheads = nn.ModuleList([nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)])
        self.multiheads.extend([nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) for i in range(num_attn_layers-1)])
        
        self.GRU = nn.GRU(input_size=embed_dim, hidden_size=512, bidirectional=True, batch_first=True)
        
        # 1024 = hidden_size of GRU x 2 (for bidirectionality of GRU)
        self.fc = nn.Linear(in_features=1024, out_features=1)
        
        # use dropout set of 0.2 as in paper
        self.dropout = nn.Dropout(p=0.2)
        
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, tokenized_input_values, attention_mask):
        
        output = self.bert(tokenized_input_values, attention_mask=attention_mask).last_hidden_state
        
        for multihead_layer in self.multiheads:
            output,_ = multihead_layer(query=output,key=output,value=output,key_padding_mask=(~attention_mask.bool()))
        
        _,hidden = self.GRU(output)
        
        # concatenate bidirectional outputs from GRU to pass to linear layer
        hidden = torch.cat([hidden[0,:, :], hidden[1,:,:]], dim=1).unsqueeze(0)
        
        output = self.fc(hidden)
        
        output = self.dropout(output)
        
        output = self.sigmoid(output)
        
        return output

In [8]:
# initialize pre-trained BERT large

bertconfig = BertConfig()
bert = BertModel.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
# freeze pre-trained layers in BERT

for param in bert.parameters():
    param.requires_grad = False

In [16]:
# initialize multihead attention sarcasm model with BERT embedder

# embed_dim = 768 if using bert_base, 1024 for bert_large
mh_sarcasm_model = multihead_attn_bert(bert, embed_dim=768, num_attn_layers=3, num_heads=8)
mh_sarcasm_model.to(device)

# save untrained model weights
torch.save(mh_sarcasm_model.state_dict(), "/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/untrained_mhbert.pth")

multihead_attn_bert(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0): BertLayer(
          (attention): BertAttention(
            (self): BertSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affin

In [12]:
#Dataloaders for the headlines sets

batch_size = 64

trainloader = DataLoader(training_set, batch_size = batch_size, num_workers=2, shuffle = True)
validationloader = DataLoader(validation_set, batch_size = batch_size, num_workers=2, shuffle = True)
testloader = DataLoader(test_set, batch_size = batch_size, num_workers=2, shuffle = True)

#Loss function (used in reference multihead attention paper)
criterion = nn.BCELoss()

In [20]:
# define training, testing, and validation loss functions

def train_mh_bert(model, trainloader, validationloader, optimizer, criterion, num_epochs, scheduler=None):
        
        # conditions for early stopping
        last_val_loss = float('inf')
        min_val_loss = float('inf')
        patience = 3
        es_counter = 0
        
        print("Starting training...")
        
        for epoch in range(1, num_epochs+1):
    
            model.train()
            if scheduler != None:
                print("Learning rate: ", scheduler.get_last_lr())
            
            running_loss = 0
            print('Epoch: ',epoch)

            for idx, (inputs,attention_mask,label) in enumerate(tqdm(trainloader,total = len(trainloader))):

                inputs, attention_mask, label = inputs.to(device),attention_mask.to(device),label.to(device)

                optimizer.zero_grad()

                output = mh_sarcasm_model(inputs, attention_mask)
                output = torch.flatten(output)

                # convert label type from int to float for use in BCELoss
                label = label.float()
                loss = criterion(output,label)

                loss.backward()
                running_loss += loss.item()
                optimizer.step()

                # print loss every 100 batches
                if idx % 100 == 0:
                    print('Loss: ',float(loss))
                    running_loss = 0.0
            
            # adjust scheduler after every epoch
            if scheduler != None:
                scheduler.step()

            # check for changes in avg validation loss to determine if early stopping is needed
            print("Checking validation loss...")
            curr_val_loss = validation_loss(model, validationloader)
            print("Average validation loss after last epoch: ", curr_val_loss)

            if curr_val_loss > last_val_loss:
                es_counter += 1

                if es_counter >= patience:
                    print("Early stopping triggered. Ending training..")
                    return
                else:
                    print(f"Increase in validation loss! {patience-es_counter} more consecutive loss increase(s) until early stop.")

            else:
                print("Decrease in validation loss. Early stop counter reset to 0.")
                es_counter = 0

            last_val_loss = curr_val_loss
            
            # check to save model if validation loss is lower than min recorded validation loss
            if curr_val_loss < min_val_loss:
                min_val_loss = curr_val_loss
                save_dir = "/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/trained_MH_BERT.pth"
                torch.save(model.state_dict(), save_dir)

def test_mh_bert(model, testloader):
    model.eval()
    correct = 0
    total = 0
    
    with torch.no_grad():
        for data in testloader:
            
            inputs, attn_mask, labels = data
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            
            output = torch.flatten(model(inputs, attn_mask))
            
            # convert output to class predictions
            output[output<0.5] = 0
            output[output>=0.5] = 1
            
            total += labels.size(0)
            correct += (output==labels).float().sum().item()
            
        acc = correct/total * 100
    
    return acc

def validation_loss(model, validationloader):
    model.eval()
    total_val_loss = 0
    criterion = nn.BCELoss()
    
    with torch.no_grad():
        for data in validationloader:
            
            inputs, attn_mask, labels = data
            inputs, attn_mask, labels = inputs.to(device), attn_mask.to(device), labels.to(device)
            labels = labels.float()
            
            output = torch.flatten(model(inputs, attn_mask))
            loss = criterion(output, labels)
            
            total_val_loss += loss.item()
    
    avg_val_loss = total_val_loss/len(validationloader)
    
    return avg_val_loss

In [18]:
#Training mh_sarcasm bert
num_epochs = 50

# optimizer using learning rate from reference paper
tuning_parameters = [parameter for parameter in mh_sarcasm_model.parameters() if parameter.requires_grad]
optimizer = torch.optim.Adam(tuning_parameters,lr = 1e-4)
# scheduler = get_linear_schedule_with_warmup(optimizer,num_warmup_steps = 0,num_training_steps = len(trainloader)*Epochs)

# train the model with early stopping (add scheduler once done)
train_mh_bert(mh_sarcasm_model, trainloader, validationloader, optimizer, criterion, num_epochs)

Starting training...
Epoch:  1


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.6963216066360474
Loss:  0.418545663356781
Loss:  0.2411714792251587
Loss:  0.42320960760116577
Checking validation loss...
Average validation loss after last epoch:  0.2985598643620809
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  2


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.2823479473590851
Loss:  0.27974405884742737
Loss:  0.3039320111274719
Loss:  0.25664085149765015
Checking validation loss...
Average validation loss after last epoch:  0.2584587554136912
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  3


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.14680689573287964
Loss:  0.4012824594974518
Loss:  0.18251962959766388
Loss:  0.21245181560516357
Checking validation loss...
Average validation loss after last epoch:  0.2478137085835139
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  4


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.29841768741607666
Loss:  0.1815876066684723
Loss:  0.19436872005462646
Loss:  0.22262385487556458
Checking validation loss...
Average validation loss after last epoch:  0.25016516711976794
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  5


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.37482279539108276
Loss:  0.28791648149490356
Loss:  0.19054046273231506
Loss:  0.22295452654361725
Checking validation loss...
Average validation loss after last epoch:  0.23309316055642235
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  6


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.14328540861606598
Loss:  0.19539877772331238
Loss:  0.18976709246635437
Loss:  0.18226853013038635
Checking validation loss...
Average validation loss after last epoch:  0.2478480350640085
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  7


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.32073381543159485
Loss:  0.12422336637973785
Loss:  0.16463278234004974
Loss:  0.13943494856357574
Checking validation loss...
Average validation loss after last epoch:  0.22149229728513295
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  8


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.1693744957447052
Loss:  0.09850645065307617
Loss:  0.16211938858032227
Loss:  0.22823259234428406
Checking validation loss...
Average validation loss after last epoch:  0.2490161649054951
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  9


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.20200148224830627
Loss:  0.16963420808315277
Loss:  0.10567642748355865
Loss:  0.3604779541492462
Checking validation loss...
Average validation loss after last epoch:  0.24725776529974408
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  10


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.2297922521829605
Loss:  0.0750618577003479
Loss:  0.11122405529022217
Loss:  0.2827882468700409
Checking validation loss...
Average validation loss after last epoch:  0.23819700280825298
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  11


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.04376868158578873
Loss:  0.16302360594272614
Loss:  0.0708831325173378
Loss:  0.24929673969745636
Checking validation loss...
Average validation loss after last epoch:  0.23085618697934682
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  12


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.15729427337646484
Loss:  0.22843106091022491
Loss:  0.2104424089193344
Loss:  0.1646895408630371
Checking validation loss...
Average validation loss after last epoch:  0.2216010789076487
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  13


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.17424669861793518
Loss:  0.21771834790706635
Loss:  0.11546191573143005
Loss:  0.09443441033363342
Checking validation loss...
Average validation loss after last epoch:  0.24912890858120387
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  14


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.12638965249061584
Loss:  0.1569819450378418
Loss:  0.0987187922000885
Loss:  0.18995054066181183
Checking validation loss...
Average validation loss after last epoch:  0.27610707845952775
Increase in validation loss! 1 more consecutive loss increase(s) until early stop.
Epoch:  15


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.11561749875545502
Loss:  0.1427774429321289
Loss:  0.15124203264713287
Loss:  0.13119392096996307
Checking validation loss...
Average validation loss after last epoch:  0.250098805460665
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  16


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.09805400669574738
Loss:  0.09547890722751617
Loss:  0.09471965581178665
Loss:  0.0752532109618187
Checking validation loss...
Average validation loss after last epoch:  0.2603272735244698
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  17


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.05152593553066254
Loss:  0.061459437012672424
Loss:  0.08125069737434387
Loss:  0.219839408993721
Checking validation loss...
Average validation loss after last epoch:  0.24760983288288116
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  18


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.054508935660123825
Loss:  0.09287788718938828
Loss:  0.07200321555137634
Loss:  0.08509998023509979
Checking validation loss...
Average validation loss after last epoch:  0.3212813686165545
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  19


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.10377069562673569
Loss:  0.04518425464630127
Loss:  0.21159781515598297
Loss:  0.04977882653474808
Checking validation loss...
Average validation loss after last epoch:  0.2924107218782107
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  20


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.12807635962963104
Loss:  0.1465209573507309
Loss:  0.19472277164459229
Loss:  0.17829164862632751
Checking validation loss...
Average validation loss after last epoch:  0.3001458489232593
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  21


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.03807127848267555
Loss:  0.07529351115226746
Loss:  0.14919498562812805
Loss:  0.04066608101129532
Checking validation loss...
Average validation loss after last epoch:  0.28835202124383713
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  22


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.0627489984035492
Loss:  0.12410268187522888
Loss:  0.03685704618692398
Loss:  0.07532985508441925
Checking validation loss...
Average validation loss after last epoch:  0.3101462074451976
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  23


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.09619129449129105
Loss:  0.09781554341316223
Loss:  0.1956649124622345
Loss:  0.0523945614695549
Checking validation loss...
Average validation loss after last epoch:  0.2714543989963002
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  24


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.11886343359947205
Loss:  0.06035050377249718
Loss:  0.12734098732471466
Loss:  0.0851750373840332
Checking validation loss...
Average validation loss after last epoch:  0.3143978723221355
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  25


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.07477100938558578
Loss:  0.0489264577627182
Loss:  0.08604718744754791
Loss:  0.09632599353790283
Checking validation loss...
Average validation loss after last epoch:  0.30025207532776726
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  26


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.01384538970887661
Loss:  0.06270094215869904
Loss:  0.02390207350254059
Loss:  0.12754742801189423
Checking validation loss...
Average validation loss after last epoch:  0.27340799636311003
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  27


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.07131831347942352
Loss:  0.09173595905303955
Loss:  0.058628544211387634
Loss:  0.0765695571899414
Checking validation loss...
Average validation loss after last epoch:  0.24553730189800263
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  28


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.06321270763874054
Loss:  0.10246194899082184
Loss:  0.1639535129070282
Loss:  0.028557563200592995
Checking validation loss...
Average validation loss after last epoch:  0.29703219764762456
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  29


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.0160438884049654
Loss:  0.05441546067595482
Loss:  0.0500582791864872
Loss:  0.03651735559105873
Checking validation loss...
Average validation loss after last epoch:  0.29620524504118495
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  30


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.07326982170343399
Loss:  0.07617610692977905
Loss:  0.04941151291131973
Loss:  0.0645643025636673
Checking validation loss...
Average validation loss after last epoch:  0.3271255784564548
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  31


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.026505418121814728
Loss:  0.1461707353591919
Loss:  0.04946666583418846
Loss:  0.0724601298570633
Checking validation loss...
Average validation loss after last epoch:  0.28492499738931654
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  32


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.03261936828494072
Loss:  0.060441046953201294
Loss:  0.036171965301036835
Loss:  0.05649522691965103
Checking validation loss...
Average validation loss after last epoch:  0.3582661723097165
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  33


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.046871718019247055
Loss:  0.11241233348846436
Loss:  0.001493986346758902
Loss:  0.031806714832782745
Checking validation loss...
Average validation loss after last epoch:  0.31532224888602894
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  34


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.03097635880112648
Loss:  0.10107386112213135
Loss:  0.15578506886959076
Loss:  0.09268830716609955
Checking validation loss...
Average validation loss after last epoch:  0.2991649904184871
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  35


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.1300961971282959
Loss:  0.06345014274120331
Loss:  0.02058563381433487
Loss:  0.09868909418582916
Checking validation loss...
Average validation loss after last epoch:  0.34221589085128573
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  36


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.055060580372810364
Loss:  0.029261277988553047
Loss:  0.024068081751465797
Loss:  0.031394340097904205
Checking validation loss...
Average validation loss after last epoch:  0.30465191304683686
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  37


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.08994238823652267
Loss:  0.005597518756985664
Loss:  0.00856197252869606
Loss:  0.04358190298080444
Checking validation loss...
Average validation loss after last epoch:  0.27431729518704945
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  38


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.04222264140844345
Loss:  0.030403729528188705
Loss:  0.010294943116605282
Loss:  0.026942409574985504
Checking validation loss...
Average validation loss after last epoch:  0.32540261588162844
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  39


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.02313464879989624
Loss:  0.07231122255325317
Loss:  0.0725988894701004
Loss:  0.07587825506925583
Checking validation loss...
Average validation loss after last epoch:  0.394190099173122
Increase in validation loss! 1 more consecutive loss increase(s) until early stop.
Epoch:  40


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.007143646478652954
Loss:  0.04251345992088318
Loss:  0.049347974359989166
Loss:  0.01922822743654251
Checking validation loss...
Average validation loss after last epoch:  0.37829419490363864
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  41


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.014790802262723446
Loss:  0.07959047704935074
Loss:  0.003014124697074294
Loss:  0.0021444372832775116
Checking validation loss...
Average validation loss after last epoch:  0.3668654439349969
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  42


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.11632373183965683
Loss:  0.07253878563642502
Loss:  0.11754895001649857
Loss:  0.009158804081380367
Checking validation loss...
Average validation loss after last epoch:  0.39316093408399155
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  43


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.009744486771523952
Loss:  0.02714664116501808
Loss:  0.008864670060575008
Loss:  0.010116401128470898
Checking validation loss...
Average validation loss after last epoch:  0.38958994903498223
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  44


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.008632187731564045
Loss:  0.046055980026721954
Loss:  0.002726966282352805
Loss:  0.01510387659072876
Checking validation loss...
Average validation loss after last epoch:  0.46313930973410605
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  45


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.01027741190046072
Loss:  0.010326704941689968
Loss:  0.02248605154454708
Loss:  0.04264972358942032
Checking validation loss...
Average validation loss after last epoch:  0.35057488398419484
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  46


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.006908167153596878
Loss:  0.009162016212940216
Loss:  0.012073447927832603
Loss:  0.024775832891464233
Checking validation loss...
Average validation loss after last epoch:  0.38542638710803456
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  47


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.0041894009336829185
Loss:  0.0646161139011383
Loss:  0.012558885850012302
Loss:  0.012202450074255466
Checking validation loss...
Average validation loss after last epoch:  0.3596651131908099
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  48


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.043110959231853485
Loss:  0.0023991051129996777
Loss:  0.038173235952854156
Loss:  0.06937780976295471
Checking validation loss...
Average validation loss after last epoch:  0.38699375010199016
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.
Epoch:  49


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.0008595343679189682
Loss:  0.051669199019670486
Loss:  0.07512842863798141
Loss:  0.0026416971813887358
Checking validation loss...
Average validation loss after last epoch:  0.3472749236557219
Decrease in validation loss. Early stop counter reset to 0.
Epoch:  50


  0%|          | 0/358 [00:00<?, ?it/s]

Loss:  0.0034973248839378357
Loss:  0.07195261120796204
Loss:  0.030228853225708008
Loss:  0.030221175402402878
Checking validation loss...
Average validation loss after last epoch:  0.38445756344331633
Increase in validation loss! 2 more consecutive loss increase(s) until early stop.


In [19]:
# test mh_sarcasm_bert

acc = test_mh_bert(mh_sarcasm_model, testloader)
print(f"Accuracy of network: {acc}")

Accuracy of network: 91.61425576519916


In [None]:
# hp validation for mh_bert

def perform_hp_tuning(model, trainingset, validationset, lr_array, batchsize_array, criterion, num_epochs=1):

    accuracies = np.empty([len(lr_array), len(batchsize_array)])
    tuning_parameters = [parameter for parameter in model.parameters() if parameter.requires_grad]

    # run validation testing

    for i, lr in enumerate(lr_array):
        
        # reset optimizer with new learning rate
        optimizer = optim.SGD(tuning_parameters, lr=lr, momentum = .9)

        for j, batch_size in enumerate(batchsize_array):

            print("LEARNING RATE: ", lr)
            print("BATCH SIZE: ", batch_size)

            # restore untrained weights for model
            model.load_state_dict(torch.load("/projectnb/dl523/students/kjv/EC523_Project/Saved_Models/Multihead_BERT/untrained_mhbert.pth"))

            # define data loaders for training and testing data
            trainloader = torch.utils.data.DataLoader(trainingset, batch_size=batch_size, num_workers=2)
            testloader = torch.utils.data.DataLoader(validationset, batch_size=batch_size, num_workers=2)
            
            # train model using current hps and early stopping
            train_mh_bert(model, trainloader, testloader, optimizer, criterion, num_epochs)
            
            # test performance on validation dataset
            result = test_mh_bert(model, testloader)

            accuracies[i,j] = result
            print(f"Accuracy for lr={lr} and bs={batch_size}: {accuracies[i,j]}\n")


    # choose learning rate and batch size with best validation accuracy
    print("---HP TESTING COMPLETE---")
    print("Accuracy Matrix: \n", accuracies)
    best_lr_ind, best_bs_ind = np.unravel_index(np.argmax(accuracies, axis=None), accuracies.shape)
    
    optimal_lr = learning_rates[best_lr_ind]
    optimal_batch_size = batchsize_array[best_bs_ind]

    print(f"\nBest learning rate: {optimal_lr}")
    print(f"Best batch size: {optimal_batch_size}")
    return optimal_lr, optimal_batch_size