In [133]:
# installations (uncomment if necessary)

# !pip install checklist
# !pip install allennlp
# !pip install allennlp_models
# !pip install --upgrade google-cloud-storage

In [134]:
# imports (uncomment nltk download if necessary)

import json
import numpy as np
import nltk
# nltk.download('omw-1.4')

import spacy
nlp = spacy.load("en_core_web_sm")

import checklist
from checklist.editor import Editor
from checklist.perturb import Perturb
from checklist.test_types import MFT, INV, DIR
from checklist.expect import Expect
from checklist.pred_wrapper import PredictorWrapper
from checklist.test_suite import TestSuite

from allennlp_models.pretrained import load_predictor

In [135]:
# load models

srl_predictor = load_predictor('structured-prediction-srl')
srl_predictor_bert = load_predictor('structured-prediction-srl-bert')

error loading _jsonnet (this is expected on Windows), treating C:\Users\Shark\anaconda3\Lib\site-packages\allennlp_models\modelcards\coref-spanbert.json as plain json
error loading _jsonnet (this is expected on Windows), treating C:\Users\Shark\anaconda3\Lib\site-packages\allennlp_models\modelcards\evaluate_rc-lerc.json as plain json
lerc is not a registered model.
error loading _jsonnet (this is expected on Windows), treating C:\Users\Shark\anaconda3\Lib\site-packages\allennlp_models\modelcards\generation-bart.json as plain json
error loading _jsonnet (this is expected on Windows), treating C:\Users\Shark\anaconda3\Lib\site-packages\allennlp_models\modelcards\glove-sst.json as plain json
error loading _jsonnet (this is expected on Windows), treating C:\Users\Shark\anaconda3\Lib\site-packages\allennlp_models\modelcards\lm-masked-language-model.json as plain json
error loading _jsonnet (this is expected on Windows), treating C:\Users\Shark\anaconda3\Lib\site-packages\allennlp_models\mod

In [136]:
# get srl model predictions into PredictorWrapper format

### added by pia ###

def predict_srl(data):
    
    pred = []
    for d in data:
        pred.append(srl_predictor.predict(d))
    return pred

wrapper_srl = PredictorWrapper.wrap_predict(predict_srl)

def predict_bert(data):
    
    pred = []
    for d in data:
        pred.append(srl_predictor_bert.predict(d))
    return pred

wrapper_bert = PredictorWrapper.wrap_predict(predict_bert)

# summary format functions

def format_srl_last(x, pred, conf, label=None, meta=None):
    results = []
    predicate_structure = pred['verbs'][-1]['description']
        
    return predicate_structure

def format_srl_first(x, pred, conf, label=None, meta=None):
    results = []
    predicate_structure = pred['verbs'][0]['description']
        
    return predicate_structure


In [137]:
# get target arg functions

def get_argTMP(pred, arg_target='ARGM-TMP'):
    # we assume one predicate:
    predicate_arguments = pred['verbs'][0]
    words = pred['words']
    tags = predicate_arguments['tags']
    
    arg_list = []
    for t, w in zip(tags, words):
        if t.endswith(arg_target):
            arg_list.append(w)
    found_arg = ' '.join(arg_list)
                         
    return found_arg

def get_arg2(pred, arg_target='ARG2'):
    # we assume one predicate:
    predicate_arguments = pred['verbs'][0]
    words = pred['words']
    tags = predicate_arguments['tags']
    
    arg_list = []
    for t, w in zip(tags, words):
        if t.endswith(arg_target):
            arg_list.append(w)
    found_arg = ' '.join(arg_list)
                         
    return found_arg

def get_dative_args(pred):
    # we assume one predicate:
    predicate_arguments = pred['verbs'][0]
    words = pred['words']
    tags = predicate_arguments['tags']
    
    arg0_list = []
    arg1_list = []
    arg2_list = []
    
    for t, w in zip(tags, words):
        if t.endswith('ARG0'):
            arg0_list.append(w)
    found_arg0 = ' '.join(arg0_list)
    
    for t, w in zip(tags, words):
        if t.endswith('ARG1'):
            arg1_list.append(w)
    found_arg1 = ' '.join(arg1_list)
    
    for t, w in zip(tags, words):
        if t.endswith('ARG2'):
            arg2_list.append(w)
    found_arg2 = ' '.join(arg2_list)
                         
    return found_arg0, found_arg1, found_arg2
            
def get_arg_span_first(pred, target_span=[]):
    # we assume one predicate:
    predicate_arguments = pred['verbs'][0]
    words = pred['words']
    tags = predicate_arguments['tags']
    
    arg_list = []
    for t, w in zip(tags, words):
        arg = t
        if '-' in t:
            arg = t.split('-')[1]
        if w in target_span:
            arg_list.append(arg)
    return arg_list

