Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
155 commits
Select commit Hold shift + click to select a range
000203f
Cleaning up autoencoder loader
Oct 3, 2024
c42a322
Preparing generation scorer
Oct 11, 2024
96b8a74
Re-doing autoencoders
Oct 11, 2024
8a47fbb
Probability and conditional probability
Oct 11, 2024
da30b38
First pass on clients
Oct 11, 2024
f59aadd
Redoing load_tokenized_data
Oct 11, 2024
25839f1
Small tweaks
Oct 16, 2024
e652874
update gitignore
Oct 16, 2024
cb3ce63
Merge branch 'main' of https://github.com/EleutherAI/sae-auto-interp …
Oct 16, 2024
264e3fb
Testing out api
Oct 22, 2024
e2aa838
Making a explanation server api
Oct 23, 2024
6e0ca1a
Fixes in the prompts
Nov 8, 2024
3df8160
Early exiting
Nov 8, 2024
131efff
Str_toks and text
Nov 8, 2024
0fbbf09
Remove extra keyword
Nov 8, 2024
fbcd7ec
Delete wrong example
Nov 8, 2024
8ea7df7
Autoencoder's stuff
Nov 13, 2024
ad77d3e
Making new pipeline
Nov 13, 2024
35d8e3c
Transcoder
Dec 11, 2024
f296d36
Dataset collumn
Dec 11, 2024
06ebde1
Breaking change: save tokens in cache, make it the primary source of …
neverix Jan 13, 2025
0f77e43
Fixing loader bug
Jan 13, 2025
8d08dbe
Merge branch 'v0.2' into v0.2-save-cache
Jan 17, 2025
38e81e8
Merge pull request #43 from EleutherAI:v0.2-save-cache
SrGonao Jan 17, 2025
162f7d1
tokens in, not hasattr
Jan 17, 2025
79272a3
Sensible default config
Jan 17, 2025
34c8748
Remove debug print
Jan 17, 2025
43c2d9f
Update pyproject.toml
SrGonao Jan 31, 2025
922dc14
Make token_loader None by default
neverix Jan 31, 2025
49fc9cf
Merge branch 'v0.2' of https://github.com/EleutherAI/sae-auto-interp …
Feb 4, 2025
640145e
Abstract load function
Feb 4, 2025
d468e0d
Naive implementation
Feb 4, 2025
5edc76b
Merge branch 'main' of https://github.com/eleutherai/delphi into v0.2
Feb 5, 2025
360b79a
Update name
Feb 5, 2025
6015908
Update config name
Feb 6, 2025
0d30fe5
Update config name and sae
Feb 6, 2025
bb9f2a4
idk
Feb 6, 2025
49517e9
Add neighbour transform
Feb 6, 2025
e6fa6f4
Add neighbour calculator
Feb 6, 2025
2c77007
Add neighbour constructor and update feature record
Feb 6, 2025
16bb4b0
Remove unused method
Feb 6, 2025
f3781b5
Make an abstract cache, and an activation cache
Feb 6, 2025
cc7d5d6
small update to autoencoders
Feb 6, 2025
da8fa2a
Handle distance in detection and fuzzing
Feb 10, 2025
c3d1d97
Correctly create neighbours in transform
Feb 10, 2025
c29c2cb
Return neighbour distance
Feb 10, 2025
7a56a5d
Add all data temporary fix
Feb 10, 2025
941cdc2
Use all data
Feb 10, 2025
185a6d6
Small fixes
Feb 10, 2025
4f17bce
Update from main
Feb 10, 2025
b621f82
Change batch to n_examples
Feb 10, 2025
560e580
Merge branch 'main' of https://github.com/eleutherai/delphi into v0.2
Feb 10, 2025
d044d90
Merge branch 'main' of https://github.com/eleutherai/delphi into v0.2
Feb 10, 2025
60f06f8
Delete extra folder
Feb 10, 2025
40d9d86
Keep experiments in legacy code, stop updating them
Feb 10, 2025
83fdbf8
random -> non_activating
Feb 10, 2025
83b77bb
fix gemma loader
Feb 10, 2025
92be5ea
Remove old key
Feb 10, 2025
ae82aa2
None instead of -1
Feb 10, 2025
eb964e3
Update feature -> latent
Feb 10, 2025
be82d15
Merge pull request #59 from EleutherAI:v0.2
SrGonao Feb 10, 2025
b65c5f9
log results in __main__ by default
luciaquirke Feb 11, 2025
ea3916e
Update README.md
luciaquirke Feb 11, 2025
9cf4a17
Update README.md
luciaquirke Feb 11, 2025
544115c
Update README.md
luciaquirke Feb 11, 2025
60ff74f
Update README.md
luciaquirke Feb 11, 2025
679301c
Update README.md
luciaquirke Feb 11, 2025
e451270
Update README.md
luciaquirke Feb 11, 2025
ba25408
Update README.md
luciaquirke Feb 11, 2025
c274e76
Update README.md
luciaquirke Feb 11, 2025
329d0f5
Update README.md
luciaquirke Feb 11, 2025
fede84f
fix bugs
luciaquirke Feb 11, 2025
b42fade
Merge pull request #60 from EleutherAI/log-results
luciaquirke Feb 11, 2025
2655389
Fix v0.2 bug
luciaquirke Feb 11, 2025
137dd7a
Small fixes
luciaquirke Feb 11, 2025
b5a02fe
Merge branch 'main' of https://github.com/eleutherai/delphi into adve…
Feb 11, 2025
331391b
Feature -> Latent
Feb 11, 2025
809c010
Add to the docstring
Feb 11, 2025
b419849
tokens doesn't need to be in BufferOutput
Feb 11, 2025
d0b499b
Adding correct functions to init
Feb 11, 2025
70442ab
Reformulating constructors
Feb 11, 2025
93120cd
Add semantic index
luciaquirke Feb 11, 2025
aca072c
Fail in a smarter way
Feb 11, 2025
7845391
Fix constructor + update neighbours
Feb 11, 2025
d3b3d71
Update fuzz and detection logic for neighbours
Feb 11, 2025
873f185
Fix neighbours non activating
Feb 11, 2025
60623d0
Moving code to class
Feb 11, 2025
910f45b
Mask name change
Feb 12, 2025
498a0bc
Remove debug pring
Feb 12, 2025
c878733
Deal with the case where there's only active
Feb 12, 2025
ee2d98b
typehint
Feb 12, 2025
abda33e
Merge branch 'faiss' of https://github.com/EleutherAI/delphi into adv…
Feb 12, 2025
920d669
Working on semantic_index
Feb 13, 2025
54a50ad
Merge branch 'main' of https://github.com/EleutherAI/delphi into adve…
Feb 17, 2025
f94e05e
Merge branch 'main' of https://github.com/EleutherAI/delphi into adve…
Feb 17, 2025
8730b6c
New constructor function
Feb 17, 2025
a89f1b7
Correct format explainer
Feb 17, 2025
672c114
Remove dataset loader
Feb 17, 2025
edeaf2a
Create all data tensor, rename dataclasses
Feb 17, 2025
44fbeea
pre-commit stuff
Feb 17, 2025
b17d7e8
Remove code that shouldn't be here
Feb 17, 2025
5ec2b70
Added new non_activating_source argument
Feb 17, 2025
12af689
else bug
Feb 17, 2025
b8ddc8e
circular import
Feb 17, 2025
9a166f6
also circular import fix
Feb 17, 2025
c849aa0
Making constructor work. Loader changes
Feb 17, 2025
1126c96
Adding neighbours to main, config
Feb 17, 2025
90c5a36
Merge branch 'main' of https://github.com/EleutherAI/delphi into neig…
Feb 18, 2025
46bfcb9
Fix all typing errors not envolving Nones or TensorTypes (I think)
Feb 18, 2025
c344cb2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2025
aba7e0c
Merge branch 'main' of https://github.com/EleutherAI/delphi into neig…
Feb 18, 2025
6ebbeea
Merge branch 'fix_types' of https://github.com/EleutherAI/delphi into…
Feb 18, 2025
4d3a174
torchtyping to jaxtyping and other typing fixes
Feb 18, 2025
b32012b
Merge branch 'fix_types' of https://github.com/EleutherAI/delphi into…
Feb 18, 2025
c0963b4
torchtyping -> jaxtyping
Feb 18, 2025
fd3fdb5
Merge branch 'fix_types' of https://github.com/EleutherAI/delphi into…
Feb 18, 2025
813a189
Reformulating neighbours
Feb 18, 2025
f001b41
Type hints and simplification
Feb 18, 2025
5b0ea4e
Last torchtypings
Feb 18, 2025
3f3f2f4
Merge branch 'fix_types' of https://github.com/EleutherAI/delphi into…
Feb 18, 2025
fa54ba4
Use utils tokenized_data
Feb 18, 2025
5346b13
ground_truth->activating
Feb 18, 2025
5b54f63
(Non)Activating examples as children of Examples
Feb 18, 2025
97d2dbc
Mostly type hints
Feb 18, 2025
b9be963
Mostly type hinys
Feb 18, 2025
21dc101
Using ActivatingExample
Feb 18, 2025
c898f16
Better names for things
Feb 18, 2025
f3b4178
Merge remote-tracking branch 'origin/main' into neighbour_latents
Feb 18, 2025
eb11854
Ruff stuff
Feb 18, 2025
222ab8b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 18, 2025
a232531
Shouldn't be here
Feb 18, 2025
035281e
Fixing typing, switching to torch, and removing comments
Feb 19, 2025
e09da34
Removing extra function, adding seed
Feb 19, 2025
6a5a7ac
Trying out dataclass
Feb 19, 2025
7fee61c
Dataclass things
Feb 19, 2025
d4abc2e
Adding tensor alias
Feb 19, 2025
1a85b54
Remove old comment
Feb 19, 2025
eb87744
Add defaults
Feb 19, 2025
8bfbe33
Remove cupy dependency,
Feb 19, 2025
4dda923
Correctly handle more than one hookpoint
Feb 19, 2025
c6d9dca
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2025
161b224
Remove debug print
Feb 19, 2025
f576dcf
add defaults
Feb 19, 2025
b913b96
Removing transforms
Feb 19, 2025
c0a7419
Removing useless code
Feb 19, 2025
12ec375
Not changing latent record in place
Feb 19, 2025
1b78ce3
Changing name
Feb 19, 2025
d5b0480
adding neighbour type
Feb 19, 2025
cffcb33
Not passing constructor or sampler anymore
Feb 19, 2025
1cdfd01
Fixing circular imports
Feb 19, 2025
e41edd3
Merge branch 'main' into neighbour_latents
SrGonao Feb 19, 2025
614473b
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2025
369eb3b
Adding encoder/decoder similarity neighbours
Feb 19, 2025
7b36be7
Merge branch 'neighbour_latents' of https://github.com/EleutherAI/del…
Feb 19, 2025
2c93969
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 19, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ Install this library as a local editable installation. Run the following command

