# Baseline Implementation of Sequential Coreference Resolution using BERT

[Literature reference](https://arxiv.org/pdf/1906.03695.pdf)

## Setup necessary depedencies
This involves downloading pre-trained BERT models and importing required python dependecies

In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [4]:
# For BERT dependecies
!pip install transformers

# For Span Extraction part
!pip install allennlp



In [5]:
import numpy as np # For Linear Algebra ops
import pandas as pd # Loading GAP dataset

from pathlib import Path

from __future__ import absolute_import, division, print_function # Stuff 

import collections # Standard python lib but not used 
import logging 
import math 

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset # For data handling in PyTorch NN

# For training
from sklearn.model_selection import StratifiedKFold 
from sklearn.metrics import log_loss

from tqdm import trange

from allennlp.modules.span_extractors import SelfAttentiveSpanExtractor # To utilize the full span of the (and their context vectors) of each name candidate. 
# Mostly target pronouns are always tokenized into single token, so we only need to extract one context vector per pronoun.

from transformers import BertTokenizer, BertConfig, BertModel, AdamW  # Adam optimizer 

logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)


## Handling GAP data
The GAP dataset is hosted on github and so we need to download it and load it into memory in an appropriate container. Since, this was provided as a part of a shared task, we have access to only the development, validation and test data. So, we need to combine development and validation data sets to create a train data set

In [7]:
# GAP data set urls
test_data_url = "https://raw.githubusercontent.com/google-research-datasets/gap-coreference/master/gap-test.tsv"
dev_data_url = "https://raw.githubusercontent.com/google-research-datasets/gap-coreference/master/gap-development.tsv"
val_data_url = "https://raw.githubusercontent.com/google-research-datasets/gap-coreference/master/gap-validation.tsv"

# Load tsv data into memory using Panda's dataframe
development_df = pd.read_csv(test_data_url, delimiter="\t")
test_df = pd.read_csv(dev_data_url, delimiter="\t")
validation_df = pd.read_csv(val_data_url, delimiter="\t")

# Print one of the data frames to depict data structure
development_df

Unnamed: 0,ID,Text,Pronoun,Pronoun-offset,A,A-offset,A-coref,B,B-offset,B-coref,URL
0,test-1,Upon their acceptance into the Kontinental Hoc...,His,383,Bob Suter,352,False,Dehner,366,True,http://en.wikipedia.org/wiki/Jeremy_Dehner
1,test-2,"Between the years 1979-1981, River won four lo...",him,430,Alonso,353,True,Alfredo Di St*fano,390,False,http://en.wikipedia.org/wiki/Norberto_Alonso
2,test-3,Though his emigration from the country has aff...,He,312,Ali Aladhadh,256,True,Saddam,295,False,http://en.wikipedia.org/wiki/Aladhadh
3,test-4,"At the trial, Pisciotta said: ``Those who have...",his,526,Alliata,377,False,Pisciotta,536,True,http://en.wikipedia.org/wiki/Gaspare_Pisciotta
4,test-5,It is about a pair of United States Navy shore...,his,406,Eddie,421,True,Rock Reilly,559,False,http://en.wikipedia.org/wiki/Chasers
...,...,...,...,...,...,...,...,...,...,...,...
1995,test-1996,"The sole exception was Wimbledon, where she pl...",She,479,Goolagong Cawley,400,True,Peggy Michel,432,False,http://en.wikipedia.org/wiki/Evonne_Goolagong_...
1996,test-1997,"According to news reports, both Moore and Fily...",her,338,Esther Sheryl Wood,263,True,Barbara Morgan,404,False,http://en.wikipedia.org/wiki/Hastings_Arthur_Wise
1997,test-1998,"In June 2009, due to the popularity of the Sab...",She,328,Kayla,364,True,Natasha Henstridge,412,False,http://en.wikipedia.org/wiki/Raya_Meddine
1998,test-1999,She was delivered to the Norwegian passenger s...,she,305,Irma,255,True,Bergen,274,False,http://en.wikipedia.org/wiki/SS_Irma_(1905)


Apart from this, we also define a few PyTorch primitives (classes and functions) to help handle GAP data during modelling phase. 

One point to note is that since the Dataset is very structured, we don't need to worry about generating all possible spans for coreference resolution. Instead, we focus on constructing spans for the A and B resolution targets

In [10]:
# Create an integer mapping for Gendered Pronouns in English
GENDER_MAPPING = {'he': 0, 'him': 0, 'his': 0,
                  'her': 1, 'she': 1, 'hers': 1,
                  'it': 2, 'that': 2} # 'it' and 'that' were added for sake of completeness