def get_arg_span_last(pred, target_span=[]):
    # we assume one predicate:
    predicate_arguments = pred['verbs'][-1]
    words = pred['words']
    tags = predicate_arguments['tags']
    
    arg_list = []
    for t, w in zip(tags, words):
        arg = t
        if '-' in t:
            arg = t.split('-')[1]
        if w in target_span:
            arg_list.append(arg)
    return arg_list


In [138]:
# expectation functions

def expect_argTMP(x, pred, conf, label=None, meta=None):
    
    # people should be recognized as arg1

    arg = get_argTMP(pred)

    if arg == 'last week':
        pass_ = True
    else:
        pass_ = False
    return pass_

def expect_arg2(x, pred, conf, label=None, meta=None):
    
    # people should be recognized as arg1

    arg = get_arg2(pred)

    if arg == 'with a knife':
        pass_ = True
    else:
        pass_ = False
    return pass_

def expect_dativeIOC(x, pred, conf, label=None, meta=None):
    
    # should be recognized as arg
    arg0, arg1, arg2 = get_dative_args(pred)
    mask = ' '.join(meta['mask'])

    if arg0 == meta['first_name1'] and arg1 == mask and arg2 == 'to ' + meta['first_name2']:
        pass_ = True
    else:
        pass_ = False
    return pass_

def expect_dativeDOC(x, pred, conf, label=None, meta=None):
    
    # should be recognized as arg
    arg0, arg1, arg2 = get_dative_args(pred)
    mask = ' '.join(meta['mask'])

    if arg0 == meta['first_name1'] and arg1 == mask and arg2 == meta['first_name2']:
        pass_ = True
    else:
        pass_ = False
    return pass_

def compare_spans(orig_pred, pred, orig_conf, conf, labels=None, meta=None):
    
    sp_orig = meta['mask1']
    sp_pred = meta['mask1']
    
    l_orig = set(get_arg_span_first(orig_pred, sp_orig))
    l_pred = set(get_arg_span_first(pred, sp_pred))
    
    if l_orig == l_pred:
        pass_ = True
    else:
        pass_ = False
    
    
    return pass_

def compare_voice(orig_pred, pred, orig_conf, conf, labels=None, meta=None):
    
    sp_orig = meta['mask1']
    sp_pred = meta['mask1']
    
    l_orig = set(get_arg_span_last(orig_pred, sp_orig))
    l_pred = set(get_arg_span_first(pred, sp_pred))
    
    if l_orig == l_pred:
        pass_ = True
    else:
        pass_ = False
    
    
    return pass_

expectTMP = Expect.single(expect_argTMP)
expect2 = Expect.single(expect_arg2)
expectIOC = Expect.single(expect_dativeIOC)
expectDOC = Expect.single(expect_dativeDOC)
expectSpan = Expect.pairwise(compare_spans)
expectVoice = Expect.pairwise(compare_voice)

In [139]:
# load checklist Suite and Editor

suite = TestSuite()
editor = Editor()

# load dict for data storing
data_dict = {}

In [140]:
# create data

# long-range dependency spans

# non-core role
set1 = editor.template('The killer killed the victim, {mask} {mask} {mask} {mask} {mask} {mask}, last week.', meta=True, save=True)
data_dict['span non-core'] = set1  # save to data dict for storing

## core role
set2 = editor.template('The killer killed the victim, {mask} {mask} {mask} {mask} {mask} {mask}, with a knife.', meta=True, save=True)
data_dict['span core'] = set2

# complex sentence structures

## non-core role
set3 = editor.template('The killer, {mask} {mask} {mask} {mask} {mask} {mask}, killed the victim last week.', meta=True, save=True)
data_dict['complex non-core'] = set3

## core role
set4 = editor.template('The killer, {mask} {mask} {mask} {mask} {mask} {mask}, killed the victim with a knife.', meta=True, save=True)
data_dict['complex core'] = set4

# verb alternation

## dative alternation DOC
set5 = editor.template('{first_name1} gave {first_name2} {mask}.', nsamples=20, meta=True, save=True)
set5 += editor.template('{first_name1} gave {first_name2} {mask} {mask}.', nsamples=20, meta=True, save=True)
set5 += editor.template('{first_name1} gave {first_name2} {mask} {mask} {mask}.', nsamples=20, meta=True, save=True)
set5 += editor.template('{first_name1} gave {first_name2} {mask} {mask} {mask} {mask}.', nsamples=20, meta=True, save=True)
set5 += editor.template('{first_name1} gave {first_name2} {mask} {mask} {mask} {mask} {mask}.', nsamples=20, meta=True, save=True)
data_dict['dative DOC'] = set5

