### Parse Data

In [27]:
%load_ext autoreload
%autoreload 1

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [185]:
import pandas as pd
import numpy as np

train_tsv = './../../data/train.tsv'
dev_tsv = './../../data/dev.tsv'
test_tsv = './../../data/dev.tsv'

In [187]:
train_data = pd.read_csv(train_tsv,delimiter='\t',encoding='utf-8')
dev_data = pd.read_csv(dev_tsv,delimiter='\t',encoding='utf-8')
test_data = pd.read_csv(test_tsv,delimiter='\t',encoding='utf-8')

In [20]:
train_data['sentence1'][0]

'I stuck a pin through a carrot. When I pulled the pin out, it had a hole.'

In [10]:
train_data['sentence2'][0]

'The carrot had a hole.'

### Load Neuralcoref

https://github.com/huggingface/neuralcoref

In [98]:
from string import punctuation
def strip_punctuation(s):
    return ''.join(c for c in s if c not in punctuation)

In [2]:
import spacy
# Note: spacy version needed to be 2.0.12
# View related discussion here: https://github.com/huggingface/neuralcoref/issues/102
nlp = spacy.load('en_coref_lg')

In [419]:
ENTAILMENT = 1
NOT_ENTAILMENT = 0
MAJORITY = NOT_ENTAILMENT

