In [145]:
%load_ext autoreload
%autoreload 2

In [37]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Config, AutoModelForSequenceClassification
from belief.evaluation import load_facts
import random
import string
from belief.macaw_utils import decompose_slots, compute_answer, run_model_with_outputs
import time
import json

In [38]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [4]:
tokenizer = AutoTokenizer.from_pretrained("allenai/macaw-large")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/macaw-large", load_in_8bit=True, device_map='auto')




Welcome to bitsandbytes. For bug reports, please run

python -m bitsandbytes

 and submit this information together with your error trace to: https://github.com/TimDettmers/bitsandbytes/issues
bin /home2/abhijit.manatkar/miniconda3/envs/advnlp/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda116.so
CUDA SETUP: CUDA runtime path found: /usr/local/cuda/lib64/libcudart.so.11.0
CUDA SETUP: Highest compute capability among GPUs detected: 7.5
CUDA SETUP: Detected CUDA version 116
CUDA SETUP: Loading binary /home2/abhijit.manatkar/miniconda3/envs/advnlp/lib/python3.8/site-packages/bitsandbytes/libbitsandbytes_cuda116.so...


  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
  warn(msg)
Either way, this might cause trouble in the future:
If you get `CUDA error: invalid device function` errors, the above might be the cause and the solution is to make sure only one ['libcudart.so', 'libcudart.so.11.0', 'libcudart.so.12.0'] in the paths that we search based on your env.
  warn(msg)


In [5]:
torch.cuda.empty_cache()

In [39]:
facts = load_facts('./data/calibration_facts.json', num_batches=1)[0]
silver_facts = load_facts('./data/silver_facts.json', num_batches=1)[0]

In [7]:
def macaw_input(question="", answer="", options=(), explanation="", context="", targets='AE'):
    
    if len(question) > 0:
        question_str = '$question$ = ' + question  + " ; "
    elif 'Q' in targets:
        question_str = "$question$ ; "
    
    if len(explanation) > 0:
        explanation_str = "$explanation$ = " + explanation + " ; "
    elif 'E' in targets:
        explanation_str = "$explanation$ ; "
    else:
        explanation_str = ""
    
    if len(answer) > 0:
        answer_str = "$answer$ = " + answer 
        if len(context) > 0:
            answer_str += " ; "
    elif 'A' in targets:
        answer_str = "$answer$"
        if len(context) > 0:
            answer_str += " ; "
    else:
        answer_str = ""
    
    if len(context) > 0:
        context_str = "$context$ = " + context
    else:
        context_str = ""
    
    letters = list(string.ascii_uppercase)
    if len(options) > 0:
        option_str = "$mcoptions$ = "
        for letter, option in zip(letters, options):
            option_str += f"({letter}) {option} "
        option_str += "; "
    elif 'M' in targets:
        option_str = "$mcoptions$ ; "
    else:
        option_str = ""
    
    return f"{question_str}{explanation_str}{option_str}{answer_str}{context_str}"

def run_macaw(input_str, model, tokenizer):
    input_ids = tokenizer.encode(input_str, return_tensors="pt").to(device)
    outs = model.generate(input_ids, max_length=500, early_stopping=True)
    return tokenizer.batch_decode(outs, skip_special_tokens=True)

In [8]:
fact = random.choice(facts)
fact

(albatross,IsA,expert, False, -99999.0)

### Facts generation

In [10]:
def get_questions(entity):

    a_an = "an" if entity[0] in ['a', 'e', 'i', 'o', 'u'] else "a"
    
    what_is = f"What is {a_an} {entity}?"
    made_of = f"What is {a_an} {entity} made of?"
    capable_of = f"What is {a_an} {entity} capable of?"
    has_what_part = f"What parts does {a_an} {entity} have?"
    has_what_property = f"What properties does {a_an} {entity} have?"
    which_category = f"Which category does {a_an} {entity} belong to?"

    return [what_is, made_of, capable_of, has_what_part, has_what_property, which_category]

