# Import Packages

In [None]:
!pip install transformers
!python3 -m spacy download en_core_web_sm

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.28.1-py3-none-any.whl (7.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.0/7.0 MB[0m [31m49.7 MB/s[0m eta [36m0:00:00[0m
Collecting huggingface-hub<1.0,>=0.11.0 (from transformers)
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m224.5/224.5 kB[0m [31m14.2 MB/s[0m eta [36m0:00:00[0m
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 (from transformers)
  Downloading tokenizers-0.13.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m7.8/7.8 MB[0m [31m21.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: tokenizers, huggingface-hub, transformers
Successfully installed huggingface-hub-0.14.1 tokenizers-0.13.3 transformers-4.28.1
2023-05-01 21

In [None]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
from transformers import BertModel, BertTokenizer, AdamW
from tqdm import tqdm
import spacy
import csv
import random
import xml.etree.ElementTree as ET
#import os
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Data Processing

## Download the semeval2017_task7 dataset

In [None]:
!wget https://alt.qcri.org/semeval2017/task7/data/uploads/semeval2017_task7.tar.xz
!tar -xf semeval2017_task7.tar.xz
#!tar -xvf semeval2017_task7.tar.xz
#%cd semeval2017_task7/
#%cd ..
%ls

--2023-05-02 20:30:04--  https://alt.qcri.org/semeval2017/task7/data/uploads/semeval2017_task7.tar.xz
Resolving alt.qcri.org (alt.qcri.org)... 80.76.166.231
Connecting to alt.qcri.org (alt.qcri.org)|80.76.166.231|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 748424 (731K) [application/x-xz]
Saving to: ‘semeval2017_task7.tar.xz’


2023-05-02 20:30:06 (865 KB/s) - ‘semeval2017_task7.tar.xz’ saved [748424/748424]

[0m[01;34msample_data[0m/  [01;34msemeval2017_task7[0m/  semeval2017_task7.tar.xz


## homographic

In [None]:
f = 'semeval2017_task7/data/test/subtask1-homographic-test.xml'

mytree = ET.parse(f)
myroot = mytree.getroot()

puns_hom = []
for item in myroot.findall('./text'):
    dict1 = {}
    dict1[item.attrib['id']] = {}
    for child in item:
        idd = child.attrib['id']
        dict1[item.attrib['id']][idd] = child.text
    for pun in dict1.values():
        puns_hom.append([pun[x].replace(u'\xa0', '_') for x in pun])

print(puns_hom[0])

['They', 'hid', 'from', 'the', 'gunman', 'in', 'a', 'sauna', 'where', 'they', 'could', 'sweat', 'it', 'out', '.']


In [None]:
gold_hom = []
with open('semeval2017_task7/data/test/subtask1-homographic-test.gold', 'r') as fin:
    for row in fin:
        gold_hom.append(int(row.strip().split('\t')[1]))
print(gold_hom)

[1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 0, 

In [None]:
location_hom = [-10000 for _ in range(len(puns_hom))]
with open('semeval2017_task7/data/test/subtask2-homographic-test.gold', 'r') as fin:
    for row in fin:
        # The default is start from 1
        pun_index = int(row.strip().split('\t')[1].split('_')[2]) - 1
        location_hom[int(row.strip().split('\t')[1].split('_')[1]) - 1] = pun_index
print(location_hom)

[11, 8, 6, 4, 14, -10000, 13, 16, 10, 12, 8, -10000, -10000, 13, 14, 12, -10000, 7, 2, 16, 8, 14, 6, 6, 4, 14, 9, 4, 3, 7, 11, -10000, 3, 14, -10000, -10000, -10000, 8, 19, -10000, 15, -10000, 7, 11, 5, 8, 12, 9, -10000, 8, 7, 18, 7, -10000, 11, -10000, -10000, 14, -10000, 19, 9, 8, 9, -10000, -10000, -10000, 14, -10000, 10, 8, 20, 9, 8, 8, 12, 6, 13, 17, -10000, 4, 8, 5, 3, -10000, 14, 19, 12, 10, -10000, 2, 14, 9, 13, -10000, -10000, 11, 5, 14, 9, 19, -10000, -10000, 6, 16, -10000, 8, 16, 9, 9, 10, -10000, 6, 10, -10000, 10, 13, 4, -10000, 6, 11, -10000, -10000, -10000, 6, 16, 0, 15, 15, -10000, 10, -10000, 3, 11, -10000, 9, 14, 4, 14, 11, 11, 16, 16, -10000, 10, 10, -10000, 7, 14, 8, 7, 13, 20, 1, 8, -10000, 9, 7, -10000, 20, 8, 11, 7, 12, 11, 12, -10000, 5, -10000, 7, 22, -10000, 9, 7, 16, 7, -10000, 7, 14, 9, -10000, 16, -10000, 6, -10000, 11, 12, -10000, 21, 10, -10000, 10, -10000, 13, -10000, 6, 13, 12, 2, 20, -10000, 12, 11, 6, 16, -10000, 14, 7, 12, 16, -10000, -10000, -10000,

In [None]:
assert len(puns_hom) == len(gold_hom)
assert len(gold_hom) == len(location_hom)

## heterographic

In [None]:
f = 'semeval2017_task7/data/test/subtask1-heterographic-test.xml'

mytree = ET.parse(f)
myroot = mytree.getroot()

puns_het = []
for item in myroot.findall('./text'):
    dict1 = {}
    dict1[item.attrib['id']] = {}
    for child in item:
        idd = child.attrib['id']
        dict1[item.attrib['id']][idd] = child.text
    for pun in dict1.values():
        puns_het.append([pun[x].replace(u'\xa0', '_') for x in pun])

print(puns_het[0])

["'", "'", 'I', "'", 'm', 'halfway', 'up', 'a', 'mountain', ',', "'", "'", 'Tom', 'alleged', '.']


In [None]:
gold_het = []
with open('semeval2017_task7/data/test/subtask1-heterographic-test.gold', 'r') as fin:
    for row in fin:
        gold_het.append(int(row.strip().split('\t')[1]))
print(gold_het)

[1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 0, 1, 0, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 1, 0, 1, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 1, 1, 1, 

In [None]:
location_het = [-10000 for _ in range(len(puns_het))]
with open('semeval2017_task7/data/test/subtask2-heterographic-test.gold', 'r') as fin:
    for row in fin:
        # The default is start from 1
        pun_index = int(row.strip().split('\t')[1].split('_')[2]) - 1
        location_het[int(row.strip().split('\t')[1].split('_')[1]) - 1] = pun_index
print(location_het)

[13, 12, -10000, 10, 4, -10000, 5, 3, 5, 7, 12, 11, 23, -10000, 2, 10, 6, -10000, 15, 15, -10000, 16, 12, 5, 6, 11, -10000, 9, -10000, -10000, 13, -10000, 23, 15, 13, 16, -10000, 12, 14, 5, -10000, 10, 9, 14, 7, 9, 7, -10000, 7, 3, 10, -10000, -10000, 28, 10, 6, -10000, 13, 12, 15, -10000, 12, -10000, 15, 17, 7, 6, -10000, 13, 26, 13, 3, 12, 12, 16, 11, 4, 6, -10000, 10, -10000, 10, 10, -10000, -10000, 9, 11, -10000, 10, 5, -10000, 7, -10000, 6, 18, 8, 3, -10000, 11, 6, 15, 5, -10000, 4, 3, 14, 7, 13, -10000, -10000, 7, 10, 12, -10000, -10000, 6, 30, -10000, 6, -10000, 14, 6, 14, 11, 6, 39, -10000, -10000, 11, -10000, 16, -10000, 8, 15, -10000, 24, 23, 17, -10000, 13, 5, 11, 14, 3, 13, -10000, -10000, 4, 6, 8, -10000, -10000, 12, 14, -10000, 1, 3, 6, 12, -10000, -10000, 4, -10000, -10000, 17, 25, -10000, 9, -10000, -10000, -10000, 19, -10000, 6, -10000, -10000, 6, -10000, 16, -10000, 22, -10000, 13, 8, 12, 19, 17, 16, 9, -10000, 14, -10000, 14, 5, 10, 12, 5, -10000, -10000, 16, 6, 6, 8

In [None]:
assert len(puns_het) == len(gold_het)
assert len(gold_het) == len(location_het)

## Puns of the Day

In [None]:
texts_PTD = []
labels_PTD = []
nlp = spacy.load('en_core_web_sm')

# text = 'My first birthday was great. My 2. was even better.'
# sentences = [str(tok) for sent in nlp(text).sents for tok in sent]

# opening the CSV file
with open("/content/drive/My Drive/puns_pos_neg_data.csv", mode ='r') as file:

    # reading the CSV file
    csvFile = csv.reader(file)
    
    # displaying the contents of the CSV file
    for line in csvFile:
        #print(line)
        labels_PTD.append(0 if line[0] == "-1" else 1)
        texts_PTD.append([str(tok) for sent in nlp(line[1]).sents for tok in sent])

del texts_PTD[0] # delete the head
del labels_PTD[0] # delete the head
assert len(texts_PTD) == len(labels_PTD)

## Get Total Dataset

In [None]:
num_pos_task_7 = sum(gold_hom + gold_het)
print(num_pos_task_7)
num_neg_task_7 = len(gold_hom + gold_het) - num_pos_task_7
print(num_neg_task_7)
num_delta = num_pos_task_7 - num_neg_task_7
print(num_delta)

# PTD: from iindex=2423 is neg
total_puns = puns_hom + puns_het + texts_PTD[-num_delta:]
total_gold = gold_hom + gold_het + labels_PTD[-num_delta:]
total_location = location_hom + location_het + [-10000 for _ in range(num_delta)]
# -10000 means has no pun in the sentence
assert len(total_puns) == len(total_gold)
assert len(total_gold) == len(total_location)
assert sum(total_gold) * 2 == len(total_puns)
print()

print(len(total_puns))
print(sum(total_gold))
print(sum([1 for i in total_location if i == -10000]))

2878
1152
1726

5756
2878
2878


In [None]:
# Set up the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Model parameters
bert_model_name = 'bert-base-uncased'

# Load the tokenizer
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

max_length = 80
batch_size = 32

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [None]:
total_corresponding_location = []
# the corresponding location,
# means the corresponding location of the token ID after tokenize (index start with 0)

for pun, gold, location in zip(total_puns, total_gold, total_location):

    if gold == 1:
        current_index = 0
        start = False
        for word_idx, x in enumerate(pun):
            subwords = tokenizer.tokenize((' ' if start else '') + x)
            start = True
            for i in range(1, len(subwords)):
                subwords[i] = subwords[i][2:] # delete the '##'

            subwords_len = np.zeros((len(subwords),), dtype=float)
            for count_index, each_token in enumerate(subwords):
                subwords_len[count_index] = len(each_token)
            longest_index = np.argmax(subwords_len, axis=0)

            if 1 + current_index + longest_index >= max_length - 1:
                # Cannot longer than max_length
                # because the input_ids has a start ID and end ID
                # which needs to +1 and -1 to match the index
                print("The corresponding location index is out of range!")

            if word_idx == location:
                total_corresponding_location.append(current_index + longest_index)
                break

            current_index += len(subwords)
            
        else:
            # Not break
            print("Something Wrong!")
            
    else:
        # gold == 0
        total_corresponding_location.append(-10000)

assert len(total_corresponding_location) == len(total_puns)

In [None]:
dataset = []
# input_ids, attention_mask, label, location_index, corresponding_location

for pun, gold, location, corresponding_location in tqdm(zip(total_puns, total_gold, total_location, total_corresponding_location)):
    sentence = ' '.join(pun)
    tokenized = tokenizer(sentence, return_tensors="pt", max_length=max_length, truncation=True, padding="max_length")
    dataset.append(
        {
            'input_ids': tokenized["input_ids"][0],
            'attention_mask': tokenized["attention_mask"][0],
            'label': torch.tensor(gold, dtype=torch.long),
            'location': torch.tensor(location, dtype=torch.long),
            'corresponding_location': torch.tensor(corresponding_location, dtype=torch.long)
         }
    )

random.shuffle(dataset)


# Split the dataset into training and validation sets
train_dataset, val_dataset = torch.utils.data.random_split(
    dataset, [len(total_puns)-int(0.2*len(total_puns)), int(0.2*len(total_puns))])

# Create DataLoaders for each set with a batch size
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

5756it [00:03, 1797.94it/s]


## Setup the Model

In [None]:
# Define the BERT-BiLSTM model
class BertBiLSTM(nn.Module):
    def __init__(self, bert_model_name, num_classes, hidden_dim, num_layers, bidirectional, dropout):
        super(BertBiLSTM, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        # Set the BERT layer as untrainable
        '''
        for param in self.bert.parameters():
            param.requires_grad = False
        '''
        self.lstm = nn.LSTM(
            input_size=self.bert.config.hidden_size,
            hidden_size=hidden_dim,
            num_layers=num_layers,
            bidirectional=bidirectional,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0
        )
        self.dropout = nn.Dropout(dropout)
        #self.classifier = nn.Linear(hidden_dim * (2 if bidirectional else 1), num_classes)
        self.classifier = nn.Linear(hidden_dim * (2 if bidirectional else 1), 1) # sigmoid
        # self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.fc = nn.Linear(hidden_dim * (2 if bidirectional else 1), 1)
        #self.sigmoid = nn.Sigmoid()

    def forward(self, input_ids, attention_mask, token_index):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = bert_output['last_hidden_state']
        lstm_output, (hidden, _) = self.lstm(sequence_output)

        if self.lstm.bidirectional:
            hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=1)
            # hidden = torch.cat((hidden[-2, :, :], hidden[-1, :, :]), dim=-1)
        else:
            hidden = hidden[-1, :, :]
        # pooled_output = lstm_output[:, -1]

        # Get the hidden state of the word at the specified index
        # focused_word_hidden = bilstm_output[:, word_index, :] # not this
        assert lstm_output.size(0) == len(input_ids)
        assert lstm_output.size(2) == self.lstm.hidden_size * (2 if self.lstm.bidirectional else 1)

        focused_word_hidden = torch.zeros((lstm_output.size(0), lstm_output.size(2)))
        for sentence_index in range(lstm_output.size(0)):
            focused_word_hidden[sentence_index] = lstm_output[sentence_index, token_index[sentence_index], :]
        classification_output = self.classifier(focused_word_hidden)

        # dropped_output = self.dropout(pooled_output)
        # logits = self.classifier(dropped_output)
        return self.fc(self.dropout(hidden)), classification_output
        '''
        You should not apply the sigmoid function within the BertBiLSTM model
        if you are using nn.BCEWithLogitsLoss(), as this loss function combines
        the sigmoid activation and binary cross-entropy loss in a numerically stable way.
        '''


In [None]:
# Model parameters
num_classes = 2
hidden_dim = 128
num_layers = 2
bidirectional = True
dropout = 0.3

# Initialize the model
model = BertBiLSTM(bert_model_name, num_classes, hidden_dim, num_layers, bidirectional, dropout).to(device)

# Use this path to save the model
model_path = "/content/drive/My Drive/my_MTL_PT_model.pt"  # Choose your desired path and filename

# Training parameters
num_epochs = 5
learning_rate = 2e-5
weight_decay = 1e-2

# criterion = nn.CrossEntropyLoss()
criterion_1 = nn.BCEWithLogitsLoss()
criterion_2 = nn.BCEWithLogitsLoss()
# Set up the optimizer
# optimizer = AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

Downloading pytorch_model.bin:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertModel: ['cls.predictions.bias', 'cls.seq_relationship.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [None]:
# Training loop
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}:')

    model.train()
    train_loss, train_correct_1, train_correct_2, train_samples = 0, 0, 0, 0
    for batch in tqdm(train_dataloader):
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].unsqueeze(1).float().to(device)
        token_index = torch.clone(batch['corresponding_location']).to(device)
        #print(input_ids.shape)
        #print(token_index.shape)
        #print(sum(attention_mask[3])) # every time is different
        for sentence_index in range(len(input_ids)):
            if token_index[sentence_index] < 0:
                # No pun word
                token_index[sentence_index] = torch.randint(1, int(sum(attention_mask[sentence_index])) - 1, (1,))
            else:
                token_index[sentence_index] += 1 # The start of the 'input_ids' is the start ID

        optimizer.zero_grad()

        # Forward pass
        logits_1, logits_2 = model(input_ids, attention_mask, token_index)

        # Compute the loss
        loss_1 = criterion_1(logits_1, labels)
        loss_2 = criterion_2(logits_2, labels)

        # Combine the losses
        joint_loss = loss_1 + loss_2

        optimizer.step()

        # Backward pass
        joint_loss.backward()

        # Update the weights
        optimizer.step()

        train_loss += joint_loss.item()

        # Compute the number of correct predictions
        # preds = torch.argmax(logits, dim=1)
        sigmoid_1 = torch.sigmoid(logits_1.view(-1))
        preds_1 = (sigmoid_1 > 0.5).unsqueeze(1).float()
        num_correct_1 = (preds_1 == labels).sum().item()
        train_correct_1 += num_correct_1

        sigmoid_2 = torch.sigmoid(logits_2.view(-1))
        preds_2 = (sigmoid_2 > 0.5).unsqueeze(1).float()
        num_correct_2 = (preds_2 == labels).sum().item()
        train_correct_2 += num_correct_2

        train_samples += labels.size(0)

    train_avg_loss = train_loss / len(train_dataloader)
    train_accuracy_1 = train_correct_1 / train_samples
    train_accuracy_2 = train_correct_2 / train_samples
    print(f'Training Loss: {train_avg_loss:.4f} - Training Accuracy: {train_accuracy_1:.4f} - Training Word Accuracy: {train_accuracy_2:.4f}')

    torch.save(model, model_path) # save the entire model, including the architecture
    # Don't need to recreate the architecture when loading the model later.
    # However, the resulting file will be larger

    # Evaluate the model on the validation set
    model.eval()
    val_loss, val_correct_1, val_correct_2, val_samples = 0, 0, 0, 0
    with torch.no_grad():
        for batch in tqdm(val_dataloader):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].unsqueeze(1).float().to(device)
            token_index = torch.clone(batch['corresponding_location']).to(device)
            #print(input_ids.shape)
            for sentence_index in range(len(input_ids)):
                if token_index[sentence_index] < 0:
                    # No pun word
                    token_index[sentence_index] = torch.randint(1, int(sum(attention_mask[sentence_index])) - 1, (1,))
                else:
                    token_index[sentence_index] += 1 # The start of the 'input_ids' is the start ID

            # Forward pass
            logits_1, logits_2 = model(input_ids, attention_mask, token_index)

            # Compute the loss
            # loss = nn.CrossEntropyLoss()(logits, labels)
            loss_1 = criterion_1(logits_1, labels)
            loss_2 = criterion_2(logits_2, labels)

            # Combine the losses
            joint_loss = loss_1 + loss_2

            val_loss += joint_loss.item()

            # Compute the number of correct predictions
            # preds = torch.argmax(logits, dim=1)
            sigmoid_1 = torch.sigmoid(logits_1.view(-1))
            preds_1 = (sigmoid_1 > 0.5).unsqueeze(1).float()
            num_correct_1 = (preds_1 == labels).sum().item()
            val_correct_1 += num_correct_1

            sigmoid_2 = torch.sigmoid(logits_2.view(-1))
            preds_2 = (sigmoid_2 > 0.5).unsqueeze(1).float()
            num_correct_2 = (preds_2 == labels).sum().item()
            val_correct_2 += num_correct_2

            val_samples += labels.size(0)

    val_avg_loss = val_loss / len(val_dataloader)
    val_accuracy_1 = val_correct_1 / val_samples
    val_accuracy_2 = val_correct_2 / val_samples
    print(f'Validation Loss: {val_avg_loss:.4f} - Validation Accuracy: {val_accuracy_1:.4f} - Validation Word Accuracy: {val_accuracy_2:.4f}')


Epoch 1/5:


100%|██████████| 144/144 [1:05:51<00:00, 27.44s/it]


Training Loss: 0.7991 - Training Accuracy: 0.8664 - Training Word Accuracy: 0.8916


100%|██████████| 36/36 [05:22<00:00,  8.96s/it]


Validation Loss: 0.4312 - Validation Accuracy: 0.9357 - Validation Word Accuracy: 0.9340
Epoch 2/5:


100%|██████████| 144/144 [1:05:12<00:00, 27.17s/it]


Training Loss: 0.3341 - Training Accuracy: 0.9542 - Training Word Accuracy: 0.9531


100%|██████████| 36/36 [05:18<00:00,  8.84s/it]


Validation Loss: 0.3452 - Validation Accuracy: 0.9409 - Validation Word Accuracy: 0.9453
Epoch 3/5:


100%|██████████| 144/144 [1:05:24<00:00, 27.26s/it]


Training Loss: 0.1991 - Training Accuracy: 0.9744 - Training Word Accuracy: 0.9776


100%|██████████| 36/36 [05:19<00:00,  8.88s/it]


Validation Loss: 0.3393 - Validation Accuracy: 0.9392 - Validation Word Accuracy: 0.9409
Epoch 4/5:


  5%|▍         | 7/144 [03:11<1:02:51, 27.53s/it]

In [None]:
# Load the the saved file
loaded_model = torch.load(model_path)

# Set the model to evaluation mode if you plan to use it for inference
loaded_model.eval()

with torch.no_grad():
    num_correct, num_samples, num_accuracy = 0, 0, 0
    for ele in tqdm(val_dataset):
        input_ids = ele['input_ids'].reshape(1, max_length).to(device)
        attention_mask = ele['attention_mask'].reshape(1, max_length).to(device)
        #labels = ele['label'].unsqueeze(1).float().to(device)
        label = ele['label'].to(device)
        token_index = torch.clone(ele['corresponding_location']).to(device)

        #print(input_ids)
        #print(attention_mask)
        #print(label)
        #print(token_index)

        if label < 0.5:
            # This has not pun, do not count
            continue
        
        scores_list = np.zeros((int(sum(attention_mask[0])) - 2,), dtype=float)
        for used_token_index in range(1, int(sum(attention_mask[0])) - 1):
            logits_1, logits_2 = loaded_model(input_ids, attention_mask, torch.tensor(used_token_index).reshape(1,))
            #print(logits_2)
            #print(logits_2.view(-1))
            scores_list[used_token_index - 1] = float(torch.sigmoid(logits_2.view(-1))) # Get each ID scores
            #print(float(torch.sigmoid(logits_2.view(-1))))

        # TODO: If the matched token ID is belonged to the pun word, it is also correct.
        if np.argmax(scores_list, axis=0) == int(token_index):
            num_correct += 1
        
        num_samples += 1
        #if num_samples == 300:
        #    break

    num_accuracy = num_correct / num_samples
    print(f'Location Accuracy: {num_accuracy:.4f}')

 52%|█████▏    | 597/1151 [28:40<26:36,  2.88s/it]

Location Accuracy: 0.2800



