In [None]:
!pip install -q transformers ftfy regex tqdm fvcore imageio imageio-ffmpeg openai pattern

import pandas as pd
import openai
import re
from collections import Counter
from google.colab import drive
import os
import spacy
import nltk
from nltk.stem import WordNetLemmatizer
import json

In [None]:
drive.mount('/gdrive/MyDrive/DL4NLP')

Mounted at /gdrive


In [None]:
nlp = spacy.load("en_core_web_sm")
nltk.download("wordnet")

lemmatizer = WordNetLemmatizer()

def process_text(input_text):

    text = " ".join(input_text)
    doc = nlp(text)
    modified_tokens = []

    for token in doc:
        # Change pronouns
        if token.text.lower() in ['he', 'she', 'him', 'her', 'his', 'hers']:
            if token.text.lower() == 'he':
                modified_tokens.append('they')
            elif token.text.lower() == 'she':
                modified_tokens.append('they')
            elif token.text.lower() == 'him':
                modified_tokens.append('them')
            elif token.text.lower() == 'her':
                modified_tokens.append('them')
            elif token.text.lower() == 'his':
                modified_tokens.append('their')
            elif token.text.lower() == 'hers':
                modified_tokens.append('their')
        else:
            # Process verbs
            if token.pos_ == "VERB":
                singular_verb = token.text
                plural_verb = lemmatizer.lemmatize(singular_verb, 'v')
                modified_tokens.append(plural_verb)
            else:
                modified_tokens.append(token.text)

    modified_text = " ".join(modified_tokens)

    final_text = modified_text.split()
    return final_text

# example
input_text = ["she", "wants", "to", "take", "the", "book"]
output_text = process_text(input_text)
print(output_text)

In [None]:
class GAP_sentence:
  def __init__(self, sentence_object: pd.core.series.Series):
    self.id = sentence_object['ID']
    self.text = sentence_object['Text']

    self.pronoun = sentence_object['Pronoun']
    self.option_a = sentence_object['A']
    self.option_b = sentence_object['B']

    self.pronoun_start = sentence_object['Pronoun-offset']
    self.pronoun_end = sentence_object['Pronoun-offset']+len(self.pronoun)

    self.option_a_start = sentence_object['A-offset']
    self.option_a_end = sentence_object['A-offset']+len(self.option_a)

    self.option_b_start = sentence_object['B-offset']
    self.option_b_end = sentence_object['B-offset']+len(self.option_b)

    self.option_a_identity = sentence_object['A-coref']
    self.option_b_identity = sentence_object['B-coref']

    self.modified_text = ''
    self.prompt = ''

    self.true_clusters = []
    self.pred_clusters = []
    self.mentions = [self.pronoun, self.option_a, self.option_b]

    self.true_reference = ''
    self.pronoun_cluster = -1
    self.option_a_cluster = -1
    self.option_b_cluster = -1

  def add_clusters(self):
    text = self.text

    reversed_offset_dict = {self.pronoun_start: self.pronoun, # this reversed dict is possible because
                            self.option_a_start: self.option_a, # all offsets are different by definition
                            self.option_b_start: self.option_b}

    ordered_offsets = sorted(list(reversed_offset_dict.keys()))

    entity_1 = reversed_offset_dict[ordered_offsets[0]]
    entity_2 = reversed_offset_dict[ordered_offsets[1]]
    entity_3 = reversed_offset_dict[ordered_offsets[2]]

    add_in_start = '['
    add_in_end = '](#)'

    modified_text = (self.text[:ordered_offsets[0]-1] + add_in_start + entity_1 + add_in_end +
                     self.text[ordered_offsets[0]+len(entity_1) : ordered_offsets[1]-1] + add_in_start + entity_2 + add_in_end +
                     self.text[ordered_offsets[1]+len(entity_2) : ordered_offsets[2]-1] + add_in_start + entity_3 + add_in_end +
                     self.text[ordered_offsets[2]+len(entity_3) :]
                     )

    self.modified_text = modified_text
    # self.prompt = self.prompt + modified_text

    self.entity_1 = entity_1
    self.entity_2 = entity_2
    self.entity_3 = entity_3

    # find all true clusters and add them as a list of lists to class instance
    true_clusters = []
    cluster_1 = [self.pronoun]
    cluster_2 = []
    cluster_3 = []
    if self.option_a_identity:
      cluster_1.append(self.option_a)
      cluster_2.append(self.option_b)
    elif self.option_b_identity:
      cluster_1.append(self.option_b)
      cluster_2.append(self.option_a)
    else:
      cluster_2.append(self.option_a)
      cluster_3 = [self.option_b]

    true_clusters.append(cluster_1)
    true_clusters.append(cluster_2)
    if cluster_3:
      true_clusters.append(cluster_3)

    self.true_clusters += true_clusters

    # return modified_text


