In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os
import json

# Optionally set your API key if needed:
os.environ["ANTHROPIC_API_KEY"] = "YOUR-API-KEY"

from neuron_explainer.activations.activation_records import calculate_max_activation
from neuron_explainer.activations.activations import ActivationRecordSliceParams, load_neuron
from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator, LinearCalibratedNeuronSimulator
from neuron_explainer.explanations.prompt_builder import PromptFormat
from neuron_explainer.explanations.scoring import simulate_and_score
from neuron_explainer.explanations.simulator import LogprobFreeExplanationTokenSimulator

SIMULATOR_MODEL_NAME = "claude-3-7-sonnet-20250219"
# SIMULATOR_MODEL_NAME = "claude-3-5-haiku-20241022"

layer = 0
index = 0
neuronfile = str(index) + ".jsonl"

# Load a neuron record. (should match the existing one)
neuron_record = load_neuron(layer, index)

# Grab the activation records we'll need.
slice_params = ActivationRecordSliceParams(n_examples_per_split=5)
train_activation_records = neuron_record.train_activation_records(
    activation_record_slice_params=slice_params
)
valid_activation_records = neuron_record.valid_activation_records(
    activation_record_slice_params=slice_params
)

# Instead of generating an explanation, load it from the JSONL file.
with open(neuronfile, "r") as f:
    # Read the first line (assuming it contains our explanation object).
    line = f.readline()
    data = json.loads(line)
    # Extract the explanation. Adjust the keys based on your JSONL structure.
    explanation = data["scored_explanations"][0]["explanation"]

explanations = [explanation]
print(f"{explanation=}")

# Simulate and score the explanation.
simulator = LinearCalibratedNeuronSimulator(
    LogprobFreeExplanationTokenSimulator(
        SIMULATOR_MODEL_NAME,
        explanation,
        max_concurrent=1,
        prompt_format=PromptFormat.INSTRUCTION_FOLLOWING,
    )
)
await simulator.calibrate(train_activation_records)

scored_simulation = await simulate_and_score(simulator, valid_activation_records)
print(f"score={scored_simulation.get_preferred_score():.2f}")