def add_gender_target_details(df):
  """
  Modify given data frame to add pronoun gender
  and target data.
  
  The pronoun gender column maps the target pronoun to an
  integer code defined in `GENDER_MAPPING`.

  The target column provides a simplied representation of the
  A-coref and B-coref columns found in the GAP dataset.
  """
  entries = len(df['ID'])
  # Introduce new columns in the data frame
  df['Gender'] = pd.Series(np.ones(entries, dtype=int), index=df.index)
  df['Target'] = pd.Series(np.ones(entries, dtype=int), index=df.index)

  for index, row in df.iterrows():
    df.loc[index, 'Gender'] = GENDER_MAPPING[row['Pronoun'].lower()]
    if row['A-coref']:
      df.loc[index, 'Target'] = 0
    elif row['B-coref']:
      df.loc[index, 'Target'] = 1
    else:
      df.loc[index, 'Target'] = 2
  
  return None # redundant statment. Needed for compiling code :/


def tokenize_data(row, tokenizer):
  break_points = sorted([("A", row["A-offset"], row["A"]),
                         ("B", row["B-offset"], row["B"]),
                         ("P", row["Pronoun-offset"], row["Pronoun"]),
                         ], key=lambda x: x[1])
  tokens, spans, current_pos = [], {}, 0
  for name, offset, text in break_points:
    tokens.extend(tokenizer.tokenize(row["Text"][current_pos:offset]))
        
    # Tokenize the target
    tmp_tokens = tokenizer.tokenize(row["Text"][offset:offset+len(text)])

    spans[name] = [len(tokens) - 2 , len(tokens) + len(tmp_tokens) + 1] # inclusive
    tokens.extend(tmp_tokens)
    current_pos = offset + len(text)
  tokens.extend(tokenizer.tokenize(row["Text"][current_pos:offset]))

  # Handle edge cases for span indices
  for key in spans.keys():
    start_idx = spans[key][0]
    end_idx = spans[key][-1]
    if start_idx > end_idx:
      start_idx = -1
      end_idx = -1
    if start_idx < 0:
      start_idx = 0
    if end_idx >= len(tokens):
      end_idx = len(tokens) -1
    
    spans[key][0] = start_idx
    spans[key][1] = end_idx

  return tokens, (spans["A"] + spans["B"] + [spans["P"][0]])




# PyTorch Dataset specialization
class ModelDataset(Dataset):
  """Custom Dataset for GAP datasets"""
  
  def __init__(self, df, tokenizer, is_data_labelled=True):
    # Modify current df to add pronoun gender and target information
    self.is_labelled = is_data_labelled
    add_gender_target_details(df)
    
    # Extract expected output from dataset if labelled
    if is_data_labelled:
      self._outputs = df['Target'].values.astype("uint8")

    # For each data entry in the data frame, tokenize data and extract out inputs to
    # coreference model   
    self._offsets, self._tokens, self._ids = [], [], []
    for _, row in df.iterrows():
      tokens, offsets = tokenize_data(row, tokenizer)
      self._offsets.append(offsets)
      self._tokens.append(
          tokenizer.convert_tokens_to_ids(["[CLS]"] + tokens + ["[SEP]"]))
      self._ids.append(row['ID'])
      
  def __len__(self):
    return len(self._tokens)

  def __getitem__(self, idx):
    if self.is_labelled:
      return self._ids[idx], self._tokens[idx], self._offsets[idx], self._outputs[idx]
    return self._ids[idx], self._tokens[idx], self._offsets[idx], None

## Building the Neural Layer
Here, we deal with the details of constructing a Neural Coreference resolution system which uses BERT embeddings.

In [11]:
class AntecedentScorer(nn.Module):
  """
  NN layer for handling antecedent ranking
  """

  def __init__(self, input_size):
    super().__init__()
    self._inp_size = input_size
    self._span_extractor = SelfAttentiveSpanExtractor(input_size)
    self.fc = nn.Sequential(
        nn.Dropout(0.2),
        nn.Linear(input_size * 3, 512),
        nn.ReLU(),
        nn.Linear(512, 3)
    )

  def forward(self, inputs, offsets):
    span_embds = self._span_extractor(
            inputs, 
            offsets[:, :4].reshape(-1, 2, 2))
    target_pn = torch.gather(
        inputs, 1, offsets[:, [4]].unsqueeze(2).expand(-1, -1, self._inp_size))
    input = torch.cat([span_embds, target_pn], dim=1)
    input = input.reshape(offsets.size()[0], -1)
    return self.fc(input)


