In [1]:
%%capture
!pip install allennlp
!pip install allennlp-models==2.1.0
#==2.1.0 allennlp-models==2.1.0


In [2]:
from typing import List
 
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import math
import torch
from torch.utils.data import Dataset, DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

cuda


# Initialize end2end predictor from Allennlp

In [3]:
import textwrap
from allennlp.predictors.predictor import Predictor
import allennlp_models.tagging
import pandas as pd
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/coref-spanbert-large-2021.03.10.tar.gz",
                                cuda_device=torch.cuda.current_device())
# Wrap text to 80 characters.
wrapper = textwrap.TextWrapper(width=80) 

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Unzipping corpora/wordnet.zip.


Plugin allennlp_models could not be loaded: No module named 'nltk.translate.meteor_score'
downloading: 100%|##########| 1345986155/1345986155 [00:33<00:00, 40734377.60B/s]


Downloading:   0%|          | 0.00/414 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665M [00:00<?, ?B/s]

Some weights of BertModel were not initialized from the model checkpoint at SpanBERT/spanbert-large-cased and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


# Load ECB+ Data

In [14]:
import random
import json
import pandas as pd

from google.colab import drive
drive.mount('/content/drive',force_remount=True)

# Make sure to click "Add shortcut to drive" for the "Coref-for-GPT" folder
gdrive_dir_path = "/content/drive/MyDrive/Coref-for-GPT"


Mounted at /content/drive


In [15]:
local_path = ""

# Change this to "local_path" if you run the notebook locally
root_path = gdrive_dir_path

In [16]:
# Path to the ecb data
ecb_path = f"{root_path}/Data/ECB+/"

In [17]:
file_path = ecb_path + "processed/dev_with_sent_idx.json"
with open(file_path) as f:
    dev = json.load(f)
print(len(dev))

169


# Get predicted mentions and clusters

In [18]:
def get_sentence_id(sents_divider, tok_id):
    for i, idx in enumerate(sents_divider):
        if tok_id < idx:
            return i
    raise Exception(f"Error: {tok_id} out of range {sents_divider}")

In [19]:
def get_mentions(output, gold_toks, sents_divider):
    mentions = []
    for i, cluster in enumerate(output["clusters"]):
        for m in cluster:
            start_id = m[0]
            end_id = m[1]
            tokens_ids = list(range(start_id, end_id+1))
            tokens = gold_toks[start_id:end_id+1]
            s_id = get_sentence_id(sents_divider, tokens_ids[0])
            mentions.append({"tokens":tokens,
                             "tokens_ids": tokens_ids,
                             "sentence_id": s_id})

    mentions = sorted(mentions, key=lambda x: x["tokens_ids"][0])
    mention_dict = {}
    for i, m in enumerate(mentions):
        m["m_id"] = i
        mention_id = str([m["tokens_ids"][0],m["tokens_ids"][-1]])
        mention_dict[mention_id] = m
    return mention_dict

In [20]:
def get_clusters(output, mentions, matched_map):
    clusters = {}
    for i, cluster in enumerate(output["clusters"]):
        matched_gold_mentions = []
        for m in cluster:
            mention = mentions[str(m)]
            m_id = mention["m_id"]
            matched_gold_id = matched_map[m_id]
            if matched_gold_id != -1:
                matched_gold_mentions.append(matched_gold_id) 
        if matched_gold_mentions:
            clusters[i] = matched_gold_mentions
    return clusters


In [21]:
def match(m_token_ids, gold_mentions, threshold):
    m_token_ids = set(m_token_ids)
    for m_id in gold_mentions:
        mention = gold_mentions[m_id]
        gold_token_ids = set(mention["tokens_ids"])
        intersec = gold_token_ids.intersection(m_token_ids)
        matched_ratio = len(intersec)/len(gold_token_ids)
        if matched_ratio > threshold:
            return m_id
    return -1

