## Experiment 3: Using WDIK as feedback and constraint solving within the batch

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import experiment_setup

In [4]:
from belief.utils import load_macaw, load_tokenizer
from belief.utils import macaw_input, run_macaw, get_macaw_scores, get_macaw_outs
from belief.evaluation import load_facts
from belief.nli import load_nli_model, load_nli_tokenizer, run_nli
from belief.lmbb import Proposition, LMBB
from tqdm import tqdm
import json
import random
from z3 import *
from collections import defaultdict

In [5]:
facts = load_facts('data/calibration_facts.json', num_batches=1)[0]

with open('cache/wdik.json', 'r') as f:
    wdik = json.load(f)

with open('data/constraints_v2.json', 'r') as f:
    constraint_data = json.load(f)
    
with open('cache/nli_constraints.json', 'r') as f:
    nli_constraints = json.load(f)

In [6]:
facts_by_entity = defaultdict(list)
for fact in facts:
    facts_by_entity[fact.subject].append(fact)

In [7]:
evaluator = LMBB(
    model=None, 
    tokenizer=None, 
    raw_constraints=constraint_data['links'],
)

In [8]:
model = load_macaw()
tokenizer = load_tokenizer()

nli_model = load_nli_model()
nli_tokenizer = load_nli_tokenizer()




Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home2/abhijit.manatkar/miniconda3/envs/advnlp/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda116.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 116
CUDA SETUP: Loading binary /home2/abhijit.manatkar/miniconda3/envs/advnlp/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda116.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


In [26]:
z3mult = 1000
NUM_FACTS = 3
NUM_CONSTRAINTS = 200
constraint_mult = 1
both = True
nn = 50

In [27]:
yes_no = ['yes', 'no']

cumulative_ground_truth = []

for i, entity in enumerate(list(facts_by_entity.keys())):
    
    print(f"### Entity = {entity} ###")
    
    initial_outs = {}
    
    print("Getting initial outs from QA model...")
    for fact in tqdm(facts_by_entity[entity]):
        question = fact.get_question()
        context = ' '.join(random.sample(wdik[entity], NUM_FACTS))
        inp_str = macaw_input(question=question, options=yes_no, context=context, targets='A')
        initial_outs[fact.sentence] = get_macaw_scores(inp_str, yes_no, model, tokenizer)
    
    # print("Creating MaxSAT problem...")
    optim = Optimize()
    bools = {}
    for sent in initial_outs:
        bools[sent] = Bool(sent)
        if both:
            optim.add_soft(bools[sent], int(initial_outs[sent]['yes'] * z3mult))
            optim.add_soft(Not(bools[sent]), int(initial_outs[sent]['no'] * z3mult))
        else:
            if initial_outs[sent]['yes'] > initial_outs[sent]['no']:
                optim.add_soft(bools[sent], int(initial_outs[sent]['yes'] * z3mult))
            else:
                optim.add_soft(Not(bools[sent]), int(initial_outs[sent]['no'] * z3mult))
    
    constraints = []
    
    print("Getting NLI outs...")
    for fact1 in facts_by_entity[entity]:
        for fact2 in facts_by_entity[entity]:
            if fact1.sentence == fact2.sentence:
                continue
            
            constraint = nli_constraints[fact1.sentence][fact2.sentence]
            
            constraints.append({
                "type": "entailment",
                "src": fact1.sentence,
                "dest": fact2.sentence,
                "weight": constraint['entailment']
            })
            
            constraints.append({
                "type": "contradiction",
                "src": fact1.sentence,
                "dest": fact2.sentence,
                "weight": constraint['contradiction']
            })

    constraints.sort(key=lambda c : -c['weight'])
    for constraint in constraints:
        if constraint['type'] == 'entailment' and constraint['weight'] > 0.9:
            optim.add_soft(
                Implies(bools[constraint['src']], bools[constraint['dest']]),
                int(constraint['weight'] * constraint_mult * z3mult)
            )
        elif constraint['type'] == 'contradiction' and constraint['weight'] > 0.9:
            optim.add_soft(
                Implies(bools[constraint['src']], Not(bools[constraint['dest']])),
                int(constraint['weight'] * constraint_mult * z3mult)
            )
    
    updated_beliefs = {}     
    
    print("Solving MaxSAT problem...")
    optim.check()
    mod = optim.model()
    
    for fact in facts_by_entity[entity]:
        new_fact = copy.deepcopy(fact)
        new_fact.boolean = bool(mod.evaluate(bools[fact.sentence]))
        updated_beliefs[new_fact.sentence] = new_fact
        
    evaluator.set_beliefs(updated_beliefs)
    cumulative_ground_truth += facts_by_entity[entity]
    f1 = evaluator.calculate_f1(cumulative_ground_truth)
    consistency = evaluator.calculate_consistency()
    print(f"F1 = {f1}, Consistency = {consistency}")
    print()
            

### Entity = adder ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 139/139 [01:13<00:00,  1.88it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.6538461489571006, Consistency = 0.8974148543290932

### Entity = albatross ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 156/156 [01:23<00:00,  1.88it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.7259259210491085, Consistency = 0.906442347148133

### Entity = daffodil ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 143/143 [01:16<00:00,  1.87it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.7448979542956061, Consistency = 0.9113664341403365

### Entity = cypress ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 163/163 [01:26<00:00,  1.88it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.7438596442848877, Consistency = 0.939474764054165

### Entity = ape ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 135/135 [01:11<00:00,  1.90it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.7669616470340496, Consistency = 0.961427985227739

### Entity = computer ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 189/189 [01:39<00:00,  1.90it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.7526881672282113, Consistency = 0.9534263438654083

### Entity = ant ###
Getting initial outs from QA model...


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 147/147 [01:17<00:00,  1.90it/s]


Getting NLI outs...
Solving MaxSAT problem...
F1 = 0.7457627070627499, Consistency = 0.964710709889208