class GAPResolver(nn.Module):
  """
  Neural Net for performing pronoun resolution on GAP dataset
  """

  def __init__(self, bert_model, device):
    super().__init__()
    self._device = torch_device
    if bert_model in ("bert-base-uncased", "bert-base-cased"):
      self._bert_hidden_size = 768
    elif bert_model in ("bert-large-uncased", "bert-large-cased"):
      self._bert_hidden_size = 1024
    # Bare Bert Model transformer outputting raw hidden-states without any specific head on top
    self._bert = BertModel.from_pretrained(bert_model).to(device)

    self._ac_scorer = AntecedentScorer(self._bert_hidden_size).to(device)

  def forward(self, tok_tnsr, offsets):
    tok_tnsr = tok_tnsr.to(self._device)
    bert_outputs = self._bert(tok_tnsr, attention_mask=(tok_tnsr > 0).long(), 
            token_type_ids=None)
    inputs = bert_outputs.last_hidden_state
    fin_outs = self._ac_scorer(inputs, offsets.to(self._device))

    return fin_outs
    

## Putting it all together
With all the core components defined, we now need to add some code to train the neural net for coreference resolution

In [12]:
def collate_examples(batch, truncate_len=490):
  """ For handling data in batches
  1. Pad the sequences
  2. Transform the target.
  """
  
  transposed = list(zip(*batch))
  id = transposed[0]
  max_len = min( max((len(x) for x in transposed[1])), truncate_len)
  
  tokens = np.zeros((len(batch), max_len), dtype=np.int64)
  for i, row in enumerate(transposed[1]):
    row = np.array(row[:truncate_len])
    tokens[i, :len(row)] = row
  
  token_tensor = torch.from_numpy(tokens)
  # Offsets
  offsets = torch.stack([torch.LongTensor(x) for x in transposed[2]], dim=0) + 1 # Account for the [CLS] token
  # Labels
  if len(transposed) == 3 or transposed[3][0] is None:
    return (id, token_tensor, offsets, None)
  labels = torch.LongTensor(transposed[3])
  return (id, token_tensor, offsets, labels)


# Define function for training the NN for coreference resolution
def train(dev_df, val_df, tokenizer, model, device, epochs, learn_rate, train_batch_size, val_batch_size):
  # Convert Pandas dataframe into PyTorch Dataset
  # Since we don't have a decdicated train data set, we will combine the development
  # and validation data sets and do stratified 5 folding on them
  train_df = pd.concat([dev_df, val_df])
  kfold = StratifiedKFold(n_splits=5) 

  # Creating a copy of the DF for KFold loop
  kf_df = test_df.copy()
  add_gender_target_details(kf_df) # We split our training data based on Gender of pronoun

  # Some logic to set different layers as trainable
  def get_children(model):
    return model if isinstance(model, (list, tuple)) else list(model.children())
  
  def set_trainable_attr(model, flag):
    model.trainable = flag
    for param in model.parameters():
      param.requires_grad = flag
      
  def apply_leaf(model, func):
    c = get_children(model)
    if isinstance(model, nn.Module):
      func(model)
    if len(c) > 0:
      for l in c:
        apply_leaf(l, func)
  
  def set_trainable(l, flag):
    apply_leaf(l, lambda model: set_trainable_attr(model, flag))
    
  set_trainable(model._bert, True)
  set_trainable(model._ac_scorer, True)
  # For fine-tuning?
  #for i in range(12,24):
  #  set_trainable(model._bert.encoder.layer[i], True)
    
  optimizer = torch.optim.Adam(model.parameters(), lr=learn_rate)
  loss_criterion = nn.CrossEntropyLoss()

  fold_count = 1
  val_preds, train_loss, val_loss, val_outs = [], [], [], []
  for train_index, valid_index in kfold.split(kf_df, kf_df["Gender"]):
    print("~~~~~" * 20)
    print(f"Fold {fold_count}")
    print("~~~~" * 20)

    fold_count += 1 # A premature update of fold count

    # Creating test and validation data sets
    train_ds = ModelDataset(train_df.iloc[train_index], tokenizer)
    train_loader = DataLoader(train_ds,
                              collate_fn = collate_examples,
                              batch_size=train_batch_size,
                              num_workers=2,
                              pin_memory=True,
                              drop_last=False)
    val_ds = ModelDataset(train_df.iloc[valid_index], tokenizer)
    val_loader = DataLoader(val_ds,
                            collate_fn = collate_examples,
                            batch_size=val_batch_size,
                            num_workers=2,
                            pin_memory=True)
    
    for _ in trange(epochs):
      model.train()

      # Train in batches
      tr_loss, val_acc = 0, 0
      train_steps, val_steps = 0, 0
      for step, batch in enumerate(train_loader):
        # Link batch to device
        batch = batch[1:]
        batch = tuple(t.to(device) for t in batch)

        # Reset Optimizer Gradient
        optimizer.zero_grad()

        # Get batch data
        tokens, offsets, label = batch

        # Forward pass
        predictions = model(tokens, offsets)
        batch_loss = loss_criterion(predictions, label) 
        train_loss.append(batch_loss.item())

        # Backward propagation
        batch_loss.backward()

        # Take a step along computed gradient
        optimizer.step()

        tr_loss += train_loss[-1]
        train_steps += 1
      
      print("Average Training Loss: {} for {} batches".format(tr_loss / train_steps, train_steps))
      
      # Switch over to batch validation
      model.eval()
      for batch in val_loader:
        # Link batch to device
        batch = batch[1:]
        batch = tuple(t.to(device) for t in batch)

        # Reset Optimizer Gradient
        optimizer.zero_grad()

        # Get batch data
        tokens, offset, label = batch

        # We don't want to comupute gradient descent
        with torch.no_grad():
          predictions = model(tokens, offset)
          val_outs.append(label.cpu().numpy())
          val_preds.append(torch.nn.functional.softmax(predictions, -1).clamp(1e-4, 1-1e-4).cpu().numpy())
          loss = loss_criterion(predictions, label)
          val_loss.append(loss.item())

          val_acc += val_loss[-1]
          val_steps +=1 
      
      print("Average Validation Accuracy: {} for {} batches".format(val_acc / val_steps, val_steps))

    # Add code for train and validation
  return val_preds, val_loss, train_loss, model