def get_qa_pairs(entity):
    
    questions = get_questions(entity)

    inpstrs = [macaw_input(targets='AE', question=question) for question in questions]
    inpids = tokenizer(inpstrs, truncation=True, padding=True, return_tensors="pt").input_ids.to(device)
    
    num_beams = 3
    num_return_sequences = 3
    
    out = model.generate(
        input_ids=inpids, 
        max_length=500,
        num_beams=num_beams,
        num_return_sequences=num_return_sequences,
        early_stopping=True
    )
    
    out_text = tokenizer.batch_decode(out, skip_special_tokens=True)
    
    qas = set()
    for i,q in enumerate(questions):
        for j in range(3*i, 3*(i+1)):
            out_str = out_text[j]
            slots = decompose_slots(out_str)
            ans = slots['answer'] if 'answer' in slots else ''
            qas.add((q,ans))
    
    return list(qas)


### QA to Declarative Sentence

In [27]:
weights = torch.load('./t5-statement-conversion-finetune.pt')['model']

In [28]:
config = T5Config.from_pretrained('t5-base')
qads = T5ForConditionalGeneration(config).to(device)

In [29]:
qads.load_state_dict(weights)

<All keys matched successfully>

In [30]:
def run_qads(qas):
    inpstrs = [q + " " + a for (q,a) in qas]
    inpids = tokenizer(inpstrs, truncation=True, padding=True, return_tensors="pt").input_ids.to(device)
    out = qads.generate(inpids, max_length=500)
    out_text = tokenizer.batch_decode(out, skip_special_tokens=True)
    return out_text

In [32]:
qas = get_qa_pairs('ant')
run_qads(qas)

['An ant has specialized cells for feeding on insects.',
 'An ant has properties of cells.',
 'An ant belongs to the predator category.',
 'An ant has eyes, nymphs, and body.',
 'An ant is capable of reproduction.',
 'An ant is an animal.',
 'An ant is made of cellulose.',
 'An ant has eyes, nymphs, and larvae parts.',
 'An ant belongs to the category predators.',
 'An ant is capable of stinging.',
 'An ant has specialized cells for protection.',
 'An ant is capable of sex reproduction.',
 'An ant is a kind of worm.',
 'An ant belongs to the category of insects.',
 'An ant is made of venom.',
 'An ant is a kind of lizard.',
 'An ant has specialized specialized cells for healing.',
 'An ant is made of specialized cells']

In [39]:
ents = list(set([fact.subject for fact in facts]) | set([fact.subject for fact in silver_facts]))

In [60]:
ent_facts = {}

start_t = time.time()
for ent in ents:
    qas = get_qa_pairs(ent)
    decs = run_qads(qas)
    ent_facts[ent] = decs
end_t = time.time()

print(f'Finished generating WDIK statements in {round(end_t - start_t, 2)} seconds')

Finished generating WDIK statements in 1019.94 seconds


In [65]:
with open('./wdik.json', 'w') as f:
    json.dump(ent_facts, f)

### NLI

In [66]:
nli_tokenizer = AutoTokenizer.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
nli_model = AutoModelForSequenceClassification.from_pretrained("MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli").to(device)



In [67]:
def run_nli(premise, hypothesis):
    input_ids = nli_tokenizer.encode(premise, hypothesis, truncation=True, return_tensors="pt").to(device)
    output = nli_model(input_ids)

    prediction = torch.softmax(output.logits[0], -1).tolist()
    label_names = ["entailment", "neutral", "contradiction"]
    prediction = {name: round(float(pred) , 3) for pred, name in zip(prediction, label_names)}

    return prediction

In [68]:
fact.get_nl_sentence()

'ant is not a plastic.'

In [122]:
fact = random.choice(facts)
print(fact)
print()
qas = get_qa_pairs(fact.subject)
wdik = run_qads(qas)

