In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os

os.environ["OPENAI_API_KEY"] = open(os.path.join(os.path.expanduser("~"), ".openai_api_key"), "r").read()[:-1]

from neuron_explainer.activations.activation_records import calculate_max_activation
from neuron_explainer.activations.activations import ActivationRecordSliceParams, load_neuron
from neuron_explainer.explanations.calibrated_simulator import UncalibratedNeuronSimulator
from neuron_explainer.explanations.explainer import TokenActivationPairExplainer, SummaryExplainer, HighlightExplainer, HighlightSummaryExplainer
from neuron_explainer.explanations.prompt_builder import PromptFormat
from neuron_explainer.explanations.scoring import simulate_and_score
from neuron_explainer.explanations.simulator import ExplanationNeuronSimulator

EXPLAINER_MODEL_NAME = "gpt-3.5-turbo"
SIMULATOR_MODEL_NAME = "gpt-3.5-turbo-instruct"
MODE = "HighlightSummary"#"Original"#"Summary"#"Fixed"#"Highlight"#"HighlightSummary"
to_print = False #whether to print the prompt used
neuron_record = load_neuron(9, 6236)

cutoff = neuron_record.quantile_boundaries[2]
print("Explainer: {} - Cutoff:{:.3f}\n".format(MODE, cutoff))
# Grab the activation records we'll need.
slice_params = ActivationRecordSliceParams(n_examples_per_split=5)
train_activation_records = neuron_record.train_activation_records(
    activation_record_slice_params=slice_params
)
valid_activation_records = neuron_record.valid_activation_records(
    activation_record_slice_params=slice_params
)

if MODE=="Summary":
    explainer = SummaryExplainer(
        model_name=EXPLAINER_MODEL_NAME,
        prompt_format=PromptFormat.HARMONY_V4,
        max_concurrent=1
    )

    explanations = await explainer.generate_explanations(
        all_activation_records=train_activation_records,
        cutoff=cutoff,
        num_samples=1,
        to_print = to_print
    )
    assert len(explanations) == 1
    explanation = explanations[0]

if MODE=="Highlight":
    explainer = HighlightExplainer(
        model_name=EXPLAINER_MODEL_NAME,
        prompt_format=PromptFormat.HARMONY_V4,
        max_concurrent=1,
    )

    explanations = await explainer.generate_explanations(
        all_activation_records=train_activation_records,
        cutoff=cutoff,
        num_samples=1,
        to_print = to_print,
    )
    assert len(explanations) == 1
    explanation = explanations[0]
    
if MODE=="HighlightSummary":
    explainer = HighlightSummaryExplainer(
        model_name=EXPLAINER_MODEL_NAME,
        prompt_format=PromptFormat.HARMONY_V4,
        max_concurrent=1,
    )

    explanations = await explainer.generate_explanations(
        all_activation_records=train_activation_records,
        cutoff=cutoff,
        num_samples=1,
        to_print = to_print,
    )
    assert len(explanations) == 1
    explanation = explanations[0]
    
    
elif MODE=="Original":
    explainer = TokenActivationPairExplainer(
        model_name=EXPLAINER_MODEL_NAME,
        prompt_format=PromptFormat.HARMONY_V4,
        max_concurrent=1,
    )
    explanations = await explainer.generate_explanations(
        all_activation_records=train_activation_records,
        max_activation=calculate_max_activation(train_activation_records),
        num_samples=1,
        to_print = to_print,
    )
    assert len(explanations) == 1
    explanation = explanations[0]

elif MODE=="AVHS":
    elif MODE=="AVHS":
    explainer = AVHSExplainer(
        model_name=EXPLAINER_MODEL_NAME,
        prompt_format=PromptFormat.HARMONY_V4,
        max_concurrent=1,
    )

    explanations = await explainer.generate_explanations(
        all_activation_records=train_activation_records,
        cutoff=cutoff,
        max_activation=calculate_max_activation(train_activation_records),
        num_samples=1,
        to_print = to_print
    )
    assert len(explanations) == 1
    explanation = explanations[0]

elif MODE=="Fixed":
    explanation = "transition words at the beginning of sentences."
    
print(f"{explanation=}")

# Simulate and score the explanation.
# simulator = UncalibratedNeuronSimulator(
#     ExplanationNeuronSimulator(
#         SIMULATOR_MODEL_NAME,
#         explanation,
#         max_concurrent=1,
#         prompt_format=PromptFormat.INSTRUCTION_FOLLOWING,
#     )
# )
# scored_simulation = await simulate_and_score(simulator, valid_activation_records)
# print(f"score={scored_simulation.get_preferred_score():.2f}")

Explainer: HighlightSummary - Cutoff:1.008

Printing explanation prompt
-
system: We're studying neurons in a neural network, trying to identify their roles. Look at the parts/tokens of the document the neuron activates highly for and summarize in a single sentence what the neuron is looking for. Don't list examples of words.

We will show short text excerpts, where each highly activating token(part of word) is highlighted using square brackets, i.e. [token]. This is followed by a comma separated list of only the highly activating tokens. Your task is to summarize what the highly activating tokens have in common, taking their context into account.

user: 

Neuron 1
Activations:

 Text excerpt - A:
<start>turturro is fabulously funny and over the top as a 'very sneaky' butler who excels in the art of impossible[ disappearing]/reapp[earing] acts<end> 

Highly activating tokens:
 disappearing,earing

 Text excerpt - B:
<start>esc[aping][ the] studio , piccoli is warmly[ affecting] and so 