def evaluate_test_data(test_df, model, tokenizer, test_batch_size, device):
  test_ds = ModelDataset(test_df, tokenizer)
  test_loader = DataLoader(test_ds, 
                           collate_fn =collate_examples,
                           batch_size=test_batch_size,
                           num_workers=2,
                           pin_memory=True,
                           shuffle=False)
  
  test_preds = []
  model.eval()

  for batch in test_loader:
    # Link batch to device
    ids = batch[0]
    batch = batch[1:]
    batch = tuple(t.to(device) for t in batch)

    # Get batch data
    tokens, offset, label = batch
    
    with torch.no_grad():
      predictions = model(tokens, offset)

      label = label.cpu().numpy()
      preds = torch.nn.functional.softmax(predictions, -1).clamp(1e-4, 1-1e-4).cpu().numpy()
      for i in range(len(ids)):
        id = ids[i]
        target = label[i]
        pred = preds[i]
        test_preds.append({id: {'target': target, 'prediction': pred}})


  return test_preds

## CorefSeq in Action
Finally, we connect all pieces together and train our model

In [14]:
import time

# Create torch tensor for device
torch_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

bert_model_type = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_type, do_lower_case=True, never_split = ("[UNK]", "[SEP]", "[PAD]", "[CLS]", "[MASK]"))
model = GAPResolver(bert_model_type, torch_device)

#model.load_state_dict(torch.load("./saved_models/baseline_gap_model"))

# Training Model on given data
train_start_time = time.time() * 1000
val_preds, val_loss, train_loss, model = train(development_df, validation_df,
                                               tokenizer, model, torch_device,
                                               15, 1e-5, 10, 32)
train_end_time = time.time() * 1000

print("Time for training: {} ms".format(train_end_time - train_start_time))

~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fold 1
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~


A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  isetter(loc, value)
  0%|          | 0/15 [00:00<?, ?it/s]

Average Training Loss: 0.958288985863328 for 160 batches


  7%|▋         | 1/15 [00:43<10:04, 43.20s/it]

Average Validation Accuracy: 0.8409944222523615 for 13 batches
Average Training Loss: 0.6769433764740824 for 160 batches


 13%|█▎        | 2/15 [01:28<09:27, 43.69s/it]

Average Validation Accuracy: 0.5712034060404851 for 13 batches
Average Training Loss: 0.39215081962756815 for 160 batches


 20%|██        | 3/15 [02:14<08:54, 44.55s/it]

