# Coreference Resolution

Configuration

In [None]:
import os
os.environ['data_dir'] = "."
os.environ['CHOSEN_MODEL'] = 'spanbert_base'
! pip uninstall -y tensorflow
! pip install -r requirements.txt --log install-log.txt -q
! chmod u+x setup_all.sh
! ./setup_all.sh

Input

In [None]:
import pandas as pd

original_text = ''
keyword = 'Keppel'
genre = 'nw'
model_name = 'spanbert_base'

Preprocess

In [None]:
import nltk
import tokenization
import json
nltk.download('punkt')


original_text = original_text
text = nltk.tokenize.sent_tokenize(original_text.replace('\n', ' '))
cnt = 0
sentence_dict = []
for (i, sentence) in enumerate(text):
    cnt += len(sentence.split())
    sentence_dict.append(cnt)

data = {
    'doc_key': genre,
    'sentences': [["[CLS]"]],
    'speakers': [["[SPL]"]],
    'clusters': [],
    'sentence_map': [0],
    'subtoken_map': [0],
}

# Determine Max Segment
max_segment = None
for line in open('experiments.conf'):
    if line.startswith(model_name):
        max_segment = True
    elif line.strip().startswith("max_segment_len"):
        if max_segment:
            max_segment = int(line.strip().split()[-1])
            break

tokenizer = tokenization.FullTokenizer(vocab_file="cased_config_vocab/vocab.txt", do_lower_case=False)
subtoken_num = 0
for sent_num, line in enumerate(text):
    raw_tokens = line.split()
    tokens = tokenizer.tokenize(line)
    if len(tokens) + len(data['sentences'][-1]) >= max_segment:
        data['sentences'][-1].append("[SEP]")
        data['sentences'].append(["[CLS]"])
        data['speakers'][-1].append("[SPL]")
        data['speakers'].append(["[SPL]"])
        data['sentence_map'].append(sent_num - 1)
        data['subtoken_map'].append(subtoken_num - 1)
        data['sentence_map'].append(sent_num)
        data['subtoken_map'].append(subtoken_num)

    ctoken = raw_tokens[0]
    cpos = 0
    for token in tokens:
        data['sentences'][-1].append(token)
        data['speakers'][-1].append("-")
        data['sentence_map'].append(sent_num)
        data['subtoken_map'].append(subtoken_num)

        if token.startswith("##"):
            token = token[2:]
        if len(ctoken) == len(token):
            subtoken_num += 1
            cpos += 1
            if cpos < len(raw_tokens):
                ctoken = raw_tokens[cpos]
        else:
            ctoken = ctoken[len(token):]

data['sentences'][-1].append("[SEP]")
data['speakers'][-1].append("[SPL]")
data['sentence_map'].append(sent_num - 1)
data['subtoken_map'].append(subtoken_num - 1)

with open("in.json", 'w') as out:
    json.dump(data, out, sort_keys=True)

! cat in.json

Predict

In [None]:
! GPU=0 python predict.py model_name in.json out.txt

Postprocess

In [None]:
output = json.load(open("out.txt"))

comb_text = [word for sentence in output['sentences'] for word in sentence]

def convert_cluster(mention):
    end = output['subtoken_map'][mention[1]] + 1
    mtext = ''.join(' '.join(comb_text[mention[0]:mention[1]+1]).split(" ##"))
    return (end, mtext)

seen = set()
clusters = []
clusters_idx = []
for cluster in output['predicted_clusters']:
    mapped = []
    mapped_idx = []
    for mention in cluster:
        seen.add(tuple(mention))
        convert = convert_cluster(mention)
        mapped.append(convert[1])
        mapped_idx.append(convert[0])
    clusters.append(mapped)
    clusters_idx.append(mapped_idx)

idx = []
for i, cluster in enumerate(clusters):
    for item in cluster:
        if keyword.lower() in item.lower():
            idx.append(i)
            break

print('Relevant sentences for \'' + keyword + '\':')
relevant_sentences = []
if len(idx):
    relevant_idx = []
    for i in idx:
        sentences_idx = []
        start_idx = 0
        for word_idx in clusters_idx[i]:
            while word_idx > sentence_dict[start_idx]:
                start_idx += 1
            sentences_idx.append(start_idx)
        relevant_idx.extend(sentences_idx)
    for i in set(relevant_idx):
        print(text[i])
        relevant_sentences.append(text[i])
else:
    print('None')
out_df = pd.DataFrame(relevant_sentences)
out_df.to_csv('coref_out.csv', index=False, header=False)