To run the default pipeline from the command line, use the following command:

`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B`
`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B`

This command will:
1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample.
Expand Down
93 changes: 64 additions & 29 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,12 @@
from functools import partial
from glob import glob
from pathlib import Path
from typing import Callable, cast
from typing import Callable

import orjson
import torch
from datasets import load_dataset
from simple_parsing import ArgumentParser
from sparsify.data import chunk_and_tokenize
from torch import Tensor
from torchtyping import TensorType
from transformers import (
AutoModel,
AutoTokenizer,
Expand All @@ -26,12 +23,12 @@
from delphi.config import CacheConfig, ExperimentConfig, LatentConfig, RunConfig
from delphi.explainers import DefaultExplainer
from delphi.latents import LatentCache, LatentDataset
from delphi.latents.constructors import default_constructor
from delphi.latents.samplers import sample
from delphi.latents.neighbours import NeighbourCalculator
from delphi.log.result_analysis import log_results
from delphi.pipeline import Pipe, Pipeline, process_wrapper
from delphi.scorers import DetectionScorer, FuzzingScorer
from delphi.sparse_coders import load_hooks_sparse_coders
from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders
from delphi.utils import load_tokenized_data


def load_artifacts(run_cfg: RunConfig):
Expand Down Expand Up @@ -59,6 +56,42 @@ def load_artifacts(run_cfg: RunConfig):
return run_cfg.hookpoints, hookpoint_to_sparse_encode, model


async def create_neighbours(
run_cfg: RunConfig,
latents_path: Path,
neighbours_path: Path,
hookpoints: list[str],
experiment_cfg: ExperimentConfig,
):
"""
Creates a neighbours file for the given hookpoints.
"""
neighbours_path.mkdir(parents=True, exist_ok=True)

if experiment_cfg.neighbours_type != "co-occurrence":
saes = load_sparse_coders(run_cfg, device="cuda")

for hookpoint in hookpoints:

if experiment_cfg.neighbours_type == "co-occurrence":
neighbour_calculator = NeighbourCalculator(
cache_dir=latents_path / hookpoint, number_of_neighbours=100
)

elif experiment_cfg.neighbours_type == "decoder_similarity":

neighbour_calculator = NeighbourCalculator(
autoencoder=saes[hookpoint], number_of_neighbours=100
)

elif experiment_cfg.neighbours_type == "encoder_similarity":
neighbour_calculator = NeighbourCalculator(
autoencoder=saes[hookpoint], number_of_neighbours=100
)
neighbour_calculator.populate_neighbour_cache(experiment_cfg.neighbours_type)
neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}")


