In [1]:
import torch
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

Using device: cuda


Coreference resolution:

In [2]:
!pip install spacy~=3.3.0
!python -m spacy download en_core_web_sm
!pip install allennlp
!pip install allennlp-models

2024-05-03 13:54:31.347686: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-05-03 13:54:31.347739: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-05-03 13:54:31.349078: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-05-03 13:54:31.356220: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Collecting en-core-web-sm==3.3.0
  Downloading https:

In [1]:
from allennlp.predictors.predictor import Predictor

In [2]:
model_url = 'https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2020.02.27.tar.gz'
predictor = Predictor.from_path(model_url)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...


Output()

Downloading:   0%|          | 0.00/414 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/208k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/634M [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


sample text:

In [3]:
text = """The 20-year-old has played seven games for the Swans since making his £1.75m move to the Liberty Stadium from League Two Exeter City in January 2015.
Grimes made his only Premier League start to date in September, having scored his first Swansea goal in the League Cup a month earlier.
His last appearance came in their 3-2 FA Cup exit against League Two Oxford.
Find all the latest football transfers on our dedicated page. """

In [4]:
#text = "Eva and Martha didn't want their friend Jenny to feel lonely so they invited her to the party in Las Vegas."
prediction = predictor.predict(document=text)

In [5]:
# it's our original text (with extra whitespaces as we trivialy just joined tokens with ' ')
' '.join(prediction['document'])

# and the found clusters - however, they are not easily understood...
prediction['clusters']

# but that's how it looks after coreference resolution (notice the possessive!)
predictor.coref_resolved(text)

"The 20-year-old has played seven games for the Swans since making The 20-year-old's £1.75m move to the Liberty Stadium from League Two Exeter City in January 2015.\nThe 20-year-old made The 20-year-old's only Premier League start to date in September, having scored The 20-year-old's first the Swans goal in the League Cup a month earlier.\nThe 20-year-old's last appearance came in the Swans's 3-2 the League Cup exit against League Two Oxford.\nFind all the latest football transfers on our dedicated page. "

In [6]:
def get_span_words(span, document):
    return ' '.join(document[span[0]:span[1]+1])

def print_clusters(prediction):
    document, clusters = prediction['document'], prediction['clusters']
    for cluster in clusters:
        print(get_span_words(cluster[0], document) + ': ', end='')
        print(f"[{'; '.join([get_span_words(span, document) for span in cluster])}]")

In [7]:
print_clusters(prediction)

The 20 - year - old: [The 20 - year - old; his; Grimes; his; his; His]
the Swans: [the Swans; Swansea; their]
the League Cup: [the League Cup; FA Cup]


In [18]:
import spacy
nlp = spacy.load('en_core_web_sm')
import nltk
nltk.download('averaged_perceptron_tagger')

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...
[nltk_data]   Package averaged_perceptron_tagger is already up-to-
[nltk_data]       date!


True

Getting parts of speach of every clusters (not necessary for coref-resolution)

In [17]:
pos_cluster_list = []
nltk_pos_list = []
def get_cluster_pos(prediction):
    document, clusters = prediction['document'], prediction['clusters']
    #pos_cluster = []
    for cluster in clusters:
        print(f"[{'; '.join([get_span_words(span, document) for span in cluster])}]")
        cluster_len = len(cluster)
        pos_cluster = []
        nltk_pos_cluster = []
        for i in range(cluster_len):
          text = get_span_words(cluster[i], document)
          doc = nlp(text)
          tokens = nltk.word_tokenize(text)
          tag = nltk.pos_tag(tokens)
          #print(tag)
          nltk_pos = []
          for j in range(len(tag)):
            nltk_pos.append(tag[j][1])
            #print(tag[j][1])
          pos = []
          for token in doc:
            n_pos = token.pos_
            pos.append(n_pos)
            #print(token.text, token.pos_)
          #print("pos:",pos)
          t_pos = " ".join(pos)
          #print("t_pos:",t_pos)
          pos_cluster.append(t_pos)
          #print("list:",pos_cluster)
          t_nltk_pos = " ".join(nltk_pos)
          nltk_pos_cluster.append(t_nltk_pos)
        nltk_pos_list.append(nltk_pos_cluster)
        pos_cluster_list.append(pos_cluster)
    #stores pos tags of all mentions in clusters
    print("complete_list:",pos_cluster_list)
    print("complete_nltk_list:",nltk_pos_list)

get_cluster_pos(prediction)

[nltk_data] Downloading package averaged_perceptron_tagger to
[nltk_data]     /root/nltk_data...


[The 20 - year - old; his; Grimes; his; his; His]
[the Swans; Swansea; their]
[the League Cup; FA Cup]
complete_list: [['DET NUM PUNCT NOUN PUNCT ADJ', 'PRON', 'NOUN', 'PRON', 'PRON', 'PRON'], ['DET PROPN', 'NOUN', 'PRON'], ['DET PROPN PROPN', 'PROPN PROPN']]
complete_nltk_list: [['DT CD : NN : JJ', 'PRP$', 'NNS', 'PRP$', 'PRP$', 'PRP$'], ['DT NNPS', 'NN', 'PRP$'], ['DT NNP NNP', 'NNP NNP']]


[nltk_data]   Unzipping taggers/averaged_perceptron_tagger.zip.


In [8]:
from typing import List
from spacy.tokens import Doc, Span

In [9]:
def core_logic_part(document: Doc, coref: List[int], resolved: List[str], mention_span: Span):
    final_token = document[coref[1]]
    if final_token.tag_ in ["PRP$", "POS"]:
        resolved[coref[0]] = mention_span.text + "'s" + final_token.whitespace_
    else:
        resolved[coref[0]] = mention_span.text + final_token.whitespace_
    for i in range(coref[0] + 1, coref[1] + 1):
        resolved[i] = ""
    return resolved


def original_replace_corefs(document: Doc, clusters: List[List[List[int]]]) -> str:
    resolved = list(tok.text_with_ws for tok in document)

    for cluster in clusters:
        mention_start, mention_end = cluster[0][0], cluster[0][1] + 1
        mention_span = document[mention_start:mention_end]

        for coref in cluster[1:]:
            core_logic_part(document, coref, resolved, mention_span)

    return "".join(resolved)

**Redundant clusters** - lack of a meaningfull mention that could become the head.
We completely ignore (we don't resove them at all) the clusters that doesn't contain any noun phrase.
**Improvement: check if clusters are valid. Include check if one mention is verb and another is noun under the same cluster.**

In [10]:
def get_span_noun_indices(doc: Doc, cluster: List[List[int]]) -> List[int]:
    spans = [doc[span[0]:span[1]+1] for span in cluster]
    spans_pos = [[token.pos_ for token in span] for span in spans]
    span_noun_indices = [i for i, span_pos in enumerate(spans_pos)
        if any(pos in span_pos for pos in ['NOUN', 'PROPN'])]  #getting all noun spans
    span_prpr_noun_indices = [i for i, span_pos in enumerate(spans_pos)
        if any(pos in span_pos for pos in ['PROPN'])]  #getting only proper noun spans

    return span_noun_indices, span_prpr_noun_indices

**Cataphora problem** - choosing the wrong cluster head.
We redefine the span that becomes a cluster head. Instead of choosing the first mention in the cluster, we pick the one that is the first noun phrase in the cluster - we define it as the first span that contains a noun.
**Improvement: Check for simpler proper noun if it exists rather than the first span that contains a noun.**

In [11]:
def get_cluster_head(doc: Doc, cluster: List[List[int]], noun_indices: List[int]):
    head_idx = noun_indices[0]
    head_start, head_end = cluster[head_idx]
    head_span = doc[head_start:head_end+1]
    return head_span, [head_start, head_end]

In [21]:
#getting span of any index words
def get_span(doc: Doc, cluster: List[List[int]], indices: List[int], i):
  id = indices[i]
  start, end = cluster[id]
  span = doc[start:end+1]
  return span, start, end

#improving code for smallest/appropriate span with proper nouns instead of first span with noun
pn_Kb = []
def get_cluster_head_pn(doc: Doc, cluster: List[List[int]], prpr_noun_indices: List[int]):
    span_len = []
    #measure length of each span and their pos
    for i in range(len(prpr_noun_indices)):
      span,start, end = get_span(doc, cluster, prpr_noun_indices, i)
      # print(span)
      length_span = len(span)
      span_len.append(length_span)

    # print("length:",span_len)
    kb_pn = []
    #check for the span with minimum length and assign it as head of cluster and store the rest in kb
    for i in range(len(prpr_noun_indices)):
      if span_len[i] == min(span_len):
        head_span, head_start, head_end = get_span(doc, cluster, prpr_noun_indices, i)

        #add pos checking here
        #add condition if minimum span > 4 and all tokens in span are not propn then change head_span
        nltk_pos = []
        tokens = nltk.word_tokenize(head_span.text)
        tag = nltk.pos_tag(tokens)
        if len(head_span) > 4 and tag[2][1] != 'PROPN':
          head_span = head_span[0]
        '''
        for j in range(len(tag)):
          nltk_pos.append(tag[j][1])
          if tag[i][1] == 'PROPN':
            head_span = tag[j][0]
        '''
      elif span_len[i] != min(span_len):
        m_span, m_start, m_end = get_span(doc, cluster, prpr_noun_indices, i)
        kb_pn.append(m_span)

    pn_Kb.append(kb_pn)
    return head_span, [head_start, head_end], pn_Kb

**Nested coreferent mentions**
In the case of nested mentions we choose to resolve the inner span (e.g. for the mention "his dog" the token his can be the inner span). That just means we don't want to resolve outer spans.

In [26]:
def is_containing_other_spans(span: List[int], all_spans: List[List[int]]):
    return any([s[0] >= span[0] and s[1] <= span[1] and s != span for s in all_spans])

def improved_replace_corefs(document, clusters):
    resolved = list(tok.text_with_ws for tok in document)
    all_spans = [span for cluster in clusters for span in cluster]  # flattened list of all spans

    for cluster in clusters:
        noun_indices, prpr_noun_indices = get_span_noun_indices(document, cluster)
        #print("noun indices:", noun_indices, "proper noun indices:", prpr_noun_indices) #test printing both noun and proper noun spans

        if prpr_noun_indices:
            #set the appropriate proper-noun mention as cluster head to replace all corefs
            mention_span, mention, kb_pn = get_cluster_head_pn(document, cluster, prpr_noun_indices)
            #print("mention span:",mention_span, "mention:", mention, "proper noun mentions:", kb_pn)
            #store the other span with noun mentions that is not a proper noun but may contain important information in the knowledge base
            for i in range(len(noun_indices)):
               if noun_indices[i] not in prpr_noun_indices:
                 idx = i
                 other_mention_span, other_mention = get_cluster_head(document, cluster, noun_indices)
                 kb = "".join(other_mention_span.text) #need to link proper noun with kb
                 print("kb:", kb, "->", mention_span)
                #print(type(other_mention_span))

            for coref in cluster:
                if coref != mention and not is_containing_other_spans(coref, all_spans):
                    core_logic_part(document, coref, resolved, mention_span)

        #if there is no proper noun
        elif not prpr_noun_indices and noun_indices:
            #set the first noun mention as cluster head to replace all corefs since proper noun does not exist
            mention_span, mention = get_cluster_head(document, cluster, noun_indices)
            #print("mention span:",mention_span, "mention:", mention)


            for coref in cluster:
                if coref != mention and not is_containing_other_spans(coref, all_spans):
                    core_logic_part(document, coref, resolved, mention_span)


        #if there is no noun and proper noun
        elif not noun_indices and not prpr_noun_indices:
          predictor.coref_resolved(text)

    return "".join(resolved)

In [19]:
clusters = predictor.predict(text)['clusters']
doc = nlp(text)

In [27]:
print(improved_replace_corefs(doc, clusters))

kb: The 20-year-old -> Grimes
Grimes has played seven games for Swansea since making Grimes's £1.75m move to the Liberty Stadium from League Two Exeter City in January 2015.
Grimes made Grimes's only Premier League start to date in September, having scored Grimes's first Swansea goal in FA Cup a month earlier.
Grimes's last appearance came in Swansea's 3-2 FA Cup exit against League Two Oxford.
Find all the latest football transfers on our dedicated page. 
