In [1]:
import os
import random
import goodfire
import pandas as pd

In [2]:
api_key = os.getenv('API_KEY')
client = goodfire.Client(api_key)

# MODEL = "meta-llama/Meta-Llama-3-70B-Instruct"
MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
variant = goodfire.Variant(MODEL)

# Create Data

In [3]:
from pathlib import Path

DATA_DIR = Path("..") / Path("data")

In [4]:
prompts_hijack_attempt_file_path = DATA_DIR / "Instruction_Hijacking_Prompts_Bad_Full_Set.csv"
prompts_hijack_attempt_df = pd.read_csv(prompts_hijack_attempt_file_path)
# hijacking_prompts_badset_df.head()

In [5]:
prompts_correct_password_file_path = DATA_DIR / "Instruction_Hijacking_Prompts_Good_Full_Set_Password.csv"
prompts_correct_password_df = pd.read_csv(prompts_correct_password_file_path)
# hijacking_prompts_goodset_password_df.head()

In [6]:
prompts_random_password_file_path = DATA_DIR / "Instruction_Hijacking_Prompts_Good_Full_Set_Random.csv"
prompts_random_password_df = pd.read_csv(prompts_random_password_file_path)
# hijacking_prompts_goodset_random_df.head()

### Hijacking Data

In [7]:
# tries to hijack
prompts_hijack_attempt_df.count()

pre_prompt    280
attack        280
bad           280
dtype: int64

In [8]:
# has the correct_password
prompts_correct_password_df.count()

pre_prompt    280
attack        280
bad           280
dtype: int64

In [9]:
# wrong / random password
prompts_random_password_df.count()

pre_prompt    280
attack        280
bad           280
dtype: int64

In [10]:
prompts_hijack_attempt_df['class'] = "hijack_attempt"
prompts_random_password_df['class'] = "random_password"
prompts_correct_password_df['class'] = "correct_password"

all_prompts_df = pd.concat(
    [
        prompts_hijack_attempt_df,
        prompts_random_password_df,
        prompts_correct_password_df,
    ]
)

In [11]:
from time import sleep
from typing import Any
from goodfire_utils import clone_variant, run_conversation_through_goodfire
from utils import retry_with_backoff


def get_response_for_prompt(row: dict[str, Any]) -> dict:
    """
    row: {"preprompt": str, "attack": str, "bad": int}
    bad is 0 or 1, 1 meanind bad, 0 meaning good
    """
    conv = [
        {"role": "assistant", "content": row["pre_prompt"]},
        {"role": "user", "content": row["attack"]},
    ]

    # sleep to distribute load of API
    sleep(random.random() * 2)

    def get_response():
        # clone the model variant as it's stateful and we don't want
        # different api hits impacting each other
        cloned_variant = clone_variant(variant)
        return run_conversation_through_goodfire(conv, cloned_variant, client)

    # try to generate response, using backoff so we don't hit rate limits
    try:
        response = retry_with_backoff(get_response, max_retries=5)

        # return response with the label added
        return {
            "prompt": response["prompt"],
            "response": response["response"],
            "features": response["features"],
            "label": row["class"],
        }
    
    except Exception as e:
        return {"error": str(e)}


In [12]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

all_prompts_dicts = all_prompts_df.to_dict(orient="records")

with ThreadPoolExecutor(30) as executor:
    conversation_responses = list(
        tqdm(
            executor.map(get_response_for_prompt, all_prompts_dicts),
            total=len(all_prompts_dicts),
        )
    )

errors = [r for r in conversation_responses if "error" in r]
print(f"Number of errors: {len(errors)}")

conversation_responses = [r for r in conversation_responses if "error" not in r]

  0%|          | 0/840 [00:00<?, ?it/s]

100%|██████████| 840/840 [08:08<00:00,  1.72it/s]

Number of errors: 13





In [15]:
errors[0]['error']

'[Errno 8] nodename nor servname provided, or not known'

In [13]:
from utils import save_as_npz
save_as_npz(conversation_responses, DATA_DIR / "tensortrust_output.npz")