class Predictor(GAP_sentence):
  def __init__(self, sentence_object, model_name):
    super().__init__(sentence_object)

    self.predicted_text = ''
    self.model_name = model_name

  def return_prompt(self, sentence):
    return "Annotate all entity mentions, annotated as [entity](#) in the following text with coreference clusters. Use Markdown tags to indicate clusters in the output, with the following format [mention](#cluster_name). \n Input: {} \n Output:".format(sentence)

  def prompt_llm(self, max_tokens=750, temperature=0.5, stop=None):
    '''Predicts current sentence using chosen language model'''

    if self.model_name == "gpt":
      openai_api_key = "" # paste your own key here!
      openai.api_key = openai_api_key

      # sentence 1 of GAP development is used as an example for GPT3.5-Turbo
      example_sentence = "He grew up in Evanston, Illinois the second oldest of five children including his brothers, Fred and Gordon and sisters, Marge (Peppy) and Marilyn. His high school days were spent at New Trier High School in Winnetka, Illinois.[MacKenzie](#) studied with[Bernard Leach](#) from 1949 to 1952.[His](#) simple, wheel-thrown functional pottery is heavily influenced by the oriental aesthetic of Shoji Hamada and Kanjiro Kawai."
      example_solved_sentence = "He grew up in Evanston, Illinois the second oldest of five children including his brothers, Fred and Gordon and sisters, Marge (Peppy) and Marilyn. His high school days were spent at New Trier High School in Winnetka, Illinois.[MacKenzie](#cluster_1) studied with[Bernard Leach](#cluster_2) from 1949 to 1952.[His](#cluster_1) simple, wheel-thrown functional pottery is heavily influenced by the oriental aesthetic of Shoji Hamada and Kanjiro Kawai."
      example_prompt = self.return_prompt(example_sentence)

      prompt = self.return_prompt(self.modified_text)
      print("Given prompt: ", prompt)
      response = openai.ChatCompletion.create(model="gpt-3.5-turbo", messages=[
        {"role": "system", "content": "You are a helpful assistant for coreference resolution."},
        {"role": "user", "content": example_prompt},
        {"role": "assistant", "content": example_solved_sentence},
        {"role": "user", "content": prompt}
        ], max_tokens=1000, temperature=0.5, stop=None)

      output = response['choices'][0]['message']['content'].strip()

    elif self.model_name == 'example':
      output = "He grew up in Evanston, Illinois the second oldest of five children including his brothers, Fred and Gordon and sisters, Marge (Peppy) and Marilyn. His high school days were spent at New Trier High School in Winnetka, Illinois. [MacKenzie](#cluster1) studied with [Bernard Leach](#cluster2) from 1949 to 1952. [His](#cluster1) simple, wheel-thrown functional pottery is heavily influenced by the oriental aesthetic of Shoji Hamada and Kanjiro Kawai."

    print('Output text LLM: ', output)
    self.predicted_text = output

    return output

  def retrieve_predictions(self):
    '''
    Extracts predicted clusters from text and returns information needed to calculate evaluation metrics
    Metrics for eval_metric can be 'Acc', 'F1', or 'B3_0'
    '''
    prediction = self.predicted_text

    if not self.option_a_identity and not self.option_b_identity:
      true_reference = None
    else:
      true_reference = "a" if self.option_a_identity else "b"
    self.true_reference = true_reference

    cluster_idxs = [i.start() for i in re.finditer('#cluster_', prediction)]
    cluster_nrs = [prediction[i+len('#cluster_')] for i in cluster_idxs]

    predicted_clusters = {self.entity_1: cluster_nrs[0],
                          self.entity_2: cluster_nrs[1],
                          self.entity_3: cluster_nrs[2]}

    pronoun_cluster = predicted_clusters[self.pronoun]
    option_a_cluster = predicted_clusters[self.option_a]
    option_b_cluster = predicted_clusters[self.option_b]
    self.pronoun_cluster = pronoun_cluster
    self.option_a_cluster = option_a_cluster
    self.option_b_cluster = option_b_cluster

    # append all predicted clusters to a list of lists which a property of the class instance
    pred_clusters = []
    cluster_1 =  []
    cluster_2 =  []
    cluster_3 =  []

    if pronoun_cluster == option_a_cluster:
      cluster_1.append(self.pronoun)
      cluster_1.append(self.option_a)
      cluster_2.append(self.option_b)

    elif pronoun_cluster == option_b_cluster:
      cluster_1.append(self.pronoun)
      cluster_1.append(self.option_b)
      cluster_2.append(self.option_a)

    else:
      cluster_1.append(self.pronoun)
      cluster_2.append(self.option_a)
      cluster_3.append(self.option_b)

    pred_clusters.append(cluster_1)
    pred_clusters.append(cluster_2)
    if cluster_3:
      pred_clusters.append(cluster_3)

    self.pred_clusters += pred_clusters

  def get_eval_metrics(self, eval_metric='B3_0'):
    if eval_metric == 'Acc' or eval_metric == 'F1':
      return self.true_reference, self.pronoun_cluster, self.option_a_cluster, self.option_b_cluster

    elif eval_metric == 'B3_0':
      return self.mentions, self.true_clusters, self.pred_clusters

    else:
      print('Metric not defined yet; choose a different one')


