In [None]:
import numpy as np

from llama2_utils import establish_llama2_endpoint
from simpleTQA import SimpleTQA 
from simpleFacts import SimpleFacts

In [None]:
model = "llama-2-7b-chat"
llama2_endpoint = establish_llama2_endpoint()

In [None]:
number_questions_to_answer = ...

def find_number_of_rows_with_questions_the_model_can_answer(number_answerable_questions_required, dataset, model):
    return np.where(dataset[f"{model}_can_answer"].cumsum() == number_answerable_questions_required)[0][0] + 1

In [None]:
def get_response_data(dataset, model):

    dataset.check_if_model_can_answer(
        model=model,
        model_kwargs={
            "endpoint": llama2_endpoint,
            "max_tokens": 64,
            "stop": "\n"
        },
        max_batch_size=20,
        save_progress=True,
        bypass_cost_check=True,
    )

    print(f"CAN_{model.upper()}_ANSWER COMPLETED CORRECTLY")

    answered_correctly = dataset[f"{model}_can_answer"].sum()
    attempted = dataset[f"{model}_can_answer"].count()
    print("Answered correctly: ", answered_correctly)
    print("Attempted: ", attempted)


    try:
        n_rows = find_number_of_rows_with_questions_the_model_can_answer(number_questions_to_answer, dataset, model)
    except IndexError:
        # the above gives index_error if there are less questions that the model can answer in the dataset than
        # number_questions_to_answer; in that case, fix n_rows to len(dataset)
        n_rows = len(dataset)

    dataset.does_model_lie(
        max_questions_to_try=n_rows,
        model=model,
        model_kwargs={"endpoint": llama2_endpoint, "max_tokens": 64, "stop": "\n"},
        # max_questions_to_try=10,
        max_batch_size=20,
        save_progress=True,
        bypass_cost_check=True,
    )

    print(f"DOES_{model.upper()}_LIE COMPLETED CORRECTLY")

    lying_rate = dataset.compute_lying_ratio()
    double_down_rate = dataset.compute_lie_double_down_rate()

    print("lying_rate", lying_rate)
    print("double_down_rate", double_down_rate)

    dataset.generate_logprobs(
        max_questions_to_try=n_rows,
        model_suspect=model,
        model_kwargs_suspect={"endpoint": llama2_endpoint, "max_tokens": 64, "stop": "\n"},
        # max_questions_to_try=10,
        save_progress=True,
    )

    print(f"GENERATE_{model.upper()}_LOGPROBS COMPLETED CORRECTLY")

In [None]:
simple_tqa = SimpleTQA()
get_response_data(simple_tqa, model)

In [None]:
simple_facts = SimpleFacts()
get_response_data(simple_facts, model)