for wik in wdik:
    prop = fact.get_nl_sentence()
    print(wik + " -> " + prop)
    print(run_nli(premise=wik, hypothesis=prop))
    print()
    print(prop + " -> " + wik)
    print(run_nli(premise=prop, hypothesis=wik))
    print()

(ant,IsA,road, False, -99999.0)

An ant is a kind of worm. -> ant is not a road.
{'entailment': 0.689, 'neutral': 0.141, 'contradiction': 0.17}

ant is not a road. -> An ant is a kind of worm.
{'entailment': 0.004, 'neutral': 0.989, 'contradiction': 0.007}

An ant is made of venom. -> ant is not a road.
{'entailment': 0.177, 'neutral': 0.275, 'contradiction': 0.548}

ant is not a road. -> An ant is made of venom.
{'entailment': 0.001, 'neutral': 0.988, 'contradiction': 0.011}

An ant is capable of sex reproduction -> ant is not a road.
{'entailment': 0.221, 'neutral': 0.449, 'contradiction': 0.33}

ant is not a road. -> An ant is capable of sex reproduction
{'entailment': 0.003, 'neutral': 0.989, 'contradiction': 0.008}

An ant has specialized cells for feeding and protection. -> ant is not a road.
{'entailment': 0.563, 'neutral': 0.159, 'contradiction': 0.279}

ant is not a road. -> An ant has specialized cells for feeding and protection.
{'entailment': 0.001, 'neutral': 0.998, 'contr

In [96]:
print("Premises:", wdik)
print("Hypothesis:", fact.get_nl_sentence())

Premises: ['An albatross is a sea animal.', 'An albatross is made of feathers.', 'An albatross is capable of laying eggs.', 'An albatross has gills, talons, and feet.', 'An albatross has blubber and fur.', 'An albatross belongs to the mammals category.']
Hypothesis: albatross is capable of fly.


In [104]:
run_nli(wdik[1], fact.get_nl_sentence())

{'entailment': 0.004, 'neutral': 0.994, 'contradiction': 0.002}

### Macaw scoring

In [11]:
def get_output_strings_from_options(options):
    l = []
    for option in options:
        l.append((f"$answer$ = {option}", option))
    return l

def get_macaw_scores(input_string, options):
    out_str = get_output_strings_from_options(options)
    res = run_model_with_outputs(model, tokenizer, device, inp_str, out_str)
    scores = {}
    for r in res:
        scores[r['output_text']] = r['score']
    return scores

In [15]:
prop = random.choice(facts)
print(prop)

question = prop.get_nl_question()
print(question)

yesno = ("yes", "no")
yesno2 = [("$answer$ = yes", "yes"), ("$answer$ = no", "no")]

yesnomaybe = ("yes", "no", "maybe")

options = ("yes", "no")

inp_str = macaw_input(question=question, options=options, targets='A')
print(inp_str)

# compute_answer(model, tokenizer, device, inp_str, generator_options={})

# res = run_model_with_outputs(model, tokenizer, device, inp_str, yesno2)
# scores = {}
# for r in res:
#     scores[r['output_text']] = r['score']
# print(scores)
# print()

print(get_macaw_scores(inp_str, options))

run_macaw(inp_str)

(ant,IsA,candy, False, -99999.0)
Is a ant a candy?
$question$ = Is a ant a candy? ; $mcoptions$ = (A) yes (B) no ; $answer$
{'yes': 0.47121472070395004, 'no': 0.997046073721489}


['$answer$ = no']

In [75]:
from belief.utils import noun_fluenterer
with open('belief/templates.json', 'r') as f:
    templates = json.load(f)

In [86]:
fact = random.choice(facts)
print(fact)
fact.get_positive_assertion()

(daffodil,HasPart,floral leaf, True, -99999.0)


'daffodil has floral leaf.'

'a daffodil has a floral leaf.'