## dative alternation IOC
set6 = editor.template('{first_name1} gave {mask} to {first_name2}.', nsamples=20, meta=True, save=True)
set6 += editor.template('{first_name1} gave {mask} {mask} to {first_name2}.', nsamples=20, meta=True, save=True)
set6 += editor.template('{first_name1} gave {mask} {mask} {mask} to {first_name2}.', nsamples=20, meta=True, save=True)
set6 += editor.template('{first_name1} gave {mask} {mask} {mask} {mask} to {first_name2}.', nsamples=20, meta=True, save=True)
set6 += editor.template('{first_name1} gave {mask} {mask} {mask} {mask} {mask} to {first_name2}.', nsamples=20, meta=True, save=True)
data_dict['dative IOC'] = set6

## causative/inchoative alternation
set7 = editor.template(['Mary broke the {mask1}.', 'The {mask1} broke.'], nsamples=100, meta=True, save=True)
data_dict['inchoative'] = set7

# voice + robustness

## voice comparison 
set8 = editor.template(['The {mask1} was written.', 'She wrote the {mask1}.'], nsamples=20, meta=True, save=True)
set8 += editor.template(['The {mask1} was seen.', 'They saw the {mask1}.'], nsamples=20, meta=True, save=True)
set8 += editor.template(['The {mask1} was felt.', 'I felt the {mask1}.'], nsamples=20, meta=True, save=True)
set8 += editor.template(['The {mask1} was heard.', 'I heard the {mask1}.'], nsamples=20, meta=True, save=True)
set8 += editor.template(['The {mask1} was created.', 'I created the {mask1}.'], nsamples=20, meta=True, save=True)
set8_typo = Perturb.perturb(set8.data, Perturb.add_typos) # added typo's for robustness check
data_dict['voice/robustness'] = set8

# store data dict containing full challenge set as json file
with open('challenge_set.json', 'w') as json_file:
    json.dump(data_dict, json_file)

In [141]:
# open challenge set from json file and load into dict for use

with open('challenge_set.json', 'r') as f:
    challenge_set = json.load(f)

In [142]:
# initialize test objects

# create list to store tests
individual_tests = []

# define challenge subsets from full challenge set dict
ret1 = challenge_set['span non-core'] 
ret2 = challenge_set['span core'] 
ret3 = challenge_set['complex non-core'] 
ret4 = challenge_set['complex core'] 
ret5 = challenge_set['dative DOC'] 
ret6 = challenge_set['dative IOC'] 
ret7 = challenge_set['inchoative'] 
ret8 = challenge_set['voice/robustness']


# long-range dependency spans

# MFT non-core role
test1 = MFT(**ret1, name = 'span non-core role', capability='Long-range dependency spans', expect=expectTMP)
suite.add(test1, format_example_fn=format_srl_first, overwrite=True)   # add to test suite for easy running and viewing in the notebook
individual_tests.append(test1)      # add tests to list for extracting and storing results

# MFT core role
test2 = MFT(**ret2, name = 'span core role', capability='Long-range dependency spans', expect=expect2)
suite.add(test2, format_example_fn=format_srl_first, overwrite=True)
individual_tests.append(test2)

# complex sentence structures

# MFT non-core role
test3 = MFT(**ret3, name = 'complex non-core role', capability='Complex sentence structures', expect=expectTMP)
suite.add(test3, format_example_fn=format_srl_last, overwrite=True)
individual_tests.append(test3)

# MFT core role
test4 = MFT(**ret4, name = 'complex core role', capability='Complex sentence structures', expect=expect2)
suite.add(test4, format_example_fn=format_srl_last, overwrite=True)
individual_tests.append(test4)

# verb alternation 

# MFT dative alternation DOC
test5 = MFT(**ret5, name = 'Dative shift DOC', capability='Verb alternation', expect=expectDOC)
suite.add(test5, format_example_fn=format_srl_first, overwrite=True)
individual_tests.append(test5)
# MFT dative alternation IOC
test6 = MFT(**ret6, name = 'Dative shift IOC', capability='Verb alternation', expect=expectIOC)
suite.add(test6, format_example_fn=format_srl_first, overwrite=True)
individual_tests.append(test6)
# DIR causative/inchoative alternation
test7 = DIR(**ret7, name='Inchoative/causative', capability='Verb alternation', 
           description='Takes an inchoative/causative pair and checks if the generated word has the same label.', expect=expectSpan)
suite.add(test7, format_example_fn=format_srl_first, overwrite=True)
individual_tests.append(test7)

# voice