def Acc(true_reference, pronoun_cluster, option_a_cluster, option_b_cluster):
  '''
  Calculates if prediction is correct or not for calculating accuracy
  for binary coref resolution (for e.g. GAP or WinoBias)

  returns bool for correct or not
  '''

  correct=False
  if not true_reference: # none of the options has the correct reference
    if pronoun_cluster != option_a_cluster and pronoun_cluster != option_b_cluster:
      correct=True

  elif true_reference=='a':
    if pronoun_cluster == option_a_cluster:
      correct=True

  elif true_reference=='b':
    if pronoun_cluster == option_b_cluster:
      correct=True

  return correct

def F1(true_reference, pronoun_cluster, option_a_cluster, option_b_cluster):
  '''
  Calculates if prediction is TP/FP/FN/TN for calculating P/R/F1
  for binary coref resolution (for e.g. GAP or WinoBias)

  returns str in ['TP', 'FP', 'FN', 'TN']
  '''

  if true_reference:
    if ((true_reference=='a' and pronoun_cluster==option_a_cluster)
        or (true_reference=='b' and pronoun_cluster==option_b_cluster)):
        return 'TP'

    elif pronoun_cluster != option_a_cluster and pronoun_cluster != option_b_cluster:
      return 'FN'

  elif not true_reference:
    if ((true_reference=='a' and pronoun_cluster==option_a_cluster)
        or (true_reference=='b' and pronoun_cluster==option_b_cluster)):
        return 'FP'

    elif pronoun_cluster != option_a_cluster and pronoun_cluster != option_b_cluster:
      return 'TN'

    else:
      return 'TP'

