
## Experiment 2: Using NLI to obtain constraints between WDIK facts and raw answers, and solving them using Z3 to flip answers if required

In [1]:
%load_ext autoreload
%autoreload 2

In [3]:
import experiment_setup

In [4]:
from belief.evaluation import load_facts
from belief.nli import load_nli_model, load_nli_tokenizer, run_nli
from belief.lmbb import Proposition, LMBB
from tqdm import tqdm
import json
import random
from z3 import *

In [5]:
facts = load_facts('data/calibration_facts.json', num_batches=1)[0]

with open('cache/wdik.json', 'r') as f:
    wdik = json.load(f)

with open('data/constraints_v2.json', 'r') as f:
    constraint_data = json.load(f)
    
with open('cache/raw_outs_calib.json', 'r') as f:
    raw_outs = json.load(f)['outs']

In [6]:
evaluator = LMBB(
    model=None, 
    tokenizer=None, 
    raw_constraints=constraint_data['links'],
)

In [7]:
raw_outs_dict = {}
for out in raw_outs:
    raw_outs_dict[out['prop']] = {'yes': out['yes'], 'no': out['no']}

In [9]:
nli_model = load_nli_model()
nli_tokenizer = load_nli_tokenizer()

[autoreload of google.protobuf.descriptor failed: Traceback (most recent call last):
  File "/home2/kushal/miniconda3/envs/ID/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 261, in check
    superreload(m, reload, self.old_objects)
  File "/home2/kushal/miniconda3/envs/ID/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 484, in superreload
    update_generic(old_obj, new_obj)
  File "/home2/kushal/miniconda3/envs/ID/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 381, in update_generic
    update(a, b)
  File "/home2/kushal/miniconda3/envs/ID/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 349, in update_class
    update_instances(old, new)
  File "/home2/kushal/miniconda3/envs/ID/lib/python3.9/site-packages/IPython/extensions/autoreload.py", line 307, in update_instances
    object.__setattr__(ref, "__class__", new)
TypeError: can't apply this __setattr__ to DescriptorMetaclass object
]


In [10]:
z3mult = 1000
wdik_weight = 1.5
NUM_FACTS = 3

In [11]:
new_beliefs = {}

for prop in tqdm(facts):
    sent = prop.sentence
    raw_out = raw_outs_dict[sent]
    subject = prop.subject
    if NUM_FACTS > 0:
        wdik_facts = random.sample(wdik[subject], NUM_FACTS)
    else:
        wdik_facts = wdik[subject]
    
    bools = {}
    optim = Optimize()

    bools[sent] = Bool(sent)
    optim.add_soft(bools[sent], int(raw_out['yes'] * z3mult))
    optim.add_soft(Not(bools[sent]), int(raw_out['no'] * z3mult))

    assertion = prop.get_assertion()
    
    for wik in wdik_facts:
        bools[wik] = Bool(wik)
        optim.add_soft(bools[wik], wdik_weight * z3mult)

        # wik -> fact
        nli_out = run_nli(premise=wik, hypothesis=assertion, model=nli_model, tokenizer=nli_tokenizer)
        optim.add_soft(Implies(bools[wik], bools[sent]), int(nli_out['entailment'] * z3mult))
        optim.add_soft(Not(Implies(bools[wik], bools[sent])), int(nli_out['contradiction'] * z3mult))

        # fact -> wik
        nli_out = run_nli(premise=assertion, hypothesis=wik, model=nli_model, tokenizer=nli_tokenizer)
        optim.add_soft(Implies(bools[sent], bools[wik]), int(nli_out['entailment'] * z3mult))
        optim.add_soft(Not(Implies(bools[sent], bools[wik])), int(nli_out['contradiction'] * z3mult))

    optim.check()
    mod = optim.model()
    
    new_beliefs[sent] = Proposition.from_sent(sent, boolean=bool(mod.evaluate(bools[sent])))

100%|███████████████████████████████████████| 1072/1072 [02:15<00:00,  7.93it/s]


In [12]:
evaluator.set_beliefs(new_beliefs)
print("F1:", evaluator.calculate_f1(facts))
print("Consistency:", evaluator.calculate_consistency())

F1: 0.862876249181329
Consistency: 0.980098481739844