# active/passive comparison + robustness
test8 = DIR(**ret8, name='voice comparison', capability='Verb alternation', 
           description='Takes an active/passive sentence pair and checks if the generated word has the same label.', expect=expectVoice)
suite.add(test8, format_example_fn=format_srl_last, overwrite=True)
individual_tests.append(test8)

In [143]:
# create dict to store results
result_dict = {}

# loop over each individual test, run the test, split results by 'passed' and 'failed' and save results to json file

for test in individual_tests:   # loop over each test
    passed_dict = {}
    capability = test.capability     
    name = test.name
    test.run(wrapper_srl, overwrite=True)     # run test
    
    passed_list = list(test.results.passed)     # get pass/fail array into list
    all_preds = list(test.results.preds)        # get all predictions
    
       # remove np arrays for json storing
        
    new_preds = []
    for item in all_preds:
        if isinstance(item, np.ndarray):
            new_item = list(item)
            new_preds.append(new_item[0])
        else:
            new_preds.append(item)
    
    true = []       # define pass and fail lists
    false = []
    
    for pred, stat in zip(new_preds, passed_list):   # loop over zipped predictions 
        if stat == True:
            true.append(pred)     # add prediction to pass or fail list
        if stat == False:
            false.append(pred)
            
    passed_dict['Passed'] = true        # add passed and failed predictions of one test to dict
    passed_dict['Failed'] = false
    result_dict[name] = passed_dict    # add passed dict to overall result dict with test name as key
            
        
with open('test_results.json', 'w') as json_file:    # store results in json file
    json.dump(result_dict, json_file)

Predicting 100 examples
Predicting 100 examples
Predicting 100 examples
Predicting 100 examples
Predicting 100 examples
Predicting 100 examples
Predicting 200 examples
Predicting 200 examples


In [147]:
# run suite LSTM for an overview in the notebook or look at the json file test_results.json

suite.run(wrapper_srl, overwrite=True)
print()
suite.summary(n=10)

Running span non-core role
Predicting 100 examples
Running span core role
Predicting 100 examples
Running complex non-core role
Predicting 100 examples
Running complex core role
Predicting 100 examples
Running Dative shift DOC
Predicting 100 examples
Running Dative shift IOC
Predicting 100 examples
Running Inchoative/causative
Predicting 200 examples
Running voice comparison
Predicting 200 examples

Verb alternation

Dative shift DOC
Test cases:      100
Fails (rate):    20 (20.0%)

Example fails:
[ARG0: Victoria] [V: gave] [ARG1: Eleanor flowers] .
----
[ARG0: Nancy] [V: gave] [ARG2: Victoria] [ARGM-TMP: a few minutes] .
----
[ARG0: Rachel] [V: gave] [ARG1: Eleanor presents] .
----
[ARG0: Al] [V: gave] [ARG1: Julia] away .
----
[ARG0: Sally] [V: gave] [ARG2: Robert CPR] .
----
[ARG0: Steve] [V: gave] [ARG1: Don something to try] .
----
[ARG0: Victoria] [V: gave] [ARG1: Eleanor " "] .
----
[ARG0: Steve] [V: gave] [ARG1: Don handcuffs] .
----
[ARG0: Christopher] [V: gave] [ARG1: Frances

In [None]:
# run test suite bert for an overview in the notebook or look at the json file test_results.json

suite.run(wrapper_bert, overwrite=True)
print()
suite.summary(n=10)

Running span non-core role
Predicting 100 examples
Running span core role
Predicting 100 examples
Running complex non-core role
Predicting 100 examples
Running complex core role
Predicting 100 examples
Running Dative shift DOC
Predicting 100 examples
Running Dative shift IOC
Predicting 100 examples
Running Inchoative/causative
Predicting 200 examples
Running voice comparison
Predicting 200 examples

Verb alternation

Dative shift DOC
Test cases:      100
Fails (rate):    14 (14.0%)

Example fails:
[ARG0: Sally] [V: gave] [ARG2: Robert CPR] .
----
[ARG0: Christopher] [V: gave] [ARG2: Frances] [ARG1: a kiss] [ARGM-LOC: on the table] .
----
[ARG0: Sara] [V: gave] [ARG2: Anthony] [ARG1: a kiss] [ARGM-LOC: on the porch] .
----


Dative shift IOC
Test cases:      100
Fails (rate):    20 (20.0%)

Example fails:
[ARG0: Al] [V: gave] up [ARG1: to Julia] .
----
[ARG0: Jay] [V: gave] [ARG1: it all] [ARGM-MNR: straight] [ARG2: to Donald] .
----
[ARG0: Christopher] [V: gave] [ARG1: it] out [ARG2: t