In [None]:
import pandas as pd
import numpy as np
import os
import torch
import re
from transformers import AutoModelForCausalLM, AutoTokenizer


In [None]:
model = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3", device_map = 'mps')
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.3")


In [None]:
wino_bias = '../self_contra/wino_bias/'

In [None]:
wino_bias_files = os.listdir(wino_bias)
wino_bias_files

In [None]:
wino_bias_files.remove('male_occupations.txt')
wino_bias_files.remove('female_occupations.txt')
wino_bias_files

In [None]:

def get_mistral_pred(question):  
    messages = [
        {
            "role":"user",
            "content":f"""Question: {question} 
            Give your reason first, then answer, answer should be in maximum 3 to 4 words. 
            Follow the format: reason:[reason]:\nAnswer:[Answer]\n"""
        }
    ]


    model_inputs = tokenizer.apply_chat_template(messages, return_tensors = "pt").to('mps')
    pad_token_id = tokenizer.pad_token_id or tokenizer.eos_token_id
    attention_mask = model_inputs.ne(pad_token_id).long()

    generated_ids = model.generate(model_inputs, max_new_tokens = 1000, do_sample = True, 
                                   attention_mask=attention_mask, pad_token_id=pad_token_id)
    output = tokenizer.batch_decode(generated_ids, skip_special_tokens = True)[0]
    #print(output)
    if 'Answer:[Answer]' in output:
        index = output.index('Answer:[Answer]')
        a = output[index+25:]
        b = a.index('Answer')
        
        reason = a[:b]
        result = a[b+8:]
        
        return reason, result
    return None, None
        

In [None]:
wino_bias_files = ['pro_stereotyped_type1.txt.dev',
                  'pro_stereotyped_type2.txt.test']
for num, file_name in enumerate(wino_bias_files):
    print(num, file_name)
    ques = []
    gt = []
    pred = []
    reasons = []

    data = pd.read_csv(wino_bias + file_name, sep="\t", header=None)
    for i, row in data.iterrows(): #.loc[177:][0]:
        line = str(row.iloc[0])
        print(i)
        matches = re.findall(r'\[(.*?)\]', line)
        cleaned = re.sub(r'\d+', '', line)
        cleaned = re.sub(r'\[+', '', cleaned)
        cleaned = re.sub(r'\]+', '', cleaned)
        
        gt.append(matches[0])
        question = cleaned + " Who is '" + matches[1] + "'?"
        reason, result = get_mistral_pred(question)
        ques.append(question)
        reasons.append(reason)
        pred.append(result)
        
    df = pd.DataFrame(columns = ['ques', 'gt', 'reasons', 'pred'])
    df['ques'] = ques
    df['gt'] = gt
    df['reasons'] = reasons
    df['pred'] = pred
    df.to_csv('../self_contra/' + file_name + '.csv')
    
