In [23]:
%pip install fastcoref



In [22]:
from fastcoref import FCoref

def resolve_coreferences(text):
    model = FCoref(device='cpu')  # or 'cpu' for CPU-only execution

    preds = model.predict(texts=[text])
    clusters_as_indexes = preds[0].get_clusters(as_strings=False)
    clusters_as_strings = preds[0].get_clusters()

    return clusters_as_indexes, clusters_as_strings


# Example usage:
# text = 'We are so happy to see you using our coref package. This package is very fast!'
text = "John and Mary are good friends. He likes her company. They often go out together."
clusters_indexes, clusters_strings = resolve_coreferences(text)

print("Clusters as indexes:")
print(clusters_indexes)

print("\nClusters as strings:")
print(clusters_strings)



Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Clusters as indexes:
[[(0, 4), (32, 34)], [(9, 13), (41, 44)]]

Clusters as strings:
[['John', 'He'], ['Mary', 'her']]


In [27]:
import pandas as pd
file_path = 'mini_test.csv'
df = pd.read_csv(file_path)
df.head()



Unnamed: 0,ID,Text,Pronoun,Pronoun-offset,A,A-offset,A-coref,B,B-offset,B-coref,URL
0,neu-0,They originally trained as an occupational the...,their,436,Holly Kenyon,336,True,Jane Alison,410,False,http://en.wikipedia.org/wiki/Holly_Kenyon
1,neu-1,"Since May 2011 , they have written a biweekly ...",They,291,Postrel,206,True,Sally Satel,277,False,http://en.wikipedia.org/wiki/Virginia_Postrel
2,neu-2,They advocated treating offenders like the men...,their,352,Karl Menninger,366,True,Thomas Szasz,399,False,http://en.wikipedia.org/wiki/Karl_Menninger
3,neu-3,"At age 17 , Rinker moved to Seattle and worked...",their,280,Rinker,215,False,Bailey,318,True,http://en.wikipedia.org/wiki/Mildred_Bailey
4,neu-4,"They wintered in 1836 - 1837 at New Orleans , ...",they,340,Stewart,273,True,Alfred Jacob Miller,313,False,http://en.wikipedia.org/wiki/William_Drummond_...


In [47]:
def flatten(internal_list):
    flattened_list = []
    for item in internal_list:
        if isinstance(item, list):
            flattened_list.extend(item)
        else:
            flattened_list.append(item)
    return tuple(flattened_list)



In [48]:
def get_antecedents(index):
  offset_pron = df.loc[index, 'Pronoun-offset']
  A_coref = df.loc[index, 'A-coref']
  B_coref = df.loc[index, 'B-coref']

  true_antecedent = None
  false_antecedent_A = None
  false_antecedent_B = None


  if A_coref == True:
    true_antecedent = df.loc[index, 'A-offset']
  elif B_coref == True:
    true_antecedent = df.loc[index, 'B-offset']
  else:
    true_antecedent = None

  if A_coref == False:
    false_antecedent_A = df.loc[index, 'A-offset']

  if B_coref == False:
    false_antecedent_B = df.loc[index, 'B-offset']

  false_antecedents = [antecedent for antecedent in [false_antecedent_A, false_antecedent_B] if antecedent is not None]

  gold_cluster = (true_antecedent, offset_pron)
  false = (false_antecedents, offset_pron)

  false_cluster = flatten(false)

  return gold_cluster, false_cluster




In [49]:
for index, row in df.iterrows():
  gold_cluster, false_cluster = get_antecedents(index)
  print('gold')
  print(gold_cluster)
  print('false')
  print(false_cluster)


gold
(336, 436)
false
(410, 436)
gold
(206, 291)
false
(277, 291)
gold
(366, 352)
false
(399, 352)
gold
(318, 280)
false
(215, 280)
gold
(273, 340)
false
(313, 340)


In [65]:
"""
# test
pred_offset = [[17, 74, 168, 206, 291], [94, 142], [214, 311, 344]]
gold_cluster_true = (206, 291)
gold_cluster_false = (214, 311)
"""

def calc_TP_FN_FP(preds, gold_cluster_true, gold_cluster_false):

  TP = 0
  FN = 0
  for sublist in preds:
      if all(value in sublist for value in gold_cluster_true):
          # Increment TP by 1 if both items in gold_cluster_true are present in any cluster
          TP += 1
          break
  else:
      # increment FN if none of the sublists contain both values from gold_cluster_true
      FN += 1

  FP = 0
  for sublist in preds:
      if all(value in sublist for value in gold_cluster_false):
          # Increment FP by 1 if both items in gold_cluster_false are present in any cluster
          FP += 1
          break

  return TP, FN, FP



In [71]:
"""
# test
gold_cluster_true, gold_cluster_false = get_antecedents(2)

text = df.loc[2, 'Text']
pred_clusters_indexes, pred_clusters_strings = resolve_coreferences(text)
preds = [[tup[0] for tup in sublist] for sublist in pred_clusters_indexes]


TP, FN, FP = calc_TP_FN_FP(preds, gold_cluster_true, gold_cluster_false)
print(TP, FN, FP)
"""



Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

0 1 0
GOLD CLUSTER:
(366, 352)
FALSE CLUSTER
(399, 352)
Clusters as indexes:
[[0, 58, 296, 352], [163, 366]]

Clusters as strings:
[['They', 'Their', 'them', 'their'], ['Menninger', 'Karl Menninger']]


In [79]:
def calc_prec_recall(TP, FN, FP):

  precision = TP / (TP + FP)
  recall = TP / (TP + FN)

  f1_score = 2 * (precision * recall) / (precision + recall)

  return precision, recall, f1_score

In [76]:
true_pos = []
false_neg = []
false_pos = []


for index, row in df.iterrows():
  gold_cluster_true, gold_cluster_false = get_antecedents(index)

  text = df.loc[index, 'Text']
  pred_clusters_indexes, pred_clusters_strings = resolve_coreferences(text)
  preds = [[tup[0] for tup in sublist] for sublist in pred_clusters_indexes]
  TP, FN, FP = calc_TP_FN_FP(preds, gold_cluster_true, gold_cluster_false)

  true_pos.append(TP)
  false_neg.append(FN)
  false_pos.append(FP)

total_TP = sum(true_pos)
total_FN = sum(false_neg)
total_FP = sum(false_pos)





Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

Map:   0%|          | 0/1 [00:00<?, ? examples/s]

Inference:   0%|          | 0/1 [00:00<?, ?it/s]

In [80]:
prec, recall, f1 = calc_prec_recall(total_TP, total_FN, total_FP)
print(f'precision: {prec}')
print(f'recall: {recall}')
print(f'f1: {f1}')

precision: 1.0
recall: 0.2
f1: 0.33333333333333337
