In [None]:
# !pip install pytorch-transformers
# !pip install pytorch

In [None]:
import torch
import os
import string
import copy
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from pytorch_transformers import *
import numpy as np
import json
import collections
from sklearn.utils import shuffle
from sklearn.metrics import accuracy_score

In [None]:
#takes a sentence input and returns a tensor vector
tokenizer = RobertaTokenizer.from_pretrained('roberta-base') 

In [None]:
#load_data basically reads in data --> takes in everything from jsonl files
def load_data(filename):
  data = []
  # read in each line and add it to list
  with open(filename, mode = "r") as file:
      for line in file:
          data.append(json.loads(line))
  return data

#Takes a list of words (strings) and a sentence (as a RoBERTa tokenized ID list) and returns a list
#of pairs indicating the tokens' start and end positions in the sentence for each word
      
def find_word_in_tokenized_sentence(word,token_ids): #gets the same word in a sentence
  decomposedWord = tokenizer.encode(word)
  # Iterate through to find a matching sublist of the token_ids
  for i in range(len(token_ids)):
    if token_ids[i] == decomposedWord[0] and token_ids[i:i+len(decomposedWord)] == decomposedWord:
      return (i,i+len(decomposedWord)-1) #matching word found
  # if no matching word --> return this 
  return (-1,-1)
  
def find_words_in_tokenized_sentences(wordList,token_ids): #gets list of word positions in a sentence
  intList = []
  for word in wordList:
    if len(intList) == 0:
      intList.append(find_word_in_tokenized_sentence(word,token_ids))
    else:
      afterLastInterval = intList[-1][1]+1
      interv = find_word_in_tokenized_sentence(word,token_ids[afterLastInterval:])
      actualPositions = (interv[0] + afterLastInterval,interv[1]+afterLastInterval)
      intList.append(actualPositions)
  return intList #returns list of positions

# Function to calculate the accuracy of our predictions vs labels
def flat_accuracy(preds, labels, return_predict_correctness = False):
  pred_flat = np.argmax(preds, axis=1).flatten()
  labels_flat = labels.flatten()
  if return_predict_correctness:
    return np.sum(pred_flat == labels_flat) / len(labels_flat), pred_flat == labels_flat
  else:
    return np.sum(pred_flat == labels_flat) / len(labels_flat)

In [None]:
BATCH_SIZE = 32 #decreased the size until the CPU stops dying
EPOCHS = 10 #could do more for higher accuracy buts takes too long
PATIENCE = 10