class WNLI():
    '''
        If use_coref is False, we always return majority label
    '''
    def __init__(self, nlp, data, majority=MAJORITY, use_coref=True, debug=False):
        self.nlp = nlp
        self.data = data
        self.debug = debug
        self.use_coref = use_coref
        self.majority = majority
        self.none_count = None
        
    def remove_article(self, sen):
        sen = sen.replace("the ", "").replace("an ", "").replace("a ", "").strip()

        return sen
        
    '''
    Find a possible overlap between sentence1 and sentence 2
    Return (sen1_ref_ind, sen1_ref, sen2_ref_ind, sen2_ref, result_query1, result_query2)
    '''    
    def find_overlap(self, sen1, sen2):
        ori_sen1 = sen1
        ori_sen2_tokens = sen2.split(" ")

        sen1 = strip_punctuation(sen1.lower())
        sen2 = strip_punctuation(sen2.lower())
        tokens = sen2.split(" ")

        result = None
        # allow for up to 3 words
        for slack in range(5):
            for ind, token in enumerate(tokens[:len(tokens)-slack]):
                query1 = " ".join(tokens[:ind])
                query2 = " ".join(tokens[ind + slack+1:])

                if query1 in sen1 and query2 in sen1:
                    sen1_ref_result = self.find_sen1_ref(ori_sen1, query1.strip(), query2.strip())

                    # If we see that query2 is before query1, we assume that we parsed this wrong
                    if sen1_ref_result is None:
                        continue

                    (sen1_ref_ind, sen1_ref) = sen1_ref_result

                    result_query1 = " ".join(ori_sen2_tokens[:ind]).strip() 
                    result_query2 = " ".join(ori_sen2_tokens[ind + slack+1:]).strip()
                    sen2_ref = " ".join(ori_sen2_tokens[ind:ind + slack + 1])

                    return (sen1_ref_ind, sen1_ref, ind, sen2_ref, result_query1, result_query2)

        return None
    
    '''
    Find the corresponding reference in sentence 1 from sentence 2
    '''
    def find_sen1_ref(self, sen1, query1, query2):
        ori_sen1_tokens = sen1.split(" ")
        sen1_tokens = strip_punctuation(sen1.lower()).split(" ")

        query1_ind, query2_ind = -1, -1
        query1_len = 0 if len(query1) == 0 else len(query1.split(" "))
        query2_len = 0 if len(query2) == 0 else len(query2.split(" "))

        # Get the indices of each query
        for ind in range(len(sen1_tokens)):
            sen1_substring = " ".join(sen1_tokens[ind:])

            if query1 in sen1_substring:
                query1_ind = ind

            if query2 in sen1_substring:
                query2_ind = ind

            # we no longer have the query in the substring
            if query1_ind != ind and query2_ind != ind and query1_ind != -1 and query2_ind != -1:
                break
                
        # Sanity check: check that the words actually are all the same in tokens
        if query1_len > 0:
            sen1_substring = sen1_tokens[query1_ind:query1_ind + query1_len]
            if not all([x == y for x, y in zip(sen1_substring, query1.split(" "))]):
                return None
            
        if query2_len > 0:
            sen1_substring = sen1_tokens[query2_ind:query2_ind + query2_len]
            if not all([x == y for x, y in zip(sen1_substring, query2.split(" "))]):
                return None
            

        # Do a sanity check, making sure we have query 2 after query 1
        if query1_len > 1 and query2_len > 0 and query1_ind > query2_ind:
            return None

        # if query 1 is "the" we know the word should come before query 2 in original sentence
        # if query 2 has more than length 0 we are pretty positive that the query goes inbetween query 1 and query2
        if query1 in ["the","a","an"] or query2_len > 0:
            ori_ref_ind = query2_ind - 1
            ori_ref = ori_sen1_tokens[ori_ref_ind]

        # if query2 has length 0, we know that reference should be after query 1
        elif query2_len == 0:
            ori_ref_ind = query1_ind + query1_len

            if ori_sen1_tokens[ori_ref_ind] in ["the", "a", "an", "of"]:
                ori_ref_ind += 1

            ori_ref = ori_sen1_tokens[ori_ref_ind]
        else:
            if self.debug:
                raise ValueError("Failed to parse : " + sen1)
                
            return None

        return (ori_ref_ind, ori_ref)
    
    '''
    Using the coreference model, try to make prediction
    If the model fails to give us an answer, just return None
    '''
    def get_label(self, sen1, sen1_ref_ind, sen1_ref, sen2_ref):
        doc_sen1 = self.nlp(sen1)

        # we don't have any coreference detected, so just return None
        if doc_sen1._.coref_clusters is None:
            if self.debug:
                print("No coreference detected, returning None")
            return None

        # the indices of token by nlp does not match the token from previous parsing - hence, manually align (might incur some error)
        spacy_ind = sen1_ref_ind
        matched = False

        while not matched:
            if strip_punctuation(str(doc_sen1[spacy_ind]).lower()) == strip_punctuation(sen1_ref.lower()):
                matched = True
            else:
                spacy_ind += 1

        token = doc_sen1[spacy_ind]

        # we think the token is not in coref - we can't do anything, return None
        if not token._.in_coref:
            if self.debug:
                print("Token not in coref, returning None")
            return None

        cluster = token._.coref_clusters[0]
        query =  self.remove_article(strip_punctuation(sen2_ref.lower()))
        
        if self.debug:
            print("Query:" + query)
            print("Cluster : " + str(cluster))

        if query in strip_punctuation(str(cluster.main).lower()):
            return ENTAILMENT

        for mention in cluster.mentions:
            if query in strip_punctuation(str(mention).lower()):
                return ENTAILMENT

        return NOT_ENTAILMENT

    def predict_single(self, ind, sen1, sen2):
        if not self.use_coref:
            return self.majority
        
        if self.debug:
            print("="*50)
        
        result = self.find_overlap(sen1, sen2)
        pred_label = None

        if result is not None:
            (sen1_ref_ind, sen1_ref, sen2_ref_ind, sen2_ref, result_query1, result_query2) = result       
            pred_label = self.get_label(sen1, sen1_ref_ind, sen1_ref, sen2_ref)
            
            if self.debug:
                print("[sen1]: " + sen1)
                print("[sen2]: " + sen2)
                print("[sen1_ref_ind]: " + str(sen1_ref_ind))
                print("[sen1_ref]: " + sen1_ref)
                print("[sen2_ref_ind]: " + str(sen2_ref_ind))
                print("[sen2_ref]: " + sen2_ref)
                print("[result_query1]: " + result_query1)
                print("[result_query2]: " + result_query2)
                print("[pred_label]:" + str(pred_label) + (" -> " + str(self.majority) if pred_label is None else ""))
                print("[actual_label]:" + str(self.data['label'][ind]))
                
        if pred_label is None:
            self.none_count += 1
            
        return self.majority if pred_label is None else pred_label

    def predict(self):
        self.none_count = 0
        labels = [self.predict_single(ind, row['sentence1'], row['sentence2']) for ind, row in self.data.iterrows()]
        
        print("Could not use coref model for {}/{} examples".format(self.none_count, len(self.data)))
        return labels

    def score(self, pred_labels):
        true_labels = self.data['label']
        return (pred_labels == true_labels).mean()
    

## Actual testing

### Using Large Corference Model

In [321]:
nlp_lg = spacy.load('en_coref_lg')

#### Using no coreference model and just predicting majority : 

In [335]:
wnli_lg_train_majority = WNLI(nlp_lg, train_data, majority=MAJORITY, use_coref=False, debug=False)
train_labels = wnli_lg_train_majority.predict()
train_score = wnli_lg_train_majority.score(train_labels)

wnli_lg_dev_majority = WNLI(nlp_lg, dev_data, majority=MAJORITY, use_coref=False, debug=False)
dev_labels = wnli_lg_dev_majority.predict()
dev_score = wnli_lg_dev_majority.score(dev_labels)

print("Train score: " + str(train_score))
print("Dev score: " + str(dev_score))

Train score: 0.5086614173228347
Dev score: 0.5633802816901409


#### Using coreference model

In [421]:
wnli_lg_dev = WNLI(nlp_lg, dev_data, majority=MAJORITY, use_coref=True, debug=True)
lg_dev_labels = wnli_lg_dev.predict()
lg_dev_score = wnli_lg_dev.score(dev_labels)