def B3_0(mentions, true_clusters, pred_clusters):
  '''
  Calculates precision, recall, and optionally F1 for the  B3(0) metric,
  based on formulation in https://aclanthology.org/W10-4305.pdf

  returns precision, recall and f1 as lists for the input sentence
  '''

  precision_scores = []
  recall_scores = []
  f1_scores = []

  for mention in mentions:
    precision = 0
    recall = 0

    # finding key and response clusters to look at (first cluster to come up that contains current mention)
    mention_key_cluster = None
    for cluster in true_clusters:
      if mention in cluster:
        mention_key_cluster = cluster
        break
    assert mention_key_cluster, "At least one true cluster must contain mention!"

    mention_pred_cluster = None
    for cluster in pred_clusters:
      if mention in cluster:
        mention_response_cluster = cluster
        break

    intersection_key_response = list((Counter(mention_key_cluster) & Counter(mention_response_cluster)).elements())
    overlap_count = len(intersection_key_response)

    # response cluster could be empty if mention was not predicted for any cluster (twinless mention); in this case precision and recall both at 0
    if mention_response_cluster:
      precision = overlap_count / len(mention_response_cluster)
      recall = overlap_count / len(mention_key_cluster)

    precision_scores.append(precision)
    recall_scores.append(recall)
    f1_scores.append((2*precision*recall)/(precision+recall))

  return precision_scores, recall_scores, f1_scores

def global_Acc(correct_list):
  '''Calculates accuracy based on list of correct/incorrect predictions'''
  return sum(correct_list)/len(correct_list)

def global_P_R_F1(TP_count,FP_count,FN_count):
  '''
  Calculates global scores for precision, recall and F1 based on
  lists of TP, FP, FN and TN counts
  '''

  precision = TP_count / (TP_count + FP_count)
  recall = TP_count / (TP_count + FN_count)
  F1 = (2*precision*recall)/(precision+recall)

  return precision, recall, F1

def global_B3_0(precision_scores, recall_scores, F1_scores):
  '''
  Calculates global precision, recall and F1 scores based on lists of
  individual B3_0 precision/recall/F1 scores per mention
  '''

  B3_0_precision = sum(precision_scores)/len(precision_scores)
  B3_0_recall = sum(recall_scores)/len(recall_scores)
  B3_0_F1 = sum(F1_scores)/len(F1_scores)

  return B3_0_precision, B3_0_recall, B3_0_F1



# steps:
# run GPT evaluation on GAP and modified GAP
# run another language model (like Flan or Alpaca 2) to evaluate
# write zero-shot section in paper and see situation on general paper,
# and make sure the paper looks good (help out writing)




In [None]:
results = {"GAP-dev": {'gpt': {'Acc': [], 'F1': {'TP':0, 'FP':0, 'FN':0, 'TN':0}, 'B3_0': {'Precision': [], 'Recall': [], 'F1': []}}},
           "GAP-test": {'gpt': {'Acc': [], 'F1': {'TP':0, 'FP':0, 'FN':0, 'TN':0}, 'B3_0': {'Precision': [], 'Recall': [], 'F1': []}}},
           "GAP-valid": {'gpt': {'Acc': [], 'F1': {'TP':0, 'FP':0, 'FN':0, 'TN':0}, 'B3_0': {'Precision': [], 'Recall': [], 'F1': []}}},

           "NEUTRALGAP-dev": {'gpt': {'Acc': [], 'F1': {'TP':0, 'FP':0, 'FN':0, 'TN':0}, 'B3_0': {'Precision': [], 'Recall': [], 'F1': []}}},
           "NEUTRALGAP-test": {'gpt': {'Acc': [], 'F1': {'TP':0, 'FP':0, 'FN':0, 'TN':0}, 'B3_0': {'Precision': [], 'Recall': [], 'F1': []}}},
           "NEUTRALGAP-valid": {'gpt': {'Acc': [], 'F1': {'TP':0, 'FP':0, 'FN':0, 'TN':0}, 'B3_0': {'Precision': [], 'Recall': [], 'F1': []}}},}

# GAP dev
GAP_development = pd.read_table('gap-development.tsv')
GAP_development['Predicted'] = None

# gpt
model = 'gpt'