In [None]:
# data preprocessing function
def wic_preprocessing(json_objects, training = True, shuffle_data = False):
  wic_sentences = [] #create arrays for sentences, encoded, etc. 
  wic_encoded = []
  wic_labels = []
  wic_word_locs = []
  wic_indexes = []
  for index, example in enumerate(json_objects):
    wic_indexes.append(index)
    #combine the sentences with marks at beginnings and ends and add to array
    sentence = f"<s>{example['sentence1']}</s><s>{example['sentence2']}</s>"
    wic_sentences.append(sentence)
    #encode sentences
    wic_encoded.append(tokenizer.encode(sentence, add_special_tokens=False))
    #keep track of matching word
    word = example['word']
    word_locs = (-1, -1)
    # Split the sentences
    sent1_split = example['sentence1'].split(' ')
    sent2_split = example['sentence2'].split(' ')
    # keep track of word indices in sentences
    sent1_word_char_loc = (example['start1'], example['end1'])
    sent2_word_char_loc = (example['start2'], example['end2'])
    # # of characters parsed in each sentence
    sent_chars = 0
    # Loop through first sentence
    i, j = 0, 0
    word1_not_found, word2_not_found = True, True
    

    while word1_not_found and i < len(sent1_split):
      word_len = len(sent1_split[i])
      # Found index of word in sentence
      if sent_chars >= sent1_word_char_loc[0] or sent_chars + word_len >= sent1_word_char_loc[1]:
        word_locs = (i, -1) 
        word1_not_found = False
      #haven't found word yet, so use prev word
      elif sent_chars > sent1_word_char_loc[1]:
        word_locs = (i - 1, -1)
        word1_not_found = False
      else:
        # Look at the next word
        sent_chars += word_len + 1 # Plus one for the space
        i += 1
    # Loop over the words in the second
    sent_chars = 0
    while word2_not_found and j < len(sent2_split):
      word_len = len(sent2_split[j])
      if sent_chars >= sent2_word_char_loc[0] or sent_chars + word_len >= sent2_word_char_loc[1]:
        word_locs = (i, j) 
        word2_not_found = False
      elif sent_chars > sent2_word_char_loc[1]:
        word_locs = (i, j - 1)
        word2_not_found = False
      else:
        # Look at the next word
        sent_chars += word_len + 1 # Plus one for the space
        j += 1
    # split out punctuation and find word index for tokenized sentences
    word1 = sent1_split[word_locs[0]].translate(str.maketrans('', '', string.punctuation)) 
    word2 = sent2_split[word_locs[1]].translate(str.maketrans('', '', string.punctuation))
    token_word_locs = find_words_in_tokenized_sentences([word1, word2], wic_encoded[-1])
    wic_word_locs.append(token_word_locs)
    # get label if there is one avail
    if training:
      if example['label']:
        wic_labels.append(1)
      else:
        wic_labels.append(0)
  # Pad the sequences and find the encoded word location in the combined input
  max_len = np.array([len(ex) for ex in wic_encoded]).max()
  wic_padded = {"input_ids" : [], "attention_mask" : [], "token_type_ids" : [], "word1_locs": [], "word2_locs" : [], "index" : wic_indexes}
  for i in range(0, len(wic_encoded)):
    enc_sentence = wic_encoded[i]
    word_locs = wic_word_locs[i]
    # Pad sequences
    ex_len = len(enc_sentence)
    padded_sentence = enc_sentence.copy()
    padded_sentence.extend([0]*(max_len - ex_len))
    wic_padded["input_ids"].append(padded_sentence)
    padded_mask = [1] * ex_len
    padded_mask.extend([0]*(max_len - ex_len))
    wic_padded["attention_mask"].append(padded_mask)
    # Create the vector to get back the words after RoBERTa
    token_word_locs = wic_word_locs[i]
    first_word_loc = []
    second_word_loc = []
    len_first_word = token_word_locs[0][1] - token_word_locs[0][0] + 1
    len_second_word = token_word_locs[1][1] - token_word_locs[1][0] + 1
    for j in range(0, max_len):
      if j >= token_word_locs[0][0] and j <= token_word_locs[0][1]:
        #Part of the first word
        first_word_loc.append(1.0 / len_first_word)
      else:
        first_word_loc.append(0.0)
      if j >= token_word_locs[1][0] and j <= token_word_locs[1][1]:
        #Part of the second word
        second_word_loc.append(1.0 / len_second_word)
      else:
        second_word_loc.append(0.0)
    # encapsulate max_len in list
    wic_padded["word1_locs"].append([first_word_loc])
    wic_padded["word2_locs"].append([second_word_loc])
    # token_type_ids is a mask that tells where the first and second sentences are
    token_type_id = []
    first_sentence = True
    sentence_start = True
    for token in padded_sentence:
      if first_sentence and sentence_start and token == 0:
        # Allows 0 at the start of the first sentence
        token_type_id.append(0)
      elif first_sentence and token > 0:
        if sentence_start:
          sentence_start = False
        token_type_id.append(0)
      elif first_sentence and not sentence_start and token == 0:
        first_sentence = False
        # Start of second sentence
        token_type_id.append(1)
      else:
        # Second sentence
        token_type_id.append(1)
    wic_padded["token_type_ids"].append(token_type_id)
  if training:
    if shuffle_data:
      # Shuffle the data
      raw_set = {"input_ids": [], "token_type_ids": [], "attention_mask": [], "labels": [], "word1_locs": [], "word2_locs" : [], "index" : []}
      raw_set["input_ids"], raw_set["token_type_ids"], raw_set["attention_mask"], raw_set["labels"], raw_set["word1_locs"], raw_set["word2_locs"], raw_set["index"] = shuffle(
          wic_padded["input_ids"], wic_padded["token_type_ids"], wic_padded["attention_mask"], wic_labels, wic_padded["word1_locs"], wic_padded["word2_locs"], wic_padded["index"])
    else:
      raw_set = {"input_ids": wic_padded["input_ids"], "token_type_ids": wic_padded["token_type_ids"],
                 "attention_mask": wic_padded["attention_mask"], "labels": wic_labels, "index" : wic_padded["index"],
                 "word1_locs": wic_padded["word1_locs"], "word2_locs" : wic_padded["word2_locs"]}
  else: # No labels present (Testing set)
    # Do not shuffle the testing set
    raw_set = {"input_ids": wic_padded["input_ids"], "token_type_ids": wic_padded["token_type_ids"], 
               "attention_mask": wic_padded["attention_mask"], "index" : wic_padded["index"], 
               "word1_locs": wic_padded["word1_locs"], "word2_locs" : wic_padded["word2_locs"]}
  # Return the raw data (Need to put them in a PyTorch tensor and dataset)
  return raw_set