Average Validation Accuracy: 0.5879032485760175 for 13 batches
Average Training Loss: 0.1994369146297686 for 160 batches


 27%|██▋       | 4/15 [03:00<08:16, 45.11s/it]

Average Validation Accuracy: 0.6798873910537133 for 13 batches
Average Training Loss: 0.1057316745922435 for 160 batches


 33%|███▎      | 5/15 [03:47<07:36, 45.65s/it]

Average Validation Accuracy: 0.67845152089229 for 13 batches
Average Training Loss: 0.056625592816271816 for 160 batches


 40%|████      | 6/15 [04:34<06:52, 45.85s/it]

Average Validation Accuracy: 0.7351138488604472 for 13 batches
Average Training Loss: 0.026400737966469023 for 160 batches


 47%|████▋     | 7/15 [05:21<06:09, 46.15s/it]

Average Validation Accuracy: 0.8021753464753811 for 13 batches
Average Training Loss: 0.015704687527613715 for 160 batches


 53%|█████▎    | 8/15 [06:07<05:23, 46.24s/it]

Average Validation Accuracy: 0.8598805006880027 for 13 batches
Average Training Loss: 0.008857141900807618 for 160 batches


 60%|██████    | 9/15 [06:54<04:38, 46.34s/it]

Average Validation Accuracy: 0.9135048217498339 for 13 batches
Average Training Loss: 0.01372988889052067 for 160 batches


 67%|██████▋   | 10/15 [07:40<03:52, 46.48s/it]

Average Validation Accuracy: 0.931015646801545 for 13 batches
Average Training Loss: 0.01300792220645235 for 160 batches


 73%|███████▎  | 11/15 [08:27<03:05, 46.50s/it]

Average Validation Accuracy: 0.9565372237792382 for 13 batches
Average Training Loss: 0.013341928262889269 for 160 batches


 80%|████████  | 12/15 [09:13<02:19, 46.49s/it]

Average Validation Accuracy: 0.9830866490419095 for 13 batches
Average Training Loss: 0.015183712556245154 for 160 batches


 87%|████████▋ | 13/15 [10:00<01:33, 46.50s/it]

Average Validation Accuracy: 0.9672655761241913 for 13 batches
Average Training Loss: 0.010740587284635695 for 160 batches


 93%|█████████▎| 14/15 [10:47<00:46, 46.54s/it]

Average Validation Accuracy: 1.0056672342694724 for 13 batches
Average Training Loss: 0.0031549727881611035 for 160 batches


100%|██████████| 15/15 [11:33<00:00, 46.26s/it]

Average Validation Accuracy: 1.0211418271064758 for 13 batches
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fold 2
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



  0%|          | 0/15 [00:00<?, ?it/s]

Average Training Loss: 0.2217607130529359 for 160 batches


  7%|▋         | 1/15 [00:46<10:56, 46.88s/it]

Average Validation Accuracy: 0.005886858902298487 for 13 batches
Average Training Loss: 0.0655596659409639 for 160 batches


 13%|█▎        | 2/15 [01:33<10:09, 46.86s/it]

Average Validation Accuracy: 0.004862012090877845 for 13 batches
Average Training Loss: 0.02086710789517383 for 160 batches


 20%|██        | 3/15 [02:20<09:21, 46.76s/it]

Average Validation Accuracy: 0.0033382712236533943 for 13 batches
Average Training Loss: 0.012788765174991567 for 160 batches


 27%|██▋       | 4/15 [03:06<08:34, 46.74s/it]

Average Validation Accuracy: 0.007393418179932409 for 13 batches
Average Training Loss: 0.009614279092238576 for 160 batches


 33%|███▎      | 5/15 [03:53<07:47, 46.80s/it]

Average Validation Accuracy: 0.002567300877462213 for 13 batches
Average Training Loss: 0.003952711708006973 for 160 batches


 40%|████      | 6/15 [04:40<07:00, 46.76s/it]

Average Validation Accuracy: 0.001938638606449016 for 13 batches
Average Training Loss: 0.002012336052757746 for 160 batches


 47%|████▋     | 7/15 [05:27<06:14, 46.79s/it]

Average Validation Accuracy: 0.002345608659267712 for 13 batches
Average Training Loss: 0.0052827766738118955 for 160 batches


 53%|█████▎    | 8/15 [06:13<05:27, 46.72s/it]

Average Validation Accuracy: 0.002834900215160675 for 13 batches
Average Training Loss: 0.0063020003063684275 for 160 batches


 60%|██████    | 9/15 [07:00<04:40, 46.79s/it]

