diff --git a/README.md b/README.md index 0dbdb88a..a720fc48 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Install this library as a local editable installation. Run the following command To run the default pipeline from the command line, use the following command: -`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B` +`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B` This command will: 1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample. diff --git a/delphi/__main__.py b/delphi/__main__.py index 186e9a9e..33e35e6d 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -12,7 +12,6 @@ from simple_parsing import ArgumentParser from sparsify.data import chunk_and_tokenize from torch import Tensor -from torchtyping import TensorType from transformers import ( AutoModel, AutoTokenizer, @@ -88,7 +87,7 @@ async def process_cache( latent_dict = { hook: latent_range for hook in hookpoints } # The latent range to explain - latent_dict = cast(dict[str, int | Tensor], latent_dict) + latent_dict = cast(dict[str, Tensor], latent_dict) constructor = partial( default_constructor, @@ -235,8 +234,6 @@ def populate_cache( ] tokens = truncated_tokens.reshape(-1, cfg.ctx_len) - tokens = cast(TensorType["batch", "seq"], tokens) - cache = LatentCache( model, hookpoint_to_sparse_encode, diff --git a/delphi/explainers/default/default.py b/delphi/explainers/default/default.py index d6a69037..c1419e70 100644 --- a/delphi/explainers/default/default.py +++ b/delphi/explainers/default/default.py @@ -1,8 +1,6 @@ import asyncio -import re -from ...logger import logger -from ..explainer import Explainer, ExplainerResult +from ..explainer import Example, Explainer from .prompt_builder import build_prompt @@ -20,98 +18,35 @@ def __init__( temperature: float = 0.0, **generation_kwargs, ): - self.client = client - self.tokenizer = tokenizer - self.verbose = verbose - - self.activations = activations - self.cot = cot - self.threshold = threshold - self.temperature = temperature - self.generation_kwargs = generation_kwargs - - async def __call__(self, record): - messages = self._build_prompt(record.train) - - response = await self.client.generate( - messages, temperature=self.temperature, **self.generation_kwargs + super().__init__( + client, + tokenizer, + verbose, + activations, + cot, + threshold, + temperature, + **generation_kwargs, ) - try: - explanation = self.parse_explanation(response.text) - if self.verbose: - logger.info(f"Explanation: {explanation}") - logger.info(f"Final message to explainer: {messages[-1]['content']}") - logger.info(f"Response from explainer: {response.text}") - - return ExplainerResult(record=record, explanation=explanation) - except Exception as e: - logger.error(f"Explanation parsing failed: {e}") - return ExplainerResult( - record=record, explanation="Explanation could not be parsed." - ) - - def parse_explanation(self, text: str) -> str: - try: - match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) - return ( - match.group(1).strip() if match else "Explanation could not be parsed." - ) - except Exception as e: - logger.error(f"Explanation parsing regex failed: {e}") - raise - - def _highlight(self, index, example): - result = f"Example {index}: " - - threshold = example.max_activation * self.threshold - if self.tokenizer is not None: - str_toks = self.tokenizer.batch_decode(example.tokens) - example.str_toks = str_toks - else: - str_toks = example.tokens - example.str_toks = str_toks - activations = example.activations - - def check(i): - return activations[i] > threshold - - i = 0 - while i < len(str_toks): - if check(i): - result += "<<" - - while i < len(str_toks) and check(i): - result += str_toks[i] - i += 1 - result += ">>" - else: - result += str_toks[i] - i += 1 - - return "".join(result) - - def _join_activations(self, example): - activations = [] - - for i, activation in enumerate(example.activations): - if activation > example.max_activation * self.threshold: - activations.append( - (example.str_toks[i], int(example.normalized_activations[i])) - ) - - acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) - - return "Activations: " + acts - - def _build_prompt(self, examples): + def _build_prompt(self, examples: list[Example]) -> list[dict]: highlighted_examples = [] for i, example in enumerate(examples): - highlighted_examples.append(self._highlight(i + 1, example)) + str_toks = self.tokenizer.batch_decode(example.tokens) + activations = example.activations.tolist() + highlighted_examples.append(self._highlight(str_toks, activations)) if self.activations: - highlighted_examples.append(self._join_activations(example)) + assert ( + example.normalized_activations is not None + ), "Normalized activations are required for activations in explainer" + normalized_activations = example.normalized_activations.tolist() + highlighted_examples.append( + self._join_activations( + str_toks, activations, normalized_activations + ) + ) highlighted_examples = "\n".join(highlighted_examples) diff --git a/delphi/explainers/default/prompt_builder.py b/delphi/explainers/default/prompt_builder.py index 8ca78140..55eddcae 100644 --- a/delphi/explainers/default/prompt_builder.py +++ b/delphi/explainers/default/prompt_builder.py @@ -26,10 +26,10 @@ def build_examples( def build_prompt( - examples, + examples: str, activations: bool = False, cot: bool = False, -): +) -> list[dict]: messages = system( cot=cot, ) @@ -49,7 +49,6 @@ def build_prompt( "content": user_start, } ) - print(messages) return messages diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 2ca9d35f..8eda3dd7 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -1,12 +1,14 @@ import json import os import random +import re from abc import ABC, abstractmethod from typing import NamedTuple import aiofiles -from ..latents.latents import LatentRecord +from ..latents.latents import Example, LatentRecord +from ..logger import logger class ExplainerResult(NamedTuple): @@ -18,18 +20,120 @@ class ExplainerResult(NamedTuple): class Explainer(ABC): + """ + Abstract base class for explainers. + """ + + def __init__( + self, + client, + tokenizer, + verbose: bool = False, + activations: bool = False, + cot: bool = False, + threshold: float = 0.6, + temperature: float = 0.0, + **generation_kwargs, + ): + self.client = client + self.tokenizer = tokenizer + self.verbose = verbose + self.activations = activations + self.cot = cot + self.threshold = threshold + self.temperature = temperature + self.generation_kwargs = generation_kwargs + + async def __call__(self, record: LatentRecord) -> ExplainerResult: + messages = self._build_prompt(record.train) + + response = await self.client.generate( + messages, temperature=self.temperature, **self.generation_kwargs + ) + + try: + explanation = self.parse_explanation(response.text) + if self.verbose: + logger.info(f"Explanation: {explanation}") + logger.info(f"Messages: {messages[-1]['content']}") + logger.info(f"Response: {response}") + + return ExplainerResult(record=record, explanation=explanation) + except Exception as e: + logger.error(f"Explanation parsing failed: {e}") + return ExplainerResult( + record=record, explanation="Explanation could not be parsed." + ) + + def parse_explanation(self, text: str) -> str: + try: + match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) + if match: + return match.group(1).strip() + else: + return "Explanation could not be parsed." + except Exception as e: + logger.error(f"Explanation parsing regex failed: {e}") + raise + + def _highlight(self, str_toks: list[str], activations: list[float]) -> str: + result = "" + threshold = max(activations) * self.threshold + + def check(i): + return activations[i] > threshold + + i = 0 + while i < len(str_toks): + if check(i): + result += "<<" + + while i < len(str_toks) and check(i): + result += str_toks[i] + i += 1 + result += ">>" + else: + result += str_toks[i] + i += 1 + + return "".join(result) + + def _join_activations( + self, + str_toks: list[str], + token_activations: list[float], + normalized_activations: list[float], + ) -> str: + acts = "" + activation_count = 0 + for str_tok, token_activation, normalized_activation in zip( + str_toks, token_activations, normalized_activations + ): + if token_activation > max(token_activations) * self.threshold: + # TODO: for each example, we only show the first 10 activations + # decide on the best way to do this + if activation_count > 10: + break + acts += f'("{str_tok}" : {int(normalized_activation)}), ' + activation_count += 1 + + return "Activations: " + acts + @abstractmethod - def __call__(self, record: LatentRecord) -> ExplainerResult: + def _build_prompt(self, examples: list[Example]) -> list[dict]: pass async def explanation_loader( record: LatentRecord, explanation_dir: str ) -> ExplainerResult: - async with aiofiles.open(f"{explanation_dir}/{record.latent}.txt", "r") as f: - explanation = json.loads(await f.read()) - - return ExplainerResult(record=record, explanation=explanation) + try: + async with aiofiles.open(f"{explanation_dir}/{record.latent}.txt", "r") as f: + explanation = json.loads(await f.read()) + return ExplainerResult(record=record, explanation=explanation) + except FileNotFoundError: + print(f"No explanation found for {record.latent}") + return ExplainerResult(record=record, explanation="No explanation found") async def random_explanation_loader( diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 3e2c0531..4fd8606a 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -5,10 +5,11 @@ import numpy as np import torch +from jaxtyping import Float from safetensors.numpy import save_file from torch import Tensor -from torchtyping import TensorType from tqdm import tqdm +from transformers import PreTrainedModel from delphi.config import CacheConfig from delphi.latents.collect_activations import collect_activations @@ -21,7 +22,9 @@ class Cache: """ def __init__( - self, filters: dict[str, TensorType["indices"]] = None, batch_size: int = 64 + self, + filters: dict[str, Float[Tensor, "indices"]] | None = None, + batch_size: int = 64, ): """ Initialize the Cache. @@ -30,16 +33,31 @@ def __init__( filters: Filters for selecting specific latents. batch_size: Size of batches for processing. Defaults to 64. """ - self.latent_locations = defaultdict(list) - self.latent_activations = defaultdict(list) - self.tokens = defaultdict(list) + self.latent_locations_batches: dict[ + str, list[Float[Tensor, "batch sequence num_latents"]] + ] = defaultdict(list) + self.latent_activations_batches: dict[ + str, list[Float[Tensor, "batch sequence num_latents"]] + ] = defaultdict(list) + self.tokens_batches: dict[ + str, list[Float[Tensor, "batch sequence"]] + ] = defaultdict(list) + + self.latent_locations: dict[ + str, Float[Tensor, "batch sequence num_latents"] + ] = {} + self.latent_activations: dict[ + str, Float[Tensor, "batch sequence num_latents"] + ] = {} + self.tokens: dict[str, Float[Tensor, "batch sequence"]] = {} + self.filters = filters self.batch_size = batch_size def add( self, - latents: TensorType["batch", "sequence", "latent"], - tokens: TensorType["batch", "sequence"], + latents: Float[Tensor, "mini_batch sequence num_latents"], + tokens: Float[Tensor, "mini_batch sequence"], batch_number: int, module_path: str, ): @@ -59,28 +77,32 @@ def add( # Adjust batch indices latent_locations[:, 0] += batch_number * self.batch_size - self.latent_locations[module_path].append(latent_locations) - self.latent_activations[module_path].append(latent_activations) - self.tokens[module_path].append(tokens) + self.latent_locations_batches[module_path].append(latent_locations) + self.latent_activations_batches[module_path].append(latent_activations) + self.tokens_batches[module_path].append(tokens) def save(self): """ Concatenate the latent locations and activations for all modules. """ - for module_path in self.latent_locations.keys(): + for module_path in self.latent_locations_batches.keys(): self.latent_locations[module_path] = torch.cat( - self.latent_locations[module_path], dim=0 + self.latent_locations_batches[module_path], dim=0 ) self.latent_activations[module_path] = torch.cat( - self.latent_activations[module_path], dim=0 + self.latent_activations_batches[module_path], dim=0 ) - self.tokens[module_path] = torch.cat(self.tokens[module_path], dim=0) + self.tokens[module_path] = torch.cat( + self.tokens_batches[module_path], dim=0 + ) def get_nonzeros_batch( - self, latents: TensorType["batch", "seq", "latent"] - ) -> tuple[Tensor, Tensor]: + self, latents: Float[Tensor, "batch sequence num_latents"] + ) -> tuple[ + Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "] + ]: """ Get non-zero activations for large batches that exceed int32 max value. @@ -115,8 +137,11 @@ def get_nonzeros_batch( return nonzero_latent_locations, nonzero_latent_activations def get_nonzeros( - self, latents: TensorType["batch", "seq", "latent"], module_path: str - ) -> tuple[Tensor, Tensor]: + self, latents: Float[Tensor, "batch sequence num_latents"], module_path: str + ) -> tuple[ + Float[Tensor, "batch sequence num_latents"], + Float[Tensor, "batch sequence num_latents"], + ]: """ Get the nonzero latent locations and activations. @@ -157,10 +182,10 @@ class LatentCache: def __init__( self, - model, + model: PreTrainedModel, hookpoint_to_sparse_encode: dict[str, Callable], batch_size: int, - filters: dict[str, TensorType["indices"]] | None = None, + filters: dict[str, Float[Tensor, "indices"]] | None = None, ): """ Initialize the LatentCache. @@ -181,8 +206,8 @@ def __init__( self.filter_submodules(filters) def load_token_batches( - self, n_tokens: int, tokens: TensorType["batch", "sequence"] - ) -> list[Tensor]: + self, n_tokens: int, tokens: Float[Tensor, "batch sequence"] + ) -> list[Float[Tensor, "batch sequence"]]: """ Load and prepare token batches for processing. @@ -205,7 +230,7 @@ def load_token_batches( return token_batches - def filter_submodules(self, filters: dict[str, TensorType["indices"]]): + def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]): """ Filter submodules based on the provided filters. @@ -218,7 +243,7 @@ def filter_submodules(self, filters: dict[str, TensorType["indices"]]): filtered_submodules[hookpoint] = self.hookpoint_to_sae[hookpoint] self.hookpoint_to_sae = filtered_submodules - def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): + def run(self, n_tokens: int, tokens: Float[Tensor, "batch sequence"]): """ Run the latent caching process. @@ -277,9 +302,9 @@ def save(self, save_dir: Path, save_tokens: bool = True): if save_tokens: data["tokens"] = self.cache.tokens[module_path] - save_file(data, output_file) + save_file(data, output_file) # type: ignore - def _generate_split_indices(self, n_splits: int) -> list[tuple[int, int]]: + def _generate_split_indices(self, n_splits: int) -> list[tuple[Tensor, Tensor]]: """ Generate indices for splitting the latent space. @@ -289,6 +314,7 @@ def _generate_split_indices(self, n_splits: int) -> list[tuple[int, int]]: Returns: list[tuple[int, int]]: list of start and end indices for each split. """ + assert self.width is not None, "Width must be set before generating splits" boundaries = torch.linspace(0, self.width, steps=n_splits + 1).long() # Adjust end by one diff --git a/delphi/latents/collect_activations.py b/delphi/latents/collect_activations.py index 63c6a4b0..97489054 100644 --- a/delphi/latents/collect_activations.py +++ b/delphi/latents/collect_activations.py @@ -23,7 +23,7 @@ def collect_activations(model: PreTrainedModel, hookpoints: list[str]): handles = [] def create_hook(hookpoint: str): - def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor: + def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor | None: # If output is a tuple (like in some transformer layers), take first element if isinstance(output, tuple): activations[hookpoint] = output[0] diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 2742362e..1a287506 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -1,19 +1,19 @@ from typing import Callable, Optional import torch -from torchtyping import TensorType +from jaxtyping import Float +from torch import Tensor from .latents import LatentRecord, prepare_examples from .loader import BufferOutput def _top_k_pools( - max_buffer: TensorType["batch"], - split_activations: list[TensorType["activations"]], - buffer_tokens: TensorType["batch", "ctx_len"], - ctx_len: int, + max_buffer: Float[Tensor, "batch"], + split_activations: Float[Tensor, "activations ctx_len"], + buffer_tokens: Float[Tensor, "batch ctx_len"], max_examples: int, -): +) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: """ Get the top k activation pools. @@ -21,11 +21,10 @@ def _top_k_pools( max_buffer: The maximum buffer values. split_activations: The split activations. buffer_tokens: The buffer tokens. - ctx_len: The context length. max_examples: The maximum number of examples. Returns: - tuple[TensorType["examples", "ctx_len"], TensorType["examples", "ctx_len"]]: + tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: The token windows and activation windows. """ k = min(max_examples, len(max_buffer)) @@ -40,7 +39,7 @@ def _top_k_pools( def pool_max_activation_windows( record, buffer_output: BufferOutput, - tokens: TensorType["batch", "seq"], + tokens: Float[Tensor, "batch sequence"], ctx_len: int, max_examples: int, ): @@ -81,7 +80,7 @@ def pool_max_activation_windows( buffer_tokens = buffer_tokens[unique_ctx_indices] token_windows, activation_windows = _top_k_pools( - max_buffer, new_tensor, buffer_tokens, ctx_len, max_examples + max_buffer, new_tensor, buffer_tokens, max_examples ) record.examples = prepare_examples(token_windows, activation_windows) @@ -89,7 +88,7 @@ def pool_max_activation_windows( def random_non_activating_windows( record: LatentRecord, - tokens: TensorType["batch", "seq"], + tokens: Float[Tensor, "batch sequence"], buffer_output: BufferOutput, ctx_len: int, n_not_active: int, @@ -138,7 +137,7 @@ def random_non_activating_windows( def default_constructor( record: LatentRecord, - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None, + token_loader: Optional[Callable[[], Float[Tensor, "batch sequence"]]] | None, buffer_output: BufferOutput, n_not_active: int, ctx_len: int, diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index c1f6e2f2..b4db1aa9 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -1,10 +1,11 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional import blobfile as bf import orjson -from torchtyping import TensorType -from transformers import AutoTokenizer +from jaxtyping import Float +from torch import Tensor +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @dataclass @@ -13,13 +14,13 @@ class Example: A single example of latent data. """ - tokens: TensorType["seq"] + tokens: Float[Tensor, "ctx_len"] """Tokenized input sequence.""" - activations: TensorType["seq"] + activations: Float[Tensor, "ctx_len"] """Activation values for the input sequence.""" - normalized_activations: Optional[TensorType["seq"]] = None + normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None """Activations quantized to integers in [0, 10].""" @property @@ -34,9 +35,9 @@ def max_activation(self) -> float: def prepare_examples( - tokens: List[TensorType["seq"]], - activations: List[TensorType["seq"]], -) -> List[Example]: + tokens: Float[Tensor, "examples ctx_len"], + activations: Float[Tensor, "examples ctx_len"], +) -> list[Example]: """ Prepare a list of examples from input tokens and activations. @@ -91,7 +92,7 @@ class LatentRecord: not_active: list[Example] = field(default_factory=list) """Non-activating examples.""" - train: list[list[Example]] = field(default_factory=list) + train: list[Example] = field(default_factory=list) """Training examples.""" test: list[list[Example]] = field(default_factory=list) @@ -129,10 +130,10 @@ def save(self, directory: str, save_examples: bool = False): def display( self, - tokenizer: AutoTokenizer, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, threshold: float = 0.0, n: int = 10, - ) -> str: + ): """ Display the latent record in a formatted string. @@ -147,9 +148,7 @@ def display( """ from IPython.core.display import HTML, display - def _to_string( - tokens: TensorType["seq"], activations: TensorType["seq"] - ) -> str: + def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str: """ Convert tokens and activations to a string. @@ -177,6 +176,7 @@ def _to_string( result.append(tokens[i]) i += 1 return "".join(result) + return "" strings = [ _to_string(tokenizer.batch_decode(example.tokens), example.activations) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 47b5b03b..db295891 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -3,12 +3,13 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Callable, NamedTuple, Optional, Union +from typing import Callable, NamedTuple, Optional import numpy as np import torch +from jaxtyping import Float from safetensors.numpy import load_file -from torchtyping import TensorType +from torch import Tensor from transformers import AutoTokenizer from delphi.utils import ( @@ -27,13 +28,13 @@ class BufferOutput(NamedTuple): latent: Latent """The latent associated with this output.""" - locations: TensorType["locations", 2] + locations: Float[Tensor, "locations 2"] """Tensor of latent locations.""" - activations: TensorType["locations"] + activations: Float[Tensor, "locations"] """Tensor of latent activations.""" - tokens: TensorType["tokens"] + tokens: Float[Tensor, "tokens"] """Tensor of all tokens.""" @@ -49,12 +50,15 @@ class TensorBuffer: module_path: str """Path of the module.""" - latents: Optional[TensorType["latents"]] = None + latents: Optional[Float[Tensor, "num_latents"]] = None """Tensor of latent indices.""" min_examples: int = 120 """Minimum number of examples required. Defaults to 120.""" + tokens: Optional[Float[Tensor, "batch sequence"]] = None + """Tensor of all tokens.""" + def __iter__(self): """ Iterate over the buffer, yielding BufferOutput objects. @@ -64,7 +68,10 @@ def __iter__(self): None otherwise. """ latents, split_locations, split_activations, tokens = self.load() - + if tokens is None: + tokens = self.tokens + if tokens is None: + raise ValueError("No tokens found") for i in range(len(latents)): latent_locations = split_locations[i] latent_activations = split_activations[i] @@ -125,7 +132,7 @@ def __init__( cfg: LatentConfig, tokenizer: Optional[Callable] = None, modules: Optional[list[str]] = None, - latents: Optional[dict[str, Union[int, torch.Tensor]]] = None, + latents: Optional[dict[str, torch.Tensor]] = None, constructor: Optional[Callable] = None, sampler: Optional[Callable] = None, transform: Optional[Callable] = None, @@ -141,13 +148,16 @@ def __init__( """ self.cfg = cfg self.buffers = [] - + if modules is None: + self.modules = os.listdir(raw_dir) + else: + self.modules = modules if latents is None: - self._build(raw_dir, modules) + self._build(raw_dir) else: - self._build_selected(raw_dir, modules, latents) + self._build_selected(raw_dir, latents) # TODO: this assumes that all modules have the same config - cache_config_dir = f"{raw_dir}/{modules[0]}/config.json" + cache_config_dir = f"{raw_dir}/{self.modules[0]}/config.json" with open(cache_config_dir, "r") as f: cache_config = json.load(f) if tokenizer is None: @@ -170,7 +180,7 @@ def load_tokens(self): if not hasattr(self, "tokens"): self.tokens = load_tokenized_data( self.cache_config["ctx_len"], - self.tokenizer, + self.tokenizer, # type: ignore self.cache_config["dataset_repo"], self.cache_config["dataset_split"], self.cache_config["dataset_name"], @@ -191,7 +201,7 @@ def _edges(self, raw_dir: str, module: str) -> list[tuple[int, int]]: edges.sort(key=lambda x: x[0]) return edges - def _build(self, raw_dir: str, modules: Optional[list[str]] = None): + def _build(self, raw_dir: str): """ Build dataset buffers which load all cached latents. @@ -199,9 +209,8 @@ def _build(self, raw_dir: str, modules: Optional[list[str]] = None): raw_dir (str): Directory containing raw latent data. modules (Optional[list[str]]): list of module names to include. """ - modules = os.listdir(raw_dir) if modules is None else modules - for module in modules: + for module in self.modules: edges = self._edges(raw_dir, module) for start, end in edges: path = f"{raw_dir}/{module}/{start}_{end}.safetensors" @@ -212,24 +221,20 @@ def _build(self, raw_dir: str, modules: Optional[list[str]] = None): def _build_selected( self, raw_dir: str, - modules: list[str], - latents: dict[str, Union[int, torch.Tensor]], + latents: dict[str, torch.Tensor], ): """ Build a dataset buffer which loads only selected latents. Args: raw_dir (str): Directory containing raw latent data. - modules (list[str]): list of module names to include. latents (dict[str, Union[int, torch.Tensor]]): Dictionary of latents per module. """ - for module in modules: + for module in self.modules: edges = self._edges(raw_dir, module) selected_latents = latents[module] - if isinstance(selected_latents, int): - selected_latents = torch.tensor([selected_latents]) boundaries = [edges[0][0]] + [edge[1] + 1 for edge in edges] bucketized = torch.bucketize( diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 0b370584..0a355a9c 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -6,7 +6,9 @@ import numpy as np import torch from safetensors.numpy import load_file +from torch import nn +from delphi.latents.latents import PreActivationRecord from delphi.latents.loader import LatentDataset @@ -21,7 +23,7 @@ class NeighbourCalculator: def __init__( self, latent_dataset: Optional["LatentDataset"] = None, - autoencoder: Optional["Autoencoder"] = None, + autoencoder: Optional["nn.Module"] = None, pre_activation_record: Optional["PreActivationRecord"] = None, number_of_neighbours: int = 10, neighbour_cache: Optional[dict[str, dict[int, list[int]]]] = None, @@ -39,6 +41,7 @@ def __init__( self.latent_dataset = latent_dataset self.autoencoder = autoencoder self.pre_activation_record = pre_activation_record + self.number_of_neighbours = number_of_neighbours # load the neighbour cache from the path if neighbour_cache is not None: @@ -139,7 +142,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: """ Compute neighbour lists based on latent co-occurrence in the dataset. """ - # To be implemented paths = [] for buffer in self.latent_dataset.buffers: @@ -214,16 +216,16 @@ def populate_neighbour_cache(self, methods: list[str]) -> None: for method in methods: self._compute_neighbour_list(method) - def save_neighbour_cache(self) -> None: + def save_neighbour_cache(self, path: str) -> None: """ Save the neighbour cache to the path as a json file """ - with open(self.path, "w") as f: + with open(path, "w") as f: json.dump(self.neighbour_cache, f, indent=4) - def load_neighbour_cache(self) -> dict[str, dict[int, list[int]]]: + def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: """ Load the neighbour cache from the path as a json file """ - with open(self.path, "r") as f: + with open(path, "r") as f: return json.load(f) diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 3c0a2ca5..d3e685be 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -1,8 +1,6 @@ import random from collections import deque -from typing import Literal, cast - -from torchtyping import TensorType +from typing import Literal from ..config import ExperimentConfig from ..logger import logger @@ -31,7 +29,7 @@ def split_activation_quantiles( """ random.seed(seed) - examples = deque(examples) + queue_examples = deque(examples) max_activation = examples[0].max_activation # For 4 quantiles, thresholds are 0.25, 0.5, 0.75 @@ -41,8 +39,8 @@ def split_activation_quantiles( for threshold in thresholds: # Get all examples in quantile quantile = [] - while examples and examples[0].max_activation < threshold: - quantile.append(examples.popleft()) + while queue_examples and queue_examples[0].max_activation < threshold: + quantile.append(queue_examples.popleft()) sample = random.sample(quantile, n_samples) samples.append(sample) @@ -121,10 +119,9 @@ def train( selected_examples = [] for quantile in selected_examples_quantiles: for example in quantile: - example.normalized_activations = cast( - TensorType["seq"], - (example.activations * 10 / max_activation).floor(), - ) + example.normalized_activations = ( + example.activations * 10 / max_activation + ).floor() selected_examples.extend(quantile) return selected_examples diff --git a/delphi/sparse_coders/__init__.py b/delphi/sparse_coders/__init__.py index 731e9606..d72b3094 100644 --- a/delphi/sparse_coders/__init__.py +++ b/delphi/sparse_coders/__init__.py @@ -1,4 +1,3 @@ from .sparse_model import load_hooks_sparse_coders, load_sparse_coders __all__ = ["load_hooks_sparse_coders", "load_sparse_coders"] - diff --git a/delphi/sparse_coders/custom/gemmascope.py b/delphi/sparse_coders/custom/gemmascope.py index 1372ef4e..8a66f20c 100644 --- a/delphi/sparse_coders/custom/gemmascope.py +++ b/delphi/sparse_coders/custom/gemmascope.py @@ -14,7 +14,7 @@ def load_gemma_autoencoders( type: str, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cuda"), -): +) -> dict[str, nn.Module]: saes = {} for layer, size, l0 in zip(ae_layers, sizes, average_l0s): diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index 6da26ce9..7c58c621 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -46,7 +46,7 @@ def load_sparsify_sparse_coders( hookpoints: list[str], device: str | torch.device | None = None, compile: bool = False, -) -> dict[str, Callable]: +) -> dict[str, Sae]: """ Load sparsify sparse coders for specified hookpoints. diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index 85f747cd..5b77feae 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -6,7 +6,7 @@ from delphi.config import RunConfig from .custom.gemmascope import load_gemma_autoencoders -from .load_sparsify import load_sparsify_hooks, load_sparsify_sparse_coders +from .load_sparsify import Sae, load_sparsify_hooks, load_sparsify_sparse_coders def load_hooks_sparse_coders( @@ -62,7 +62,9 @@ def load_hooks_sparse_coders( dtype=model.dtype, device=model.device, ) - + # throw an error if the dictionary is empty + if not hookpoint_to_sparse_encode: + raise ValueError("No sparse coders loaded") return hookpoint_to_sparse_encode @@ -70,7 +72,7 @@ def load_sparse_coders( model: PreTrainedModel, run_cfg: RunConfig, compile: bool = False, -) -> dict[str, nn.Module]: +) -> dict[str, nn.Module] | dict[str, Sae]: """ Load sparse coders for specified hookpoints. @@ -79,7 +81,7 @@ def load_sparse_coders( run_cfg (RunConfig): The run configuration. Returns: - dict[str, Callable]: A dictionary mapping hookpoints to sparse coders. + dict[str, nn.Module]: A dictionary mapping hookpoints to sparse coders. """ # Add SAE hooks to the model diff --git a/delphi/tests/conftest.py b/delphi/tests/conftest.py index e241c4ee..63151218 100644 --- a/delphi/tests/conftest.py +++ b/delphi/tests/conftest.py @@ -1,6 +1,15 @@ +from typing import cast + import pytest import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import Tensor +from transformers import ( + AutoModel, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from delphi.config import CacheConfig, RunConfig from delphi.latents import LatentCache @@ -24,30 +33,31 @@ @pytest.fixture(scope="module") -def tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m") +def tokenizer() -> PreTrainedTokenizer | PreTrainedTokenizerFast: + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m") tokenizer.pad_token = tokenizer.eos_token return tokenizer @pytest.fixture(scope="module") -def model(): - model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m") +def model() -> PreTrainedModel: + model = AutoModel.from_pretrained("EleutherAI/pythia-160m") return model @pytest.fixture(scope="module") -def mock_dataset(tokenizer: AutoTokenizer) -> torch.Tensor: +def mock_dataset(tokenizer: PreTrainedTokenizer) -> torch.Tensor: tokens = tokenizer( random_text, return_tensors="pt", truncation=True, max_length=16, padding=True )["input_ids"] + tokens = cast(Tensor, tokens) + print(tokens) + print(tokens.shape) return tokens @pytest.fixture(scope="module") -def cache_setup( - tmp_path_factory, mock_dataset: torch.Tensor, model: AutoModelForCausalLM -): +def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedModel): """ This fixture creates a temporary directory, loads the model, initializes the cache, runs the cache once, saves the cache splits @@ -58,12 +68,12 @@ def cache_setup( # Load model and set run configuration run_cfg_gemma = RunConfig( - model="EleutherAI/pythia-70m", - sparse_model="EleutherAI/sae-pythia-70m-32k", + model="EleutherAI/pythia-160m", + sparse_model="EleutherAI/sae-pythia-160m-32k", hookpoints=["layers.1"], ) hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg_gemma) - + print(hookpoint_to_sparse_encode) # Define cache config and initialize cache cache_cfg = CacheConfig(batch_size=1, ctx_len=16, n_tokens=100) cache = LatentCache( @@ -81,7 +91,6 @@ def cache_setup( # Save the cache config cache.save_config(temp_dir, cache_cfg, "EleutherAI/pythia-70m") - return { "cache": cache, "tokens": tokens, diff --git a/delphi/tests/test_latents/test_cache.py b/delphi/tests/test_latents/test_cache.py index 198f0cf3..3d60d094 100644 --- a/delphi/tests/test_latents/test_cache.py +++ b/delphi/tests/test_latents/test_cache.py @@ -13,7 +13,7 @@ def test_latent_locations(cache_setup: dict[str, Any]): shape and values. """ cache = cache_setup["cache"] - locations = cache.cache.latent_locations["gpt_neox.layers.1"] + locations = cache.cache.latent_locations["layers.1"] max_values, _ = locations.max(axis=0) # Expected values based on the cache run assert max_values[0] == 5, "Expected first dimension max value to be 5" @@ -25,7 +25,7 @@ def test_split_files_created(cache_setup: dict[str, Any]): """ Test that exactly 5 cache split files have been created. """ - save_dir = cache_setup["temp_dir"] / "gpt_neox.layers.1" + save_dir = cache_setup["temp_dir"] / "layers.1" cache_files = [f for f in os.listdir(save_dir) if f.endswith(".safetensors")] assert len(cache_files) == 5, "Expected 5 split files in the cache directory" @@ -37,7 +37,7 @@ def test_split_file_contents(cache_setup: dict[str, Any]): - tokens were correctly stored and match the input tokens. - latent max values are as expected. """ - save_dir = cache_setup["temp_dir"] / "gpt_neox.layers.1" + save_dir = cache_setup["temp_dir"] / "layers.1" tokens = cache_setup["tokens"] # Choose one file to verify cache_files = os.listdir(save_dir) @@ -67,7 +67,7 @@ def test_config_file(cache_setup: dict[str, Any]): """ Test that the saved configuration file contains the correct parameters. """ - config_path = cache_setup["temp_dir"] / "gpt_neox.layers.1" / "config.json" + config_path = cache_setup["temp_dir"] / "layers.1" / "config.json" with open(config_path, "r") as f: config = json.load(f) cache_cfg = cache_setup["cache_cfg"] diff --git a/delphi/utils.py b/delphi/utils.py index cb2b56aa..6796f028 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -1,6 +1,7 @@ from typing import Any, Type, TypeVar, cast -from torchtyping import TensorType +from jaxtyping import Float +from torch import Tensor from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -28,7 +29,7 @@ def load_tokenized_data( ) tokens_ds = tokens_ds.shuffle(seed) - tokens = cast(TensorType["batch", "seq"], tokens_ds["input_ids"]) + tokens = cast(Float[Tensor, "batch seq"], tokens_ds["input_ids"]) return tokens diff --git a/pyproject.toml b/pyproject.toml index ed9ae112..36d1b782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "sparsify@git+https://github.com/EleutherAI/sparsify", "safetensors", "simple_parsing", - "torchtyping", + "jaxtyping", "fire", "blobfile", "bitsandbytes",