In [None]:
# Process the data
train_json_objs = load_data("train.jsonl")
raw_train_set = wic_preprocessing(train_json_objs, shuffle_data=True)
print(train_json_objs[raw_train_set["index"][15]])
print(raw_train_set["input_ids"][15]),
print(raw_train_set["token_type_ids"][15]),
print(raw_train_set["attention_mask"][15]),
print(raw_train_set["labels"][15])
print(raw_train_set["word1_locs"][15])
print(raw_train_set["word2_locs"][15])

In [None]:
print(len(raw_train_set["labels"])/BATCH_SIZE)

In [None]:
# Create a PyTorch dataset for it
train_data = TensorDataset(
    torch.tensor(raw_train_set["input_ids"]),
    torch.tensor(raw_train_set["token_type_ids"]),
    torch.tensor(raw_train_set["attention_mask"]),
    torch.tensor(raw_train_set["labels"]),
    torch.tensor(raw_train_set["word1_locs"]),
    torch.tensor(raw_train_set["word2_locs"]),
    torch.tensor(raw_train_set["index"])
)
# Create a sampler and loader
train_sampler = RandomSampler(train_data)
train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=BATCH_SIZE)

In [None]:
# Load the val json objects
valid_json_objs = load_data("val.jsonl")
# preprocess the objects
raw_valid_set = wic_preprocessing(valid_json_objs)
# Create PyTorch dataset
validation_data = TensorDataset(
    torch.tensor(raw_valid_set["input_ids"]),
    torch.tensor(raw_valid_set["token_type_ids"]),
    torch.tensor(raw_valid_set["attention_mask"]),
    torch.tensor(raw_valid_set["labels"]),
    torch.tensor(raw_valid_set["word1_locs"]),
    torch.tensor(raw_valid_set["word2_locs"]),
    torch.tensor(raw_valid_set["index"])
)
# Create a sampler and loader for each
validation_sampler = SequentialSampler(validation_data)
validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=BATCH_SIZE)

In [None]:
# Load the base RoBERTa model
model = RobertaForMaskedLM.from_pretrained('roberta-base')
roberta_init_weights = model.state_dict()

In [None]:
class WiC_Head(torch.nn.Module):
    def __init__(self, roberta_based_model, embedding_size = 768):
        """
        Keeps a reference to the provided RoBERTa model. 
        It then adds a linear layer that takes the distance between two 
        """
        super(WiC_Head, self).__init__()
        self.embedding_size = embedding_size
        self.embedder = roberta_based_model
        self.linear_diff = torch.nn.Linear(embedding_size, 250, bias = True)
        self.linear_seperator = torch.nn.Linear(250, 2, bias = True)
        self.loss = torch.nn.CrossEntropyLoss()
        self.activation = torch.nn.ReLU()
        self.softmax = torch.nn.Softmax()

    def forward(self, input_ids=None, attention_mask=None, labels=None, word1_locs = None, word2_locs = None):
        """
        Takes in the same argument as RoBERTa forward plus two tensors for the location of the 2 words to compare
        """
        batch_size = word1_locs.shape[0]
        # Get the embeddings
        embs, _ = self.embedder.roberta(input_ids=input_ids, attention_mask=attention_mask)
        # Get the words
        word1s = torch.matmul(word1_locs, embs).view(batch_size, self.embedding_size)
        word2s = torch.matmul(word2_locs, embs).view(batch_size, self.embedding_size)
        diff = word1s - word2s
        # Calculate outputs using activation
        layer1_results = self.activation(self.linear_diff(diff))
        logits = self.softmax(self.linear_seperator(layer1_results))
        outputs = logits
        # Calculate the loss
        if labels is not None:
            loss = self.loss(logits.view(-1, 2), labels.view(-1))
            outputs = (loss, logits)
        return outputs

In [None]:
class_model = WiC_Head(model, embedding_size = 768)