Average Validation Accuracy: 0.0037853720602400312 for 13 batches
Average Training Loss: 0.0017852271658739482 for 160 batches


 67%|██████▋   | 10/15 [07:47<03:53, 46.71s/it]

Average Validation Accuracy: 0.012609692477361443 for 13 batches
Average Training Loss: 0.0016766657774951454 for 160 batches


 73%|███████▎  | 11/15 [08:34<03:07, 46.81s/it]

Average Validation Accuracy: 0.0013373802027602394 for 13 batches
Average Training Loss: 0.0010669680940736726 for 160 batches


 80%|████████  | 12/15 [09:21<02:20, 46.75s/it]

Average Validation Accuracy: 0.009255979534198279 for 13 batches
Average Training Loss: 0.002436991043123271 for 160 batches


 87%|████████▋ | 13/15 [10:08<01:33, 46.82s/it]

Average Validation Accuracy: 0.005061068533485433 for 13 batches
Average Training Loss: 0.0009736553472976083 for 160 batches


 93%|█████████▎| 14/15 [10:54<00:46, 46.81s/it]

Average Validation Accuracy: 0.0037189577974808905 for 13 batches
Average Training Loss: 0.0006245094515179517 for 160 batches


100%|██████████| 15/15 [11:41<00:00, 46.77s/it]

Average Validation Accuracy: 0.0013801426108469828 for 13 batches
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fold 3
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



  0%|          | 0/15 [00:00<?, ?it/s]

Average Training Loss: 0.013220405990296058 for 160 batches


  7%|▋         | 1/15 [00:47<10:59, 47.08s/it]

Average Validation Accuracy: 0.0008580285771481263 for 13 batches
Average Training Loss: 0.007866737923723121 for 160 batches


 13%|█▎        | 2/15 [01:33<10:10, 46.98s/it]

Average Validation Accuracy: 0.000224181040091655 for 13 batches
Average Training Loss: 0.01975283504825711 for 160 batches


 20%|██        | 3/15 [02:20<09:22, 46.91s/it]

Average Validation Accuracy: 0.0009152453567367047 for 13 batches
Average Training Loss: 0.003389641381727415 for 160 batches


 27%|██▋       | 4/15 [03:07<08:36, 46.96s/it]

Average Validation Accuracy: 0.00026092327271516505 for 13 batches
Average Training Loss: 0.0011569778281227626 for 160 batches


 33%|███▎      | 5/15 [03:54<07:48, 46.90s/it]

Average Validation Accuracy: 0.00019249435102280515 for 13 batches
Average Training Loss: 0.004606316414219691 for 160 batches


 40%|████      | 6/15 [04:41<07:02, 46.92s/it]

Average Validation Accuracy: 0.00026320161011356575 for 13 batches
Average Training Loss: 0.0011317124680999767 for 160 batches


 47%|████▋     | 7/15 [05:28<06:14, 46.85s/it]

Average Validation Accuracy: 0.00018712747204027927 for 13 batches
Average Training Loss: 0.004747668070513101 for 160 batches


 53%|█████▎    | 8/15 [06:14<05:28, 46.86s/it]

Average Validation Accuracy: 0.0008243496424536436 for 13 batches
Average Training Loss: 0.0011795750666351522 for 160 batches


 60%|██████    | 9/15 [07:01<04:41, 46.92s/it]

Average Validation Accuracy: 0.00018657175342713555 for 13 batches
Average Training Loss: 0.00040062965662173154 for 160 batches


 67%|██████▋   | 10/15 [07:48<03:54, 46.85s/it]

Average Validation Accuracy: 0.00015857002934073814 for 13 batches
Average Training Loss: 0.00031404112706923115 for 160 batches


 73%|███████▎  | 11/15 [08:35<03:07, 46.83s/it]

Average Validation Accuracy: 0.00012390839737445975 for 13 batches
Average Training Loss: 0.007737730181440838 for 160 batches


 80%|████████  | 12/15 [09:22<02:20, 46.83s/it]

Average Validation Accuracy: 0.015748504427812386 for 13 batches
Average Training Loss: 0.012105705311387282 for 160 batches


 87%|████████▋ | 13/15 [10:09<01:33, 46.82s/it]

Average Validation Accuracy: 0.0007918249220193292 for 13 batches
Average Training Loss: 0.0009360734841720841 for 160 batches


 93%|█████████▎| 14/15 [10:56<00:46, 46.86s/it]

