This notebook runs the coreferecne resolution model described in ["SpanBERT: Improving Pre-training by Representing and Predicting Spans"](https://arxiv.org/pdf/1907.10529.pdf) by Mandar Joshi, Danqi Chen, Yinhan Liu, Daniel S. Weld, Luke Zettlemoyer, Omer Levy, and released here: https://github.com/mandarjoshi90/coref

This Colab is by me, Jonathan K. Kummerfeld. My website is www.jkk.name

Thank you to [Shon Otmazgin](https://github.com/shon-otmazgin) for bugfixes that address software changes since I originally made this colab.

Note:
- This code does not handle text with multiple speakers, for that you will need to adjust the data preparation process.
- Occasionally I get a bug where either an assertion about the size of the input mask fails or a sequence is being assigned to an array element. It appears to be inconsistent across runs, so I'm not sure what is going on.
- The default model is not the best one. I chose it because it is much faster to download.

If you have suggestions, please contact me at jkummerf@umich.edu

# Configuration

First, specify your input. If you are just playing with this, edit the provided text. If you want to run on a larger file:

1. Upload a file.
2. Set the filename.

In [None]:
filename = "optional-change-to-your-file.txt"

text = [
"Firefly is an American space Western drama television series which ran from 2002-2003, created by writer and director Joss Whedon, under his Mutant Enemy Productions label.",
"Whedon served as an executive producer, along with Tim Minear.",
"The series is set in the year 2517, after the arrival of humans in a new star system and follows the adventures of the renegade crew of Serenity, a 'Firefly-class' spaceship.",
"The ensemble cast portrays the nine characters who live on Serenity.",
"Whedon pitched the show as 'nine people looking into the blackness of space and seeing nine different things.'",
"The show explores the lives of a group of people, some of whom fought on the losing side of a civil war, who make a living on the fringes of society as part of the pioneer culture of their star system.",
"In this future, the only two surviving superpowers, the United States and China, fused to form the central federal government, called the Alliance, resulting in the fusion of the two cultures.",
"According to Whedon's vision, 'nothing will change in the future: technology will advance, but we will still have the same political, moral, and ethical problems as today.'",
"Firefly premiered in the U.S. on the Fox network on September 20, 2002.",
"By mid-December, Firefly had averaged 4.7 million viewers per episode and was 98th in Nielsen ratings.",
"It was canceled after 11 of the 14 produced episodes were aired.",
"Despite the relatively short life span of the series, it received strong sales when it was released on DVD and has large fan support campaigns.",
"It won a Primetime Emmy Award in 2003 for Outstanding Special Visual Effects for a Series.",
"TV Guide ranked the series at No. 5 on their 2013 list of 60 shows that were 'Cancelled Too Soon.'",
"The post-airing success of the show led Whedon and Universal Pictures to produce Serenity, a 2005 film which continues from the story of the series, and the Firefly franchise expanded to other media, including comics and a role-playing game.",
]

if filename != "optional-change-to-your-file.txt":
    data = [l.strip() for l in open(filename).readlines()]

Next, specify the data type and model:

In [None]:
genre = "nw"
# The Ontonotes data for training the model contains text from several sources
# of very different styles. You need to specify the most suitable one out of:
# "bc": broadcast conversation
# "bn": broadcast news
# "mz": magazine
# "nw": newswire
# "pt": Bible text
# "tc": telephone conversation
# "wb": web data

model_name = "spanbert_base"
# The fine-tuned model to use. Options are:
# bert_base
# spanbert_base
# bert_large
# spanbert_large

# System Installation
Get the code:

In [None]:
! git clone https://github.com/mandarjoshi90/coref.git
%cd coref

Cloning into 'coref'...
remote: Enumerating objects: 6, done.[K
remote: Counting objects: 100% (6/6), done.[K
remote: Compressing objects: 100% (6/6), done.[K
remote: Total 734 (delta 2), reused 0 (delta 0), pack-reused 728[K
Receiving objects: 100% (734/734), 4.17 MiB | 8.64 MiB/s, done.
Resolving deltas: 100% (441/441), done.
/content/coref


Temporary hack to fix a requirement (pending pull request)

In [None]:
! sed 's/MarkupSafe==1.0/MarkupSafe==1.1.1/; s/scikit-learn==0.19.1/scikit-learn==0.21/; s/scipy==1.0.0/scipy==1.6.2/' < requirements.txt > tmp
! mv tmp requirements.txt

! sed 's/.D.GLIBCXX.USE.CXX11.ABI.0//' < setup_all.sh  > tmp
! mv tmp setup_all.sh 
! chmod u+x setup_all.sh 

Set some environment variables. The data directory one is used by the system, the other is so we can use the model defined above.

In [None]:
import os
os.environ['data_dir'] = "."
os.environ['CHOSEN_MODEL'] = model_name

Run Setup. Note, some incompatibility issues do appear, but I still find that everything installs and runs. Also, I specifically request tensorflow 2 and then uninstall it to make sure we've got a clean setup.

In [None]:
%tensorflow_version 2.x
! pip uninstall -y tensorflow
! pip install -r requirements.txt --log install-log.txt -q
! ./setup_all.sh

Uninstalling tensorflow-2.4.1:
  Successfully uninstalled tensorflow-2.4.1
[K     |████████████████████████████████| 102kB 5.6MB/s 
[K     |████████████████████████████████| 1.2MB 10.5MB/s 
[K     |████████████████████████████████| 163kB 27.1MB/s 
[K     |████████████████████████████████| 6.6MB 26.5MB/s 
[K     |████████████████████████████████| 552kB 49.5MB/s 
[K     |████████████████████████████████| 61kB 7.6MB/s 
[K     |████████████████████████████████| 2.2MB 8.5MB/s 
[K     |████████████████████████████████| 266kB 37.3MB/s 
[K     |████████████████████████████████| 890kB 52.4MB/s 
[K     |████████████████████████████████| 133kB 54.3MB/s 
[K     |████████████████████████████████| 153kB 58.2MB/s 
[K     |████████████████████████████████| 51kB 6.5MB/s 
[K     |████████████████████████████████| 51kB 6.6MB/s 
[K     |████████████████████████████████| 92kB 11.1MB/s 
[K     |████████████████████████████████| 20.3MB 1.4MB/s 
[K     |████████████████████████████████| 2.1MB 

Get the finetuned BERT model specified above.

In [None]:
! ./download_pretrained.sh $CHOSEN_MODEL

Downloading spanbert_base
--2021-03-30 14:28:26--  http://nlp.cs.washington.edu/pair2vec/spanbert_base.tar.gz
Resolving nlp.cs.washington.edu (nlp.cs.washington.edu)... 128.208.3.120, 2607:4000:200:12::78
Connecting to nlp.cs.washington.edu (nlp.cs.washington.edu)|128.208.3.120|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1633726311 (1.5G) [application/x-gzip]
Saving to: ‘./spanbert_base.tar.gz’


2021-03-30 14:29:04 (41.2 MB/s) - ‘./spanbert_base.tar.gz’ saved [1633726311/1633726311]

spanbert_base/
spanbert_base/checkpoint
spanbert_base/model.max.ckpt.index
spanbert_base/stdout.log
spanbert_base/bert_config.json
spanbert_base/vocab.txt
spanbert_base/model.max.ckpt.data-00000-of-00001
spanbert_base/events.out.tfevents.1561596094.learnfair1413


# Data Preparation and Prediction

Process the data to be in the required input format.

In [None]:
from bert import tokenization
import json

data = {
    'doc_key': genre,
    'sentences': [["[CLS]"]],
    'speakers': [["[SPL]"]],
    'clusters': [],
    'sentence_map': [0],
    'subtoken_map': [0],
}

# Determine Max Segment
max_segment = None
for line in open('experiments.conf'):
    if line.startswith(model_name):
        max_segment = True
    elif line.strip().startswith("max_segment_len"):
        if max_segment:
            max_segment = int(line.strip().split()[-1])
            break

tokenizer = tokenization.FullTokenizer(vocab_file="cased_config_vocab/vocab.txt", do_lower_case=False)
subtoken_num = 0
for sent_num, line in enumerate(text):
    raw_tokens = line.split()
    tokens = tokenizer.tokenize(line)
    if len(tokens) + len(data['sentences'][-1]) >= max_segment:
        data['sentences'][-1].append("[SEP]")
        data['sentences'].append(["[CLS]"])
        data['speakers'][-1].append("[SPL]")
        data['speakers'].append(["[SPL]"])
        data['sentence_map'].append(sent_num - 1)
        data['subtoken_map'].append(subtoken_num - 1)
        data['sentence_map'].append(sent_num)
        data['subtoken_map'].append(subtoken_num)

    ctoken = raw_tokens[0]
    cpos = 0
    for token in tokens:
        data['sentences'][-1].append(token)
        data['speakers'][-1].append("-")
        data['sentence_map'].append(sent_num)
        data['subtoken_map'].append(subtoken_num)
        
        if token.startswith("##"):
            token = token[2:]
        if len(ctoken) == len(token):
            subtoken_num += 1
            cpos += 1
            if cpos < len(raw_tokens):
                ctoken = raw_tokens[cpos]
        else:
            ctoken = ctoken[len(token):]

data['sentences'][-1].append("[SEP]")
data['speakers'][-1].append("[SPL]")
data['sentence_map'].append(sent_num - 1)
data['subtoken_map'].append(subtoken_num - 1)

with open("sample.in.json", 'w') as out:
    json.dump(data, out, sort_keys=True)

! cat sample.in.json

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
W0330 14:29:28.621733 140693483616128 deprecation_wrapper.py:119] From /content/coref/bert/tokenization.py:125: The name tf.gfile.GFile is deprecated. Please use tf.io.gfile.GFile instead.



{"clusters": [], "doc_key": "nw", "sentence_map": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 1

Run Prediction

In [None]:
! GPU=0 python predict.py $CHOSEN_MODEL sample.in.json sample.out.txt

  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
  from ._conv import register_converters as _register_converters
  _np_qint8 = np.dtype([("qint8", np.int8, 1)])
  _np_quint8 = np.dtype([("quint8", np.uint8, 1)])
  _np_qint16 = np.dtype([("qint16", np.int16, 1)])
  _np_quint16 = np.dtype([("quint16", np.uint16, 1)])
  _np_qint32 = np.dtype([("qint32", np.int32, 1)])
  np_resource = np.dtype([("resource", np.ubyte, 1)])
W0330 14:29:30.837061 140695779342208 deprecation_wrapper.py:119] From /content/coref/coref_ops.py:11: The name tf.NotDifferentiable is deprecated. Please use tf.no_gradient instead.

W0330 14:29:30.949380 140695779342208 deprecation_wrapper.py:119] From /content/coref/bert/optimization.py:87: The name tf.train.Opti

# Output Handling

Finally, we do a little processing to get the output to have the same token indices as our input.

In [None]:
output = json.load(open("sample.out.txt"))

comb_text = [word for sentence in output['sentences'] for word in sentence]

def convert_mention(mention):
    start = output['subtoken_map'][mention[0]]
    end = output['subtoken_map'][mention[1]] + 1
    nmention = (start, end)
    mtext = ''.join(' '.join(comb_text[mention[0]:mention[1]+1]).split(" ##"))
    return (nmention, mtext)

seen = set()
print('Clusters:')
for cluster in output['predicted_clusters']:
    mapped = []
    for mention in cluster:
        seen.add(tuple(mention))
        mapped.append(convert_mention(mention))
    print(mapped, end=",\n")

print('\nMentions:')
for mention in output['top_spans']:
    if tuple(mention) in seen:
        continue
    print(convert_mention(mention), end=",\n")

Clusters:
[((15, 20), 'writer and director Joss Whedon'), ((21, 22), 'his'), ((26, 27), 'Whedon'), ((78, 79), 'Whedon'), ((170, 171), "Whedon ' s"), ((304, 305), 'Whedon')],
[((2, 9), 'an American space Western drama television series'), ((36, 38), 'The series'), ((80, 82), 'the show'), ((96, 98), 'The show'), ((195, 196), 'Firefly'), ((210, 211), 'Firefly'), ((224, 225), 'It'), ((243, 245), 'the series'), ((245, 246), 'it'), ((250, 251), 'it'), ((261, 262), 'It'), ((280, 282), 'the series'), ((301, 303), 'the show'), ((320, 322), 'the series'), ((324, 325), 'Firefly')],
[((63, 67), "Serenity , a ' Firefly - class ' spaceship"), ((77, 78), 'Serenity')],
[((50, 54), 'a new star system'), ((134, 137), 'their star system')],
[((12, 13), '2002'), ((207, 208), '2002')],
[((12, 13), '2003'), ((268, 269), '2003')],
[((277, 279), 'TV Guide'), ((286, 287), 'their')],

Mentions:
((0, 1), 'Firefly'),
((0, 2), 'Firefly is'),
((2, 13), 'an American space Western drama television series which ran fr