From 715d963b6e8f56494d7ddfba46563bb93a402498 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 4 Mar 2025 23:13:41 +0000 Subject: [PATCH 1/6] Add pyright to CI --- .github/workflows/tests.yml | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index daf7e650..c8deb68a 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -27,3 +27,8 @@ jobs: - name: Run tests run: pytest + + - name: Type Checking + uses: jakebailey/pyright-action@v1 + with: + version: 1.1.378 From d86b611b348e31308234dc13942ca532b1fc0405 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 4 Mar 2025 23:45:11 +0000 Subject: [PATCH 2/6] fix minor type issues; add type ignore to loosely typed files --- delphi/__main__.py | 23 ++++++--- delphi/latents/latents.py | 5 +- .../oai_autointerp/explanations/simulator.py | 48 +++++++++---------- delphi/scorers/surprisal/surprisal.py | 20 +++++--- delphi/sparse_coders/load_sparsify.py | 14 +++--- delphi/sparse_coders/sparse_model.py | 3 +- delphi/tests/e2e.py | 3 +- delphi/utils.py | 13 +++++ 8 files changed, 82 insertions(+), 47 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index d4ea1259..19d9c7c3 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -26,7 +26,7 @@ from delphi.pipeline import Pipe, Pipeline, process_wrapper from delphi.scorers import DetectionScorer, FuzzingScorer from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders -from delphi.utils import load_tokenized_data +from delphi.utils import assert_type, load_tokenized_data def load_artifacts(run_cfg: RunConfig): @@ -325,8 +325,11 @@ async def run( hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) - nrh = non_redundant_hookpoints( - hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite + nrh = assert_type( + dict, + non_redundant_hookpoints( + hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite + ), ) if nrh: populate_cache( @@ -340,8 +343,11 @@ async def run( del model, hookpoint_to_sparse_encode if run_cfg.constructor_cfg.non_activating_source == "neighbours": - nrh = non_redundant_hookpoints( - hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite + nrh = assert_type( + list, + non_redundant_hookpoints( + hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite + ), ) if nrh: create_neighbours( @@ -353,8 +359,11 @@ async def run( else: print("Skipping neighbour creation") - nrh = non_redundant_hookpoints( - hookpoints, scores_path, "scores" in run_cfg.overwrite + nrh = assert_type( + list, + non_redundant_hookpoints( + hookpoints, scores_path, "scores" in run_cfg.overwrite + ), ) if nrh: await process_cache( diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index a68f146b..b5eb6030 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -134,7 +134,7 @@ class LatentRecord: train: list[ActivatingExample] = field(default_factory=list) """Training examples.""" - test: list[ActivatingExample] = field(default_factory=list) + test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list) """Test examples.""" neighbours: list[Neighbour] = field(default_factory=list) @@ -143,6 +143,9 @@ class LatentRecord: explanation: str = "" """Explanation of the latent.""" + extra_examples: Optional[list[Example]] = None + """Extra examples to include in the record.""" + @property def max_activation(self) -> float: """ diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py index eb29e0dd..6e7d7a7b 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py @@ -107,8 +107,8 @@ def parse_top_logprobs(top_logprobs: dict[str, float]) -> OrderedDict[int, float """ probabilities_by_distribution_value = OrderedDict() for token, contents in top_logprobs.items(): - logprob = contents.logprob - decoded_token = contents.decoded_token + logprob = contents.logprob # type: ignore + decoded_token = contents.decoded_token # type: ignore if decoded_token in VALID_ACTIVATION_TOKENS: token_as_int = int(decoded_token) probabilities_by_distribution_value[token_as_int] = np.exp(logprob) @@ -134,7 +134,7 @@ def compute_predicted_activation_stats_for_token( def parse_simulation_response( - response: dict[str, Any], + response: Any, tokenized_prompt: list[int], tab_token: int, tokens: Sequence[str], @@ -250,11 +250,11 @@ async def simulate( else: assert isinstance(prompt, str) - response = await self.client.generate(prompt, **sampling_params) - tokenized_prompt = self.client.tokenizer.apply_chat_template( + response = await self.client.generate(prompt, **sampling_params) # type: ignore + tokenized_prompt = self.client.tokenizer.apply_chat_template( # type: ignore prompt, add_generation_prompt=True ) - tab_token = self.client.tokenizer.encode("\t")[1] + tab_token = self.client.tokenizer.encode("\t")[1] # type: ignore logger.debug("response in score_explanation_by_activations is %s", response) try: result = parse_simulation_response( @@ -287,7 +287,7 @@ def make_simulation_prompt( # Consider reconciling them. prompt_builder = PromptBuilder() prompt_builder.add_message( - "system", + "system", # type: ignore """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at summary of what the neuron does, and try to predict how it will fire on each token. @@ -299,7 +299,7 @@ def make_simulation_prompt( few_shot_examples = self.few_shot_example_set.get_examples() for i, example in enumerate(few_shot_examples): prompt_builder.add_message( - "user", + "user", # type: ignore f"\n\nNeuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX}" f"{example.explanation}", ) @@ -309,17 +309,17 @@ def make_simulation_prompt( start_indices=example.first_revealed_activation_indices, ) prompt_builder.add_message( - "assistant", f"\nActivations: {formatted_activation_records}\n" + "assistant", f"\nActivations: {formatted_activation_records}\n" # type: ignore ) prompt_builder.add_message( - "user", + "user", # type: ignore f"\n\nNeuron {len(few_shot_examples) + 1}\nExplanation of neuron " f"{len(few_shot_examples) + 1} behavior: {EXPLANATION_PREFIX} " f"{self.explanation.strip()}", ) prompt_builder.add_message( - "assistant", + "assistant", # type: ignore f"\nActivations: {format_sequences_for_simulation([tokens])}", ) return prompt_builder.build(self.prompt_format) @@ -595,7 +595,7 @@ async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation: result = SequenceSimulation( activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, - expected_activations=predicted_activations, + expected_activations=predicted_activations, # type: ignore # Since the predicted activation is just a sampled token, we don't have a distribution. distribution_values=[], distribution_probabilities=[], @@ -614,7 +614,7 @@ def _make_simulation_prompt_json( assert explanation != "" prompt_builder = PromptBuilder() prompt_builder.add_message( - "system", + "system", # type: ignore """We're studying neurons in a neural network. Each neuron looks for certain things in a short document. Your task is to read the explanation of what the neuron does, and predict the neuron's activations for each token in the document. For each document, you will see the full text of the document, then the tokens in the document with the activation left blank. You will print, in valid json, the exact same tokens verbatim, but with the activation values filled in according to the explanation. Pay special attention to the explanation's description of the context and order of tokens or words. @@ -638,7 +638,7 @@ def _make_simulation_prompt_json( } """ prompt_builder.add_message( - "user", + "user", # type: ignore _format_record_for_logprob_free_simulation_json( explanation=example.explanation, activation_record=example.activation_records[0], @@ -658,7 +658,7 @@ def _make_simulation_prompt_json( } """ prompt_builder.add_message( - "assistant", + "assistant", # type: ignore _format_record_for_logprob_free_simulation_json( explanation=example.explanation, activation_record=example.activation_records[0], @@ -678,10 +678,10 @@ def _make_simulation_prompt_json( } """ prompt_builder.add_message( - "user", + "user", # type: ignore _format_record_for_logprob_free_simulation_json( explanation=explanation, - activation_record=ActivationRecord(tokens=tokens, activations=[]), + activation_record=ActivationRecord(tokens=tokens, activations=[]), # type: ignore include_activations=False, ), ) @@ -698,7 +698,7 @@ def _make_simulation_prompt( assert explanation != "" prompt_builder = PromptBuilder() prompt_builder.add_message( - "system", + "system", # type: ignore """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. @@ -716,7 +716,7 @@ def _make_simulation_prompt( example.activation_records[0], include_activations=False ) prompt_builder.add_message( - "user", + "user", # type: ignore f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " f"{example.explanation}\n\n" f"Sequence 1 Tokens without Activations:\n{tokens_without_activations}\n\n" @@ -728,7 +728,7 @@ def _make_simulation_prompt( max_activation=few_shot_example_max_activation, ) prompt_builder.add_message( - "assistant", + "assistant", # type: ignore f"{tokens_with_activations}\n\n", ) @@ -737,7 +737,7 @@ def _make_simulation_prompt( record, include_activations=False ) prompt_builder.add_message( - "user", + "user", # type: ignore f"Sequence {record_index + 2} Tokens without Activations:\n{tks_without}\n\n" f"Sequence {record_index + 2} Tokens with Activations:\n", ) @@ -747,16 +747,16 @@ def _make_simulation_prompt( max_activation=few_shot_example_max_activation, ) prompt_builder.add_message( - "assistant", + "assistant", # type: ignore f"{tokens_with_activations}\n\n", ) neuron_index = len(few_shot_examples) + 1 tokens_without_activations = _format_record_for_logprob_free_simulation( - ActivationRecord(tokens=tokens, activations=[]), include_activations=False + ActivationRecord(tokens=tokens, activations=[]), include_activations=False # type: ignore ) prompt_builder.add_message( - "user", + "user", # type: ignore f"Neuron {neuron_index}\nExplanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} " f"{explanation}\n\n" f"Sequence 1 Tokens without Activations:\n{tokens_without_activations}\n\n" diff --git a/delphi/scorers/surprisal/surprisal.py b/delphi/scorers/surprisal/surprisal.py index da0ccfeb..1231450e 100644 --- a/delphi/scorers/surprisal/surprisal.py +++ b/delphi/scorers/surprisal/surprisal.py @@ -3,10 +3,13 @@ from typing import NamedTuple import torch +from simple_parsing import field from torch.nn.functional import cross_entropy from transformers import PreTrainedTokenizer -from ...latents import Example, LatentRecord +from delphi.utils import assert_type + +from ...latents import ActivatingExample, Example, LatentRecord from ..scorer import Scorer, ScorerResult from .prompts import BASEPROMPT as base_prompt @@ -19,13 +22,13 @@ class SurprisalOutput: distance: float | int """Quantile or neighbor distance""" - no_explanation: list[float] = 0 + no_explanation: list[float] = field(default_factory=list) """What is the surprisal of the model with no explanation""" - explanation: list[float] = 0 + explanation: list[float] = field(default_factory=list) """What is the surprisal of the model with an explanation""" - activations: list[float] = 0 + activations: list[float] = field(default_factory=list) """What are the activations of the model""" @@ -55,7 +58,7 @@ def __init__( async def __call__( self, record: LatentRecord, - ) -> list[SurprisalOutput]: + ) -> ScorerResult: samples = self._prepare(record) random.shuffle(samples) @@ -66,7 +69,7 @@ async def __call__( return ScorerResult(record=record, score=results) - def _prepare(self, record: LatentRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[Sample]: """ Prepare and shuffle a list of samples for classification. """ @@ -74,6 +77,8 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: defaults = { "tokenizer": self.tokenizer, } + + assert record.extra_examples is not None, "No extra examples provided" samples = examples_to_samples( record.extra_examples, distance=-1, @@ -81,6 +86,7 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: ) for i, examples in enumerate(record.test): + examples = assert_type(list, examples) samples.extend( examples_to_samples( examples, @@ -181,7 +187,7 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[SurprisalOutpu def examples_to_samples( - examples: list[Example], + examples: list[Example] | list[ActivatingExample], tokenizer: PreTrainedTokenizer, **sample_kwargs, ) -> list[Sample]: diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index c9c259e3..1d4bfa9c 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -3,19 +3,21 @@ from typing import Callable import torch -from sparsify import Sae +from sparsify import SparseCoder from torch import Tensor from transformers import PreTrainedModel -def sae_dense_latents(x: Tensor, sae: Sae) -> Tensor: +def sae_dense_latents(x: Tensor, sae: SparseCoder) -> Tensor: """Run `sae` on `x`, yielding the dense activations.""" pre_acts = sae.pre_acts(x) acts, indices = sae.select_topk(pre_acts) return torch.zeros_like(pre_acts).scatter_(-1, indices, acts) -def resolve_path(model: PreTrainedModel, path_segments: list[str]) -> list[str] | None: +def resolve_path( + model: PreTrainedModel | torch.nn.Module, path_segments: list[str] +) -> list[str] | None: """Attempt to resolve the path segments to the model in the case where it has been wrapped (e.g. by a LanguageModel, causal model, or classifier).""" # If the first segment is a valid attribute, return the path segments @@ -45,7 +47,7 @@ def load_sparsify_sparse_coders( hookpoints: list[str], device: str | torch.device, compile: bool = False, -) -> dict[str, Sae]: +) -> dict[str, SparseCoder]: """ Load sparsify sparse coders for specified hookpoints. @@ -67,7 +69,7 @@ def load_sparsify_sparse_coders( name_path = Path(name) if name_path.exists(): for hookpoint in hookpoints: - sparse_model_dict[hookpoint] = Sae.load_from_disk( + sparse_model_dict[hookpoint] = SparseCoder.load_from_disk( name_path / hookpoint, device=device ) if compile: @@ -76,7 +78,7 @@ def load_sparsify_sparse_coders( ) else: # Load on CPU first to not run out of memory - sparse_models = Sae.load_many(name, device="cpu") + sparse_models = SparseCoder.load_many(name, device="cpu") for hookpoint in hookpoints: sparse_model_dict[hookpoint] = sparse_models[hookpoint].to(device) if compile: diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index ad55901d..fdd5769a 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from sparsify import SparseCoder from transformers import PreTrainedModel from delphi.config import RunConfig @@ -74,7 +75,7 @@ def load_sparse_coders( run_cfg: RunConfig, device: str | torch.device, compile: bool = False, -) -> dict[str, nn.Module]: +) -> dict[str, nn.Module] | dict[str, SparseCoder]: """ Load sparse coders for specified hookpoints. diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index a984bb5b..421686fe 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -3,6 +3,7 @@ from pathlib import Path import torch +from pandas import DataFrame from delphi.__main__ import RunConfig, run from delphi.config import CacheConfig, ConstructorConfig, SamplerConfig @@ -60,7 +61,7 @@ async def test(): scores_path = Path("results") / run_cfg.name / "scores" df = build_scores_df(scores_path, run_cfg.hookpoints) for score_type in df["score_type"].unique(): - score_df = df[df["score_type"] == score_type] + score_df: DataFrame = df[df["score_type"] == score_type] weighted_mean_metrics = latent_balanced_score_metrics( score_df, score_type, verbose=False ) diff --git a/delphi/utils.py b/delphi/utils.py index ac0f88db..2b278cb9 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -1,3 +1,5 @@ +from typing import Any, Type, TypeVar, cast + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -29,3 +31,14 @@ def load_tokenized_data( tokens = tokens_ds["input_ids"] return tokens + + +T = TypeVar("T") + + +def assert_type(typ: Type[T], obj: Any) -> T: + """Assert that an object is of a given type at runtime and return it.""" + if not isinstance(obj, typ): + raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") + + return cast(typ, obj) From 2d4f743e746ade874c2d553db5477c180768acd0 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 5 Mar 2025 00:34:56 +0000 Subject: [PATCH 3/6] fix more types --- delphi/__main__.py | 5 +++++ delphi/clients/client.py | 4 ++-- delphi/clients/offline.py | 8 ++++---- delphi/tests/e2e.py | 8 ++++---- pyproject.toml | 3 ++- 5 files changed, 17 insertions(+), 11 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 19d9c7c3..bc2974f3 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -90,6 +90,11 @@ def create_neighbours( neighbour_calculator = NeighbourCalculator( autoencoder=saes[hookpoint].cuda(), number_of_neighbours=100 ) + else: + raise ValueError( + f"Neighbour type {constructor_cfg.neighbours_type} not supported" + ) + neighbour_calculator.populate_neighbour_cache(constructor_cfg.neighbours_type) neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}") diff --git a/delphi/clients/client.py b/delphi/clients/client.py index e7afc37d..a9551fcc 100644 --- a/delphi/clients/client.py +++ b/delphi/clients/client.py @@ -6,8 +6,8 @@ @dataclass class Response: text: str - logprobs: list[float] = None - prompt_logprobs: list[float] = None + logprobs: list[float] | None = None + prompt_logprobs: list[float] | None = None class Client(ABC): diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index cd70bd48..a0581a01 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -91,8 +91,8 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): self.sampling_params.temperature = kwarg["temperature"] loop = asyncio.get_running_loop() prompts = [] - if self.statistics: - statistics = [] + statistics = [] + for batch in batches: prompt = self.tokenizer.apply_chat_template( batch, add_generation_prompt=True, tokenize=True @@ -101,7 +101,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): if self.statistics: non_cached_tokens = len( self.tokenizer.apply_chat_template( - batch[-1:], add_generation_prompt=True, tokenize=True + batch[-1:], add_generation_prompt=True, tokenize=True # type: ignore ) ) statistics.append( @@ -114,7 +114,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): response = await loop.run_in_executor( None, partial( - self.client.generate, + self.client.generate, # type: ignore prompt_token_ids=prompts, sampling_params=self.sampling_params, use_tqdm=False, diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index 421686fe..073ef3d0 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -3,10 +3,9 @@ from pathlib import Path import torch -from pandas import DataFrame -from delphi.__main__ import RunConfig, run -from delphi.config import CacheConfig, ConstructorConfig, SamplerConfig +from delphi.__main__ import run +from delphi.config import CacheConfig, ConstructorConfig, RunConfig, SamplerConfig from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics @@ -61,7 +60,8 @@ async def test(): scores_path = Path("results") / run_cfg.name / "scores" df = build_scores_df(scores_path, run_cfg.hookpoints) for score_type in df["score_type"].unique(): - score_df: DataFrame = df[df["score_type"] == score_type] + score_df = df.query(f"score_type == '{score_type}'") + weighted_mean_metrics = latent_balanced_score_metrics( score_df, score_type, verbose=False ) diff --git a/pyproject.toml b/pyproject.toml index 4c67bcb7..45a35294 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ dev = ["pytest"] visualize = [ "kaleido==0.2.1", - "plotly>=5.0.0rc2" + "plotly>=5.0.0rc2", + "pandas" ] [tool.pyright] From 2a3ffa896d326223d123a49d89d77e2e26e9e636 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 5 Mar 2025 00:50:45 +0000 Subject: [PATCH 4/6] fix types --- .github/workflows/tests.yml | 2 +- delphi/__main__.py | 7 +++++-- delphi/clients/offline.py | 7 +++---- delphi/clients/openrouter.py | 2 +- delphi/explainers/single_token_explainer.py | 15 +++++++++++++-- delphi/latents/cache.py | 1 - delphi/latents/constructors.py | 3 +++ delphi/latents/latents.py | 2 +- delphi/latents/neighbours.py | 3 ++- delphi/pipeline.py | 16 ++++++++-------- 10 files changed, 37 insertions(+), 21 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c8deb68a..bfc1c435 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ".[dev]" + pip install ".[dev,visualize]" - name: Run tests run: pytest diff --git a/delphi/__main__.py b/delphi/__main__.py index bc2974f3..85474114 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -70,8 +70,11 @@ def create_neighbours( neighbours_path.mkdir(parents=True, exist_ok=True) constructor_cfg = run_cfg.constructor_cfg - if constructor_cfg.neighbours_type != "co-occurrence": - saes = load_sparse_coders(run_cfg, device="cpu") + saes = ( + load_sparse_coders(run_cfg, device="cpu") + if constructor_cfg.neighbours_type != "co-occurrence" + else {} + ) for hookpoint in hookpoints: diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index a0581a01..2c4d078a 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -127,10 +127,10 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): if self.statistics: statistics[i].num_generated_tokens = len(r.outputs[0].token_ids) # save the statistics to a file, name is a hash of the prompt - statistics[i].prompt = batches[i][-1]["content"] + statistics[i].prompt = batches[i][-1]["content"] # type: ignore statistics[i].response = r.outputs[0].text with open( - f"statistics/{hash(batches[i][-1]['content'][-100:])}.json", "w" + f"statistics/{hash(batches[i][-1]['content'][-100:])}.json", "w" # type: ignore ) as f: json.dump(statistics[i].__dict__, f, indent=4) new_response.append( @@ -142,7 +142,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): ) return new_response - async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: + async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: # type: ignore """ Enqueue a request and wait for the result. """ @@ -150,7 +150,6 @@ async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> if self.task is None: self.task = asyncio.create_task(self._process_batches()) await self.queue.put((prompt, future, kwargs)) - # print(f"Current queue size: {self.queue.qsize()} prompts") return await future def _parse_logprobs(self, response): diff --git a/delphi/clients/openrouter.py b/delphi/clients/openrouter.py index acb80345..e74680cc 100644 --- a/delphi/clients/openrouter.py +++ b/delphi/clients/openrouter.py @@ -37,7 +37,7 @@ def postprocess(self, response): async def generate( self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs - ) -> Response: + ) -> Response: # type: ignore kwargs.pop("schema", None) max_tokens = kwargs.pop("max_tokens", 500) temperature = kwargs.pop("temperature", 1.0) diff --git a/delphi/explainers/single_token_explainer.py b/delphi/explainers/single_token_explainer.py index ffef2832..439f51bc 100644 --- a/delphi/explainers/single_token_explainer.py +++ b/delphi/explainers/single_token_explainer.py @@ -32,10 +32,21 @@ def _build_prompt(self, examples): highlighted_examples = [] for i, example in enumerate(examples): - highlighted_examples.append(self._highlight(i + 1, example)) + highlighted_examples.append( + self._highlight(example.str_tokens, example.activations.tolist()) + ) if self.activations: - highlighted_examples.append(self._join_activations(example)) + assert ( + example.normalized_activations is not None + ), "Normalized activations are required for activations in explainer" + highlighted_examples.append( + self._join_activations( + example.str_tokens, + example.activations.tolist(), + example.normalized_activations.tolist(), + ) + ) return build_single_token_prompt( examples=highlighted_examples, diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 800bd0a4..1530d983 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -279,7 +279,6 @@ def run(self, n_tokens: int, tokens: token_tensor_shape): print(f"Total tokens processed: {total_tokens:,}") self.cache.save() - del sae_latents def save(self, save_dir: Path, save_tokens: bool = True): """ diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 18f695ec..94e17fd1 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -189,6 +189,9 @@ def constructor( seed=seed, tokenizer=tokenizer, ) + else: + raise ValueError(f"Invalid non-activating source: {source_non_activating}") + record.not_active = non_activating_examples return record diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index b5eb6030..cf142f35 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -206,7 +206,7 @@ def display( Returns: str: The formatted string. """ - from IPython.core.display import HTML, display + from IPython.core.display import HTML, display # type: ignore def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str: """ diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index af8aaaca..397637e8 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -179,6 +179,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] n_tokens = int(idx_cantor.max().item()) token_batch_size = 20_000 + co_occurrence_matrix = None done = False while not done: try: @@ -197,7 +198,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] co_occurrence_matrix = torch.zeros( (n_latents, n_latents), dtype=torch.int32 ) - # co_occurrence_matrix = co_occurrence_matrix.cuda() for start, end in tqdm( zip(batch_boundaries[:-1], batch_boundaries[1:]) @@ -239,6 +239,7 @@ def compute_jaccard(cooc_matrix): return jaccard_matrix # Compute Jaccard similarity matrix + assert co_occurrence_matrix is not None, "Co-occurrence matrix is not computed" jaccard_matrix = compute_jaccard(co_occurrence_matrix) # get the indices of the top k neighbours for each feature diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 0b3f64ed..b3ead562 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -81,7 +81,9 @@ def __init__(self, loader: AsyncIterable | Callable, *pipes: Pipe | Callable): loader (Callable): The loader to be executed first. *pipes (list[Pipe]): Pipes to be executed in the pipeline. """ - self.pipes = [loader] + list(pipes) + + self.loader = loader + self.pipes = pipes async def run(self, max_concurrent: int = 10) -> list[Any]: """ @@ -136,13 +138,11 @@ async def generate_items(self) -> AsyncIterable[Any]: Raises: TypeError: If the first pipe is neither an async iterable nor a callable. """ - first_pipe = self.pipes[0] - - if isinstance(first_pipe, AsyncIterable): - async for item in first_pipe: + if isinstance(self.loader, AsyncIterable): + async for item in self.loader: yield item - elif callable(first_pipe): - for item in first_pipe(): + elif callable(self.loader): + for item in self.loader(): yield item await asyncio.sleep(0) # Allow other coroutines to run else: @@ -164,6 +164,6 @@ async def process_item( """ async with semaphore: result = item - for pipe in self.pipes[1:]: + for pipe in self.pipes: result = await pipe(result) return result From a814d77233442f25f36e826c6c28dfd0c524807e Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 5 Mar 2025 01:03:09 +0000 Subject: [PATCH 5/6] add type ignores --- delphi/clients/openrouter.py | 4 ++-- delphi/log/result_analysis.py | 2 +- delphi/scorers/classifier/classifier.py | 8 +++---- delphi/scorers/classifier/detection.py | 4 ++-- delphi/scorers/classifier/fuzz.py | 6 ++--- delphi/scorers/embedding/embedding.py | 22 +++++++++---------- .../oai_autointerp/explanations/scoring.py | 12 +++++----- .../oai_autointerp/explanations/simulator.py | 1 + delphi/scorers/simulator/oai_simulator.py | 8 +++---- delphi/scorers/surprisal/surprisal.py | 8 +++---- 10 files changed, 38 insertions(+), 37 deletions(-) diff --git a/delphi/clients/openrouter.py b/delphi/clients/openrouter.py index e74680cc..73e5b864 100644 --- a/delphi/clients/openrouter.py +++ b/delphi/clients/openrouter.py @@ -20,7 +20,7 @@ class OpenRouter(Client): def __init__( self, model: str, - api_key: str = None, + api_key: str | None = None, base_url="https://openrouter.ai/api/v1/chat/completions", ): super().__init__(model) @@ -36,7 +36,7 @@ def postprocess(self, response): return Response(msg) async def generate( - self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs + self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore ) -> Response: # type: ignore kwargs.pop("schema", None) max_tokens = kwargs.pop("max_tokens", 500) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ccab5327..ce52ce18 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -258,5 +258,5 @@ def log_results(scores_path: Path, visualize_path: Path, target_modules: list[st plot_line(df, visualize_path) for score_type in df["score_type"].unique(): - score_df = df[df["score_type"] == score_type] + score_df = df.query(f"score_type == '{score_type}'") latent_balanced_score_metrics(score_df, score_type) diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index e5f8f915..45cb6b94 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -41,10 +41,10 @@ def __init__( self.generation_kwargs = generation_kwargs self.log_prob = log_prob - async def __call__( - self, - record: LatentRecord, - ) -> ScorerResult: + async def __call__( # type: ignore + self, # type: ignore + record: LatentRecord, # type: ignore + ) -> ScorerResult: # type: ignore samples = self._prepare(record) random.shuffle(samples) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 9724ad46..bd78fbdf 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -42,7 +42,7 @@ def __init__( def prompt(self, examples: str, explanation: str) -> list[dict]: return prompt(examples, explanation) - def _prepare(self, record: LatentRecord) -> list[Sample]: + def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore """ Prepare and shuffle a list of samples for classification. """ @@ -57,7 +57,7 @@ def _prepare(self, record: LatentRecord) -> list[Sample]: samples.extend( examples_to_samples( - record.test, + record.test, # type: ignore ) ) diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 31c1d8e2..667db798 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -61,13 +61,13 @@ def mean_n_activations_ceil(self, examples: list[ActivatingExample]): return ceil(avg) - def _prepare(self, record: LatentRecord) -> list[Sample]: + def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore """ Prepare and shuffle a list of samples for classification. """ assert len(record.test) > 0, "No test records found" - n_incorrect = self.mean_n_activations_ceil(record.test) + n_incorrect = self.mean_n_activations_ceil(record.test) # type: ignore if len(record.not_active) > 0: samples = examples_to_samples( @@ -81,7 +81,7 @@ def _prepare(self, record: LatentRecord) -> list[Sample]: samples.extend( examples_to_samples( - record.test, + record.test, # type: ignore n_incorrect=0, highlighted=True, ) diff --git a/delphi/scorers/embedding/embedding.py b/delphi/scorers/embedding/embedding.py index 677f12a7..2de89874 100644 --- a/delphi/scorers/embedding/embedding.py +++ b/delphi/scorers/embedding/embedding.py @@ -42,22 +42,22 @@ def __init__( self.tokenizer = tokenizer self.generation_kwargs = generation_kwargs - async def __call__( - self, - record: LatentRecord, - ) -> list[EmbeddingOutput]: + async def __call__( # type: ignore + self, # type: ignore + record: LatentRecord, # type: ignore + ) -> ScorerResult: # type: ignore samples = self._prepare(record) random.shuffle(samples) results = self._query( record.explanation, - samples, + samples, # type: ignore ) return ScorerResult(record=record, score=results) def call_sync(self, record: LatentRecord) -> list[EmbeddingOutput]: - return asyncio.run(self.__call__(record)) + return asyncio.run(self.__call__(record)) # type: ignore def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ @@ -68,21 +68,21 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: "tokenizer": self.tokenizer, } samples = examples_to_samples( - record.extra_examples, + record.extra_examples, # type: ignore distance=-1, - **defaults, + **defaults, # type: ignore ) for i, examples in enumerate(record.test): samples.extend( examples_to_samples( - examples, + examples, # type: ignore distance=i + 1, - **defaults, + **defaults, # type: ignore ) ) - return samples + return samples # type: ignore def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutput]: explanation_string = ( diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py index 73eca99c..aa2f6b8e 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py @@ -78,7 +78,7 @@ async def _simulate_and_score_sequence( scored_sequence_simulation = ScoredSequenceSimulation( distance=quantile, simulation=simulation, - true_activations=activations.activations.tolist(), + true_activations=activations.activations.tolist(), # type: ignore ev_correlation_score=score_from_simulation( activations, simulation, correlation_score ), @@ -135,14 +135,14 @@ def aggregate_scored_sequence_simulations( rsquared_score = 0 absolute_dev_explained_score = 0 - scored_sequence_simulations = [default(s) for s in scored_sequence_simulations] + scored_sequence_simulations = [default(s) for s in scored_sequence_simulations] # type: ignore ev_correlation_score = fix_nan(ev_correlation_score) return ScoredSimulation( distance=distance, scored_sequence_simulations=scored_sequence_simulations, - ev_correlation_score=ev_correlation_score, + ev_correlation_score=ev_correlation_score, # type: ignore rsquared_score=float(rsquared_score), absolute_dev_explained_score=float(absolute_dev_explained_score), ) @@ -164,7 +164,7 @@ async def simulate_and_score( _simulate_and_score_sequence( simulator, activation_record, quantile + 1 ) - for activation_record in activation_quantile + for activation_record in activation_quantile # type: ignore ] ) for quantile, activation_quantile in enumerate(activation_records) @@ -173,7 +173,7 @@ async def simulate_and_score( if len(non_activation_records) > 0: non_activating_scored_seq_simulations = await asyncio.gather( *[ - _simulate_and_score_sequence(simulator, non_activation_record[0], -1) + _simulate_and_score_sequence(simulator, non_activation_record[0], -1) # type: ignore for non_activation_record in non_activation_records ] ) @@ -196,4 +196,4 @@ async def simulate_and_score( if len(non_activation_records) > 0: all_data = all_activated + non_activating_scored_seq_simulations values.append(aggregate_scored_sequence_simulations(all_data, 0)) - return values + return values # type: ignore diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py index 6e7d7a7b..b2ed1de8 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py @@ -345,6 +345,7 @@ def _format_record_for_logprob_free_simulation( token = END_OF_TEXT_TOKEN_REPLACEMENT # We use a weird unicode character here to make it easier to parse the response (can split on "༗\n"). if include_activations: + assert normalized_activations is not None response += f"{token}\t{normalized_activations[i]}༗\n" else: response += f"{token}\t༗\n" diff --git a/delphi/scorers/simulator/oai_simulator.py b/delphi/scorers/simulator/oai_simulator.py index 53b69d6c..b3038315 100644 --- a/delphi/scorers/simulator/oai_simulator.py +++ b/delphi/scorers/simulator/oai_simulator.py @@ -37,9 +37,9 @@ async def __call__(self, record): record.explanation, ) - valid_activation_records = self.to_activation_records(record.test) + valid_activation_records = self.to_activation_records(record.test) # type: ignore if len(record.not_active) > 0: - non_activation_records = self.to_activation_records([record.not_active]) + non_activation_records = self.to_activation_records([record.not_active]) # type: ignore else: non_activation_records = [] @@ -53,13 +53,13 @@ async def __call__(self, record): ) def to_activation_records(self, examples: list[Example]) -> list[ActivationRecord]: - return [ + return [ # type: ignore [ ActivationRecord( self.tokenizer.batch_decode(example.tokens), example.normalized_activations.half(), ) - for example in quantiles + for example in quantiles # type: ignore ] for quantiles in examples ] diff --git a/delphi/scorers/surprisal/surprisal.py b/delphi/scorers/surprisal/surprisal.py index 1231450e..ee92b1c1 100644 --- a/delphi/scorers/surprisal/surprisal.py +++ b/delphi/scorers/surprisal/surprisal.py @@ -55,10 +55,10 @@ def __init__( self.batch_size = batch_size self.generation_kwargs = generation_kwargs - async def __call__( - self, - record: LatentRecord, - ) -> ScorerResult: + async def __call__( # type: ignore + self, # type: ignore + record: LatentRecord, # type: ignore + ) -> ScorerResult: # type: ignore samples = self._prepare(record) random.shuffle(samples) From a1a5bb48c24192fa0fde9d93881063779d33e588 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Wed, 5 Mar 2025 01:11:45 +0000 Subject: [PATCH 6/6] add type ignores --- delphi/clients/openrouter.py | 2 +- delphi/scorers/simulator/oai_autointerp/explanations/scoring.py | 2 ++ .../scorers/simulator/oai_autointerp/explanations/simulator.py | 2 ++ delphi/scorers/simulator/oai_simulator.py | 2 +- 4 files changed, 6 insertions(+), 2 deletions(-) diff --git a/delphi/clients/openrouter.py b/delphi/clients/openrouter.py index 73e5b864..2a6e8b85 100644 --- a/delphi/clients/openrouter.py +++ b/delphi/clients/openrouter.py @@ -35,7 +35,7 @@ def postprocess(self, response): msg = response_json["choices"][0]["message"]["content"] return Response(msg) - async def generate( + async def generate( # type: ignore self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore ) -> Response: # type: ignore kwargs.pop("schema", None) diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py index aa2f6b8e..dbdc2a16 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py @@ -177,6 +177,8 @@ async def simulate_and_score( for non_activation_record in non_activation_records ] ) + else: + non_activating_scored_seq_simulations = [] # with open('test.txt', 'w') as f: # f.write(str(scored_sequence_simulations)) diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py index b2ed1de8..d8c056c4 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py @@ -331,6 +331,7 @@ def _format_record_for_logprob_free_simulation( max_activation: Optional[float] = None, ) -> str: response = "" + normalized_activations = None if include_activations: assert max_activation is not None assert len(activation_record.tokens) == len( @@ -339,6 +340,7 @@ def _format_record_for_logprob_free_simulation( normalized_activations = normalize_activations( activation_record.activations, max_activation=max_activation ) + for i, token in enumerate(activation_record.tokens): # Edge Case #3: End tokens confuse the chat-based simulator. Replace end token with "not end token". if token.strip() == END_OF_TEXT_TOKEN: diff --git a/delphi/scorers/simulator/oai_simulator.py b/delphi/scorers/simulator/oai_simulator.py index b3038315..e9436956 100644 --- a/delphi/scorers/simulator/oai_simulator.py +++ b/delphi/scorers/simulator/oai_simulator.py @@ -25,7 +25,7 @@ def __init__( self.tokenizer = tokenizer self.all_at_once = all_at_once - async def __call__(self, record): + async def __call__(self, record): # type: ignore # Simulate and score the explanation. cls = ( ExplanationNeuronSimulator