async def process_cache(
latent_cfg: LatentConfig,
run_cfg: RunConfig,
Expand Down Expand Up @@ -88,25 +121,14 @@ async def process_cache(
latent_dict = {
hook: latent_range for hook in hookpoints
} # The latent range to explain
latent_dict = cast(dict[str, int | Tensor], latent_dict)

constructor = partial(
default_constructor,
token_loader=None,
n_not_active=experiment_cfg.n_non_activating,
ctx_len=experiment_cfg.example_ctx_len,
max_examples=latent_cfg.max_examples,
)
sampler = partial(sample, cfg=experiment_cfg)

dataset = LatentDataset(
raw_dir=str(latents_path),
cfg=latent_cfg,
latent_cfg=latent_cfg,
experiment_cfg=experiment_cfg,
modules=hookpoints,
latents=latent_dict,
tokenizer=tokenizer,
constructor=constructor,
sampler=sampler,
)

if run_cfg.explainer_provider == "offline":
Expand Down Expand Up @@ -214,14 +236,15 @@ def populate_cache(
"""
latents_path.mkdir(parents=True, exist_ok=True)

data = load_dataset(
cfg.dataset_repo, name=cfg.dataset_name, split=cfg.dataset_split
tokens = load_tokenized_data(
cfg.ctx_len,
tokenizer,
cfg.dataset_repo,
cfg.dataset_split,
cfg.dataset_name,
cfg.dataset_column,
run_cfg.seed,
)
data = data.shuffle(run_cfg.seed)
data = chunk_and_tokenize(
data, tokenizer, max_seq_len=cfg.ctx_len, text_key=cfg.dataset_column
)
tokens = data["input_ids"]

if run_cfg.filter_bos:
if tokenizer.bos_token_id is None:
Expand All @@ -235,8 +258,6 @@ def populate_cache(
]
tokens = truncated_tokens.reshape(-1, cfg.ctx_len)

tokens = cast(TensorType["batch", "seq"], tokens)

cache = LatentCache(
model,
hookpoint_to_sparse_encode,
Expand Down Expand Up @@ -271,6 +292,7 @@ async def run(
latents_path = base_path / "latents"
explanations_path = base_path / "explanations"
scores_path = base_path / "scores"
neighbours_path = base_path / "neighbours"
visualize_path = base_path / "visualize"

latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None
Expand All @@ -294,6 +316,19 @@ async def run(
print(f"Files found in {latents_path}, skipping cache population...")

del model, hookpoint_to_sparse_encode
if (
not glob(str(neighbours_path / ".*")) + glob(str(neighbours_path / "*"))
or "neighbours" in run_cfg.overwrite
):
await create_neighbours(
run_cfg,
latents_path,
neighbours_path,
hookpoints,
experiment_cfg,
)
else:
print(f"Files found in {neighbours_path}, skipping...")

if (
not glob(str(scores_path / ".*")) + glob(str(scores_path / "*"))
Expand Down
13 changes: 12 additions & 1 deletion delphi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,17 @@ class ExperimentConfig(Serializable):
test_type: Literal["quantiles", "activation"] = "quantiles"
"""Type of sampler to use for latent explanation testing."""

non_activating_source: Literal["random", "neighbours"] = "random"
"""Source of non-activating examples. Random uses non-activating contexts
sampled from any non activating window. Neighbours uses actvating contexts
from pre-computed latent neighbours. They are still non-activating but
have a higher chance of being similar to the activating examples."""

neighbours_type: Literal[
"co-occurrence", "decoder_similarity", "encoder_similarity"
] = "co-occurrence"
"""Type of neighbours to use. Only used if non_activating_source is 'neighbours'."""


@dataclass
class LatentConfig(Serializable):
Expand Down Expand Up @@ -145,6 +156,6 @@ class RunConfig:
scoring speed but can leak information to the fuzzing and detection scorer,
as well as increasing the scorer LLM task difficulty."""

overwrite: list[Literal["cache", "scores"]] = list_field()
overwrite: list[Literal["cache", "neighbours", "scores"]] = list_field()
"""List of run stages to recompute. This is a debugging tool
and may be removed in the future."""
122 changes: 20 additions & 102 deletions delphi/explainers/default/default.py
Original file line number Diff line number Diff line change
@@ -1,117 +1,35 @@
import asyncio
import re
from dataclasses import dataclass

from ...logger import logger
from ..explainer import Explainer, ExplainerResult
from ..explainer import ActivatingExample, Explainer
from .prompt_builder import build_prompt


@dataclass
class DefaultExplainer(Explainer):
name = "default"
activations: bool = True
"""Whether to show activations to the explainer."""
cot: bool = False
"""Whether to use chain of thought reasoning."""

def __init__(
self,
client,
tokenizer,
verbose: bool = False,
activations: bool = False,
cot: bool = False,
threshold: float = 0.6,
temperature: float = 0.0,
**generation_kwargs,
):
self.client = client
self.tokenizer = tokenizer
self.verbose = verbose

self.activations = activations
self.cot = cot
self.threshold = threshold
self.temperature = temperature
self.generation_kwargs = generation_kwargs

async def __call__(self, record):
messages = self._build_prompt(record.train)

response = await self.client.generate(
messages, temperature=self.temperature, **self.generation_kwargs
)

try:
explanation = self.parse_explanation(response.text)
if self.verbose:
logger.info(f"Explanation: {explanation}")
logger.info(f"Final message to explainer: {messages[-1]['content']}")
logger.info(f"Response from explainer: {response.text}")

return ExplainerResult(record=record, explanation=explanation)
except Exception as e:
logger.error(f"Explanation parsing failed: {e}")
return ExplainerResult(
record=record, explanation="Explanation could not be parsed."
)

def parse_explanation(self, text: str) -> str:
try:
match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL)
return (
match.group(1).strip() if match else "Explanation could not be parsed."
)
except Exception as e:
logger.error(f"Explanation parsing regex failed: {e}")
raise

def _highlight(self, index, example):
result = f"Example {index}: "

threshold = example.max_activation * self.threshold
if self.tokenizer is not None:
str_toks = self.tokenizer.batch_decode(example.tokens)
example.str_toks = str_toks
else:
str_toks = example.tokens
example.str_toks = str_toks
activations = example.activations

def check(i):
return activations[i] > threshold

i = 0
while i < len(str_toks):
if check(i):
result += "<<"

while i < len(str_toks) and check(i):
result += str_toks[i]
i += 1
result += ">>"
else:
result += str_toks[i]
i += 1

return "".join(result)

def _join_activations(self, example):
activations = []

for i, activation in enumerate(example.activations):
if activation > example.max_activation * self.threshold:
activations.append(
(example.str_toks[i], int(example.normalized_activations[i]))
)

acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations)

return "Activations: " + acts

def _build_prompt(self, examples):
def _build_prompt(self, examples: list[ActivatingExample]) -> list[dict]:
highlighted_examples = []

for i, example in enumerate(examples):
highlighted_examples.append(self._highlight(i + 1, example))
str_toks = self.tokenizer.batch_decode(example.tokens)
activations = example.activations.tolist()
highlighted_examples.append(self._highlight(str_toks, activations))

if self.activations:
highlighted_examples.append(self._join_activations(example))
assert (
example.normalized_activations is not None
), "Normalized activations are required for activations in explainer"
normalized_activations = example.normalized_activations.tolist()
highlighted_examples.append(
self._join_activations(
str_toks, activations, normalized_activations
)
)

highlighted_examples = "\n".join(highlighted_examples)

Expand Down
6 changes: 3 additions & 3 deletions delphi/explainers/default/prompt_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def build_examples(


def build_prompt(
examples,
examples: str,
activations: bool = False,
cot: bool = False,
):
) -> list[dict]:
messages = system(
cot=cot,
)
Expand All @@ -49,7 +49,7 @@ def build_prompt(
"content": user_start,
}
)

return messages


Expand Down
Loading