# Introduction

This is the notebook responsible for calling the goodfire API.

We have a simple wrapper designed to form the right prompt and analyse the answer, and some utilities for running experiments over a range of parameters. All experimental results are dumped to csvs in the data/ folder, where they can be read by `analysis.ipynb`

# Setup

In [1]:
GOODFIRE_API_KEY = open("API_KEY.txt", "r").read().strip()

In [2]:
import goodfire
from goodfire import AsyncClient

client = AsyncClient(
    GOODFIRE_API_KEY
  )

# Instantiate a model variant
#base = "meta-llama/Meta-Llama-3-8B-Instruct"
#base = "meta-llama/Meta-Llama-3.1-70B-Instruct"
base = "meta-llama/Llama-3.3-70B-Instruct"
#base = "meta-llama/Meta-Llama-3.1-8B-Instruct"
variant = goodfire.Variant(base)


# Jailbreak the variant using feature "Requests to bypass ethical and moral constraints"
# This can be useful to reduce the frequency of refusals
#variant.set(client.features.list(ids=["3b0f849c-b42a-4ce3-a89e-326e6569d51a"]), -0.5)

feature_counts = {
    "meta-llama/Llama-3.3-70B-Instruct": 65536,
    "meta-llama/Meta-Llama-3.1-8B-Instruct": 65536,
    #65536  for 8B
}
feature_count = feature_counts[base]

# Constants

In [3]:
import json
import shared


#QuestionSource = "OUS"
QuestionSource = "GGB"
Questions = shared.get_questions(QuestionSource)

# How to use the questions in a chatbot prompt.
QuestionTemplate = "\"{0}\"\n\nPlease rate the preceeding statement on 1 to 7 point Likert scale with 1 - Strongly Disagree, 7 - Strongly Agree. Print a single digit then end the message."
MAX_WORKERS=3

In [4]:
# # Persona Nationality 
# from shared import nationality_to_continent

# nationalities = [k for k,_ in nationality_to_continent.items()]
# genders = ['male', 'female']
# ages = [f'{age} year old' for age in ['21', '30', '40', '50', '60', '75']]

# PersonaQuestionTemplate = "Consider the following statement as a {1} person: "+QuestionTemplate
# # OUS_PersonaTemplate = "Imagine you are a {1}.\n\n"+OUS_Template



# Utility Code

In [5]:
from dataclasses import dataclass
import re
from typing import Union
import concurrent.futures as futures
from pprint import pprint
import numpy as np
from scipy.special import softmax
import asyncio
import time
from typing import List, Dict


@dataclass
class RateLimiter:
    requests_per_minute: int
    _request_times: List[float] = None
    
    def __post_init__(self):
        self._request_times = []
    
    async def acquire(self) -> float:
        now = time.time()
        minute_ago = now - 60
        
        # remove timestamps older than 1 minute
        self._request_times = [t for t in self._request_times if t > minute_ago]
        
        if len(self._request_times) >= self.requests_per_minute:
            # wait until oldest request expires
            wait_time = self._request_times[0] - minute_ago
            await asyncio.sleep(wait_time)
            now = time.time()  # Update time after sleeping
            
        self._request_times.append(now)
        return now

@dataclass
class Response:
    question: Union[int, str]
    score: float
    text: str
    logits: dict[int, float]
    mean: float
    stddev: float
    source: str

# sem = asyncio.Semaphore(5)
# async def run_question_with_sem(sem, question: Union[int, str], model=None, persona=None, progress=None):
#     await asyncio.sleep(2)

#     async with sem:
#         return await run_question(question , model, persona, progress)

async def run_question(limiter: RateLimiter, question: Union[int, str], model=None, persona=None, progress=None) -> Response:
    await limiter.acquire()
    orig_question = question
    if isinstance(question, int):
        question = Questions[question]
    model = model or variant
    if persona:
        user_text = PersonaQuestionTemplate.format(question["statement"], persona)
    else:
        user_text = QuestionTemplate.format(question)
    messages = [
            # System prompt discourages refusal
            {"role": "system", "content": "Always answer the user's question."},
            {"role": "user", "content": user_text},
            # Encourages a single value response. Also discourages refusal?
            {"role": "assistant", "content": "I'd rate this statement: "}
        ]
    response = await client.chat.completions.create(
        messages,
        model=model,
        max_completion_tokens=50,
        temperature=0
    )
    text = response.choices[0].message["content"]
    score = None
    # Try some heuristics for finding the score
    match = (
        re.search(r"(\d) out of 7", text) or
        re.search(r"(\d)", text)
    )
    if match:
        try:
            score_text = match.group(1)
            score = int(score_text)
        except:
            pass

    logits = None
    mean = None
    stddev = None
    if score is not None:
        # Attempt to get logits
        logit_messages = messages + [{"role": "assistant", "content": match.string[:match.start(1)]}]
        logits = await client.chat.logits(logit_messages,
            model=model,
            top_k=100, #  has to be reasonably large so we don't drop anything significant
            filter_vocabulary=list('1234567')
        )
        logits = {int(k): v for k,v in logits.logits.items() if k in '1234567'}
        if logits:
            probs = dict(zip(logits.keys(), softmax(np.array(list(logits.values())))))
            mean = np.sum([k*v for k,v in probs.items()])
            stddev = np.sqrt(np.sum([v * (k - mean)**2 for k,v in probs.items()]))

    if progress:
        progress.update()
    return Response(question=orig_question, score=score, text=text, logits=logits, mean=mean, stddev=stddev, source=QuestionSource)


# async def run_questions_with_sem(sem=sem, *args, **kwargs):
#     async with sem:
#         return await run_questions(*args, **kwargs)