count=0 # print first few results
for sentence_id in range(len(GAP_development)):
  # if sentence_id<12:
  #   continue
  sentence_predictor = Predictor(GAP_development.iloc[sentence_id], model)
  sentence_predictor.add_clusters() # pre-processing; adding clusters to true sentence
  prediction = sentence_predictor.prompt_llm() # prediction; adding cluster predictions to sentence
  sentence_predictor.retrieve_predictions() # post-processing; adding cluster prediction data to sentence instance

  accuracy_evaluation = Acc(*sentence_predictor.get_eval_metrics(eval_metric='Acc')) # returning evaluation data in the right format for the desired evaluation metric
  F1_evaluation = F1(*sentence_predictor.get_eval_metrics(eval_metric='F1'))
  # print(sentence_predictor.get_eval_metrics(eval_metric='F1'))
  # print(F1_evaluation)
  B3_0_Pr, B3_0_Rec, B3_0_F1 = B3_0(*sentence_predictor.get_eval_metrics(eval_metric='B3_0'))

  GAP_development.at[sentence_id, 'Predicted'] = prediction
  results["GAP-dev"][model]["Acc"].append(accuracy_evaluation) # update acc
  results["GAP-dev"][model]["F1"][F1_evaluation]+=1 # update F1
  results["GAP-dev"][model]["B3_0"]["Precision"] += B3_0_Pr # update Precision_B3_0
  results["GAP-dev"][model]["B3_0"]["Recall"] += B3_0_Rec # update Recall_B3_0
  results["GAP-dev"][model]["B3_0"]["F1"] += B3_0_F1 # update F1_B3_0

  with open('zero_shot_results.json', 'w') as f:
    json.dump(results, f)


  count+=1
  if count<=10:
    print(accuracy_evaluation, F1_evaluation, B3_0_Pr, B3_0_Rec, B3_0_F1)





