In [18]:
import os
import random
import goodfire
import pandas as pd

In [19]:
api_key = os.getenv('API_KEY')
client = goodfire.Client(api_key)

# MODEL = "meta-llama/Meta-Llama-3-70B-Instruct"
MODEL = "meta-llama/Meta-Llama-3-8B-Instruct"
variant = goodfire.Variant(MODEL)

# Create Data

In [20]:
import json
with open("./pair_attacks.json", 'r') as f:
    pair_attacks = json.load(f)

### Hijacking Data

In [None]:
from time import sleep
from typing import Any
from src.goodfire_utils import clone_variant, run_conversation_through_goodfire
from src.utils import retry_with_backoff


def get_response_for_prompt(row: dict[str, Any]) -> dict:
    """
    row: {"preprompt": str, "attack": str, "bad": int}
    bad is 0 or 1, 1 meanind bad, 0 meaning good
    """
    conv = [
        {"role": "user", "content": row["prompt"]},
    ]

    # sleep to distribute load of API
    sleep(random.random() * 2)

    def get_response():
        # clone the model variant as it's stateful and we don't want
        # different api hits impacting each other
        cloned_variant = clone_variant(variant)
        return run_conversation_through_goodfire(conv, cloned_variant, client)

    # try to generate response, using backoff so we don't hit rate limits
    try:
        response = retry_with_backoff(get_response, max_retries=5)

        # return response with the label added
        return {
            "prompt": response["prompt"],
            "response": response["response"],
            "features": response["features"],
            "label": row["success"],
        }

    except Exception as e:
        return {"error": e}

In [22]:
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm

with ThreadPoolExecutor(30) as executor:
    conversation_responses = list(
        tqdm(
            executor.map(get_response_for_prompt, pair_attacks),
            total=len(pair_attacks),
        )
    )

errors = [r for r in conversation_responses if "error" in r]
print(f"Number of errors: {len(errors)}")

conversation_responses = [r for r in conversation_responses if "error" not in r]

100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 462/462 [06:10<00:00,  1.25it/s]

Number of errors: 30





In [23]:
errors

[{'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''},
 {'error': ''}]

In [17]:
from src.utils import save_as_npz
save_as_npz(conversation_responses, "./data/pair_attacks.npz")