async def run_questions(*args, **kwargs) -> list[Response]:
        limiter = RateLimiter(requests_per_minute=200)
        async with asyncio.TaskGroup() as tg:
            # async with sem:
            # tasks = [tg.create_task(run_question_with_sem(sem, q, *args, **kwargs)) for q in range(len(Questions))]
            tasks = [tg.create_task(run_question(limiter, q, *args, **kwargs)) for q in range(len(Questions))]

        return [await task for task in tasks]
    
def to_vector(responses: list[Response]) -> np.array:
    return np.array([r.mean if r.mean is not None else np.nan for r in responses])

import datetime

def now_str():
    return datetime.datetime.now().strftime("%Y%m%d%H%M%S")

def clone(variant: goodfire.Variant) -> goodfire.Variant:
    new_variant = goodfire.Variant(variant.base_model)
    for edit in variant.edits:
        new_variant.set(edit[0], edit[1]['value'], mode=edit[1]['mode'])

    return new_variant

In [6]:
# Some testing
#q = run_question(1)
#print(q)
#qs = run_questions()
#pprint(qs)
#print(to_vector(qs))

In [7]:
from typing import Optional
import tqdm
import time
import pandas as pd

async def tabular_experiments(features: list[goodfire.Feature], steerages: list[float], personas: Optional[list[str]] = None, wait: Optional[float]=None, base=base):
    if personas is None:
        personas = [None]
    results = []
    async with asyncio.TaskGroup() as tg:
        combinations = []
        for feature in features:
            for steerage in steerages:
                model = goodfire.Variant(base)
                if feature is None:
                    assert steerage == 0
                else:
                    model.set(feature, steerage)
                for persona in personas:
                    combinations.append((feature, steerage, persona))
        tasks = []
        progress = tqdm.tqdm(total=len(combinations) * len(Questions))
        for combination in combinations :
            feature, steerage, persona = combination
            # task = tg.create_task(run_questions_with_sem(sem, persona=persona, model=model, progress=progress))
            task = tg.create_task(run_questions(persona=persona, model=model, progress=progress))

            tasks.append((feature, steerage, persona, task))
            # TODO: Remove once we get parallellism working better
            await task
            
        for feature, steerage, persona, task in tasks:
            responses: list[Response] = task.result()
            if wait:
                time.sleep(wait)
            for response in responses:
                results.append(dict(
                    base=base,
                    source=response.source,
                    feature=feature.label if feature else "",
                    steerage=steerage,
                    persona=persona,
                    question=response.question,
                    mean_score=response.mean,
                    stddev_score=response.stddev,
                    score=response.score,
                    text=response.text,
                ))
    return pd.DataFrame(results)

## Features of interest

In [8]:
moral_keywords = ['moral', 'altruism', 'greater good', 'ethic', 'integrity', 'dignity']

# Experiments

In [None]:
import time

start_time = time.time()

for keyword in moral_keywords[:1]:
    print(f'Running search and steering for features associated with "{keyword}"\n')
    features = list((await client.features.search(keyword, model=base, top_k=10)))
    steerages = [-.5, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.5]
    experiments = await tabular_experiments(features, steerages)
    experiments.to_csv("data/" + now_str()+''.join(keyword)+".csv", index=False)
    end_time = time.time()
    print(f'Time take for {keyword} -> {end_time-start_time}')


Running search and steering for features associated with "moral"



  1%|          | 76/8100 [00:06<03:28, 38.49it/s]  Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...
  1%|          | 90/8100 [00:20<03:28, 38.49it/s]Rate limit exceeded. Attempting exponential backoff...
Rate limit exceeded. Attempting exponential backoff...


In [None]:
# Run baseline
if True:
    features = [None]
    steerages = [0]
    experiments = await tabular_experiments(features, steerages)
    experiments.to_csv("data/" + now_str()+".csv", index=False)

In [None]:
# Run some random features
if False:
    features = list(client.features.search("elephants", model=base, top_k=1)[0])
    steerages = [-0.8, -0.5, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.5, 0.8]
    personas = [0]
    experiments = tabular_experiments(features, steerages, personas)
    experiments.to_csv("data/" + now_str()+".csv", index=False)

In [None]:
# persona test
if False:
    features = list(client.features.search("moral", model=base, top_k=5)[0])
    steerages = [0]
    persona_tags = ['nationalities', 'ages', 'genders']
    for i, personas in enumerate([nationalities, ages, genders]):
        experiments = tabular_experiments(features[:1], steerages, personas)
        experiments.to_csv("data/" + now_str()+persona_tags[i]+".csv", index=False)

In [None]:
import time
# keywords
#'overall impact','duty', 'dignity', 'greater good', git 
if False:
    for keyword in [#'obligation','ethic']: # 'dignity', 'greater good',
        'obligation','ethic']:
        print(f'Running search and steering for features associated with "{keyword}"\n')
        features = client.features.search(keyword, model=base)[0][:20]
        steerages = [-.5, -0.3, -0.2, -0.1, 0, 0.1, 0.2, 0.3, 0.5]
        experiments = tabular_experiments(features, steerages, personas=None, wait=1.5, base=base)
        experiments.to_csv("data/" + now_str()+''.join(keyword)+".csv", index=False)
        time.sleep(2)

In [None]:
from itertools import batched
if False:
    for feature_ids in batched(range(0, feature_count), 20):
        features = client.features.lookup(list(feature_ids), model=base)
        print(features)


In [None]:
# Experiment with logits
if False:
    logits = await client.chat.logits(
        messages=[
            {"role": "user", "content": "A random number between 0 and 9 is "}
        ],
        model="meta-llama/Llama-3.3-70B-Instruct",
        filter_vocabulary=list('0123456789')
    )
    print(logits.logits) 
    probs = dict(zip(logits.logits.keys(), softmax(np.array(list(logits.logits.values())))))
    print(probs)