Given prompt:  Annotate all entity mentions, annotated as [entity](#) in the following text with coreference clusters. Use Markdown tags to indicate clusters in the output, with the following format [mention](#cluster_name). 
 Input: Killian in 1978--79, an assistant district attorney for Brunswick Judicial Circuit in 1979--80, and a practicing attorney in Glynn County in 1980--90. Williams was elected a Superior Court judge in 1990, taking the bench in 1991. In November 2010[Williams](#) competed against[Mary Helen Moses](#) in[her](#) most recent bid for re-election. 
 Output:
Output text LLM:  ARTA driver Vitantonio Liuzzi will be replaced by former Mugen driver Tomoki Nojiri after a disappointing season last year. After years of being with Real Racing, Toshihiro Kaneishi will not drive for this season, being replaced by former Team Kunimitsu driver Hideki Mutoh.[Kazuki Nakajima](#cluster_1), like[Oliver Jarvis](#cluster_2), will not return to focus on[his](#cluster_1) LMP1 drive in

ServiceUnavailableError: ignored

In [None]:
# code for fixing offset and entity values in GAP-nb

def find_nearest_blank_space(s, index):
    # Use rfind() to find the last occurrence of blank space before the given index
    nearest_blank_space_index = s.rfind(' ', 0, index)
    return nearest_blank_space_index

# do this for all 3 GAP sets
GAP_nb_set = pd.read_table('val-GAP-NB.tsv')
GAP_set = pd.read_table('gap-validation.tsv')

failed_instances = []
for sentence_id in range(len(GAP_nb_set)):
  new_sentence = GAP_nb_set.iloc[sentence_id]
  original_sentence = GAP_set.iloc[sentence_id]

  space_offsets = [i for i, ltr in enumerate(original_sentence["Text"]) if ltr == ' ']
  word_offsets = [i+1 for i in space_offsets]
  # print(word_offsets)
  # print(original_sentence['Text'])
  # print(sentence_id)

  # pronoun
  old_offset = original_sentence['Pronoun-offset']

  # for fixing sentences with commas before entities
  # if original_sentence['Text'][old_offset-1] != ' ':
  #   nearest_blank_space_index = find_nearest_blank_space(original_sentence['Text'], old_offset)
  #   old_offset = nearest_blank_space_index + 1

  if old_offset == 0:
    original_word_idx = 0
  else:
    try:
      original_word_idx = word_offsets.index(old_offset)

      new_pronoun = new_sentence["Text"].split(' ')[original_word_idx]

      words_up_to_pronoun = new_sentence['Text'].split(' ')[:original_word_idx]
      new_pronoun_offset = sum([len(word) for word in words_up_to_pronoun]) + len(words_up_to_pronoun)

      GAP_nb_set.at[sentence_id, 'Pronoun'] = new_pronoun
      GAP_nb_set.at[sentence_id, 'Pronoun-offset'] = new_pronoun_offset

    except:
      failed_instances.append(sentence_id)
      pass

  # option a
  old_offset = original_sentence['A-offset']

  # for fixing sentences with commas before entities
  # if original_sentence['Text'][old_offset-1] != ' ':
  #   nearest_blank_space_index = find_nearest_blank_space(original_sentence['Text'], old_offset)
  #   old_offset = nearest_blank_space_index + 1

  # if "``" in original_sentence['Text'][old_offset-2:old_offset] or "--" in original_sentence['Text'][old_offset-2:old_offset]:
  #   old_offset -= 2
  # elif "`" in original_sentence['Text'][old_offset-2:old_offset] or "-" in original_sentence['Text'][old_offset-2:old_offset]:
  #   old_offset -= 1

  if old_offset == 0:
    original_word_idx = 0
  else:
    try:
      original_word_idx = word_offsets.index(old_offset)
      new_option_a = new_sentence["Text"].split(' ')[original_word_idx]

      words_up_to_option_a = new_sentence['Text'].split(' ')[:original_word_idx]
      new_option_a_offset = sum([len(word) for word in words_up_to_option_a]) + len(words_up_to_option_a)

      GAP_nb_set.at[sentence_id, 'A'] = new_option_a
      GAP_nb_set.at[sentence_id, 'A-offset'] = new_option_a_offset
    except:
      failed_instances.append(sentence_id)
      pass

  # option b
  old_offset = original_sentence['B-offset']

  # for fixing sentences with commas before entities
  # if original_sentence['Text'][old_offset-1] != ' ':
  #   nearest_blank_space_index = find_nearest_blank_space(original_sentence['Text'], old_offset)
  #   old_offset = nearest_blank_space_index + 1

  # if "``" in original_sentence['Text'][old_offset-2:old_offset] or "--" in original_sentence['Text'][old_offset-2:old_offset]:
  #   old_offset -= 2
  # elif "`" in original_sentence['Text'][old_offset-2:old_offset] or "-" in original_sentence['Text'][old_offset-2:old_offset]:
  #   old_offset -= 1

  if old_offset == 0:
    original_word_idx = 0
  else:
    try:
      original_word_idx = word_offsets.index(old_offset)
      new_option_b = new_sentence["Text"].split(' ')[original_word_idx]

      words_up_to_option_b = new_sentence['Text'].split(' ')[:original_word_idx]
      new_option_b_offset = sum([len(word) for word in words_up_to_option_b]) + len(words_up_to_option_b)

      GAP_nb_set.at[sentence_id, 'B'] = new_option_b
      GAP_nb_set.at[sentence_id, 'B-offset'] = new_option_b_offset
    except:
      failed_instances.append(sentence_id)
      pass

GAP_nb_set.to_csv('val-GAP-NB-fixed-2.tsv', sep="\t")

In [None]:
# code to test llm evaluation

model = 'gpt'

# do this in a loop for all sentences, then run global eval functions to get final evaluation scores:

# test 1
sentence_1_ex = GAP_sentence(df.iloc[1])
sentence_1_ex.add_clusters()

# test 2
sentence_1_ex_eval = Predictor(df.iloc[2], model)
sentence_1_ex_eval.add_clusters() # pre-processing; adding clusters to true sentence
sentence_1_ex_eval.prompt_llm() # prediction; adding cluster predictions to sentence
sentence_1_ex_eval.retrieve_predictions() # post-processing; adding cluster prediction data to sentence
a,b,c,d = sentence_1_ex_eval.get_eval_metrics(eval_metric='Acc') # returning evaluation data in the right format for the desired evaluation metric
e,f,g,h = sentence_1_ex_eval.get_eval_metrics(eval_metric='F1')
i,j,k = sentence_1_ex_eval.get_eval_metrics(eval_metric='B3_0')

accuracy_evaluation = Acc(a,b,c,d)
F1_evaluation = F1(e,f,g,h)
B3_0_evaluation = B3_0(i,j,k)

print(accuracy_evaluation, F1_evaluation, B3_0_evaluation)



Given prompt:  Annotate all entity mentions, annotated as [entity](#) in the following text with coreference clusters. Use Markdown tags to indicate clusters in the output, with the following format [mention](#cluster_name). 
 Input: He had been reelected to Congress, but resigned in 1990 to accept a post as Ambassador to Brazil. De la Sota again ran for governor of C*rdoba in 1991. Defeated by Governor[Angeloz](#) by over 15%, this latter setback was significant because it cost[De la Sota](#) much of[his](#) support within the Justicialist Party (which was flush with victory in the 1991 mid-terms), leading to President Carlos Menem 's endorsement of a separate party list in C*rdoba for the 1993 mid-term elections, and to De la Sota's failure to regain a seat in Congress. 
 Output:
Output text LLM:  He had been reelected to Congress, but resigned in 1990 to accept a post as Ambassador to Brazil. De la Sota again ran for governor of C*rdoba in 1991. Defeated by Governor[Angeloz](#cluste

In [None]:
# from here on: old versions of code
def pronoun_info(df_idx):
  df_example = df.iloc[df_idx]

  full_text = df_example['Text']
  # print("Full text: {}".format(full_text))
  # print('\n')

  pronoun = df_example['Pronoun']
  # print("Pronoun to be resolved: {}".format(pronoun))
  # print('\n')

  if df_example['A-coref'] == df_example['B-coref']:
    return 'invalid'

  if df_example['A-coref']:
    correct_coref = df_example['A']
    incorrect_coref = df_example['B']

  elif df_example['B-coref']:
    correct_coref = df_example['B']
    incorrect_coref = df_example['A']

  # print("Correct coreference: {}".format(correct_coref))
  # print("Incorrect coreference: {}".format(incorrect_coref))

  return full_text, pronoun, correct_coref, incorrect_coref


In [None]:
from transformers import AutoTokenizer, RobertaForMultipleChoice
import torch

tokenizer = AutoTokenizer.from_pretrained("roberta-base")
model = RobertaForMultipleChoice.from_pretrained("roberta-base")


Downloading (…)lve/main/config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

Downloading (…)olve/main/vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Downloading model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
def multiple_choice_eval_roberta():
  FP = 0
  FN = 0
  TP = 0
  TN = 0

  correct = 0

  for id in range(len(df['ID'])):
    if pronoun_info(id) == 'invalid':
      print('invalid')
      break

    full_sentence, pronoun, correct_coref, incorrect_coref = pronoun_info(id)

    prompt = "What does the pronoun {} refer to in this sentence? Sentence: {}".format(pronoun, full_sentence)
    choice0 = "The pronoun refers to {}.".format(incorrect_coref)
    choice1 = "The pronoun refers to {}.".format(correct_coref)

    labels = torch.tensor(0).unsqueeze(0)  # choice0 is correct (according to Wikipedia ;)), batch size 1

    encoding = tokenizer([prompt, prompt], [choice0, choice1], return_tensors="pt", padding=True)
    outputs = model(**{k: v.unsqueeze(0) for k, v in encoding.items()}, labels=labels)  # batch size is 1

    # the linear classifier still needs to be trained
    loss = outputs.loss
    logits = outputs.logits

    if logits[0][0] < logits[0][1]:
      correct += 1

  accuracy = correct / len(df['ID'])

  return correct, accuracy

In [None]:
correct, accuracy = multiple_choice_eval_roberta()
correct, accuracy

invalid


(4, 0.002)