In [23]:
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from belief.evaluation import load_facts
import random
import string

In [3]:
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

In [5]:
tokenizer = AutoTokenizer.from_pretrained("allenai/macaw-large")
model = AutoModelForSeq2SeqLM.from_pretrained("allenai/macaw-large", load_in_8bit=True, device_map='auto')



In [96]:
torch.cuda.empty_cache()

In [8]:
facts = load_facts('./data/calibration_facts.json', num_batches=1)[0]

In [84]:
def macaw_input(question="", answer="", options=(), explanation="", context="", targets='AE'):
    
    if len(question) > 0:
        question_str = '$question$ = ' + question  + " ; "
    elif 'Q' in targets:
        question_str = "$question$ ; "
    
    if len(explanation) > 0:
        explanation_str = "$explanation$ = " + explanation + " ; "
    elif 'E' in targets:
        explanation_str = "$explanation$ ; "
    else:
        explanation_str = ""
    
    if len(answer) > 0:
        answer_str = "$answer$ = " + explanation + " ; "
    elif 'A' in targets:
        answer_str = "$answer$ ; "
    else:
        answer_str = ""
    
    if len(context) > 0:
        context_str = "$context$ = " + context
    else:
        context_str = ""
    
    letters = list(string.ascii_uppercase)
    if len(options) > 0:
        option_str = "$mcoptions$ = "
        for letter, option in zip(letters, options):
            option_str += f"({letter}) {option} "
        option_str += "; "
    elif 'M' in targets:
        option_str = "$mcoptions$ ; "
    else:
        option_str = ""
    
    return f"{question_str}{explanation_str}{option_str}{answer_str}{context_str}"

In [90]:
fact = random.choice(facts)
fact

(ant,IsA,house, False, -99999.0)

In [94]:
inpstr = macaw_input(targets='AE', question='What is an ant capable of?')
print(inpstr)

$question$ = What is an ant capable of? ; $explanation$ ; $answer$ ; 


In [95]:
inpids = tokenizer.encode(inpstr, return_tensors="pt").to(device)
output = model.generate(inpids, max_length=500)
tokenizer.batch_decode(output, skip_special_tokens=True)

['$explanation$ = An insect can perform photosynthesis. Photosynthesis is a source of food for the plant by converting carbon dioxide, water, and sunlight into carbohydrates. ; $answer$ = absorbing light']