Query:hair
Cluster : The drain: [The drain, It]
[sen1]: The drain is clogged with hair. It has to be cleaned.
[sen2]: The hair has to be cleaned.
[sen1_ref_ind]: 6
[sen1_ref]: It
[sen2_ref_ind]: 1
[sen2_ref]: hair
[result_query1]: The
[result_query2]: has to be cleaned.
[pred_label]:0
[actual_label]:0
Query:susan
Cluster : Jane: [Jane, she]
[sen1]: Jane knocked on Susan's door but she did not answer.
[sen2]: Susan did not answer.
[sen1_ref_ind]: 6
[sen1_ref]: she
[sen2_ref_ind]: 0
[sen2_ref]: Susan
[result_query1]: 
[result_query2]: did not answer.
[pred_label]:0
[actual_label]:1
Query:sally
Cluster : Beth: [Beth, her, she]
[sen1]: Beth didn't get angry with Sally, who had cut her off, because she stopped and counted to ten.
[sen2]: Sally stopped and counted to ten.
[sen1_ref_ind]: 12
[sen1_ref]: she
[sen2_ref_ind]: 0
[sen2_ref]: Sally
[result_query1]: 
[result_query2]: stopped and counted to ten.
[pred_label]:0
[actual_label]:0
No coreference detected, returning None
[sen1]: No one jo

Query:larry
Cluster : Larry: [Larry, his, he, him, his, him]
[sen1]: Always before, Larry had helped Dad with his work. But he could not help him now, for Dad said that his boss at the railroad company would not want anyone but him to work in the office.
[sen2]: Larry could not help him now.
[sen1_ref_ind]: 10
[sen1_ref]: he
[sen2_ref_ind]: 0
[sen2_ref]: Larry
[result_query1]: 
[result_query2]: could not help him now.
[pred_label]:1
[actual_label]:1
Query:eric
Cluster : George: [George, he, he]
[sen1]: George got free tickets to the play, but he gave them to Eric, because he was particularly eager to see it.
[sen2]: Eric was particularly eager to see it.
[sen1_ref_ind]: 14
[sen1_ref]: he
[sen2_ref_ind]: 0
[sen2_ref]: Eric
[result_query1]: 
[result_query2]: was particularly eager to see it.
[pred_label]:0
[actual_label]:1
Token not in coref, returning None
[sen1]: They broadcast an announcement, but a subway came into the station and I couldn't hear over it.
[sen2]: I couldn't hear the 

Query:drawing
Cluster : Sam's drawing: [Sam's drawing, it, it]
[sen1]: Sam's drawing was hung just above Tina's and it did look much better with another one below it.
[sen2]: Tina's drawing did look much better with another one below it.
[sen1_ref_ind]: 8
[sen1_ref]: it
[sen2_ref_ind]: 1
[sen2_ref]: drawing
[result_query1]: Tina's
[result_query2]: did look much better with another one below it.
[pred_label]:1
[actual_label]:0
Query:things
Cluster : they had to be denied so many things: [they had to be denied so many things, them]
[sen1]: Papa looked down at the children's faces, so puzzled and sad now. It was bad enough that they had to be denied so many things because he couldn't afford them.
[sen2]: He couldn't afford the things.
[sen1_ref_ind]: 29
[sen1_ref]: them.
[sen2_ref_ind]: 3
[sen2_ref]: the things.
[result_query1]: He couldn't afford
[result_query2]: 
[pred_label]:1
[actual_label]:1
Query:sam
Cluster : Sam: [Sam, he]
[sen1]: Sam took French classes from Adam, because he was 

In [406]:
wnli_lg_train = WNLI(nlp_lg, train_data, majority=MAJORITY, use_coref=True, debug=False)
lg_train_labels = wnli_lg_train.predict()
lg_train_score = wnli_lg_train.score(train_labels)

wnli_lg_dev = WNLI(nlp_lg, dev_data, majority=MAJORITY, use_coref=True, debug=False)
lg_dev_labels = wnli_lg_dev.predict()
lg_dev_score = wnli_lg_dev.score(dev_labels)

print("Train score: " + str(lg_train_score))
print("Dev score: " + str(lg_dev_score))

Train score: 0.5118110236220472
Dev score: 0.5352112676056338


### Using medium coreference model

In [409]:
nlp_md = spacy.load('en_coref_md')

In [420]:
wnli_md_train = WNLI(nlp_md, train_data, majority=MAJORITY, use_coref=True, debug=False)
md_train_labels = wnli_md_train.predict()
md_train_score = wnli_md_train.score(md_train_labels)

wnli_md_dev = WNLI(nlp_md, dev_data, majority=MAJORITY, use_coref=True, debug=False)
md_dev_labels = wnli_md_dev.predict()
md_dev_score = wnli_md_dev.score(md_dev_labels)

print("Train score: " + str(md_train_score))
print("Dev score: " + str(md_dev_score))

Could not use coref model for 91/635 examples
Could not use coref model for 13/71 examples
Train score: 0.5133858267716536
Dev score: 0.5211267605633803
