In [1]:
import pandas as pd
import numpy as np
import torch
import json

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM

In [2]:
model_type = "google/flan-t5-xl"

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_type)
model = AutoModelForSeq2SeqLM.from_pretrained(model_type)
model = model.to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
def generate_answer(model, tokenizer, prompt):
    input_ids = tokenizer(prompt, return_tensors='pt').to(device)
    gen_text = model.generate(**input_ids, max_new_tokens=6, num_beams=1)
    gen_text_dec = tokenizer.decode(gen_text[0], skip_special_tokens=True)
    answer = gen_text_dec.split(' ')[0]
    return answer

In [5]:
def generate_batch_answers(model, tokenizer, prompts, batch_size=10):
    answers = []
    full_texts = []

    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i+batch_size]
        input_ids = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
        gen_texts = model.generate(**input_ids, max_new_tokens=6, num_beams=1)
        for gen_text in gen_texts:
            gen_text_dec = tokenizer.decode(gen_text, skip_special_tokens=True)
            answer = gen_text_dec.split(' ')[0]
            answers.append(answer)

    return answers


In [6]:
aug_setting = 'query' # the augmentation setting ('demons': 'D', 'demons+query': 'D and q', 'query': 'q')
selected_tasks = [1,12,14,16] # the tasks that we want to generate answers
batch_size = 60

# read dataset
for task in selected_tasks:
    data = []
    with open(f'../data/{aug_setting}_v1/task_{task}.jsonl') as f:
        for line in f:
            data.append(json.loads(line))
    
    num_examples = len(data)
    prompt_list = [data[i]['prompt'] for i in range(num_examples)]
    entity_list = [data[i]['entity'] for i in range(num_examples)]
    actual_ans_list = [data[i]['answer'] for i in range(num_examples)]
    gen_ans_list = generate_batch_answers(model, tokenizer, prompt_list, batch_size)
   

    df = pd.DataFrame()
    df['prompt'] = prompt_list
    df['entity'] = entity_list
    df['actual_ans'] = actual_ans_list
    df['gen_ans'] = gen_ans_list
    # modify the output direction to match dataset and bootstrap version of names
    df.to_csv(f'../results/bAbI/v_1_names/{aug_setting}/task_{task}.csv', index = False)