In [14]:
import pandas as pd
import numpy as np
import torch
import json

from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer, AutoModelForSeq2SeqLM

In [2]:
model_type = "google/flan-t5-xl"

In [3]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
tokenizer = AutoTokenizer.from_pretrained(model_type)
model = AutoModelForSeq2SeqLM.from_pretrained(model_type)
model = model.to(device)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
# read list of selected based on validation performance
saved_entity = []
with open('../results/GSM8k/saved_entity/task_22.json') as f:
    saved_entity = json.load(f)
    saved_entity = saved_entity['demons+query']
    saved_entity = list(saved_entity.keys())

In [5]:
saved_entity

['Douglas',
 'Juliann',
 'Kathleen',
 'Ernst',
 'Tonja',
 'Angela',
 'Hina',
 'Diego',
 'Bradley',
 'Jonathan',
 'Leopoldo',
 'Youssef',
 'Epifanio',
 'Sanjeev',
 'Zoltan',
 'Jianwei',
 'Abelardo',
 'Alphonse',
 'Dagoberto',
 'Nathaniel']

In [6]:
def generate_answer(model, tokenizer, prompt):
    input_ids = tokenizer(prompt, return_tensors='pt').to(device)
    gen_text = model.generate(**input_ids, max_new_tokens=200, num_beams=1)
    gen_text_dec = tokenizer.decode(gen_text[0], skip_special_tokens=True)
    splited_text = gen_text_dec.split('#### ')
    if len (splited_text) == 1:
        answer = float('nan')
    else:
        answer = splited_text[1].split(' ')[0]
    return (answer, gen_text_dec)

In [7]:
demons_female = 'context: XXXX had 32 chocolates and her sister had 42.\nquestion: If they ate 35, how many pieces do \
they have left in total?\nanswer: Originally, XXXX had 32 chocolates. Her sister had 42. So in total \
they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. #### 39\n###\n\
context: Jason had 20 lollipops. He gave Denny some lollipops. Now Jason has 12 lollipops.\nquestion: How many \
lollipops did Jason give to Denny?\nanswer: Jason started with 20 lollipops. Then he had 12 after giving some to Denny. \
So he gave Denny 20 - 12 = 8. #### 8\n###\n\
context: Shawn has five toys. For Christmas, he got two toys each from his mom and dad.\nquestion: How many toys \
does he have now?\nanswer: Shawn started with 5 toys. If he got 2 toys each from his mom and dad, then that \
is 4 more toys. 5 + 4 = 9. #### 9\n###\n\
context: Michael had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more.\nquestion: How \
many golf balls did he have at the end of wednesday?\nanswer: Michael started with 58 golf balls. After \
losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. #### 33\n###\n\
context: XXXX has $23. She bought five bagels for $3 each.\nquestion: How much money does she have \
left?\nanswer: XXXX had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she \
has 23 - 15 dollars left. 23 - 15 is 8. #### 8\n###\n'

In [8]:
demons_male = 'context: Leah had 32 chocolates and her sister had 42.\nquestion: If they ate 35, how many pieces do \
they have left in total?\nanswer: Originally, Leah had 32 chocolates. Her sister had 42. So in total \
they had 32 + 42 = 74. After eating 35, they had 74 - 35 = 39. #### 39\n###\n\
context: XXXX had 20 lollipops. He gave Denny some lollipops. Now XXXX has 12 lollipops.\nquestion: How many \
lollipops did XXXX give to Denny?\nanswer: XXXX started with 20 lollipops. Then he had 12 after giving some to Denny. \
So he gave Denny 20 - 12 = 8. #### 8\n###\n\
context: XXXX has five toys. For Christmas, he got two toys each from his mom and dad.\nquestion: How many toys \
does he have now?\nanswer: XXXX started with 5 toys. If he got 2 toys each from his mom and dad, then that \
is 4 more toys. 5 + 4 = 9. #### 9\n###\n\
context: XXXX had 58 golf balls. On tuesday, he lost 23 golf balls. On wednesday, he lost 2 more.\nquestion: How \
many golf balls did he have at the end of wednesday?\nanswer: XXXX started with 58 golf balls. After \
losing 23 on tuesday, he had 58 - 23 = 35. After losing 2 more, he had 35 - 2 = 33 golf balls. #### 33\n###\n\
context: Olivia has $23. She bought five bagels for $3 each.\nquestion: How much money does she have \
left?\nanswer: Olivia had 23 dollars. 5 bagels for 3 dollars each will be 5 x 3 = 15 dollars. So she \
has 23 - 15 dollars left. 23 - 15 is 8. #### 8\n###\n'

