# SAE Bench Eval Template

## Overview

Every eval type has the following:
1. A corresponding sub-package. 
2. A main.py, which includes:
   1.  An argparse interface (`arg_parse`) for running the eval from the command line 
   2.  A run_eval function which operates on a set of SAEs, producing a json with results per SAE. 
3. An eval_config.py with config specific to that eval and defaults set to recommended values. 

## CLI and Eval Config

The CLI interface takes a combination of common arguments (same for all evals) and eval-type specific arguments. Eval-type specific arguments should match those in the eval_config of that sub-package. The common eval arguments should include:
- `sae_regex_pattern` and `sae_block_pattern` used with regex to select SAEs from the SAE Lens library. 
- `output_folder` to place the output in. 
- `model_name` for loading a model from TransformerLens.

To see which SAEs you can select via the regex arguments, use the SAE selection utils like this:

In [1]:
from sae_bench_utils.sae_selection_utils import print_all_sae_releases, print_release_details

# Callum came up with this format which I like visually.
print_all_sae_releases() # each release has a corresponding model / repo_id. We recommend you don't select releases with different models when running evals.
print_release_details('gpt2-small-res-jb') # each release has a number of possible SAEs. 

  from .autonotebook import tqdm as notebook_tqdm


┌─────────────────────────────────────┬─────────────────────────────────────────────────────┬────────────────────────────────────────────────────────┬──────────┐
│ model                               │ release                                             │ repo_id                                                │   n_saes │
├─────────────────────────────────────┼─────────────────────────────────────────────────────┼────────────────────────────────────────────────────────┼──────────┤
│ gemma-2-27b                         │ gemma-scope-27b-pt-res                              │ google/gemma-scope-27b-pt-res                          │       18 │
│ gemma-2-27b                         │ gemma-scope-27b-pt-res-canonical                    │ google/gemma-scope-27b-pt-res                          │        3 │
│ gemma-2-2b                          │ gemma-scope-2b-pt-res                               │ google/gemma-scope-2b-pt-res                           │      310 │
│ gemma-2-2b                

Here's an example call to the absorption eval. Note that we are selecting just one release / SAE (though we could select more) and that we're using default arguments for the eval-specific args (by not setting them via the CLI.)

```bash
python evals/absorption/main.py \
--sae_regex_pattern "sae_bench_pythia70m_sweep_standard_ctx128_0712" \
--sae_block_pattern "blocks.4.hook_resid_post__trainer_.*" \
--model_name pythia-70m-deduped \
--output_folder results
```

To create such an interface, an arg_parse function should be created in the main.py file as below and an EvalConfig should be instantiated in an eval_config.py file inside the eval subpackage. Eval configs should be dataclass objects that have serializable values (so it's easy to save them / load them.)

You can test whether you've set up the `EvalConfig` and `arg_parser` correctly by using the `validate_eval_cli_interface` testing util. Feel free to change the CLI args / Eval Config to test the validation.

In [5]:
import argparse
from dataclasses import dataclass
from sae_bench_utils.testing_utils import validate_eval_cli_interface

def arg_parser():
    parser = argparse.ArgumentParser(description="Run absorption evaluation")
    parser.add_argument("--arg1", type=int, default=42, help="Description for arg1")
    parser.add_argument("--arg2", type=float, default=0.03, help="Description for arg2")
    parser.add_argument("--arg3", type=int, default=10, help="Description for arg3")
    parser.add_argument("--arg4", type=str, default="{word} has the first letter:", help="Description for arg4")
    parser.add_argument("--arg5", type=int, default=-6, help="Description for arg5")
    parser.add_argument("--model_name", type=str, default="pythia-70m-deduped", help="Description for arg6")
    parser.add_argument("--sae_regex_pattern", type=str, required=True, help="Regex pattern for SAE selection")
    parser.add_argument("--sae_block_pattern", type=str, required=True, help="Regex pattern for SAE block selection")
    parser.add_argument("--output_folder", type=str, default="evals/absorption/results", help="Output folder")
    parser.add_argument("--force_rerun", action="store_true", help="Force rerun of experiments")

    return parser

@dataclass
class EvalConfig:
    arg1: int = 42
    arg2: float = 0.03
    arg3: int = 10
    arg4: str = "{word} has the first letter:"
    arg5: int = -6
    model_name: str = "pythia-70m-deduped"

validate_eval_cli_interface(arg_parser(), eval_config_cls=EvalConfig)


## Output Format

Each output json should correspond to one SAE and have the following fields:

See an example here:
- `eval_instance_id`: A unique UUID identifying this specific eval run
- `sae_lens_release`: The release identifier of the SAE from SAE Lens
- `sae_lens_id`: The specific identifier of the SAE within the release
- `eval_type_id`: The type of evaluation being performed
- `sae_lens_version`: Version of SAE Lens used
- `sae_bench_version`: Version of SAE Bench used
- `date_time`: Timestamp when the eval was run
- `eval_config`: Object containing all configuration parameters used for this eval.
- `eval_results`: Object containing the results of the evaluation.

```json
{
    "eval_instance_id": "0c057d5e-973e-410e-8e32-32569323b5e6",
    "sae_lens_release": "sae_bench_pythia70m_sweep_standard_ctx128_0712",
    "sae_lens_id": "blocks.3.hook_resid_post__trainer_10",
    "eval_type_id": "absorption",
    "sae_lens_version": "4.0.0",
    "sae_bench_version": "57e9be0ac9199dba6b9f87fe92f80532e9aefced",
    "date_time": "2024-10-22T10:46:36.799610",
    "eval_config": {
        "random_seed": 42,
        "f1_jump_threshold": 0.03,
        "max_k_value": 10,
        "prompt_template": "{word} has the first letter:",
        "prompt_token_pos": -6,
        "model_name": "pythia-70m-deduped"
    },
    "eval_results": {...},
}
```

In order to validate an output json, you should use the `validate_eval_output_format` function from our testing utils.  Feel free to break the json and see the test fail. (eg: remove a field like `sae_lens_release`).

In [5]:
import json
import os
from sae_bench_utils.testing_utils import validate_eval_output_format

eval_results_temp = {
    "eval_instance_id": "0c057d5e-973e-410e-8e32-32569323b5e6",
    "sae_lens_release": "sae_bench_pythia70m_sweep_standard_ctx128_0712",
    "sae_lens_id": "blocks.3.hook_resid_post__trainer_10",
    "eval_type_id": "absorption",
    "sae_lens_version": "4.0.0",
    "sae_bench_version": "57e9be0ac9199dba6b9f87fe92f80532e9aefced",
    "date_time": "2024-10-22T10:46:36.799610",
    "eval_config": {
        "random_seed": 42,
        "f1_jump_threshold": 0.03,
        "max_k_value": 10,
        "prompt_template": "{word} has the first letter:",
        "prompt_token_pos": -6,
        "model_name": "pythia-70m-deduped"
    },
    "eval_results": {},
}


# save to file
with open('eval_results_temp.json', 'w') as f:
    json.dump(eval_results_temp, f)

validate_eval_output_format('eval_results_temp.json', eval_type="absorption")

# delete file
os.remove('eval_results_temp.json')

We can then load the eval results jsons accross many different SAEs, have a high level of visibility into which evals were run with which parameters and code.


## A note on cached results

A variety of evals can share the results intermediate computation, such as model activations or trained probes. Most of these will be model / hook point specific so should be saved along a path of the format `f'{artifact_dir}/{eval_type}/{model}/{hook_point}/{artifact_id}'`.

