diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index daf7e650..bfc1c435 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -23,7 +23,12 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - pip install ".[dev]" + pip install ".[dev,visualize]" - name: Run tests run: pytest + + - name: Type Checking + uses: jakebailey/pyright-action@v1 + with: + version: 1.1.378 diff --git a/delphi/__main__.py b/delphi/__main__.py index d4ea1259..85474114 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -26,7 +26,7 @@ from delphi.pipeline import Pipe, Pipeline, process_wrapper from delphi.scorers import DetectionScorer, FuzzingScorer from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders -from delphi.utils import load_tokenized_data +from delphi.utils import assert_type, load_tokenized_data def load_artifacts(run_cfg: RunConfig): @@ -70,8 +70,11 @@ def create_neighbours( neighbours_path.mkdir(parents=True, exist_ok=True) constructor_cfg = run_cfg.constructor_cfg - if constructor_cfg.neighbours_type != "co-occurrence": - saes = load_sparse_coders(run_cfg, device="cpu") + saes = ( + load_sparse_coders(run_cfg, device="cpu") + if constructor_cfg.neighbours_type != "co-occurrence" + else {} + ) for hookpoint in hookpoints: @@ -90,6 +93,11 @@ def create_neighbours( neighbour_calculator = NeighbourCalculator( autoencoder=saes[hookpoint].cuda(), number_of_neighbours=100 ) + else: + raise ValueError( + f"Neighbour type {constructor_cfg.neighbours_type} not supported" + ) + neighbour_calculator.populate_neighbour_cache(constructor_cfg.neighbours_type) neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}") @@ -325,8 +333,11 @@ async def run( hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) - nrh = non_redundant_hookpoints( - hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite + nrh = assert_type( + dict, + non_redundant_hookpoints( + hookpoint_to_sparse_encode, latents_path, "cache" in run_cfg.overwrite + ), ) if nrh: populate_cache( @@ -340,8 +351,11 @@ async def run( del model, hookpoint_to_sparse_encode if run_cfg.constructor_cfg.non_activating_source == "neighbours": - nrh = non_redundant_hookpoints( - hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite + nrh = assert_type( + list, + non_redundant_hookpoints( + hookpoints, neighbours_path, "neighbours" in run_cfg.overwrite + ), ) if nrh: create_neighbours( @@ -353,8 +367,11 @@ async def run( else: print("Skipping neighbour creation") - nrh = non_redundant_hookpoints( - hookpoints, scores_path, "scores" in run_cfg.overwrite + nrh = assert_type( + list, + non_redundant_hookpoints( + hookpoints, scores_path, "scores" in run_cfg.overwrite + ), ) if nrh: await process_cache( diff --git a/delphi/clients/client.py b/delphi/clients/client.py index e7afc37d..a9551fcc 100644 --- a/delphi/clients/client.py +++ b/delphi/clients/client.py @@ -6,8 +6,8 @@ @dataclass class Response: text: str - logprobs: list[float] = None - prompt_logprobs: list[float] = None + logprobs: list[float] | None = None + prompt_logprobs: list[float] | None = None class Client(ABC): diff --git a/delphi/clients/offline.py b/delphi/clients/offline.py index cd70bd48..2c4d078a 100644 --- a/delphi/clients/offline.py +++ b/delphi/clients/offline.py @@ -91,8 +91,8 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): self.sampling_params.temperature = kwarg["temperature"] loop = asyncio.get_running_loop() prompts = [] - if self.statistics: - statistics = [] + statistics = [] + for batch in batches: prompt = self.tokenizer.apply_chat_template( batch, add_generation_prompt=True, tokenize=True @@ -101,7 +101,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): if self.statistics: non_cached_tokens = len( self.tokenizer.apply_chat_template( - batch[-1:], add_generation_prompt=True, tokenize=True + batch[-1:], add_generation_prompt=True, tokenize=True # type: ignore ) ) statistics.append( @@ -114,7 +114,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): response = await loop.run_in_executor( None, partial( - self.client.generate, + self.client.generate, # type: ignore prompt_token_ids=prompts, sampling_params=self.sampling_params, use_tqdm=False, @@ -127,10 +127,10 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): if self.statistics: statistics[i].num_generated_tokens = len(r.outputs[0].token_ids) # save the statistics to a file, name is a hash of the prompt - statistics[i].prompt = batches[i][-1]["content"] + statistics[i].prompt = batches[i][-1]["content"] # type: ignore statistics[i].response = r.outputs[0].text with open( - f"statistics/{hash(batches[i][-1]['content'][-100:])}.json", "w" + f"statistics/{hash(batches[i][-1]['content'][-100:])}.json", "w" # type: ignore ) as f: json.dump(statistics[i].__dict__, f, indent=4) new_response.append( @@ -142,7 +142,7 @@ async def process_func(self, batches: Union[str, list[dict[str, str]]], kwargs): ) return new_response - async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: + async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> str: # type: ignore """ Enqueue a request and wait for the result. """ @@ -150,7 +150,6 @@ async def generate(self, prompt: Union[str, list[dict[str, str]]], **kwargs) -> if self.task is None: self.task = asyncio.create_task(self._process_batches()) await self.queue.put((prompt, future, kwargs)) - # print(f"Current queue size: {self.queue.qsize()} prompts") return await future def _parse_logprobs(self, response): diff --git a/delphi/clients/openrouter.py b/delphi/clients/openrouter.py index acb80345..2a6e8b85 100644 --- a/delphi/clients/openrouter.py +++ b/delphi/clients/openrouter.py @@ -20,7 +20,7 @@ class OpenRouter(Client): def __init__( self, model: str, - api_key: str = None, + api_key: str | None = None, base_url="https://openrouter.ai/api/v1/chat/completions", ): super().__init__(model) @@ -35,9 +35,9 @@ def postprocess(self, response): msg = response_json["choices"][0]["message"]["content"] return Response(msg) - async def generate( - self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs - ) -> Response: + async def generate( # type: ignore + self, prompt: str, raw: bool = False, max_retries: int = 1, **kwargs # type: ignore + ) -> Response: # type: ignore kwargs.pop("schema", None) max_tokens = kwargs.pop("max_tokens", 500) temperature = kwargs.pop("temperature", 1.0) diff --git a/delphi/explainers/single_token_explainer.py b/delphi/explainers/single_token_explainer.py index ffef2832..439f51bc 100644 --- a/delphi/explainers/single_token_explainer.py +++ b/delphi/explainers/single_token_explainer.py @@ -32,10 +32,21 @@ def _build_prompt(self, examples): highlighted_examples = [] for i, example in enumerate(examples): - highlighted_examples.append(self._highlight(i + 1, example)) + highlighted_examples.append( + self._highlight(example.str_tokens, example.activations.tolist()) + ) if self.activations: - highlighted_examples.append(self._join_activations(example)) + assert ( + example.normalized_activations is not None + ), "Normalized activations are required for activations in explainer" + highlighted_examples.append( + self._join_activations( + example.str_tokens, + example.activations.tolist(), + example.normalized_activations.tolist(), + ) + ) return build_single_token_prompt( examples=highlighted_examples, diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 800bd0a4..1530d983 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -279,7 +279,6 @@ def run(self, n_tokens: int, tokens: token_tensor_shape): print(f"Total tokens processed: {total_tokens:,}") self.cache.save() - del sae_latents def save(self, save_dir: Path, save_tokens: bool = True): """ diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 18f695ec..94e17fd1 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -189,6 +189,9 @@ def constructor( seed=seed, tokenizer=tokenizer, ) + else: + raise ValueError(f"Invalid non-activating source: {source_non_activating}") + record.not_active = non_activating_examples return record diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index a68f146b..cf142f35 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -134,7 +134,7 @@ class LatentRecord: train: list[ActivatingExample] = field(default_factory=list) """Training examples.""" - test: list[ActivatingExample] = field(default_factory=list) + test: list[ActivatingExample] | list[list[Example]] = field(default_factory=list) """Test examples.""" neighbours: list[Neighbour] = field(default_factory=list) @@ -143,6 +143,9 @@ class LatentRecord: explanation: str = "" """Explanation of the latent.""" + extra_examples: Optional[list[Example]] = None + """Extra examples to include in the record.""" + @property def max_activation(self) -> float: """ @@ -203,7 +206,7 @@ def display( Returns: str: The formatted string. """ - from IPython.core.display import HTML, display + from IPython.core.display import HTML, display # type: ignore def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str: """ diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index af8aaaca..397637e8 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -179,6 +179,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] n_tokens = int(idx_cantor.max().item()) token_batch_size = 20_000 + co_occurrence_matrix = None done = False while not done: try: @@ -197,7 +198,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] co_occurrence_matrix = torch.zeros( (n_latents, n_latents), dtype=torch.int32 ) - # co_occurrence_matrix = co_occurrence_matrix.cuda() for start, end in tqdm( zip(batch_boundaries[:-1], batch_boundaries[1:]) @@ -239,6 +239,7 @@ def compute_jaccard(cooc_matrix): return jaccard_matrix # Compute Jaccard similarity matrix + assert co_occurrence_matrix is not None, "Co-occurrence matrix is not computed" jaccard_matrix = compute_jaccard(co_occurrence_matrix) # get the indices of the top k neighbours for each feature diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ccab5327..ce52ce18 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -258,5 +258,5 @@ def log_results(scores_path: Path, visualize_path: Path, target_modules: list[st plot_line(df, visualize_path) for score_type in df["score_type"].unique(): - score_df = df[df["score_type"] == score_type] + score_df = df.query(f"score_type == '{score_type}'") latent_balanced_score_metrics(score_df, score_type) diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 0b3f64ed..b3ead562 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -81,7 +81,9 @@ def __init__(self, loader: AsyncIterable | Callable, *pipes: Pipe | Callable): loader (Callable): The loader to be executed first. *pipes (list[Pipe]): Pipes to be executed in the pipeline. """ - self.pipes = [loader] + list(pipes) + + self.loader = loader + self.pipes = pipes async def run(self, max_concurrent: int = 10) -> list[Any]: """ @@ -136,13 +138,11 @@ async def generate_items(self) -> AsyncIterable[Any]: Raises: TypeError: If the first pipe is neither an async iterable nor a callable. """ - first_pipe = self.pipes[0] - - if isinstance(first_pipe, AsyncIterable): - async for item in first_pipe: + if isinstance(self.loader, AsyncIterable): + async for item in self.loader: yield item - elif callable(first_pipe): - for item in first_pipe(): + elif callable(self.loader): + for item in self.loader(): yield item await asyncio.sleep(0) # Allow other coroutines to run else: @@ -164,6 +164,6 @@ async def process_item( """ async with semaphore: result = item - for pipe in self.pipes[1:]: + for pipe in self.pipes: result = await pipe(result) return result diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index e5f8f915..45cb6b94 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -41,10 +41,10 @@ def __init__( self.generation_kwargs = generation_kwargs self.log_prob = log_prob - async def __call__( - self, - record: LatentRecord, - ) -> ScorerResult: + async def __call__( # type: ignore + self, # type: ignore + record: LatentRecord, # type: ignore + ) -> ScorerResult: # type: ignore samples = self._prepare(record) random.shuffle(samples) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 9724ad46..bd78fbdf 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -42,7 +42,7 @@ def __init__( def prompt(self, examples: str, explanation: str) -> list[dict]: return prompt(examples, explanation) - def _prepare(self, record: LatentRecord) -> list[Sample]: + def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore """ Prepare and shuffle a list of samples for classification. """ @@ -57,7 +57,7 @@ def _prepare(self, record: LatentRecord) -> list[Sample]: samples.extend( examples_to_samples( - record.test, + record.test, # type: ignore ) ) diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 31c1d8e2..667db798 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -61,13 +61,13 @@ def mean_n_activations_ceil(self, examples: list[ActivatingExample]): return ceil(avg) - def _prepare(self, record: LatentRecord) -> list[Sample]: + def _prepare(self, record: LatentRecord) -> list[Sample]: # type: ignore """ Prepare and shuffle a list of samples for classification. """ assert len(record.test) > 0, "No test records found" - n_incorrect = self.mean_n_activations_ceil(record.test) + n_incorrect = self.mean_n_activations_ceil(record.test) # type: ignore if len(record.not_active) > 0: samples = examples_to_samples( @@ -81,7 +81,7 @@ def _prepare(self, record: LatentRecord) -> list[Sample]: samples.extend( examples_to_samples( - record.test, + record.test, # type: ignore n_incorrect=0, highlighted=True, ) diff --git a/delphi/scorers/embedding/embedding.py b/delphi/scorers/embedding/embedding.py index 677f12a7..2de89874 100644 --- a/delphi/scorers/embedding/embedding.py +++ b/delphi/scorers/embedding/embedding.py @@ -42,22 +42,22 @@ def __init__( self.tokenizer = tokenizer self.generation_kwargs = generation_kwargs - async def __call__( - self, - record: LatentRecord, - ) -> list[EmbeddingOutput]: + async def __call__( # type: ignore + self, # type: ignore + record: LatentRecord, # type: ignore + ) -> ScorerResult: # type: ignore samples = self._prepare(record) random.shuffle(samples) results = self._query( record.explanation, - samples, + samples, # type: ignore ) return ScorerResult(record=record, score=results) def call_sync(self, record: LatentRecord) -> list[EmbeddingOutput]: - return asyncio.run(self.__call__(record)) + return asyncio.run(self.__call__(record)) # type: ignore def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ @@ -68,21 +68,21 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: "tokenizer": self.tokenizer, } samples = examples_to_samples( - record.extra_examples, + record.extra_examples, # type: ignore distance=-1, - **defaults, + **defaults, # type: ignore ) for i, examples in enumerate(record.test): samples.extend( examples_to_samples( - examples, + examples, # type: ignore distance=i + 1, - **defaults, + **defaults, # type: ignore ) ) - return samples + return samples # type: ignore def _query(self, explanation: str, samples: list[Sample]) -> list[EmbeddingOutput]: explanation_string = ( diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py index 73eca99c..dbdc2a16 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/scoring.py @@ -78,7 +78,7 @@ async def _simulate_and_score_sequence( scored_sequence_simulation = ScoredSequenceSimulation( distance=quantile, simulation=simulation, - true_activations=activations.activations.tolist(), + true_activations=activations.activations.tolist(), # type: ignore ev_correlation_score=score_from_simulation( activations, simulation, correlation_score ), @@ -135,14 +135,14 @@ def aggregate_scored_sequence_simulations( rsquared_score = 0 absolute_dev_explained_score = 0 - scored_sequence_simulations = [default(s) for s in scored_sequence_simulations] + scored_sequence_simulations = [default(s) for s in scored_sequence_simulations] # type: ignore ev_correlation_score = fix_nan(ev_correlation_score) return ScoredSimulation( distance=distance, scored_sequence_simulations=scored_sequence_simulations, - ev_correlation_score=ev_correlation_score, + ev_correlation_score=ev_correlation_score, # type: ignore rsquared_score=float(rsquared_score), absolute_dev_explained_score=float(absolute_dev_explained_score), ) @@ -164,7 +164,7 @@ async def simulate_and_score( _simulate_and_score_sequence( simulator, activation_record, quantile + 1 ) - for activation_record in activation_quantile + for activation_record in activation_quantile # type: ignore ] ) for quantile, activation_quantile in enumerate(activation_records) @@ -173,10 +173,12 @@ async def simulate_and_score( if len(non_activation_records) > 0: non_activating_scored_seq_simulations = await asyncio.gather( *[ - _simulate_and_score_sequence(simulator, non_activation_record[0], -1) + _simulate_and_score_sequence(simulator, non_activation_record[0], -1) # type: ignore for non_activation_record in non_activation_records ] ) + else: + non_activating_scored_seq_simulations = [] # with open('test.txt', 'w') as f: # f.write(str(scored_sequence_simulations)) @@ -196,4 +198,4 @@ async def simulate_and_score( if len(non_activation_records) > 0: all_data = all_activated + non_activating_scored_seq_simulations values.append(aggregate_scored_sequence_simulations(all_data, 0)) - return values + return values # type: ignore diff --git a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py index eb29e0dd..d8c056c4 100644 --- a/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py +++ b/delphi/scorers/simulator/oai_autointerp/explanations/simulator.py @@ -107,8 +107,8 @@ def parse_top_logprobs(top_logprobs: dict[str, float]) -> OrderedDict[int, float """ probabilities_by_distribution_value = OrderedDict() for token, contents in top_logprobs.items(): - logprob = contents.logprob - decoded_token = contents.decoded_token + logprob = contents.logprob # type: ignore + decoded_token = contents.decoded_token # type: ignore if decoded_token in VALID_ACTIVATION_TOKENS: token_as_int = int(decoded_token) probabilities_by_distribution_value[token_as_int] = np.exp(logprob) @@ -134,7 +134,7 @@ def compute_predicted_activation_stats_for_token( def parse_simulation_response( - response: dict[str, Any], + response: Any, tokenized_prompt: list[int], tab_token: int, tokens: Sequence[str], @@ -250,11 +250,11 @@ async def simulate( else: assert isinstance(prompt, str) - response = await self.client.generate(prompt, **sampling_params) - tokenized_prompt = self.client.tokenizer.apply_chat_template( + response = await self.client.generate(prompt, **sampling_params) # type: ignore + tokenized_prompt = self.client.tokenizer.apply_chat_template( # type: ignore prompt, add_generation_prompt=True ) - tab_token = self.client.tokenizer.encode("\t")[1] + tab_token = self.client.tokenizer.encode("\t")[1] # type: ignore logger.debug("response in score_explanation_by_activations is %s", response) try: result = parse_simulation_response( @@ -287,7 +287,7 @@ def make_simulation_prompt( # Consider reconciling them. prompt_builder = PromptBuilder() prompt_builder.add_message( - "system", + "system", # type: ignore """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at summary of what the neuron does, and try to predict how it will fire on each token. @@ -299,7 +299,7 @@ def make_simulation_prompt( few_shot_examples = self.few_shot_example_set.get_examples() for i, example in enumerate(few_shot_examples): prompt_builder.add_message( - "user", + "user", # type: ignore f"\n\nNeuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX}" f"{example.explanation}", ) @@ -309,17 +309,17 @@ def make_simulation_prompt( start_indices=example.first_revealed_activation_indices, ) prompt_builder.add_message( - "assistant", f"\nActivations: {formatted_activation_records}\n" + "assistant", f"\nActivations: {formatted_activation_records}\n" # type: ignore ) prompt_builder.add_message( - "user", + "user", # type: ignore f"\n\nNeuron {len(few_shot_examples) + 1}\nExplanation of neuron " f"{len(few_shot_examples) + 1} behavior: {EXPLANATION_PREFIX} " f"{self.explanation.strip()}", ) prompt_builder.add_message( - "assistant", + "assistant", # type: ignore f"\nActivations: {format_sequences_for_simulation([tokens])}", ) return prompt_builder.build(self.prompt_format) @@ -331,6 +331,7 @@ def _format_record_for_logprob_free_simulation( max_activation: Optional[float] = None, ) -> str: response = "" + normalized_activations = None if include_activations: assert max_activation is not None assert len(activation_record.tokens) == len( @@ -339,12 +340,14 @@ def _format_record_for_logprob_free_simulation( normalized_activations = normalize_activations( activation_record.activations, max_activation=max_activation ) + for i, token in enumerate(activation_record.tokens): # Edge Case #3: End tokens confuse the chat-based simulator. Replace end token with "not end token". if token.strip() == END_OF_TEXT_TOKEN: token = END_OF_TEXT_TOKEN_REPLACEMENT # We use a weird unicode character here to make it easier to parse the response (can split on "༗\n"). if include_activations: + assert normalized_activations is not None response += f"{token}\t{normalized_activations[i]}༗\n" else: response += f"{token}\t༗\n" @@ -595,7 +598,7 @@ async def simulate(self, tokens: Sequence[str]) -> SequenceSimulation: result = SequenceSimulation( activation_scale=ActivationScale.SIMULATED_NORMALIZED_ACTIVATIONS, - expected_activations=predicted_activations, + expected_activations=predicted_activations, # type: ignore # Since the predicted activation is just a sampled token, we don't have a distribution. distribution_values=[], distribution_probabilities=[], @@ -614,7 +617,7 @@ def _make_simulation_prompt_json( assert explanation != "" prompt_builder = PromptBuilder() prompt_builder.add_message( - "system", + "system", # type: ignore """We're studying neurons in a neural network. Each neuron looks for certain things in a short document. Your task is to read the explanation of what the neuron does, and predict the neuron's activations for each token in the document. For each document, you will see the full text of the document, then the tokens in the document with the activation left blank. You will print, in valid json, the exact same tokens verbatim, but with the activation values filled in according to the explanation. Pay special attention to the explanation's description of the context and order of tokens or words. @@ -638,7 +641,7 @@ def _make_simulation_prompt_json( } """ prompt_builder.add_message( - "user", + "user", # type: ignore _format_record_for_logprob_free_simulation_json( explanation=example.explanation, activation_record=example.activation_records[0], @@ -658,7 +661,7 @@ def _make_simulation_prompt_json( } """ prompt_builder.add_message( - "assistant", + "assistant", # type: ignore _format_record_for_logprob_free_simulation_json( explanation=example.explanation, activation_record=example.activation_records[0], @@ -678,10 +681,10 @@ def _make_simulation_prompt_json( } """ prompt_builder.add_message( - "user", + "user", # type: ignore _format_record_for_logprob_free_simulation_json( explanation=explanation, - activation_record=ActivationRecord(tokens=tokens, activations=[]), + activation_record=ActivationRecord(tokens=tokens, activations=[]), # type: ignore include_activations=False, ), ) @@ -698,7 +701,7 @@ def _make_simulation_prompt( assert explanation != "" prompt_builder = PromptBuilder() prompt_builder.add_message( - "system", + "system", # type: ignore """We're studying neurons in a neural network. Each neuron looks for some particular thing in a short document. Look at an explanation of what the neuron does, and try to predict its activations on a particular token. The activation format is tokenactivation, and activations range from 0 to 10. Most activations will be 0. @@ -716,7 +719,7 @@ def _make_simulation_prompt( example.activation_records[0], include_activations=False ) prompt_builder.add_message( - "user", + "user", # type: ignore f"Neuron {i + 1}\nExplanation of neuron {i + 1} behavior: {EXPLANATION_PREFIX} " f"{example.explanation}\n\n" f"Sequence 1 Tokens without Activations:\n{tokens_without_activations}\n\n" @@ -728,7 +731,7 @@ def _make_simulation_prompt( max_activation=few_shot_example_max_activation, ) prompt_builder.add_message( - "assistant", + "assistant", # type: ignore f"{tokens_with_activations}\n\n", ) @@ -737,7 +740,7 @@ def _make_simulation_prompt( record, include_activations=False ) prompt_builder.add_message( - "user", + "user", # type: ignore f"Sequence {record_index + 2} Tokens without Activations:\n{tks_without}\n\n" f"Sequence {record_index + 2} Tokens with Activations:\n", ) @@ -747,16 +750,16 @@ def _make_simulation_prompt( max_activation=few_shot_example_max_activation, ) prompt_builder.add_message( - "assistant", + "assistant", # type: ignore f"{tokens_with_activations}\n\n", ) neuron_index = len(few_shot_examples) + 1 tokens_without_activations = _format_record_for_logprob_free_simulation( - ActivationRecord(tokens=tokens, activations=[]), include_activations=False + ActivationRecord(tokens=tokens, activations=[]), include_activations=False # type: ignore ) prompt_builder.add_message( - "user", + "user", # type: ignore f"Neuron {neuron_index}\nExplanation of neuron {neuron_index} behavior: {EXPLANATION_PREFIX} " f"{explanation}\n\n" f"Sequence 1 Tokens without Activations:\n{tokens_without_activations}\n\n" diff --git a/delphi/scorers/simulator/oai_simulator.py b/delphi/scorers/simulator/oai_simulator.py index 53b69d6c..e9436956 100644 --- a/delphi/scorers/simulator/oai_simulator.py +++ b/delphi/scorers/simulator/oai_simulator.py @@ -25,7 +25,7 @@ def __init__( self.tokenizer = tokenizer self.all_at_once = all_at_once - async def __call__(self, record): + async def __call__(self, record): # type: ignore # Simulate and score the explanation. cls = ( ExplanationNeuronSimulator @@ -37,9 +37,9 @@ async def __call__(self, record): record.explanation, ) - valid_activation_records = self.to_activation_records(record.test) + valid_activation_records = self.to_activation_records(record.test) # type: ignore if len(record.not_active) > 0: - non_activation_records = self.to_activation_records([record.not_active]) + non_activation_records = self.to_activation_records([record.not_active]) # type: ignore else: non_activation_records = [] @@ -53,13 +53,13 @@ async def __call__(self, record): ) def to_activation_records(self, examples: list[Example]) -> list[ActivationRecord]: - return [ + return [ # type: ignore [ ActivationRecord( self.tokenizer.batch_decode(example.tokens), example.normalized_activations.half(), ) - for example in quantiles + for example in quantiles # type: ignore ] for quantiles in examples ] diff --git a/delphi/scorers/surprisal/surprisal.py b/delphi/scorers/surprisal/surprisal.py index da0ccfeb..ee92b1c1 100644 --- a/delphi/scorers/surprisal/surprisal.py +++ b/delphi/scorers/surprisal/surprisal.py @@ -3,10 +3,13 @@ from typing import NamedTuple import torch +from simple_parsing import field from torch.nn.functional import cross_entropy from transformers import PreTrainedTokenizer -from ...latents import Example, LatentRecord +from delphi.utils import assert_type + +from ...latents import ActivatingExample, Example, LatentRecord from ..scorer import Scorer, ScorerResult from .prompts import BASEPROMPT as base_prompt @@ -19,13 +22,13 @@ class SurprisalOutput: distance: float | int """Quantile or neighbor distance""" - no_explanation: list[float] = 0 + no_explanation: list[float] = field(default_factory=list) """What is the surprisal of the model with no explanation""" - explanation: list[float] = 0 + explanation: list[float] = field(default_factory=list) """What is the surprisal of the model with an explanation""" - activations: list[float] = 0 + activations: list[float] = field(default_factory=list) """What are the activations of the model""" @@ -52,10 +55,10 @@ def __init__( self.batch_size = batch_size self.generation_kwargs = generation_kwargs - async def __call__( - self, - record: LatentRecord, - ) -> list[SurprisalOutput]: + async def __call__( # type: ignore + self, # type: ignore + record: LatentRecord, # type: ignore + ) -> ScorerResult: # type: ignore samples = self._prepare(record) random.shuffle(samples) @@ -66,7 +69,7 @@ async def __call__( return ScorerResult(record=record, score=results) - def _prepare(self, record: LatentRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[Sample]: """ Prepare and shuffle a list of samples for classification. """ @@ -74,6 +77,8 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: defaults = { "tokenizer": self.tokenizer, } + + assert record.extra_examples is not None, "No extra examples provided" samples = examples_to_samples( record.extra_examples, distance=-1, @@ -81,6 +86,7 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: ) for i, examples in enumerate(record.test): + examples = assert_type(list, examples) samples.extend( examples_to_samples( examples, @@ -181,7 +187,7 @@ def _query(self, explanation: str, samples: list[Sample]) -> list[SurprisalOutpu def examples_to_samples( - examples: list[Example], + examples: list[Example] | list[ActivatingExample], tokenizer: PreTrainedTokenizer, **sample_kwargs, ) -> list[Sample]: diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index c9c259e3..1d4bfa9c 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -3,19 +3,21 @@ from typing import Callable import torch -from sparsify import Sae +from sparsify import SparseCoder from torch import Tensor from transformers import PreTrainedModel -def sae_dense_latents(x: Tensor, sae: Sae) -> Tensor: +def sae_dense_latents(x: Tensor, sae: SparseCoder) -> Tensor: """Run `sae` on `x`, yielding the dense activations.""" pre_acts = sae.pre_acts(x) acts, indices = sae.select_topk(pre_acts) return torch.zeros_like(pre_acts).scatter_(-1, indices, acts) -def resolve_path(model: PreTrainedModel, path_segments: list[str]) -> list[str] | None: +def resolve_path( + model: PreTrainedModel | torch.nn.Module, path_segments: list[str] +) -> list[str] | None: """Attempt to resolve the path segments to the model in the case where it has been wrapped (e.g. by a LanguageModel, causal model, or classifier).""" # If the first segment is a valid attribute, return the path segments @@ -45,7 +47,7 @@ def load_sparsify_sparse_coders( hookpoints: list[str], device: str | torch.device, compile: bool = False, -) -> dict[str, Sae]: +) -> dict[str, SparseCoder]: """ Load sparsify sparse coders for specified hookpoints. @@ -67,7 +69,7 @@ def load_sparsify_sparse_coders( name_path = Path(name) if name_path.exists(): for hookpoint in hookpoints: - sparse_model_dict[hookpoint] = Sae.load_from_disk( + sparse_model_dict[hookpoint] = SparseCoder.load_from_disk( name_path / hookpoint, device=device ) if compile: @@ -76,7 +78,7 @@ def load_sparsify_sparse_coders( ) else: # Load on CPU first to not run out of memory - sparse_models = Sae.load_many(name, device="cpu") + sparse_models = SparseCoder.load_many(name, device="cpu") for hookpoint in hookpoints: sparse_model_dict[hookpoint] = sparse_models[hookpoint].to(device) if compile: diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index ad55901d..fdd5769a 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -2,6 +2,7 @@ import torch import torch.nn as nn +from sparsify import SparseCoder from transformers import PreTrainedModel from delphi.config import RunConfig @@ -74,7 +75,7 @@ def load_sparse_coders( run_cfg: RunConfig, device: str | torch.device, compile: bool = False, -) -> dict[str, nn.Module]: +) -> dict[str, nn.Module] | dict[str, SparseCoder]: """ Load sparse coders for specified hookpoints. diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index a984bb5b..073ef3d0 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -4,8 +4,8 @@ import torch -from delphi.__main__ import RunConfig, run -from delphi.config import CacheConfig, ConstructorConfig, SamplerConfig +from delphi.__main__ import run +from delphi.config import CacheConfig, ConstructorConfig, RunConfig, SamplerConfig from delphi.log.result_analysis import build_scores_df, latent_balanced_score_metrics @@ -60,7 +60,8 @@ async def test(): scores_path = Path("results") / run_cfg.name / "scores" df = build_scores_df(scores_path, run_cfg.hookpoints) for score_type in df["score_type"].unique(): - score_df = df[df["score_type"] == score_type] + score_df = df.query(f"score_type == '{score_type}'") + weighted_mean_metrics = latent_balanced_score_metrics( score_df, score_type, verbose=False ) diff --git a/delphi/utils.py b/delphi/utils.py index ac0f88db..2b278cb9 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -1,3 +1,5 @@ +from typing import Any, Type, TypeVar, cast + from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @@ -29,3 +31,14 @@ def load_tokenized_data( tokens = tokens_ds["input_ids"] return tokens + + +T = TypeVar("T") + + +def assert_type(typ: Type[T], obj: Any) -> T: + """Assert that an object is of a given type at runtime and return it.""" + if not isinstance(obj, typ): + raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") + + return cast(typ, obj) diff --git a/pyproject.toml b/pyproject.toml index 4c67bcb7..45a35294 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,7 +30,8 @@ dependencies = [ dev = ["pytest"] visualize = [ "kaleido==0.2.1", - "plotly>=5.0.0rc2" + "plotly>=5.0.0rc2", + "pandas" ] [tool.pyright]