In [1]:
%load_ext autoreload
%autoreload 2
import experiment_setup
from belief.evaluation import load_facts
from belief.nli import load_nli_model, load_nli_tokenizer, run_nli
from belief.lmbb import Proposition, LMBB
from tqdm import tqdm
import json
import random
from z3 import *

In [2]:
facts = load_facts('data/calibration_facts.json', num_batches=1)[0]

with open('cache/wdik.json', 'r') as f:
    wdik = json.load(f)

with open('data/constraints_v2.json', 'r') as f:
    constraint_data = json.load(f)
    
with open('cache/raw_outs_calib.json', 'r') as f:
    raw_outs = json.load(f)['outs']

In [4]:
evaluator = LMBB(
    model=None, 
    tokenizer=None, 
    raw_constraints=constraint_data['links'],
)
nli_model = load_nli_model()
nli_tokenizer = load_nli_tokenizer()
raw_outs_dict = {}
for out in raw_outs:
    raw_outs_dict[out['prop']] = {'yes': out['yes'], 'no': out['no']}



In [5]:
z3mult = 1000
wdik_weight = 1.5
NUM_FACTS = 3
new_beliefs = {}

for prop in tqdm(facts):
    sent = prop.sentence
    raw_out = raw_outs_dict[sent]
    subject = prop.subject
    if NUM_FACTS > 0:
        wdik_facts = random.sample(wdik[subject], NUM_FACTS)
    else:
        wdik_facts = wdik[subject]
    
    bools = {}
    optim = Optimize()

    bools[sent] = Bool(sent)
    optim.add_soft(bools[sent], int(raw_out['yes'] * z3mult))
    optim.add_soft(Not(bools[sent]), int(raw_out['no'] * z3mult))

    assertion = prop.get_assertion()
    
    for wik in wdik_facts:
        bools[wik] = Bool(wik)
        optim.add_soft(bools[wik], wdik_weight * z3mult)

        # wik -> fact
        nli_out = run_nli(premise=wik, hypothesis=assertion, model=nli_model, tokenizer=nli_tokenizer)
        if nli_out['entailment'] > 0.9:
            optim.add_soft(Implies(bools[wik], bools[sent]), int(nli_out['entailment'] * z3mult))
        if nli_out['contradiction'] > 0.9:
            optim.add_soft(Not(Implies(bools[wik], bools[sent])), int(nli_out['contradiction'] * z3mult))

        # fact -> wik
        nli_out = run_nli(premise=assertion, hypothesis=wik, model=nli_model, tokenizer=nli_tokenizer)
        if nli_out['entailment'] > 0.9:
            optim.add_soft(Implies(bools[sent], bools[wik]), int(nli_out['entailment'] * z3mult))
        if nli_out['contradiction'] > 0.9:
            optim.add_soft(Not(Implies(bools[sent], bools[wik])), int(nli_out['contradiction'] * z3mult))

    optim.check()
    mod = optim.model()
    
    new_beliefs[sent] = Proposition.from_sent(sent, boolean=bool(mod.evaluate(bools[sent])))

100%|███████████████████████████████████████| 1072/1072 [02:09<00:00,  8.28it/s]


In [6]:
evaluator.set_beliefs(new_beliefs)
print("F1:", evaluator.calculate_f1(facts))
print("Consistency:", evaluator.calculate_consistency())

F1: 0.8716323246099944
Consistency: 0.9731226918342224


In [13]:
def exp_4(t):
    new_beliefs = {}
    evaluator = LMBB(
    model=None, 
    tokenizer=None, 
    raw_constraints=constraint_data['links'],
    )
    for prop in tqdm(facts):
        sent = prop.sentence
        raw_out = raw_outs_dict[sent]
        subject = prop.subject
        if NUM_FACTS > 0:
            wdik_facts = random.sample(wdik[subject], NUM_FACTS)
        else:
            wdik_facts = wdik[subject]

        bools = {}
        optim = Optimize()

        bools[sent] = Bool(sent)
        optim.add_soft(bools[sent], int(raw_out['yes'] * z3mult))
        optim.add_soft(Not(bools[sent]), int(raw_out['no'] * z3mult))

        assertion = prop.get_assertion()

        for wik in wdik_facts:
            bools[wik] = Bool(wik)
            optim.add_soft(bools[wik], wdik_weight * z3mult)

            # wik -> fact
            nli_out = run_nli(premise=wik, hypothesis=assertion, model=nli_model, tokenizer=nli_tokenizer)
            if nli_out['entailment'] > t:
                optim.add_soft(Implies(bools[wik], bools[sent]), int(nli_out['entailment'] * z3mult))
            if nli_out['contradiction'] > t:
                optim.add_soft(Not(Implies(bools[wik], bools[sent])), int(nli_out['contradiction'] * z3mult))

            # fact -> wik
            nli_out = run_nli(premise=assertion, hypothesis=wik, model=nli_model, tokenizer=nli_tokenizer)
            if nli_out['entailment'] > t:
                optim.add_soft(Implies(bools[sent], bools[wik]), int(nli_out['entailment'] * z3mult))
            if nli_out['contradiction'] > t:
                optim.add_soft(Not(Implies(bools[sent], bools[wik])), int(nli_out['contradiction'] * z3mult))

        optim.check()
        mod = optim.model()

        new_beliefs[sent] = Proposition.from_sent(sent, boolean=bool(mod.evaluate(bools[sent])))
    evaluator.set_beliefs(new_beliefs)
    print("F1:", evaluator.calculate_f1(facts))
    print("Consistency:", evaluator.calculate_consistency())

exp_4(0.8)
exp_4(0.7)


100%|███████████████████████████████████████| 1072/1072 [02:09<00:00,  8.27it/s]


F1: 0.8737864027475624
Consistency: 0.978046778826426


100%|███████████████████████████████████████| 1072/1072 [02:10<00:00,  8.20it/s]

F1: 0.8405315564582069
Consistency: 0.9741485432909315





In [14]:
exp_4(0.92)
exp_4(0.95)


100%|███████████████████████████████████████| 1072/1072 [02:10<00:00,  8.22it/s]


F1: 0.8662420331922188
Consistency: 0.9772260976610587


100%|███████████████████████████████████████| 1072/1072 [02:09<00:00,  8.28it/s]

F1: 0.8607594886454495
Consistency: 0.973533032416906





In [15]:
exp_4(0.99)

100%|███████████████████████████████████████| 1072/1072 [02:09<00:00,  8.29it/s]

F1: 0.8720238045065369
Consistency: 0.973533032416906



