In [None]:
import os
import pandas as pd
from transformers import AutoTokenizer, AutoModelForCausalLM
from baukit import Trace
import torch
import matplotlib.pyplot as plt
import json

In [None]:
# import my modules
import importlib
# join the path to the modules to the current working directory

import utils, dataset_utils
importlib.reload(utils)
importlib.reload(dataset_utils)
from utils import *
from dataset_utils import *

In [None]:
# make folders plots, data results if they dont exist
for folder in ['plots', 'data', 'results']:
    if not os.path.exists(folder):
        os.makedirs(folder)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"device: {device}")

In [None]:
# load model
# model_name = "mistralai/Mistral-7B-v0.1"
model_name = "HuggingFaceH4/zephyr-7b-beta"
# model_name = "huggyllama/llama-7b"
# load tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.eos_token_id

In [None]:
# suppress system output
import warnings
warnings.filterwarnings('ignore')

In [None]:
def calc_success_rates(model, tokenizer, dataset_name, formats, max_new_tokens=5, batch_size=64):
    lie_formats = formats['lie_formats']
    truth_formats = formats['truth_formats']

    success_lies = {}
    success_truths = {}

    print(f"Testing {len(lie_formats)} lie formats for {dataset_name}...\n")
    for lie_format in lie_formats:
        dataset = load_data_set(dataset_name, lie_format=lie_format)
        _, answer_tokens_lie = generate_tokens(model, tokenizer, dataset['lie_scenario'], 
                                                            max_new_tokens=max_new_tokens, batch_size=batch_size, do_sample=False)

        success_lie = check_answer(tokenizer, answer_tokens_lie, dataset['true_answer'], batch_size=batch_size)
        success = 100-np.mean(success_lie)*100
        success_lies[lie_format] = success

        print(f"{lie_format}: \n\t success_rate: {success:.2f}%")


    print(f"\nTesting {len(truth_formats)} truth formats for {dataset_name}...\n")
    for truth_format in truth_formats:
        dataset = load_data_set(dataset_name, truth_format=truth_format)
        _, answer_tokens_truth = generate_tokens(model, tokenizer, dataset['truth_scenario'], 
                                                                max_new_tokens=max_new_tokens, batch_size=batch_size, do_sample=False)

        success_truth = check_answer(tokenizer, answer_tokens_truth, dataset['true_answer'], batch_size=batch_size)
        success = np.mean(success_truth)*100
        success_truths[truth_format] = success
        print(f"{truth_format}: \n\t success_rate: {success:.2f}%")

    print("\n\n")
    print("Sorted by success rate: \n")

    print_success_rates(success_lies, success_truths, dataset_name)


def print_success_rates(success_lies, success_truths, dataset_name):
    # sort formats by success rate
    success_lies = dict(sorted(success_lies.items(), key=lambda item: item[1], reverse=True))
    print(f"Success rates for lie formats for dataset {dataset_name}:")
    for lie_format, success in success_lies.items():
        print(f"{lie_format}\n\t success_rate: {success:.2f}%")

    success_truths = dict(sorted(success_truths.items(), key=lambda item: item[1], reverse=True))
    print(f"Success rates for truth formats for dataset {dataset_name}:")
    for truth_format, success in success_truths.items():
        print(f"{truth_format}\n\t success_rate: {success:.2f}%")



In [None]:
max_new_tokens = 5
batch_size = 64

# define prompt_formats
with open('data/formats_statements.json') as json_file:
    formats_statements = json.load(json_file)

In [None]:
calc_success_rates(model, tokenizer, 'Questions1000', formats_statements, max_new_tokens=max_new_tokens, batch_size=batch_size)

In [None]:
calc_success_rates(model, tokenizer, 'FreebaseStatements', formats_statements, max_new_tokens=max_new_tokens, batch_size=batch_size)