Average Validation Accuracy: 0.0003066354309759425 for 13 batches
Average Training Loss: 0.0029359435976289206 for 160 batches


100%|██████████| 15/15 [11:42<00:00, 46.84s/it]

Average Validation Accuracy: 0.0037233176433521574 for 13 batches
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fold 4
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



  0%|          | 0/15 [00:00<?, ?it/s]

Average Training Loss: 0.0045365424881765645 for 160 batches


  7%|▋         | 1/15 [00:47<11:03, 47.37s/it]

Average Validation Accuracy: 0.00012737563967955514 for 13 batches
Average Training Loss: 0.004187617772879548 for 160 batches


 13%|█▎        | 2/15 [01:33<10:12, 47.11s/it]

Average Validation Accuracy: 0.0009346154735025126 for 13 batches
Average Training Loss: 0.004254275124151263 for 160 batches


 20%|██        | 3/15 [02:20<09:24, 47.05s/it]

Average Validation Accuracy: 0.0005739533684950752 for 13 batches
Average Training Loss: 0.005665647909518156 for 160 batches


 27%|██▋       | 4/15 [03:07<08:36, 46.95s/it]

Average Validation Accuracy: 0.000733032832403506 for 13 batches
Average Training Loss: 0.00244438955660371 for 160 batches


 33%|███▎      | 5/15 [03:54<07:48, 46.85s/it]

Average Validation Accuracy: 0.00011197373113156154 for 13 batches
Average Training Loss: 0.0024037234681486552 for 160 batches


 40%|████      | 6/15 [04:40<07:01, 46.83s/it]

Average Validation Accuracy: 0.00025826802280230017 for 13 batches
Average Training Loss: 0.0007978317893730491 for 160 batches


 47%|████▋     | 7/15 [05:27<06:14, 46.82s/it]

Average Validation Accuracy: 9.60553427900707e-05 for 13 batches
Average Training Loss: 0.00025531347201876995 for 160 batches


 53%|█████▎    | 8/15 [06:14<05:27, 46.80s/it]

Average Validation Accuracy: 8.668251183945149e-05 for 13 batches
Average Training Loss: 0.0006971542915607642 for 160 batches


 60%|██████    | 9/15 [07:01<04:40, 46.81s/it]

Average Validation Accuracy: 0.0001428762797541612 for 13 batches
Average Training Loss: 0.0003636313417018755 for 160 batches


 67%|██████▋   | 10/15 [07:48<03:54, 46.91s/it]

Average Validation Accuracy: 5.28521153543037e-05 for 13 batches
Average Training Loss: 0.0001479737877559728 for 160 batches


 73%|███████▎  | 11/15 [08:35<03:07, 46.83s/it]

Average Validation Accuracy: 5.080100198448725e-05 for 13 batches
Average Training Loss: 0.0002744060963323136 for 160 batches


 80%|████████  | 12/15 [09:22<02:20, 46.87s/it]

Average Validation Accuracy: 4.918883377495849e-05 for 13 batches
Average Training Loss: 0.00011774061092069132 for 160 batches


 87%|████████▋ | 13/15 [10:08<01:33, 46.75s/it]

Average Validation Accuracy: 5.316138329082885e-05 for 13 batches
Average Training Loss: 9.521567916408458e-05 for 160 batches


 93%|█████████▎| 14/15 [10:55<00:46, 46.83s/it]

Average Validation Accuracy: 5.046901473038955e-05 for 13 batches
Average Training Loss: 0.00012848915990559817 for 160 batches


100%|██████████| 15/15 [11:42<00:00, 46.81s/it]

Average Validation Accuracy: 3.286843345379636e-05 for 13 batches
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Fold 5
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~



  0%|          | 0/15 [00:00<?, ?it/s]

Average Training Loss: 0.0015806850761777014 for 160 batches


  7%|▋         | 1/15 [00:47<11:06, 47.60s/it]

Average Validation Accuracy: 4.97498781791924e-05 for 13 batches
Average Training Loss: 0.004998047873669975 for 160 batches


 13%|█▎        | 2/15 [01:34<10:14, 47.28s/it]

Average Validation Accuracy: 8.141126363625517e-05 for 13 batches
Average Training Loss: 0.010788063883194355 for 160 batches


 20%|██        | 3/15 [02:21<09:26, 47.21s/it]

Average Validation Accuracy: 0.00013087596016703174 for 13 batches
Average Training Loss: 0.0007655439669633779 for 160 batches


 27%|██▋       | 4/15 [03:08<08:38, 47.13s/it]