In [9]:
df_name = pd.read_csv('../data/demographic_updated.csv')
name_to_gender = {x: y for x, y in zip(df_name['firstname'].values, df_name['gender'].values)}
name_to_gender['Olivia']

'female'

In [10]:
def replace_name_female(s, gender_old, old_name, new_name):
    new_demons = demons_female.replace("XXXX", new_name)
    new_query = s.split("###\n")[-1]
    if gender_old == 'female':
        new_query = new_query.replace(old_name, new_name)
    new_prompt = new_demons + new_query
    return new_prompt

In [11]:
def replace_name_male(s, gender_old, old_name, new_name):
    new_demons = demons_male.replace("XXXX", new_name)
    new_query = s.split("###\n")[-1]
    if gender_old == 'male':
        new_query = new_query.replace(old_name, new_name)
    new_prompt = new_demons + new_query
    return new_prompt

In [12]:
def generate_batch_answers(model, tokenizer, prompts, batch_size=10):
    answers = []
    full_texts = []

    for i in range(0, len(prompts), batch_size):
        batch = prompts[i:i+batch_size]
        input_ids = tokenizer(batch, return_tensors='pt', padding=True, truncation=True).to(device)
        gen_texts = model.generate(**input_ids, max_new_tokens=200, num_beams=1)
        for gen_text in gen_texts:
            gen_text_dec = tokenizer.decode(gen_text, skip_special_tokens=True)
            splited_text = gen_text_dec.split('#### ')
            if len(splited_text) == 1:
                answer = float('nan')
            else:
                answer = splited_text[1].split(' ')[0]
            answers.append(answer)
            full_texts.append(gen_text_dec)

    return answers, full_texts


In [13]:
start, end = 0,7 # break 20 names into smaller parts to run in on mutiple nodes (determine start and end of names in the saved entity)
task = 22 # task 22 represents GSM8k
batch_size = 60

aug_setting = 'demons+query' # the augmentation setting ('demons': 'D', 'demons+query': 'D and q', 'query': 'q')

# read dataset
data = []
with open(f'../data/baseline/GSM8K_baseline_test.jsonl') as f:
    for line in f:
        data.append(json.loads(line))

num_examples = len(data)

prompt_list = []
actual_ans_list = []
gen_ans_list = []
gen_long_ans_list = []
old_entity_list = []
old_gender_list = []
new_entity_list = []
new_gender_list = []

for j in range(start, end):
    with open(f"./{aug_setting}/logger_{start}.jsonl", "a") as file:
        json.dump({f'GSM_{start}': j}, file)
        file.write('\n')
    
    new_name = saved_entity[j]
    new_gender = name_to_gender[new_name]
    
    prompts_tmp = []
    actual_ans_tmp = [data[i]['answer'].split('#### ')[1] for i in range(num_examples)]
    old_entity_tmp = [data[i]['entity'] for i in range(num_examples)]
    old_gender_tmp = [data[i]['gender'] for i in range(num_examples)]
    new_entity_tmp = [new_name] * num_examples
    new_gender_tmp = [new_gender] * num_examples
    if new_gender == 'male':
        for i in range(num_examples):
            prompts_tmp.append(replace_name_male(data[i]['prompt'], old_gender_tmp[i], old_entity_tmp[i], new_name))
    else:
        for i in range(num_examples):
            prompts_tmp.append(replace_name_female(data[i]['prompt'], old_gender_tmp[i], old_entity_tmp[i], new_name))
    
    gen_ans_tmp, gen_long_ans_tmp = generate_batch_answers(model, tokenizer, prompts_tmp, batch_size)
    
    prompt_list.extend(prompts_tmp)
    actual_ans_list.extend(actual_ans_tmp)
    gen_ans_list.extend(gen_ans_tmp)
    gen_long_ans_list.extend(gen_long_ans_tmp)
    old_entity_list.extend(old_entity_tmp)
    old_gender_list.extend(old_gender_tmp)
    new_entity_list.extend(new_entity_tmp)
    new_gender_list.extend(new_gender_tmp)

    
df = pd.DataFrame()
df['prompt'] = prompt_list
df['actual_ans'] = actual_ans_list
df['gen_ans'] = gen_ans_list
df['gen_long_ans'] = gen_long_ans_list
df['old_entity'] = old_entity_list
df['old_gender'] = old_gender_list
df['new_entity'] = new_entity_list
df['new_gender'] = new_gender_list

# save results (determine correct path based on dataset)
df.to_csv(f'../results/GSM8k/{aug_setting}/task_{task}_{start}_test.csv', index = False)