In [None]:
# ideal min accuracy
MIN_ACCURACY = 0.73
REACHED_MIN_ACCURACY = False
best_weights = Model.state_dict()
# maximize from 0
max_val_acc = (0, 0)
# Create the optimizer
param_optimizer = list(Model.named_parameters())
no_decay = ['bias', 'gamma', 'beta']
optimizer_grouped_parameters = [
    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.01},
    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
     'weight_decay_rate': 0.0}
]
# I use the one that comes with the models, but any other optimizer could be used
optimizer = AdamW(optimizer_grouped_parameters, lr=1e-5)
# Store our loss and accuracy for plotting
fit_history = {"loss": [],  "accuracy": [], "val_loss": [], "val_accuracy": []}
epoch_number = 0
epoch_since_max = 0
continue_learning = True
while epoch_number < EPOCHS and continue_learning:
  epoch_number += 1
  print(f"Training epoch #{epoch_number}")

  # Tracking variables
  tr_loss, tr_accuracy = 0, 0
  nb_tr_examples, nb_tr_steps = 0, 0
  eval_loss, eval_accuracy = 0, 0
  nb_eval_steps, nb_eval_examples = 0, 0
  # Training
  Model.train()

  Model.embedder.requires_grad_ = False
  # Train the data for each epoch
  for step, batch in enumerate(train_dataloader):
    b_input_ids, b_token_ids, b_input_mask, b_labels, b_word1, b_word2, b_index = batch
    #reset gradient
    optimizer.zero_grad()
    # get input and compute loss
    loss, logits = Model(b_input_ids, attention_mask=b_input_mask, labels=b_labels, word1_locs = b_word1, word2_locs = b_word2) 
    # get gradient
    loss.backward()
    # Update model
    optimizer.step()
    
    logits = logits.detach().numpy()
    label_ids = b_labels.numpy()
    # Calculate the accuracy
    b_accuracy = flat_accuracy(logits, label_ids) # For RobertaForClassification
    # Append to fit history
    fit_history["loss"].append(loss.item()) 
    fit_history["accuracy"].append(b_accuracy) 
    # Update tracking variables
    tr_loss += loss.item()
    tr_accuracy += b_accuracy
    nb_tr_examples += b_input_ids.size(0)
    nb_tr_steps += 1
    if nb_tr_steps%10 == 0:
      print("\t\tTraining Batch {}: Loss: {}; Accuracy: {}".format(nb_tr_steps, loss.item(), b_accuracy))
  print("Training:\n\tLoss: {}; Accuracy: {}".format(tr_loss/nb_tr_steps, tr_accuracy/nb_tr_steps))
  
  # Validation
  Model.eval()
  # Evaluate data for one epoch
  for batch in validation_dataloader:
    b_input_ids, b_token_ids, b_input_mask, b_labels, b_word1, b_word2, b_index = batch
    # don't store gradients
    with torch.no_grad():
      # get input and compute loss
      loss, logits = Model(b_input_ids, attention_mask=b_input_mask, labels=b_labels, word1_locs = b_word1, word2_locs = b_word2)

    logits = logits.detach().numpy()
    label_ids = b_labels.numpy()
    # Calculate the accuracy
    b_accuracy = flat_accuracy(logits, label_ids) # For RobertaForClassification
    # Append to fit history
    fit_history["val_loss"].append(loss.item()) 
    fit_history["val_accuracy"].append(b_accuracy) 
    # Update tracking variables
    eval_loss += loss.item()
    eval_accuracy += b_accuracy
    nb_eval_examples += b_input_ids.size(0)
    nb_eval_steps += 1
    if nb_eval_steps%10 == 0:
      print("\t\tValidation Batch {}: Loss: {}; Accuracy: {}".format(nb_eval_steps, loss.item(), b_accuracy))
  eval_acc = eval_accuracy/nb_eval_steps
  if eval_acc >= max_val_acc[0]:
    max_val_acc = (eval_acc, epoch_number)
    continue_learning = True
    epoch_since_max = 0 # New max
    best_weights = copy.deepcopy(Model.state_dict()) # Keep the best weights
    # See if we have reached min_accuracy
    if eval_acc >= MIN_ACCURACY:
      REACHED_MIN_ACCURACY = True
    if REACHED_MIN_ACCURACY:
      continue_learning = False # Stop learning. Reached baseline acc for this model
  else:
    epoch_since_max += 1
    if epoch_since_max > PATIENCE:
      continue_learning = False # Stop learning, starting to overfit
  print("Validation:\n\tLoss={}; Accuracy: {}".format(eval_loss/nb_eval_steps, eval_accuracy/nb_eval_steps))
print(f"Best accuracy ({max_val_acc[0]}) obtained at epoch #{max_val_acc[1]}.")
# Reload the best weights (from memory)
Model.load_state_dict(best_weights)