Average Validation Accuracy: 4.389218436419749e-05 for 13 batches
Average Training Loss: 0.0005232600414274202 for 160 batches


 33%|███▎      | 5/15 [03:55<07:50, 47.09s/it]

Average Validation Accuracy: 5.2026302280585064e-05 for 13 batches
Average Training Loss: 0.00021228285123839897 for 160 batches


 40%|████      | 6/15 [04:42<07:03, 47.05s/it]

Average Validation Accuracy: 4.331073229807841e-05 for 13 batches
Average Training Loss: 0.00024765932629122744 for 160 batches


 47%|████▋     | 7/15 [05:28<06:15, 46.94s/it]

Average Validation Accuracy: 2.623207011800752e-05 for 13 batches
Average Training Loss: 0.00011338518625336746 for 160 batches


 53%|█████▎    | 8/15 [06:15<05:28, 46.95s/it]

Average Validation Accuracy: 2.3330479962169193e-05 for 13 batches
Average Training Loss: 0.0002414825088607131 for 160 batches


 60%|██████    | 9/15 [07:02<04:41, 46.89s/it]

Average Validation Accuracy: 5.7026025597923974e-05 for 13 batches
Average Training Loss: 0.00013144000771774244 for 160 batches


 67%|██████▋   | 10/15 [07:49<03:54, 46.91s/it]

Average Validation Accuracy: 2.8049866999096524e-05 for 13 batches
Average Training Loss: 0.00011069787991857538 for 160 batches


 73%|███████▎  | 11/15 [08:36<03:07, 46.92s/it]

Average Validation Accuracy: 1.7633955534480403e-05 for 13 batches
Average Training Loss: 7.894124396017333e-05 for 160 batches


 80%|████████  | 12/15 [09:23<02:20, 46.94s/it]

Average Validation Accuracy: 1.5821484601894135e-05 for 13 batches
Average Training Loss: 6.342710197628775e-05 for 160 batches


 87%|████████▋ | 13/15 [10:10<01:33, 46.91s/it]

Average Validation Accuracy: 1.3455823556376765e-05 for 13 batches
Average Training Loss: 6.247559437184691e-05 for 160 batches


 93%|█████████▎| 14/15 [10:56<00:46, 46.85s/it]

Average Validation Accuracy: 1.6604564813440866e-05 for 13 batches
Average Training Loss: 4.7094444208539696e-05 for 160 batches


100%|██████████| 15/15 [11:43<00:00, 46.93s/it]

Average Validation Accuracy: 1.2552395376657772e-05 for 13 batches
Time for training: 3530793.535888672 ms





# Evaluation and Analysis
Now that we have our test predicitions, we 

In [15]:
#torch.save(model.state_dict(), "./drive/MyDrive/Colab Notebooks/baseline_gap_model")

In [16]:
# Evaluating model on test data
test_preds = evaluate_test_data(test_df, model, tokenizer, 32, torch_device)


In [17]:
import csv
a_preds, b_preds, neither_preds,  = [], [], []
a_corefs, b_corefs, corr_preds, ids = [], [], [], []

for data in test_preds:
  key = list(data.keys())[0]
  a_pred = data[key]['prediction'][0]
  b_pred = data[key]['prediction'][1]
  n_pred = data[key]['prediction'][2]
  target = data[key]['target']
  pred_correct = False

  if target == 0:
    if a_pred > b_pred and a_pred > n_pred:
      pred_correct = True
      a_coref = True
      b_coref = False
  elif target == 1:
    if b_pred > a_pred and b_pred > n_pred:
      pred_correct = True
      a_coref = False
      b_coref = True
  else:
    if n_pred > a_pred and n_pred > b_pred:
      pred_correct = True
      a_coref = False
      b_coref = False

  a_preds.append(a_pred)
  b_preds.append(b_pred)
  neither_preds.append(n_pred)
  a_corefs.append(a_coref)
  b_corefs.append(b_coref)
  corr_preds.append(pred_correct)
  ids.append(key)



a_preds = np.array(a_preds)
b_preds = np.array(b_preds)
neither_preds = np.array(neither_preds)

result_df = pd.DataFrame([ids[:], a_preds[:], b_preds[:], neither_preds[:], a_corefs[:], b_corefs[:], corr_preds[:]],
                         index=['ID', 'A', 'B', 'NEITHER', 'A-coref', 'B-coref', 'Predicted Correctly?']).transpose()
result_df.to_csv('result.tsv', index=False, sep='\t', quoting=csv.QUOTE_NONE)