# Set up api and dataset

In [None]:
!pip install goodfire

In [None]:
try:
  from google.colab import userdata

  # Add your Goodfire API Key to your Colab secrets
  GOODFIRE_API_KEY = userdata.get('GOODFIRE_API_KEY')
except:
  keyfile = open("../goodfire_apikey.txt", "r")
  GOODFIRE_API_KEY = keyfile.readline()

import goodfire
client = goodfire.Client(
    GOODFIRE_API_KEY
  )

# Instantiate a model variant
variant_small = goodfire.Variant("meta-llama/Meta-Llama-3-8B-Instruct")
variant_large = goodfire.Variant("meta-llama/Meta-Llama-3.1-70B-Instruct")

In [None]:
# first try to install datasets
!pip install datasets

In [None]:
# try tqdm (it doesnt seem to work for some reason)
!pip install tqdm
from tqdm.notebook import tqdm
import time

for i in tqdm(range(100)):
    time.sleep(0.01)

In [None]:
# next try to download the triva dataset
from datasets import load_dataset

ds = load_dataset("mandarjoshi/trivia_qa", "unfiltered.nocontext")

# Run the on question suppression

run the on question suppression for each model at each nudge level (will need to see what combination of nudge + num_features produces the best results)

In [None]:
prompt = "You are an AI assitant answering trivia questions. Give a single concise answer."
nudge_values = [0.1, 0.25, 0.5, 0.9] # need to check these are reasonable
num_features = 3
timeout = 2

for nudge_amount in nudge_values:
    num_correct = 0
    for index, sample in enumerate(tqdm(ds["train"])):
        sample_question =  sample["question"]
        sample_answers = sample['answer']['normalized_aliases']

        # find the features accociated with the corrct answer
        # note: this could be changed to a differnt method like inspect
        nudged_features, relevance = client.features.search(
            sample_answers[0], # should probably be random or something else
            model=variant_small,
            top_k=num_features
        )

        # now set the features
        variant_small.reset()
        variant_small.set(nudged_features[0:num_features], nudge_amount, mode="nudge")

        # now get the model response
        response = client.chat.completions.create(
            [
                {"role": "system", "content": prompt},
                {"role": "user", "content": sample_question}
            ],
            model=variant_small,
            stream=False,
            max_completion_tokens=50,
        )

        given_answer = response.choices[0].message["content"].lower()

        #given_answer = ""
        #for token in response:
        #    given_answer += token.choices[0].delta.content
        #given_answer = given_answer.lower()

        for answer in sample_answers:
            if answer in given_answer:
                num_correct += 1
                break

        if index > 300:
            break

        # make sure not to spam the api
        time.sleep(timeout)

    print(f"nudge value:{nudge_amount}")
    print(num_correct)
    print(index)
    print(num_correct/index)


In [None]:
nudge_values = [0.1, 0.25, 0.5, 0.9] # need to check these are reasonable
num_features = 3
timeout = 2

for nudge_amount in nudge_values:
    num_correct = 0
    for index, sample in enumerate(tqdm(ds["train"])):
        sample_question =  sample["question"]
        sample_answers = sample['answer']['normalized_aliases']

        # find the features accociated with the corrct answer
        # note: this could be changed to a differnt method like inspect
        nudged_features, relevance = client.features.search(
            sample_answers[0], # should probably be random or something else
            model=variant_large,
            top_k=num_features
        )

        # now set the features
        variant_large.reset()
        variant_large.set(nudged_features[0:num_features], nudge_amount, mode="nudge")

        # now get the model response
        response = client.chat.completions.create(
            [
                {"role": "system", "content": prompt},
                {"role": "user", "content": sample_question}
            ],
            model=variant_large,
            stream=False,
            max_completion_tokens=50,
        )

        given_answer = response.choices[0].message["content"].lower()

        #given_answer = ""
        #for token in response:
        #    given_answer += token.choices[0].delta.content
        #given_answer = given_answer.lower()

        for answer in sample_answers:
            if answer in given_answer:
                num_correct += 1
                break

        if index > 300:
            break

        # make sure not to spam the api
        time.sleep(timeout)

    print(f"nudge value:{nudge_amount}")
    print(num_correct)
    print(index)
    print(num_correct/index)