In [22]:
def matched_gold_mentions(mentions, gold_mentions, threshold = 0.5):
    map = []
    matched_mentions = []
    for m in mentions.values():
        matched_gold_m_id = match(m["tokens_ids"], gold_mentions, threshold)
        map.append(matched_gold_m_id)
        if matched_gold_m_id != -1:
            matched_mentions.append(matched_gold_m_id)
    return map, matched_mentions

In [23]:
def get_gold_clusters(gold_mentions):
    clusters = {}
    for m_id in gold_mentions:
        mention = gold_mentions[m_id]
        cluster_id = mention["cluster_id"]
        if cluster_id in clusters:
            clusters[cluster_id].append(m_id)
        else:
            clusters[cluster_id] = [m_id]
    return clusters

In [24]:
def get_pairwise_labels(gold_mentions, mentions_ids, clusters):
    pairs = []
    labels = []
    for i in range(len(mentions_ids)-1):
        for j in range(i+1, len(mentions_ids)):
            m1_id = mentions_ids[i]
            m2_id = mentions_ids[j]
            pair = set([m1_id, m2_id])
            label = 0
            for cluster_mentions in clusters.values():
                if pair.issubset(set(cluster_mentions)):
                    label = 1
                    break
            pairs.append(pair)
            labels.append(label)
            
    pairs, labels = np.array(pairs), np.array(labels)
    return pairs, labels


In [25]:
def get_comparison(text, gold_mentions, gold_toks, output, sents_divider, sentence_restriction = 2):
    # gold
    new_gold_clusters = get_gold_clusters(gold_mentions)
    gold_pairs, gold_labels = get_pairwise_labels(gold_mentions, list(gold_mentions.keys()), new_gold_clusters)
    gold_df = pd.DataFrame([gold_pairs, gold_labels]).T
    gold_df.columns = ["pair", "label"]
    gold_df["mention_pair"] = gold_df["pair"].apply(lambda x: [gold_mentions[list(x)[0]], gold_mentions[list(x)[1]]])
    gold_df["pair"] = gold_df["pair"].astype(str)

    gold_df["sent_idx"] = gold_df["mention_pair"].apply(lambda x: [x[0]["sentence_id"], x[1]["sentence_id"]])
    # changed
    sents = text.split("[EOS]")
    gold_df["sentence"] = gold_df["sent_idx"].apply(lambda x: [sents[x[0]], sents[x[1]]])
    gold_df["sent_filter"] = gold_df["sent_idx"].apply(lambda x: np.abs(x[1]-x[0]))
    if sentence_restriction:
        gold_df = gold_df[gold_df["sent_filter"] < sentence_restriction].reset_index(drop = True)

    
    # predictions
    mentions = get_mentions(output, gold_toks, sents_divider)
    pred_to_gold_map, matched_mentions = matched_gold_mentions(mentions, gold_mentions, threshold = 0.5)
    clusters = get_clusters(output, mentions, pred_to_gold_map)
    pred_pairs, pred_labels = get_pairwise_labels(gold_mentions, matched_mentions, clusters)
    pred_df = pd.DataFrame([pred_pairs, pred_labels]).T
    pred_df.columns = ["pair", "pred"]
    pred_df["pair"] = pred_df["pair"].astype(str)

    result = pd.merge(left = gold_df, right = pred_df, how = "left", on = "pair")
    result["pred"] = result["pred"].fillna(0)
    
    return result


In [26]:
def annotate(data, output_file):
    for i, doc_name in enumerate(tqdm(data)):
        text, gold_toks, gold_mentions, gold_clusters, sents_divider = data[doc_name]
        output = predictor.predict_tokenized(gold_toks)
        result = get_comparison(text, gold_mentions, gold_toks, output, sents_divider, sentence_restriction = 2)
        result["doc_name"] = doc_name
        result.to_csv(output_file, mode="a", index=False, header=None)

In [27]:
output_file = f"{root_path}/Results/e2e-coref/pairwise_result.csv"
annotate(dev, output_file)

100%|██████████| 169/169 [01:10<00:00,  2.39it/s]
