From 000203f1d6179d390487e1c0488cd96247c0a84a Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 3 Oct 2024 13:49:11 +0000 Subject: [PATCH 001/132] Cleaning up autoencoder loader --- .../model.py => Custom/gemmascope.py} | 9 +- .../{OpenAI/model.py => Custom/openai.py} | 0 .../autoencoders/DeepMind/__init__.py | 53 ------ .../autoencoders/Neurons/__init__.py | 2 +- .../autoencoders/OpenAI/__init__.py | 2 +- sae_auto_interp/autoencoders/wrapper.py | 154 +++++++++++++++++- 6 files changed, 153 insertions(+), 67 deletions(-) rename sae_auto_interp/autoencoders/{DeepMind/model.py => Custom/gemmascope.py} (90%) rename sae_auto_interp/autoencoders/{OpenAI/model.py => Custom/openai.py} (100%) delete mode 100644 sae_auto_interp/autoencoders/DeepMind/__init__.py diff --git a/sae_auto_interp/autoencoders/DeepMind/model.py b/sae_auto_interp/autoencoders/Custom/gemmascope.py similarity index 90% rename from sae_auto_interp/autoencoders/DeepMind/model.py rename to sae_auto_interp/autoencoders/Custom/gemmascope.py index 33ffac5c..a06b1aa1 100644 --- a/sae_auto_interp/autoencoders/DeepMind/model.py +++ b/sae_auto_interp/autoencoders/Custom/gemmascope.py @@ -3,6 +3,7 @@ import torch.nn as nn from huggingface_hub import hf_hub_download + # This is from the GemmaScope tutorial class JumpReLUSAE(nn.Module): def __init__(self, d_model, d_sae): @@ -30,10 +31,10 @@ def forward(self, acts): return recon @classmethod - def from_pretrained(cls, path,type,device): + def from_pretrained(cls, model_name_or_path,position,device): path_to_params = hf_hub_download( - repo_id="google/gemma-scope-9b-pt-"+type, - filename=f"{path}/params.npz", + repo_id=model_name_or_path, + filename=f"{position}/params.npz", force_download=False, ) params = np.load(path_to_params) @@ -43,5 +44,3 @@ def from_pretrained(cls, path,type,device): if device == "cuda": model.cuda() return model - - diff --git a/sae_auto_interp/autoencoders/OpenAI/model.py b/sae_auto_interp/autoencoders/Custom/openai.py similarity index 100% rename from sae_auto_interp/autoencoders/OpenAI/model.py rename to sae_auto_interp/autoencoders/Custom/openai.py diff --git a/sae_auto_interp/autoencoders/DeepMind/__init__.py b/sae_auto_interp/autoencoders/DeepMind/__init__.py deleted file mode 100644 index ea342634..00000000 --- a/sae_auto_interp/autoencoders/DeepMind/__init__.py +++ /dev/null @@ -1,53 +0,0 @@ -from functools import partial -from .model import JumpReLUSAE -from typing import List, Dict -import torch -from ..wrapper import AutoencoderLatents -DEVICE = "cuda:0" - - - - -def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str,randomize:bool=False): - submodules = {} - - for layer in ae_layers: - if randomize: - d_model = model.config.hidden_size - d_sae = 131072 - sae = JumpReLUSAE(d_model,d_sae) - #Randomize the weights - sae.W_enc.data.uniform_(-1,1) - sae.W_dec.data.uniform_(-1,1) - # This does not work - sae.threshold.data.uniform_(-1,1) - sae.b_enc.data.uniform_(-1,1) - sae.b_dec.data.uniform_(-1,1) - else: - path = f"layer_{layer}/width_{size}/average_l0_{average_l0s[layer]}" - sae = JumpReLUSAE.from_pretrained(path,type,"cuda") - - sae.half() - def _forward(sae, x): - encoded = sae.encode(x) - return encoded - if type == "res": - submodule = model.model.layers[layer] - elif type == "mlp": - submodule = model.model.layers[layer].post_feedforward_layernorm - submodule.ae = AutoencoderLatents( - sae, partial(_forward, sae), width=sae.W_enc.shape[1] - ) - - submodules[submodule.path] = submodule - - with model.edit(" ") as edited: - for _, submodule in submodules.items(): - if type == "res": - acts = submodule.output[0] - else: - acts = submodule.output - submodule.ae(acts, hook=True) - - return submodules, edited - diff --git a/sae_auto_interp/autoencoders/Neurons/__init__.py b/sae_auto_interp/autoencoders/Neurons/__init__.py index 5e5a9f7d..bd4e06ff 100644 --- a/sae_auto_interp/autoencoders/Neurons/__init__.py +++ b/sae_auto_interp/autoencoders/Neurons/__init__.py @@ -1,7 +1,7 @@ from typing import List import torch from functools import partial -from ..OpenAI.model import ACTIVATIONS_CLASSES, TopK +from ..Custom.openai import ACTIVATIONS_CLASSES, TopK DEVICE = "cuda:0" diff --git a/sae_auto_interp/autoencoders/OpenAI/__init__.py b/sae_auto_interp/autoencoders/OpenAI/__init__.py index a3d06bd8..207053ff 100644 --- a/sae_auto_interp/autoencoders/OpenAI/__init__.py +++ b/sae_auto_interp/autoencoders/OpenAI/__init__.py @@ -4,7 +4,7 @@ import torch from ..wrapper import AutoencoderLatents -from .model import Autoencoder +from ..Custom.openai import Autoencoder DEVICE = "cuda:0" diff --git a/sae_auto_interp/autoencoders/wrapper.py b/sae_auto_interp/autoencoders/wrapper.py index 761939bc..159f6e59 100644 --- a/sae_auto_interp/autoencoders/wrapper.py +++ b/sae_auto_interp/autoencoders/wrapper.py @@ -1,23 +1,163 @@ -from typing import Callable - +from typing import Callable, Optional, Union, Any, Literal, Dict import torch +from simple_parsing import Serializable +from functools import partial +class AutoencoderConfig(Serializable): + model_name_or_path: str + autoencoder_type: Literal["SAE", "SAE_LENS", "NEURONS", "CUSTOM"] = "SAE" + device: Optional[str] = None + hookpoints: Optional[List[str]] = None + kwargs: Dict[str, Any] = {} class AutoencoderLatents(torch.nn.Module): """ - Wrapper module to simplify capturing of autoencoder latents. + Unified wrapper for different types of autoencoders, compatible with nnsight. """ def __init__( self, - ae: torch.nn.Module, - _forward: Callable, + autoencoder: Any, + forward_function: Callable, width: int, ) -> None: super().__init__() - self.ae = ae - self._forward = _forward + self.ae = autoencoder + self._forward = forward_function self.width = width def forward(self, x: torch.Tensor) -> torch.Tensor: return self._forward(x) + + @classmethod + def from_pretrained( + cls, + config: AutoencoderConfig, + hookpoint: str, + **kwargs + ): + device = config.device or ('cuda' if torch.cuda.is_available() else 'cpu') + autoencoder_type = config.autoencoder_type + model_name_or_path = config.model_name_or_path + + if autoencoder_type == "SAE": + from sae import Sae + local = kwargs.get("local",None) + assert local is not None, "local must be specified for SAE" + if local: + sae = Sae.load_from_disk(model_name_or_path+"/"+hookpoint, device=device, **kwargs) + else: + sae = Sae.load_from_hub(model_name_or_path,hookpoint, device=device, **kwargs) + forward_function = choose_forward_function(config, sae) + width = sae.encoder.weight.shape[0] + + elif autoencoder_type == "SAE_LENS": + from sae_lens import SAE + sae, cfg_dict, sparsity = SAE.from_pretrained( + release = model_name_or_path, # see other options in sae_lens/pretrained_saes.yaml + sae_id = hookpoint, + device = device + ) + forward_function = choose_forward_function(config, sae) + width = sae.d_sae + + elif autoencoder_type == "NEURONS": + raise NotImplementedError("Neurons autoencoder not implemented yet") + + elif autoencoder_type == "CUSTOM": + # to use a custom autoencoder, you must make own custom autoencoder class and implement the forward function + # it should have a specific name that you use here + custom_name = config.kwargs.get("custom_name", None) + if custom_name is None: + raise ValueError("custom_name must be specified for CUSTOM autoencoder") + if custom_name == "gemmascope": + from Custom.gemmascope import JumpReLUSAE + position = config.kwargs.get("position", None) + assert position is not None, "position must be specified for gemmascope autoencoder" + sae = JumpReLUSAE.from_pretrained(model_name_or_path,position,device) + forward_function = choose_forward_function(config, sae) + if custom_name == "openai": + raise NotImplementedError("OpenAI autoencoder not implemented yet") + from Custom.openai import Autoencoder + path = f"{model_name_or_path}/{hookpoint}.pt" + state_dict = torch.load(path) + ae = Autoencoder.from_state_dict(state_dict=state_dict) + + else: + raise ValueError(f"Unsupported autoencoder type: {autoencoder_type}") + + return cls(sae, forward_function, width) + @classmethod + def random(cls, config: AutoencoderConfig, hookpoint: str, **kwargs): + pass + +def choose_forward_function(autoencoder_config: AutoencoderConfig, autoencoder: Any): + if autoencoder_config.autoencoder_type == "SAE": + from .Custom.openai import ACTIVATIONS_CLASSES, TopK + + def _forward(sae, k,x): + encoded = sae.pre_acts(x) + if k is not None: + trained_k = k + else: + trained_k = sae.cfg.k + topk = TopK(trained_k, postact_fn=ACTIVATIONS_CLASSES["Identity"]()) + return topk(encoded) + k = autoencoder_config.kwargs.get("k", None) + return partial(_forward, autoencoder, k) + + elif autoencoder_config.autoencoder_type == "SAE_LENS": + return autoencoder.encode + + elif autoencoder_config.autoencoder_type == "CUSTOM": + if autoencoder_config.kwargs.get("custom_name", None) == "gemmascope": + return autoencoder.encode + else: + raise ValueError(f"Unsupported custom autoencoder: {autoencoder_config.kwargs.get('custom_name', None)}") + +def hook_submodule(autoencoder_config: AutoencoderConfig, submodule: Any, model: Any): + + with model.edit("") as edited: + if "embed" not in submodule.path and "mlp" not in submodule.path: + acts = submodule.output[0] + else: + acts = submodule.output + submodule.ae(acts, hook=True) + return submodule,edited + + +def load_autoencoder_into_model( + model: Any, + autoencoder_config: AutoencoderConfig, + **kwargs +) -> Tuple[List[Any], Any]: + """ + Load an autoencoder and hook it into the model using nnsight. + + Args: + model (Any): The main model to hook the autoencoder into. + hookpoints (list[str]): List of paths to the submodules to hook the autoencoder into. + autoencoder_config (AutoencoderConfig): Configuration for the autoencoder. + + Returns: + Tuple[List[Any], Any]: The list of submodules with the autoencoder attached and the edited model. + """ + + submodules = {} + edited_model = model + hookpoints = autoencoder_config.hookpoints + assert hookpoints is not None, "Hookpoints must be specified in autoencoder_config" + for module_path in hookpoints: + autoencoder = AutoencoderLatents.from_pretrained( + autoencoder_config, + module_path, + ) + submodule = model.get_submodule(module_path) + autoencoder.ae = autoencoder + submodule,edited_model = hook_submodule(autoencoder_config, submodule, edited_model) + + submodules[submodule.path] = submodule + + return submodules, edited_model + + From c42a32284264a3243b9996a28f7604eee496fe1f Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 11 Oct 2024 08:41:44 +0000 Subject: [PATCH 002/132] Preparing generation scorer --- sae_auto_interp/scorers/__init__.py | 3 +-- sae_auto_interp/scorers/generation/generation.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/sae_auto_interp/scorers/__init__.py b/sae_auto_interp/scorers/__init__.py index 71f20094..01c0e4d3 100644 --- a/sae_auto_interp/scorers/__init__.py +++ b/sae_auto_interp/scorers/__init__.py @@ -2,14 +2,13 @@ from .classifier.neighbor import NeighborScorer from .classifier.recall import RecallScorer from .classifier.utils import get_neighbors, load_neighbors -from .generation.generation import GenerationScorer +#from .generation.generation import GenerationScorer from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer from .embedding.embedding import EmbedingScorer __all__ = [ "FuzzingScorer", - "GenerationScorer", "NeighborScorer", "OpenAISimulator", "RecallScorer", diff --git a/sae_auto_interp/scorers/generation/generation.py b/sae_auto_interp/scorers/generation/generation.py index 5988fa64..51499524 100644 --- a/sae_auto_interp/scorers/generation/generation.py +++ b/sae_auto_interp/scorers/generation/generation.py @@ -1,4 +1,4 @@ -from ...clients import Client, create_response_model +from ...clients import Client from ..scorer import Scorer, ScorerResult from .prompts import get_gen_scorer_template import re From 96b8a74921ae0fcf6e0f69c37eee0a617a40a1b9 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 11 Oct 2024 08:42:08 +0000 Subject: [PATCH 003/132] Re-doing autoencoders --- sae_auto_interp/autoencoders/__init__.py | 7 ++- sae_auto_interp/autoencoders/eleuther.py | 2 +- sae_auto_interp/autoencoders/wrapper.py | 58 +++++++++++++++++------- 3 files changed, 46 insertions(+), 21 deletions(-) diff --git a/sae_auto_interp/autoencoders/__init__.py b/sae_auto_interp/autoencoders/__init__.py index b416a37d..19a0eb9d 100644 --- a/sae_auto_interp/autoencoders/__init__.py +++ b/sae_auto_interp/autoencoders/__init__.py @@ -1,13 +1,12 @@ -from .eleuther import load_eai_autoencoders from .Neurons import load_llama3_neurons from .OpenAI import load_oai_autoencoders from .Sam import load_sam_autoencoders -from .DeepMind import load_gemma_autoencoders +from .wrapper import load_autoencoder_into_model, AutoencoderConfig __all__ = [ - "load_eai_autoencoders", - "load_gemma_autoencoders", + "load_autoencoder_into_model", "load_llama3_neurons", "load_oai_autoencoders", "load_sam_autoencoders", + "AutoencoderConfig", ] diff --git a/sae_auto_interp/autoencoders/eleuther.py b/sae_auto_interp/autoencoders/eleuther.py index d38846b7..634b550c 100644 --- a/sae_auto_interp/autoencoders/eleuther.py +++ b/sae_auto_interp/autoencoders/eleuther.py @@ -4,7 +4,7 @@ from sae import Sae -from .OpenAI.model import ACTIVATIONS_CLASSES, TopK +from .Custom.openai import ACTIVATIONS_CLASSES, TopK from .wrapper import AutoencoderLatents DEVICE = "cuda:0" diff --git a/sae_auto_interp/autoencoders/wrapper.py b/sae_auto_interp/autoencoders/wrapper.py index 159f6e59..263384a6 100644 --- a/sae_auto_interp/autoencoders/wrapper.py +++ b/sae_auto_interp/autoencoders/wrapper.py @@ -1,14 +1,18 @@ -from typing import Callable, Optional, Union, Any, Literal, Dict +from dataclasses import dataclass, field +from functools import partial +from typing import Any, Callable, Dict, List, Literal, Optional, Tuple + import torch from simple_parsing import Serializable -from functools import partial + +@dataclass class AutoencoderConfig(Serializable): - model_name_or_path: str + model_name_or_path: str = "model" autoencoder_type: Literal["SAE", "SAE_LENS", "NEURONS", "CUSTOM"] = "SAE" device: Optional[str] = None hookpoints: Optional[List[str]] = None - kwargs: Dict[str, Any] = {} + kwargs: Dict[str, Any] = field(default_factory=dict) class AutoencoderLatents(torch.nn.Module): """ @@ -71,11 +75,10 @@ def from_pretrained( if custom_name is None: raise ValueError("custom_name must be specified for CUSTOM autoencoder") if custom_name == "gemmascope": - from Custom.gemmascope import JumpReLUSAE - position = config.kwargs.get("position", None) - assert position is not None, "position must be specified for gemmascope autoencoder" - sae = JumpReLUSAE.from_pretrained(model_name_or_path,position,device) + from .Custom.gemmascope import JumpReLUSAE + sae = JumpReLUSAE.from_pretrained(model_name_or_path,hookpoint,device) forward_function = choose_forward_function(config, sae) + width = sae.W_enc.data.shape[1] if custom_name == "openai": raise NotImplementedError("OpenAI autoencoder not implemented yet") from Custom.openai import Autoencoder @@ -115,10 +118,34 @@ def _forward(sae, k,x): else: raise ValueError(f"Unsupported custom autoencoder: {autoencoder_config.kwargs.get('custom_name', None)}") -def hook_submodule(autoencoder_config: AutoencoderConfig, submodule: Any, model: Any): - +def get_submodule(model:Any,autoencoder_config:AutoencoderConfig,hookpoint:str)->Any: + + if autoencoder_config.autoencoder_type == "SAE": + if "res" in hookpoint: + submodule = model.model.get_submodule(hookpoint) + elif "mlp" in hookpoint: + layer = int(hookpoint.split(".")[-1]) + submodule = model.model.layers[layer].mlp + else: + raise ValueError(f"Unsupported hookpoint: {hookpoint}") + return submodule + elif autoencoder_config.autoencoder_type == "SAE_LENS": + raise NotImplementedError("SAE_LENS not implemented yet") + #return model.get_submodule(hookpoint) + elif autoencoder_config.autoencoder_type == "CUSTOM": + if autoencoder_config.kwargs.get("custom_name", None) == "gemmascope": + layer = int(hookpoint.split("/")[0].split("_")[-1]) + model_name = autoencoder_config.model_name_or_path + if "res" in model_name: + submodule = model.model.layers[layer] + if "mlp" in model_name: + submodule = model.model.layers[layer].post_feedforward_layernorm + return submodule + +def hook_submodule( submodule: Any, model: Any,module_path:str,autoencoder_config:AutoencoderConfig)->Tuple[Any,Any]: + #TODO: This should take into account the autoencoder config, but for now I think this is valid for all with model.edit("") as edited: - if "embed" not in submodule.path and "mlp" not in submodule.path: + if "embed" not in module_path and "mlp" not in module_path: acts = submodule.output[0] else: acts = submodule.output @@ -130,13 +157,12 @@ def load_autoencoder_into_model( model: Any, autoencoder_config: AutoencoderConfig, **kwargs -) -> Tuple[List[Any], Any]: +) -> Tuple[Dict[str,Any], Any]: """ Load an autoencoder and hook it into the model using nnsight. Args: model (Any): The main model to hook the autoencoder into. - hookpoints (list[str]): List of paths to the submodules to hook the autoencoder into. autoencoder_config (AutoencoderConfig): Configuration for the autoencoder. Returns: @@ -152,9 +178,9 @@ def load_autoencoder_into_model( autoencoder_config, module_path, ) - submodule = model.get_submodule(module_path) - autoencoder.ae = autoencoder - submodule,edited_model = hook_submodule(autoencoder_config, submodule, edited_model) + submodule = get_submodule(edited_model,autoencoder_config,module_path) + submodule.ae = autoencoder + submodule,edited_model = hook_submodule(submodule, edited_model,module_path,autoencoder_config) submodules[submodule.path] = submodule From 8a47fbb05480936775957184daf47d0ede36a191 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 11 Oct 2024 08:42:30 +0000 Subject: [PATCH 004/132] Probability and conditional probability --- .../scorers/classifier/classifier.py | 29 +++++++++++-------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/sae_auto_interp/scorers/classifier/classifier.py b/sae_auto_interp/scorers/classifier/classifier.py index 7740fe9e..4b6398ea 100644 --- a/sae_auto_interp/scorers/classifier/classifier.py +++ b/sae_auto_interp/scorers/classifier/classifier.py @@ -39,7 +39,6 @@ async def __call__( record: FeatureRecord, ) -> list[ClassifierOutput]: samples = self._prepare(record) - random.shuffle(samples) samples = self._batch(samples) results = await self._query( @@ -84,19 +83,22 @@ async def _generate( prompt = self._build_prompt(explanation, batch) if self.log_prob: self.generation_kwargs["logprobs"] = True - self.generation_kwargs["top_logprobs"] = 5 + self.generation_kwargs["top_logprobs"] = 10 response = await self.client.generate(prompt, **self.generation_kwargs) if response is None: array = [-1] * self.batch_size + conditional_probabilities = [-1] * self.batch_size probabilities = [-1] * self.batch_size + else: selections = response.text logprobs = response.logprobs if self.log_prob else None try: - array, probabilities = self._parse(selections, logprobs) + array,conditional_probabilities, probabilities = self._parse(selections, logprobs) except Exception as e: logger.error(f"Parsing selections failed: {e}") array = [-1] * self.batch_size + conditional_probabilities = [-1] * self.batch_size probabilities = [-1] * self.batch_size results = [] @@ -111,6 +113,7 @@ async def _generate( response.append(prediction) if self.log_prob: result.probability = probabilities[i] + result.conditional_probability = conditional_probabilities[i] results.append(result) if self.verbose: @@ -125,21 +128,21 @@ def _parse(self, string, logprobs=None): array = json.loads(match.group(0)) assert len(array) == self.batch_size if self.log_prob: - probabilities = self._parse_logprobs(logprobs) - assert len(probabilities) == self.batch_size - return array, probabilities + conditional_probabilities,probabilities = self._parse_logprobs(logprobs) + assert len(conditional_probabilities) == self.batch_size + return array, conditional_probabilities, probabilities + conditional_probabilities = None probabilities = None - return array, probabilities + return array, conditional_probabilities, probabilities except (json.JSONDecodeError, AssertionError, AttributeError) as e: logger.error(f"Parsing array failed: {e}") - if self.log_prob: - return [-1] * self.batch_size, [-1] * self.batch_size - return [-1] * self.batch_size + return [-1] * self.batch_size, [-1] * self.batch_size, [-1] * self.batch_size def _parse_logprobs(self, logprobs): #Logprobs will be a list of 5 probabilites for each token in the response # The response will be in the form of [x, x, x, ...] for each position we # need to get the probability of 1 or 0 + conditional_probabilities = [] probabilities = [] for i in range(len(logprobs)): @@ -155,10 +158,12 @@ def _parse_logprobs(self, logprobs): elif "1" in token: prob_1 += np.exp(logprob).item() if prob_0+prob_1>0: - probabilities.append(prob_1/(prob_0+prob_1)) + conditional_probabilities.append(prob_1/(prob_0+prob_1)) + probabilities.append(prob_1) else: + conditional_probabilities.append(0) probabilities.append(0) - return probabilities + return conditional_probabilities,probabilities From da30b38cf375d2ed74c3e15153f514e222895eff Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 11 Oct 2024 08:42:51 +0000 Subject: [PATCH 005/132] First pass on clients --- sae_auto_interp/clients/__init__.py | 8 +- sae_auto_interp/clients/client.py | 18 ++-- sae_auto_interp/clients/local.py | 33 ++---- sae_auto_interp/clients/offline.py | 138 +++++++++++--------------- sae_auto_interp/clients/openrouter.py | 46 ++------- 5 files changed, 88 insertions(+), 155 deletions(-) diff --git a/sae_auto_interp/clients/__init__.py b/sae_auto_interp/clients/__init__.py index df64fd67..f0b089d2 100644 --- a/sae_auto_interp/clients/__init__.py +++ b/sae_auto_interp/clients/__init__.py @@ -1,7 +1,7 @@ from .client import Client -from .local import Local +#from .local import Local from .offline import Offline -from .openrouter import OpenRouter -from .outlines import Outlines +#from .openrouter import OpenRouter +#from .outlines import Outlines -__all__ = ["Client", "Local", "OpenRouter", "Outlines", "HuggingFace", "Offline"] +__all__ = ["Client", "Offline"] diff --git a/sae_auto_interp/clients/client.py b/sae_auto_interp/clients/client.py index 7866afe0..c6fc409f 100644 --- a/sae_auto_interp/clients/client.py +++ b/sae_auto_interp/clients/client.py @@ -1,13 +1,13 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import List +from typing import List, Union, Dict, Any @dataclass class Response: text: str - logprobs: List[float] - prompt_logprobs: List[float] + logprobs: List[float] = None + prompt_logprobs: List[float] = None class Client(ABC): @@ -15,10 +15,10 @@ def __init__(self, model: str): self.model = model @abstractmethod - async def generate( - self, - prompt: str, - **kwargs - ): - raise NotImplementedError + async def generate(self, prompt: Union[str, List[Dict[str, str]]], **kwargs) -> Response: + pass + + @abstractmethod + async def process_response(self, raw_response: Any) -> Response: + pass diff --git a/sae_auto_interp/clients/local.py b/sae_auto_interp/clients/local.py index 71ab484c..ec993c26 100644 --- a/sae_auto_interp/clients/local.py +++ b/sae_auto_interp/clients/local.py @@ -13,19 +13,9 @@ class Local(Client): def __init__(self, model: str, base_url="http://localhost:8000/v1"): super().__init__(model) self.client = AsyncOpenAI(base_url=base_url, api_key="EMPTY", timeout=None) - self.model = model self.tokenizer = AutoTokenizer.from_pretrained(model) - async def generate( - self, - prompt: str, - use_legacy_api: bool = False, - max_retries: int = 2, - **kwargs - ) -> Response: - """ - Wrapper method for vLLM post requests. - """ + async def generate(self, prompt: Union[str, List[Dict[str, str]]], use_legacy_api: bool = False, max_retries: int = 2, **kwargs) -> Response: try: for attempt in range(max_retries): try: @@ -43,29 +33,22 @@ async def generate( ) if response is None: raise Exception("Response is None") - return self.postprocess(response) - + return await self.process_response(response) except json.JSONDecodeError as e: logger.warning(f"Attempt {attempt + 1}: Invalid JSON response, retrying... {e}") - except Exception as e: logger.warning(f"Attempt {attempt + 1}: {str(e)}, retrying...") - await asyncio.sleep(1) except Exception as e: logger.error(f"All retry attempts failed. Most recent error: {e}") raise - - def postprocess(self, response: dict) -> Response: - """ - Postprocess the response from the API. - """ - new_response=Response( - text=response.choices[0].message.content, - logprobs=response.choices[0].logprobs, - prompt_logprobs=response.choices[0].prompt_logprobs + + async def process_response(self, raw_response: Any) -> Response: + return Response( + text=raw_response.choices[0].message.content, + logprobs=raw_response.choices[0].logprobs, + prompt_logprobs=raw_response.choices[0].prompt_logprobs ) - return new_response diff --git a/sae_auto_interp/clients/offline.py b/sae_auto_interp/clients/offline.py index 8f92440f..fca664f9 100644 --- a/sae_auto_interp/clients/offline.py +++ b/sae_auto_interp/clients/offline.py @@ -10,8 +10,6 @@ from .client import Client, Response from vllm.distributed.parallel_state import destroy_model_parallel, destroy_distributed_environment - - @dataclass class Top_Logprob: token: str @@ -25,89 +23,38 @@ class Logprobs: class Offline(Client): provider = "offline" - def __init__(self, model: str, max_memory: float=0.85,prefix_caching:bool=True,batch_size:int=100,max_model_len:int=4096,num_gpus:int=2,enforce_eager:bool=False): + def __init__(self, model: str, max_memory: float=0.85, prefix_caching: bool=True, batch_size: int=100, max_model_len: int=4096, num_gpus: int=2, enforce_eager: bool=False, lora_path: str=None): super().__init__(model) - self.model = model self.queue = asyncio.Queue() self.task = None - self.client = LLM(model=model, gpu_memory_utilization=max_memory, enable_prefix_caching=prefix_caching, tensor_parallel_size=num_gpus, max_model_len=max_model_len,enforce_eager=enforce_eager) - self.sampling_params = SamplingParams(max_tokens=500, temperature=0.7) - self.tokenizer= AutoTokenizer.from_pretrained(model) - self.batch_size=batch_size - - - async def process_func(self, batches: Union[str, List[Dict[str, str]]], kwargs): - """ - Process a single request. - """ - - # This is actually stupid - for kwarg in kwargs: - if "logprobs" in kwarg: - self.sampling_params.logprobs = kwarg["top_logprobs"] - if "prompt_logprobs" in kwarg: - self.sampling_params.prompt_logprobs = kwarg["prompt_logprobs"] - loop = asyncio.get_running_loop() - prompts=[] - for batch in batches: - prompt = self.tokenizer.apply_chat_template(batch, add_generation_prompt=True, tokenize=True) - prompts.append(prompt) - response = await loop.run_in_executor( - None, - partial(self.client.generate, prompt_token_ids=prompts, sampling_params=self.sampling_params, use_tqdm=False) - ) + self.client = LLM(model=model, gpu_memory_utilization=max_memory, enable_prefix_caching=prefix_caching, tensor_parallel_size=num_gpus, max_model_len=max_model_len, enforce_eager=enforce_eager,enable_lora=True) + self.sampling_params = SamplingParams(max_tokens=500, temperature=0.01) + self.tokenizer = AutoTokenizer.from_pretrained(model) + self.batch_size = batch_size + if lora_path is not None: + from vllm.lora.request import LoRARequest + request = LoRARequest("lora_adapter",1,lora_path) + self.lora_request = request + else: + self.lora_request = None - new_response = [] - for r in response: - logprobs,prompt_logprobs=self._parse_logprobs(r) - new_response.append(Response(text=r.outputs[0].text, logprobs=logprobs, prompt_logprobs=prompt_logprobs)) - return new_response - async def generate(self, prompt: Union[str, List[Dict[str, str]]], **kwargs) -> str: - """ - Enqueue a request and wait for the result. - """ + async def generate(self, prompt: Union[str, List[Dict[str, str]]], **kwargs) -> Response: future = asyncio.Future() if self.task is None: self.task = asyncio.create_task(self._process_batches()) await self.queue.put((prompt, future, kwargs)) - #print(f"Current queue size: {self.queue.qsize()} prompts") return await future - def _parse_logprobs(self,response): - logprobs=response.outputs[0].logprobs - prompt_logprobs=response.prompt_logprobs - if logprobs is None and prompt_logprobs is None: - return None,None - logprobs_list=None - if logprobs is not None: - logprobs_list=[] - for i in range(len(logprobs)): - log_prob_dict = logprobs[i] - top_logprobs = [] - decoded_token = "" - for token, logprob in log_prob_dict.items(): - if logprob.rank==1: - decoded_token = logprob.decoded_token - top_logprobs.append(Top_Logprob(token=decoded_token, logprob=logprob.logprob)) - else: - top_logprobs.append(Top_Logprob(token=logprob.decoded_token, logprob=logprob.logprob)) - logprobs_list.append(Logprobs(token=decoded_token, top_logprobs=top_logprobs)) - - return logprobs_list,prompt_logprobs - - + async def process_response(self, raw_response: Any) -> Response: + logprobs, prompt_logprobs = self._parse_logprobs(raw_response) + return Response(text=raw_response.outputs[0].text, logprobs=logprobs, prompt_logprobs=prompt_logprobs) async def _process_batches(self): - """ - Continuously process batches of requests. - """ - batch_count = 0 while True: batch = [] batch_futures = [] batch_kwargs = [] - # Collect a batch of requests start_time = asyncio.get_event_loop().time() while len(batch) < self.batch_size: try: @@ -116,21 +63,19 @@ async def _process_batches(self): batch_futures.append(future) batch_kwargs.append(kwargs) except asyncio.QueueEmpty: - if batch: # If we have any items, process them + if batch: break - await asyncio.sleep(0.1) # Short sleep if queue is empty + await asyncio.sleep(0.1) continue - if asyncio.get_event_loop().time() - start_time > 1: # Time-based batch cutoff + if asyncio.get_event_loop().time() - start_time > 1: break if not batch: continue - # Process the batch + try: - results = await self.process_func(batch, batch_kwargs) - batch_count += 1 - #print(f"Batch {batch_count} finished processing. {len(results)} prompts processed.") + results = await self._process_func(batch, batch_kwargs) for result, future in zip(results, batch_futures): if not future.done(): future.set_result(result) @@ -140,11 +85,44 @@ async def _process_batches(self): if not future.done(): future.set_exception(e) + async def _process_func(self, batches: Union[str, List[Dict[str, str]]], kwargs): + for kwarg in kwargs: + if "logprobs" in kwarg: + self.sampling_params.logprobs = kwarg["top_logprobs"] + if "prompt_logprobs" in kwarg: + self.sampling_params.prompt_logprobs = kwarg["prompt_logprobs"] + + prompts = [self.tokenizer.apply_chat_template(batch, add_generation_prompt=True, tokenize=True) for batch in batches] + + loop = asyncio.get_running_loop() + response = await loop.run_in_executor( + None, + partial(self.client.generate, prompt_token_ids=prompts, sampling_params=self.sampling_params, use_tqdm=False, lora_request=self.lora_request) + ) + + return [await self.process_response(r) for r in response] + + def _parse_logprobs(self, response): + logprobs = response.outputs[0].logprobs + prompt_logprobs = response.prompt_logprobs + if logprobs is None and prompt_logprobs is None: + return None, None + + logprobs_list = None + if logprobs is not None: + logprobs_list = [] + for log_prob_dict in logprobs: + top_logprobs = [] + decoded_token = "" + for token, logprob in log_prob_dict.items(): + if logprob.rank == 1: + decoded_token = logprob.decoded_token + top_logprobs.append(Top_Logprob(token=logprob.decoded_token, logprob=logprob.logprob)) + logprobs_list.append(Logprobs(token=decoded_token, top_logprobs=top_logprobs)) + + return logprobs_list, prompt_logprobs async def close(self): - """ - Clean up resources when the client is no longer needed. - """ destroy_model_parallel() destroy_distributed_environment() del self.client @@ -155,6 +133,4 @@ async def close(self): await self.task except asyncio.CancelledError: pass - - - \ No newline at end of file + diff --git a/sae_auto_interp/clients/openrouter.py b/sae_auto_interp/clients/openrouter.py index 6b5c5e6f..649f5a44 100644 --- a/sae_auto_interp/clients/openrouter.py +++ b/sae_auto_interp/clients/openrouter.py @@ -1,19 +1,9 @@ import json from asyncio import sleep - import httpx - from ..logger import logger -from .client import Client - -# Preferred provider routing arguments. -# Change depending on what model you'd like to use. -PROVIDER = {"order": ["Together", "DeepInfra"]} - -class Response: - def __init__(self, response): - self.text = response - +from .client import Client, Response +from typing import Union, List, Dict, Any class OpenRouter(Client): def __init__( self, @@ -22,26 +12,15 @@ def __init__( base_url="https://openrouter.ai/api/v1/chat/completions", ): super().__init__(model) - self.headers = {"Authorization": f"Bearer {api_key}"} - self.url = base_url self.client = httpx.AsyncClient() - def postprocess(self, response): - response_json = response.json() - msg = response_json["choices"][0]["message"]["content"] - return Response(msg) - - async def generate( - self, prompt: str, raw: bool = False, max_retries: int = 2, **kwargs - ) -> str: + async def generate(self, prompt: Union[str, List[Dict[str, str]]], max_retries: int = 2, **kwargs) -> Response: kwargs.pop("schema", None) - data = { "model": self.model, "messages": prompt, - # "provider": PROVIDER, **kwargs, } @@ -50,22 +29,17 @@ async def generate( response = await self.client.post( url=self.url, json=data, headers=self.headers ) - if raw: - return response.json() - - result = self.postprocess(response) - - return result - + return await self.process_response(response) except json.JSONDecodeError: - logger.warning( - f"Attempt {attempt + 1}: Invalid JSON response, retrying..." - ) - + logger.warning(f"Attempt {attempt + 1}: Invalid JSON response, retrying...") except Exception as e: logger.warning(f"Attempt {attempt + 1}: {str(e)}, retrying...") - await sleep(1) logger.error("All retry attempts failed.") raise RuntimeError("Failed to generate text after multiple attempts.") + + async def process_response(self, raw_response: Any) -> Response: + response_json = raw_response.json() + text = response_json["choices"][0]["message"]["content"] + return Response(text=text) From f59aadd755c44087dfe670249ea4d7a1fa66e982 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 11 Oct 2024 08:43:19 +0000 Subject: [PATCH 006/132] Redoing load_tokenized_data --- sae_auto_interp/utils.py | 45 +++++++++++++++++++++++++++++++--------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/sae_auto_interp/utils.py b/sae_auto_interp/utils.py index c42c15db..2dc857b0 100644 --- a/sae_auto_interp/utils.py +++ b/sae_auto_interp/utils.py @@ -1,5 +1,5 @@ from transformers import AutoTokenizer - +import torch def load_tokenized_data( ctx_len: int, @@ -7,26 +7,45 @@ def load_tokenized_data( dataset_repo: str, dataset_split: str, dataset_name: str = "", + column_name: str = "text", seed: int = 22, -): +) -> torch.Tensor: """ - Load a huggingface dataset, tokenize it, and shuffle. + Load a Hugging Face dataset, tokenize it, and shuffle. + + Args: + ctx_len (int): Context length for tokenization. + tokenizer (AutoTokenizer): The tokenizer to use. + dataset_repo (str): The dataset repository name. + dataset_split (str): The dataset split to use. + dataset_name (str, optional): The dataset name. Defaults to "". + column_name (str, optional): The column name to use for tokenization. Defaults to "text". + seed (int, optional): Random seed for shuffling. Defaults to 22. + + Returns: + torch.Tensor: The tokenized and shuffled dataset. """ from datasets import load_dataset from transformer_lens import utils - print(dataset_repo,dataset_name,dataset_split) data = load_dataset(dataset_repo, name=dataset_name, split=dataset_split) - if "rpj" in dataset_repo: - tokens = utils.tokenize_and_concatenate(data, tokenizer, max_length=ctx_len,column_name="raw_content") - else: - tokens = utils.tokenize_and_concatenate(data, tokenizer, max_length=ctx_len,column_name="text") + tokens = utils.tokenize_and_concatenate(data, tokenizer, max_length=ctx_len,column_name=column_name) tokens = tokens.shuffle(seed)["tokens"] return tokens -def load_filter(path: str, device: str = "cuda:0"): +def load_filter(path: str, device: str = "cuda:0") -> dict: + """ + Load a filter from a JSON file and convert values to tensors. + + Args: + path (str): Path to the JSON file containing the filter. + device (str, optional): The device to load the tensors to. Defaults to "cuda:0". + + Returns: + dict: A dictionary with tensor values on the specified device. + """ import json import torch @@ -39,9 +58,15 @@ def load_filter(path: str, device: str = "cuda:0"): -def load_tokenizer(model): +def load_tokenizer(model: str) -> AutoTokenizer: """ Loads tokenizer to the default NNsight configuration. + + Args: + model (str): The model name or path to load the tokenizer from. + + Returns: + AutoTokenizer: The configured tokenizer. """ tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left") From 25839f197046cd1588eb9630c79c2814f6b133e1 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 16 Oct 2024 13:36:25 +0000 Subject: [PATCH 007/132] Small tweaks --- sae_auto_interp/clients/__init__.py | 4 ++-- sae_auto_interp/clients/offline.py | 8 ++++++-- sae_auto_interp/explainers/default/default.py | 7 +++---- sae_auto_interp/explainers/default/prompt_builder.py | 2 +- sae_auto_interp/utils.py | 2 +- 5 files changed, 13 insertions(+), 10 deletions(-) diff --git a/sae_auto_interp/clients/__init__.py b/sae_auto_interp/clients/__init__.py index f0b089d2..ca772d25 100644 --- a/sae_auto_interp/clients/__init__.py +++ b/sae_auto_interp/clients/__init__.py @@ -1,7 +1,7 @@ from .client import Client #from .local import Local from .offline import Offline -#from .openrouter import OpenRouter +from .openrouter import OpenRouter #from .outlines import Outlines -__all__ = ["Client", "Offline"] +__all__ = ["Client", "Offline", "OpenRouter"] diff --git a/sae_auto_interp/clients/offline.py b/sae_auto_interp/clients/offline.py index fca664f9..ef18c5f3 100644 --- a/sae_auto_interp/clients/offline.py +++ b/sae_auto_interp/clients/offline.py @@ -27,8 +27,12 @@ def __init__(self, model: str, max_memory: float=0.85, prefix_caching: bool=True super().__init__(model) self.queue = asyncio.Queue() self.task = None - self.client = LLM(model=model, gpu_memory_utilization=max_memory, enable_prefix_caching=prefix_caching, tensor_parallel_size=num_gpus, max_model_len=max_model_len, enforce_eager=enforce_eager,enable_lora=True) - self.sampling_params = SamplingParams(max_tokens=500, temperature=0.01) + if lora_path is not None: + enable_lora = True + else: + enable_lora = False + self.client = LLM(model=model, gpu_memory_utilization=max_memory, enable_prefix_caching=prefix_caching, tensor_parallel_size=num_gpus, max_model_len=max_model_len, enforce_eager=enforce_eager,enable_lora=enable_lora) + self.sampling_params = SamplingParams(max_tokens=500, temperature=0.7) self.tokenizer = AutoTokenizer.from_pretrained(model) self.batch_size = batch_size if lora_path is not None: diff --git a/sae_auto_interp/explainers/default/default.py b/sae_auto_interp/explainers/default/default.py index dcedf94d..fed56094 100644 --- a/sae_auto_interp/explainers/default/default.py +++ b/sae_auto_interp/explainers/default/default.py @@ -86,10 +86,9 @@ def check(i): def _join_activations(self, example): activations = [] - threshold = 0.6 - for i, normalized in enumerate(example.normalized_activations): - if example.normalized_activations[i] > threshold: - activations.append((example.str_toks[i], int(normalized))) + for i, activation in enumerate(example.activations): + if activation > example.max_activation * self.threshold: + activations.append((example.str_toks[i], int(example.normalized_activations[i]))) acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) diff --git a/sae_auto_interp/explainers/default/prompt_builder.py b/sae_auto_interp/explainers/default/prompt_builder.py index 7d1ac215..0a4926bb 100644 --- a/sae_auto_interp/explainers/default/prompt_builder.py +++ b/sae_auto_interp/explainers/default/prompt_builder.py @@ -17,7 +17,7 @@ def build_examples( "content": prompt, }, { - "role": "system", + "role": "assistant", "content": response, }, ] diff --git a/sae_auto_interp/utils.py b/sae_auto_interp/utils.py index 2dc857b0..c5dfd6eb 100644 --- a/sae_auto_interp/utils.py +++ b/sae_auto_interp/utils.py @@ -7,7 +7,7 @@ def load_tokenized_data( dataset_repo: str, dataset_split: str, dataset_name: str = "", - column_name: str = "text", + column_name: str = "raw_content", seed: int = 22, ) -> torch.Tensor: """ From e65287493c146046c8d54e34d232fb5d84a6e1b3 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 16 Oct 2024 13:37:14 +0000 Subject: [PATCH 008/132] update gitignore --- .gitignore | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/.gitignore b/.gitignore index 17f1d051..0b4c1eff 100644 --- a/.gitignore +++ b/.gitignore @@ -17,6 +17,11 @@ processed* extras* redoing* weights* +temp* +tests* +wandb* +simulation_data* +output* .nfs* processed_features/ From 264e3fb71a609f1a37b402aa94223d26a218a976 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 22 Oct 2024 16:00:53 +0000 Subject: [PATCH 009/132] Testing out api --- app.py | 81 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 81 insertions(+) create mode 100644 app.py diff --git a/app.py b/app.py new file mode 100644 index 00000000..db4e82b9 --- /dev/null +++ b/app.py @@ -0,0 +1,81 @@ +from flask import Flask, request, jsonify +import asyncio +import torch +from functools import partial + +from sae_auto_interp.clients import OpenRouter +from sae_auto_interp.config import ExperimentConfig, FeatureConfig +from sae_auto_interp.explainers import DefaultExplainer +from sae_auto_interp.features import FeatureDataset, FeatureLoader +from sae_auto_interp.features.constructors import default_constructor +from sae_auto_interp.features.samplers import sample +from sae_auto_interp.pipeline import Pipeline, process_wrapper +from sae_auto_interp.counterfactuals import ExplainerInterventionExample, ExplainerNeuronFormatter, get_explainer_prompt, fs_examples + +app = Flask(__name__) + +# Global variables +client = None +explainer_pipe = None + +def initialize_globals(): + # Make the folder for storing the explanations + os.makedirs("explanations", exist_ok=True) + + # Make the folder for storing the scores + +@app.before_first_request +def before_first_request(): + initialize_globals() + +@app.route('/generate_explanation', methods=['POST']) +def generate_explanation(): + """ + Generate an explanation for a given set of activations. This endpoint expects + a JSON object with the following fields: + - activations: A list of dictionaries, each containing a 'tokens' key with a list of token strings and a 'values' key with a list of activation values. + - api_key: The API key to use for the request. + - model: The model to use for the request. + + We could potentially allow for more options, eg we have a threshold that is set "in stone", we don't do COT and always show activations. + We don't currently support that, but we could allow for custom prompts as well. + """ + + data = request.json + + if not data or 'activations' not in data: + return jsonify({"error": "Missing required data"}), 400 + if 'api_key' not in data: + return jsonify({"error": "Missing API key"}), 400 + if 'model' not in data: + return jsonify({"error": "Missing model"}), 400 + try: + feature = Feature(f"feature", 0) + examples = [] + for activation in data['activations']: + example = Example(activation['tokens'], activation['values']) + examples.append(example) + feature_record = FeatureRecord(feature) + feature_record.train = [examples] + + client = OpenRouter(api_key=data['api_key'], model=data['model']) + + explainer = DefaultExplainer(client, tokenizer=None, threshold=0.6) + explanation = explainer(feature_record).explanation + + return jsonify({"explanation": explanation}), 200 + + except Exception as e: + return jsonify({"error": str(e)}), 500 + + +@app.route('/generate_score_fuzz', methods=['POST']) +def generate_score_fuzz(): + """ + Generate a score for a given set of activations. This endpoint expects + a JSON object with the following fields: + - activations: A list of dictionaries, each containing a 'tokens' key with a list of token strings and a 'values' key with a list of activation values. + """ + +if __name__ == '__main__': + app.run(debug=True) From e2aa838594bb0fa72256c1c6fffe5247d67d8e91 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 23 Oct 2024 11:13:52 +0000 Subject: [PATCH 010/132] Making a explanation server api --- examples/example_server.py | 67 ++++++++ examples/server.py | 157 ++++++++++++++++++ sae_auto_interp/explainers/default/default.py | 12 +- .../scorers/classifier/classifier.py | 6 +- sae_auto_interp/scorers/classifier/sample.py | 37 +++-- 5 files changed, 264 insertions(+), 15 deletions(-) create mode 100644 examples/example_server.py create mode 100644 examples/server.py diff --git a/examples/example_server.py b/examples/example_server.py new file mode 100644 index 00000000..13109490 --- /dev/null +++ b/examples/example_server.py @@ -0,0 +1,67 @@ +import json +import requests +import random +from transformers import AutoTokenizer + +# Load the activation data from the JSON file +with open("/mnt/ssd-1/gpaulo/SAE-Zoology/extras/neuronpedia/formatted_contexts/activating_contexts_16k/mlp/0/layer_0_contexts_chunk_1.json", "r") as f: + activation_data = json.load(f) +# Load the explanation data +with open("/mnt/ssd-1/gpaulo/SAE-Zoology/extras/explanations_16k/model.layers.0.post_feedforward_layernorm_feature.json", "r") as f: + explanation_data = json.load(f) +tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") + +actual_data = activation_data["features"][0] +activations = actual_data["activations"] +for activation in activations: + activation["tokens"] = tokenizer.batch_decode(activation["tokens"]) # If you have the tokens already decoded, you can skip this step +feature_index = actual_data["feature_index"] +print(feature_index) +explanation = explanation_data[str(feature_index)] +# Server URL +BASE_URL = "http://localhost:5000" + +# API key and model (replace these with your actual values) +API_KEY = "your_api_key_here" +MODEL = "meta-llama/llama-3.1-70b-instruct:free" + +def test_generate_explanation(): + url = f"{BASE_URL}/generate_explanation" + + # Prepare the request data + data = { + "activations": activations[:10], # Using only 10 activations for testing + "api_key": API_KEY, + "model": MODEL + } + + # Send the request + response = requests.post(url, json=data) + + print("Generate Explanation Response:") + print(response.status_code) + print(response.json()) + +def test_generate_score(score_type): + url = f"{BASE_URL}/generate_score_fuzz_detection" + + + data = { + "activations": activations[10:], # Using only the other activations for testing + "explanation": explanation, + "api_key": API_KEY, + "model": MODEL, + "type": score_type + } + + # Send the request + response = requests.post(url, json=data) + + print(f"Generate Score ({score_type}) Response:") + print(response.status_code) + print(response.json()) + +if __name__ == "__main__": + #test_generate_explanation() + #test_generate_score("fuzz") + test_generate_score("detection") diff --git a/examples/server.py b/examples/server.py new file mode 100644 index 00000000..39d8474a --- /dev/null +++ b/examples/server.py @@ -0,0 +1,157 @@ +from flask import Flask, request, jsonify +import asyncio +import torch +import os +import pandas as pd +from functools import partial + +from sae_auto_interp.clients import OpenRouter +from sae_auto_interp.config import ExperimentConfig, FeatureConfig +from sae_auto_interp.explainers import DefaultExplainer +from sae_auto_interp.features import FeatureDataset, FeatureLoader, FeatureRecord, Feature, Example +from sae_auto_interp.features.constructors import default_constructor +from sae_auto_interp.features.samplers import sample +from sae_auto_interp.pipeline import Pipeline, process_wrapper +from sae_auto_interp.scorers import FuzzingScorer, DetectionScorer + +app = Flask(__name__) + +# Global variables +client = None +explainer_pipe = None + +def calculate_balanced_accuracy(dataframe): + tp = len(dataframe[(dataframe["ground_truth"]==True) & (dataframe["correct"]==True)]) + tn = len(dataframe[(dataframe["ground_truth"]==False) & (dataframe["correct"]==True)]) + fp = len(dataframe[(dataframe["ground_truth"]==False) & (dataframe["correct"]==False)]) + fn = len(dataframe[(dataframe["ground_truth"]==True) & (dataframe["correct"]==False)]) + if tp+fn == 0: + recall = 0 + else: + recall = tp/(tp+fn) + if tn+fp == 0: + balanced_accuracy = 0 + else: + balanced_accuracy = (recall+tn/(tn+fp))/2 + return balanced_accuracy + +def per_feature_scores_fuzz_detection(score_data): + + data = [d for d in score_data if d.prediction != -1] + + data_df = pd.DataFrame(data) + print(data_df) + + balanced_accuracy = calculate_balanced_accuracy(data_df) + return balanced_accuracy + +def initialize_globals(): + # Make the folder for storing the explanations + os.makedirs("explanations", exist_ok=True) + # Make the folder for storing the scores + os.makedirs("scores", exist_ok=True) + +# Initialize globals when the app starts +initialize_globals() + +@app.route('/generate_explanation', methods=['POST']) +def generate_explanation(): + """ + Generate an explanation for a given set of activations. This endpoint expects + a JSON object with the following fields: + - activations: A list of dictionaries, each containing a 'tokens' key with a list of token strings and a 'values' key with a list of activation values. + - api_key: The API key to use for the request. + - model: The model to use for the request. + + We could potentially allow for more options, eg we have a threshold that is set "in stone", we don't do COT and always show activations. + We don't currently support that, but we could allow for custom prompts as well. + """ + + data = request.json + + if not data or 'activations' not in data: + return jsonify({"error": "Missing required data"}), 400 + if 'api_key' not in data: + return jsonify({"error": "Missing API key"}), 400 + if 'model' not in data: + return jsonify({"error": "Missing model"}), 400 + try: + feature = Feature(f"feature", 0) + examples = [] + for activation in data['activations']: + example = Example(activation['tokens'], torch.tensor(activation['values'])) + examples.append(example) + feature_record = FeatureRecord(feature) + feature_record.train = examples + + client = OpenRouter(api_key=data['api_key'], model=data['model']) + + explainer = DefaultExplainer(client, tokenizer=None, threshold=0.6) + result = explainer.call_sync(feature_record) # Use call_sync instead of __call__ + + return jsonify({"explanation": result.explanation}), 200 + + except Exception as e: + return jsonify({"error": str(e)}), 500 + +@app.route('/generate_score_fuzz_detection', methods=['POST']) +def generate_score_fuzz_detection(): + """ + Generate a score for a given set of activations and explanation. This endpoint expects + a JSON object with the following fields: + - activations: A list of dictionaries, each containing a 'tokens' key with a list of token strings and a 'values' key with a list of activation values. + - explanation: The explanation to use for the score. + - type: Whether to do detection or fuzzing. + - api_key: The API key to use for the request. + - model: The model to use for the request. + + We could potentially allow for more options, eg we hardcode showing 5 examples at a time. + We don't currently support that, but we could allow for custom prompts as well. + OpenRouter doesn't support log_prob, so we can't use that. + """ + + data = request.json + + if not data or 'activations' not in data: + return jsonify({"error": "Missing required data"}), 400 + if 'explanation' not in data: + return jsonify({"error": "Missing explanation"}), 400 + if 'api_key' not in data: + return jsonify({"error": "Missing API key"}), 400 + if 'model' not in data: + return jsonify({"error": "Missing model"}), 400 + if 'type' not in data: + return jsonify({"error": "Missing type"}), 400 + try: + feature = Feature(f"feature", 0) + activating_examples = [] + non_activating_examples = [] + for activation in data['activations']: + example = Example(activation['tokens'], torch.tensor(activation['values'])) + if sum(activation['values']) > 0: + activating_examples.append(example) + else: + non_activating_examples.append(example) + feature_record = FeatureRecord(feature) + feature_record.test = [activating_examples] + feature_record.extra_examples = non_activating_examples + feature_record.random_examples = non_activating_examples + feature_record.explanation = data['explanation'] + + client = OpenRouter(api_key=data['api_key'], model=data['model']) + if data['type'] == 'fuzz': + # We can't use log_prob as it's not supported by OpenRouter + scorer = FuzzingScorer(client, tokenizer=None, batch_size=5,verbose=False,log_prob=False) + elif data['type'] == 'detection': + # We can't use log_prob as it's not supported by OpenRouter + scorer = DetectionScorer(client, tokenizer=None, batch_size=5,verbose=False,log_prob=False) + result = scorer.call_sync(feature_record) # Use call_sync instead of __call__ + #print(result.score) + score = per_feature_scores_fuzz_detection(result.score) + return jsonify({"score": score,"breakdown": result.score}), 200 + + except Exception as e: + return jsonify({"error": str(e)}), 500 + +if __name__ == '__main__': + app.run(debug=True) diff --git a/sae_auto_interp/explainers/default/default.py b/sae_auto_interp/explainers/default/default.py index fed56094..d46e77fc 100644 --- a/sae_auto_interp/explainers/default/default.py +++ b/sae_auto_interp/explainers/default/default.py @@ -1,4 +1,5 @@ import re +import asyncio import torch @@ -61,8 +62,12 @@ def _highlight(self, index, example): result = f"Example {index}: " threshold = example.max_activation * self.threshold - str_toks = self.tokenizer.batch_decode(example.tokens) - example.str_toks = str_toks + if self.tokenizer is not None: + str_toks = self.tokenizer.batch_decode(example.tokens) + example.str_toks = str_toks + else: + str_toks = example.tokens + example.str_toks = str_toks activations = example.activations def check(i): @@ -110,3 +115,6 @@ def _build_prompt(self, examples): activations=self.activations, cot=self.cot, ) + + def call_sync(self, record): + return asyncio.run(self.__call__(record)) diff --git a/sae_auto_interp/scorers/classifier/classifier.py b/sae_auto_interp/scorers/classifier/classifier.py index 4b6398ea..ee68ebdc 100644 --- a/sae_auto_interp/scorers/classifier/classifier.py +++ b/sae_auto_interp/scorers/classifier/classifier.py @@ -45,7 +45,7 @@ async def __call__( record.explanation, samples, ) - + #print(results) return ScorerResult(record=record, score=results) @abstractmethod @@ -85,6 +85,7 @@ async def _generate( self.generation_kwargs["logprobs"] = True self.generation_kwargs["top_logprobs"] = 10 response = await self.client.generate(prompt, **self.generation_kwargs) + #print(response) if response is None: array = [-1] * self.batch_size conditional_probabilities = [-1] * self.batch_size @@ -188,3 +189,6 @@ def _batch(self, samples): samples[i : i + self.batch_size] for i in range(0, len(samples), self.batch_size) ] + + def call_sync(self, record: FeatureRecord) -> list[ClassifierOutput]: + return asyncio.run(self.__call__(record)) diff --git a/sae_auto_interp/scorers/classifier/sample.py b/sae_auto_interp/scorers/classifier/sample.py index 7ab8aee1..9b29e3dc 100644 --- a/sae_auto_interp/scorers/classifier/sample.py +++ b/sae_auto_interp/scorers/classifier/sample.py @@ -17,25 +17,33 @@ @dataclass class ClassifierOutput: - text: str - """Text""" + str_tokens: list[str] + """List of strings""" + + activations: list[float] + """List of floats""" distance: float | int """Quantile or neighbor distance""" ground_truth: bool - """Whether the example is correct or not""" + """Whether the example is activating or not""" prediction: bool = False - """Whether the model predicted the example correctly""" + """Whether the model predicted the example activating or not""" highlighted: bool = False """Whether the sample is highlighted""" + probability: float = 0.0 + """The probability of the example activating""" + + correct: bool = False + """Whether the prediction is correct""" + class Sample(NamedTuple): text: str - data: ClassifierOutput @@ -50,13 +58,15 @@ def examples_to_samples( samples = [] for example in examples: - text,clean = _prepare_text(example, tokenizer, n_incorrect, threshold, highlighted) + text,clean,str_toks = _prepare_text(example, tokenizer, n_incorrect, threshold, highlighted) samples.append( Sample( text=text, data=ClassifierOutput( - text=clean, highlighted=highlighted, **sample_kwargs + str_tokens=str_toks, + activations=example.activations.tolist(), + highlighted=highlighted, **sample_kwargs ), ) ) @@ -73,11 +83,14 @@ def _prepare_text( threshold: float, highlighted: bool, ): - str_toks = tokenizer.batch_decode(example.tokens) + if tokenizer is None: # If we don't have a tokenizer, we assume the tokens are already strings + str_toks = example.tokens + else: + str_toks = tokenizer.batch_decode(example.tokens) clean = "".join(str_toks) # Just return text if there's no highlighting if not highlighted: - return clean,clean + return clean,clean,str_toks threshold = threshold * example.max_activation @@ -88,7 +101,7 @@ def _prepare_text( def check(i): return example.activations[i] >= threshold - return _highlight(str_toks, check),clean + return _highlight(str_toks, check),clean,str_toks # Highlight n_incorrect tokens with activations # below threshold if incorrect example @@ -97,7 +110,7 @@ def check(i): # Rare case where there are no tokens below threshold if below_threshold.dim() == 0: logger.error("Failed to prepare example.") - return DEFAULT_MESSAGE + return DEFAULT_MESSAGE,DEFAULT_MESSAGE,str_toks random.seed(22) @@ -108,7 +121,7 @@ def check(i): def check(i): return i in random_indices - return _highlight(str_toks, check),clean + return _highlight(str_toks, check),clean,str_toks def _highlight(tokens, check): From 6e0ca1ac202cf50d76bf9bf5fd8d5b18e22c54bc Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 8 Nov 2024 12:40:56 +0000 Subject: [PATCH 011/132] Fixes in the prompts --- .../scorers/classifier/prompts/detection_prompt.py | 6 +++--- sae_auto_interp/scorers/classifier/prompts/fuzz_prompt.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sae_auto_interp/scorers/classifier/prompts/detection_prompt.py b/sae_auto_interp/scorers/classifier/prompts/detection_prompt.py index 4caa71f6..fd62b092 100644 --- a/sae_auto_interp/scorers/classifier/prompts/detection_prompt.py +++ b/sae_auto_interp/scorers/classifier/prompts/detection_prompt.py @@ -10,7 +10,7 @@ # https://www.neuronpedia.org/gpt2-small/6-res-jb/6048 DSCORER_EXAMPLE_ONE = """Feature explanation: Words related to American football positions, specifically the tight end position. -Text examples: +Test examples: Example 0:<|endoftext|>Getty ImagesÄŠÄŠPatriots tight end Rob Gronkowski had his bossâĢĻ Example 1: names of months used in The Lord of the Rings:ĊĊâĢľâ̦the @@ -24,7 +24,7 @@ # https://www.neuronpedia.org/gpt2-small/6-res-jb/9396 DSCORER_EXAMPLE_TWO = """Feature explanation: The word "guys" in the phrase "you guys". -Text examples: +Test examples: Example 0: enact an individual health insurance mandate?âĢĿ, Pelosi's response was to dismiss both Example 1: birth control access<|endoftext|> but I assure you women in Kentucky aren't laughing as they struggle @@ -38,7 +38,7 @@ # https://www.neuronpedia.org/gpt2-small/8-res-jb/12654 DSCORER_EXAMPLE_THREE = """Feature explanation: "of" before words that start with a capital letter. -Text examples: +Test examples: Example 0: climate, TomblinâĢĻs Chief of Staff Charlie Lorensen said.ÄŠ Example 1: no wonderworking relics, no true Body and Blood of Christ, no true Baptism diff --git a/sae_auto_interp/scorers/classifier/prompts/fuzz_prompt.py b/sae_auto_interp/scorers/classifier/prompts/fuzz_prompt.py index e405c5de..00198911 100644 --- a/sae_auto_interp/scorers/classifier/prompts/fuzz_prompt.py +++ b/sae_auto_interp/scorers/classifier/prompts/fuzz_prompt.py @@ -11,7 +11,7 @@ # https://www.neuronpedia.org/gpt2-small/6-res-jb/6048 DSCORER_EXAMPLE_ONE = """Feature explanation: Words related to American football positions, specifically the tight end position. -Text examples: +Test examples: Example 0:<|endoftext|>Getty ImagesÄŠÄŠPatriots<< tight end>> Rob Gronkowski had his bossâĢĻ Example 1: posted<|endoftext|>You should know this<< about>> offensive line coaches: they are large, demanding<< men>> @@ -33,7 +33,7 @@ # https://www.neuronpedia.org/gpt2-small/6-res-jb/9396 DSCORER_EXAMPLE_TWO = """Feature explanation: The word "guys" in the phrase "you guys". -Text examples: +Test examples: Example 0: if you are<< comfortable>> with it. You<< guys>> support me in many other ways already and Example 1: birth control access<|endoftext|> but I assure you<< women>> in Kentucky aren't laughing as they struggle @@ -55,7 +55,7 @@ # https://www.neuronpedia.org/gpt2-small/8-res-jb/12654 DSCORER_EXAMPLE_THREE = """Feature explanation: "of" before words that start with a capital letter. -Text examples: +Test examples: Example 0: climate, TomblinâĢĻs Chief<< of>> Staff Charlie Lorensen said.ÄŠ Example 1: no wonderworking relics, no true Body and Blood<< of>> Christ, no true Baptism From 3df816017e0b93ce00ffd28227ba84f19edc25e9 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 8 Nov 2024 12:41:07 +0000 Subject: [PATCH 012/132] Early exiting --- sae_auto_interp/features/constructors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sae_auto_interp/features/constructors.py b/sae_auto_interp/features/constructors.py index c479edc4..268e9488 100644 --- a/sae_auto_interp/features/constructors.py +++ b/sae_auto_interp/features/constructors.py @@ -80,6 +80,8 @@ def random_activation_windows( n_random (int): The number of random examples to generate. """ torch.manual_seed(22) + if n_random == 0: + return batch_size = tokens.shape[0] unique_batch_pos = buffer_output.locations[:, 0].unique() From 131effff506a998d2a20dc2e8e6988bc7dbe4189 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 8 Nov 2024 12:41:28 +0000 Subject: [PATCH 013/132] Str_toks and text --- sae_auto_interp/scorers/classifier/sample.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/sae_auto_interp/scorers/classifier/sample.py b/sae_auto_interp/scorers/classifier/sample.py index 9b29e3dc..611850c7 100644 --- a/sae_auto_interp/scorers/classifier/sample.py +++ b/sae_auto_interp/scorers/classifier/sample.py @@ -58,7 +58,7 @@ def examples_to_samples( samples = [] for example in examples: - text,clean,str_toks = _prepare_text(example, tokenizer, n_incorrect, threshold, highlighted) + text,str_toks = _prepare_text(example, tokenizer, n_incorrect, threshold, highlighted) samples.append( Sample( @@ -90,7 +90,7 @@ def _prepare_text( clean = "".join(str_toks) # Just return text if there's no highlighting if not highlighted: - return clean,clean,str_toks + return clean,str_toks threshold = threshold * example.max_activation @@ -101,7 +101,7 @@ def _prepare_text( def check(i): return example.activations[i] >= threshold - return _highlight(str_toks, check),clean,str_toks + return _highlight(str_toks, check),str_toks # Highlight n_incorrect tokens with activations # below threshold if incorrect example @@ -110,7 +110,7 @@ def check(i): # Rare case where there are no tokens below threshold if below_threshold.dim() == 0: logger.error("Failed to prepare example.") - return DEFAULT_MESSAGE,DEFAULT_MESSAGE,str_toks + return DEFAULT_MESSAGE,str_toks random.seed(22) @@ -121,7 +121,7 @@ def check(i): def check(i): return i in random_indices - return _highlight(str_toks, check),clean,str_toks + return _highlight(str_toks, check),str_toks def _highlight(tokens, check): From 0fbbf0940bf60c04133c0bd8deca6fe23011ddb3 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 8 Nov 2024 12:41:47 +0000 Subject: [PATCH 014/132] Remove extra keyword --- sae_auto_interp/autoencoders/wrapper.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sae_auto_interp/autoencoders/wrapper.py b/sae_auto_interp/autoencoders/wrapper.py index 263384a6..80d02177 100644 --- a/sae_auto_interp/autoencoders/wrapper.py +++ b/sae_auto_interp/autoencoders/wrapper.py @@ -11,7 +11,6 @@ class AutoencoderConfig(Serializable): model_name_or_path: str = "model" autoencoder_type: Literal["SAE", "SAE_LENS", "NEURONS", "CUSTOM"] = "SAE" device: Optional[str] = None - hookpoints: Optional[List[str]] = None kwargs: Dict[str, Any] = field(default_factory=dict) class AutoencoderLatents(torch.nn.Module): @@ -156,6 +155,7 @@ def hook_submodule( submodule: Any, model: Any,module_path:str,autoencoder_confi def load_autoencoder_into_model( model: Any, autoencoder_config: AutoencoderConfig, + hookpoints: List[str], **kwargs ) -> Tuple[Dict[str,Any], Any]: """ @@ -171,7 +171,6 @@ def load_autoencoder_into_model( submodules = {} edited_model = model - hookpoints = autoencoder_config.hookpoints assert hookpoints is not None, "Hookpoints must be specified in autoencoder_config" for module_path in hookpoints: autoencoder = AutoencoderLatents.from_pretrained( From fbcd7ecfafcf7e365c2350bc6cd30d18e197351b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 8 Nov 2024 12:42:10 +0000 Subject: [PATCH 015/132] Delete wrong example --- app.py | 81 ---------------------------------------------------------- 1 file changed, 81 deletions(-) delete mode 100644 app.py diff --git a/app.py b/app.py deleted file mode 100644 index db4e82b9..00000000 --- a/app.py +++ /dev/null @@ -1,81 +0,0 @@ -from flask import Flask, request, jsonify -import asyncio -import torch -from functools import partial - -from sae_auto_interp.clients import OpenRouter -from sae_auto_interp.config import ExperimentConfig, FeatureConfig -from sae_auto_interp.explainers import DefaultExplainer -from sae_auto_interp.features import FeatureDataset, FeatureLoader -from sae_auto_interp.features.constructors import default_constructor -from sae_auto_interp.features.samplers import sample -from sae_auto_interp.pipeline import Pipeline, process_wrapper -from sae_auto_interp.counterfactuals import ExplainerInterventionExample, ExplainerNeuronFormatter, get_explainer_prompt, fs_examples - -app = Flask(__name__) - -# Global variables -client = None -explainer_pipe = None - -def initialize_globals(): - # Make the folder for storing the explanations - os.makedirs("explanations", exist_ok=True) - - # Make the folder for storing the scores - -@app.before_first_request -def before_first_request(): - initialize_globals() - -@app.route('/generate_explanation', methods=['POST']) -def generate_explanation(): - """ - Generate an explanation for a given set of activations. This endpoint expects - a JSON object with the following fields: - - activations: A list of dictionaries, each containing a 'tokens' key with a list of token strings and a 'values' key with a list of activation values. - - api_key: The API key to use for the request. - - model: The model to use for the request. - - We could potentially allow for more options, eg we have a threshold that is set "in stone", we don't do COT and always show activations. - We don't currently support that, but we could allow for custom prompts as well. - """ - - data = request.json - - if not data or 'activations' not in data: - return jsonify({"error": "Missing required data"}), 400 - if 'api_key' not in data: - return jsonify({"error": "Missing API key"}), 400 - if 'model' not in data: - return jsonify({"error": "Missing model"}), 400 - try: - feature = Feature(f"feature", 0) - examples = [] - for activation in data['activations']: - example = Example(activation['tokens'], activation['values']) - examples.append(example) - feature_record = FeatureRecord(feature) - feature_record.train = [examples] - - client = OpenRouter(api_key=data['api_key'], model=data['model']) - - explainer = DefaultExplainer(client, tokenizer=None, threshold=0.6) - explanation = explainer(feature_record).explanation - - return jsonify({"explanation": explanation}), 200 - - except Exception as e: - return jsonify({"error": str(e)}), 500 - - -@app.route('/generate_score_fuzz', methods=['POST']) -def generate_score_fuzz(): - """ - Generate a score for a given set of activations. This endpoint expects - a JSON object with the following fields: - - activations: A list of dictionaries, each containing a 'tokens' key with a list of token strings and a 'values' key with a list of activation values. - """ - -if __name__ == '__main__': - app.run(debug=True) From 8ea7df7f65cf1bba35997782bb8326b3d3b63b09 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 13 Nov 2024 15:25:04 +0000 Subject: [PATCH 016/132] Autoencoder's stuff --- sae_auto_interp/autoencoders/DeepMind/__init__.py | 2 +- sae_auto_interp/autoencoders/__init__.py | 3 ++- sae_auto_interp/autoencoders/wrapper.py | 9 +++++---- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/sae_auto_interp/autoencoders/DeepMind/__init__.py b/sae_auto_interp/autoencoders/DeepMind/__init__.py index b81806f1..e6461e41 100644 --- a/sae_auto_interp/autoencoders/DeepMind/__init__.py +++ b/sae_auto_interp/autoencoders/DeepMind/__init__.py @@ -1,5 +1,5 @@ from functools import partial -from .model import JumpReLUSAE +from ..Custom.gemmascope import JumpReLUSAE from typing import List, Dict import torch from ..wrapper import AutoencoderLatents diff --git a/sae_auto_interp/autoencoders/__init__.py b/sae_auto_interp/autoencoders/__init__.py index 7a4c8815..57e3f91b 100644 --- a/sae_auto_interp/autoencoders/__init__.py +++ b/sae_auto_interp/autoencoders/__init__.py @@ -3,7 +3,7 @@ from .Sam import load_sam_autoencoders from .eleuther import load_eai_autoencoders from .DeepMind import load_gemma_autoencoders -from .wrapper import load_autoencoder_into_model, AutoencoderConfig +from .wrapper import load_autoencoder_into_model, AutoencoderConfig, AutoencoderLatents __all__ = [ "load_autoencoder_into_model", @@ -11,4 +11,5 @@ "load_oai_autoencoders", "load_sam_autoencoders", "AutoencoderConfig", + "AutoencoderLatents", ] diff --git a/sae_auto_interp/autoencoders/wrapper.py b/sae_auto_interp/autoencoders/wrapper.py index 80d02177..cb30b4b6 100644 --- a/sae_auto_interp/autoencoders/wrapper.py +++ b/sae_auto_interp/autoencoders/wrapper.py @@ -23,12 +23,13 @@ def __init__( autoencoder: Any, forward_function: Callable, width: int, + hookpoint: str, ) -> None: super().__init__() self.ae = autoencoder self._forward = forward_function self.width = width - + self.hookpoint = hookpoint def forward(self, x: torch.Tensor) -> torch.Tensor: return self._forward(x) @@ -42,7 +43,6 @@ def from_pretrained( device = config.device or ('cuda' if torch.cuda.is_available() else 'cpu') autoencoder_type = config.autoencoder_type model_name_or_path = config.model_name_or_path - if autoencoder_type == "SAE": from sae import Sae local = kwargs.get("local",None) @@ -88,7 +88,7 @@ def from_pretrained( else: raise ValueError(f"Unsupported autoencoder type: {autoencoder_type}") - return cls(sae, forward_function, width) + return cls(sae, forward_function, width, hookpoint) @classmethod def random(cls, config: AutoencoderConfig, hookpoint: str, **kwargs): pass @@ -166,7 +166,8 @@ def load_autoencoder_into_model( autoencoder_config (AutoencoderConfig): Configuration for the autoencoder. Returns: - Tuple[List[Any], Any]: The list of submodules with the autoencoder attached and the edited model. + Tuple[List[Any], Any]: The list of submodules with the autoencoder attached + Model with the autoencoder hooked in """ submodules = {} From ad77d3ed2b88a8111070f58dd785abd7c249953e Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 13 Nov 2024 15:25:48 +0000 Subject: [PATCH 017/132] Making new pipeline --- .../counterfactuals/output_explainer.py | 102 ++++++++++++++++++ 1 file changed, 102 insertions(+) create mode 100644 sae_auto_interp/counterfactuals/output_explainer.py diff --git a/sae_auto_interp/counterfactuals/output_explainer.py b/sae_auto_interp/counterfactuals/output_explainer.py new file mode 100644 index 00000000..7222e13c --- /dev/null +++ b/sae_auto_interp/counterfactuals/output_explainer.py @@ -0,0 +1,102 @@ +import re +import asyncio + +import torch + +from sae_auto_interp.explainers.explainer import Explainer, ExplainerResult +from sae_auto_interp.explainers.prompt_builder import build_prompt +import time +from ...logger import logger +import numpy as np + +class OutputExplainer(Explainer): + name = "output" + + def __init__( + self, + explanation_client, + prepared_examples, + tokenizer, + verbose: bool = False, + logit_lens: bool = False, + **generation_kwargs, + ): + self.client = explanation_client + self.prepared_examples = prepared_examples + self.tokenizer = tokenizer + self.verbose = verbose + self.logit_lens = logit_lens + self.generation_kwargs = generation_kwargs + + + async def __call__(self, record): + + examples = self.prepared_examples + assert len(examples) > 0, "Prepared examples first" + + + + + messages = self._build_prompt(record.train) + + response = await self.client.generate(messages, **self.generation_kwargs) + + try: + explanation = self.parse_explanation(response.text) + if self.verbose: + return ( + messages[-1]["content"], + response, + ExplainerResult(record=record, explanation=explanation), + ) + + return ExplainerResult(record=record, explanation=explanation) + except Exception as e: + logger.error(f"Explanation parsing failed: {e}") + return ExplainerResult(record=record, explanation="Explanation could not be parsed.") + + + + + def parse_explanation(self, text: str) -> str: + try: + match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) + return match.group(1).strip() if match else "Explanation could not be parsed." + except Exception as e: + logger.error(f"Explanation parsing regex failed: {e}") + raise + + + + + + def _join_activations(self, example): + activations = [] + + for i, activation in enumerate(example.activations): + if activation > example.max_activation * self.threshold: + activations.append((example.str_toks[i], int(example.normalized_activations[i]))) + + acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) + + return "Activations: " + acts + + def _build_prompt(self, examples): + highlighted_examples = [] + + for i, example in enumerate(examples): + highlighted_examples.append(self._highlight(i + 1, example)) + + if self.activations: + highlighted_examples.append(self._join_activations(example)) + + highlighted_examples = "\n".join(highlighted_examples) + + return build_prompt( + examples=highlighted_examples, + activations=self.activations, + cot=self.cot, + ) + + def call_sync(self, record): + return asyncio.run(self.__call__(record)) From 35d8e3c98c8846695094efbe466e1985f6aeee10 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 11 Dec 2024 08:50:33 +0000 Subject: [PATCH 018/132] Transcoder --- sae_auto_interp/autoencoders/eleuther.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/sae_auto_interp/autoencoders/eleuther.py b/sae_auto_interp/autoencoders/eleuther.py index fd7ae72d..4a824b56 100644 --- a/sae_auto_interp/autoencoders/eleuther.py +++ b/sae_auto_interp/autoencoders/eleuther.py @@ -15,6 +15,7 @@ def load_eai_autoencoders( ae_layers: List[int], weight_dir: str, module: str, + transcoder: bool = False, randomize: bool = False, seed: int = 42, k: Optional[int] = None @@ -77,7 +78,7 @@ def _forward(sae, k,x): else: submodule = model.gpt_neox.layers[layer].mlp submodule.ae = AutoencoderLatents( - sae, partial(_forward, sae, k), width=sae.encoder.weight.shape[0] + sae, partial(_forward, sae, k), width=sae.encoder.weight.shape[0],hookpoint=submodule.path ) submodules[submodule.path] = submodule @@ -85,9 +86,15 @@ def _forward(sae, k,x): with model.edit("") as edited: for path, submodule in submodules.items(): if "embed" not in path and "mlp" not in path: - acts = submodule.output[0] + if transcoder: + acts = submodule.input[0] + else: + acts = submodule.output[0] else: - acts = submodule.output + if transcoder: + acts = submodule.input + else: + acts = submodule.output submodule.ae(acts, hook=True) return submodules,edited From f296d362f94b89ce13ef43b7371e5232d8cb2d0f Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 11 Dec 2024 08:50:44 +0000 Subject: [PATCH 019/132] Dataset collumn --- sae_auto_interp/config.py | 3 +++ sae_auto_interp/features/cache.py | 1 - sae_auto_interp/features/loader.py | 1 + 3 files changed, 4 insertions(+), 1 deletion(-) diff --git a/sae_auto_interp/config.py b/sae_auto_interp/config.py index 7f31a22a..353eeedd 100644 --- a/sae_auto_interp/config.py +++ b/sae_auto_interp/config.py @@ -58,6 +58,9 @@ class CacheConfig(Serializable): dataset_name: str = "" """Dataset name to use""" + dataset_column_name: str = "text" + """Dataset column name to use""" + batch_size: int = 32 """Number of sequences to process in a batch""" diff --git a/sae_auto_interp/features/cache.py b/sae_auto_interp/features/cache.py index 8bbca50b..400f45ef 100644 --- a/sae_auto_interp/features/cache.py +++ b/sae_auto_interp/features/cache.py @@ -120,7 +120,6 @@ def get_nonzeros( else: nonzero_feature_locations = torch.nonzero(latents.abs() > 1e-5) nonzero_feature_activations = latents[latents.abs() > 1e-5] - # Return all nonzero features if no filter is provided if self.filters is None: return nonzero_feature_locations, nonzero_feature_activations diff --git a/sae_auto_interp/features/loader.py b/sae_auto_interp/features/loader.py index 291afe57..c41aa625 100644 --- a/sae_auto_interp/features/loader.py +++ b/sae_auto_interp/features/loader.py @@ -143,6 +143,7 @@ def __init__( cache_config["dataset_repo"], cache_config["dataset_split"], cache_config["dataset_name"], + cache_config["dataset_column_name"], ) def _edges(self): From 06ebde18c31c121aee57c8ca786a50cf5b34fbab Mon Sep 17 00:00:00 2001 From: neverix Date: Mon, 13 Jan 2025 02:06:42 +0000 Subject: [PATCH 020/132] Breaking change: save tokens in cache, make it the primary source of truth --- README.md | 5 +++ sae_auto_interp/features/cache.py | 25 ++++++++++++--- sae_auto_interp/features/constructors.py | 22 +++++++++++-- sae_auto_interp/features/loader.py | 41 ++++++++++++++++++------ 4 files changed, 77 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index 5597ff32..3379f0ce 100644 --- a/README.md +++ b/README.md @@ -238,8 +238,13 @@ Surprisal scoring computes the loss over some examples and uses a base model. We Embedding scoring uses a small embedding model through `sentence_transformers` to embed the examples do retrival. It also does not use VLLM but run the model directly. The setup is similar as above but for a example check `embedding.py` in the experiments folder. +# Breaking changes in v0.2 +`features.cache`: Dataset tokens are now saved in safetensors files together with the activations. + +`features.constructors.default_constructor`: `tokens` was renamed to `token_loader`, which must be a callable for lazy loading. Instead of passing `tokens=dataset.tokens`, pass `token_loader=lambda: dataset.load_tokens()` (assuming `dataset` is a `FeatureDataset` instance). + # Scripts Example scripts can be found in `demos`. Some of these scripts can be called from the CLI, as seen in examples found in `scripts`. These baseline scripts should allow anyone to start generating and scoring explanations in any SAE they are interested in. One always needs to first cache the activations of the features of any given SAE, and then generating explanations and scoring them can be done at the same time. diff --git a/sae_auto_interp/features/cache.py b/sae_auto_interp/features/cache.py index 400f45ef..01ad71ca 100644 --- a/sae_auto_interp/features/cache.py +++ b/sae_auto_interp/features/cache.py @@ -30,12 +30,14 @@ def __init__( """ self.feature_locations = defaultdict(list) self.feature_activations = defaultdict(list) + self.tokens = defaultdict(list) self.filters = filters self.batch_size = batch_size def add( self, latents: TensorType["batch", "sequence", "feature"], + tokens: TensorType["batch", "sequence"], batch_number: int, module_path: str, ): @@ -44,17 +46,20 @@ def add( Args: latents (TensorType["batch", "sequence", "feature"]): Latent activations. + tokens (TensorType["batch", "sequence"]): Input tokens. batch_number (int): Current batch number. module_path (str): Path of the module. """ feature_locations, feature_activations = self.get_nonzeros(latents, module_path) feature_locations = feature_locations.cpu() feature_activations = feature_activations.cpu() + tokens = tokens.cpu() # Adjust batch indices feature_locations[:, 0] += batch_number * self.batch_size self.feature_locations[module_path].append(feature_locations) self.feature_activations[module_path].append(feature_activations) + self.tokens[module_path].append(tokens) def save(self): """ @@ -68,6 +73,10 @@ def save(self): self.feature_activations[module_path] = torch.cat( self.feature_activations[module_path], dim=0 ) + + self.tokens[module_path] = torch.cat( + self.tokens[module_path], dim=0 + ) def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]): """ @@ -112,7 +121,8 @@ def get_nonzeros( module_path (str): Path of the module. Returns: - Tuple[torch.Tensor, torch.Tensor]: Non-zero feature locations and activations. + Tuple[TensorType["num_nonzero", 3], TensorType["num_nonzero"]]: + Non-zero feature locations and activations. """ size = latents.shape[1] * latents.shape[0] * latents.shape[2] if size > torch.iinfo(torch.int32).max: @@ -228,7 +238,7 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): for module_path, submodule in self.submodule_dict.items(): buffer[module_path] = submodule.ae.output.save() for module_path, latents in buffer.items(): - self.cache.add(latents, batch_number, module_path) + self.cache.add(latents, batch, batch_number, module_path) del buffer torch.cuda.empty_cache() @@ -240,12 +250,13 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): print(f"Total tokens processed: {total_tokens:,}") self.cache.save() - def save(self, save_dir): + def save(self, save_dir, save_tokens: bool = True): """ Save the cached features to disk. Args: save_dir (str): Directory to save the features. + save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to """ for module_path in self.cache.feature_locations.keys(): output_file = f"{save_dir}/{module_path}.safetensors" @@ -254,6 +265,8 @@ def save(self, save_dir): "locations": self.cache.feature_locations[module_path], "activations": self.cache.feature_activations[module_path], } + if save_tokens: + data["tokens"] = self.cache.tokens[module_path] save_file(data, output_file) @@ -272,19 +285,21 @@ def _generate_split_indices(self, n_splits): # Adjust end by one return list(zip(boundaries[:-1], boundaries[1:] - 1)) - def save_splits(self, n_splits: int, save_dir): + def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True): """ Save the cached features in splits. Args: n_splits (int): Number of splits to generate. save_dir (str): Directory to save the splits. + save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True. """ split_indices = self._generate_split_indices(n_splits) for module_path in self.cache.feature_locations.keys(): feature_locations = self.cache.feature_locations[module_path] feature_activations = self.cache.feature_activations[module_path] + tokens = self.cache.tokens[module_path].numpy() features = feature_locations[:, 2] for start, end in split_indices: @@ -309,6 +324,8 @@ def save_splits(self, n_splits: int, save_dir): "locations": masked_locations, "activations": masked_activations, } + if save_tokens: + split_data["tokens"] = tokens save_file(split_data, output_file) diff --git a/sae_auto_interp/features/constructors.py b/sae_auto_interp/features/constructors.py index 268e9488..d6df9677 100644 --- a/sae_auto_interp/features/constructors.py +++ b/sae_auto_interp/features/constructors.py @@ -1,5 +1,6 @@ import torch from torchtyping import TensorType +from typing import Callable, Optional from .features import FeatureRecord, prepare_examples from .loader import BufferOutput @@ -100,7 +101,7 @@ def random_activation_windows( def default_constructor( record: FeatureRecord, - tokens: TensorType["batch", "seq"], + token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], buffer_output: BufferOutput, n_random: int, ctx_len: int, @@ -111,12 +112,29 @@ def default_constructor( Args: record (FeatureRecord): The feature record to update. - tokens (TensorType["batch", "seq"]): The input tokens. + token_loader (Optional[Callable[[], TensorType["batch", "seq"]]]): + An optional function that creates the dataset tokens. buffer_output (BufferOutput): The buffer output containing activations and locations. n_random (int): Number of random examples to generate. ctx_len (int): Context length for each example. max_examples (int): Maximum number of examples to generate. """ + tokens = buffer_output.tokens + if tokens is None: + if token_loader is None: + raise ValueError("Either tokens or token_loader must be provided") + try: + tokens = token_loader() + except TypeError: + raise ValueError( + "Starting with v0.2, `tokens` was renamed to `token_loader`, " + "which must be a callable for lazy loading.\n\n" + "Instead of passing\n" + "` tokens=dataset.tokens`,\n" + "pass\n" + "` token_loader=lambda: dataset.load_tokens()`,\n" + "(assuming `dataset` is a `FeatureDataset` instance)." + ) pool_max_activation_windows( record, tokens=tokens, diff --git a/sae_auto_interp/features/loader.py b/sae_auto_interp/features/loader.py index c41aa625..72c3e67f 100644 --- a/sae_auto_interp/features/loader.py +++ b/sae_auto_interp/features/loader.py @@ -26,10 +26,12 @@ class BufferOutput(NamedTuple): feature (Feature): The feature associated with this output. locations (TensorType["locations", 2]): Tensor of feature locations. activations (TensorType["locations"]): Tensor of feature activations. + tokens (TensorType["tokens"]): Tensor of all tokens. """ feature: Feature locations: TensorType["locations", 2] activations: TensorType["locations"] + tokens: TensorType["tokens"] class TensorBuffer: @@ -69,6 +71,10 @@ def __iter__(self): first_feature = int(self.tensor_path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) locations = torch.tensor(split_data["locations"].astype(np.int64)) + if hasattr(split_data, "tokens"): + tokens = torch.tensor(split_data["tokens"].astype(np.int64)) + else: + tokens = None locations[:,2] = locations[:,2] + first_feature @@ -94,7 +100,8 @@ def __iter__(self): yield BufferOutput( Feature(self.module_path, int(features[i].item())), feature_locations, - feature_activations + feature_activations, + tokens ) def reset(self): @@ -137,15 +144,29 @@ def __init__( with open(cache_config_dir, "r") as f: cache_config = json.load(f) self.tokenizer = load_tokenizer(cache_config["model_name"]) - self.tokens = load_tokenized_data( - cache_config["ctx_len"], - self.tokenizer, - cache_config["dataset_repo"], - cache_config["dataset_split"], - cache_config["dataset_name"], - cache_config["dataset_column_name"], - ) - + self.cache_config = cache_config + + def load_tokens(self): + """ + Load tokenized data for the dataset. + Caches the tokenized data if not already loaded. + + Returns: + torch.Tensor: The tokenized dataset. + """ + if not hasattr(self, "tokens"): + self.tokens = load_tokenized_data( + self.cache_config["ctx_len"], + self.tokenizer, + self.cache_config["dataset_repo"], + self.cache_config["dataset_split"], + self.cache_config["dataset_name"], + column_name=self.cache_config.get( + "column_name", self.cache_config.get("dataset_row", "raw_content") + ), + ) + return self.tokens + def _edges(self): """Generate edge indices for feature splits.""" return torch.linspace(0, self.cfg.width, steps=self.cfg.n_splits + 1).long() From 0f77e43ffce753202810a1bf8ede3388712a5786 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 13 Jan 2025 12:59:47 +0000 Subject: [PATCH 021/132] Fixing loader bug --- sae_auto_interp/features/loader.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/sae_auto_interp/features/loader.py b/sae_auto_interp/features/loader.py index c41aa625..521aaf18 100644 --- a/sae_auto_interp/features/loader.py +++ b/sae_auto_interp/features/loader.py @@ -8,6 +8,7 @@ from safetensors.numpy import load_file from torchtyping import TensorType from tqdm import tqdm +from nnsight import LanguageModel from sae_auto_interp.utils import ( load_tokenized_data, @@ -136,7 +137,9 @@ def __init__( cache_config_dir = f"{raw_dir}/{modules[0]}/config.json" with open(cache_config_dir, "r") as f: cache_config = json.load(f) - self.tokenizer = load_tokenizer(cache_config["model_name"]) + temp_model = LanguageModel(cache_config["model_name"], device_map="cpu", dispatch=False) + self.tokenizer = temp_model.tokenizer + print(cache_config) self.tokens = load_tokenized_data( cache_config["ctx_len"], self.tokenizer, @@ -145,6 +148,7 @@ def __init__( cache_config["dataset_name"], cache_config["dataset_column_name"], ) + print(self.tokenizer.decode(self.tokens[0])) def _edges(self): """Generate edge indices for feature splits.""" From 162f7d1a12176be4a1116c79a9b0cde03b31e0ac Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 17 Jan 2025 13:04:48 +0000 Subject: [PATCH 022/132] tokens in, not hasattr --- sae_auto_interp/features/loader.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sae_auto_interp/features/loader.py b/sae_auto_interp/features/loader.py index afa66725..81121580 100644 --- a/sae_auto_interp/features/loader.py +++ b/sae_auto_interp/features/loader.py @@ -72,10 +72,11 @@ def __iter__(self): first_feature = int(self.tensor_path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) locations = torch.tensor(split_data["locations"].astype(np.int64)) - if hasattr(split_data, "tokens"): + if "tokens" in split_data: tokens = torch.tensor(split_data["tokens"].astype(np.int64)) else: tokens = None + print(tokens.shape) locations[:,2] = locations[:,2] + first_feature @@ -146,7 +147,6 @@ def __init__( cache_config = json.load(f) temp_model = LanguageModel(cache_config["model_name"], device_map="cpu", dispatch=False) self.tokenizer = temp_model.tokenizer - self.cache_config = cache_config def load_tokens(self): From 79272a3f708240705780f7ec6a896f3e884f84da Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 17 Jan 2025 13:17:52 +0000 Subject: [PATCH 023/132] Sensible default config --- sae_auto_interp/config.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/sae_auto_interp/config.py b/sae_auto_interp/config.py index 353eeedd..3873d798 100644 --- a/sae_auto_interp/config.py +++ b/sae_auto_interp/config.py @@ -7,13 +7,13 @@ @dataclass class ExperimentConfig(Serializable): - n_examples_train: int = 40 + n_examples_train: int = 50 """Number of examples to sample for training""" - n_examples_test: int = 5 + n_examples_test: int = 50 """Number of examples to sample for testing""" - n_quantiles: int = 20 + n_quantiles: int = 10 """Number of quantiles to sample""" example_ctx_len: int = 32 @@ -22,7 +22,7 @@ class ExperimentConfig(Serializable): n_random: int = 50 """Number of random examples to sample""" - train_type: Literal["top", "random", "quantiles"] = "random" + train_type: Literal["top", "random", "quantiles"] = "quantiles" """Type of sampler to use for training""" test_type: Literal["quantiles", "activation"] = "quantiles" From 34c8748154413d64d9d1668ed210de3bff031ce3 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 17 Jan 2025 13:32:00 +0000 Subject: [PATCH 024/132] Remove debug print --- sae_auto_interp/features/loader.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sae_auto_interp/features/loader.py b/sae_auto_interp/features/loader.py index 81121580..371c3272 100644 --- a/sae_auto_interp/features/loader.py +++ b/sae_auto_interp/features/loader.py @@ -76,7 +76,6 @@ def __iter__(self): tokens = torch.tensor(split_data["tokens"].astype(np.int64)) else: tokens = None - print(tokens.shape) locations[:,2] = locations[:,2] + first_feature From 43c2d9f169670aea905c6d2c4898e7bf63a86b68 Mon Sep 17 00:00:00 2001 From: Goncalo Paulo <30472805+SrGonao@users.noreply.github.com> Date: Fri, 31 Jan 2025 16:13:39 +0000 Subject: [PATCH 025/132] Update pyproject.toml --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index bb76b40d..54faaf94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "sae_auto_interp" -version = "0.1.0" +version = "0.2.0" description = "Automated Interpretability" readme = "README.md" requires-python = ">=3.10" From 922dc1424c633b9f36c1b3d8b96a5dce70c7228e Mon Sep 17 00:00:00 2001 From: neverix Date: Fri, 31 Jan 2025 11:29:22 -0500 Subject: [PATCH 026/132] Make token_loader None by default --- sae_auto_interp/features/constructors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/sae_auto_interp/features/constructors.py b/sae_auto_interp/features/constructors.py index d6df9677..cca82360 100644 --- a/sae_auto_interp/features/constructors.py +++ b/sae_auto_interp/features/constructors.py @@ -101,23 +101,23 @@ def random_activation_windows( def default_constructor( record: FeatureRecord, - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], buffer_output: BufferOutput, n_random: int, ctx_len: int, max_examples: int, + token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] = None, ): """ Construct feature examples using pool max activation windows and random activation windows. Args: record (FeatureRecord): The feature record to update. - token_loader (Optional[Callable[[], TensorType["batch", "seq"]]]): - An optional function that creates the dataset tokens. buffer_output (BufferOutput): The buffer output containing activations and locations. n_random (int): Number of random examples to generate. ctx_len (int): Context length for each example. max_examples (int): Maximum number of examples to generate. + token_loader (Optional[Callable[[], TensorType["batch", "seq"]]]): + An optional function that creates the dataset tokens. """ tokens = buffer_output.tokens if tokens is None: From 640145e7178d13bdb3aea2e149f8a8813284527f Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 4 Feb 2025 19:25:39 +0000 Subject: [PATCH 027/132] Abstract load function --- sae_auto_interp/features/loader.py | 36 +++++++++++++++++++----------- 1 file changed, 23 insertions(+), 13 deletions(-) diff --git a/sae_auto_interp/features/loader.py b/sae_auto_interp/features/loader.py index b39f480d..6a76455b 100644 --- a/sae_auto_interp/features/loader.py +++ b/sae_auto_interp/features/loader.py @@ -60,6 +60,8 @@ def __init__( self.features = features self.min_examples = min_examples + + def __iter__(self): """ Iterate over the buffer, yielding BufferOutput objects. @@ -67,6 +69,22 @@ def __iter__(self): Yields: Union[BufferOutput, None]: BufferOutput if enough examples, None otherwise. """ + features, split_locations, split_activations, tokens = self.load() + + for i in range(len(features)): + feature_locations = split_locations[i] + feature_activations = split_activations[i] + if len(feature_locations) < self.min_examples: + yield None + else: + yield BufferOutput( + Feature(self.module_path, int(features[i].item())), + feature_locations, + feature_activations, + tokens + ) + + def load(self): split_data = load_file(self.tensor_path) first_feature = int(self.tensor_path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) @@ -90,19 +108,11 @@ def __iter__(self): features = unique_features split_locations = torch.split(locations, counts.tolist()) split_activations = torch.split(activations, counts.tolist()) - - for i in range(len(features)): - feature_locations = split_locations[i] - feature_activations = split_activations[i] - if len(feature_locations) < self.min_examples: - yield None - else: - yield BufferOutput( - Feature(self.module_path, int(features[i].item())), - feature_locations, - feature_activations, - tokens - ) + + return features, split_locations, split_activations, tokens + + + def reset(self): """Reset the buffer state.""" From d468e0d83a62281543d8b7cb4a3b15f7d619a4bf Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 4 Feb 2025 19:25:53 +0000 Subject: [PATCH 028/132] Naive implementation --- sae_auto_interp/features/neighbours.py | 209 +++++++++++++++++++++++++ 1 file changed, 209 insertions(+) create mode 100644 sae_auto_interp/features/neighbours.py diff --git a/sae_auto_interp/features/neighbours.py b/sae_auto_interp/features/neighbours.py new file mode 100644 index 00000000..2b40a3f2 --- /dev/null +++ b/sae_auto_interp/features/neighbours.py @@ -0,0 +1,209 @@ +import json +from typing import Dict, List, Optional + +import cupy as cp +import cupyx.scipy.sparse as cusparse +import numpy as np +import torch +from safetensors.numpy import load_file + + +class NeighbourCalculator: + """ + Class to compute the neighbours of selected features using different methods: + - similarity: uses autoencoder weights + - correlation: uses pre-activation records and autoencoder + - co-occurrence: uses feature dataset statistics + """ + + def __init__( + self, + + feature_dataset: Optional['FeatureDataset'] = None, + autoencoder: Optional["Autoencoder"] = None, + pre_activation_record: Optional['PreActivationRecord'] = None, + number_of_neighbours: int = 10, + neighbour_cache: Optional[Dict[str, Dict[int, List[int]]]] = None, + ): + + """ + Initialize a NeighbourCalculator. + + Args: + feature_dataset (Optional[FeatureDataset]): Dataset containing feature activations + autoencoder (Optional[Autoencoder]): The trained autoencoder model + pre_activation_record (Optional[PreActivationRecord]): Record of pre-activation values + """ + self.feature_dataset = feature_dataset + self.autoencoder = autoencoder + self.pre_activation_record = pre_activation_record + + + # load the neighbour cache from the path + if neighbour_cache is not None: + self.neighbour_cache = neighbour_cache + else: + # Dictionary to cache computed neighbour lists + self.neighbour_cache: Dict[str, Dict[int, List[int]]] = {} + + + def _compute_neighbour_list(self, method: str) -> None: + """ + Compute complete neighbour lists using specified method. + + Args: + method (str): One of 'similarity', 'correlation', or 'co-occurrence' + """ + if method == 'similarity': + if self.autoencoder is None: + raise ValueError("Autoencoder is required for similarity-based neighbours") + self.neighbour_cache[method] = self._compute_similarity_neighbours() + + elif method == 'correlation': + if self.autoencoder is None or self.pre_activation_record is None: + raise ValueError("Autoencoder and pre-activation record are required for correlation-based neighbours") + self.neighbour_cache[method] = self._compute_correlation_neighbours() + + elif method == 'co-occurrence': + if self.feature_dataset is None: + raise ValueError("Feature dataset is required for co-occurrence-based neighbours") + self.neighbour_cache[method] = self._compute_cooccurrence_neighbours() + + else: + raise ValueError(f"Unknown method: {method}. Use 'similarity', 'correlation', or 'co-occurrence'") + + def _compute_similarity_neighbours(self) -> Dict[int, List[int]]: + """ + Compute neighbour lists based on weight similarity in the autoencoder. + """ + + # We use the encoder vectors to compute the similarity between features + encoder = self.autoencoder.encoder + + weight_matrix_normalized = encoder.weight / encoder.weight.norm(dim=1, keepdim=True) + + # Compute the similarity between features + similarity_matrix = weight_matrix_normalized.T @ weight_matrix_normalized + + # Get the indices of the top k neighbours for each feature + top_k_indices = torch.topk(similarity_matrix, self.number_of_neighbours, dim=1).indices + + # Return the neighbour lists + return {i: top_k_indices[i].tolist() for i in range(len(top_k_indices))} + + + def _compute_correlation_neighbours(self) -> Dict[int, List[int]]: + """ + Compute neighbour lists based on activation correlation patterns. + """ + + # the preactivation_matrix has the shape (number_of_samples,hidden_dimension) + preactivation_matrix = self.pre_activation_record + + # compute the covariance matrix of the preactivation matrix + covariance_matrix = torch.cov(preactivation_matrix.T) + + # load the encoder + encoder_matrix = self.autoencoder.encoder.weight + + # covariance between the features is u^T * covariance_matrix * u + covariance_between_features = encoder_matrix.T @ covariance_matrix @ encoder_matrix + + # the correlation is then the covariance devided by the product of the standard deviations + + product_of_std = torch.diag(covariance_matrix)**2 + + correlation_matrix = covariance_between_features / product_of_std + + # get the indices of the top k neighbours for each feature + top_k_indices = torch.topk(correlation_matrix, self.number_of_neighbours, dim=1).indices + + # return the neighbour lists + return {i: top_k_indices[i].tolist() for i in range(len(top_k_indices))} + + + def _compute_cooccurrence_neighbours(self) -> Dict[int, List[int]]: + """ + Compute neighbour lists based on feature co-occurrence in the dataset. + """ + # To be implemented + + paths = [] + for buffer in self.feature_dataset.buffers: + paths.append(buffer.tensor_path) + + all_locations = [] + all_activations = [] + for path in paths: + split_data = load_file(path) + first_feature = int(path.split("/")[-1].split("_")[0]) + activations = torch.tensor(split_data["activations"]) + locations = torch.tensor(split_data["locations"].astype(np.int64)) + locations[:,2] = locations[:,2] + first_feature + all_locations.append(locations) + all_activations.append(activations) + + # concatenate the locations and activations + locations = torch.cat(all_locations).cuda() + activations = torch.cat(all_activations).cuda() + n_features = int(torch.max(locations[:,2])) + 1 + + # 1. Get unique values of first 2 dims (i.e. absolute token index) and their counts + # Trick is to use Cantor pairing function to have a bijective mapping between (batch_id, ctx_pos) and a unique 1D index + # Faster than running `torch.unique_consecutive` on the first 2 dims + idx_cantor = (locations[:,0] + locations[:,1]) * (locations[:,0] + locations[:,1] + 1) // 2 + locations[:,1] + unique_idx, idx_counts = torch.unique_consecutive(idx_cantor, return_counts=True) + n_tokens = len(unique_idx) + + # 2. The Cantor indices are not consecutive, so we create sorted ones from the counts + locations_flat = torch.repeat_interleave(torch.arange(n_tokens, device=locations.device), idx_counts) + del idx_cantor,unique_idx,idx_counts + + rows = cp.asarray(locations[:, 2]) + cols = cp.asarray(locations_flat) + data = cp.ones(len(rows)) + sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_features, n_tokens)) + cooc_matrix = sparse_matrix @ sparse_matrix.T + + # Compute Jaccard similarity + def compute_jaccard(cooc_matrix): + self_occurrence = cooc_matrix.diagonal() + jaccard_matrix = cooc_matrix / (self_occurrence[:, None] + self_occurrence - cooc_matrix) + return jaccard_matrix + + del rows, cols, data, sparse_matrix + # Compute Jaccard similarity matrix + jaccard_matrix = compute_jaccard(cooc_matrix) + + # get the indices of the top k neighbours for each feature + top_k_indices = torch.topk(jaccard_matrix, self.number_of_neighbours, dim=1).indices + + # return the neighbour lists + return {i: top_k_indices[i].tolist() for i in range(len(top_k_indices))} + + + + def populate_neighbour_cache(self, methods: List[str]) -> None: + """ + Populate the neighbour cache with the computed neighbour lists + """ + for method in methods: + self._compute_neighbour_list(method) + + + def save_neighbour_cache(self) -> None: + """ + Save the neighbour cache to the path as a json file + """ + with open(self.path, 'w') as f: + json.dump(self.neighbour_cache, f) + + def load_neighbour_cache(self) -> Dict[str, Dict[int, List[int]]]: + """ + Load the neighbour cache from the path as a json file + """ + with open(self.path, 'r') as f: + return json.load(f) + + + From 360b79a19b7d609d8d2528fe2175e1989cffe5c2 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 5 Feb 2025 16:56:11 +0000 Subject: [PATCH 029/132] Update name --- pyproject.toml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d41b7fc3..9865718e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] -name = "sae_auto_interp" +name = "delphi" version = "0.2.0" description = "Automated Interpretability" readme = "README.md" @@ -27,11 +27,11 @@ dependencies = [ ] [tool.pyright] -include = ["sae_auto_interp*"] +include = ["delphi*"] reportPrivateImportUsage = false [tool.setuptools.packages.find] -include = ["sae_auto_interp*"] +include = ["delphi*"] [tool.ruff] # TODO: Clean up or remove experiments folder. From 6015908370a7e7a14db17546da6f8891e725d4aa Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:07:49 +0000 Subject: [PATCH 030/132] Update config name --- delphi/config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/delphi/config.py b/delphi/config.py index c2bfef9b..53ce7815 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -54,8 +54,8 @@ class CacheConfig(Serializable): dataset_name: str = "" """Dataset name to use.""" - dataset_row: str = "raw_content" - """Dataset row to use.""" + dataset_column_name: str = "raw_content" + """Dataset column name to use.""" batch_size: int = 32 """Number of sequences to process in a batch.""" From 0d30fe519268bf89eed6b7e67c9beda4994a86ad Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:08:09 +0000 Subject: [PATCH 031/132] Update config name and sae --- delphi/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/delphi/utils.py b/delphi/utils.py index f6368261..88c75de3 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -9,7 +9,7 @@ def load_tokenized_data( dataset_repo: str, dataset_split: str, dataset_name: str = "", - dataset_row: str = "raw_content", + dataset_column_name: str = "raw_content", seed: int = 22, add_bos_token: bool = True, ): @@ -17,12 +17,10 @@ def load_tokenized_data( Load a huggingface dataset, tokenize it, and shuffle. """ from datasets import load_dataset - from sae.data import chunk_and_tokenize + from sparsify.data import chunk_and_tokenize - print(dataset_repo,dataset_name,dataset_split) - data = load_dataset(dataset_repo, name=dataset_name, split=dataset_split) - tokens_ds = chunk_and_tokenize(data, tokenizer, max_seq_len=ctx_len, text_key=dataset_row) + tokens_ds = chunk_and_tokenize(data, tokenizer, max_seq_len=ctx_len, text_key=dataset_column_name) tokens_ds = tokens_ds.shuffle(seed) tokens = cast(TensorType["batch", "seq"], tokens_ds["input_ids"]) From bb9f2a42ccae7c6a8d173515c744a8500216dc1c Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:08:22 +0000 Subject: [PATCH 032/132] idk --- sae_auto_interp/config.py | 74 --------------------------------------- 1 file changed, 74 deletions(-) delete mode 100644 sae_auto_interp/config.py diff --git a/sae_auto_interp/config.py b/sae_auto_interp/config.py deleted file mode 100644 index 3873d798..00000000 --- a/sae_auto_interp/config.py +++ /dev/null @@ -1,74 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - -from simple_parsing import Serializable - - -@dataclass -class ExperimentConfig(Serializable): - - n_examples_train: int = 50 - """Number of examples to sample for training""" - - n_examples_test: int = 50 - """Number of examples to sample for testing""" - - n_quantiles: int = 10 - """Number of quantiles to sample""" - - example_ctx_len: int = 32 - """Length of each example""" - - n_random: int = 50 - """Number of random examples to sample""" - - train_type: Literal["top", "random", "quantiles"] = "quantiles" - """Type of sampler to use for training""" - - test_type: Literal["quantiles", "activation"] = "quantiles" - """Type of sampler to use for testing""" - - - - -@dataclass -class FeatureConfig(Serializable): - width: int = 131072 - """Number of features in the autoencoder""" - - min_examples: int = 200 - """Minimum number of examples for a feature to be included""" - - max_examples: int = 10000 - """Maximum number of examples for a feature to included""" - - n_splits: int = 5 - """Number of splits that features were devided into""" - - -@dataclass -class CacheConfig(Serializable): - - dataset_repo: str = "kh4dien/fineweb-100m-sample" - """Dataset repository to use""" - - dataset_split: str = "train" - """Dataset split to use""" - - dataset_name: str = "" - """Dataset name to use""" - - dataset_column_name: str = "text" - """Dataset column name to use""" - - batch_size: int = 32 - """Number of sequences to process in a batch""" - - ctx_len: int = 256 - """Context length of the autoencoder. Each batch is shape (batch_size, ctx_len)""" - - n_tokens: int = 10_000_000 - """Number of tokens to cache""" - - n_splits: int = 5 - """Number of splits to divide .safetensors into""" From 49517e9ea89ed6581782980b2a87b0bf8708a153 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:08:36 +0000 Subject: [PATCH 033/132] Add neighbour transform --- delphi/features/transforms.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) create mode 100644 delphi/features/transforms.py diff --git a/delphi/features/transforms.py b/delphi/features/transforms.py new file mode 100644 index 00000000..07cbcf5a --- /dev/null +++ b/delphi/features/transforms.py @@ -0,0 +1,20 @@ +from typing import Callable, Optional + +import torch +from torchtyping import TensorType + +from .features import FeatureRecord, prepare_examples +from .loader import BufferOutput + +import json +def set_neighbours( + record: FeatureRecord, + neighbours_path: str, + neighbours_type: str, +): + """ + Set the neighbours for the feature record. + """ + with open(neighbours_path, "r") as f: + neighbours = json.load(f) + record.neighbours = neighbours[neighbours_type][record.feature.feature_index] From e6fa6f4489a16c76ab6fdd90f51b8d473251252b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:08:49 +0000 Subject: [PATCH 034/132] Add neighbour calculator --- delphi/features/neighbours.py | 168 ++++++++++++++++++++++++---------- 1 file changed, 120 insertions(+), 48 deletions(-) diff --git a/delphi/features/neighbours.py b/delphi/features/neighbours.py index 2b40a3f2..3ec9170e 100644 --- a/delphi/features/neighbours.py +++ b/delphi/features/neighbours.py @@ -1,8 +1,7 @@ import json from typing import Dict, List, Optional +from tqdm import tqdm -import cupy as cp -import cupyx.scipy.sparse as cusparse import numpy as np import torch from safetensors.numpy import load_file @@ -21,7 +20,7 @@ def __init__( feature_dataset: Optional['FeatureDataset'] = None, autoencoder: Optional["Autoencoder"] = None, - pre_activation_record: Optional['PreActivationRecord'] = None, + residual_stream_record: Optional['ResidualStreamRecord'] = None, number_of_neighbours: int = 10, neighbour_cache: Optional[Dict[str, Dict[int, List[int]]]] = None, ): @@ -32,12 +31,12 @@ def __init__( Args: feature_dataset (Optional[FeatureDataset]): Dataset containing feature activations autoencoder (Optional[Autoencoder]): The trained autoencoder model - pre_activation_record (Optional[PreActivationRecord]): Record of pre-activation values + residual_stream_record (Optional[ResidualStreamRecord]): Record of residual stream values """ self.feature_dataset = feature_dataset self.autoencoder = autoencoder - self.pre_activation_record = pre_activation_record - + self.residual_stream_record = residual_stream_record + self.number_of_neighbours = number_of_neighbours # load the neighbour cache from the path if neighbour_cache is not None: @@ -60,8 +59,8 @@ def _compute_neighbour_list(self, method: str) -> None: self.neighbour_cache[method] = self._compute_similarity_neighbours() elif method == 'correlation': - if self.autoencoder is None or self.pre_activation_record is None: - raise ValueError("Autoencoder and pre-activation record are required for correlation-based neighbours") + if self.autoencoder is None or self.residual_stream_record is None: + raise ValueError("Autoencoder and residual stream record are required for correlation-based neighbours") self.neighbour_cache[method] = self._compute_correlation_neighbours() elif method == 'co-occurrence': @@ -76,76 +75,109 @@ def _compute_similarity_neighbours(self) -> Dict[int, List[int]]: """ Compute neighbour lists based on weight similarity in the autoencoder. """ - + print("Computing similarity neighbours") # We use the encoder vectors to compute the similarity between features - encoder = self.autoencoder.encoder + encoder = self.autoencoder.encoder.cuda() weight_matrix_normalized = encoder.weight / encoder.weight.norm(dim=1, keepdim=True) - + wT = weight_matrix_normalized.T # Compute the similarity between features - similarity_matrix = weight_matrix_normalized.T @ weight_matrix_normalized - - # Get the indices of the top k neighbours for each feature - top_k_indices = torch.topk(similarity_matrix, self.number_of_neighbours, dim=1).indices - - # Return the neighbour lists - return {i: top_k_indices[i].tolist() for i in range(len(top_k_indices))} + done = False + batch_size = weight_matrix_normalized.shape[0] + number_features = batch_size + + neighbour_lists = {} + while not done: + try: + + for start in tqdm(range(0,number_features,batch_size)): + rows = wT[start:start+batch_size] + similarity_matrix = weight_matrix_normalized @ rows + top_k_indices = torch.topk(similarity_matrix, self.number_of_neighbours+1, dim=1).indices + neighbour_lists.update({i+start: top_k_indices[i].tolist()[1:] for i in range(len(top_k_indices))}) + del similarity_matrix + torch.cuda.empty_cache() + done = True + except Exception: + batch_size = batch_size // 2 + if batch_size < 2: + raise ValueError("Batch size is too small to compute similarity matrix. You don't have enough memory.") + + + return neighbour_lists def _compute_correlation_neighbours(self) -> Dict[int, List[int]]: """ Compute neighbour lists based on activation correlation patterns. """ - - # the preactivation_matrix has the shape (number_of_samples,hidden_dimension) - preactivation_matrix = self.pre_activation_record + print("Computing correlation neighbours") - # compute the covariance matrix of the preactivation matrix - covariance_matrix = torch.cov(preactivation_matrix.T) + # the activation_matrix has the shape (number_of_samples,hidden_dimension) + + activations = torch.tensor(load_file(self.residual_stream_record+".safetensors")["activations"]) - # load the encoder - encoder_matrix = self.autoencoder.encoder.weight + + estimator = CovarianceEstimator(activations.shape[1]) + # batch the activations + batch_size = 10000 + for i in tqdm(range(0,activations.shape[0],batch_size)): + estimator.update(activations[i:i+batch_size]) + + covariance_matrix = estimator.cov().cuda().half() - # covariance between the features is u^T * covariance_matrix * u - covariance_between_features = encoder_matrix.T @ covariance_matrix @ encoder_matrix + # load the encoder + encoder_matrix = self.autoencoder.encoder.weight.cuda().half() - # the correlation is then the covariance devided by the product of the standard deviations + covariance_between_features = torch.zeros((encoder_matrix.shape[0],encoder_matrix.shape[0]),device="cpu") - product_of_std = torch.diag(covariance_matrix)**2 + # do batches of features + batch_size = 1024 + for start in tqdm(range(0,encoder_matrix.shape[0],batch_size)): + end = min(encoder_matrix.shape[0],start+batch_size) + encoder_rows = encoder_matrix[start:end] + + correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T + covariance_between_features[start:end] = correlation.cpu() + # the correlation is then the covariance divided by the product of the standard deviations + diagonal_covariance = torch.diag(covariance_between_features) + product_of_std = torch.sqrt(torch.outer(diagonal_covariance,diagonal_covariance)+1e-6) correlation_matrix = covariance_between_features / product_of_std # get the indices of the top k neighbours for each feature - top_k_indices = torch.topk(correlation_matrix, self.number_of_neighbours, dim=1).indices + top_k_indices = torch.topk(correlation_matrix, self.number_of_neighbours+1, dim=1).indices # return the neighbour lists - return {i: top_k_indices[i].tolist() for i in range(len(top_k_indices))} + return {i: top_k_indices[i].tolist()[1:] for i in range(len(top_k_indices))} def _compute_cooccurrence_neighbours(self) -> Dict[int, List[int]]: """ Compute neighbour lists based on feature co-occurrence in the dataset. + If you run out of memory try reducing the token_batch_size """ - # To be implemented + + + import cupy as cp + import cupyx.scipy.sparse as cusparse + print("Computing co-occurrence neighbours") paths = [] for buffer in self.feature_dataset.buffers: paths.append(buffer.tensor_path) all_locations = [] - all_activations = [] for path in paths: split_data = load_file(path) first_feature = int(path.split("/")[-1].split("_")[0]) - activations = torch.tensor(split_data["activations"]) locations = torch.tensor(split_data["locations"].astype(np.int64)) locations[:,2] = locations[:,2] + first_feature + # compute number of tokens all_locations.append(locations) - all_activations.append(activations) - + # concatenate the locations and activations - locations = torch.cat(all_locations).cuda() - activations = torch.cat(all_activations).cuda() + locations = torch.cat(all_locations) n_features = int(torch.max(locations[:,2])) + 1 # 1. Get unique values of first 2 dims (i.e. absolute token index) and their counts @@ -163,23 +195,41 @@ def _compute_cooccurrence_neighbours(self) -> Dict[int, List[int]]: cols = cp.asarray(locations_flat) data = cp.ones(len(rows)) sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_features, n_tokens)) - cooc_matrix = sparse_matrix @ sparse_matrix.T - + token_batch_size = 100_000 + cooc_matrix = cp.zeros((n_features, n_features), dtype=cp.float32) + + sparse_matrix_csc = sparse_matrix.tocsc() + for start in tqdm(range(0, n_tokens, token_batch_size)): + end = min(n_tokens, start + token_batch_size) + # Slice the sparse matrix to get a batch of tokens. + sub_matrix = sparse_matrix_csc[:, start:end] + # Compute the partial co-occurrence matrix for this batch. + partial_cooc = (sub_matrix @ sub_matrix.T).toarray() + cooc_matrix += partial_cooc + + # Free temporary variables. + del rows, cols, data, sparse_matrix, sparse_matrix_csc + + # Compute Jaccard similarity def compute_jaccard(cooc_matrix): self_occurrence = cooc_matrix.diagonal() jaccard_matrix = cooc_matrix / (self_occurrence[:, None] + self_occurrence - cooc_matrix) + # remove the diagonal and keep the upper triangle return jaccard_matrix - del rows, cols, data, sparse_matrix # Compute Jaccard similarity matrix jaccard_matrix = compute_jaccard(cooc_matrix) - # get the indices of the top k neighbours for each feature - top_k_indices = torch.topk(jaccard_matrix, self.number_of_neighbours, dim=1).indices + jaccard_torch = torch.as_tensor(cp.asnumpy(jaccard_matrix)) + # get the indices of the top k neighbours for each feature + top_k_indices = torch.topk(jaccard_torch, self.number_of_neighbours+1, dim=1).indices + del jaccard_matrix,cooc_matrix,jaccard_torch + torch.cuda.empty_cache() + # return the neighbour lists - return {i: top_k_indices[i].tolist() for i in range(len(top_k_indices))} + return {i: top_k_indices[i].tolist()[1:] for i in range(len(top_k_indices))} @@ -191,19 +241,41 @@ def populate_neighbour_cache(self, methods: List[str]) -> None: self._compute_neighbour_list(method) - def save_neighbour_cache(self) -> None: + def save_neighbour_cache(self, path: str) -> None: """ Save the neighbour cache to the path as a json file """ - with open(self.path, 'w') as f: + with open(path, 'w') as f: json.dump(self.neighbour_cache, f) - def load_neighbour_cache(self) -> Dict[str, Dict[int, List[int]]]: + def load_neighbour_cache(self, path: str) -> Dict[str, Dict[int, List[int]]]: """ Load the neighbour cache from the path as a json file """ - with open(self.path, 'r') as f: + with open(path, 'r') as f: return json.load(f) +class CovarianceEstimator: + def __init__(self, n_features, *, device = None): + self.mean = torch.zeros(n_features, device=device) + self.cov_ = torch.zeros(n_features, n_features, device=device) + self.n = 0 + + def update(self, x: torch.Tensor): + n, d = x.shape + assert d == len(self.mean) + + self.n += n + + # Welford's online algorithm + delta = x - self.mean + self.mean.add_(delta.sum(dim=0), alpha=1 / self.n) + delta2 = x - self.mean + + self.cov_.addmm_(delta.mH, delta2) + + def cov(self): + """Return the estimated covariance matrix.""" + return self.cov_ / self.n \ No newline at end of file From 2c77007e2508b9de3b79a8b6c440c41473c42c04 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:09:12 +0000 Subject: [PATCH 035/132] Add neighbour constructor and update feature record --- delphi/features/constructors.py | 119 ++++++++++++++++++++++++++++++-- delphi/features/features.py | 3 +- delphi/features/loader.py | 105 +++++++++++----------------- 3 files changed, 158 insertions(+), 69 deletions(-) diff --git a/delphi/features/constructors.py b/delphi/features/constructors.py index 29396762..1619e6a3 100644 --- a/delphi/features/constructors.py +++ b/delphi/features/constructors.py @@ -1,6 +1,7 @@ +from typing import Callable, Optional + import torch from torchtyping import TensorType -from typing import Callable, Optional from .features import FeatureRecord, prepare_examples from .loader import BufferOutput @@ -76,7 +77,7 @@ def pool_max_activation_windows( record.examples = prepare_examples(token_windows, activation_windows) def random_activation_windows( - record, + record: FeatureRecord, tokens: TensorType["batch", "seq"], buffer_output: BufferOutput, ctx_len: int, @@ -95,6 +96,7 @@ def random_activation_windows( torch.manual_seed(22) if n_random == 0: return + batch_size = tokens.shape[0] unique_batch_pos = buffer_output.locations[:, 0].unique() @@ -118,6 +120,70 @@ def random_activation_windows( torch.zeros_like(toks), ) +def neighbour_random_activation_windows( + record: FeatureRecord, + tokens: TensorType["batch", "seq"], + buffer_output: BufferOutput, + everything: BufferOutput, + ctx_len: int, + n_random: int, +): + """ + Generate random activation windows and update the feature record. + + Args: + record (FeatureRecord): The feature record to update. + tokens (TensorType["batch", "seq"]): The input tokens. + buffer_output (BufferOutput): The buffer output containing activations and locations. + ctx_len (int): The context length. + n_random (int): The number of random examples to generate. + """ + torch.manual_seed(22) + if n_random == 0: + return + + assert record.neighbours is not None, "Neighbours are not set, add them via a transform" + + batch_size = tokens.shape[0] + + mask = torch.zeros(batch_size, dtype=torch.bool) + + for neighbour in record.neighbours: + # Get the possible batch positions where the neighbour is active + possible_locations = everything.locations[everything.locations[:, 0] == neighbour] + # Get the unique locations + unique_possible_locations = possible_locations.unique(dim=0) + # Set the mask to True for the unique locations + mask[unique_possible_locations[:, 1]] = True + + # Get the unique batch positions where the latent is active + unique_batch_pos_active = buffer_output.locations[:, 0].unique() + # Set the mask to False for the unique locations where the latent is active + mask[unique_batch_pos_active] = False + + available_indices = mask.nonzero().squeeze() + + # TODO:What to do when the latent is active at least once in each batch? + if available_indices.numel() < n_random: + print("No available indices") + record.random_examples = [] + return + else: + # Select the batch positions + selected_indices = available_indices[torch.randint(0,len(available_indices),size=(n_random,))] + # Select the token positions + selected_positions = torch.randint(0, tokens.shape[1] - ctx_len, size=(n_random,)) + + + # Get tokens + toks = tokens[selected_indices, selected_positions : selected_positions + ctx_len] + + record.random_examples = prepare_examples( + toks, + torch.zeros_like(toks), + ) + + def default_constructor( record: FeatureRecord, token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], @@ -156,8 +222,8 @@ def default_constructor( ) pool_max_activation_windows( record, - tokens=tokens, buffer_output=buffer_output, + tokens=tokens, ctx_len=ctx_len, max_examples=max_examples, ) @@ -167,4 +233,49 @@ def default_constructor( buffer_output=buffer_output, n_random=n_random, ctx_len=ctx_len, - ) \ No newline at end of file + ) + +def neighbour_constructor( + record: FeatureRecord, + token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], + buffer_output: BufferOutput, + everything: BufferOutput, + n_random: int, + ctx_len: int, + max_examples: int, +): + """ + Construct feature examples using pool max activation windows and random activation windows from neighbours. + """ + tokens = everything.tokens + if tokens is None: + if token_loader is None: + raise ValueError("Either tokens or token_loader must be provided") + try: + tokens = token_loader() + except TypeError: + raise ValueError( + "Starting with v0.2, `tokens` was renamed to `token_loader`, " + "which must be a callable for lazy loading.\n\n" + "Instead of passing\n" + "` tokens=dataset.tokens`,\n" + "pass\n" + "` token_loader=lambda: dataset.load_tokens()`,\n" + "(assuming `dataset` is a `FeatureDataset` instance)." + ) + pool_max_activation_windows( + record, + buffer_output=buffer_output, + tokens=tokens, + ctx_len=ctx_len, + max_examples=max_examples, + ) + neighbour_random_activation_windows( + record, + tokens=tokens, + buffer_output=buffer_output, + everything=everything, + n_random=n_random, + ctx_len=ctx_len, + ) + \ No newline at end of file diff --git a/delphi/features/features.py b/delphi/features/features.py index ad86fa9b..fb4b70ad 100644 --- a/delphi/features/features.py +++ b/delphi/features/features.py @@ -97,7 +97,8 @@ def __init__( self.examples = [] self.train: list[list[Example]] = [] self.test: list[list[Example]] = [] - + self.neighbours: list[int] = [] + @property def max_activation(self): """ diff --git a/delphi/features/loader.py b/delphi/features/loader.py index 30a6af98..3f247737 100644 --- a/delphi/features/loader.py +++ b/delphi/features/loader.py @@ -145,12 +145,15 @@ def __init__( """ self.cfg = cfg self.buffers = [] + if features is None: self._build(raw_dir, modules) else: self._build_selected(raw_dir, modules, features) + self.everything = self._build_everything(raw_dir, modules) + cache_config_dir = f"{raw_dir}/{modules[0]}/config.json" with open(cache_config_dir, "r") as f: cache_config = json.load(f) @@ -244,76 +247,48 @@ def _build_selected( ) ) - def __len__(self): - """Return the number of buffers in the dataset.""" - return len(self.buffers) - - def load( - self, - collate: bool = False, - constructor: Optional[Callable] = None, - sampler: Optional[Callable] = None, - transform: Optional[Callable] = None, - ): + def _build_everything(self, raw_dir: str, modules: List[str]): """ - Load and process feature records from the dataset. - - Args: - collate (bool): Whether to collate all records into a single list. - constructor (Optional[Callable]): Function to construct feature records. - sampler (Optional[Callable]): Function to sample from feature records. - transform (Optional[Callable]): Function to transform feature records. - - Returns: - Union[List[FeatureRecord], Generator]: Processed feature records. + Build a BufferOutput with the locations and activations of all features. """ - def _process(buffer_output: BufferOutput): - record = FeatureRecord(buffer_output.feature) - if constructor is not None: - constructor(record=record, buffer_output=buffer_output) - - if sampler is not None: - sampler(record) - - if transform is not None: - transform(record) - - return record - - def _worker(buffer): - return [ - _process(data) - for data in tqdm(buffer, desc=f"Loading {buffer.module_path}") - if data is not None - ] + edges = self._edges() + everything = {} + all_locations = [] + all_activations = [] + tokens = None + for module in modules: + for start, end in zip(edges[:-1], edges[1:]): + path = f"{raw_dir}/{module}/{start}_{end-1}.safetensors" + split_data = load_file(path) + first_feature = int(path.split("/")[-1].split("_")[0]) + activations = torch.tensor(split_data["activations"]) + locations = torch.tensor(split_data["locations"].astype(np.int64)) + if tokens is None: + if "tokens" in split_data: + tokens = torch.tensor(split_data["tokens"].astype(np.int64)) + else: + tokens = None + + locations[:,2] = locations[:,2] + first_feature + all_locations.append(locations) + all_activations.append(activations) - return self._load(collate, _worker) - - def _load(self, collate: bool, _worker: Callable): - """ - Internal method to load feature records. + all_locations = torch.cat(all_locations) + all_activations = torch.cat(all_activations) + everything[module] = BufferOutput(-1, all_locations, all_activations, tokens) - Args: - collate (bool): Whether to collate all records into a single list. - _worker (Callable): Function to process each buffer. + - Returns: - Union[List[FeatureRecord], Generator]: Processed feature records. - """ - if collate: - all_records = [] - for buffer in self.buffers: - all_records.extend(_worker(buffer)) - return all_records - else: - for buffer in self.buffers: - yield _worker(buffer) - + def __len__(self): + """Return the number of buffers in the dataset.""" + return len(self.buffers) + def reset(self): """Reset all buffers in the dataset.""" for buffer in self.buffers: buffer.reset() + class FeatureLoader: """ Loader class for processing feature records from a FeatureDataset. @@ -379,12 +354,13 @@ async def _aprocess_feature(self, buffer_output: BufferOutput): Optional[FeatureRecord]: Processed feature record or None. """ record = FeatureRecord(buffer_output.feature) + if self.transform is not None: + self.transform(record) if self.constructor is not None: self.constructor(record=record, buffer_output=buffer_output) if self.sampler is not None: self.sampler(record) - if self.transform is not None: - self.transform(record) + return record def __iter__(self): @@ -425,10 +401,11 @@ def _process_feature(self, buffer_output: BufferOutput): Optional[FeatureRecord]: Processed feature record or None. """ record = FeatureRecord(buffer_output.feature) + if self.transform is not None: + self.transform(record) if self.constructor is not None: self.constructor(record=record, buffer_output=buffer_output) if self.sampler is not None: self.sampler(record) - if self.transform is not None: - self.transform(record) + return record \ No newline at end of file From 16bb4b01763181e896957fa995ec3ee8e629266b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:09:26 +0000 Subject: [PATCH 036/132] Remove unused method --- delphi/clients/client.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/delphi/clients/client.py b/delphi/clients/client.py index c6fc409f..0d6b54b8 100644 --- a/delphi/clients/client.py +++ b/delphi/clients/client.py @@ -18,7 +18,3 @@ def __init__(self, model: str): async def generate(self, prompt: Union[str, List[Dict[str, str]]], **kwargs) -> Response: pass - @abstractmethod - async def process_response(self, raw_response: Any) -> Response: - pass - From f3781b5652fb44b282e0f894767fa7c09fbe8731 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:09:50 +0000 Subject: [PATCH 037/132] Make an abstract cache, and an activation cache --- delphi/features/cache.py | 333 +++++++++++++++++++++++++++++---------- 1 file changed, 247 insertions(+), 86 deletions(-) diff --git a/delphi/features/cache.py b/delphi/features/cache.py index d1abb260..a09070b7 100644 --- a/delphi/features/cache.py +++ b/delphi/features/cache.py @@ -5,16 +5,100 @@ import numpy as np import torch -from safetensors.numpy import save_file +from safetensors.numpy import save_file, load_file from torchtyping import TensorType from tqdm import tqdm from delphi.config import CacheConfig -class Cache: + +class BaseCache: + """ + Base class for caching activations. + """ + def __init__( + self, + model, + submodule_dict: Dict, + batch_size: int, + filters: Dict[str, TensorType["indices"]] = None, + ): + """ + Initialize the FeatureCache. + + Args: + model: The model to cache features for. + submodule_dict (Dict): Dictionary of submodules to cache. + batch_size (int): Size of batches for processing. + filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. + """ + self.model = model + self.submodule_dict = submodule_dict + + self.batch_size = batch_size + + if filters is not None: + self._filter_submodules(filters) + + def _filter_submodules(self, filters: Dict[str, TensorType["indices"]]): + """ + Filter submodules based on the provided filters. + + Args: + filters (Dict[str, TensorType["indices"]]): Filters for selecting specific features. + """ + filtered_submodules = {} + for module_path in self.submodule_dict.keys(): + if module_path in filters: + filtered_submodules[module_path] = self.submodule_dict[module_path] + self.submodule_dict = filtered_submodules + + def _load_token_batches( + self, n_tokens: int, tokens: TensorType["batch", "sequence"] + ): + """ + Load and prepare token batches for processing. + + Args: + n_tokens (int): Total number of tokens to process. + tokens (TensorType["batch", "sequence"]): Input tokens. + + Returns: + List[torch.Tensor]: List of token batches. + """ + max_batches = n_tokens // tokens.shape[1] + tokens = tokens[:max_batches] + + n_mini_batches = len(tokens) // self.batch_size + + token_batches = [ + tokens[self.batch_size * i : self.batch_size * (i + 1), :] + for i in range(n_mini_batches) + ] + + return token_batches + + def _generate_split_indices(self, n_splits): + """ + Generate indices for splitting the feature space. + + Args: + n_splits (int): Number of splits to generate. + + Returns: + List[Tuple[int, int]]: List of start and end indices for each split. + """ + boundaries = torch.linspace(0, self.width, steps=n_splits + 1).long() + + # Adjust end by one + return list(zip(boundaries[:-1], boundaries[1:] - 1)) + + +#TODO: I don't like this class name. +class LatentAtivationBuffer: """ - The Cache class stores feature locations and activations for modules. + The LatentAtivationBuffer class stores feature locations and activations for modules. It provides methods for adding, saving, and retrieving non-zero activations. """ @@ -50,7 +134,7 @@ def add( batch_number (int): Current batch number. module_path (str): Path of the module. """ - feature_locations, feature_activations = self.get_nonzeros(latents, module_path) + feature_locations, feature_activations = self._get_nonzeros(latents, module_path) feature_locations = feature_locations.cpu() feature_activations = feature_activations.cpu() tokens = tokens.cpu() @@ -78,7 +162,7 @@ def save(self): self.tokens[module_path], dim=0 ) - def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]): + def _get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]): """ Get non-zero activations for large batches that exceed int32 max value. @@ -110,7 +194,7 @@ def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]): nonzero_feature_activations = torch.cat(nonzero_feature_activations, dim=0) return nonzero_feature_locations, nonzero_feature_activations - def get_nonzeros( + def _get_nonzeros( self, latents: TensorType["batch", "seq", "feature"], module_path: str ): """ @@ -125,7 +209,7 @@ def get_nonzeros( """ size = latents.shape[1] * latents.shape[0] * latents.shape[2] if size > torch.iinfo(torch.int32).max: - nonzero_feature_locations, nonzero_feature_activations = self.get_nonzeros_batch(latents) + nonzero_feature_locations, nonzero_feature_activations = self._get_nonzeros_batch(latents) else: nonzero_feature_locations = torch.nonzero(latents.abs() > 1e-5) nonzero_feature_activations = latents[latents.abs() > 1e-5] @@ -142,7 +226,8 @@ def get_nonzeros( return nonzero_feature_locations[mask], nonzero_feature_activations[mask] -class FeatureCache: + +class FeatureCache(BaseCache): """ FeatureCache manages the caching of feature activations for a model. It handles the process of running the model, storing activations, and saving them to disk. @@ -164,54 +249,11 @@ def __init__( batch_size (int): Size of batches for processing. filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. """ - self.model = model - self.submodule_dict = submodule_dict - - self.batch_size = batch_size + super().__init__(model, submodule_dict, batch_size, filters) self.width = list(submodule_dict.values())[0].ae.width - self.cache = Cache(filters, batch_size=batch_size) - if filters is not None: - self.filter_submodules(filters) - - - def load_token_batches( - self, n_tokens: int, tokens: TensorType["batch", "sequence"] - ): - """ - Load and prepare token batches for processing. - - Args: - n_tokens (int): Total number of tokens to process. - tokens (TensorType["batch", "sequence"]): Input tokens. - - Returns: - List[torch.Tensor]: List of token batches. - """ - max_batches = n_tokens // tokens.shape[1] - tokens = tokens[:max_batches] - - n_mini_batches = len(tokens) // self.batch_size - - token_batches = [ - tokens[self.batch_size * i : self.batch_size * (i + 1), :] - for i in range(n_mini_batches) - ] - - return token_batches - - def filter_submodules(self, filters: Dict[str, TensorType["indices"]]): - """ - Filter submodules based on the provided filters. - - Args: - filters (Dict[str, TensorType["indices"]]): Filters for selecting specific features. - """ - filtered_submodules = {} - for module_path in self.submodule_dict.keys(): - if module_path in filters: - filtered_submodules[module_path] = self.submodule_dict[module_path] - self.submodule_dict = filtered_submodules + self.buffer = LatentAtivationBuffer(filters, batch_size=batch_size) + def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): """ @@ -221,7 +263,7 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): n_tokens (int): Total number of tokens to process. tokens (TensorType["batch", "seq"]): Input tokens. """ - token_batches = self.load_token_batches(n_tokens, tokens) + token_batches = self._load_token_batches(n_tokens, tokens) total_tokens = 0 total_batches = len(token_batches) @@ -232,14 +274,14 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): total_tokens += tokens_per_batch with torch.no_grad(): - buffer = {} + activations_buffer = {} with self.model.trace(batch): for module_path, submodule in self.submodule_dict.items(): - buffer[module_path] = submodule.ae.output.save() - for module_path, latents in buffer.items(): - self.cache.add(latents, batch, batch_number, module_path) + activations_buffer[module_path] = submodule.ae.output.save() + for module_path, latents in activations_buffer.items(): + self.buffer.add(latents, batch, batch_number, module_path) - del buffer + del activations_buffer torch.cuda.empty_cache() # Update the progress bar @@ -247,7 +289,7 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): pbar.set_postfix({"Total Tokens": f"{total_tokens:,}"}) print(f"Total tokens processed: {total_tokens:,}") - self.cache.save() + self.buffer.save() def save(self, save_dir, save_tokens: bool = True): """ @@ -257,33 +299,18 @@ def save(self, save_dir, save_tokens: bool = True): save_dir (str): Directory to save the features. save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True. """ - for module_path in self.cache.feature_locations.keys(): + for module_path in self.buffer.feature_locations.keys(): output_file = f"{save_dir}/{module_path}.safetensors" data = { - "locations": self.cache.feature_locations[module_path], - "activations": self.cache.feature_activations[module_path], + "locations": self.buffer.feature_locations[module_path], + "activations": self.buffer.feature_activations[module_path], } if save_tokens: - data["tokens"] = self.cache.tokens[module_path] + data["tokens"] = self.buffer.tokens[module_path] save_file(data, output_file) - def _generate_split_indices(self, n_splits): - """ - Generate indices for splitting the feature space. - - Args: - n_splits (int): Number of splits to generate. - - Returns: - List[Tuple[int, int]]: List of start and end indices for each split. - """ - boundaries = torch.linspace(0, self.width, steps=n_splits + 1).long() - - # Adjust end by one - return list(zip(boundaries[:-1], boundaries[1:] - 1)) - def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True): """ Save the cached features in splits. @@ -294,10 +321,10 @@ def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True): save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True. """ split_indices = self._generate_split_indices(n_splits) - for module_path in self.cache.feature_locations.keys(): - feature_locations = self.cache.feature_locations[module_path] - feature_activations = self.cache.feature_activations[module_path] - tokens = self.cache.tokens[module_path].numpy() + for module_path in self.buffer.feature_locations.keys(): + feature_locations = self.buffer.feature_locations[module_path] + feature_activations = self.buffer.feature_activations[module_path] + tokens = self.buffer.tokens[module_path].numpy() feature_indices = feature_locations[:, 2] @@ -339,9 +366,143 @@ def save_config(self, save_dir: str, cfg: CacheConfig, model_name: str): cfg (CacheConfig): Configuration object. model_name (str): Name of the model. """ - for module_path in self.cache.feature_locations.keys(): + for module_path in self.buffer.feature_locations.keys(): config_file = f"{save_dir}/{module_path}/config.json" with open(config_file, "w") as f: config_dict = cfg.to_dict() config_dict["model_name"] = model_name - json.dump(config_dict, f) \ No newline at end of file + json.dump(config_dict, f) + +# TODO: This looks like duplicate code + +class ActivationBuffer: + """ + The ActivationBuffer class stores activations for modules. + It provides methods for adding, saving, and retrieving activations. + """ + + def __init__( + self, + ): + """ + Initialize the Buffer. + """ + self.activations = defaultdict(list) + + def add( + self, + activations: TensorType["batch", "sequence", "feature"], + module_path: str, + ): + """ + Add the activations from a module to the cache. + + Args: + activations (TensorType["batch", "sequence", "hidden_dimension"]): Activations. + module_path (str): Path of the module. + """ + activations = activations.reshape(-1, activations.shape[2]).cpu() + + self.activations[module_path].append(activations) + + def save(self): + """ + Concatenate the pre-activations for all modules. + """ + for module_path in self.pre_activations.keys(): + self.pre_activations[module_path] = torch.cat( + self.pre_activations[module_path], dim=0 + ) + + + +class ResidualStreamCache(BaseCache): + """ + ResidualStreamCache manages the caching of residual stream of a model. + It handles the process of running the model, storing activations, and saving them to disk. + """ + + def __init__( + self, + model, + submodule_dict: Dict, + batch_size: int, + ): + """ + Initialize the ResidualStreamCache. + + Args: + model: The model to cache features for. + submodule_dict (Dict): Dictionary of submodules to cache. + batch_size (int): Size of batches for processing. + filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. + """ + super().__init__(model, submodule_dict, batch_size, None) + + self.buffer = ActivationBuffer() + + + def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): + """ + Run the feature caching process. + + Args: + n_tokens (int): Total number of tokens to process. + tokens (TensorType["batch", "seq"]): Input tokens. + """ + token_batches = self._load_token_batches(n_tokens, tokens) + + total_tokens = 0 + total_batches = len(token_batches) + tokens_per_batch = token_batches[0].numel() + + with tqdm(total=total_batches, desc="Caching features") as pbar: + for batch in token_batches: + total_tokens += tokens_per_batch + + with torch.no_grad(): + activations_buffer = {} + with self.model.trace(batch): + for module_path, submodule in self.submodule_dict.items(): + if "input" in module_path: + activations_buffer[module_path] = submodule.input.save() + else: + activations_buffer[module_path] = submodule.output.save() + for module_path, pre_activations in activations_buffer.items(): + self.buffer.add(pre_activations, module_path) + + del activations_buffer + torch.cuda.empty_cache() + + # Update the progress bar + pbar.update(1) + pbar.set_postfix({"Total Tokens": f"{total_tokens:,}"}) + + print(f"Total tokens processed: {total_tokens:,}") + self.buffer.save() + + def save(self, save_dir): + """ + Save the cached features to disk. + + Args: + save_dir (str): Directory to save the features. + save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True. + """ + for module_path in self.buffer.activations.keys(): + output_file = f"{save_dir}/{module_path}.safetensors" + + data = { + "activations": self.buffer.activations[module_path].half().numpy(), + } + save_file(data, output_file) + + + def load(self, load_dir: str): + """ + Load the cached features from disk. + """ + for module_path in self.buffer.activations.keys(): + input_file = f"{load_dir}/{module_path}.safetensors" + data = load_file(input_file) + self.buffer.activations[module_path] = data["activations"] From cc7d5d6728358bb3ac25fa3c58b2dc0cd3e973c8 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 6 Feb 2025 18:10:02 +0000 Subject: [PATCH 038/132] small update to autoencoders --- delphi/autoencoders/Custom/gemmascope.py | 14 ++++----- delphi/autoencoders/eleuther.py | 36 +++++++++++++++--------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/delphi/autoencoders/Custom/gemmascope.py b/delphi/autoencoders/Custom/gemmascope.py index efccbcf8..1a09036a 100644 --- a/delphi/autoencoders/Custom/gemmascope.py +++ b/delphi/autoencoders/Custom/gemmascope.py @@ -31,16 +31,12 @@ def forward(self, acts): recon = self.decode(acts) return recon - @classmethod -<<<<<<<< HEAD:sae_auto_interp/autoencoders/Custom/gemmascope.py - def from_pretrained(cls, model_name_or_path,position,device): -======== - def from_pretrained(cls, path: str, type: str, device: str) -> nn.Module: ->>>>>>>> 30d5e27537e1c108a9dda37e87e51dda9bfa4206:delphi/autoencoders/DeepMind/model.py +@classmethod +def from_pretrained(cls, path: str, type: str, device: str) -> nn.Module: path_to_params = hf_hub_download( - repo_id=model_name_or_path, - filename=f"{position}/params.npz", - force_download=False, + repo_id=path, + filename=f"{type}/params.npz", + force_download=False, ) params = np.load(path_to_params) pt_params = {k: torch.from_numpy(v) for k, v in params.items()} diff --git a/delphi/autoencoders/eleuther.py b/delphi/autoencoders/eleuther.py index dc4bbab8..4ad37c32 100644 --- a/delphi/autoencoders/eleuther.py +++ b/delphi/autoencoders/eleuther.py @@ -1,10 +1,12 @@ from functools import partial, reduce -from typing import List, Any, Tuple, Optional, Dict +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + import torch from nnsight import LanguageModel from sparsify import Sae -from pathlib import Path -from .OpenAI.model import ACTIVATIONS_CLASSES, TopK + +from .Custom.openai import ACTIVATIONS_CLASSES, TopK from .wrapper import AutoencoderLatents DEVICE = "cuda:0" @@ -15,6 +17,7 @@ def load_eai_autoencoders( ae_layers: List[int], weight_dir: str, module: str, + transcoder: bool = False, randomize: bool = False, seed: int = 42, k: Optional[int] = None @@ -63,21 +66,22 @@ def _forward(sae, k,x): trained_k = sae.cfg.k topk = TopK(trained_k, postact_fn=ACTIVATIONS_CLASSES["Identity"]()) return topk(encoded) - - if "llama" in weight_dir: + if "pythia" in weight_dir: if module == "res": - submodule = model.model.layers[layer] + submodule = model.gpt_neox.layers[layer] else: - submodule = model.model.layers[layer].mlp + submodule = model.gpt_neox.layers[layer].mlp + elif "gpt2" in weight_dir: submodule = model.transformer.h[layer] else: if module == "res": - submodule = model.gpt_neox.layers[layer] + submodule = model.model.layers[layer] else: - submodule = model.gpt_neox.layers[layer].mlp + submodule = model.model.layers[layer].mlp + submodule.ae = AutoencoderLatents( - sae, partial(_forward, sae, k), width=sae.encoder.weight.shape[0] + sae, partial(_forward, sae, k), width=sae.encoder.weight.shape[0],hookpoint=submodule.path ) submodules[submodule.path] = submodule @@ -85,14 +89,18 @@ def _forward(sae, k,x): with model.edit("") as edited: for path, submodule in submodules.items(): if "embed" not in path and "mlp" not in path: - acts = submodule.output[0] + if transcoder: + acts = submodule.input[0] + else: + acts = submodule.output[0] else: - acts = submodule.output + if transcoder: + acts = submodule.input + else: + acts = submodule.output submodule.ae(acts, hook=True) return submodules, edited - - def resolve_path(model, path_segments: List[str]) -> List[str] | None: """Attempt to resolve the path segments to the model in the case where it has been wrapped (e.g. by a LanguageModel, causal model, or classifier).""" From da8fa2a2b847c24b3b7118c555ce7911cdfe59e9 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 15:53:11 +0000 Subject: [PATCH 039/132] Handle distance in detection and fuzzing --- delphi/scorers/classifier/classifier.py | 3 +-- delphi/scorers/classifier/detection.py | 29 +++++++++++++++++++------ delphi/scorers/classifier/fuzz.py | 29 +++++++++++++++++++------ 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 0f4dd793..3c454390 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -13,7 +13,6 @@ from ..scorer import Scorer, ScorerResult from .sample import ClassifierOutput, Sample - class Classifier(Scorer): def __init__( self, @@ -39,8 +38,8 @@ async def __call__( record: FeatureRecord, ) -> list[ClassifierOutput]: samples = self._prepare(record) - random.shuffle(samples) + samples = self._batch(samples) results = await self._query( record.explanation, diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 97019d07..89f9ff7d 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -1,7 +1,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from ...clients.client import Client -from ...features import FeatureRecord +from ...features import FeatureRecord, Example from .classifier import Classifier from .prompts.detection_prompt import prompt from .sample import Sample, examples_to_samples @@ -35,12 +35,27 @@ def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: Prepare and shuffle a list of samples for classification. """ - samples = examples_to_samples( - record.random_examples, - distance=-1, - ground_truth=False, - tokenizer=self.tokenizer, - ) + # check if random_examples is a list of lists or a list of examples + if isinstance(record.random_examples[0], tuple): + # Here we are using neighbours + samples = [] + for i, (examples, neighbour) in enumerate(record.random_examples): + samples.extend( + examples_to_samples( + examples, + distance=-neighbour.distance, + ground_truth=False, + tokenizer=self.tokenizer, + ) + ) + elif isinstance(record.random_examples[0], Example): + # This is if we dont use neighbours + samples = examples_to_samples( + record.random_examples, + distance=-1, + ground_truth=False, + tokenizer=self.tokenizer, + ) for i, examples in enumerate(record.test): samples.extend( diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index f4794ee2..caec1408 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -66,14 +66,29 @@ def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: all_examples.extend(examples_chunk) n_incorrect = self.mean_n_activations_ceil(all_examples) + if isinstance(record.random_examples[0], tuple): + # Here we are using neighbours + samples = [] + for i, (examples, neighbour) in enumerate(record.random_examples): + samples.extend( + examples_to_samples( + examples, + distance=-neighbour.distance, + ground_truth=False, + n_incorrect=n_incorrect, + **defaults, + ) + ) + elif isinstance(record.random_examples[0], Example): + # This is if we dont use neighbours + samples = examples_to_samples( + record.random_examples, + distance=-1, + ground_truth=False, + n_incorrect=n_incorrect, + **defaults, + ) - samples = examples_to_samples( - record.extra_examples, - distance=-1, - ground_truth=False, - n_incorrect=n_incorrect, - **defaults, - ) for i, examples in enumerate(record.test): samples.extend( From c3d1d972e64847693f9131e4f19769e34f94ecb0 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 15:53:45 +0000 Subject: [PATCH 040/132] Correctly create neighbours in transform --- delphi/features/transforms.py | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/delphi/features/transforms.py b/delphi/features/transforms.py index 07cbcf5a..b660f451 100644 --- a/delphi/features/transforms.py +++ b/delphi/features/transforms.py @@ -2,19 +2,31 @@ import torch from torchtyping import TensorType - +from dataclasses import dataclass from .features import FeatureRecord, prepare_examples from .loader import BufferOutput import json + + +@dataclass +class Neighbour: + distance: float + feature_index: int + def set_neighbours( record: FeatureRecord, - neighbours_path: str, - neighbours_type: str, + neighbours: dict[int, list[tuple[float, int]]], + threshold: float, ): """ Set the neighbours for the feature record. """ - with open(neighbours_path, "r") as f: - neighbours = json.load(f) - record.neighbours = neighbours[neighbours_type][record.feature.feature_index] + + neighbours = neighbours[str(record.feature.feature_index)] + + # Each element in neighbours is a tuple of (distance,feature_index) + # We want to keep only the ones with a distance less than the threshold + neighbours = [neighbour for neighbour in neighbours if neighbour[0] > threshold] + + record.neighbours = [Neighbour(distance=neighbour[0], feature_index=neighbour[1]) for neighbour in neighbours] From c29c2cb04a0c61f082b9fceed062835197cd6b14 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 15:54:43 +0000 Subject: [PATCH 041/132] Return neighbour distance --- delphi/features/neighbours.py | 50 +++++++++++++++++++++-------------- 1 file changed, 30 insertions(+), 20 deletions(-) diff --git a/delphi/features/neighbours.py b/delphi/features/neighbours.py index 3ec9170e..b93c1e7b 100644 --- a/delphi/features/neighbours.py +++ b/delphi/features/neighbours.py @@ -1,5 +1,5 @@ import json -from typing import Dict, List, Optional +from typing import Optional from tqdm import tqdm import numpy as np @@ -22,7 +22,7 @@ def __init__( autoencoder: Optional["Autoencoder"] = None, residual_stream_record: Optional['ResidualStreamRecord'] = None, number_of_neighbours: int = 10, - neighbour_cache: Optional[Dict[str, Dict[int, List[int]]]] = None, + neighbour_cache: Optional[dict[str, dict[int, list[tuple[int,float]]]]] = None, ): """ @@ -42,8 +42,8 @@ def __init__( if neighbour_cache is not None: self.neighbour_cache = neighbour_cache else: - # Dictionary to cache computed neighbour lists - self.neighbour_cache: Dict[str, Dict[int, List[int]]] = {} + # dictionary to cache computed neighbour lists + self.neighbour_cache: dict[str, dict[int, list[int]]] = {} def _compute_neighbour_list(self, method: str) -> None: @@ -53,11 +53,14 @@ def _compute_neighbour_list(self, method: str) -> None: Args: method (str): One of 'similarity', 'correlation', or 'co-occurrence' """ - if method == 'similarity': + if method == 'similarity_encoder': if self.autoencoder is None: raise ValueError("Autoencoder is required for similarity-based neighbours") - self.neighbour_cache[method] = self._compute_similarity_neighbours() - + self.neighbour_cache[method] = self._compute_similarity_neighbours("encoder") + elif method == 'similarity_decoder': + if self.autoencoder is None: + raise ValueError("Autoencoder is required for similarity-based neighbours") + self.neighbour_cache[method] = self._compute_similarity_neighbours("decoder") elif method == 'correlation': if self.autoencoder is None or self.residual_stream_record is None: raise ValueError("Autoencoder and residual stream record are required for correlation-based neighbours") @@ -71,15 +74,22 @@ def _compute_neighbour_list(self, method: str) -> None: else: raise ValueError(f"Unknown method: {method}. Use 'similarity', 'correlation', or 'co-occurrence'") - def _compute_similarity_neighbours(self) -> Dict[int, List[int]]: + def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: """ Compute neighbour lists based on weight similarity in the autoencoder. """ print("Computing similarity neighbours") # We use the encoder vectors to compute the similarity between features - encoder = self.autoencoder.encoder.cuda() + if method == "encoder": + encoder = self.autoencoder.encoder.cuda() + weight_matrix_normalized = encoder.weight.data / encoder.weight.data.norm(dim=1, keepdim=True) + + elif method == "decoder": + decoder = self.autoencoder.W_dec.cuda() + weight_matrix_normalized = decoder.data / decoder.data.norm(dim=1, keepdim=True) + else: + raise ValueError(f"Unknown method: {method}. Use 'encoder' or 'decoder'") - weight_matrix_normalized = encoder.weight / encoder.weight.norm(dim=1, keepdim=True) wT = weight_matrix_normalized.T # Compute the similarity between features done = False @@ -93,8 +103,8 @@ def _compute_similarity_neighbours(self) -> Dict[int, List[int]]: for start in tqdm(range(0,number_features,batch_size)): rows = wT[start:start+batch_size] similarity_matrix = weight_matrix_normalized @ rows - top_k_indices = torch.topk(similarity_matrix, self.number_of_neighbours+1, dim=1).indices - neighbour_lists.update({i+start: top_k_indices[i].tolist()[1:] for i in range(len(top_k_indices))}) + indices,values = torch.topk(similarity_matrix, self.number_of_neighbours+1, dim=1) + neighbour_lists.update({i+start: list(zip(indices[i].tolist()[1:],values[i].tolist()[1:])) for i in range(len(indices))}) del similarity_matrix torch.cuda.empty_cache() done = True @@ -107,7 +117,7 @@ def _compute_similarity_neighbours(self) -> Dict[int, List[int]]: return neighbour_lists - def _compute_correlation_neighbours(self) -> Dict[int, List[int]]: + def _compute_correlation_neighbours(self) -> dict[int, list[int]]: """ Compute neighbour lists based on activation correlation patterns. """ @@ -146,13 +156,13 @@ def _compute_correlation_neighbours(self) -> Dict[int, List[int]]: correlation_matrix = covariance_between_features / product_of_std # get the indices of the top k neighbours for each feature - top_k_indices = torch.topk(correlation_matrix, self.number_of_neighbours+1, dim=1).indices + indices,values = torch.topk(correlation_matrix, self.number_of_neighbours+1, dim=1) # return the neighbour lists - return {i: top_k_indices[i].tolist()[1:] for i in range(len(top_k_indices))} + return {i: list(zip(indices[i].tolist()[1:],values[i].tolist()[1:])) for i in range(len(indices))} - def _compute_cooccurrence_neighbours(self) -> Dict[int, List[int]]: + def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: """ Compute neighbour lists based on feature co-occurrence in the dataset. If you run out of memory try reducing the token_batch_size @@ -224,16 +234,16 @@ def compute_jaccard(cooc_matrix): jaccard_torch = torch.as_tensor(cp.asnumpy(jaccard_matrix)) # get the indices of the top k neighbours for each feature - top_k_indices = torch.topk(jaccard_torch, self.number_of_neighbours+1, dim=1).indices + top_k_indices,values = torch.topk(jaccard_torch, self.number_of_neighbours+1, dim=1) del jaccard_matrix,cooc_matrix,jaccard_torch torch.cuda.empty_cache() # return the neighbour lists - return {i: top_k_indices[i].tolist()[1:] for i in range(len(top_k_indices))} + return {i: list(zip(top_k_indices[i].tolist()[1:],values[i].tolist()[1:])) for i in range(len(top_k_indices))} - def populate_neighbour_cache(self, methods: List[str]) -> None: + def populate_neighbour_cache(self, methods: list[str]) -> None: """ Populate the neighbour cache with the computed neighbour lists """ @@ -248,7 +258,7 @@ def save_neighbour_cache(self, path: str) -> None: with open(path, 'w') as f: json.dump(self.neighbour_cache, f) - def load_neighbour_cache(self, path: str) -> Dict[str, Dict[int, List[int]]]: + def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: """ Load the neighbour cache from the path as a json file """ From 7a56a5d4cb1e4c7eedcd58b36cc4ce74be8b6d46 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 15:55:16 +0000 Subject: [PATCH 042/132] Add all data temporary fix --- delphi/features/loader.py | 118 ++++++++++---------------------------- 1 file changed, 29 insertions(+), 89 deletions(-) diff --git a/delphi/features/loader.py b/delphi/features/loader.py index 3f247737..357a2ea9 100644 --- a/delphi/features/loader.py +++ b/delphi/features/loader.py @@ -18,6 +18,13 @@ from ..features.features import Feature, FeatureRecord +class AllData(NamedTuple): + features: list[TensorType["features"]] + locations: list[TensorType["locations", 2]] + activations: list[TensorType["locations"]] + tokens: TensorType["tokens"] + + class BufferOutput(NamedTuple): """ Represents the output of a TensorBuffer. @@ -32,8 +39,7 @@ class BufferOutput(NamedTuple): locations: TensorType["locations", 2] activations: TensorType["locations"] tokens: TensorType["tokens"] - - + class TensorBuffer: """ Lazy loading buffer for cached splits. @@ -112,8 +118,6 @@ def load(self): return features, split_locations, split_activations, tokens - - def reset(self): """Reset the buffer state.""" self.start = 0 @@ -152,7 +156,7 @@ def __init__( else: self._build_selected(raw_dir, modules, features) - self.everything = self._build_everything(raw_dir, modules) + self.all_data = self._build_everything(raw_dir, modules) cache_config_dir = f"{raw_dir}/{modules[0]}/config.json" with open(cache_config_dir, "r") as f: @@ -252,32 +256,27 @@ def _build_everything(self, raw_dir: str, modules: List[str]): Build a BufferOutput with the locations and activations of all features. """ edges = self._edges() - everything = {} + all_features = [] all_locations = [] all_activations = [] tokens = None + all_data = {} for module in modules: for start, end in zip(edges[:-1], edges[1:]): path = f"{raw_dir}/{module}/{start}_{end-1}.safetensors" - split_data = load_file(path) - first_feature = int(path.split("/")[-1].split("_")[0]) - activations = torch.tensor(split_data["activations"]) - locations = torch.tensor(split_data["locations"].astype(np.int64)) - if tokens is None: - if "tokens" in split_data: - tokens = torch.tensor(split_data["tokens"].astype(np.int64)) - else: - tokens = None - locations[:,2] = locations[:,2] + first_feature - all_locations.append(locations) - all_activations.append(activations) + buffer = TensorBuffer(path, module, min_examples=self.cfg.min_examples) + features, locations, activations, tk = buffer.load() + all_features.extend(features) + all_locations.extend(locations) + all_activations.extend(activations) + if tokens is None: + tokens = tk + all_features = torch.stack(all_features) + all_data[module] = AllData(all_features, all_locations, all_activations, tokens) - all_locations = torch.cat(all_locations) - all_activations = torch.cat(all_activations) - everything[module] = BufferOutput(-1, all_locations, all_activations, tokens) - + return all_data def __len__(self): """Return the number of buffers in the dataset.""" @@ -323,25 +322,14 @@ async def __aiter__(self): FeatureRecord: Processed feature records. """ for buffer in self.feature_dataset.buffers: - async for record in self._aprocess_buffer(buffer): - yield record - - async def _aprocess_buffer(self, buffer): - """ - Asynchronously process a buffer. - - Args: - buffer (TensorBuffer): Buffer to process. - - Yields: - Optional[FeatureRecord]: Processed feature record or None. - """ - for data in buffer: - if data is not None: - record = await self._aprocess_feature(data) - if record is not None: - yield record - await asyncio.sleep(0) # Allow other coroutines to run + for data in buffer: + if data is not None: + + record = await self._aprocess_feature(data) + if record is not None: + print() + yield record + await asyncio.sleep(0) async def _aprocess_feature(self, buffer_output: BufferOutput): """ @@ -360,52 +348,4 @@ async def _aprocess_feature(self, buffer_output: BufferOutput): self.constructor(record=record, buffer_output=buffer_output) if self.sampler is not None: self.sampler(record) - return record - - def __iter__(self): - """ - Synchronous iterator for processing feature records. - - Yields: - FeatureRecord: Processed feature records. - """ - for buffer in self.feature_dataset.buffers: - for record in self._process_buffer(buffer): - yield record - - def _process_buffer(self, buffer): - """ - Process a buffer synchronously. - - Args: - buffer (TensorBuffer): Buffer to process. - - Yields: - Optional[FeatureRecord]: Processed feature record or None. - """ - for data in buffer: - if data is not None: - record = self._process_feature(data) - if record is not None: - yield record - - def _process_feature(self, buffer_output: BufferOutput): - """ - Process a single feature synchronously. - - Args: - buffer_output (BufferOutput): Feature data to process. - - Returns: - Optional[FeatureRecord]: Processed feature record or None. - """ - record = FeatureRecord(buffer_output.feature) - if self.transform is not None: - self.transform(record) - if self.constructor is not None: - self.constructor(record=record, buffer_output=buffer_output) - if self.sampler is not None: - self.sampler(record) - - return record \ No newline at end of file From 941cdc2bb16fe7c9dc71750ee884c856d3057ff5 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 15:55:32 +0000 Subject: [PATCH 043/132] Use all data --- delphi/features/constructors.py | 86 ++++++++++++++++++++------------- 1 file changed, 52 insertions(+), 34 deletions(-) diff --git a/delphi/features/constructors.py b/delphi/features/constructors.py index 1619e6a3..ebea029a 100644 --- a/delphi/features/constructors.py +++ b/delphi/features/constructors.py @@ -4,8 +4,7 @@ from torchtyping import TensorType from .features import FeatureRecord, prepare_examples -from .loader import BufferOutput - +from .loader import BufferOutput, AllData def _top_k_pools( max_buffer: TensorType["batch"], @@ -124,7 +123,7 @@ def neighbour_random_activation_windows( record: FeatureRecord, tokens: TensorType["batch", "seq"], buffer_output: BufferOutput, - everything: BufferOutput, + all_data: AllData, ctx_len: int, n_random: int, ): @@ -146,42 +145,61 @@ def neighbour_random_activation_windows( batch_size = tokens.shape[0] - mask = torch.zeros(batch_size, dtype=torch.bool) + # Get the unique batch positions where the latent is active + unique_batch_pos_active = buffer_output.locations[:, 0].unique_consecutive() + mask = torch.zeros(batch_size, dtype=torch.bool) + + # TODO: For now we use at most 10 examples per neighbour, we may want to allow a variable number of examples per neighbour + n_examples_per_neighbour = 10 + + number_examples = 0 + available_features = all_data.features + all_examples = [] for neighbour in record.neighbours: - # Get the possible batch positions where the neighbour is active - possible_locations = everything.locations[everything.locations[:, 0] == neighbour] - # Get the unique locations - unique_possible_locations = possible_locations.unique(dim=0) + if number_examples >= n_random: + break + # find indice in all_data.features that matches the neighbour + indice = torch.where(available_features == neighbour.feature_index)[0] + if len(indice) == 0: + continue + # get the locations of the neighbour + locations = all_data.locations[indice] + # get the unique locations + unique_locations = locations[:,0].unique_consecutive(dim=0) + # Set the mask to True for the unique locations - mask[unique_possible_locations[:, 1]] = True - - # Get the unique batch positions where the latent is active - unique_batch_pos_active = buffer_output.locations[:, 0].unique() - # Set the mask to False for the unique locations where the latent is active - mask[unique_batch_pos_active] = False + mask[unique_locations] = True - available_indices = mask.nonzero().squeeze() + # Set the mask to False for the unique locations where the latent is active + # TODO: we probably want to be less strict here, we could use parts of the batch where the latent is not active + mask[unique_batch_pos_active] = False - # TODO:What to do when the latent is active at least once in each batch? - if available_indices.numel() < n_random: - print("No available indices") - record.random_examples = [] - return - else: - # Select the batch positions - selected_indices = available_indices[torch.randint(0,len(available_indices),size=(n_random,))] - # Select the token positions - selected_positions = torch.randint(0, tokens.shape[1] - ctx_len, size=(n_random,)) + available_indices = mask.nonzero().flatten() + + if available_indices.numel() == 0: + continue + size = min(n_examples_per_neighbour, len(available_indices)) + + selected_indices = torch.randint(0, len(unique_locations), size=(size,)) + selected_positions = torch.randint(0, tokens.shape[1] - ctx_len, size=(size,)) + + range_indices = torch.arange(ctx_len, device=tokens.device).unsqueeze(0) # Shape: (1, ctx_len) + # Each selected_positions gives a unique starting index. We add the range tensor to get indices for each example. + positions = selected_positions.unsqueeze(1) + range_indices - # Get tokens - toks = tokens[selected_indices, selected_positions : selected_positions + ctx_len] + # Get tokens + toks = tokens[selected_indices].gather(dim=1, index=positions) - record.random_examples = prepare_examples( - toks, - torch.zeros_like(toks), - ) + examples = prepare_examples(toks, torch.zeros_like(toks)) + number_examples += len(examples) + all_examples.append((examples, neighbour)) + + if len(all_examples) == 0: + print("No examples found") + + record.random_examples = all_examples def default_constructor( @@ -239,7 +257,7 @@ def neighbour_constructor( record: FeatureRecord, token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], buffer_output: BufferOutput, - everything: BufferOutput, + all_data: AllData, n_random: int, ctx_len: int, max_examples: int, @@ -247,7 +265,7 @@ def neighbour_constructor( """ Construct feature examples using pool max activation windows and random activation windows from neighbours. """ - tokens = everything.tokens + tokens = all_data.tokens if tokens is None: if token_loader is None: raise ValueError("Either tokens or token_loader must be provided") @@ -274,7 +292,7 @@ def neighbour_constructor( record, tokens=tokens, buffer_output=buffer_output, - everything=everything, + all_data=all_data, n_random=n_random, ctx_len=ctx_len, ) From 185a6d667a76b675baf65ca385796856c4138c48 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 15:55:46 +0000 Subject: [PATCH 044/132] Small fixes --- delphi/explainers/explainer.py | 16 ++++++++++------ delphi/features/samplers.py | 4 ++-- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 4a1c58be..4c88011c 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -24,13 +24,17 @@ def __call__(self, record: FeatureRecord) -> ExplainerResult: async def explanation_loader(record: FeatureRecord, explanation_dir: str) -> ExplainerResult: - async with aiofiles.open(f'{explanation_dir}/{record.feature}.txt', 'r') as f: - explanation = json.loads(await f.read()) + try: + async with aiofiles.open(f'{explanation_dir}/{record.feature}.txt', 'r') as f: + explanation = json.loads(await f.read()) + return ExplainerResult( + record=record, + explanation=explanation + ) + except FileNotFoundError: + return None - return ExplainerResult( - record=record, - explanation=explanation - ) + async def random_explanation_loader(record: FeatureRecord, explanation_dir: str) -> ExplainerResult: explanations = [f for f in os.listdir(explanation_dir) if f.endswith(".txt")] diff --git a/delphi/features/samplers.py b/delphi/features/samplers.py index fcc0fcc6..1c146b4d 100644 --- a/delphi/features/samplers.py +++ b/delphi/features/samplers.py @@ -150,8 +150,8 @@ def sample( if cfg.n_examples_test > 0: _test = test( examples, - max_activation, - cfg.n_examples_test, + max_activation, + cfg.n_examples_test, cfg.n_quantiles, cfg.test_type, ) From b621f82c0c084300ad3f9d24fbab0bc801df633c Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 16:06:43 +0000 Subject: [PATCH 045/132] Change batch to n_examples --- delphi/scorers/classifier/classifier.py | 18 +++++++++++++++--- delphi/scorers/classifier/detection.py | 17 +++++++++++++++-- delphi/scorers/classifier/fuzz.py | 17 +++++++++++++++-- 3 files changed, 45 insertions(+), 7 deletions(-) diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 0f4dd793..40a21a47 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -20,15 +20,27 @@ def __init__( client: Client, tokenizer: PreTrainedTokenizer, verbose: bool, - batch_size: int, + n_examples_shown: int, log_prob: bool, **generation_kwargs, ): + """ + Initialize a Classifier. + + Args: + client: The client to use for generation + tokenizer: The tokenizer used to cache the tokens + verbose: Whether to print verbose output + n_examples_shown: The number of examples to show in the prompt, + a larger number can both leak information and make + it harder for models to generate anwers in the correct format + log_prob: Whether to use log probabilities to allow for AUC calculation + generation_kwargs: Additional generation kwargs + """ self.client = client self.tokenizer = tokenizer self.verbose = verbose - - self.batch_size = batch_size + self.n_examples_shown = n_examples_shown self.generation_kwargs = generation_kwargs self.log_prob = log_prob diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 97019d07..1acddd8c 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -15,15 +15,28 @@ def __init__( client: Client, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, verbose: bool = False, - batch_size: int = 10, + n_examples_shown: int = 10, log_prob: bool = False, **generation_kwargs, ): + """ + Initialize a DetectionScorer. + + Args: + client: The client to use for generation. + tokenizer: The tokenizer used to cache the tokens + verbose: Whether to print verbose output. + n_examples_shown: The number of examples to show in the prompt, + a larger number can both leak information and make + it harder for models to generate anwers in the correct format + log_prob: Whether to use log probabilities to allow for AUC calculation + generation_kwargs: Additional generation kwargs + """ super().__init__( client=client, tokenizer=tokenizer, verbose=verbose, - batch_size=batch_size, + n_examples_shown=n_examples_shown, log_prob=log_prob, **generation_kwargs, ) diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index f4794ee2..22e7fe37 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -21,16 +21,29 @@ def __init__( client: Client, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, verbose: bool = False, - batch_size: int = 1, + n_examples_shown: int = 10, threshold: float = 0.3, log_prob: bool = False, **generation_kwargs, ): + """ + Initialize a FuzzingScorer. + + Args: + client: The client to use for generation. + tokenizer: The tokenizer used to cache the tokens + verbose: Whether to print verbose output. + n_examples_shown: The number of examples to show in the prompt, + a larger number can both leak information and make + it harder for models to generate anwers in the correct format + log_prob: Whether to use log probabilities to allow for AUC calculation + generation_kwargs: Additional generation kwargs + """ super().__init__( client=client, tokenizer=tokenizer, verbose=verbose, - batch_size=batch_size, + n_examples_shown=n_examples_shown, log_prob=log_prob, **generation_kwargs, ) From 60f06f89c7d72af51cf4675cdc925ec99d7090ee Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 16:29:48 +0000 Subject: [PATCH 046/132] Delete extra folder --- sae_auto_interp/config.py | 74 --------------------------------------- 1 file changed, 74 deletions(-) delete mode 100644 sae_auto_interp/config.py diff --git a/sae_auto_interp/config.py b/sae_auto_interp/config.py deleted file mode 100644 index 3873d798..00000000 --- a/sae_auto_interp/config.py +++ /dev/null @@ -1,74 +0,0 @@ -from dataclasses import dataclass -from typing import Literal - -from simple_parsing import Serializable - - -@dataclass -class ExperimentConfig(Serializable): - - n_examples_train: int = 50 - """Number of examples to sample for training""" - - n_examples_test: int = 50 - """Number of examples to sample for testing""" - - n_quantiles: int = 10 - """Number of quantiles to sample""" - - example_ctx_len: int = 32 - """Length of each example""" - - n_random: int = 50 - """Number of random examples to sample""" - - train_type: Literal["top", "random", "quantiles"] = "quantiles" - """Type of sampler to use for training""" - - test_type: Literal["quantiles", "activation"] = "quantiles" - """Type of sampler to use for testing""" - - - - -@dataclass -class FeatureConfig(Serializable): - width: int = 131072 - """Number of features in the autoencoder""" - - min_examples: int = 200 - """Minimum number of examples for a feature to be included""" - - max_examples: int = 10000 - """Maximum number of examples for a feature to included""" - - n_splits: int = 5 - """Number of splits that features were devided into""" - - -@dataclass -class CacheConfig(Serializable): - - dataset_repo: str = "kh4dien/fineweb-100m-sample" - """Dataset repository to use""" - - dataset_split: str = "train" - """Dataset split to use""" - - dataset_name: str = "" - """Dataset name to use""" - - dataset_column_name: str = "text" - """Dataset column name to use""" - - batch_size: int = 32 - """Number of sequences to process in a batch""" - - ctx_len: int = 256 - """Context length of the autoencoder. Each batch is shape (batch_size, ctx_len)""" - - n_tokens: int = 10_000_000 - """Number of tokens to cache""" - - n_splits: int = 5 - """Number of splits to divide .safetensors into""" From 40d9d86f4521f1d4f22ae53258ac9943bba22caf Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 16:31:31 +0000 Subject: [PATCH 047/132] Keep experiments in legacy code, stop updating them --- experiments/caching_gemma/cache.py | 220 ------------------ experiments/caching_gemma/cache_neurons.py | 65 ------ experiments/caching_gemma/embedding.py | 120 ---------- experiments/caching_gemma/example_bash.sh | 1 - .../generate_explanations_and_scores.py | 155 ------------ .../generate_only_explanations.py | 115 --------- .../caching_gemma/score_explanations.py | 134 ----------- .../caching_gemma/score_explanations_fig1.py | 134 ----------- .../caching_gemma/simulation_scoring.py | 131 ----------- experiments/caching_gemma/surprisal.py | 116 --------- experiments/caching_llama/cache.py | 82 ------- experiments/caching_llama/cache_neurons.py | 56 ----- .../generate_explanations_and_scores.py | 154 ------------ .../generate_only_explanations.py | 114 --------- .../caching_llama/score_explanations.py | 139 ----------- .../caching_llama/simulation_scoring.py | 135 ----------- experiments/caching_llama/surprisal.py | 116 --------- 17 files changed, 1987 deletions(-) delete mode 100644 experiments/caching_gemma/cache.py delete mode 100644 experiments/caching_gemma/cache_neurons.py delete mode 100644 experiments/caching_gemma/embedding.py delete mode 100644 experiments/caching_gemma/example_bash.sh delete mode 100644 experiments/caching_gemma/generate_explanations_and_scores.py delete mode 100644 experiments/caching_gemma/generate_only_explanations.py delete mode 100644 experiments/caching_gemma/score_explanations.py delete mode 100644 experiments/caching_gemma/score_explanations_fig1.py delete mode 100644 experiments/caching_gemma/simulation_scoring.py delete mode 100644 experiments/caching_gemma/surprisal.py delete mode 100644 experiments/caching_llama/cache.py delete mode 100644 experiments/caching_llama/cache_neurons.py delete mode 100644 experiments/caching_llama/generate_explanations_and_scores.py delete mode 100644 experiments/caching_llama/generate_only_explanations.py delete mode 100644 experiments/caching_llama/score_explanations.py delete mode 100644 experiments/caching_llama/simulation_scoring.py delete mode 100644 experiments/caching_llama/surprisal.py diff --git a/experiments/caching_gemma/cache.py b/experiments/caching_gemma/cache.py deleted file mode 100644 index d45d64f8..00000000 --- a/experiments/caching_gemma/cache.py +++ /dev/null @@ -1,220 +0,0 @@ -from nnsight import LanguageModel -from simple_parsing import ArgumentParser -import torch -from delphi.autoencoders import load_gemma_autoencoders -from delphi.config import CacheConfig -from delphi.features import FeatureCache -from delphi.utils import load_tokenized_data -import os - - -l0_dict_mlp = { - "16k": {0:50, - 1:56, - 2:33, - 3:55, - 4:66, - 5:46, - 6:46, - 7:47, - 8:55, - 9:40, - 10:49, - 11:34, - 12:42, - 13:40, - 14:41, - 15:45, - 16:37, - 17:41, - 18:36, - 19:38, - 20: 41, - 21:34, - 22:34, - 23:73, - 24:32, - 25:72, - 26:57, - 27:52, - 28:50, - 29:49, - 30:51, - 31:43, - 32:44, - 33:48, - 34:47, - 35:46, - 36:47, - 37:53, - 38:45, - 39:43, - 40:37, - 41:58 - - }, - "131k":{ - 20: 41, - 24: 33, - 28: 47, - 32: 40 -}} -l0_dict_res = { - "16k": {0:35, - 1:69, - 2:67, - 3:37, - 4:37, - 5:37, - 6:47, - 7:46, - 8:51, - 9:51, - 10:57, - 11:32, - 12:33, - 13:34, - 14:35, - 15:34, - 16:39, - 17:38, - 18:37, - 19:35, - 20: 36, - 21:36, - 22: 35, - 23: 35, - 24: 34, - 25: 34, - 26: 35, - 27:36, - 28: 37, - 29:38, - 30:37, - 31:35, - 32: 34, - 33:34, - 34:34, - 35:34, - 36:34, - 37:34, - 38:34, - 39:34, - 40:32, - 41:52 - }, - "131k": {0:30, - 1:33, - 2:36, - 3:46, - 4:51, - 5:51, - 6:66,#Doesnt work - 7:38, - 8:41, - 9:42, - 10:47, - 11:49, - 12:52, - 13:30,#Doesnt work - 14:56, - 15:55, - 16:35, - 17:35, - 18:34, - 19:32, - 20:34, - 21:33, - 22:32, - 23:32, - 24: 55, - 25:54, - 26:32, - 27:33, - 28: 32, - 29:33, - 30:32, - 31:52, - 32: 51, - 33:51, - 34:51, - 35:51, - 36: 51, - 37:53, - 38:53, - 39:54, - 40: 49, - 41:45, - } -} - -def main(cfg: CacheConfig,args): - layers = args.layers - size = args.size - type = args.type - name = args.name - random = args.random - model = LanguageModel("google/gemma-2-9b", device_map="cuda", dispatch=True,torch_dtype="float16") - layers = [int(layer) for layer in layers.split(",")] - if type == "res": - dict_l0 = l0_dict_res - elif type == "mlp": - dict_l0 = l0_dict_mlp - - submodule_dict,model = load_gemma_autoencoders( - model, - layers, - {layer: dict_l0[size][layer] for layer in layers}, - size, - type, - random - ) - - tokens = load_tokenized_data( - cfg.ctx_len, - model.tokenizer, - cfg.dataset_repo, - cfg.dataset_split, - cfg.dataset_name - ) - - cache = FeatureCache( - model, - submodule_dict, - batch_size=cfg.batch_size, - ) - - cache.run(10000000, tokens) - if name not in [""]: - name = f"_{name}" - if random: - name = name + "_random" - - if not os.path.exists(f"raw_features/gemma/{size}{name}"): - os.makedirs(f"raw_features/gemma/{size}{name}") - - cache.save_splits( - n_splits=cfg.n_splits, - save_dir=f"raw_features/gemma/{size}{name}" - ) - - cache.save_config( - save_dir=f"raw_features/gemma/{size}{name}", - cfg=cfg, - model_name="google/gemma-2-9b" - ) - -if __name__ == "__main__": - - parser = ArgumentParser() - #ctx len 256 - parser.add_arguments(CacheConfig, dest="options") - parser.add_argument("--layers", type=str, default="23,27") - parser.add_argument("--size", type=str, default="16k") - parser.add_argument("--type", type=str, default="res") - parser.add_argument("--name", type=str, default="") - parser.add_argument("--random", action="store_true") - args = parser.parse_args() - cfg = args.options - - main(cfg,args) diff --git a/experiments/caching_gemma/cache_neurons.py b/experiments/caching_gemma/cache_neurons.py deleted file mode 100644 index 6abb378c..00000000 --- a/experiments/caching_gemma/cache_neurons.py +++ /dev/null @@ -1,65 +0,0 @@ -from nnsight import LanguageModel -from simple_parsing import ArgumentParser - -from delphi.autoencoders import load_llama3_neurons -from delphi.config import CacheConfig -from delphi.features import FeatureCache -from delphi.utils import load_tokenized_data -import os - -def main(cfg: CacheConfig,args): - model = LanguageModel("google/gemma-2-9b", device_map="cuda", dispatch=True,torch_dtype="float16") - layers = args.layers - k=args.k - layers = [int(layer) for layer in layers.split(",")] - - submodule_dict,model = load_llama3_neurons( - model, - layers, - args.k - ) - - - tokens = load_tokenized_data( - cfg.ctx_len, - model.tokenizer, - cfg.dataset_repo, - cfg.dataset_split, - cfg.dataset_name - ) - - cache = FeatureCache( - model, - submodule_dict, - batch_size=cfg.batch_size, - ) - - cache.run(cfg.n_tokens, tokens) - - if not os.path.exists(f"raw_features/gemma/neurons_{k}"): - os.makedirs(f"raw_features/gemma/neurons_{k}") - - cache.save_splits( - n_splits=cfg.n_splits, - save_dir=f"raw_features/gemma/neurons_{k}" - ) - - cache.save_config( - save_dir=f"raw_features/gemma/neurons_{k}", - cfg=cfg, - model_name="google/gemma-2-9b" - ) - - -if __name__ == "__main__": - - parser = ArgumentParser() - #ctx len 256 - parser.add_arguments(CacheConfig, dest="options") - parser.add_argument("--k", type=int, default=32) - parser.add_argument("--layers", type=str, default="23,27") - - args = parser.parse_args() - cfg = args.options - - main(cfg,args) diff --git a/experiments/caching_gemma/embedding.py b/experiments/caching_gemma/embedding.py deleted file mode 100644 index 5a286023..00000000 --- a/experiments/caching_gemma/embedding.py +++ /dev/null @@ -1,120 +0,0 @@ -import asyncio -import os -import random -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser -from transformers import AutoTokenizer -from sentence_transformers import SentenceTransformer -import torch - -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader, random_explanation_loader -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import EmbeddingScorer - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - random = args.random - experiment_name = args.experiment_name - ### Load dataset ### - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - - - ### Build Explainer pipe ### - if random: - explainer = partial(random_explanation_loader, explanation_dir=f"results/explanations/{sae_model}/{experiment_name}/") - results_dir = f"results/scores/{sae_model}/random_explanation/" - - else: - explainer = partial(explanation_loader, explanation_dir=f"results/explanations/{sae_model}/{experiment_name}/") - results_dir = f"results/scores/{sae_model}/{experiment_name}/" - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - record.extra_examples = record.not_active - - - return record - if random: - experiment_name = "random-explanation" - def scorer_postprocess(result, score_dir): - with open(f"{results_dir}/{score_dir}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - print(results_dir) - os.makedirs(f"{results_dir}/embedding", exist_ok=True) - - model = SentenceTransformer("dunzhang/stella_en_400M_v5", trust_remote_code=True).cuda() - #model = SentenceTransformer("nvidia/NV-Embed-v2", trust_remote_code=True).cuda() - scorer_pipe = Pipe( - process_wrapper( - EmbeddingScorer(model,tokenizer, batch_size=shown_examples,verbose=False), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="embedding"), - ), - - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer, - scorer_pipe, - ) - - asyncio.run(pipeline.run(1)) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--random", action="store_true") - args = parser.parse_args() - - - - main(args) diff --git a/experiments/caching_gemma/example_bash.sh b/experiments/caching_gemma/example_bash.sh deleted file mode 100644 index e33161ab..00000000 --- a/experiments/caching_gemma/example_bash.sh +++ /dev/null @@ -1 +0,0 @@ -export VLLM_WORKER_MULTIPROC_METHOD=spawn; CUDA_VISIBLE_DEVICES=0,1 python caching_gemma/generate_explanations.py --module .model.layers.8 --train_type "top" --n_examples_train 40 --model gemma/131k --experiment_name top40 --n_quantiles 10 --n_random 100 --n_examples_test 10 --features 300 --example_ctx_len 32 --width 131072 diff --git a/experiments/caching_gemma/generate_explanations_and_scores.py b/experiments/caching_gemma/generate_explanations_and_scores.py deleted file mode 100644 index 3c05eeb9..00000000 --- a/experiments/caching_gemma/generate_explanations_and_scores.py +++ /dev/null @@ -1,155 +0,0 @@ -import asyncio -import json -import os -from functools import partial - -import orjson -import torch -import time -from simple_parsing import ArgumentParser - -from delphi.clients import Offline -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import DefaultExplainer -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe,Pipeline, process_wrapper -from delphi.scorers import FuzzingScorer, DetectionScorer - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - activations = args.activations - cot = args.cot - quantization = args.quantization - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - ### Load client ### - if args.quantization == "awq": - client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.8,max_model_len=5120) - elif args.quantization == "none": - client = Offline("meta-llama/Meta-Llama-3.1-70B-Instruct",max_memory=0.8,max_model_len=5120,num_gpus=2) - - ### Build Explainer pipe ### - def explainer_postprocess(result): - - with open(f"results/explanations/{sae_model}/{experiment_name}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.explanation)) - - return result - #try making the directory if it doesn't exist - os.makedirs(f"results/explanations/{sae_model}/{experiment_name}", exist_ok=True) - - explainer_pipe = process_wrapper( - DefaultExplainer( - client, - tokenizer=dataset.tokenizer, - threshold=0.3, - activations=activations, - cot=cot - ), - postprocess=explainer_postprocess, - ) - - #save the experiment config - with open(f"results/explanations/{sae_model}/{experiment_name}/experiment_config.json", "w") as f: - print(experiment_cfg.to_dict()) - f.write(json.dumps(experiment_cfg.to_dict())) - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - record.explanation = result.explanation - record.extra_examples = record.not_active - - return record - - def scorer_postprocess(result, score_dir): - record = result.record - with open(f"results/scores/{sae_model}/{experiment_name}/{score_dir}/{record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - - os.makedirs(f"results/scores/{sae_model}/{experiment_name}/detection", exist_ok=True) - os.makedirs(f"results/scores/{sae_model}/{experiment_name}/fuzz", exist_ok=True) - - #save the experiment config - with open(f"results/scores/{sae_model}/{experiment_name}/detection/experiment_config.json", "w") as f: - f.write(json.dumps(experiment_cfg.to_dict())) - - with open(f"results/scores/{sae_model}/{experiment_name}/fuzz/experiment_config.json", "w") as f: - f.write(json.dumps(experiment_cfg.to_dict())) - - - scorer_pipe = Pipe(process_wrapper( - DetectionScorer(client, tokenizer=dataset.tokenizer, batch_size=shown_examples,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="detection"), - ), - process_wrapper( - FuzzingScorer(client, tokenizer=dataset.tokenizer, batch_size=shown_examples,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="fuzz"), - ) - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - start_time = time.time() - asyncio.run(pipeline.run(50)) - end_time = time.time() - print(f"Time taken: {end_time - start_time} seconds") - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--no-activations", action="store_false", dest="activations", help="Disable activations") - parser.add_argument("--cot", type=bool, default=False) - parser.add_argument("--quantization", type=str, default="awq") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - args = parser.parse_args() - print(args) - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_gemma/generate_only_explanations.py b/experiments/caching_gemma/generate_only_explanations.py deleted file mode 100644 index 567cd37e..00000000 --- a/experiments/caching_gemma/generate_only_explanations.py +++ /dev/null @@ -1,115 +0,0 @@ -import asyncio -import json -import os -from functools import partial - -import orjson -import torch -import time -from simple_parsing import ArgumentParser - -from delphi.clients import Offline,OpenRouter -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import DefaultExplainer -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipeline, process_wrapper - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - explainer_size = args.explainer_size - - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - ### Load client ### - if explainer_size == "70b": - client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.8,max_model_len=5120,num_gpus=1) - elif explainer_size == "8b": - client = Offline("meta-llama/Meta-Llama-3.1-8B-Instruct",max_memory=0.8,max_model_len=5120,num_gpus=1) - elif explainer_size == "claude": - client = OpenRouter("anthropic/claude-3.5-sonnet",api_key=os.getenv("OPENROUTER_API_KEY")) - - ### Build Explainer pipe ### - def explainer_postprocess(result): - - with open(f"results/explanations/{sae_model}/{experiment_name}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.explanation)) - del result - return None - - #try making the directory if it doesn't exist - os.makedirs(f"results/explanations/{sae_model}/{experiment_name}", exist_ok=True) - - explainer_pipe = process_wrapper( - DefaultExplainer( - client, - tokenizer=dataset.tokenizer, - threshold=0.3, - activations=True - ), - postprocess=explainer_postprocess, - ) - - #save the experiment config - with open(f"results/explanations/{sae_model}/{experiment_name}/experiment_config.json", "w") as f: - print(experiment_cfg.to_dict()) - f.write(json.dumps(experiment_cfg.to_dict())) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - ) - start_time = time.time() - asyncio.run(pipeline.run(100)) - end_time = time.time() - print(f"Time taken: {end_time - start_time} seconds") - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--activations", type=bool, default=True) - parser.add_argument("--explainer_size", type=str, default="70b") - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_gemma/score_explanations.py b/experiments/caching_gemma/score_explanations.py deleted file mode 100644 index 4d6b150f..00000000 --- a/experiments/caching_gemma/score_explanations.py +++ /dev/null @@ -1,134 +0,0 @@ -import asyncio -import os -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser - -from delphi.clients import Offline,OpenRouter -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader,random_explanation_loader -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import FuzzingScorer, DetectionScorer - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - experiment_name = args.experiment_name - quantization = args.quantization - scorer_size = args.scorer_size - log_prob = args.prob - - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - ### Load client ### - - EXPLAINER_OUT_DIR = f"results/explanations/{sae_model}/{experiment_name}" - - ### Build Explainer pipe ### - if args.random: - explainer_pipe = partial(random_explanation_loader, explanation_dir=EXPLAINER_OUT_DIR) - experiment_name = "random_explanation" - else: - explainer_pipe = partial(explanation_loader, explanation_dir=EXPLAINER_OUT_DIR) - - - if scorer_size == "70b": - client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.8, max_model_len=5120,num_gpus=2) - SCORE_OUT_DIR = f"results/scores/{sae_model}/{experiment_name}" - elif scorer_size == "8b": - client = Offline("meta-llama/Meta-Llama-3.1-8B-Instruct",max_memory=0.8,max_model_len=5120,num_gpus=1) - SCORE_OUT_DIR = f"results/scores/{sae_model}/{experiment_name}_scorer_8b" - elif scorer_size == "claude": - client = OpenRouter("anthropic/claude-3.5-sonnet",api_key=os.getenv("OPENROUTER_API_KEY")) - SCORE_OUT_DIR = f"results/scores/{sae_model}/{experiment_name}_scorer_claude" - - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - record.extra_examples = record.not_active - - return record - - def scorer_postprocess(result, score_dir): - with open(f"{SCORE_OUT_DIR}/{score_dir}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - os.makedirs(f"{SCORE_OUT_DIR}/detection", exist_ok=True) - os.makedirs(f"{SCORE_OUT_DIR}/fuzz", exist_ok=True) - scorer_pipe = Pipe( - process_wrapper( - detectionScorer(client, tokenizer=dataset.tokenizer, batch_size=shown_examples,verbose=False,log_prob=log_prob), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="detection"), - ), - process_wrapper( - FuzzingScorer(client, tokenizer=dataset.tokenizer, batch_size=shown_examples,verbose=False,log_prob=log_prob), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="fuzz"), - ), - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - - asyncio.run(pipeline.run(50)) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--scorer_size", type=str, default="70b") - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--quantization", type=str, default="awq",choices=["awq","bnb","normal"]) - parser.add_argument("--random", type=bool, default=False) - parser.add_argument("--prob",action="store_false") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - args = parser.parse_args() - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_gemma/score_explanations_fig1.py b/experiments/caching_gemma/score_explanations_fig1.py deleted file mode 100644 index ab76b620..00000000 --- a/experiments/caching_gemma/score_explanations_fig1.py +++ /dev/null @@ -1,134 +0,0 @@ -import asyncio -import os -from functools import partial -from typing import NamedTuple - -import orjson -import torch -from simple_parsing import ArgumentParser - -from delphi.clients import Offline, OpenRouter -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.features import FeatureDataset, FeatureLoader -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import FuzzingScorer, DetectionScorer -import aiofiles -import json -from delphi.features import FeatureRecord - -class ExplainerResult(NamedTuple): - record: FeatureRecord - """Feature record passed through to scorer.""" - - explanation: str - """Generated explanation for feature.""" - - -async def explanation_loader(record: FeatureRecord, explanation_dir: str) -> ExplainerResult: - feature = str(record.feature) - layer = feature.split("_feature")[0].split(".")[-1] - feature = feature.split("_feature")[1].split(".")[0] - async with aiofiles.open(f'extras/explanations_131k/model.layers.{layer}_feature.json', 'r') as f: - explanations = json.loads(await f.read()) - explanation = explanations[feature] - print() - print(explanation) - print(record.feature) - return ExplainerResult( - record=record, - explanation=explanation - ) - - - -def main(args): - feature_cfg = args.feature_options - experiment_cfg = ExperimentConfig() - experiment_name = "fig1" - - feature_dict = {#".model.layers.0":torch.tensor([44,2528,2011]), - #".model.layers.5":torch.tensor([3031,8603]), - #".model.layers.15":torch.tensor([6680]), - #".model.layers.20":torch.tensor([7343]), - ".model.layers.40":torch.tensor([4661]), - } - dataset = FeatureDataset( - raw_dir="raw_features/gemma/131k", - cfg=feature_cfg, - modules=[".model.layers.40"], - features=feature_dict, - ) - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=100, - ctx_len=32, - max_examples=10000 - ) - experiment_cfg.n_examples_test=10 - experiment_cfg.n_quantiles=10 - experiment_cfg.example_ctx_len=32 - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - ### Load client ### - - EXPLAINER_OUT_DIR = f"results/explanations/gemma/131k/{experiment_name}" - - ### Build Explainer pipe ### - explainer_pipe = partial(explanation_loader, explanation_dir=EXPLAINER_OUT_DIR) - - - client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.8, max_model_len=5120,num_gpus=2) - SCORE_OUT_DIR = f"results/scores/gemma/131k/{experiment_name}" - - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - record.extra_examples = record.not_active - - return record - - def scorer_postprocess(result, score_dir): - with open(f"{SCORE_OUT_DIR}/{score_dir}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - os.makedirs(f"{SCORE_OUT_DIR}/detection", exist_ok=True) - os.makedirs(f"{SCORE_OUT_DIR}/fuzz", exist_ok=True) - scorer_pipe = Pipe( - process_wrapper( - DetectionScorer(client, tokenizer=dataset.tokenizer, batch_size=5,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="detection"), - ), - process_wrapper( - FuzzingScorer(client, tokenizer=dataset.tokenizer, batch_size=5,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="fuzz"), - ), - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - - asyncio.run(pipeline.run(50)) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_arguments(FeatureConfig, dest="feature_options") - args = parser.parse_args() - - - main(args) diff --git a/experiments/caching_gemma/simulation_scoring.py b/experiments/caching_gemma/simulation_scoring.py deleted file mode 100644 index e3647cc3..00000000 --- a/experiments/caching_gemma/simulation_scoring.py +++ /dev/null @@ -1,131 +0,0 @@ -import asyncio -import json -import os -import time -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser -from transformers import AutoTokenizer - -from delphi.clients import Offline -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader -from delphi.features import FeatureDataset, FeatureLoader -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import OpenAISimulator - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - all_at_once = args.all_at_once - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - scorer_size = args.scorer_size - - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - - - - #client = Local("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",base_url=f"http://localhost:{args.port}/v1") - if scorer_size == "70b": - client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.5, max_model_len=5120,num_gpus=2,batch_size=5,prefix_caching=False) - SCORER_OUT_DIR = f"results/scores/{sae_model}/{experiment_name}" - elif scorer_size == "8b": - client = Offline("meta-llama/Meta-Llama-3.1-8B-Instruct",max_memory=0.5,max_model_len=5120,num_gpus=1,batch_size=5,prefix_caching=False) - SCORER_OUT_DIR = f"results/scores/{sae_model}/{experiment_name}_scorer_8b" - elif scorer_size == "claude": - client = OpenRouter("anthropic/claude-3.5-sonnet",api_key="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") - SCORER_OUT_DIR = f"results/scores/{sae_model}/{experiment_name}_scorer_claude" - - ### Set directories ### - - EXPLAINER_OUT_DIR = f"results/explanations/{sae_model}/{experiment_name}/" - - ### Build Explainer pipe ### - - explainer_pipe = partial(explanation_loader, explanation_dir=EXPLAINER_OUT_DIR) - - ### Build Scorer pipe ### - - - def scorer_preprocess(result): - record = result.record - record.explanation = result.explanation - new_test = [] - for i in range(len(record.test)): - new_test.append(record.test[i]) - record.test = new_test - #record.test = record.test[0][:2] - return record - - - def scorer_postprocess(result): - with open(f"{SCORER_OUT_DIR}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - os.makedirs(f"{SCORER_OUT_DIR}", exist_ok=True) - - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - scorer_pipe = Pipe( - process_wrapper( - OpenAISimulator(client, tokenizer=tokenizer, all_at_once=all_at_once), - preprocess=scorer_preprocess, - postprocess=scorer_postprocess, - ) - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - - asyncio.run( - pipeline.run(10) - ) - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--batch_size", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--no_all_at_once", action='store_false', dest="all_at_once") - parser.add_argument("--scorer_size", type=str, default="70b") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--port", type=int, default=8001) - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_gemma/surprisal.py b/experiments/caching_gemma/surprisal.py deleted file mode 100644 index e1b80eb1..00000000 --- a/experiments/caching_gemma/surprisal.py +++ /dev/null @@ -1,116 +0,0 @@ -import asyncio -import os -import random -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser -from transformers import AutoModelForCausalLM, AutoTokenizer -import torch - -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader, random_explanation_loader -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import SurprisalScorer - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - random = args.random - experiment_name = args.experiment_name - ### Load dataset ### - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - - - ### Build Explainer pipe ### - if random: - explainer = partial(random_explanation_loader, explanation_dir=f"results/explanations/{sae_model}/{experiment_name}/") - else: - explainer = partial(explanation_loader, explanation_dir=f"results/explanations/{sae_model}/{experiment_name}/") - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - record.extra_examples = record.not_active - - - return record - if random: - experiment_name = "random-explanation" - def scorer_postprocess(result, score_dir): - with open(f"results/scores/{sae_model}/{experiment_name}/{score_dir}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - os.makedirs(f"results/scores/{sae_model}/{experiment_name}/surprisal", exist_ok=True) - - model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-70B", device_map="auto",load_in_8bit=True) - model.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-70B") - scorer_pipe = Pipe( - process_wrapper( - SurprisalScorer(model,tokenizer, batch_size=shown_examples,verbose=False), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="surprisal"), - ), - - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer, - scorer_pipe, - ) - - asyncio.run(pipeline.run(1)) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--random", action="store_true") - args = parser.parse_args() - - - - main(args) diff --git a/experiments/caching_llama/cache.py b/experiments/caching_llama/cache.py deleted file mode 100644 index 98453228..00000000 --- a/experiments/caching_llama/cache.py +++ /dev/null @@ -1,82 +0,0 @@ -from nnsight import LanguageModel -from simple_parsing import ArgumentParser -import torch -from delphi.autoencoders import load_eai_autoencoders -from delphi.config import CacheConfig -from delphi.features import FeatureCache -from delphi.utils import load_tokenized_data -import os - -def main(cfg: CacheConfig, args): - - size = args.size - randomize = args.randomize - k = args.k - type = args.type - - model = LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="auto",dispatch=True,torch_dtype=torch.bfloat16) - - if size == "32x": - weight_dir = "EleutherAI/sae-llama-3.1-8b-32x" - elif size == "64x": - weight_dir = "EleutherAI/sae-llama-3.1-8b-64x" - elif size == "64x_no_multi": - weight_dir = "/mnt/ssd-1/nora/llama-64x-no-multitopk" - - - submodule_dict,model = load_eai_autoencoders( - model, - [23,29], - weight_dir, - module=type, - randomize=randomize, - k=k - ) - - - tokens = load_tokenized_data( - cfg.ctx_len, - model.tokenizer, - cfg.dataset_repo, - cfg.dataset_split, - ) - print(submodule_dict) - cache = FeatureCache( - model, - submodule_dict, - batch_size=cfg.batch_size, - ) - name="" - if k is not None: - name += f"_topk_{k}" - if randomize: - name += "_random" - if type != "res": - name += f"_{type}" - cache.run(10000000, tokens) - os.makedirs(f"raw_features/llama/{size}{name}", exist_ok=True) - - cache.save_splits( - n_splits=cfg.n_splits, - save_dir=f"raw_features/llama/{size}{name}" - ) - cache.save_config( - save_dir=f"raw_features/llama/{size}{name}", - cfg=cfg, - model_name="meta-llama/Meta-Llama-3.1-8B" - ) - -if __name__ == "__main__": - - parser = ArgumentParser() - #ctx len 256 - parser.add_arguments(CacheConfig, dest="options") - parser.add_argument("--size", type=str, default="32x") - parser.add_argument("--type", type=str, default="res") - parser.add_argument("--randomize", action="store_true") - parser.add_argument("--k", type=int, default=None) - args = parser.parse_args() - cfg = args.options - print(cfg) - - main(cfg, args) diff --git a/experiments/caching_llama/cache_neurons.py b/experiments/caching_llama/cache_neurons.py deleted file mode 100644 index 0f3bc822..00000000 --- a/experiments/caching_llama/cache_neurons.py +++ /dev/null @@ -1,56 +0,0 @@ -from nnsight import LanguageModel -from simple_parsing import ArgumentParser - -from delphi.autoencoders import load_llama3_neurons -from delphi.config import CacheConfig -from delphi.features import FeatureCache -from delphi.utils import load_tokenized_data -import os - -def main(cfg: CacheConfig,args): - model = LanguageModel("meta-llama/Meta-Llama-3.1-8B", device_map="auto", dispatch=True) - - submodule_dict,model = load_llama3_neurons( - model, - [23,29], - args.k - ) - - - tokens = load_tokenized_data( - cfg.ctx_len, - model.tokenizer, - cfg.dataset_repo, - cfg.dataset_split, - ) - - cache = FeatureCache( - model, - submodule_dict, - batch_size=cfg.batch_size, - ) - - cache.run(10000000, tokens) - os.makedirs(f"raw_features/llama/neurons_{args.k}", exist_ok=True) - - cache.save_splits( - n_splits=cfg.n_splits, - save_dir=f"raw_features/llama/neurons_{args.k}" - ) - cache.save_config( - save_dir=f"raw_features/llama/neurons_{args.k}", - cfg=cfg, - model_name="meta-llama/Meta-Llama-3.1-8B" - ) - - -if __name__ == "__main__": - - parser = ArgumentParser() - #ctx len 256 - parser.add_arguments(CacheConfig, dest="options") - parser.add_argument("--k", type=int, default=32) - args = parser.parse_args() - cfg = args.options - - main(cfg,args) diff --git a/experiments/caching_llama/generate_explanations_and_scores.py b/experiments/caching_llama/generate_explanations_and_scores.py deleted file mode 100644 index d6134424..00000000 --- a/experiments/caching_llama/generate_explanations_and_scores.py +++ /dev/null @@ -1,154 +0,0 @@ -import asyncio -import json -import os -from functools import partial - -import orjson -import torch -import time -from simple_parsing import ArgumentParser - -from delphi.clients import Offline -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import DefaultExplainer -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe,Pipeline, process_wrapper -from delphi.scorers import FuzzingScorer, DetectionScorer - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - quantization = args.quantization - activations = args.activations - - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features/{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - ### Load client ### - if args.quantization == "awq": - client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.8,max_model_len=5120) - elif args.quantization == "none": - client = Offline("meta-llama/Meta-Llama-3.1-70B-Instruct",max_memory=0.8,max_model_len=5120,num_gpus=4) - #client = Offline("Qwen/Qwen2-72B-Instruct-AWQ") - - ### Build Explainer pipe ### - def explainer_postprocess(result): - - with open(f"results/explanations/{sae_model}/{experiment_name}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.explanation)) - - return result - #try making the directory if it doesn't exist - os.makedirs(f"results/explanations/{sae_model}/{experiment_name}", exist_ok=True) - - explainer_pipe = process_wrapper( - DefaultExplainer( - client, - tokenizer=dataset.tokenizer, - threshold=0.3, - activations=activations - ), - postprocess=explainer_postprocess, - ) - - #save the experiment config - with open(f"results/explanations/{sae_model}/{experiment_name}/experiment_config.json", "w") as f: - print(experiment_cfg.to_dict()) - f.write(json.dumps(experiment_cfg.to_dict())) - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - record.explanation = result.explanation - record.extra_examples = record.not_active - - return record - - def scorer_postprocess(result, score_dir): - record = result.record - with open(f"results/scores/{sae_model}/{experiment_name}/{score_dir}/{record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - - os.makedirs(f"results/scores/{sae_model}/{experiment_name}/detection", exist_ok=True) - os.makedirs(f"results/scores/{sae_model}/{experiment_name}/fuzz", exist_ok=True) - - #save the experiment config - with open(f"results/scores/{sae_model}/{experiment_name}/detection/experiment_config.json", "w") as f: - f.write(json.dumps(experiment_cfg.to_dict())) - - with open(f"results/scores/{sae_model}/{experiment_name}/fuzz/experiment_config.json", "w") as f: - f.write(json.dumps(experiment_cfg.to_dict())) - - - scorer_pipe = Pipe(process_wrapper( - DetectionScorer(client, tokenizer=dataset.tokenizer, batch_size=shown_examples,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="detection"), - ), - process_wrapper( - FuzzingScorer(client, tokenizer=dataset.tokenizer, batch_size=shown_examples,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="fuzz"), - ) - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - start_time = time.time() - asyncio.run(pipeline.run(50)) - end_time = time.time() - print(f"Time taken: {end_time - start_time} seconds") - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--activations", type=bool, default=True) - parser.add_argument("--quantization", type=str, default="awq") - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_llama/generate_only_explanations.py b/experiments/caching_llama/generate_only_explanations.py deleted file mode 100644 index 0759e0e4..00000000 --- a/experiments/caching_llama/generate_only_explanations.py +++ /dev/null @@ -1,114 +0,0 @@ -import asyncio -import json -import os -from functools import partial - -import orjson -import torch -import time -from simple_parsing import ArgumentParser - -from delphi.clients import Local -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import DefaultExplainer -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipeline, process_wrapper - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - quantization = args.quantization - - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features_{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - - ### Load client ### - if quantization == "awq": - client = Local("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",base_url=f"http://localhost:{args.port}/v1") - else: - client = Local("meta-llama/Meta-Llama-3.1-70B-Instruct",base_url=f"http://localhost:{args.port}/v1") - - - ### Build Explainer pipe ### - def explainer_postprocess(result): - - with open(f"results/explanations/{sae_model}_{experiment_name}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.explanation)) - - return result - #try making the directory if it doesn't exist - os.makedirs(f"results/explanations/{sae_model}_{experiment_name}", exist_ok=True) - - explainer_pipe = process_wrapper( - DefaultExplainer( - client, - tokenizer=dataset.tokenizer, - threshold=0.1 - ), - postprocess=explainer_postprocess, - ) - - #save the experiment config - with open(f"results/explanations/{sae_model}_{experiment_name}/experiment_config.json", "w") as f: - print(experiment_cfg.to_dict()) - f.write(json.dumps(experiment_cfg.to_dict())) - - ### Build Scorer pipe ### - - pipeline = Pipeline( - loader, - explainer_pipe, - ) - start_time = time.time() - asyncio.run(pipeline.run(100)) - end_time = time.time() - print(f"Time taken: {end_time - start_time} seconds") - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--quantization", type=str, default="awq",choices=["awq","bnb","normal"]) - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--port", type=int, default=8001) - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_llama/score_explanations.py b/experiments/caching_llama/score_explanations.py deleted file mode 100644 index ee15648f..00000000 --- a/experiments/caching_llama/score_explanations.py +++ /dev/null @@ -1,139 +0,0 @@ -import asyncio -import os -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser - -from delphi.clients import Local -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader -from delphi.features import ( - FeatureDataset, -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import FuzzingScorer, DetectionScorer -from delphi.utils import ( - load_tokenized_data, - load_tokenizer, -) - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - quantization = args.quantization - - ### Load dataset ### - #TODO: we should probably save the token information when we save the features that - # when we load it we don't have to remember all the details. - tokenizer = load_tokenizer("meta-llama/Meta-Llama-3.1-8B") - tokens = load_tokenized_data( - 256, - tokenizer, - "kh4dien/fineweb-100m-sample", - "train", - ) - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features_{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - #TODO: should be configurable - - loader = partial(dataset.load, - constructor=partial( - default_constructor, - tokens=tokens, - n_random=experiment_cfg.n_random, - ctx_len=feature_cfg.example_ctx_len, - max_examples=10_000 - ), - sampler=partial(sample,cfg=experiment_cfg) - ) - - ### Load client ### - - if quantization == "awq": - client = Local("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",base_url=f"http://localhost:{args.port}/v1") - else: - client = Local("meta-llama/Meta-Llama-3.1-70B-Instruct",base_url=f"http://localhost:{args.port}/v1") - - - EXPLAINER_OUT_DIR = f"results/explanations/{sae_model}_{experiment_name}" - - ### Build Explainer pipe ### - explainer_pipe = partial(explanation_loader, explanation_dir=EXPLAINER_OUT_DIR) - - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - record.extra_examples = record.not_active - - return record - - def scorer_postprocess(result, score_dir): - with open(f"results/scores/{sae_model}_{experiment_name}_{score_dir}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - os.makedirs(f"results/scores/{sae_model}_{experiment_name}_detection", exist_ok=True) - os.makedirs(f"results/scores/{sae_model}_{experiment_name}_fuzz", exist_ok=True) - - scorer_pipe = Pipe( - process_wrapper( - DetectionScorer(client, tokenizer=tokenizer, batch_size=shown_examples,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="detection"), - ), - process_wrapper( - FuzzingScorer(client, tokenizer=tokenizer, batch_size=shown_examples,verbose=False,log_prob=True), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="fuzz"), - ), - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - - asyncio.run(pipeline.run(max_processes=10)) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--quantization", type=str, default="awq",choices=["awq","bnb","normal"]) - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--port", type=int, default=8001) - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_llama/simulation_scoring.py b/experiments/caching_llama/simulation_scoring.py deleted file mode 100644 index 94d86532..00000000 --- a/experiments/caching_llama/simulation_scoring.py +++ /dev/null @@ -1,135 +0,0 @@ -import asyncio -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser -import os -from delphi.clients import Local,Outlines -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader -from delphi.features import FeatureDataset -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import OpenAISimulator -from delphi.utils import ( - load_tokenized_data, - load_tokenizer, -) - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - batch_size = args.batch_size - all_at_once = args.all_at_once - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - - ### Load dataset ### - #TODO: we should probably save the token information when we save the features that - # when we load it we don't have to remember all the details. - tokenizer = load_tokenizer("meta-llama/Meta-Llama-3.1-8B") - tokens = load_tokenized_data( - 256, - tokenizer, - "kh4dien/fineweb-100m-sample", - "train", - ) - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features_{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - #TODO: should be configurable - - loader = partial(dataset.load, - constructor=partial( - default_constructor, - tokens=tokens, - n_random=experiment_cfg.n_random, - ctx_len=feature_cfg.example_ctx_len, - max_examples=10_000 - ), - sampler=partial(sample,cfg=experiment_cfg) - ) - - - client = Outlines("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",base_url=f"http://localhost:{args.port}") - - - ### Set directories ### - - EXPLAINER_OUT_DIR = f"results/explanations/{sae_model}_{experiment_name}/" - SCORER_OUT_DIR = f"results/scores/{sae_model}_{experiment_name}_{'all_at_once' if all_at_once else 'token_by_token'}/" - - ### Build Explainer pipe ### - - explainer_pipe = partial(explanation_loader, explanation_dir=EXPLAINER_OUT_DIR) - - ### Build Scorer pipe ### - - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - new_test = [] - for i in range(len(record.test)): - new_test.extend(record.test[i][:2]) - record.test = new_test - #record.test = record.test[0][:2] - return record - - - def scorer_postprocess(result): - with open(f"{SCORER_OUT_DIR}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - os.makedirs(f"{SCORER_OUT_DIR}", exist_ok=True) - - - scorer_pipe = Pipe( - process_wrapper( - OpenAISimulator(client, tokenizer=tokenizer, all_at_once=all_at_once), - preprocess=scorer_preprocess, - postprocess=scorer_postprocess, - ) - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer_pipe, - scorer_pipe, - ) - - asyncio.run( - pipeline.run(max_processes=1) - ) - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--batch_size", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--all_at_once", action='store_true', default=False) - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--port", type=int, default=8001) - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) diff --git a/experiments/caching_llama/surprisal.py b/experiments/caching_llama/surprisal.py deleted file mode 100644 index f3efbd75..00000000 --- a/experiments/caching_llama/surprisal.py +++ /dev/null @@ -1,116 +0,0 @@ -import asyncio -import os -import random -from functools import partial - -import orjson -import torch -from simple_parsing import ArgumentParser -from transformers import AutoTokenizer,AutoModelForCausalLM -import torch - -from delphi.config import ExperimentConfig, FeatureConfig -from delphi.explainers import explanation_loader -from delphi.features import ( - FeatureDataset, - FeatureLoader -) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample -from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import SurprisalScorer - - -def main(args): - module = args.module - feature_cfg = args.feature_options - experiment_cfg = args.experiment_options - shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature - sae_model = args.model - quantization = args.quantization - - ### Load dataset ### - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir=f"raw_features_{sae_model}", - cfg=feature_cfg, - modules=[module], - features=feature_dict, - ) - - - constructor=partial( - default_constructor, - tokens=dataset.tokens, - n_random=experiment_cfg.n_random, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples - ) - sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) - - - ### Build Explainer pipe ### - - explainer = partial(explanation_loader, explanation_dir=f"results/explanations/{sae_model}_{experiment_name}/") - - ### Build Scorer pipe ### - - def scorer_preprocess(result): - record = result.record - - record.explanation = result.explanation - record.extra_examples = record.not_active - - - return record - - def scorer_postprocess(result, score_dir): - with open(f"results/scores/{sae_model}_{experiment_name}_{score_dir}/{result.record.feature}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) - - os.makedirs(f"results/scores/{sae_model}_{experiment_name}_surprisal", exist_ok=True) - tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B") - model = AutoModelForCausalLM.from_pretrained("meta-llama/Meta-Llama-3.1-70B", device_map="auto",load_in_8bit=True) - model.tokenizer = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-70B") - scorer_pipe = Pipe( - process_wrapper( - SurprisalScorer(model,tokenizer, batch_size=shown_examples,verbose=False), - preprocess=scorer_preprocess, - postprocess=partial(scorer_postprocess, score_dir="surprisal"), - ), - - ) - - ### Build the pipeline ### - - pipeline = Pipeline( - loader, - explainer, - scorer_pipe, - ) - - asyncio.run(pipeline.run(1)) - - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--shown_examples", type=int, default=5) - parser.add_argument("--model", type=str, default="128k") - parser.add_argument("--module", type=str, default=".transformer.h.0") - parser.add_argument("--features", type=int, default=100) - parser.add_argument("--start_feature", type=int, default=0) - parser.add_argument("--experiment_name", type=str, default="default") - parser.add_argument("--quantization", type=str, default="awq",choices=["awq","bnb","normal"]) - parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") - parser.add_argument("--port", type=int, default=8001) - args = parser.parse_args() - - experiment_name = args.experiment_name - - - - main(args) From 83fdbf82389e2b657b2b626a94ccfd7d738edb9b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 18:04:41 +0000 Subject: [PATCH 048/132] random -> non_activating --- delphi/config.py | 4 +- delphi/scorers/classifier/classifier.py | 18 +- examples/caching_activations.ipynb | 2279 ++++++++++++++++++++++- examples/generate_explanations.ipynb | 90 +- examples/score_explanations.ipynb | 93 +- examples/server.py | 4 +- 6 files changed, 2343 insertions(+), 145 deletions(-) diff --git a/delphi/config.py b/delphi/config.py index 43cf0ebf..9dc01870 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -19,8 +19,8 @@ class ExperimentConfig(Serializable): """Length of each sampled example sequence. Longer sequences reduce detection scoring performance in weak models.""" - n_random: int = 50 - """Number of random examples to sample.""" + n_non_activating: int = 50 + """Number of non-activating examples to sample.""" train_type: Literal["top", "random", "quantiles"] = "quantiles" """Type of sampler to use for feature explanation generation.""" diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 5394015a..d6ecde79 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -104,8 +104,8 @@ async def _generate( logger.error(f"Error generating text: {e}") response = None if response is None: - predictions = [-1] * self.batch_size - probabilities = [-1] * self.batch_size + predictions = [-1] * self.n_examples_shown + probabilities = [-1] * self.n_examples_shown else: selections = response.text logprobs = response.logprobs if self.log_prob else None @@ -113,8 +113,8 @@ async def _generate( predictions, probabilities = self._parse(selections, logprobs) except Exception as e: logger.error(f"Parsing selections failed: {e}") - predictions = [-1] * self.batch_size - probabilities = [-1] * self.batch_size + predictions = [-1] * self.n_examples_shown + probabilities = [-1] * self.n_examples_shown results = [] correct = [] @@ -143,11 +143,11 @@ def _parse(self, string, logprobs=None): match = re.search(pattern, string) predictions: list[int] = json.loads(match.group(0)) - assert len(predictions) == self.batch_size + assert len(predictions) == self.n_examples_shown probabilities = ( self._parse_logprobs(logprobs) if logprobs is not None - else [None] * self.batch_size + else [None] * self.n_examples_shown ) return predictions, probabilities @@ -183,7 +183,7 @@ def _parse_logprobs(self, logprobs: list): else: binary_probabilities.append(0.) - assert len(binary_probabilities) == self.batch_size + assert len(binary_probabilities) == self.n_examples_shown return binary_probabilities @@ -204,8 +204,8 @@ def _build_prompt( def _batch(self, samples): return [ - samples[i : i + self.batch_size] - for i in range(0, len(samples), self.batch_size) + samples[i : i + self.n_examples_shown] + for i in range(0, len(samples), self.n_examples_shown) ] def call_sync(self, record: FeatureRecord) -> list[ClassifierOutput]: diff --git a/examples/caching_activations.ipynb b/examples/caching_activations.ipynb index df09271e..2f2a2bbd 100644 --- a/examples/caching_activations.ipynb +++ b/examples/caching_activations.ipynb @@ -30,7 +30,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -46,85 +46,2234 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6c0b5075561a4a8eae6c48703b964161", + "model_id": "74f50f88e37d4803a9e156692791596d", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Loading checkpoint shards: 0%| | 0/8 [00:00" + "[None, None, None]" ] }, - "execution_count": 10, + "execution_count": 13, "metadata": {}, "output_type": "execute_result" } @@ -201,8 +214,15 @@ " explainer_pipe,\n", ")\n", "number_of_parallel_latents = 10\n", - "asyncio.run(pipeline.run(number_of_parallel_latents)) # This will start generating the explanations." + "await pipeline.run(number_of_parallel_latents) # This will start generating the explanations." ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { @@ -221,7 +241,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.13" + "version": "3.12.7" } }, "nbformat": 4, diff --git a/examples/score_explanations.ipynb b/examples/score_explanations.ipynb index 641b1031..760f463c 100644 --- a/examples/score_explanations.ipynb +++ b/examples/score_explanations.ipynb @@ -39,12 +39,12 @@ "\n", "\n", "\n", - "API_KEY = os.getenv(\"OPENROUTER_API_KEY\")" + "API_KEY = \"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\"#os.getenv(\"OPENROUTER_API_KEY\")" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [], "source": [ @@ -58,27 +58,12 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 3, "metadata": {}, - "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "106d483d37a84fafb7371e167f94181a", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Resolving data files: 0%| | 0/150 [00:00 Date: Mon, 10 Feb 2025 18:05:05 +0000 Subject: [PATCH 049/132] fix gemma loader --- delphi/autoencoders/Custom/gemmascope.py | 4 ---- delphi/autoencoders/DeepMind/__init__.py | 3 ++- delphi/autoencoders/wrapper.py | 5 ++--- delphi/clients/client.py | 6 +++--- delphi/features/constructors.py | 2 +- 5 files changed, 8 insertions(+), 12 deletions(-) diff --git a/delphi/autoencoders/Custom/gemmascope.py b/delphi/autoencoders/Custom/gemmascope.py index efccbcf8..0564627d 100644 --- a/delphi/autoencoders/Custom/gemmascope.py +++ b/delphi/autoencoders/Custom/gemmascope.py @@ -32,11 +32,7 @@ def forward(self, acts): return recon @classmethod -<<<<<<<< HEAD:sae_auto_interp/autoencoders/Custom/gemmascope.py def from_pretrained(cls, model_name_or_path,position,device): -======== - def from_pretrained(cls, path: str, type: str, device: str) -> nn.Module: ->>>>>>>> 30d5e27537e1c108a9dda37e87e51dda9bfa4206:delphi/autoencoders/DeepMind/model.py path_to_params = hf_hub_download( repo_id=model_name_or_path, filename=f"{position}/params.npz", diff --git a/delphi/autoencoders/DeepMind/__init__.py b/delphi/autoencoders/DeepMind/__init__.py index e6461e41..536354d5 100644 --- a/delphi/autoencoders/DeepMind/__init__.py +++ b/delphi/autoencoders/DeepMind/__init__.py @@ -12,9 +12,10 @@ def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,in submodules = {} for layer in ae_layers: + model_name = f"google/gemma-scope-9b-pt-{type}" path = f"layer_{layer}/width_{size}/average_l0_{average_l0s[layer]}" - sae = JumpReLUSAE.from_pretrained(path,type,"cuda") + sae = JumpReLUSAE.from_pretrained(model_name,path,"cuda") sae.half() def _forward(sae, x): diff --git a/delphi/autoencoders/wrapper.py b/delphi/autoencoders/wrapper.py index cb30b4b6..299a1d77 100644 --- a/delphi/autoencoders/wrapper.py +++ b/delphi/autoencoders/wrapper.py @@ -1,3 +1,5 @@ + + from dataclasses import dataclass, field from functools import partial from typing import Any, Callable, Dict, List, Literal, Optional, Tuple @@ -17,19 +19,16 @@ class AutoencoderLatents(torch.nn.Module): """ Unified wrapper for different types of autoencoders, compatible with nnsight. """ - def __init__( self, autoencoder: Any, forward_function: Callable, width: int, - hookpoint: str, ) -> None: super().__init__() self.ae = autoencoder self._forward = forward_function self.width = width - self.hookpoint = hookpoint def forward(self, x: torch.Tensor) -> torch.Tensor: return self._forward(x) diff --git a/delphi/clients/client.py b/delphi/clients/client.py index c6fc409f..d41580ff 100644 --- a/delphi/clients/client.py +++ b/delphi/clients/client.py @@ -18,7 +18,7 @@ def __init__(self, model: str): async def generate(self, prompt: Union[str, List[Dict[str, str]]], **kwargs) -> Response: pass - @abstractmethod - async def process_response(self, raw_response: Any) -> Response: - pass + # @abstractmethod + # async def process_response(self, raw_response: Any) -> Response: + # pass diff --git a/delphi/features/constructors.py b/delphi/features/constructors.py index e2f97112..f3bad327 100644 --- a/delphi/features/constructors.py +++ b/delphi/features/constructors.py @@ -122,7 +122,7 @@ def random_non_activating_windows( def default_constructor( record: FeatureRecord, - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], + token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None, buffer_output: BufferOutput, n_not_active: int, ctx_len: int, From 92be5eaab206b74b8d9d6c4ba3bfbf618c5daf9a Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 18:08:39 +0000 Subject: [PATCH 050/132] Remove old key --- examples/score_explanations.ipynb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/score_explanations.ipynb b/examples/score_explanations.ipynb index 760f463c..b3787132 100644 --- a/examples/score_explanations.ipynb +++ b/examples/score_explanations.ipynb @@ -39,7 +39,7 @@ "\n", "\n", "\n", - "API_KEY = \"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\"#os.getenv(\"OPENROUTER_API_KEY\")" + "API_KEY = os.getenv(\"OPENROUTER_API_KEY\")" ] }, { From ae82aa24b376d607b8a9790cd743e541856facd4 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 18:14:47 +0000 Subject: [PATCH 051/132] None instead of -1 --- delphi/scorers/classifier/classifier.py | 21 +++++++++------------ delphi/scorers/classifier/sample.py | 8 ++++---- 2 files changed, 13 insertions(+), 16 deletions(-) diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index d6ecde79..0374659d 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -104,8 +104,8 @@ async def _generate( logger.error(f"Error generating text: {e}") response = None if response is None: - predictions = [-1] * self.n_examples_shown - probabilities = [-1] * self.n_examples_shown + predictions = [None] * self.n_examples_shown + probabilities = [None] * self.n_examples_shown else: selections = response.text logprobs = response.logprobs if self.log_prob else None @@ -113,21 +113,18 @@ async def _generate( predictions, probabilities = self._parse(selections, logprobs) except Exception as e: logger.error(f"Parsing selections failed: {e}") - predictions = [-1] * self.n_examples_shown - probabilities = [-1] * self.n_examples_shown + predictions = [None] * self.n_examples_shown + probabilities = [None] * self.n_examples_shown results = [] - correct = [] - response = [] - for sample, prediction, probability in zip(batch, predictions, probabilities): result = sample.data result.prediction = prediction - result.correct = prediction == result.ground_truth - correct.append(result.ground_truth) - response.append(prediction) - if probability is not None: - result.probability = probability + if prediction is not None: + result.correct = prediction == result.ground_truth + else: + result.correct = None + result.probability = probability results.append(result) if self.verbose: diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index d71e426e..be72eb27 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -29,16 +29,16 @@ class ClassifierOutput: ground_truth: bool """Whether the example is activating or not""" - prediction: bool = False + prediction: bool = False | None """Whether the model predicted the example activating or not""" - highlighted: bool = False + highlighted: bool = False """Whether the sample is highlighted""" - probability: float = 0.0 + probability: float = 0.0 | None """The probability of the example activating""" - correct: bool = False + correct: bool = False | None """Whether the prediction is correct""" From eb964e3861c9958ec771ad84b03c0a8b7424b086 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 10 Feb 2025 18:20:08 +0000 Subject: [PATCH 052/132] Update feature -> latent --- .gitignore | 2 +- delphi/__main__.py | 68 +++---- delphi/autoencoders/Neurons/__init__.py | 4 +- delphi/autoencoders/OpenAI/__init__.py | 2 +- delphi/autoencoders/Sam/model.py | 16 +- delphi/config.py | 26 +-- delphi/counterfactuals/pipeline.py | 38 ++-- delphi/explainers/default/prompts.py | 6 +- delphi/explainers/explainer.py | 20 +-- delphi/{features => latents}/__init__.py | 16 +- delphi/{features => latents}/cache.py | 130 +++++++------- delphi/{features => latents}/constructors.py | 20 +-- .../features.py => latents/latents.py} | 44 ++--- delphi/{features => latents}/loader.py | 168 +++++++++--------- delphi/{features => latents}/neighbours.py | 42 ++--- delphi/{features => latents}/samplers.py | 4 +- delphi/{features => latents}/stats.py | 34 ++-- delphi/scorers/classifier/classifier.py | 8 +- delphi/scorers/classifier/detection.py | 4 +- delphi/scorers/classifier/fuzz.py | 6 +- .../classifier/prompts/detection_prompt.py | 12 +- .../scorers/classifier/prompts/fuzz_prompt.py | 12 +- delphi/scorers/embedding/embedding.py | 8 +- delphi/scorers/scorer.py | 10 +- delphi/scorers/simulator/oai_simulator.py | 2 +- delphi/scorers/surprisal/prompts.py | 10 +- delphi/scorers/surprisal/surprisal.py | 6 +- delphi/tests/e2e.py | 26 +-- examples/caching_activations.ipynb | 6 +- examples/example_script.py | 42 ++--- examples/example_server.py | 10 +- examples/generate_explanations.ipynb | 26 +-- examples/latent_contexts.ipynb | 26 +-- examples/latents/.model.layers.10/config.json | 1 + examples/score_explanations.ipynb | 30 ++-- examples/server.py | 46 ++--- experiments/output_features/analysis.ipynb | 4 +- 37 files changed, 468 insertions(+), 467 deletions(-) rename delphi/{features => latents}/__init__.py (61%) rename delphi/{features => latents}/cache.py (67%) rename delphi/{features => latents}/constructors.py (92%) rename delphi/{features/features.py => latents/latents.py} (80%) rename delphi/{features => latents}/loader.py (67%) rename delphi/{features => latents}/neighbours.py (84%) rename delphi/{features => latents}/samplers.py (98%) rename delphi/{features => latents}/stats.py (72%) create mode 100644 examples/latents/.model.layers.10/config.json diff --git a/.gitignore b/.gitignore index 3bd21fdd..845b44cc 100644 --- a/.gitignore +++ b/.gitignore @@ -4,7 +4,7 @@ *.safetensors *.log *workspace -raw_features/* +latents/* results/* extras/* temp/* diff --git a/delphi/__main__.py b/delphi/__main__.py index e3bd9d95..a6f169c9 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -25,15 +25,15 @@ from sparsify.data import chunk_and_tokenize from simple_parsing import field, list_field -from delphi.config import ExperimentConfig, FeatureConfig +from delphi.config import ExperimentConfig, LatentConfig from delphi.explainers import DefaultExplainer -from delphi.features import FeatureDataset, FeatureLoader -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample +from delphi.latents import LatentDataset, LatentLoader +from delphi.latents.constructors import default_constructor +from delphi.latents.samplers import sample from delphi.pipeline import Pipeline, process_wrapper from delphi.clients import Offline, OpenRouter from delphi.config import CacheConfig -from delphi.features import FeatureCache +from delphi.latents import LatentCache from delphi.utils import assert_type from delphi.scorers import FuzzingScorer, DetectionScorer from delphi.pipeline import Pipe @@ -116,8 +116,8 @@ class RunConfig: overwrite: list[str] = list_field() """Whether to overwrite existing parts of the run. Options are 'cache', 'scores', and 'visualize'.""" - max_features: int | None = None - """Maximum number of features to explain for each SAE.""" + max_latents: int | None = None + """Maximum number of latents to explain for each SAE.""" load_in_8bit: bool = False """Load the model in 8-bit mode.""" @@ -171,7 +171,7 @@ def load_artifacts(run_cfg: RunConfig): model, # type: ignore run_cfg.sparse_model, run_cfg.hookpoints, - k=run_cfg.max_features, + k=run_cfg.max_latents, ) else: # Doing a hack @@ -193,7 +193,7 @@ def load_artifacts(run_cfg: RunConfig): return run_cfg.hookpoints, submodule_name_to_submodule, model, model.tokenizer async def process_cache( - feature_cfg: FeatureConfig, + latent_cfg: LatentConfig, run_cfg: RunConfig, experiment_cfg: ExperimentConfig, latents_path: Path, @@ -202,11 +202,11 @@ async def process_cache( # The layers to explain hookpoints: list[str], tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - feature_range: Tensor | None, + latent_range: Tensor | None, ): """ - Converts SAE feature activations in on-disk cache in the `latents_path` directory - to feature explanations in the `explanations_path` directory and explanation + Converts SAE latent activations in on-disk cache in the `latents_path` directory + to latent explanations in the `explanations_path` directory and explanation scores in the `fuzz_scores_path` directory. """ explanations_path.mkdir(parents=True, exist_ok=True) @@ -216,19 +216,19 @@ async def process_cache( fuzz_scores_path.mkdir(parents=True, exist_ok=True) detection_scores_path.mkdir(parents=True, exist_ok=True) - if feature_range is None: - feature_dict = None + if latent_range is None: + latent_dict = None else: - feature_dict = { - hook: feature_range for hook in hookpoints + latent_dict = { + hook: latent_range for hook in hookpoints } # The latent range to explain - feature_dict = cast(dict[str, int | Tensor], feature_dict) + latent_dict = cast(dict[str, int | Tensor], latent_dict) - dataset = FeatureDataset( + dataset = LatentDataset( raw_dir=str(latents_path), - cfg=feature_cfg, + cfg=latent_cfg, modules=hookpoints, - features=feature_dict, + latents=latent_dict, tokenizer=tokenizer, ) @@ -260,13 +260,13 @@ async def process_cache( token_loader=None, n_random=experiment_cfg.n_random, ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples, + max_examples=latent_cfg.max_examples, ) sampler = partial(sample, cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) + loader = LatentLoader(dataset, constructor=constructor, sampler=sampler) def explainer_postprocess(result): - with open(explanations_path / f"{result.record.feature}.txt", "wb") as f: + with open(explanations_path / f"{result.record.latent}.txt", "wb") as f: f.write(orjson.dumps(result.explanation)) return result @@ -288,9 +288,9 @@ def scorer_preprocess(result): # Saves the score to a file def scorer_postprocess(result, score_dir): - safe_feature_name = str(result.record.feature).replace("/", "--") + safe_latent_name = str(result.record.latent).replace("/", "--") - with open(score_dir / f"{safe_feature_name}.txt", "wb") as f: + with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: f.write(orjson.dumps(result.score)) scorer_pipe = Pipe( @@ -337,7 +337,7 @@ def populate_cache( filter_bos: bool, ): """ - Populates an on-disk cache in `latents_path` with SAE feature activations. + Populates an on-disk cache in `latents_path` with SAE latent activations. """ latents_path.mkdir(parents=True, exist_ok=True) @@ -365,7 +365,7 @@ def populate_cache( tokens = cast(TensorType["batch", "seq"], tokens) - cache = FeatureCache( + cache = LatentCache( hooked_model, submodule_name_to_submodule, batch_size=cfg.batch_size, @@ -381,7 +381,7 @@ def populate_cache( cache.save_config(save_dir=str(latents_path), cfg=cfg, model_name=run_cfg.model) -async def run(experiment_cfg: ExperimentConfig, feature_cfg: FeatureConfig, cache_cfg: CacheConfig, run_cfg: RunConfig): +async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_cfg: CacheConfig, run_cfg: RunConfig): base_path = Path.cwd() / "results" if run_cfg.name: base_path = base_path / run_cfg.name @@ -394,8 +394,8 @@ async def run(experiment_cfg: ExperimentConfig, feature_cfg: FeatureConfig, cach explanations_path = base_path / "explanations" scores_path = base_path / "scores" - feature_range = ( - torch.arange(run_cfg.max_features) if run_cfg.max_features else None + latent_range = ( + torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None ) hookpoints, submodule_name_to_submodule, hooked_model, tokenizer = load_artifacts( @@ -425,7 +425,7 @@ async def run(experiment_cfg: ExperimentConfig, feature_cfg: FeatureConfig, cach or "scores" in run_cfg.overwrite ): await process_cache( - feature_cfg, + latent_cfg, run_cfg, experiment_cfg, latents_path, @@ -433,7 +433,7 @@ async def run(experiment_cfg: ExperimentConfig, feature_cfg: FeatureConfig, cach scores_path, hookpoints, tokenizer, - feature_range, + latent_range, ) else: print(f"Files found in {scores_path}, skipping...") @@ -442,9 +442,9 @@ async def run(experiment_cfg: ExperimentConfig, feature_cfg: FeatureConfig, cach if __name__ == "__main__": parser = ArgumentParser() parser.add_arguments(ExperimentConfig, dest="experiment_cfg") - parser.add_arguments(FeatureConfig, dest="feature_cfg") + parser.add_arguments(LatentConfig, dest="latent_cfg") parser.add_arguments(CacheConfig, dest="cache_cfg") parser.add_arguments(RunConfig, dest="run_cfg") args = parser.parse_args() - asyncio.run(run(args.experiment_cfg, args.feature_cfg, args.cache_cfg, args.run_cfg)) + asyncio.run(run(args.experiment_cfg, args.latent_cfg, args.cache_cfg, args.run_cfg)) diff --git a/delphi/autoencoders/Neurons/__init__.py b/delphi/autoencoders/Neurons/__init__.py index bd4e06ff..bda1814b 100644 --- a/delphi/autoencoders/Neurons/__init__.py +++ b/delphi/autoencoders/Neurons/__init__.py @@ -44,8 +44,8 @@ def load_llama3_neurons( for layer in layers: submodule = model.model.layers[layer].mlp.down_proj - submodule.ae = TopKNeurons(k, input_dim=submodule.in_features, rotate=rotate, seed=seed,device=DEVICE) - submodule.ae.width = submodule.in_features + submodule.ae = TopKNeurons(k, input_dim=submodule.in_latents, rotate=rotate, seed=seed,device=DEVICE) + submodule.ae.width = submodule.in_latents submodule_dict[layer] = submodule with model.edit(" ") as edited: diff --git a/delphi/autoencoders/OpenAI/__init__.py b/delphi/autoencoders/OpenAI/__init__.py index 28d8d03e..9587c729 100644 --- a/delphi/autoencoders/OpenAI/__init__.py +++ b/delphi/autoencoders/OpenAI/__init__.py @@ -66,7 +66,7 @@ def load_random_oai_autoencoders( for layer in ae_layers: submodule = model.model.layers[layer] - sae = Autoencoder(n_latents, submodule.mlp.gate_proj.in_features, activation=ACTIVATIONS_CLASSES["TopK"](k=k), normalize=False, tied=False) + sae = Autoencoder(n_latents, submodule.mlp.gate_proj.in_latents, activation=ACTIVATIONS_CLASSES["TopK"](k=k), normalize=False, tied=False) sae.to(DEVICE).to(model.dtype) # Randomize the weights sae.encoder.weight.data.normal_(0, 1, generator=generator) diff --git a/delphi/autoencoders/Sam/model.py b/delphi/autoencoders/Sam/model.py index 8dc1db78..e5bf8902 100644 --- a/delphi/autoencoders/Sam/model.py +++ b/delphi/autoencoders/Sam/model.py @@ -13,7 +13,7 @@ class Dictionary(ABC): A dictionary consists of a collection of vectors, an encoder, and a decoder. """ - dict_size: int # number of features in the dictionary + dict_size: int # number of latents in the dictionary activation_dim: int # dimension of the activation vectors @abstractmethod @@ -55,17 +55,17 @@ def encode(self, x): def decode(self, f): return self.decoder(f) + self.bias - def forward(self, x, output_features=False, ghost_mask=None): + def forward(self, x, output_latents=False, ghost_mask=None): """ Forward pass of an autoencoder. x : activations to be autoencoded - output_features : if True, return the encoded features as well as the decoded x - ghost_mask : if not None, run in "ghost mode" where features are masked + output_latents : if True, return the encoded latents as well as the decoded x + ghost_mask : if not None, run in "ghost mode" where latents are masked """ if ghost_mask is None: # normal mode f = self.encode(x) x_hat = self.decode(f) - if output_features: + if output_latents: return x_hat, f else: return x_hat @@ -79,7 +79,7 @@ def forward(self, x, output_features=False, ghost_mask=None): f_ghost ) # note that this only applies the decoder weight matrix, no bias x_hat = self.decode(f) - if output_features: + if output_latents: return x_hat, x_ghost, f else: return x_hat, x_ghost @@ -113,8 +113,8 @@ def encode(self, x): def decode(self, f): return f - def forward(self, x, output_features=False, ghost_mask=None): - if output_features: + def forward(self, x, output_latents=False, ghost_mask=None): + if output_latents: return x, x else: return x diff --git a/delphi/config.py b/delphi/config.py index 9dc01870..6dd92b38 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -7,13 +7,13 @@ @dataclass class ExperimentConfig(Serializable): n_examples_train: int = 40 - """Number of examples to sample for feature explanation generation.""" + """Number of examples to sample for latent explanation generation.""" n_examples_test: int = 50 - """Number of examples to sample for feature explanation testing.""" + """Number of examples to sample for latent explanation testing.""" n_quantiles: int = 10 - """Number of feature activation quantiles to sample.""" + """Number of latent activation quantiles to sample.""" example_ctx_len: int = 32 """Length of each sampled example sequence. Longer sequences @@ -23,36 +23,36 @@ class ExperimentConfig(Serializable): """Number of non-activating examples to sample.""" train_type: Literal["top", "random", "quantiles"] = "quantiles" - """Type of sampler to use for feature explanation generation.""" + """Type of sampler to use for latent explanation generation.""" test_type: Literal["quantiles", "activation"] = "quantiles" - """Type of sampler to use for feature explanation testing.""" + """Type of sampler to use for latent explanation testing.""" @dataclass -class FeatureConfig(Serializable): +class LatentConfig(Serializable): width: int = 131_072 - """Number of features in each autoencoder""" + """Number of latents in each autoencoder""" min_examples: int = 200 - """Minimum number of examples to generate for a single feature. + """Minimum number of examples to generate for a single latent. If the number of activating examples is less than this, the - feature will not be explained and scored.""" + latent will not be explained and scored.""" max_examples: int = 10_000 - """Maximum number of examples to generate for a single feature.""" + """Maximum number of examples to generate for a single latent.""" n_splits: int = 5 - """Number of splits that features will be divided into.""" + """Number of splits that latents will be divided into.""" @dataclass class CacheConfig(Serializable): dataset_repo: str = "EleutherAI/rpj-v2-sample" - """Dataset repository to use for generating feature activations.""" + """Dataset repository to use for generating latent activations.""" dataset_split: str = "train[:1%]" - """Dataset split to use for generating feature activations.""" + """Dataset split to use for generating latent activations.""" dataset_name: str = "" """Dataset name to use.""" diff --git a/delphi/counterfactuals/pipeline.py b/delphi/counterfactuals/pipeline.py index 12b56d22..8b24f252 100644 --- a/delphi/counterfactuals/pipeline.py +++ b/delphi/counterfactuals/pipeline.py @@ -9,14 +9,14 @@ from typing import Callable, Literal import fire import torch -from ..config import ExperimentConfig, FeatureConfig -from ..features import ( - FeatureDataset, - FeatureLoader +from ..config import ExperimentConfig, LatentConfig +from ..latents import ( + LatentDataset, + LatentLoader ) from ..autoencoders.OpenAI.model import TopK, ACTIVATIONS_CLASSES, Autoencoder -from ..features.constructors import default_constructor -from ..features.samplers import sample +from ..latents.constructors import default_constructor +from ..latents.samplers import sample from ..autoencoders.DeepMind.model import JumpReLUSAE from . import ( ExplainerNeuronFormatter, @@ -28,7 +28,7 @@ expl_given_generation_score, LAYER_TO_L0 ) -from ..features import FeatureDataset +from ..latents import LatentDataset from functools import partial import random from transformers import AutoModelForCausalLM, AutoTokenizer @@ -40,22 +40,22 @@ PATH_ROOT = Path(__file__).parent.parent.parent -def get_feature_loader(feat_layer, n_feats, n_train, n_test, n_quantiles, width, latents: Literal["sae", "random"] = "sae"): +def get_latent_loader(feat_layer, n_feats, n_train, n_test, n_quantiles, width, latents: Literal["sae", "random"] = "sae"): module = f".model.layers.{feat_layer}" if latents == "sae": assert width == 131072 raw_dir = PATH_ROOT / "cache/gemma_sae_131k" else: raw_dir = PATH_ROOT / "cache/gemma_topk" - feature_dict = {f"{module}": torch.arange(0, n_feats)} - feature_cfg = FeatureConfig(width=width, n_splits=5, max_examples=100000, min_examples=200) + latent_dict = {f"{module}": torch.arange(0, n_feats)} + latent_cfg = LatentConfig(width=width, n_splits=5, max_examples=100000, min_examples=200) experiment_cfg = ExperimentConfig(n_random=0, example_ctx_len=64, n_quantiles=n_quantiles, n_examples_test=0, n_examples_train=(n_train + n_test) // n_quantiles, train_type="quantiles", test_type="even") - dataset = FeatureDataset( + dataset = LatentDataset( raw_dir=str(raw_dir), - cfg=feature_cfg, + cfg=latent_cfg, modules=[module], - features=feature_dict, # type: ignore + latents=latent_dict, # type: ignore ) constructor=partial( @@ -63,11 +63,11 @@ def get_feature_loader(feat_layer, n_feats, n_train, n_test, n_quantiles, width, tokens=dataset.tokens, # type: ignore n_random=experiment_cfg.n_random, ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples + max_examples=latent_cfg.max_examples ) sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) + loader = LatentLoader(dataset, constructor=constructor, sampler=sampler) return loader @@ -171,7 +171,7 @@ def main( with open(config_save_path, "w") as f: json.dump(config, f) - loader = get_feature_loader(feat_layer, n_feats, n_train, n_test, n_quantiles, sae_width, latents=latents) + loader = get_latent_loader(feat_layer, n_feats, n_train, n_test, n_quantiles, sae_width, latents=latents) subject = AutoModelForCausalLM.from_pretrained(subject_name, torch_dtype=torch.bfloat16).to(device) subject_tokenizer = AutoTokenizer.from_pretrained(subject_name) @@ -290,7 +290,7 @@ def load_explainer(): # we use the test set to tune the intervention strengths because we want the KL to be ~exactly `kl_threshold` on the test set for iter, (record, scorer_examples) in enumerate(tqdm(zip(loader, all_scorer_examples), desc="Tuning intervention strengths", total=len(all_scorer_examples))): garbage_collect() - feat_idx = record.feature.feature_index + feat_idx = record.latent.latent_index ids_s = [ids for ids, _ in scorer_examples] if zero_ablate: intervention_strength = None @@ -306,7 +306,7 @@ def load_explainer(): # do generation for iter, (record, scorer_examples) in enumerate(tqdm(zip(loader, all_scorer_examples), desc="Generating completions", total=len(all_scorer_examples))): garbage_collect() - feat_idx = record.feature.feature_index + feat_idx = record.latent.latent_index # get completions completions = [] @@ -327,7 +327,7 @@ def load_explainer(): # get intervention results for iter, (record, explainer_examples) in enumerate(tqdm(zip(loader, all_explainer_examples), desc="Running interventions for explainer", total=len(all_explainer_examples))): garbage_collect() - feat_idx = record.feature.feature_index + feat_idx = record.latent.latent_index intervention_examples = [] for ids, act in explainer_examples: diff --git a/delphi/explainers/default/prompts.py b/delphi/explainers/default/prompts.py index ed6bdbe5..8cbad111 100644 --- a/delphi/explainers/default/prompts.py +++ b/delphi/explainers/default/prompts.py @@ -1,7 +1,7 @@ ### SYSTEM PROMPT ### SYSTEM_SINGLE_TOKEN = """Your job is to look for patterns in text. You will be given a list of WORDS, your task is to provide an explanation for what pattern best describes them. Here are some guidelines: -- Produce a specific final description for the features common in the examples, and what patterns you found. +- Produce a specific final description for the latents common in the examples, and what patterns you found. - Don't focus on giving examples of important tokens, if the examples are uninformative, you don't need to mention them. - Do not make lists of possible explanations. Keep your explanations short and concise. - The last line of your response must be the formatted explanation, using [EXPLANATION]: @@ -28,7 +28,7 @@ You will be given a list of text examples on which special words are selected and between delimiters like <>. If a sequence of consecutive tokens all are important, the entire sequence of tokens will be contained between delimiters <>. How important each token is for the behavior is listed after each example in parentheses. -- Try to produce a concise final description. Simply describe the text features that are common in the examples, and what patterns you found. +- Try to produce a concise final description. Simply describe the text latents that are common in the examples, and what patterns you found. - If the examples are uninformative, you don't need to mention them. Don't focus on giving examples of important tokens, but try to summarize the patterns found in the examples. - Do not mention the marker tokens (<< >>) in your explanation. - Do not make lists of possible explanations. Keep your explanations short and concise. @@ -43,7 +43,7 @@ 1.Find the special words that are selected in the examples and list a couple of them. Search for patterns in these words, if there are any. Don't list more than 5 words. -2. Write down general shared features of the text examples. This could be related to the full sentence or to the words surrounding the marked words. +2. Write down general shared latents of the text examples. This could be related to the full sentence or to the words surrounding the marked words. 3. Formulate an hypothesis and write down the final explanation using [EXPLANATION]:. diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 4a1c58be..74ba1e4d 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -6,25 +6,25 @@ import aiofiles -from ..features.features import FeatureRecord +from ..latents.latents import LatentRecord class ExplainerResult(NamedTuple): - record: FeatureRecord - """Feature record passed through to scorer.""" + record: LatentRecord + """Latent record passed through to scorer.""" explanation: str - """Generated explanation for feature.""" + """Generated explanation for latent.""" class Explainer(ABC): @abstractmethod - def __call__(self, record: FeatureRecord) -> ExplainerResult: + def __call__(self, record: LatentRecord) -> ExplainerResult: pass -async def explanation_loader(record: FeatureRecord, explanation_dir: str) -> ExplainerResult: - async with aiofiles.open(f'{explanation_dir}/{record.feature}.txt', 'r') as f: +async def explanation_loader(record: LatentRecord, explanation_dir: str) -> ExplainerResult: + async with aiofiles.open(f'{explanation_dir}/{record.latent}.txt', 'r') as f: explanation = json.loads(await f.read()) return ExplainerResult( @@ -32,10 +32,10 @@ async def explanation_loader(record: FeatureRecord, explanation_dir: str) -> Exp explanation=explanation ) -async def random_explanation_loader(record: FeatureRecord, explanation_dir: str) -> ExplainerResult: +async def random_explanation_loader(record: LatentRecord, explanation_dir: str) -> ExplainerResult: explanations = [f for f in os.listdir(explanation_dir) if f.endswith(".txt")] - if str(record.feature) in explanations: - explanations.remove(str(record.feature)) + if str(record.latent) in explanations: + explanations.remove(str(record.latent)) random_explanation = random.choice(explanations) async with aiofiles.open(f'{explanation_dir}/{random_explanation}', 'r') as f: explanation = json.loads(await f.read()) diff --git a/delphi/features/__init__.py b/delphi/latents/__init__.py similarity index 61% rename from delphi/features/__init__.py rename to delphi/latents/__init__.py index 6e615b0c..302f5cbc 100644 --- a/delphi/features/__init__.py +++ b/delphi/latents/__init__.py @@ -1,19 +1,19 @@ -from .cache import FeatureCache +from .cache import LatentCache from .constructors import ( default_constructor, pool_max_activation_windows, random_non_activating_windows, ) -from .features import Example, Feature, FeatureRecord -from .loader import FeatureDataset, FeatureLoader +from .latents import Example, Latent, LatentRecord +from .loader import LatentDataset, LatentLoader from .samplers import sample from .stats import get_neighbors, unigram __all__ = [ - "FeatureCache", - "FeatureDataset", - "Feature", - "FeatureRecord", + "LatentCache", + "LatentDataset", + "Latent", + "LatentRecord", "Example", "pool_max_activation_windows", "random_non_activating_windows", @@ -21,5 +21,5 @@ "sample", "get_neighbors", "unigram", - "FeatureLoader" + "LatentLoader" ] diff --git a/delphi/features/cache.py b/delphi/latents/cache.py similarity index 67% rename from delphi/features/cache.py rename to delphi/latents/cache.py index a060f9fd..ff6d0942 100644 --- a/delphi/features/cache.py +++ b/delphi/latents/cache.py @@ -14,7 +14,7 @@ class Cache: """ - The Cache class stores feature locations and activations for modules. + The Cache class stores latent locations and activations for modules. It provides methods for adding, saving, and retrieving non-zero activations. """ @@ -25,18 +25,18 @@ def __init__( Initialize the Cache. Args: - filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. + filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific latents. batch_size (int): Size of batches for processing. Defaults to 64. """ - self.feature_locations = defaultdict(list) - self.feature_activations = defaultdict(list) + self.latent_locations = defaultdict(list) + self.latent_activations = defaultdict(list) self.tokens = defaultdict(list) self.filters = filters self.batch_size = batch_size def add( self, - latents: TensorType["batch", "sequence", "feature"], + latents: TensorType["batch", "sequence", "latent"], tokens: TensorType["batch", "sequence"], batch_number: int, module_path: str, @@ -45,53 +45,53 @@ def add( Add the latents from a module to the cache. Args: - latents (TensorType["batch", "sequence", "feature"]): Latent activations. + latents (TensorType["batch", "sequence", "latent"]): Latent activations. tokens (TensorType["batch", "sequence"]): Input tokens. batch_number (int): Current batch number. module_path (str): Path of the module. """ - feature_locations, feature_activations = self.get_nonzeros(latents, module_path) - feature_locations = feature_locations.cpu() - feature_activations = feature_activations.cpu() + latent_locations, latent_activations = self.get_nonzeros(latents, module_path) + latent_locations = latent_locations.cpu() + latent_activations = latent_activations.cpu() tokens = tokens.cpu() # Adjust batch indices - feature_locations[:, 0] += batch_number * self.batch_size - self.feature_locations[module_path].append(feature_locations) - self.feature_activations[module_path].append(feature_activations) + latent_locations[:, 0] += batch_number * self.batch_size + self.latent_locations[module_path].append(latent_locations) + self.latent_activations[module_path].append(latent_activations) self.tokens[module_path].append(tokens) def save(self): """ - Concatenate the feature locations and activations for all modules. + Concatenate the latent locations and activations for all modules. """ - for module_path in self.feature_locations.keys(): - self.feature_locations[module_path] = torch.cat( - self.feature_locations[module_path], dim=0 + for module_path in self.latent_locations.keys(): + self.latent_locations[module_path] = torch.cat( + self.latent_locations[module_path], dim=0 ) - self.feature_activations[module_path] = torch.cat( - self.feature_activations[module_path], dim=0 + self.latent_activations[module_path] = torch.cat( + self.latent_activations[module_path], dim=0 ) self.tokens[module_path] = torch.cat( self.tokens[module_path], dim=0 ) - def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]): + def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "latent"]): """ Get non-zero activations for large batches that exceed int32 max value. Args: - latents (TensorType["batch", "seq", "feature"]): Input latent activations. + latents (TensorType["batch", "seq", "latent"]): Input latent activations. Returns: - Tuple[torch.Tensor, torch.Tensor]: Non-zero feature locations and activations. + Tuple[torch.Tensor, torch.Tensor]: Non-zero latent locations and activations. """ # Calculate the maximum batch size that fits within sys.maxsize max_batch_size = torch.iinfo(torch.int32).max // (latents.shape[1] * latents.shape[2]) - nonzero_feature_locations = [] - nonzero_feature_activations = [] + nonzero_latent_locations = [] + nonzero_latent_activations = [] for i in range(0, latents.shape[0], max_batch_size): batch = latents[i:i+max_batch_size] @@ -102,49 +102,49 @@ def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]): # Adjust indices to account for batching batch_locations[:, 0] += i - nonzero_feature_locations.append(batch_locations) - nonzero_feature_activations.append(batch_activations) + nonzero_latent_locations.append(batch_locations) + nonzero_latent_activations.append(batch_activations) # Concatenate results - nonzero_feature_locations = torch.cat(nonzero_feature_locations, dim=0) - nonzero_feature_activations = torch.cat(nonzero_feature_activations, dim=0) - return nonzero_feature_locations, nonzero_feature_activations + nonzero_latent_locations = torch.cat(nonzero_latent_locations, dim=0) + nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0) + return nonzero_latent_locations, nonzero_latent_activations def get_nonzeros( - self, latents: TensorType["batch", "seq", "feature"], module_path: str + self, latents: TensorType["batch", "seq", "latent"], module_path: str ): """ - Get the nonzero feature locations and activations. + Get the nonzero latent locations and activations. Args: - latents (TensorType["batch", "seq", "feature"]): Input latent activations. + latents (TensorType["batch", "seq", "latent"]): Input latent activations. module_path (str): Path of the module. Returns: - Tuple[torch.Tensor, torch.Tensor]: Non-zero feature locations and activations. + Tuple[torch.Tensor, torch.Tensor]: Non-zero latent locations and activations. """ size = latents.shape[1] * latents.shape[0] * latents.shape[2] if size > torch.iinfo(torch.int32).max: - nonzero_feature_locations, nonzero_feature_activations = self.get_nonzeros_batch(latents) + nonzero_latent_locations, nonzero_latent_activations = self.get_nonzeros_batch(latents) else: - nonzero_feature_locations = torch.nonzero(latents.abs() > 1e-5) - nonzero_feature_activations = latents[latents.abs() > 1e-5] + nonzero_latent_locations = torch.nonzero(latents.abs() > 1e-5) + nonzero_latent_activations = latents[latents.abs() > 1e-5] - # Return all nonzero features if no filter is provided + # Return all nonzero latents if no filter is provided if self.filters is None: - return nonzero_feature_locations, nonzero_feature_activations + return nonzero_latent_locations, nonzero_latent_activations - # Return only the selected features if a filter is provided + # Return only the selected latents if a filter is provided else: - selected_features = self.filters[module_path] - mask = torch.isin(nonzero_feature_locations[:, 2], selected_features) + selected_latents = self.filters[module_path] + mask = torch.isin(nonzero_latent_locations[:, 2], selected_latents) - return nonzero_feature_locations[mask], nonzero_feature_activations[mask] + return nonzero_latent_locations[mask], nonzero_latent_activations[mask] -class FeatureCache: +class LatentCache: """ - FeatureCache manages the caching of feature activations for a model. + LatentCache manages the caching of latent activations for a model. It handles the process of running the model, storing activations, and saving them to disk. """ @@ -156,13 +156,13 @@ def __init__( filters: Dict[str, TensorType["indices"]] = None, ): """ - Initialize the FeatureCache. + Initialize the LatentCache. Args: - model: The model to cache features for. + model: The model to cache latents for. submodule_dict (Dict): Dictionary of submodules to cache. batch_size (int): Size of batches for processing. - filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific features. + filters (Dict[str, TensorType["indices"]], optional): Filters for selecting specific latents. """ self.model = model self.submodule_dict = submodule_dict @@ -205,7 +205,7 @@ def filter_submodules(self, filters: Dict[str, TensorType["indices"]]): Filter submodules based on the provided filters. Args: - filters (Dict[str, TensorType["indices"]]): Filters for selecting specific features. + filters (Dict[str, TensorType["indices"]]): Filters for selecting specific latents. """ filtered_submodules = {} for module_path in self.submodule_dict.keys(): @@ -215,7 +215,7 @@ def filter_submodules(self, filters: Dict[str, TensorType["indices"]]): def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): """ - Run the feature caching process. + Run the latent caching process. Args: n_tokens (int): Total number of tokens to process. @@ -227,7 +227,7 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): total_batches = len(token_batches) tokens_per_batch = token_batches[0].numel() - with tqdm(total=total_batches, desc="Caching features") as pbar: + with tqdm(total=total_batches, desc="Caching latents") as pbar: for batch_number, batch in enumerate(token_batches): total_tokens += tokens_per_batch @@ -251,18 +251,18 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): def save(self, save_dir, save_tokens: bool = True): """ - Save the cached features to disk. + Save the cached latents to disk. Args: - save_dir (str): Directory to save the features. + save_dir (str): Directory to save the latents. save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True. """ - for module_path in self.cache.feature_locations.keys(): + for module_path in self.cache.latent_locations.keys(): output_file = f"{save_dir}/{module_path}.safetensors" data = { - "locations": self.cache.feature_locations[module_path], - "activations": self.cache.feature_activations[module_path], + "locations": self.cache.latent_locations[module_path], + "activations": self.cache.latent_activations[module_path], } if save_tokens: data["tokens"] = self.cache.tokens[module_path] @@ -271,7 +271,7 @@ def save(self, save_dir, save_tokens: bool = True): def _generate_split_indices(self, n_splits): """ - Generate indices for splitting the feature space. + Generate indices for splitting the latent space. Args: n_splits (int): Number of splits to generate. @@ -286,7 +286,7 @@ def _generate_split_indices(self, n_splits): def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True): """ - Save the cached non-zero feature activations and locations in splits. + Save the cached non-zero latent activations and locations in splits. Args: n_splits (int): Number of splits to generate. @@ -294,19 +294,19 @@ def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True): save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True. """ split_indices = self._generate_split_indices(n_splits) - for module_path in self.cache.feature_locations.keys(): - feature_locations = self.cache.feature_locations[module_path] - feature_activations = self.cache.feature_activations[module_path] + for module_path in self.cache.latent_locations.keys(): + latent_locations = self.cache.latent_locations[module_path] + latent_activations = self.cache.latent_activations[module_path] tokens = self.cache.tokens[module_path].numpy() - feature_indices = feature_locations[:, 2] + latent_indices = latent_locations[:, 2] for start, end in split_indices: - mask = (feature_indices >= start) & (feature_indices <= end) + mask = (latent_indices >= start) & (latent_indices <= end) - masked_activations = feature_activations[mask].half().numpy() + masked_activations = latent_activations[mask].half().numpy() - masked_locations = feature_locations[mask].numpy() + masked_locations = latent_locations[mask].numpy() # Optimization to reduce the max value to enable a smaller dtype masked_locations[:, 2] = masked_locations[:, 2] - start.item() @@ -332,14 +332,14 @@ def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True): def save_config(self, save_dir: str, cfg: CacheConfig, model_name: str): """ - Save the configuration for the cached features. + Save the configuration for the cached latents. Args: save_dir (str): Directory to save the configuration. cfg (CacheConfig): Configuration object. model_name (str): Name of the model. """ - for module_path in self.cache.feature_locations.keys(): + for module_path in self.cache.latent_locations.keys(): config_file = f"{save_dir}/{module_path}/config.json" with open(config_file, "w") as f: config_dict = cfg.to_dict() diff --git a/delphi/features/constructors.py b/delphi/latents/constructors.py similarity index 92% rename from delphi/features/constructors.py rename to delphi/latents/constructors.py index f3bad327..75f23a18 100644 --- a/delphi/features/constructors.py +++ b/delphi/latents/constructors.py @@ -2,7 +2,7 @@ from torchtyping import TensorType from typing import Callable, Optional -from .features import FeatureRecord, prepare_examples +from .latents import LatentRecord, prepare_examples from .loader import BufferOutput @@ -42,10 +42,10 @@ def pool_max_activation_windows( max_examples: int, ): """ - Pool max activation windows from the buffer output and update the feature record. + Pool max activation windows from the buffer output and update the latent record. Args: - record (FeatureRecord): The feature record to update. + record (LatentRecord): The latent record to update. buffer_output (BufferOutput): The buffer output containing activations and locations. tokens (TensorType["batch", "seq"]): The input tokens. ctx_len (int): The context length. @@ -76,17 +76,17 @@ def pool_max_activation_windows( record.examples = prepare_examples(token_windows, activation_windows) def random_non_activating_windows( - record: FeatureRecord, + record: LatentRecord, tokens: TensorType["batch", "seq"], buffer_output: BufferOutput, ctx_len: int, n_not_active: int, ): """ - Generate random non-activating sequence windows and update the feature record. + Generate random non-activating sequence windows and update the latent record. Args: - record (FeatureRecord): The feature record to update. + record (LatentRecord): The latent record to update. tokens (TensorType["batch", "seq"]): The input tokens. buffer_output (BufferOutput): The buffer output containing activations and locations. ctx_len (int): The context length. @@ -121,7 +121,7 @@ def random_non_activating_windows( ) def default_constructor( - record: FeatureRecord, + record: LatentRecord, token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None, buffer_output: BufferOutput, n_not_active: int, @@ -129,10 +129,10 @@ def default_constructor( max_examples: int, ): """ - Construct feature examples using pool max activation windows and random activation windows. + Construct latent examples using pool max activation windows and random activation windows. Args: - record (FeatureRecord): The feature record to update. + record (LatentRecord): The latent record to update. token_loader (Optional[Callable[[], TensorType["batch", "seq"]]]): An optional function that creates the dataset tokens. buffer_output (BufferOutput): The buffer output containing activations and locations. @@ -154,7 +154,7 @@ def default_constructor( "` tokens=dataset.tokens`,\n" "pass\n" "` token_loader=lambda: dataset.load_tokens()`,\n" - "(assuming `dataset` is a `FeatureDataset` instance)." + "(assuming `dataset` is a `LatentDataset` instance)." ) pool_max_activation_windows( record, diff --git a/delphi/features/features.py b/delphi/latents/latents.py similarity index 80% rename from delphi/features/features.py rename to delphi/latents/latents.py index fd5c4b9b..e7c0df2b 100644 --- a/delphi/features/features.py +++ b/delphi/latents/latents.py @@ -9,7 +9,7 @@ @dataclass class Example: """ - A single example of feature data. + A single example of latent data. Attributes: tokens (TensorType["seq"]): Tokenized input sequence. @@ -52,48 +52,48 @@ def prepare_examples(tokens, activations): ] @dataclass -class Feature: +class Latent: """ - A feature extracted from a model's activations. + A latent extracted from a model's activations. Attributes: - module_name (str): The module name associated with the feature. - feature_index (int): The index of the feature within the module. + module_name (str): The module name associated with the latent. + latent_index (int): The index of the latent within the module. """ module_name: str - feature_index: int + latent_index: int def __repr__(self) -> str: """ - Return a string representation of the feature. + Return a string representation of the latent. Returns: - str: A string representation of the feature. + str: A string representation of the latent. """ - return f"{self.module_name}_feature{self.feature_index}" + return f"{self.module_name}_latent{self.latent_index}" -class FeatureRecord: +class LatentRecord: """ - A record of feature data. + A record of latent data. Attributes: - feature (Feature): The feature associated with the record. - examples: list[Example]: Example sequences where the feature activations, + latent (Latent): The latent associated with the record. + examples: list[Example]: Example sequences where the latent activations, assumed to be sorted in descending order by max activation """ def __init__( self, - feature: Feature, + latent: Latent, ): """ - Initialize the feature record. + Initialize the latent record. Args: - feature (Feature): The feature associated with the record. + latent (Latent): The latent associated with the record. """ - self.feature = feature + self.latent = latent self.examples: list[Example] = [] self.not_active: list[Example] = [] self.train: list[list[Example]] = [] @@ -102,7 +102,7 @@ def __init__( @property def max_activation(self): """ - Get the maximum activation value for the feature. + Get the maximum activation value for the latent. Returns: float: The maximum activation value. @@ -111,13 +111,13 @@ def max_activation(self): def save(self, directory: str, save_examples=False): """ - Save the feature record to a file. + Save the latent record to a file. Args: directory (str): The directory to save the file in. save_examples (bool): Whether to save the examples. Defaults to False. """ - path = f"{directory}/{self.feature}.json" + path = f"{directory}/{self.latent}.json" serializable = self.__dict__ if not save_examples: @@ -125,7 +125,7 @@ def save(self, directory: str, save_examples=False): serializable.pop("train") serializable.pop("test") - serializable.pop("feature") + serializable.pop("latent") with bf.BlobFile(path, "wb") as f: f.write(orjson.dumps(serializable)) @@ -137,7 +137,7 @@ def display( n: int = 10, ) -> str: """ - Display the feature record in a formatted string. + Display the latent record in a formatted string. Args: tokenizer (AutoTokenizer): The tokenizer to use for decoding. diff --git a/delphi/features/loader.py b/delphi/latents/loader.py similarity index 67% rename from delphi/features/loader.py rename to delphi/latents/loader.py index 65724f90..5bbedd24 100644 --- a/delphi/features/loader.py +++ b/delphi/latents/loader.py @@ -14,8 +14,8 @@ load_tokenized_data, ) -from ..config import FeatureConfig -from ..features.features import Feature, FeatureRecord +from ..config import LatentConfig +from .latents import Latent, LatentRecord class BufferOutput(NamedTuple): @@ -23,12 +23,12 @@ class BufferOutput(NamedTuple): Represents the output of a TensorBuffer. Attributes: - feature (Feature): The feature associated with this output. - locations (TensorType["locations", 2]): Tensor of feature locations. - activations (TensorType["locations"]): Tensor of feature activations. + latent (Latent): The latent associated with this output. + locations (TensorType["locations", 2]): Tensor of latent locations. + activations (TensorType["locations"]): Tensor of latent activations. tokens (TensorType["tokens"]): Tensor of all tokens. """ - feature: Feature + latent: Latent locations: TensorType["locations", 2] activations: TensorType["locations"] tokens: TensorType["tokens"] @@ -43,7 +43,7 @@ def __init__( self, path: str, module_path: str, - features: Optional[TensorType["features"]] = None, + latents: Optional[TensorType["latents"]] = None, min_examples: int = 120, ): """ @@ -52,12 +52,12 @@ def __init__( Args: path (str): Path to the tensor file. module_path (str): Path of the module. - features (Optional[TensorType["features"]]): Tensor of feature indices. + latents (Optional[TensorType["latents"]]): Tensor of latent indices. min_examples (int): Minimum number of examples required. Defaults to 120. """ self.tensor_path = path self.module_path = module_path - self.features = features + self.latents = latents self.min_examples = min_examples @@ -69,24 +69,24 @@ def __iter__(self): Yields: Union[BufferOutput, None]: BufferOutput if enough examples, None otherwise. """ - features, split_locations, split_activations, tokens = self.load() + latents, split_locations, split_activations, tokens = self.load() - for i in range(len(features)): - feature_locations = split_locations[i] - feature_activations = split_activations[i] - if len(feature_locations) < self.min_examples: + for i in range(len(latents)): + latent_locations = split_locations[i] + latent_activations = split_activations[i] + if len(latent_locations) < self.min_examples: yield None else: yield BufferOutput( - Feature(self.module_path, int(features[i].item())), - feature_locations, - feature_activations, + Latent(self.module_path, int(latents[i].item())), + latent_locations, + latent_activations, tokens ) def load(self): split_data = load_file(self.tensor_path) - first_feature = int(self.tensor_path.split("/")[-1].split("_")[0]) + first_latent = int(self.tensor_path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) locations = torch.tensor(split_data["locations"].astype(np.int64)) if "tokens" in split_data: @@ -94,22 +94,22 @@ def load(self): else: tokens = None - locations[:,2] = locations[:,2] + first_feature + locations[:,2] = locations[:,2] + first_latent - if self.features is not None: - wanted_locations = torch.isin(locations[:,2], self.features) + if self.latents is not None: + wanted_locations = torch.isin(locations[:,2], self.latents) locations = locations[wanted_locations] activations = activations[wanted_locations] indices = torch.argsort(locations[:,2], stable=True) activations = activations[indices] locations = locations[indices] - unique_features, counts = torch.unique_consecutive(locations[:,2], return_counts=True) - features = unique_features + unique_latents, counts = torch.unique_consecutive(locations[:,2], return_counts=True) + latents = unique_latents split_locations = torch.split(locations, counts.tolist()) split_activations = torch.split(activations, counts.tolist()) - return features, split_locations, split_activations, tokens + return latents, split_locations, split_activations, tokens @@ -121,36 +121,36 @@ def reset(self): self.locations = None -class FeatureDataset: +class LatentDataset: """ - Dataset which constructs TensorBuffers for each module and feature. + Dataset which constructs TensorBuffers for each module and latent. """ def __init__( self, raw_dir: str, - cfg: FeatureConfig, + cfg: LatentConfig, tokenizer: Optional[Callable] = None, modules: Optional[List[str]] = None, - features: Optional[Dict[str, Union[int, torch.Tensor]]] = None, + latents: Optional[Dict[str, Union[int, torch.Tensor]]] = None, ): """ - Initialize a FeatureDataset. + Initialize a LatentDataset. Args: - raw_dir (str): Directory containing raw feature data. - cfg (FeatureConfig): Configuration for feature processing. + raw_dir (str): Directory containing raw latent data. + cfg (LatentConfig): Configuration for latent processing. modules (Optional[List[str]]): List of module names to include. - features (Optional[Dict[str, Union[int, torch.Tensor]]]): Dictionary of features per module. + latents (Optional[Dict[str, Union[int, torch.Tensor]]]): Dictionary of latents per module. """ self.cfg = cfg self.buffers = [] - if features is None: + if latents is None: self._build(raw_dir, modules) else: # TODO fix type error - self._build_selected(raw_dir, modules, features) # type: ignore + self._build_selected(raw_dir, modules, latents) # type: ignore cache_config_dir = f"{raw_dir}/{modules[0]}/config.json" with open(cache_config_dir, "r") as f: @@ -184,15 +184,15 @@ def load_tokens(self): return self.tokens def _edges(self): - """Generate edge indices for feature splits.""" + """Generate edge indices for latent splits.""" return torch.linspace(0, self.cfg.width, steps=self.cfg.n_splits + 1).long() def _build(self, raw_dir: str, modules: Optional[List[str]] = None): """ - Build dataset buffers which load all cached features. + Build dataset buffers which load all cached latents. Args: - raw_dir (str): Directory containing raw feature data. + raw_dir (str): Directory containing raw latent data. modules (Optional[List[str]]): List of module names to include. """ edges = self._edges() @@ -207,29 +207,29 @@ def _build(self, raw_dir: str, modules: Optional[List[str]] = None): ) def _build_selected( - self, raw_dir: str, modules: List[str], features: Dict[str, Union[int, torch.Tensor]] + self, raw_dir: str, modules: List[str], latents: Dict[str, Union[int, torch.Tensor]] ): """ - Build a dataset buffer which loads only selected features. + Build a dataset buffer which loads only selected latents. Args: - raw_dir (str): Directory containing raw feature data. + raw_dir (str): Directory containing raw latent data. modules (List[str]): List of module names to include. - features (Dict[str, Union[int, torch.Tensor]]): Dictionary of features per module. + latents (Dict[str, Union[int, torch.Tensor]]): Dictionary of latents per module. """ edges = self._edges() for module in modules: - selected_features = features[module] - if isinstance(selected_features, int): - selected_features = torch.tensor([selected_features]) + selected_latents = latents[module] + if isinstance(selected_latents, int): + selected_latents = torch.tensor([selected_latents]) - bucketized = torch.bucketize(selected_features, edges, right=True) + bucketized = torch.bucketize(selected_latents, edges, right=True) unique_buckets = torch.unique(bucketized) for bucket in unique_buckets: mask = bucketized == bucket - _selected_features = selected_features[mask] + _selected_latents = selected_latents[mask] start, end = edges[bucket.item() - 1], edges[bucket.item()] @@ -240,7 +240,7 @@ def _build_selected( TensorBuffer( path, module, - _selected_features, + _selected_latents, min_examples=self.cfg.min_examples, ) ) @@ -257,19 +257,19 @@ def load( transform: Optional[Callable] = None, ): """ - Load and process feature records from the dataset. + Load and process latent records from the dataset. Args: collate (bool): Whether to collate all records into a single list. - constructor (Optional[Callable]): Function to construct feature records. - sampler (Optional[Callable]): Function to sample from feature records. - transform (Optional[Callable]): Function to transform feature records. + constructor (Optional[Callable]): Function to construct latent records. + sampler (Optional[Callable]): Function to sample from latent records. + transform (Optional[Callable]): Function to transform latent records. Returns: - Union[List[FeatureRecord], Generator]: Processed feature records. + Union[List[LatentRecord], Generator]: Processed latent records. """ def _process(buffer_output: BufferOutput): - record = FeatureRecord(buffer_output.feature) + record = LatentRecord(buffer_output.latent) if constructor is not None: constructor(record=record, buffer_output=buffer_output) @@ -292,14 +292,14 @@ def _worker(buffer): def _load(self, collate: bool, _worker: Callable): """ - Internal method to load feature records. + Internal method to load latent records. Args: collate (bool): Whether to collate all records into a single list. _worker (Callable): Function to process each buffer. Returns: - Union[List[FeatureRecord], Generator]: Processed feature records. + Union[List[LatentRecord], Generator]: Processed latent records. """ if collate: all_records = [] @@ -315,40 +315,40 @@ def reset(self): for buffer in self.buffers: buffer.reset() -class FeatureLoader: +class LatentLoader: """ - Loader class for processing feature records from a FeatureDataset. + Loader class for processing latent records from a LatentDataset. """ def __init__( self, - feature_dataset: 'FeatureDataset', + latent_dataset: 'LatentDataset', constructor: Optional[Callable] = None, sampler: Optional[Callable] = None, transform: Optional[Callable] = None ): """ - Initialize a FeatureLoader. + Initialize a LatentLoader. Args: - feature_dataset (FeatureDataset): The dataset to load features from. - constructor (Optional[Callable]): Function to construct feature records. - sampler (Optional[Callable]): Function to sample from feature records. - transform (Optional[Callable]): Function to transform feature records. + latent_dataset (LatentDataset): The dataset to load latents from. + constructor (Optional[Callable]): Function to construct latent records. + sampler (Optional[Callable]): Function to sample from latent records. + transform (Optional[Callable]): Function to transform latent records. """ - self.feature_dataset = feature_dataset + self.latent_dataset = latent_dataset self.constructor = constructor self.sampler = sampler self.transform = transform async def __aiter__(self): """ - Asynchronous iterator for processing feature records. + Asynchronous iterator for processing latent records. Yields: - FeatureRecord: Processed feature records. + LatentRecord: Processed latent records. """ - for buffer in self.feature_dataset.buffers: + for buffer in self.latent_dataset.buffers: async for record in self._aprocess_buffer(buffer): yield record @@ -360,26 +360,26 @@ async def _aprocess_buffer(self, buffer): buffer (TensorBuffer): Buffer to process. Yields: - Optional[FeatureRecord]: Processed feature record or None. + Optional[LatentRecord]: Processed latent record or None. """ for data in buffer: if data is not None: - record = await self._aprocess_feature(data) + record = await self._aprocess_latent(data) if record is not None: yield record await asyncio.sleep(0) # Allow other coroutines to run - async def _aprocess_feature(self, buffer_output: BufferOutput): + async def _aprocess_latent(self, buffer_output: BufferOutput): """ - Asynchronously process a single feature. + Asynchronously process a single latent. Args: - buffer_output (BufferOutput): Feature data to process. + buffer_output (BufferOutput): Latent data to process. Returns: - Optional[FeatureRecord]: Processed feature record or None. + Optional[LatentRecord]: Processed latent record or None. """ - record = FeatureRecord(buffer_output.feature) + record = LatentRecord(buffer_output.latent) if self.constructor is not None: self.constructor(record=record, buffer_output=buffer_output) if self.sampler is not None: @@ -390,12 +390,12 @@ async def _aprocess_feature(self, buffer_output: BufferOutput): def __iter__(self): """ - Synchronous iterator for processing feature records. + Synchronous iterator for processing latent records. Yields: - FeatureRecord: Processed feature records. + LatentRecord: Processed latent records. """ - for buffer in self.feature_dataset.buffers: + for buffer in self.latent_dataset.buffers: for record in self._process_buffer(buffer): yield record @@ -407,25 +407,25 @@ def _process_buffer(self, buffer): buffer (TensorBuffer): Buffer to process. Yields: - Optional[FeatureRecord]: Processed feature record or None. + Optional[LatentRecord]: Processed latent record or None. """ for data in buffer: if data is not None: - record = self._process_feature(data) + record = self._process_latent(data) if record is not None: yield record - def _process_feature(self, buffer_output: BufferOutput): + def _process_latent(self, buffer_output: BufferOutput): """ - Process a single feature synchronously. + Process a single latent synchronously. Args: - buffer_output (BufferOutput): Feature data to process. + buffer_output (BufferOutput): Latent data to process. Returns: - Optional[FeatureRecord]: Processed feature record or None. + Optional[LatentRecord]: Processed latent record or None. """ - record = FeatureRecord(buffer_output.feature) + record = LatentRecord(buffer_output.latent) if self.constructor is not None: self.constructor(record=record, buffer_output=buffer_output) if self.sampler is not None: diff --git a/delphi/features/neighbours.py b/delphi/latents/neighbours.py similarity index 84% rename from delphi/features/neighbours.py rename to delphi/latents/neighbours.py index 2b40a3f2..e40498bd 100644 --- a/delphi/features/neighbours.py +++ b/delphi/latents/neighbours.py @@ -10,16 +10,16 @@ class NeighbourCalculator: """ - Class to compute the neighbours of selected features using different methods: + Class to compute the neighbours of selected latents using different methods: - similarity: uses autoencoder weights - correlation: uses pre-activation records and autoencoder - - co-occurrence: uses feature dataset statistics + - co-occurrence: uses latent dataset statistics """ def __init__( self, - feature_dataset: Optional['FeatureDataset'] = None, + latent_dataset: Optional['LatentDataset'] = None, autoencoder: Optional["Autoencoder"] = None, pre_activation_record: Optional['PreActivationRecord'] = None, number_of_neighbours: int = 10, @@ -30,11 +30,11 @@ def __init__( Initialize a NeighbourCalculator. Args: - feature_dataset (Optional[FeatureDataset]): Dataset containing feature activations + latent_dataset (Optional[LatentDataset]): Dataset containing latent activations autoencoder (Optional[Autoencoder]): The trained autoencoder model pre_activation_record (Optional[PreActivationRecord]): Record of pre-activation values """ - self.feature_dataset = feature_dataset + self.latent_dataset = latent_dataset self.autoencoder = autoencoder self.pre_activation_record = pre_activation_record @@ -65,8 +65,8 @@ def _compute_neighbour_list(self, method: str) -> None: self.neighbour_cache[method] = self._compute_correlation_neighbours() elif method == 'co-occurrence': - if self.feature_dataset is None: - raise ValueError("Feature dataset is required for co-occurrence-based neighbours") + if self.latent_dataset is None: + raise ValueError("Latent dataset is required for co-occurrence-based neighbours") self.neighbour_cache[method] = self._compute_cooccurrence_neighbours() else: @@ -77,15 +77,15 @@ def _compute_similarity_neighbours(self) -> Dict[int, List[int]]: Compute neighbour lists based on weight similarity in the autoencoder. """ - # We use the encoder vectors to compute the similarity between features + # We use the encoder vectors to compute the similarity between latents encoder = self.autoencoder.encoder weight_matrix_normalized = encoder.weight / encoder.weight.norm(dim=1, keepdim=True) - # Compute the similarity between features + # Compute the similarity between latents similarity_matrix = weight_matrix_normalized.T @ weight_matrix_normalized - # Get the indices of the top k neighbours for each feature + # Get the indices of the top k neighbours for each latent top_k_indices = torch.topk(similarity_matrix, self.number_of_neighbours, dim=1).indices # Return the neighbour lists @@ -106,16 +106,16 @@ def _compute_correlation_neighbours(self) -> Dict[int, List[int]]: # load the encoder encoder_matrix = self.autoencoder.encoder.weight - # covariance between the features is u^T * covariance_matrix * u - covariance_between_features = encoder_matrix.T @ covariance_matrix @ encoder_matrix + # covariance between the latents is u^T * covariance_matrix * u + covariance_between_latents = encoder_matrix.T @ covariance_matrix @ encoder_matrix # the correlation is then the covariance devided by the product of the standard deviations product_of_std = torch.diag(covariance_matrix)**2 - correlation_matrix = covariance_between_features / product_of_std + correlation_matrix = covariance_between_latents / product_of_std - # get the indices of the top k neighbours for each feature + # get the indices of the top k neighbours for each latent top_k_indices = torch.topk(correlation_matrix, self.number_of_neighbours, dim=1).indices # return the neighbour lists @@ -124,29 +124,29 @@ def _compute_correlation_neighbours(self) -> Dict[int, List[int]]: def _compute_cooccurrence_neighbours(self) -> Dict[int, List[int]]: """ - Compute neighbour lists based on feature co-occurrence in the dataset. + Compute neighbour lists based on latent co-occurrence in the dataset. """ # To be implemented paths = [] - for buffer in self.feature_dataset.buffers: + for buffer in self.latent_dataset.buffers: paths.append(buffer.tensor_path) all_locations = [] all_activations = [] for path in paths: split_data = load_file(path) - first_feature = int(path.split("/")[-1].split("_")[0]) + first_latent = int(path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) locations = torch.tensor(split_data["locations"].astype(np.int64)) - locations[:,2] = locations[:,2] + first_feature + locations[:,2] = locations[:,2] + first_latent all_locations.append(locations) all_activations.append(activations) # concatenate the locations and activations locations = torch.cat(all_locations).cuda() activations = torch.cat(all_activations).cuda() - n_features = int(torch.max(locations[:,2])) + 1 + n_latents = int(torch.max(locations[:,2])) + 1 # 1. Get unique values of first 2 dims (i.e. absolute token index) and their counts # Trick is to use Cantor pairing function to have a bijective mapping between (batch_id, ctx_pos) and a unique 1D index @@ -162,7 +162,7 @@ def _compute_cooccurrence_neighbours(self) -> Dict[int, List[int]]: rows = cp.asarray(locations[:, 2]) cols = cp.asarray(locations_flat) data = cp.ones(len(rows)) - sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_features, n_tokens)) + sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_latents, n_tokens)) cooc_matrix = sparse_matrix @ sparse_matrix.T # Compute Jaccard similarity @@ -175,7 +175,7 @@ def compute_jaccard(cooc_matrix): # Compute Jaccard similarity matrix jaccard_matrix = compute_jaccard(cooc_matrix) - # get the indices of the top k neighbours for each feature + # get the indices of the top k neighbours for each latent top_k_indices = torch.topk(jaccard_matrix, self.number_of_neighbours, dim=1).indices # return the neighbour lists diff --git a/delphi/features/samplers.py b/delphi/latents/samplers.py similarity index 98% rename from delphi/features/samplers.py rename to delphi/latents/samplers.py index fcc0fcc6..2cd086bc 100644 --- a/delphi/features/samplers.py +++ b/delphi/latents/samplers.py @@ -4,7 +4,7 @@ from torchtyping import TensorType from ..config import ExperimentConfig -from .features import Example, FeatureRecord +from .latents import Example, LatentRecord from ..logger import logger @@ -134,7 +134,7 @@ def test( def sample( - record: FeatureRecord, + record: LatentRecord, cfg: ExperimentConfig, ): examples = record.examples diff --git a/delphi/features/stats.py b/delphi/latents/stats.py similarity index 72% rename from delphi/features/stats.py rename to delphi/latents/stats.py index 0d87235a..44c88b4d 100644 --- a/delphi/features/stats.py +++ b/delphi/latents/stats.py @@ -6,11 +6,11 @@ import torch import torch.nn.functional as F -from . import FeatureRecord +from . import LatentRecord def logits( - records: list[FeatureRecord], + records: list[LatentRecord], W_U: torch.nn.Module, W_dec: torch.nn.Module, k: int = 10, @@ -20,7 +20,7 @@ def logits( Compute the top k logits via direct logit attribution for a set of records. Args: - records (list[FeatureRecord]): A list of feature records. + records (list[LatentRecord]): A list of latent records. W_U (torch.nn.Module): The linear layer for the encoder. W_dec (torch.nn.Module): The linear layer for the decoder. k (int): The number of top logits to compute. @@ -30,9 +30,9 @@ def logits( decoded_top_logits (list[list[str]]): A list of top k logits for each record. """ - feature_indices = [record.feature.feature_index for record in records] + latent_indices = [record.latent.latent_index for record in records] - narrowed_logits = torch.matmul(W_U, W_dec[:, feature_indices]) + narrowed_logits = torch.matmul(W_U, W_dec[:, latent_indices]) top_logits = torch.topk(narrowed_logits, k, dim=0).indices @@ -48,7 +48,7 @@ def logits( def unigram( - record: FeatureRecord, k: int = 10, threshold: float = 0.0, negative_shift: int = 0 + record: LatentRecord, k: int = 10, threshold: float = 0.0, negative_shift: int = 0 ): avg_nonzero = [] top_tokens = [] @@ -73,8 +73,8 @@ def unigram( return -1, np.mean(avg_nonzero) -def cos(matrix, selected_features=[0]): - a = matrix[:, selected_features] +def cos(matrix, selected_latents=[0]): + a = matrix[:, selected_latents] b = matrix a = F.normalize(a, p=2, dim=0) @@ -85,25 +85,25 @@ def cos(matrix, selected_features=[0]): return cos_sim -def get_neighbors(submodule_dict, feature_filter, k=10): +def get_neighbors(submodule_dict, latent_filter, k=10): """ - Get the required features for neighbor scoring. + Get the required latents for neighbor scoring. Returns: neighbors_dict: Nested dictionary of modules -> neighbors -> indices, values - per_layer_features (dict): A dictionary of features per layer + per_layer_latents (dict): A dictionary of latents per layer """ neighbors_dict = defaultdict(dict) - per_layer_features = {} + per_layer_latents = {} for module_path, submodule in submodule_dict.items(): - selected_features = feature_filter.get(module_path, False) - if not selected_features: + selected_latents = latent_filter.get(module_path, False) + if not selected_latents: continue W_D = submodule.ae.autoencoder._module.decoder.weight - cos_sim = cos(W_D, selected_features=selected_features) + cos_sim = cos(W_D, selected_latents=selected_latents) top = torch.topk(cos_sim, k=k) top_indices = top.indices @@ -115,6 +115,6 @@ def get_neighbors(submodule_dict, feature_filter, k=10): "values": values.tolist()[1:], } - per_layer_features[module_path] = torch.unique(top_indices).tolist() + per_layer_latents[module_path] = torch.unique(top_indices).tolist() - return neighbors_dict, per_layer_features + return neighbors_dict, per_layer_latents diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index 0374659d..1279b63a 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -8,7 +8,7 @@ from transformers import PreTrainedTokenizer from ...clients.client import Client -from ...features import FeatureRecord +from ...latents import LatentRecord from ...logger import logger from ..scorer import Scorer, ScorerResult from .sample import ClassifierOutput, Sample @@ -48,7 +48,7 @@ def __init__( async def __call__( self, - record: FeatureRecord, + record: LatentRecord, ) -> list[ClassifierOutput]: samples = self._prepare(record) @@ -62,7 +62,7 @@ async def __call__( return ScorerResult(record=record, score=results) @abstractmethod - def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[list[Sample]]: pass @@ -205,5 +205,5 @@ def _batch(self, samples): for i in range(0, len(samples), self.n_examples_shown) ] - def call_sync(self, record: FeatureRecord) -> list[ClassifierOutput]: + def call_sync(self, record: LatentRecord) -> list[ClassifierOutput]: return asyncio.run(self.__call__(record)) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index a1543720..8b0e017f 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -1,7 +1,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from ...clients.client import Client -from ...features import FeatureRecord +from ...latents import LatentRecord from .classifier import Classifier from .prompts.detection_prompt import prompt from .sample import Sample, examples_to_samples @@ -45,7 +45,7 @@ def __init__( self.prompt = prompt - def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ Prepare and shuffle a list of samples for classification. """ diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index b3583033..317cc318 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -5,8 +5,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from ...clients.client import Client -from ...features.features import Example -from ...features import FeatureRecord +from ...latents.latents import Example +from ...latents import LatentRecord from ..scorer import Scorer from .classifier import Classifier from .prompts.fuzz_prompt import prompt @@ -63,7 +63,7 @@ def mean_n_activations_ceil(self, examples: list[Example]): return ceil(avg) - def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ Prepare and shuffle a list of samples for classification. """ diff --git a/delphi/scorers/classifier/prompts/detection_prompt.py b/delphi/scorers/classifier/prompts/detection_prompt.py index fd62b092..be5699b1 100644 --- a/delphi/scorers/classifier/prompts/detection_prompt.py +++ b/delphi/scorers/classifier/prompts/detection_prompt.py @@ -1,14 +1,14 @@ DSCORER_SYSTEM_PROMPT = """You are an intelligent and meticulous linguistics researcher. -You will be given a certain feature of text, such as "male pronouns" or "text with negative sentiment". +You will be given a certain latent of text, such as "male pronouns" or "text with negative sentiment". -You will then be given several text examples. Your task is to determine which examples possess the feature. +You will then be given several text examples. Your task is to determine which examples possess the latent. For each example in turn, return 1 if the sentence is correctly labeled or 0 if the tokens are mislabeled. You must return your response in a valid Python list. Do not return anything else besides a Python list. """ # https://www.neuronpedia.org/gpt2-small/6-res-jb/6048 -DSCORER_EXAMPLE_ONE = """Feature explanation: Words related to American football positions, specifically the tight end position. +DSCORER_EXAMPLE_ONE = """Latent explanation: Words related to American football positions, specifically the tight end position. Test examples: @@ -22,7 +22,7 @@ DSCORER_RESPONSE_ONE = "[1,0,0,0,1]" # https://www.neuronpedia.org/gpt2-small/6-res-jb/9396 -DSCORER_EXAMPLE_TWO = """Feature explanation: The word "guys" in the phrase "you guys". +DSCORER_EXAMPLE_TWO = """Latent explanation: The word "guys" in the phrase "you guys". Test examples: @@ -36,7 +36,7 @@ DSCORER_RESPONSE_TWO = "[0,0,0,0,0]" # https://www.neuronpedia.org/gpt2-small/8-res-jb/12654 -DSCORER_EXAMPLE_THREE = """Feature explanation: "of" before words that start with a capital letter. +DSCORER_EXAMPLE_THREE = """Latent explanation: "of" before words that start with a capital letter. Test examples: @@ -49,7 +49,7 @@ DSCORER_RESPONSE_THREE = "[1,1,1,1,1]" -GENERATION_PROMPT = """Feature explanation: {explanation} +GENERATION_PROMPT = """Latent explanation: {explanation} Text examples: diff --git a/delphi/scorers/classifier/prompts/fuzz_prompt.py b/delphi/scorers/classifier/prompts/fuzz_prompt.py index 00198911..13207654 100644 --- a/delphi/scorers/classifier/prompts/fuzz_prompt.py +++ b/delphi/scorers/classifier/prompts/fuzz_prompt.py @@ -1,15 +1,15 @@ # %% DSCORER_SYSTEM_PROMPT = """You are an intelligent and meticulous linguistics researcher. -You will be given a certain feature of text, such as "male pronouns" or "text with negative sentiment". You will be given a few examples of text that contain this feature. Portions of the sentence which strongly represent this feature are between tokens << and >>. +You will be given a certain latent of text, such as "male pronouns" or "text with negative sentiment". You will be given a few examples of text that contain this latent. Portions of the sentence which strongly represent this latent are between tokens << and >>. -Some examples might be mislabeled. Your task is to determine if every single token within << and >> is correctly labeled. Consider that all provided examples could be correct, none of the examples could be correct, or a mix. An example is only correct if every marked token is representative of the feature +Some examples might be mislabeled. Your task is to determine if every single token within << and >> is correctly labeled. Consider that all provided examples could be correct, none of the examples could be correct, or a mix. An example is only correct if every marked token is representative of the latent For each example in turn, return 1 if the sentence is correctly labeled or 0 if the tokens are mislabeled. You must return your response in a valid Python list. Do not return anything else besides a Python list. """ # https://www.neuronpedia.org/gpt2-small/6-res-jb/6048 -DSCORER_EXAMPLE_ONE = """Feature explanation: Words related to American football positions, specifically the tight end position. +DSCORER_EXAMPLE_ONE = """Latent explanation: Words related to American football positions, specifically the tight end position. Test examples: @@ -31,7 +31,7 @@ DSCORER_RESPONSE_ONE = "[1,0,0,1,1]" # https://www.neuronpedia.org/gpt2-small/6-res-jb/9396 -DSCORER_EXAMPLE_TWO = """Feature explanation: The word "guys" in the phrase "you guys". +DSCORER_EXAMPLE_TWO = """Latent explanation: The word "guys" in the phrase "you guys". Test examples: @@ -53,7 +53,7 @@ DSCORER_RESPONSE_TWO = "[0,0,0,0,0]" # https://www.neuronpedia.org/gpt2-small/8-res-jb/12654 -DSCORER_EXAMPLE_THREE = """Feature explanation: "of" before words that start with a capital letter. +DSCORER_EXAMPLE_THREE = """Latent explanation: "of" before words that start with a capital letter. Test examples: @@ -74,7 +74,7 @@ DSCORER_RESPONSE_THREE = "[1,1,1,1,1]" -GENERATION_PROMPT = """Feature explanation: {explanation} +GENERATION_PROMPT = """Latent explanation: {explanation} Text examples: diff --git a/delphi/scorers/embedding/embedding.py b/delphi/scorers/embedding/embedding.py index 64011e89..95651660 100644 --- a/delphi/scorers/embedding/embedding.py +++ b/delphi/scorers/embedding/embedding.py @@ -12,7 +12,7 @@ from sentence_transformers import SentenceTransformer from transformers import PreTrainedTokenizer from ...clients.client import Client -from ...features import Example, FeatureRecord +from ...latents import Example, LatentRecord from ..scorer import Scorer, ScorerResult @@ -53,7 +53,7 @@ def __init__( async def __call__( self, - record: FeatureRecord, + record: LatentRecord, ) -> list[EmbeddingOutput]: samples = self._prepare(record) @@ -65,11 +65,11 @@ async def __call__( return ScorerResult(record=record, score=results) - def call_sync(self, record: FeatureRecord) -> list[EmbeddingOutput]: + def call_sync(self, record: LatentRecord) -> list[EmbeddingOutput]: return asyncio.run(self.__call__(record)) - def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ Prepare and shuffle a list of samples for classification. """ diff --git a/delphi/scorers/scorer.py b/delphi/scorers/scorer.py index 3b34f05a..fa5a0ae5 100644 --- a/delphi/scorers/scorer.py +++ b/delphi/scorers/scorer.py @@ -1,18 +1,18 @@ from abc import ABC, abstractmethod from typing import Any, NamedTuple -from ..features.features import FeatureRecord +from ..latents.latents import LatentRecord class ScorerResult(NamedTuple): - record: FeatureRecord - """Feature record passed through.""" + record: LatentRecord + """Latent record passed through.""" score: Any - """Generated score for feature.""" + """Generated score for latent.""" class Scorer(ABC): @abstractmethod - def __call__(self, record: FeatureRecord) -> ScorerResult: + def __call__(self, record: LatentRecord) -> ScorerResult: pass diff --git a/delphi/scorers/simulator/oai_simulator.py b/delphi/scorers/simulator/oai_simulator.py index 1c252ee1..a9c71465 100644 --- a/delphi/scorers/simulator/oai_simulator.py +++ b/delphi/scorers/simulator/oai_simulator.py @@ -1,6 +1,6 @@ from typing import List -from ...features import Example +from ...latents import Example from .oai_autointerp import ( ActivationRecord, ExplanationNeuronSimulator, diff --git a/delphi/scorers/surprisal/prompts.py b/delphi/scorers/surprisal/prompts.py index 7a229ff8..81dccfbf 100644 --- a/delphi/scorers/surprisal/prompts.py +++ b/delphi/scorers/surprisal/prompts.py @@ -1,8 +1,8 @@ -# BASEPROMPT= ("The following is a description of a certain feature of text and a list of examples that contain the feature.\n" +# BASEPROMPT= ("The following is a description of a certain latent of text and a list of examples that contain the latent.\n" # "Description: \n" # "References to the Antichrist, the Apocalypse and conspiracy theories related to those topics. \n" # "Sentences:\n" -# " by which he distinguishes Antichrist is, that he would rob God of his honour and take it to himself, he gives the leading feature which we ought \n" +# " by which he distinguishes Antichrist is, that he would rob God of his honour and take it to himself, he gives the leading latent which we ought \n" # " would be destroyed. The worlds economy would likely collapse as a result and could usher in a one world government movement. I wrote a small 6 page \n" # "3 begins. And the rise of Antichrist. Get ready with God that you would be found worthy to escape the horrors that will fall on earth. \n" # "Description: \n" @@ -18,7 +18,7 @@ # " refurbishing the Bank’s branches.\nBIP reached 400 thousand users in one year\nThe use of BIP has already doubled\nThe \n" # " the Federal Deposit Insurance Corp.\nMr. Ranzini would not say where University plans to open loan offices in the future.Beyond Michigan and the \n" # ) -# BASEPROMPT= ("The following is a description of a certain feature of text and a list of examples that contain the feature.\n" +# BASEPROMPT= ("The following is a description of a certain latent of text and a list of examples that contain the latent.\n" # "Description: \n" # "References to the Antichrist, the Apocalypse and conspiracy theories related to those topics. \n" # "Sentences:\n" @@ -38,11 +38,11 @@ # " refurbishing the Bank’s branches. \n" # " the Federal Deposit Insurance Corp. \n" # ) -BASEPROMPT= ("The following is a description of a certain feature of text and a list of examples that contain the feature.\n" +BASEPROMPT= ("The following is a description of a certain latent of text and a list of examples that contain the latent.\n" "Description: \n" "References to the Antichrist, the Apocalypse and conspiracy theories related to those topics. \n" "Sentences: \n" - " by which he distinguishes Antichrist is, that he would rob God of his honour and take it to himself, he gives the leading feature which we ought \n" + " by which he distinguishes Antichrist is, that he would rob God of his honour and take it to himself, he gives the leading latent which we ought \n" "3 begins. And the rise of Antichrist. Get ready with \n" " would be destroyed. The worlds economy would likely collapse as a result and could usher in a one world government movement. I wrote a small 6 page \n" "Description: \n" diff --git a/delphi/scorers/surprisal/surprisal.py b/delphi/scorers/surprisal/surprisal.py index bbe62bde..23b53aa7 100644 --- a/delphi/scorers/surprisal/surprisal.py +++ b/delphi/scorers/surprisal/surprisal.py @@ -13,7 +13,7 @@ from transformers import PreTrainedTokenizer from ...clients.client import Client -from ...features import Example, FeatureRecord +from ...latents import Example, LatentRecord from ..scorer import Scorer, ScorerResult from .prompts import BASEPROMPT as base_prompt @@ -60,7 +60,7 @@ def __init__( async def __call__( self, - record: FeatureRecord, + record: LatentRecord, ) -> list[SurprisalOutput]: samples = self._prepare(record) @@ -72,7 +72,7 @@ async def __call__( return ScorerResult(record=record, score=results) - def _prepare(self, record: FeatureRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ Prepare and shuffle a list of samples for classification. """ diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index b44516e8..9e56e982 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -8,7 +8,7 @@ import numpy as np -from delphi.config import ExperimentConfig, FeatureConfig, CacheConfig +from delphi.config import ExperimentConfig, LatentConfig, CacheConfig from delphi.__main__ import run, RunConfig def parse_score_file(file_path): @@ -92,7 +92,7 @@ def build_df(path: Path, target_modules: list[str], range: Tensor | None): ] df_data = { col: [] - for col in ["file_name", "score_type", "feature_idx", "module"] + metrics_cols + for col in ["file_name", "score_type", "latent_idx", "module"] + metrics_cols } # Get subdirectories in the scores path @@ -104,16 +104,16 @@ def build_df(path: Path, target_modules: list[str], range: Tensor | None): for score_file in list(score_type_path.glob(f"*{module}*")) + list( score_type_path.glob(f".*{module}*") ): - feature_idx = int(score_file.stem.split("feature")[-1]) - if range is not None and feature_idx not in range: + latent_idx = int(score_file.stem.split("latent")[-1]) + if range is not None and latent_idx not in range: continue df = parse_score_file(score_file) - # Calculate the accuracy and cross entropy loss for this feature + # Calculate the accuracy and cross entropy loss for this latent df_data["file_name"].append(score_file.stem) df_data["score_type"].append(score_type) - df_data["feature_idx"].append(feature_idx) + df_data["latent_idx"].append(latent_idx) df_data["module"].append(module) for col in metrics_cols: df_data[col].append(df.loc[0, col]) @@ -138,10 +138,10 @@ async def test(): n_examples_train=40, n_examples_test=50, ) - feature_cfg = FeatureConfig( + latent_cfg = LatentConfig( width=32_768, - min_examples=200, # The minimum number of examples to consider for the feature to be explained - max_examples=10_000, # The maximum number of examples a feature may activate on before being excluded from explanation + min_examples=200, # The minimum number of examples to consider for the latent to be explained + max_examples=10_000, # The maximum number of examples a latent may activate on before being excluded from explanation ) run_cfg = RunConfig( name='test', @@ -151,23 +151,23 @@ async def test(): explainer_model="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", hookpoints=["layers.3"], explainer_model_max_len=4208, - max_features=100, + max_latents=100, seed=22, num_gpus=torch.cuda.device_count(), filter_bos=True ) start_time = time.time() - await run(experiment_cfg, feature_cfg, cache_cfg, run_cfg) + await run(experiment_cfg, latent_cfg, cache_cfg, run_cfg) end_time = time.time() print(f"Time taken: {end_time - start_time} seconds") scores_path = Path("results") / run_cfg.name / "scores" - feature_range = torch.arange(run_cfg.max_features) if run_cfg.max_features else None + latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None hookpoints, *_ = load_artifacts(run_cfg) - df = build_df(scores_path, hookpoints, feature_range) + df = build_df(scores_path, hookpoints, latent_range) # Performs better than random guessing for score_type in df["score_type"].unique(): diff --git a/examples/caching_activations.ipynb b/examples/caching_activations.ipynb index 2f2a2bbd..737a2261 100644 --- a/examples/caching_activations.ipynb +++ b/examples/caching_activations.ipynb @@ -89,13 +89,13 @@ "outputs": [], "source": [ "from delphi.config import CacheConfig\n", - "from delphi.features import FeatureCache\n", + "from delphi.latents import LatentCache\n", "from delphi.utils import load_tokenized_data\n" ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": {}, "outputs": [ { @@ -2299,7 +2299,7 @@ "\n", "\n", "\n", - "cache = FeatureCache(\n", + "cache = LatentCache(\n", " model,\n", " submodule_dict,\n", " batch_size = cfg.batch_size,\n", diff --git a/examples/example_script.py b/examples/example_script.py index 98c0dcd6..5cba2f9e 100644 --- a/examples/example_script.py +++ b/examples/example_script.py @@ -9,34 +9,34 @@ from simple_parsing import ArgumentParser from delphi.clients import Offline -from delphi.config import ExperimentConfig, FeatureConfig +from delphi.config import ExperimentConfig, LatentConfig from delphi.explainers import DefaultExplainer -from delphi.features import ( - FeatureDataset, - FeatureLoader +from delphi.latents import ( + LatentDataset, + LatentLoader ) -from delphi.features.constructors import default_constructor -from delphi.features.samplers import sample +from delphi.latents.constructors import default_constructor +from delphi.latents.samplers import sample from delphi.pipeline import Pipe,Pipeline, process_wrapper from delphi.scorers import FuzzingScorer, DetectionScorer -# run with python examples/example_script.py --model gemma/16k --module .model.layers.10 --features 100 --experiment_name test +# run with python examples/example_script.py --model gemma/16k --module .model.layers.10 --latents 100 --experiment_name test def main(args): module = args.module - feature_cfg = args.feature_options + latent_cfg = args.latent_options experiment_cfg = args.experiment_options shown_examples = args.shown_examples - n_features = args.features - start_feature = args.start_feature + n_latents = args.latents + start_latent = args.start_latent sae_model = args.model - feature_dict = {f"{module}": torch.arange(start_feature,start_feature+n_features)} - dataset = FeatureDataset( - raw_dir="raw_features", - cfg=feature_cfg, + latent_dict = {f"{module}": torch.arange(start_latent,start_latent+n_latents)} + dataset = LatentDataset( + raw_dir="raw_latents", + cfg=latent_cfg, modules=[module], - features=feature_dict, + latents=latent_dict, ) @@ -45,10 +45,10 @@ def main(args): token_loader=lambda: dataset.load_tokens(), n_random=experiment_cfg.n_random, ctx_len=experiment_cfg.example_ctx_len, - max_examples=feature_cfg.max_examples + max_examples=latent_cfg.max_examples ) sampler=partial(sample,cfg=experiment_cfg) - loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler) + loader = LatentLoader(dataset, constructor=constructor, sampler=sampler) ### Load client ### client = Offline("hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4",max_memory=0.8,max_model_len=5120) @@ -56,7 +56,7 @@ def main(args): ### Build Explainer pipe ### def explainer_postprocess(result): - with open(f"results/explanations/{sae_model}/{experiment_name}/{result.record.feature}.txt", "wb") as f: + with open(f"results/explanations/{sae_model}/{experiment_name}/{result.record.latent}.txt", "wb") as f: f.write(orjson.dumps(result.explanation)) return result @@ -88,7 +88,7 @@ def scorer_preprocess(result): def scorer_postprocess(result, score_dir): record = result.record - with open(f"results/scores/{sae_model}/{experiment_name}/{score_dir}/{record.feature}.txt", "wb") as f: + with open(f"results/scores/{sae_model}/{experiment_name}/{score_dir}/{record.latent}.txt", "wb") as f: f.write(orjson.dumps(result.score)) @@ -133,10 +133,10 @@ def scorer_postprocess(result, score_dir): parser.add_argument("--shown_examples", type=int, default=5) parser.add_argument("--model", type=str, default="gemma/16k") parser.add_argument("--module", type=str, default=".model.layers.10") - parser.add_argument("--features", type=int, default=100) + parser.add_argument("--latents", type=int, default=100) parser.add_argument("--experiment_name", type=str, default="default") parser.add_arguments(ExperimentConfig, dest="experiment_options") - parser.add_arguments(FeatureConfig, dest="feature_options") + parser.add_arguments(LatentConfig, dest="latent_options") args = parser.parse_args() experiment_name = args.experiment_name diff --git a/examples/example_server.py b/examples/example_server.py index ad812dcf..babafbbc 100644 --- a/examples/example_server.py +++ b/examples/example_server.py @@ -7,17 +7,17 @@ with open("/mnt/ssd-1/gpaulo/SAE-Zoology/extras/neuronpedia/formatted_contexts/activating_contexts_16k/mlp/0/layer_0_contexts_chunk_1.json", "r") as f: activation_data = json.load(f) # Load the explanation data -with open("/mnt/ssd-1/gpaulo/SAE-Zoology/extras/explanations_16k/model.layers.0.post_feedforward_layernorm_feature.json", "r") as f: +with open("/mnt/ssd-1/gpaulo/SAE-Zoology/extras/explanations_16k/model.layers.0.post_feedforward_layernorm_latent.json", "r") as f: explanation_data = json.load(f) tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-9b") -actual_data = activation_data["features"][10] +actual_data = activation_data["latents"][10] activations = actual_data["activations"] for activation in activations: activation["tokens"] = tokenizer.batch_decode(activation["tokens"]) # If you have the tokens already decoded, you can skip this step -feature_index = actual_data["feature_index"] -print(feature_index) -explanation = explanation_data[str(feature_index)] +latent_index = actual_data["latent_index"] +print(latent_index) +explanation = explanation_data[str(latent_index)] # Server URL BASE_URL = "http://localhost:5000" diff --git a/examples/generate_explanations.ipynb b/examples/generate_explanations.ipynb index c6dfee9d..5c40b19c 100644 --- a/examples/generate_explanations.ipynb +++ b/examples/generate_explanations.ipynb @@ -28,11 +28,11 @@ "import torch\n", "\n", "from delphi.clients import OpenRouter\n", - "from delphi.config import ExperimentConfig, FeatureConfig\n", + "from delphi.config import ExperimentConfig, LatentConfig\n", "from delphi.explainers import DefaultExplainer\n", - "from delphi.features import FeatureDataset, FeatureLoader\n", - "from delphi.features.constructors import default_constructor\n", - "from delphi.features.samplers import sample\n", + "from delphi.latents import LatentDataset, LatentLoader\n", + "from delphi.latents.constructors import default_constructor\n", + "from delphi.latents.samplers import sample\n", "from delphi.pipeline import Pipeline, process_wrapper\n", "\n", "API_KEY = os.getenv(\"OPENROUTER_API_KEY\")\n" @@ -44,9 +44,9 @@ "metadata": {}, "outputs": [], "source": [ - "feature_cfg = FeatureConfig(\n", + "latent_cfg = LatentConfig(\n", " width=131072, # The number of latents of your SAE\n", - " min_examples=200, # The minimum number of examples to consider for the feature to be explained\n", + " min_examples=200, # The minimum number of examples to consider for the latent to be explained\n", " max_examples=10000, # The maximum number of examples to be sampled from\n", " n_splits=5 # How many splits was the cache split into\n", ")\n" @@ -59,13 +59,13 @@ "outputs": [], "source": [ "module = \".model.layers.10\" # The layer to explain\n", - "feature_dict = {module: torch.arange(0,5)} # The what latents to explain\n", + "latent_dict = {module: torch.arange(0,5)} # The what latents to explain\n", "\n", - "dataset = FeatureDataset(\n", + "dataset = LatentDataset(\n", " raw_dir=\"latents\", # The folder where the cache is stored\n", - " cfg=feature_cfg,\n", + " cfg=latent_cfg,\n", " modules=[module],\n", - " features=feature_dict,\n", + " latents=latent_dict,\n", ")\n" ] }, @@ -113,10 +113,10 @@ " token_loader=None,\n", " n_not_active=experiment_cfg.n_non_activating, \n", " ctx_len=experiment_cfg.example_ctx_len, \n", - " max_examples=feature_cfg.max_examples\n", + " max_examples=latent_cfg.max_examples\n", " )\n", "sampler=partial(sample,cfg=experiment_cfg)\n", - "loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler)\n", + "loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n", " " ] }, @@ -137,7 +137,7 @@ "\n", "# The function that saves the explanations\n", "def explainer_postprocess(result):\n", - " with open(f\"results/explanations/{result.record.feature}.txt\", \"wb\") as f:\n", + " with open(f\"results/explanations/{result.record.latent}.txt\", \"wb\") as f:\n", " f.write(orjson.dumps(result.explanation))\n", " del result\n", " return None\n", diff --git a/examples/latent_contexts.ipynb b/examples/latent_contexts.ipynb index 53650437..4fc6c228 100644 --- a/examples/latent_contexts.ipynb +++ b/examples/latent_contexts.ipynb @@ -22,10 +22,10 @@ "from IPython.display import HTML, clear_output, display\n", "from nnsight import LanguageModel\n", "\n", - "from delphi.config import ExperimentConfig, FeatureConfig\n", - "from delphi.features import FeatureDataset, FeatureLoader\n", - "from delphi.features.constructors import default_constructor\n", - "from delphi.features.samplers import sample\n" + "from delphi.config import ExperimentConfig, LatentConfig\n", + "from delphi.latents import LatentDataset, LatentLoader\n", + "from delphi.latents.constructors import default_constructor\n", + "from delphi.latents.samplers import sample\n" ] }, { @@ -142,18 +142,18 @@ " \n", "def load_examples(layer_name=None,sae_size=\"131k\"):\n", " \n", - " feature_cfg = FeatureConfig(width=131072)\n", - " raw_dir = f\"raw_features/gemma/{sae_size}\"\n", + " latent_cfg = LatentConfig(width=131072)\n", + " raw_dir = f\"raw_latents/gemma/{sae_size}\"\n", " experiment_cfg = ExperimentConfig(n_random=0,train_type=\"quantiles\",n_examples_train=40,n_quantiles=10,example_ctx_len=32)\n", "\n", " #module = f\".model.layers.{layer_name}.post_feedforward_layernorm\"\n", " module = f\".model.layers.{layer_name}\"\n", " \n", - " dataset = FeatureDataset(\n", + " dataset = LatentDataset(\n", " raw_dir=raw_dir,\n", - " cfg=feature_cfg,\n", + " cfg=latent_cfg,\n", " modules=[module],\n", - " features={module:torch.tensor([3254, 6517, 8812,1318, 4834, 7605,])},\n", + " latents={module:torch.tensor([3254, 6517, 8812,1318, 4834, 7605,])},\n", " )\n", " constructor=partial(\n", " default_constructor,\n", @@ -162,14 +162,14 @@ " max_examples=10000\n", " )\n", " sampler=partial(sample,cfg=experiment_cfg)\n", - " loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler)\n", + " loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n", "\n", " all_examples = {}\n", " maximum_activations = {}\n", " for record in loader:\n", " train_examples = record.train\n", - " all_examples[str(record.feature)] = train_examples\n", - " maximum_activations[str(record.feature)] = record.max_activation\n", + " all_examples[str(record.latent)] = train_examples\n", + " maximum_activations[str(record.latent)] = record.max_activation\n", "\n", " return all_examples, maximum_activations\n", "\n", @@ -285,7 +285,7 @@ "name": "stdout", "output_type": "stream", "text": [ - ".model.layers.10_feature6517\n", + ".model.layers.10_latent6517\n", "40\n", "40\n" ] diff --git a/examples/latents/.model.layers.10/config.json b/examples/latents/.model.layers.10/config.json new file mode 100644 index 00000000..505b423f --- /dev/null +++ b/examples/latents/.model.layers.10/config.json @@ -0,0 +1 @@ +{"dataset_repo": "EleutherAI/rpj-v2-sample", "dataset_split": "train[:1%]", "dataset_name": "", "dataset_row": "raw_content", "batch_size": 8, "ctx_len": 256, "n_tokens": 1000000, "n_splits": 5, "model_name": "google/gemma-2-9b"} \ No newline at end of file diff --git a/examples/score_explanations.ipynb b/examples/score_explanations.ipynb index b3787132..1d789b8d 100644 --- a/examples/score_explanations.ipynb +++ b/examples/score_explanations.ipynb @@ -26,14 +26,14 @@ "import orjson\n", "import asyncio\n", "from delphi.clients import OpenRouter\n", - "from delphi.config import ExperimentConfig, FeatureConfig\n", + "from delphi.config import ExperimentConfig, LatentConfig\n", "from delphi.explainers import explanation_loader\n", - "from delphi.features import (\n", - " FeatureDataset,\n", - " FeatureLoader\n", + "from delphi.latents import (\n", + " LatentDataset,\n", + " LatentLoader\n", ")\n", - "from delphi.features.constructors import default_constructor\n", - "from delphi.features.samplers import sample\n", + "from delphi.latents.constructors import default_constructor\n", + "from delphi.latents.samplers import sample\n", "from delphi.pipeline import Pipeline, process_wrapper\n", "from delphi.scorers import FuzzingScorer\n", "\n", @@ -48,9 +48,9 @@ "metadata": {}, "outputs": [], "source": [ - "feature_cfg = FeatureConfig(\n", + "latent_cfg = LatentConfig(\n", " width=131072, # The number of latents of your SAE\n", - " min_examples=200, # The minimum number of examples to consider for the feature to be explained\n", + " min_examples=200, # The minimum number of examples to consider for the latent to be explained\n", " max_examples=10000, # The maximum number of examples to be sampled from\n", " n_splits=5 # How many splits was the cache split into\n", ")\n" @@ -63,13 +63,13 @@ "outputs": [], "source": [ "module = \".model.layers.10\" # The layer to score\n", - "feature_dict = {module: torch.arange(0,3)} # The what latents to score\n", + "latent_dict = {module: torch.arange(0,3)} # The what latents to score\n", "\n", - "dataset = FeatureDataset(\n", + "dataset = LatentDataset(\n", " raw_dir=\"latents\", # The folder where the cache is stored\n", - " cfg=feature_cfg,\n", + " cfg=latent_cfg,\n", " modules=[module],\n", - " features=feature_dict,\n", + " latents=latent_dict,\n", ")\n" ] }, @@ -117,10 +117,10 @@ " token_loader=None,\n", " n_not_active=experiment_cfg.n_non_activating, \n", " ctx_len=experiment_cfg.example_ctx_len, \n", - " max_examples=feature_cfg.max_examples\n", + " max_examples=latent_cfg.max_examples\n", " )\n", "sampler=partial(sample,cfg=experiment_cfg)\n", - "loader = FeatureLoader(dataset, constructor=constructor, sampler=sampler)\n", + "loader = LatentLoader(dataset, constructor=constructor, sampler=sampler)\n", " " ] }, @@ -153,7 +153,7 @@ "\n", "# Saves the score to a file\n", "def scorer_postprocess(result, score_dir):\n", - " with open(f\"results/scores/{result.record.feature}.txt\", \"wb\") as f:\n", + " with open(f\"results/scores/{result.record.latent}.txt\", \"wb\") as f:\n", " f.write(orjson.dumps(result.score))\n", "\n", "\n", diff --git a/examples/server.py b/examples/server.py index b02c54d3..43655326 100644 --- a/examples/server.py +++ b/examples/server.py @@ -10,7 +10,7 @@ from delphi.explainers import DefaultExplainer -from delphi.features import FeatureRecord, Feature, Example +from delphi.latents import LatentRecord, Latent, Example from delphi.scorers import FuzzingScorer, DetectionScorer,EmbeddingScorer app = Flask(__name__) @@ -32,7 +32,7 @@ def calculate_balanced_accuracy(dataframe): balanced_accuracy = (recall+tn/(tn+fp))/2 return balanced_accuracy -def per_feature_scores_fuzz_detection(score_data): +def per_latent_scores_fuzz_detection(score_data): data = [d for d in score_data if d.prediction != -1] @@ -41,7 +41,7 @@ def per_feature_scores_fuzz_detection(score_data): balanced_accuracy = calculate_balanced_accuracy(data_df) return balanced_accuracy -def per_feature_scores_embedding(score_data): +def per_latent_scores_embedding(score_data): data_df = pd.DataFrame(score_data) data_df["ground_truth"] = data_df["distance"]>0 print(data_df) @@ -77,18 +77,18 @@ def generate_explanation(): if 'model' not in data: return jsonify({"error": "Missing model"}), 400 try: - feature = Feature(f"feature", 0) + latent = Latent(f"latent", 0) examples = [] for activation in data['activations']: example = Example(activation['tokens'], torch.tensor(activation['values'])) examples.append(example) - feature_record = FeatureRecord(feature) - feature_record.train = examples + latent_record = LatentRecord(latent) + latent_record.train = examples client = OpenRouter(api_key=data['api_key'], model=data['model']) explainer = DefaultExplainer(client, tokenizer=None, threshold=0.6) - result = explainer.call_sync(feature_record) # Use call_sync instead of __call__ + result = explainer.call_sync(latent_record) # Use call_sync instead of __call__ return jsonify({"explanation": result.explanation}), 200 @@ -124,7 +124,7 @@ def generate_score_fuzz_detection(): if 'type' not in data: return jsonify({"error": "Missing type"}), 400 try: - feature = Feature(f"feature", 0) + latent = Latent(f"latent", 0) activating_examples = [] non_activating_examples = [] for activation in data['activations']: @@ -133,11 +133,11 @@ def generate_score_fuzz_detection(): activating_examples.append(example) else: non_activating_examples.append(example) - feature_record = FeatureRecord(feature) - feature_record.test = [activating_examples] - feature_record.extra_examples = non_activating_examples - feature_record.not_active = non_activating_examples - feature_record.explanation = data['explanation'] + latent_record = LatentRecord(latent) + latent_record.test = [activating_examples] + latent_record.extra_examples = non_activating_examples + latent_record.not_active = non_activating_examples + latent_record.explanation = data['explanation'] client = OpenRouter(api_key=data['api_key'], model=data['model']) if data['type'] == 'fuzz': @@ -146,9 +146,9 @@ def generate_score_fuzz_detection(): elif data['type'] == 'detection': # We can't use log_prob as it's not supported by OpenRouter scorer = DetectionScorer(client, tokenizer=None, n_examples_shown=5,verbose=False,log_prob=False) - result = scorer.call_sync(feature_record) # Use call_sync instead of __call__ + result = scorer.call_sync(latent_record) # Use call_sync instead of __call__ #print(result.score) - score = per_feature_scores_fuzz_detection(result.score) + score = per_latent_scores_fuzz_detection(result.score) return jsonify({"score": score,"breakdown": result.score}), 200 except Exception as e: @@ -171,7 +171,7 @@ def generate_score_embedding(): if 'explanation' not in data: return jsonify({"error": "Missing explanation"}), 400 try: - feature = Feature(f"feature", 0) + latent = Latent(f"latent", 0) activating_examples = [] non_activating_examples = [] for activation in data['activations']: @@ -180,15 +180,15 @@ def generate_score_embedding(): activating_examples.append(example) else: non_activating_examples.append(example) - feature_record = FeatureRecord(feature) - feature_record.test = [activating_examples] - feature_record.extra_examples = non_activating_examples - feature_record.negative_examples = non_activating_examples - feature_record.explanation = data['explanation'] + latent_record = LatentRecord(latent) + latent_record.test = [activating_examples] + latent_record.extra_examples = non_activating_examples + latent_record.negative_examples = non_activating_examples + latent_record.explanation = data['explanation'] scorer = EmbeddingScorer(model) - result = scorer.call_sync(feature_record) # Use call_sync instead of __call__ + result = scorer.call_sync(latent_record) # Use call_sync instead of __call__ #print(result.score) - score = per_feature_scores_embedding(result.score) + score = per_latent_scores_embedding(result.score) return jsonify({"score": score,"breakdown": result.score}), 200 except Exception as e: diff --git a/experiments/output_features/analysis.ipynb b/experiments/output_features/analysis.ipynb index 59bcb958..59148a77 100644 --- a/experiments/output_features/analysis.ipynb +++ b/experiments/output_features/analysis.ipynb @@ -537,7 +537,7 @@ " with open(path, \"r\") as f:\n", " data = json.load(f)\n", "\n", - " input_df = pd.concat([input_df, pd.DataFrame([{\"feat_idx\": int(k.split(\"feature\")[-1]), \"feat\": k, \"score\": data[k][\"score\"], \"explanations\": data[k][\"explanations\"]} for k in data.keys()])])\n", + " input_df = pd.concat([input_df, pd.DataFrame([{\"feat_idx\": int(k.split(\"latent\")[-1]), \"feat\": k, \"score\": data[k][\"score\"], \"explanations\": data[k][\"explanations\"]} for k in data.keys()])])\n", " cfg = base_cfg.copy()\n", " cfg[\"feat_layer\"] = fl\n", " output_df = pd.concat([output_df, load_result(cfg, convert_to_single=convert_to_single)])\n", @@ -677,7 +677,7 @@ " with open(path, \"r\") as f:\n", " data = json.load(f)\n", "\n", - " input_df = pd.DataFrame([{\"feat_idx\": int(k.split(\"feature\")[-1]), \"feat\": k, \"score\": data[k][\"score\"], \"explanations\": data[k][\"explanations\"]} for k in data.keys()])\n", + " input_df = pd.DataFrame([{\"feat_idx\": int(k.split(\"latent\")[-1]), \"feat\": k, \"score\": data[k][\"score\"], \"explanations\": data[k][\"explanations\"]} for k in data.keys()])\n", " joined_df = input_df.merge(output_df, on=\"feat_idx\")\n", " linr = linregress(joined_df[\"score\"], joined_df[\"max_delta_conditional_entropy\"])\n", " plt.scatter(joined_df[\"score\"], joined_df[\"max_delta_conditional_entropy\"], alpha=0.3)\n", From b65c5f9fb9964937fafd82d41974827c89500bc5 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 01:56:47 +0000 Subject: [PATCH 053/132] log results in __main__ by default --- delphi/__main__.py | 105 +++++++++++---------- delphi/log/result_analysis.py | 172 ++++++++++++++++++++++++++++++++++ delphi/tests/e2e.py | 168 +-------------------------------- 3 files changed, 233 insertions(+), 212 deletions(-) create mode 100644 delphi/log/result_analysis.py diff --git a/delphi/__main__.py b/delphi/__main__.py index a6f169c9..5bef1bb7 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -40,41 +40,7 @@ from delphi.autoencoders.eleuther import load_and_hook_sparsify_models from delphi.autoencoders.DeepMind import JumpReLUSAE from delphi.autoencoders.wrapper import AutoencoderLatents - - -def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str, hookpoints): - submodules = {} - - for layer in ae_layers: - - path = f"layer_{layer}/width_{size}/average_l0_{average_l0s[layer]}" - sae = JumpReLUSAE.from_pretrained(path,type,"cuda") - - sae.half() - def _forward(sae, x): - encoded = sae.encode(x) - return encoded - if type == "res": - submodule = model.model.layers[layer] - elif type == "mlp": - submodule = model.model.layers[layer].post_feedforward_layernorm - submodule.ae = AutoencoderLatents( - sae, partial(_forward, sae), width=sae.W_enc.shape[1] - ) - - hookpoint = [hookpoint for hookpoint in hookpoints if f"layers.{layer}" in hookpoint][0] - - submodules[hookpoint] = submodule - - with model.edit(" ") as edited: - for _, submodule in submodules.items(): - if type == "res": - acts = submodule.output[0] - else: - acts = submodule.output - submodule.ae(acts, hook=True) - - return submodules, edited +from delphi.log.result_analysis import log_results @dataclass @@ -89,16 +55,16 @@ class RunConfig: default="EleutherAI/sae-llama-3-8b-32x", positional=True, ) - """Name of the models associated with the model to explain, or path to - directory containing its weights. Models must be loadable with sparsify.""" + """Name of sparse models associated with the model to explain, or path to + directory containing their weights. Models must be loadable with sparsify.""" hookpoints: list[str] = list_field() - """List of hookpoints to load SAEs for.""" + """List of model hookpoints to attach sparse models to.""" explainer_model: str = field( default="hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4", ) - """Name of the model to use for explanation generation.""" + """Name of the model to use for explanation and scoring.""" explainer_model_max_len: int = field( default=5120, @@ -108,16 +74,16 @@ class RunConfig: explainer_provider: str = field( default="offline", ) - """Provider to use for explanation generation. Options are 'offline' for local models and 'openrouter' for API calls.""" + """Provider to use for explanation and scoring. Options are 'offline' for local models and 'openrouter' for API calls.""" name: str = "" - """The name of the run. Results will be saved in a directory with this name.""" - - overwrite: list[str] = list_field() - """Whether to overwrite existing parts of the run. Options are 'cache', 'scores', and 'visualize'.""" + """The name of the run. Results are saved in a directory with this name.""" max_latents: int | None = None - """Maximum number of latents to explain for each SAE.""" + """Maximum number of features to explain for each sparse model.""" + + filter_bos: bool = False + """Whether to filter out BOS tokens from the cache.""" load_in_8bit: bool = False """Load the model in 8-bit mode.""" @@ -133,15 +99,55 @@ class RunConfig: num_gpus: int = field( default=1, ) - """Number of GPUs to use for explanation generation.""" + """Number of GPUs to use for explanation and scoring.""" seed: int = field( default=22, ) """Seed for the random number generator.""" - filter_bos: bool = False - """Tokens to filter out from the cache.""" + log: bool = field( + default=True, + ) + """Whether to log summary statistics and results of the run.""" + + overwrite: list[str] = list_field() + """Whether to overwrite existing parts of the run. Options are 'cache', 'scores', and 'visualize'.""" + + +def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str, hookpoints): + submodules = {} + + for layer in ae_layers: + + path = f"layer_{layer}/width_{size}/average_l0_{average_l0s[layer]}" + sae = JumpReLUSAE.from_pretrained(path,type,"cuda") + + sae.half() + def _forward(sae, x): + encoded = sae.encode(x) + return encoded + if type == "res": + submodule = model.model.layers[layer] + elif type == "mlp": + submodule = model.model.layers[layer].post_feedforward_layernorm + submodule.ae = AutoencoderLatents( + sae, partial(_forward, sae), width=sae.W_enc.shape[1] + ) + + hookpoint = [hookpoint for hookpoint in hookpoints if f"layers.{layer}" in hookpoint][0] + + submodules[hookpoint] = submodule + + with model.edit(" ") as edited: + for _, submodule in submodules.items(): + if type == "res": + acts = submodule.output[0] + else: + acts = submodule.output + submodule.ae(acts, hook=True) + + return submodules, edited def load_artifacts(run_cfg: RunConfig): @@ -438,6 +444,9 @@ async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_ else: print(f"Files found in {scores_path}, skipping...") + if run_cfg.log: + log_results(scores_path, run_cfg.hookpoints) + if __name__ == "__main__": parser = ArgumentParser() diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py new file mode 100644 index 00000000..8261bb2f --- /dev/null +++ b/delphi/log/result_analysis.py @@ -0,0 +1,172 @@ +import orjson +import pandas as pd +from torch import Tensor +from pathlib import Path +import numpy as np + + +def feature_balanced_score_metrics(df: pd.DataFrame, score_type: str): + # Calculate weights based on non-errored examples + valid_examples = df['total_examples'] + weights = valid_examples / valid_examples.sum() + + weighted_mean_metrics = { + 'accuracy': np.average(df['accuracy'], weights=weights), + 'f1_score': np.average(df['f1_score'], weights=weights), + 'precision': np.average(df['precision'], weights=weights), + 'recall': np.average(df['recall'], weights=weights), + 'false_positives': np.average(df['false_positives'], weights=weights), + 'false_negatives': np.average(df['false_negatives'], weights=weights), + 'true_positives': np.average(df['true_positives'], weights=weights), + 'true_negatives': np.average(df['true_negatives'], weights=weights), + 'positive_class_ratio': np.average(df['positive_class_ratio'], weights=weights), + 'negative_class_ratio': np.average(df['negative_class_ratio'], weights=weights), + 'total_positives': np.average(df['total_positives'], weights=weights), + 'total_negatives': np.average(df['total_negatives'], weights=weights), + 'true_positive_rate': np.average(df['true_positive_rate'], weights=weights), + 'true_negative_rate': np.average(df['true_negative_rate'], weights=weights), + 'false_positive_rate': np.average(df['false_positive_rate'], weights=weights), + 'false_negative_rate': np.average(df['false_negative_rate'], weights=weights), + } + + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Accuracy: {weighted_mean_metrics['accuracy']:.3f}") + print(f"F1 Score: {weighted_mean_metrics['f1_score']:.3f}") + print(f"Precision: {weighted_mean_metrics['precision']:.3f}") + print(f"Recall: {weighted_mean_metrics['recall']:.3f}") + + fractions_failed = [failed_count / total_examples for failed_count, total_examples in zip(df['failed_count'], df['total_examples'])] + print(f"Average fraction of failed examples: {sum(fractions_failed) / len(fractions_failed):.3f}") + + print("\nConfusion Matrix:") + print(f"True Positive Rate: {weighted_mean_metrics['true_positive_rate']:.3f}") + print(f"True Negative Rate: {weighted_mean_metrics['true_negative_rate']:.3f}") + print(f"False Positive Rate: {weighted_mean_metrics['false_positive_rate']:.3f}") + print(f"False Negative Rate: {weighted_mean_metrics['false_negative_rate']:.3f}") + + print(f"\nClass Distribution:") + print(f"Positives: {df['total_positives'].sum():.0f} ({weighted_mean_metrics['positive_class_ratio']:.1%})") + print(f"Negatives: {df['total_negatives'].sum():.0f} ({weighted_mean_metrics['negative_class_ratio']:.1%})") + print(f"Total: {df['total_examples'].sum():.0f}") + + return weighted_mean_metrics + + +def parse_score_file(file_path): + with open(file_path, "rb") as f: + data = orjson.loads(f.read()) + + df = pd.DataFrame([{ + "text": "".join(example["str_tokens"]), + "distance": example["distance"], + "ground_truth": example["ground_truth"], + "prediction": example["prediction"], + "probability": example["probability"], + "correct": example["correct"], + "activations": example["activations"], + "highlighted": example["highlighted"] + } for example in data]) + + # Calculate basic counts + failed_count = (df['prediction'] == -1).sum() + df = df[df['prediction'] != -1] + df.reset_index(drop=True, inplace=True) + total_examples = len(df) + total_positives = (df["ground_truth"]).sum() + total_negatives = (~df["ground_truth"]).sum() + + # Calculate confusion matrix elements + true_positives = ((df["prediction"] == 1) & (df["ground_truth"])).sum() + true_negatives = ((df["prediction"] == 0) & (~df["ground_truth"])).sum() + false_positives = ((df["prediction"] == 1) & (~df["ground_truth"])).sum() + false_negatives = ((df["prediction"] == 0) & (df["ground_truth"])).sum() + + # Calculate rates + true_positive_rate = true_positives / total_positives if total_positives > 0 else 0 + true_negative_rate = true_negatives / total_negatives if total_negatives > 0 else 0 + false_positive_rate = false_positives / total_negatives if total_negatives > 0 else 0 + false_negative_rate = false_negatives / total_positives if total_positives > 0 else 0 + + # Calculate precision, recall, f1 (using sklearn for verification) + precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 + recall = true_positive_rate # Same as TPR + f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 + + # Calculate accuracy + accuracy = (true_positives + true_negatives) / total_examples + + # Add metrics to first row + metrics = { + "true_positive_rate": true_positive_rate, + "true_negative_rate": true_negative_rate, + "false_positive_rate": false_positive_rate, + "false_negative_rate": false_negative_rate, + "true_positives": true_positives, + "true_negatives": true_negatives, + "false_positives": false_positives, + "false_negatives": false_negatives, + "precision": precision, + "recall": recall, + "f1_score": f1_score, + "accuracy": accuracy, + "total_examples": total_examples, + "total_positives": total_positives, + "total_negatives": total_negatives, + "positive_class_ratio": total_positives / total_examples, + "negative_class_ratio": total_negatives / total_examples, + "failed_count": failed_count, + } + + for key, value in metrics.items(): + df.loc[0, key] = value + + return df + + +def build_scores_df(path: Path, target_modules: list[str], range: Tensor | None = None): + metrics_cols = [ + "accuracy", "probability", "precision", "recall", "f1_score", + "true_positives", "true_negatives", "false_positives", "false_negatives", + "true_positive_rate", "true_negative_rate", "false_positive_rate", "false_negative_rate", + "total_examples", "total_positives", "total_negatives", + "positive_class_ratio", "negative_class_ratio", "failed_count" + ] + df_data = { + col: [] + for col in ["file_name", "score_type", "feature_idx", "module"] + metrics_cols + } + + # Get subdirectories in the scores path + scores_types = [d.name for d in path.iterdir() if d.is_dir()] + + for score_type in scores_types: + score_type_path = path / score_type + + for module in target_modules: + for score_file in list(score_type_path.glob(f"*{module}*")) + list( + score_type_path.glob(f".*{module}*") + ): + feature_idx = int(score_file.stem.split("feature")[-1]) + if range is not None and feature_idx not in range: + continue + + df = parse_score_file(score_file) + + # Calculate the accuracy and cross entropy loss for this feature + df_data["file_name"].append(score_file.stem) + df_data["score_type"].append(score_type) + df_data["feature_idx"].append(feature_idx) + df_data["module"].append(module) + for col in metrics_cols: df_data[col].append(df.loc[0, col]) + + + df = pd.DataFrame(df_data) + assert not df.empty + return df + + +def log_results(scores_path: Path, target_modules: list[str]): + df = build_scores_df(scores_path, target_modules) + for score_type in df["score_type"].unique(): + score_df = df[df['score_type'] == score_type] + feature_balanced_score_metrics(score_df, score_type) diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index 9e56e982..ba6b87a6 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -1,126 +1,12 @@ from pathlib import Path -import orjson import torch -from torch import Tensor -import pandas as pd import asyncio import time -import numpy as np from delphi.config import ExperimentConfig, LatentConfig, CacheConfig from delphi.__main__ import run, RunConfig - -def parse_score_file(file_path): - with open(file_path, "rb") as f: - data = orjson.loads(f.read()) - - df = pd.DataFrame([{ - "text": "".join(example["str_tokens"]), - "distance": example["distance"], - "ground_truth": example["ground_truth"], - "prediction": example["prediction"], - "probability": example["probability"], - "correct": example["correct"], - "activations": example["activations"], - "highlighted": example["highlighted"] - } for example in data]) - - # Calculate basic counts - failed_count = (df['prediction'] == -1).sum() - df = df[df['prediction'] != -1] - df.reset_index(drop=True, inplace=True) - total_examples = len(df) - total_positives = (df["ground_truth"]).sum() - total_negatives = (~df["ground_truth"]).sum() - - # Calculate confusion matrix elements - true_positives = ((df["prediction"] == 1) & (df["ground_truth"])).sum() - true_negatives = ((df["prediction"] == 0) & (~df["ground_truth"])).sum() - false_positives = ((df["prediction"] == 1) & (~df["ground_truth"])).sum() - false_negatives = ((df["prediction"] == 0) & (df["ground_truth"])).sum() - - # Calculate rates - true_positive_rate = true_positives / total_positives if total_positives > 0 else 0 - true_negative_rate = true_negatives / total_negatives if total_negatives > 0 else 0 - false_positive_rate = false_positives / total_negatives if total_negatives > 0 else 0 - false_negative_rate = false_negatives / total_positives if total_positives > 0 else 0 - - # Calculate precision, recall, f1 (using sklearn for verification) - precision = true_positives / (true_positives + false_positives) if (true_positives + false_positives) > 0 else 0 - recall = true_positive_rate # Same as TPR - f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0 - - # Calculate accuracy - accuracy = (true_positives + true_negatives) / total_examples - - # Add metrics to first row - metrics = { - "true_positive_rate": true_positive_rate, - "true_negative_rate": true_negative_rate, - "false_positive_rate": false_positive_rate, - "false_negative_rate": false_negative_rate, - "true_positives": true_positives, - "true_negatives": true_negatives, - "false_positives": false_positives, - "false_negatives": false_negatives, - "precision": precision, - "recall": recall, - "f1_score": f1_score, - "accuracy": accuracy, - "total_examples": total_examples, - "total_positives": total_positives, - "total_negatives": total_negatives, - "positive_class_ratio": total_positives / total_examples, - "negative_class_ratio": total_negatives / total_examples, - "failed_count": failed_count, - } - - for key, value in metrics.items(): - df.loc[0, key] = value - - return df - - -def build_df(path: Path, target_modules: list[str], range: Tensor | None): - metrics_cols = [ - "accuracy", "probability", "precision", "recall", "f1_score", - "true_positives", "true_negatives", "false_positives", "false_negatives", - "true_positive_rate", "true_negative_rate", "false_positive_rate", "false_negative_rate", - "total_examples", "total_positives", "total_negatives", - "positive_class_ratio", "negative_class_ratio", "failed_count" - ] - df_data = { - col: [] - for col in ["file_name", "score_type", "latent_idx", "module"] + metrics_cols - } - - # Get subdirectories in the scores path - scores_types = [d.name for d in path.iterdir() if d.is_dir()] - print(scores_types) - for score_type in scores_types: - score_type_path = path / score_type - for module in target_modules: - for score_file in list(score_type_path.glob(f"*{module}*")) + list( - score_type_path.glob(f".*{module}*") - ): - latent_idx = int(score_file.stem.split("latent")[-1]) - if range is not None and latent_idx not in range: - continue - - df = parse_score_file(score_file) - - # Calculate the accuracy and cross entropy loss for this latent - df_data["file_name"].append(score_file.stem) - df_data["score_type"].append(score_type) - df_data["latent_idx"].append(latent_idx) - df_data["module"].append(module) - for col in metrics_cols: df_data[col].append(df.loc[0, col]) - - - df = pd.DataFrame(df_data) - assert not df.empty - return df +from delphi.log.result_analysis import build_scores_df, feature_balanced_score_metrics async def test(): @@ -162,58 +48,12 @@ async def test(): end_time = time.time() print(f"Time taken: {end_time - start_time} seconds") - scores_path = Path("results") / run_cfg.name / "scores" - - latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None - hookpoints, *_ = load_artifacts(run_cfg) - - df = build_df(scores_path, hookpoints, latent_range) - # Performs better than random guessing + scores_path = Path("results") / run_cfg.name / "scores" + df = build_scores_df(scores_path, run_cfg.hookpoints) for score_type in df["score_type"].unique(): score_df = df[df['score_type'] == score_type] - # Calculate weights based on non-errored examples - valid_examples = score_df['total_examples'] - weights = valid_examples / valid_examples.sum() - - weighted_mean_metrics = { - 'accuracy': np.average(score_df['accuracy'], weights=weights), - 'f1_score': np.average(score_df['f1_score'], weights=weights), - 'precision': np.average(score_df['precision'], weights=weights), - 'recall': np.average(score_df['recall'], weights=weights), - 'false_positives': np.average(score_df['false_positives'], weights=weights), - 'false_negatives': np.average(score_df['false_negatives'], weights=weights), - 'true_positives': np.average(score_df['true_positives'], weights=weights), - 'true_negatives': np.average(score_df['true_negatives'], weights=weights), - 'positive_class_ratio': np.average(score_df['positive_class_ratio'], weights=weights), - 'negative_class_ratio': np.average(score_df['negative_class_ratio'], weights=weights), - 'total_positives': np.average(score_df['total_positives'], weights=weights), - 'total_negatives': np.average(score_df['total_negatives'], weights=weights), - 'true_positive_rate': np.average(score_df['true_positive_rate'], weights=weights), - 'true_negative_rate': np.average(score_df['true_negative_rate'], weights=weights), - 'false_positive_rate': np.average(score_df['false_positive_rate'], weights=weights), - 'false_negative_rate': np.average(score_df['false_negative_rate'], weights=weights), - } - - print(f"\n=== {score_type.title()} Metrics ===") - print(f"Accuracy: {weighted_mean_metrics['accuracy']:.3f}") - print(f"F1 Score: {weighted_mean_metrics['f1_score']:.3f}") - print(f"Precision: {weighted_mean_metrics['precision']:.3f}") - print(f"Recall: {weighted_mean_metrics['recall']:.3f}") - - fractions_failed = [failed_count / total_examples for failed_count, total_examples in zip(score_df['failed_count'], score_df['total_examples'])] - print(f"Average fraction of failed examples: {sum(fractions_failed) / len(fractions_failed):.3f}") - - print("\nConfusion Matrix:") - print(f"True Positive Rate: {weighted_mean_metrics['true_positive_rate']:.3f}") - print(f"True Negative Rate: {weighted_mean_metrics['true_negative_rate']:.3f}") - print(f"False Positive Rate: {weighted_mean_metrics['false_positive_rate']:.3f}") - print(f"False Negative Rate: {weighted_mean_metrics['false_negative_rate']:.3f}") - - print(f"\nClass Distribution:") - print(f"Positives: {score_df['total_positives'].sum():.0f} ({weighted_mean_metrics['positive_class_ratio']:.1%})") - print(f"Negatives: {score_df['total_negatives'].sum():.0f} ({weighted_mean_metrics['negative_class_ratio']:.1%})") - print(f"Total: {score_df['total_examples'].sum():.0f}") + weighted_mean_metrics = feature_balanced_score_metrics(score_df, score_type) assert weighted_mean_metrics['accuracy'] > 0.55, f"Score type {score_type} has an accuracy of {weighted_mean_metrics['accuracy']}" From ea3916e37d756bdcf02e183dd17f2a7160afbae2 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:13:14 +1100 Subject: [PATCH 054/132] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e9dc390b..58f9ba38 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ To run a minimal pipeline from the command line, you can use the following comma `python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --max_features 100 --hookpoints layers.5 --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` -This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. +This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. Summary metrics including F1 and confusion matrices for each scorer are logged by default. # Loading Autoencoders From 9cf4a17d72fb2ff3d60514880f0997d66e0d74c7 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:14:06 +1100 Subject: [PATCH 055/132] Update README.md --- README.md | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/README.md b/README.md index 58f9ba38..f7531488 100644 --- a/README.md +++ b/README.md @@ -6,13 +6,13 @@ This library provides utilities for generating and scoring text explanations of The branch used for the article [Automatically Interpreting Millions of Features in Large Language Models](https://arxiv.org/pdf/2410.13928) is the legacy branch [article_version](https://github.com/EleutherAI/delphi/tree/article_version), that branch contains the scripts to reproduce our experiments. Note that we're still actively improving the codebase and that the newest version on the main branch could require slightly different usage. -## Installation +# Installation Install this library as a local editable installation. Run the following command from the `delphi` directory. ```pip install -e .``` -## Getting Started +# Getting Started To run a minimal pipeline from the command line, you can use the following command: @@ -20,11 +20,11 @@ To run a minimal pipeline from the command line, you can use the following comma This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. Summary metrics including F1 and confusion matrices for each scorer are logged by default. -# Loading Autoencoders +## Loading Autoencoders This library uses NNsight to load and edit a model with sparse auxiliary models. We provide wrappers to load GPT-2 autoencoders trained by [OpenAI](https://github.com/openai/sparse_autoencoder), for the [GemmaScope SAEs](https://arxiv.org/abs/2408.05147) and for SAEs and transcoders trained by EleutherAI using [SAE](https://github.com/EleutherAI/sae). See the [examples](examples/loading_saes.ipynb) directory for specific examples. -# Caching +## Caching The first step to generate explanations is to cache sparse model activations. To do so, load your sparse models into the base model, load the tokens you want to cache the activations from, create a `FeatureCache` object and run it. We recommend caching over at least 10M tokens. @@ -55,7 +55,7 @@ cache.save_splits( Safetensors are split into shards over the width of the autoencoder. -# Loading Feature Records +## Loading Feature Records The `.features` module provides utilities for reconstructing and sampling various statistics for sparse features. In this version of the code you needed to specify the width of the autoencoder, the minimum number examples for a feature to be included and the maximum number of examples to include, as well as the number of splits to divide the features into. @@ -106,7 +106,7 @@ constructor = partial(default_constructor, tokens=dataset.tokens, n_random=cfg.n sampler = partial(sample, cfg=cfg) ``` -# Generating Explanations +## Generating Explanations We currently support using OpenRouter's OpenAI compatible API or running locally with VLLM. Define the client you want to use, then create an explainer from the `.explainers` module. @@ -157,7 +157,7 @@ pipeline = Pipeline( asyncio.run(pipeline.run(n_processes)) ``` -# Scoring Explanations +## Scoring Explanations The process of running a scorer is similar to that of an explainer. You need to have a client running, and you need to create a Scorer from the '.scorer' module. You can either load the explanations you generated earlier, or generate new ones using the explainer pipe. @@ -228,27 +228,27 @@ pipeline = Pipeline( asyncio.run(pipeline.run()) ``` -## Simulation +### Simulation To do simulation scoring we forked and modified OpenAIs neuron explainer. The name of the scorer is `OpenAISimulator`, and it can be run with the same setup as described above. -## Surprisal +### Surprisal Surprisal scoring computes the loss over some examples and uses a base model. We don't use VLLM but run the model using the `AutoModelForCausalLM` wrapper from HuggingFace. The setup is similar as above but for a example check `surprisal.py` in the experiments folder. -## Embedding +### Embedding Embedding scoring uses a small embedding model through `sentence_transformers` to embed the examples do retrival. It also does not use VLLM but run the model directly. The setup is similar as above but for a example check `embedding.py` in the experiments folder. -# Scripts +## Scripts Example scripts can be found in `demos`. Some of these scripts can be called from the CLI, as seen in examples found in `scripts`. These baseline scripts should allow anyone to start generating and scoring explanations in any SAE they are interested in. One always needs to first cache the activations of the features of any given SAE, and then generating explanations and scoring them can be done at the same time. -# Experiments +## Experiments The experiments discussed in [the blog post](https://blog.eleuther.ai/autointerp/) were mostly run in a legacy version of this code, which can be found in the [Experiments](https://github.com/EleutherAI/delphi/tree/Experiments) branch. -# License +## License Copyright 2024 the EleutherAI Institute From 544115ce2733c00a745e4c13bb0c65265c6fd55f Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:17:39 +1100 Subject: [PATCH 056/132] Update README.md --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index f7531488..d7f1b7b7 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,8 @@ To run a minimal pipeline from the command line, you can use the following comma This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. Summary metrics including F1 and confusion matrices for each scorer are logged by default. +The pipeline is highly configurable and can be invoked programmatically. For an example, see the [end-to-end test](https://github.com/EleutherAI/delphi/blob/main/delphi/tests/e2e.py). + ## Loading Autoencoders This library uses NNsight to load and edit a model with sparse auxiliary models. We provide wrappers to load GPT-2 autoencoders trained by [OpenAI](https://github.com/openai/sparse_autoencoder), for the [GemmaScope SAEs](https://arxiv.org/abs/2408.05147) and for SAEs and transcoders trained by EleutherAI using [SAE](https://github.com/EleutherAI/sae). See the [examples](examples/loading_saes.ipynb) directory for specific examples. From 60ff74f1de2d2e9775502441ae7d10daded835e5 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:18:03 +1100 Subject: [PATCH 057/132] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d7f1b7b7..e4b5fba8 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Install this library as a local editable installation. Run the following command # Getting Started -To run a minimal pipeline from the command line, you can use the following command: +To run a minimal pipeline from the command line, use the following command: `python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --max_features 100 --hookpoints layers.5 --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` From 679301cfc2f70884bc29c50e85972a639b7fb57e Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:19:39 +1100 Subject: [PATCH 058/132] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index e4b5fba8..49376535 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ Install this library as a local editable installation. Run the following command # Getting Started -To run a minimal pipeline from the command line, use the following command: +To run an example pipeline from the command line, use the following command: `python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --max_features 100 --hookpoints layers.5 --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` From e45127041569f4a23cfba4fb71148c8104d0fe54 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:20:16 +1100 Subject: [PATCH 059/132] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 49376535..5bea03b7 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ Delphi was the home of a temple to Phoebus Apollo, which famously had the inscription, 'Know Thyself.' This library lets language models know themselves through automated interpretability. -This library provides utilities for generating and scoring text explanations of sparse autoencoder (SAE) features. The explainer and scorer models can be run locally or accessed using API calls via OpenRouter. +This library provides utilities for generating and scoring text explanations of sparse autoencoder (SAE) and transcoder features. The explainer and scorer models can be run locally or accessed using API calls via OpenRouter. The branch used for the article [Automatically Interpreting Millions of Features in Large Language Models](https://arxiv.org/pdf/2410.13928) is the legacy branch [article_version](https://github.com/EleutherAI/delphi/tree/article_version), that branch contains the scripts to reproduce our experiments. Note that we're still actively improving the codebase and that the newest version on the main branch could require slightly different usage. From ba25408d4eafc63d818dec866d11db952394579c Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:21:28 +1100 Subject: [PATCH 060/132] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 5bea03b7..25b505f2 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,7 @@ To run an example pipeline from the command line, use the following command: `python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --max_features 100 --hookpoints layers.5 --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` -This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. Summary metrics including F1 and confusion matrices for each scorer are logged by default. +This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. Summary metrics including per-scorer F1s and confusion matrices are logged by default. The pipeline is highly configurable and can be invoked programmatically. For an example, see the [end-to-end test](https://github.com/EleutherAI/delphi/blob/main/delphi/tests/e2e.py). From c274e768bf90c2cc45d290a44996827a37682d7c Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:33:24 +1100 Subject: [PATCH 061/132] Update README.md --- README.md | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 25b505f2..acde1293 100644 --- a/README.md +++ b/README.md @@ -14,13 +14,19 @@ Install this library as a local editable installation. Run the following command # Getting Started -To run an example pipeline from the command line, use the following command: +To run the default pipeline from the command line, use the following command: -`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --max_features 100 --hookpoints layers.5 --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` +`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` --n_tokens 10_000_000 --max_features 100 --hookpoints layers.5 -This will cache the activations of the first 10 million tokens of EleutherAI/rpj-v2-sample, generate explanations for the first 100 features of layer 5 using the explainer model, then score the explanations using fuzzing and detection scorers. Summary metrics including per-scorer F1s and confusion matrices are logged by default. +This command will: +1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample. +2. Generate explanations for the first 100 features of layer 5 using the specified explainer model. +3. Score the explanations uses fuzzing and detection scorers. +4. Log summary metrics including per-scorer F1 scores and confusion matrices. -The pipeline is highly configurable and can be invoked programmatically. For an example, see the [end-to-end test](https://github.com/EleutherAI/delphi/blob/main/delphi/tests/e2e.py). +The pipeline is highly configurable and can also be called programmatically (see the [end-to-end test](https://github.com/EleutherAI/delphi/blob/main/delphi/tests/e2e.py) for an example). + +To use other scorer types, instantiate a custom pipeline. ## Loading Autoencoders From 329d0f5dce288f68e5633dd7c1c559395efe367a Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 13:34:09 +1100 Subject: [PATCH 062/132] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index acde1293..b0bae286 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Install this library as a local editable installation. Run the following command To run the default pipeline from the command line, use the following command: -`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]'` --n_tokens 10_000_000 --max_features 100 --hookpoints layers.5 +`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/rpj-v2-sample' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_features 100 --hookpoints layers.5` This command will: 1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample. From fede84f91e583681ce8c98e97597aee46dae452e Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 02:39:26 +0000 Subject: [PATCH 063/132] fix bugs --- delphi/scorers/classifier/sample.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index be72eb27..a6b6efc5 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -1,11 +1,11 @@ import random from dataclasses import dataclass -from typing import List, NamedTuple +from typing import NamedTuple import torch from transformers import PreTrainedTokenizer -from ...features import Example +from ...latents import Example from ...logger import logger L = "<<" @@ -29,16 +29,16 @@ class ClassifierOutput: ground_truth: bool """Whether the example is activating or not""" - prediction: bool = False | None + prediction: bool | None = False """Whether the model predicted the example activating or not""" highlighted: bool = False """Whether the sample is highlighted""" - probability: float = 0.0 | None + probability: float | None = 0.0 """The probability of the example activating""" - correct: bool = False | None + correct: bool | None = False """Whether the prediction is correct""" From 265538948de850edd08bb8061efbc943d5245c0b Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 04:48:40 +0000 Subject: [PATCH 064/132] Fix v0.2 bug --- delphi/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 5bef1bb7..5940ebc8 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -264,7 +264,7 @@ async def process_cache( constructor = partial( default_constructor, token_loader=None, - n_random=experiment_cfg.n_random, + n_not_active=experiment_cfg.n_non_activating, ctx_len=experiment_cfg.example_ctx_len, max_examples=latent_cfg.max_examples, ) From 137dd7abf6601293b1fb7bc2375683beca06b916 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 05:02:25 +0000 Subject: [PATCH 065/132] Small fixes --- delphi/__main__.py | 20 ++++++++++++++------ delphi/autoencoders/DeepMind/__init__.py | 2 ++ delphi/autoencoders/wrapper.py | 2 +- delphi/pipeline.py | 2 +- 4 files changed, 18 insertions(+), 8 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 5940ebc8..8ea8d69c 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -114,6 +114,12 @@ class RunConfig: overwrite: list[str] = list_field() """Whether to overwrite existing parts of the run. Options are 'cache', 'scores', and 'visualize'.""" + num_examples_per_scorer_prompt: int = field( + default=1, + ) + """Number of examples to use for each scorer prompt. Using more than 1 improves scoring speed but can + leak information to the fuzzing scorer, and increases the scorer LLM task difficulty.""" + def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str, hookpoints): submodules = {} @@ -131,6 +137,8 @@ def _forward(sae, x): submodule = model.model.layers[layer] elif type == "mlp": submodule = model.model.layers[layer].post_feedforward_layernorm + else: + raise ValueError(f"Invalid autoencoder type: {type}") submodule.ae = AutoencoderLatents( sae, partial(_forward, sae), width=sae.W_enc.shape[1] ) @@ -304,7 +312,7 @@ def scorer_postprocess(result, score_dir): DetectionScorer( client, tokenizer=dataset.tokenizer, # type: ignore - batch_size=10, + batch_size=run_cfg.num_examples_per_scorer_prompt, verbose=False, log_prob=False, ), @@ -315,7 +323,7 @@ def scorer_postprocess(result, score_dir): FuzzingScorer( client, tokenizer=dataset.tokenizer, # type: ignore - batch_size=10, + batch_size=run_cfg.num_examples_per_scorer_prompt, verbose=False, log_prob=False, ), @@ -363,10 +371,10 @@ def populate_cache( flattened_tokens = tokens.flatten() mask = ~torch.isin(flattened_tokens, torch.tensor([tokenizer.bos_token_id])) masked_tokens = flattened_tokens[mask] - truncated_tokens = masked_tokens[ - : len(masked_tokens) - (len(masked_tokens) % cfg.ctx_len) - ] - tokens = truncated_tokens.reshape(-1, cfg.ctx_len) + truncated_tokens = masked_tokens[ + : len(masked_tokens) - (len(masked_tokens) % cfg.ctx_len) + ] + tokens = truncated_tokens.reshape(-1, cfg.ctx_len) tokens = cast(TensorType["batch", "seq"], tokens) diff --git a/delphi/autoencoders/DeepMind/__init__.py b/delphi/autoencoders/DeepMind/__init__.py index 536354d5..52e668db 100644 --- a/delphi/autoencoders/DeepMind/__init__.py +++ b/delphi/autoencoders/DeepMind/__init__.py @@ -25,6 +25,8 @@ def _forward(sae, x): submodule = model.model.layers[layer] elif type == "mlp": submodule = model.model.layers[layer].post_feedforward_layernorm + else: + raise ValueError(f"Invalid autoencoder type: {type}") submodule.ae = AutoencoderLatents( sae, partial(_forward, sae), width=sae.W_enc.shape[1] ) diff --git a/delphi/autoencoders/wrapper.py b/delphi/autoencoders/wrapper.py index 299a1d77..765eb54d 100644 --- a/delphi/autoencoders/wrapper.py +++ b/delphi/autoencoders/wrapper.py @@ -43,7 +43,7 @@ def from_pretrained( autoencoder_type = config.autoencoder_type model_name_or_path = config.model_name_or_path if autoencoder_type == "SAE": - from sae import Sae + from sparsify import Sae local = kwargs.get("local",None) assert local is not None, "local must be specified for SAE" if local: diff --git a/delphi/pipeline.py b/delphi/pipeline.py index 386503c9..dc8c5d21 100644 --- a/delphi/pipeline.py +++ b/delphi/pipeline.py @@ -5,7 +5,7 @@ from tqdm.asyncio import tqdm -def process_wrapper(function: Callable, preprocess: Callable = None, postprocess: Callable = None) -> Callable: +def process_wrapper(function: Callable, preprocess: Callable | None = None, postprocess: Callable | None = None) -> Callable: """ Wraps a function with optional preprocessing and postprocessing steps. From 331391b5ef4488f6a0e63aa15ba609e76da79985 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 09:19:11 +0000 Subject: [PATCH 066/132] Feature -> Latent --- delphi/latents/neighbours.py | 30 +++++++++++++++--------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 48f92ec7..765a9dcd 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -79,7 +79,7 @@ def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: Compute neighbour lists based on weight similarity in the autoencoder. """ print("Computing similarity neighbours") - # We use the encoder vectors to compute the similarity between features + # We use the encoder vectors to compute the similarity between latents if method == "encoder": encoder = self.autoencoder.encoder.cuda() weight_matrix_normalized = encoder.weight.data / encoder.weight.data.norm(dim=1, keepdim=True) @@ -91,16 +91,16 @@ def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: raise ValueError(f"Unknown method: {method}. Use 'encoder' or 'decoder'") wT = weight_matrix_normalized.T - # Compute the similarity between features + # Compute the similarity between latents done = False batch_size = weight_matrix_normalized.shape[0] - number_features = batch_size + number_latents = batch_size neighbour_lists = {} while not done: try: - for start in tqdm(range(0,number_features,batch_size)): + for start in tqdm(range(0,number_latents,batch_size)): rows = wT[start:start+batch_size] similarity_matrix = weight_matrix_normalized @ rows indices,values = torch.topk(similarity_matrix, self.number_of_neighbours+1, dim=1) @@ -139,21 +139,21 @@ def _compute_correlation_neighbours(self) -> dict[int, list[int]]: # load the encoder encoder_matrix = self.autoencoder.encoder.weight.cuda().half() - covariance_between_features = torch.zeros((encoder_matrix.shape[0],encoder_matrix.shape[0]),device="cpu") + covariance_between_latents = torch.zeros((encoder_matrix.shape[0],encoder_matrix.shape[0]),device="cpu") - # do batches of features + # do batches of latents batch_size = 1024 for start in tqdm(range(0,encoder_matrix.shape[0],batch_size)): end = min(encoder_matrix.shape[0],start+batch_size) encoder_rows = encoder_matrix[start:end] correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T - covariance_between_features[start:end] = correlation.cpu() + covariance_between_latents[start:end] = correlation.cpu() # the correlation is then the covariance divided by the product of the standard deviations - diagonal_covariance = torch.diag(covariance_between_features) + diagonal_covariance = torch.diag(covariance_between_latents) product_of_std = torch.sqrt(torch.outer(diagonal_covariance,diagonal_covariance)+1e-6) - correlation_matrix = covariance_between_features / product_of_std + correlation_matrix = covariance_between_latents / product_of_std # get the indices of the top k neighbours for each feature indices,values = torch.topk(correlation_matrix, self.number_of_neighbours+1, dim=1) @@ -188,7 +188,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: # concatenate the locations and activations locations = torch.cat(all_locations) - n_features = int(torch.max(locations[:,2])) + 1 + n_latents = int(torch.max(locations[:,2])) + 1 # 1. Get unique values of first 2 dims (i.e. absolute token index) and their counts # Trick is to use Cantor pairing function to have a bijective mapping between (batch_id, ctx_pos) and a unique 1D index @@ -204,9 +204,9 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: rows = cp.asarray(locations[:, 2]) cols = cp.asarray(locations_flat) data = cp.ones(len(rows)) - sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_features, n_tokens)) + sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_latents, n_tokens)) token_batch_size = 100_000 - cooc_matrix = cp.zeros((n_features, n_features), dtype=cp.float32) + cooc_matrix = cp.zeros((n_latents, n_latents), dtype=cp.float32) sparse_matrix_csc = sparse_matrix.tocsc() for start in tqdm(range(0, n_tokens, token_batch_size)): @@ -268,9 +268,9 @@ def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: class CovarianceEstimator: - def __init__(self, n_features, *, device = None): - self.mean = torch.zeros(n_features, device=device) - self.cov_ = torch.zeros(n_features, n_features, device=device) + def __init__(self, n_latents, *, device = None): + self.mean = torch.zeros(n_latents, device=device) + self.cov_ = torch.zeros(n_latents, n_latents, device=device) self.n = 0 def update(self, x: torch.Tensor): From 809c0105edd262de6277f229fbb039e3ff0bb38b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 10:42:49 +0000 Subject: [PATCH 067/132] Add to the docstring --- delphi/latents/latents.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index a870571f..86b6e2a4 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -81,6 +81,10 @@ class LatentRecord: latent (Latent): The latent associated with the record. examples: list[Example]: Example sequences where the latent activations, assumed to be sorted in descending order by max activation + not_active: list[Example]: Some sequences where the latent is not active + train: list[list[Example]]: Examples used for the explainer model + test: list[list[Example]]: Examples used for the scorer models + neighbours: list[int]: The indices of the neighbouring latents """ def __init__( From b419849d2c99a4ded043a67b6a54562d501e0ce1 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 11:17:07 +0000 Subject: [PATCH 068/132] tokens doesn't need to be in BufferOutput --- delphi/latents/loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index b7cd7bf8..92c8580e 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -37,7 +37,6 @@ class BufferOutput(NamedTuple): latent: Latent locations: TensorType["locations", 2] activations: TensorType["locations"] - tokens: TensorType["tokens"] class TensorBuffer: """ @@ -86,7 +85,6 @@ def __iter__(self): Latent(self.module_path, int(latents[i].item())), latent_locations, latent_activations, - tokens ) def load(self): From d0b499bb1d9cd6f2c8511559734036003cf44e0e Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 11:17:29 +0000 Subject: [PATCH 069/132] Adding correct functions to init --- delphi/latents/__init__.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/delphi/latents/__init__.py b/delphi/latents/__init__.py index 302f5cbc..5db9c716 100644 --- a/delphi/latents/__init__.py +++ b/delphi/latents/__init__.py @@ -1,13 +1,14 @@ from .cache import LatentCache from .constructors import ( - default_constructor, + constructor, pool_max_activation_windows, random_non_activating_windows, + neighbour_non_activation_windows, ) from .latents import Example, Latent, LatentRecord from .loader import LatentDataset, LatentLoader from .samplers import sample -from .stats import get_neighbors, unigram +from .stats import unigram __all__ = [ "LatentCache", @@ -17,9 +18,9 @@ "Example", "pool_max_activation_windows", "random_non_activating_windows", - "default_constructor", + "neighbour_non_activation_windows", + "constructor", "sample", - "get_neighbors", "unigram", "LatentLoader" ] From 70442ab9f6b9ca11726520bf92fdd4842fe98e6f Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 11:17:52 +0000 Subject: [PATCH 070/132] Reformulating constructors --- delphi/latents/constructors.py | 328 +++++++++++++++------------------ 1 file changed, 146 insertions(+), 182 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index da537d64..79b20a8f 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -32,10 +32,12 @@ def _top_k_pools( return token_windows, activation_windows + def pool_max_activation_windows( - record, - buffer_output: BufferOutput, - tokens: TensorType["batch", "seq"], + activations: TensorType["n_examples"], + tokens: TensorType["windows", "seq"], + ctx_indices: TensorType["n_examples"], + index_within_ctx: TensorType["n_examples"], ctx_len: int, max_examples: int, ): @@ -43,112 +45,137 @@ def pool_max_activation_windows( Pool max activation windows from the buffer output and update the latent record. Args: - record (LatentRecord): The latent record to update. - buffer_output (BufferOutput): The buffer output containing activations and locations. - tokens (TensorType["batch", "seq"]): The input tokens. + activations (TensorType["n_examples"]): The activations. + tokens (TensorType["windows", "seq"]): The input tokens. + ctx_indices (TensorType["n_examples"]): The context indices. + index_within_ctx (TensorType["n_examples"]): The index within the context. ctx_len (int): The context length. max_examples (int): The maximum number of examples. """ - flat_indices = buffer_output.locations[:, 0] * tokens.shape[1] + buffer_output.locations[:, 1] - ctx_indices = flat_indices // ctx_len - index_within_ctx = flat_indices % ctx_len - - # unique_ctx_indices: array of distinct context window indices in order of first appearance. i.e. sequential integers from 0 to 3903 + # unique_ctx_indices: array of distinct context window indices in order of first appearance. i.e. sequential integers from 0 to batch_size * cache_token_length // ctx_len # inverses: maps each activation back to its index in unique_ctx_indices (can be used to dereference the context window idx of each activation) # lengths: the number of activations per unique context window index unique_ctx_indices, inverses, lengths = torch.unique_consecutive(ctx_indices, return_counts=True, return_inverse=True) + # Get the max activation magnitude within each context window - max_buffer = torch.segment_reduce(buffer_output.activations, 'max', lengths=lengths) + max_buffer = torch.segment_reduce(activations, 'max', lengths=lengths) # Deduplicate the context windows - new_tensor= torch.zeros(len(unique_ctx_indices), ctx_len, dtype=buffer_output.activations.dtype) - new_tensor[inverses, index_within_ctx] = buffer_output.activations + new_tensor= torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype) + new_tensor[inverses, index_within_ctx] = activations - buffer_tokens = tokens.reshape(-1, ctx_len) - buffer_tokens = buffer_tokens[unique_ctx_indices] + + tokens = tokens[unique_ctx_indices] token_windows, activation_windows = _top_k_pools( - max_buffer, new_tensor, buffer_tokens, max_examples + max_buffer, new_tensor, tokens, max_examples ) - record.examples = prepare_examples(token_windows, activation_windows) + return token_windows, activation_windows -def random_non_activating_windows( +def constructor( record: LatentRecord, - tokens: TensorType["batch", "seq"], buffer_output: BufferOutput, - ctx_len: int, + all_data: AllData, n_not_active: int, + max_examples: int, + ctx_len: int, + constructor_type: str = "random", + token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None = None ): - """ - Generate random non-activating sequence windows and update the latent record. + tokens = all_data.tokens + if tokens is None: + if token_loader is None: + raise ValueError("Either tokens or token_loader must be provided") + try: + tokens = token_loader() + except TypeError: + raise ValueError( + "Starting with v0.2, `tokens` was renamed to `token_loader`, " + "which must be a callable for lazy loading.\n\n" + "Instead of passing\n" + "` tokens=dataset.tokens`,\n" + "pass\n" + "` token_loader=lambda: dataset.load_tokens()`,\n" + "(assuming `dataset` is a `LatentDataset` instance)." + ) - Args: - record (LatentRecord): The latent record to update. - tokens (TensorType["batch", "seq"]): The input tokens. - buffer_output (BufferOutput): The buffer output containing activations and locations. - ctx_len (int): The context length. - n_not_active (int): The number of non activating examples to generate. - """ - torch.manual_seed(22) - if n_not_active == 0: - record.not_active = [] - return - batch_size = tokens.shape[0] - unique_batch_pos = buffer_output.locations[:, 0].unique() + cache_token_length = tokens.shape[1] - mask = torch.ones(batch_size, dtype=torch.bool) - mask[unique_batch_pos] = False - - available_indices = mask.nonzero().squeeze() + # Get all positions where the latent is active + flat_indices = buffer_output.locations[:, 0] * cache_token_length + buffer_output.locations[:, 1] + ctx_indices = flat_indices // ctx_len + index_within_ctx = flat_indices % ctx_len + reshaped_tokens = tokens.reshape(-1, ctx_len) + n_windows = reshaped_tokens.shape[0] + + unique_batch_pos = ctx_indices.unique() - # TODO:What to do when the latent is active at least once in each batch? - if available_indices.numel() < n_not_active: - print("No available randomly sampled non-activating sequences") - record.not_active = [] - return - else: - selected_indices = available_indices[torch.randint(0,len(available_indices),size=(n_not_active,))] + mask = torch.ones(n_windows, dtype=torch.bool) + mask[unique_batch_pos] = False + # Indices where the latent is active + active_indices = mask.nonzero(as_tuple=False).squeeze() + activations = buffer_output.activations + + # Add activation examples to the record in place + print(activations.shape, reshaped_tokens.shape, ctx_indices.shape, index_within_ctx.shape, max_examples) + token_windows, activation_windows = pool_max_activation_windows(record, + activations=activations, + tokens=reshaped_tokens, + ctx_indices=ctx_indices, + index_within_ctx=index_within_ctx, + max_examples=max_examples) + record.examples = prepare_examples(token_windows, activation_windows) - toks = tokens[selected_indices, 10 : 10 + ctx_len] + if constructor_type == "random": + # Add random non-activating examples to the record in place + random_non_activating_windows( + record, + available_indices=active_indices, + reshaped_tokens=reshaped_tokens, + n_not_active=n_not_active, + ) + elif constructor_type == "neighbour": + neighbour_non_activation_windows( + record, + not_active_mask=mask, + tokens=tokens, + all_data=all_data, + ctx_len=ctx_len, + n_not_active=n_not_active, + ) - record.not_active = prepare_examples( - toks, - torch.zeros_like(toks), - ) -def neighbour_random_activation_windows( +def neighbour_non_activation_windows( record: LatentRecord, + not_active_mask: TensorType["n_windows"], tokens: TensorType["batch", "seq"], - buffer_output: BufferOutput, all_data: AllData, ctx_len: int, - n_random: int, + n_not_active: int, ): """ Generate random activation windows and update the latent record. Args: record (LatentRecord): The latent record to update. + not_active_mask (TensorType["n_windows"]): The mask of the non-active windows. tokens (TensorType["batch", "seq"]): The input tokens. - buffer_output (BufferOutput): The buffer output containing activations and locations. + all_data (AllData): The all data containing activations and locations. ctx_len (int): The context length. n_random (int): The number of random examples to generate. """ torch.manual_seed(22) - if n_random == 0: + if n_not_active == 0: + record.not_active = [] return assert record.neighbours is not None, "Neighbours are not set, add them via a transform" - batch_size = tokens.shape[0] - - # Get the unique batch positions where the latent is active - unique_batch_pos_active = buffer_output.locations[:, 0].unique_consecutive() - - mask = torch.zeros(batch_size, dtype=torch.bool) - + cache_token_length = tokens.shape[1] + reshaped_tokens = tokens.reshape(-1, ctx_len) + n_windows = reshaped_tokens.shape[0] # TODO: For now we use at most 10 examples per neighbour, we may want to allow a variable number of examples per neighbour n_examples_per_neighbour = 10 @@ -156,7 +183,7 @@ def neighbour_random_activation_windows( available_features = all_data.features all_examples = [] for neighbour in record.neighbours: - if number_examples >= n_random: + if number_examples >= n_not_active: break # find indice in all_data.features that matches the neighbour indice = torch.where(available_features == neighbour.feature_index)[0] @@ -164,135 +191,72 @@ def neighbour_random_activation_windows( continue # get the locations of the neighbour locations = all_data.locations[indice] - # get the unique locations - unique_locations = locations[:,0].unique_consecutive(dim=0) - + activations = all_data.activations[indice] + # get the active window indices + flat_indices = locations[:, 0] * cache_token_length + locations[:, 1] + ctx_indices = flat_indices // ctx_len + index_within_ctx = flat_indices % ctx_len # Set the mask to True for the unique locations - mask[unique_locations] = True - - # Set the mask to False for the unique locations where the latent is active - # TODO: we probably want to be less strict here, we could use parts of the batch where the latent is not active - mask[unique_batch_pos_active] = False - - available_indices = mask.nonzero().flatten() - - if available_indices.numel() == 0: - continue - size = min(n_examples_per_neighbour, len(available_indices)) - - selected_indices = torch.randint(0, len(unique_locations), size=(size,)) - selected_positions = torch.randint(0, tokens.shape[1] - ctx_len, size=(size,)) - - range_indices = torch.arange(ctx_len, device=tokens.device).unsqueeze(0) # Shape: (1, ctx_len) - - # Each selected_positions gives a unique starting index. We add the range tensor to get indices for each example. - positions = selected_positions.unsqueeze(1) + range_indices - - # Get tokens - toks = tokens[selected_indices].gather(dim=1, index=positions) - - examples = prepare_examples(toks, torch.zeros_like(toks)) - number_examples += len(examples) - all_examples.append((examples, neighbour)) + unique_batch_pos_active = ctx_indices.unique() + + mask = torch.zeros(n_windows, dtype=torch.bool) + mask[unique_batch_pos_active] = True + + # Get the indices where mask is True but active_indices is False + new_mask = mask & not_active_mask + + available_indices = new_mask.nonzero().flatten() + + mask_ctx = torch.isin(ctx_indices, available_indices) + available_ctx_indices = ctx_indices[mask_ctx] + available_index_within_ctx = index_within_ctx[mask_ctx] + token_windows, activation_windows = pool_max_activation_windows(record, + activations=activations, + tokens=reshaped_tokens, + ctx_indices=available_ctx_indices, + index_within_ctx=available_index_within_ctx, + max_examples=n_examples_per_neighbour) + # use the first n_examples_per_neighbour examples, which will be the most active examples + examples_used = len(token_windows) + all_examples.append(prepare_examples(token_windows, torch.zeros_like(token_windows))) + number_examples += examples_used if len(all_examples) == 0: print("No examples found") - record.random_examples = all_examples - + record.not_active = all_examples -def default_constructor( +def random_non_activating_windows( record: LatentRecord, - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None, - buffer_output: BufferOutput, + available_indices: TensorType["n_windows"], + reshaped_tokens: TensorType["n_windows", "ctx_len"], n_not_active: int, - ctx_len: int, - max_examples: int, ): """ - Construct latent examples using pool max activation windows and random activation windows. + Generate random non-activating sequence windows and update the latent record. Args: record (LatentRecord): The latent record to update. - token_loader (Optional[Callable[[], TensorType["batch", "seq"]]]): - An optional function that creates the dataset tokens. - buffer_output (BufferOutput): The buffer output containing activations and locations. - n_not_active (int): Number of non-activating examples to randomly generate. - ctx_len (int): Context length for each example. - max_examples (int): Maximum number of examples to generate. + available_indices (TensorType["n_windows"]): The indices of the windows where the latent is not active. + reshaped_tokens (TensorType["n_windows", "ctx_len"]): The tokens reshaped to the context length. + n_not_active (int): The number of non activating examples to generate. """ - tokens = buffer_output.tokens - if tokens is None: - if token_loader is None: - raise ValueError("Either tokens or token_loader must be provided") - try: - tokens = token_loader() - except TypeError: - raise ValueError( - "Starting with v0.2, `tokens` was renamed to `token_loader`, " - "which must be a callable for lazy loading.\n\n" - "Instead of passing\n" - "` tokens=dataset.tokens`,\n" - "pass\n" - "` token_loader=lambda: dataset.load_tokens()`,\n" - "(assuming `dataset` is a `LatentDataset` instance)." - ) - pool_max_activation_windows( - record, - buffer_output=buffer_output, - tokens=tokens, - ctx_len=ctx_len, - max_examples=max_examples, - ) - random_non_activating_windows( - record, - tokens=tokens, - buffer_output=buffer_output, - n_not_active=n_not_active, - ctx_len=ctx_len, - ) + torch.manual_seed(22) + if n_not_active == 0: + record.not_active = [] + return + + # If this happens it means that the latent is active in every window, so it is a bad latent + if available_indices.numel() < n_not_active: + print("No available randomly sampled non-activating sequences") + record.not_active = [] + return + else: + selected_indices = available_indices[torch.randint(0, available_indices.shape[0], size=(n_not_active,))] + + toks = reshaped_tokens[selected_indices] -def neighbour_constructor( - record: LatentRecord, - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]], - buffer_output: BufferOutput, - all_data: AllData, - n_random: int, - ctx_len: int, - max_examples: int, -): - """ - Construct feature examples using pool max activation windows and random activation windows from neighbours. - """ - tokens = all_data.tokens - if tokens is None: - if token_loader is None: - raise ValueError("Either tokens or token_loader must be provided") - try: - tokens = token_loader() - except TypeError: - raise ValueError( - "Starting with v0.2, `tokens` was renamed to `token_loader`, " - "which must be a callable for lazy loading.\n\n" - "Instead of passing\n" - "` tokens=dataset.tokens`,\n" - "pass\n" - "` token_loader=lambda: dataset.load_tokens()`,\n" - "(assuming `dataset` is a `FeatureDataset` instance)." - ) - pool_max_activation_windows( - record, - buffer_output=buffer_output, - tokens=tokens, - ctx_len=ctx_len, - max_examples=max_examples, - ) - neighbour_random_activation_windows( - record, - tokens=tokens, - buffer_output=buffer_output, - all_data=all_data, - n_random=n_random, - ctx_len=ctx_len, - ) - \ No newline at end of file + record.not_active = prepare_examples( + toks, + torch.zeros_like(toks), + ) \ No newline at end of file From 93120cd11609862afcb6973266984b5604dd9510 Mon Sep 17 00:00:00 2001 From: Lucia Quirke Date: Tue, 11 Feb 2025 14:05:09 +0000 Subject: [PATCH 071/132] Add semantic index --- delphi/__main__.py | 54 +++++++-- delphi/explainers/__init__.py | 3 +- delphi/explainers/contrastive_explainer.py | 121 +++++++++++++++++++++ delphi/latents/cache.py | 7 +- delphi/semantic_index/index.py | 91 ++++++++++++++++ 5 files changed, 260 insertions(+), 16 deletions(-) create mode 100644 delphi/explainers/contrastive_explainer.py create mode 100644 delphi/semantic_index/index.py diff --git a/delphi/__main__.py b/delphi/__main__.py index 8ea8d69c..f5e3d809 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -8,7 +8,9 @@ from multiprocessing import cpu_count import asyncio import os +import json +from datasets import Dataset from simple_parsing import ArgumentParser, field import torch from torch import Tensor @@ -26,7 +28,7 @@ from simple_parsing import field, list_field from delphi.config import ExperimentConfig, LatentConfig -from delphi.explainers import DefaultExplainer +from delphi.explainers import DefaultExplainer, ContrastiveExplainer from delphi.latents import LatentDataset, LatentLoader from delphi.latents.constructors import default_constructor from delphi.latents.samplers import sample @@ -41,6 +43,7 @@ from delphi.autoencoders.DeepMind import JumpReLUSAE from delphi.autoencoders.wrapper import AutoencoderLatents from delphi.log.result_analysis import log_results +from delphi.semantic_index.index import build_or_load_index, load_index @dataclass @@ -120,6 +123,9 @@ class RunConfig: """Number of examples to use for each scorer prompt. Using more than 1 improves scoring speed but can leak information to the fuzzing scorer, and increases the scorer LLM task difficulty.""" + semantic_index: bool = False + """Whether to build semantic index of token sequences.""" + def load_gemma_autoencoders(model, ae_layers: list[int],average_l0s: Dict[int,int],size:str,type:str, hookpoints): submodules = {} @@ -206,10 +212,12 @@ def load_artifacts(run_cfg: RunConfig): return run_cfg.hookpoints, submodule_name_to_submodule, model, model.tokenizer -async def process_cache( +async def run_pipeline( + cache_cfg: CacheConfig, latent_cfg: LatentConfig, run_cfg: RunConfig, experiment_cfg: ExperimentConfig, + base_path: Path, latents_path: Path, explanations_path: Path, scores_path: Path, @@ -246,6 +254,9 @@ async def process_cache( tokenizer=tokenizer, ) + if run_cfg.semantic_index: + index = load_index(base_path, cache_cfg) + if run_cfg.explainer_provider == "offline": client = Offline( run_cfg.explainer_model, @@ -284,14 +295,21 @@ def explainer_postprocess(result): f.write(orjson.dumps(result.explanation)) return result - explainer_pipe = process_wrapper( - DefaultExplainer( + if run_cfg.semantic_index: + explainer = ContrastiveExplainer( client, tokenizer=dataset.tokenizer, + index=index, threshold=0.3, - ), - postprocess=explainer_postprocess, - ) + ) + else: + explainer = DefaultExplainer( + client, + tokenizer=dataset.tokenizer, + threshold=0.3, + ) + + explainer_pipe = process_wrapper(explainer, postprocess=explainer_postprocess) # Builds the record from result returned by the pipeline def scorer_preprocess(result): @@ -341,24 +359,31 @@ def scorer_postprocess(result, score_dir): await pipeline.run(run_cfg.pipeline_num_proc) -def populate_cache( +def prepare_data( run_cfg: RunConfig, cfg: CacheConfig, hooked_model: LanguageModel, submodule_name_to_submodule: dict[str, nn.Module], latents_path: Path, + base_path: Path, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, filter_bos: bool, ): """ - Populates an on-disk cache in `latents_path` with SAE latent activations. + Populates an on-disk cache in `latents_path` with SAE latent activations. + Optionally builds a semantic index of token sequences. """ latents_path.mkdir(parents=True, exist_ok=True) data = load_dataset( cfg.dataset_repo, name=cfg.dataset_name, split=cfg.dataset_split ) + data = assert_type(Dataset, data) data = data.shuffle(run_cfg.seed) + + if run_cfg.semantic_index: + build_or_load_index(data, base_path, cfg) + data = chunk_and_tokenize( data, tokenizer, max_seq_len=cfg.ctx_len, text_key=cfg.dataset_row ) @@ -420,28 +445,35 @@ async def run(experiment_cfg: ExperimentConfig, latent_cfg: LatentConfig, cache_ not glob(str(latents_path / ".*")) + glob(str(latents_path / "*")) or "cache" in run_cfg.overwrite ): - populate_cache( + prepare_data( run_cfg, cache_cfg, hooked_model, submodule_name_to_submodule, latents_path, + base_path, tokenizer, filter_bos=run_cfg.filter_bos, ) else: print(f"Files found in {latents_path}, skipping cache population...") + if run_cfg.semantic_index: + index = load_index(base_path, cache_cfg) + del hooked_model, submodule_name_to_submodule if ( not glob(str(scores_path / ".*")) + glob(str(scores_path / "*")) or "scores" in run_cfg.overwrite ): - await process_cache( + + await run_pipeline( + cache_cfg, latent_cfg, run_cfg, experiment_cfg, + base_path, latents_path, explanations_path, scores_path, diff --git a/delphi/explainers/__init__.py b/delphi/explainers/__init__.py index caa4bcdf..9f843330 100644 --- a/delphi/explainers/__init__.py +++ b/delphi/explainers/__init__.py @@ -1,5 +1,6 @@ from .default.default import DefaultExplainer from .single_token_explainer import SingleTokenExplainer from .explainer import Explainer, explanation_loader, random_explanation_loader +from .contrastive_explainer import ContrastiveExplainer -__all__ = ["Explainer", "DefaultExplainer", "SingleTokenExplainer", "explanation_loader", "random_explanation_loader"] +__all__ = ["Explainer", "DefaultExplainer", "ContrastiveExplainer", "SingleTokenExplainer", "explanation_loader", "random_explanation_loader"] diff --git a/delphi/explainers/contrastive_explainer.py b/delphi/explainers/contrastive_explainer.py new file mode 100644 index 00000000..ddb85886 --- /dev/null +++ b/delphi/explainers/contrastive_explainer.py @@ -0,0 +1,121 @@ +import re +import asyncio + +import faiss + +from delphi.explainers.explainer import Explainer, ExplainerResult +from delphi.explainers.default.prompt_builder import build_single_token_prompt +from delphi.logger import logger + +class ContrastiveExplainer(Explainer): + name = "contrastive" + + def __init__( + self, + client, + tokenizer, + index: faiss.IndexFlatL2, + verbose: bool = False, + activations: bool = False, + cot: bool = False, + threshold: float = 0.6, + temperature: float = 0., + **generation_kwargs, + ): + self.client = client + self.tokenizer = tokenizer + self.index = index + self.verbose = verbose + + self.activations = activations + self.cot = cot + self.threshold = threshold + self.temperature = temperature + self.generation_kwargs = generation_kwargs + + async def __call__(self, record): + + messages = self._build_prompt(record.train) + + response = await self.client.generate( + messages, temperature=self.temperature, **self.generation_kwargs + ) + + try: + explanation = self.parse_explanation(response.text) + if self.verbose: + return ( + messages[-1]["content"], + response, + ExplainerResult(record=record, explanation=explanation), + ) + + return ExplainerResult(record=record, explanation=explanation) + except Exception as e: + logger.error(f"Explanation parsing failed: {e}") + return ExplainerResult(record=record, explanation="Explanation could not be parsed.") + + def parse_explanation(self, text: str) -> str: + try: + match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) + return match.group(1).strip() if match else "Explanation could not be parsed." + except Exception as e: + logger.error(f"Explanation parsing regex failed: {e}") + raise + + def _highlight(self, index, example): + # result = f"Example {index}: " + result = f"" + threshold = example.max_activation * self.threshold + if self.tokenizer is not None: + str_toks = self.tokenizer.batch_decode(example.tokens) + example.str_toks = str_toks + else: + str_toks = example.tokens + example.str_toks = str_toks + activations = example.activations + + def check(i): + return activations[i] > threshold + + i = 0 + while i < len(str_toks): + if check(i): + # result += "<<" + + while i < len(str_toks) and check(i): + result += str_toks[i] + i += 1 + # result += ">>" + else: + # result += str_toks[i] + i += 1 + + return "".join(result) + + def _join_activations(self, example): + activations = [] + + for i, activation in enumerate(example.activations): + if activation > example.max_activation * self.threshold: + activations.append((example.str_toks[i], int(example.normalized_activations[i]))) + + acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) + + return "Activations: " + acts + + def _build_prompt(self, examples): + highlighted_examples = [] + + for i, example in enumerate(examples): + highlighted_examples.append(self._highlight(i + 1, example)) + + if self.activations: + highlighted_examples.append(self._join_activations(example)) + + return build_single_token_prompt( + examples=highlighted_examples, + ) + + def call_sync(self, record): + return asyncio.run(self.__call__(record)) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index ff6d0942..18123c66 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -169,17 +169,15 @@ def __init__( self.batch_size = batch_size self.width = list(submodule_dict.values())[0].ae.width - self.cache = Cache(filters, batch_size=batch_size) if filters is not None: self.filter_submodules(filters) - def load_token_batches( self, n_tokens: int, tokens: TensorType["batch", "sequence"] ): """ - Load and prepare token batches for processing. + Split tokens into a list of token batches. Args: n_tokens (int): Total number of tokens to process. @@ -215,7 +213,7 @@ def filter_submodules(self, filters: Dict[str, TensorType["indices"]]): def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): """ - Run the latent caching process. + Cache latents from the model. Args: n_tokens (int): Total number of tokens to process. @@ -249,6 +247,7 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): print(f"Total tokens processed: {total_tokens:,}") self.cache.save() + def save(self, save_dir, save_tokens: bool = True): """ Save the cached latents to disk. diff --git a/delphi/semantic_index/index.py b/delphi/semantic_index/index.py new file mode 100644 index 00000000..2ef11b5d --- /dev/null +++ b/delphi/semantic_index/index.py @@ -0,0 +1,91 @@ +import json +from pathlib import Path + +from torch.utils.data import DataLoader +import faiss +import torch +from torch import Tensor +from sparsify.data import chunk_and_tokenize +import numpy as np +from datasets import Dataset +from transformers import AutoTokenizer, AutoModel + +from delphi.config import CacheConfig +from delphi.utils import assert_type +from delphi.logger import logger + + +def get_neighbors_by_id(index: faiss.IndexIDMap, vector_id: int, k: int = 10): + # First reconstruct the vector for the given ID + vector = index.reconstruct(vector_id) + + # Reshape to match FAISS expectations (needs 2D array) + vector = vector.reshape(1, -1) + + # Search for nearest neighbors + distances, neighbor_ids = index.search(vector, k + 1) # k+1 since it will find itself + + # Remove the first result (which will be the query vector itself) + return distances[0][1:], neighbor_ids[0][1:] + +def get_index_path(base_path: Path, cfg: CacheConfig): + return base_path / f"{cfg.dataset_repo.replace('/', '_')}_{cfg.dataset_split}_{cfg.ctx_len}.idx" + + +def load_index(base_path: Path, cfg: CacheConfig) -> faiss.IndexFlatL2: + index_path = get_index_path(base_path, cfg) + return faiss.read_index(str(index_path)) + + +def save_index(index: faiss.IndexFlatL2, base_path: Path, cfg: CacheConfig): + index_path = get_index_path(base_path, cfg) + + faiss.write_index(index, str(index_path)) + + with open(index_path.with_suffix(".json"), "w") as f: + json.dump({"index_path": str(index_path), "embedding_model": "sentence-transformers/all-MiniLM-L6-v2"}, f) + + +def build_semantic_index(data: Dataset, cfg: CacheConfig): + """ + Build a semantic index of the token sequences. + """ + index_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + index_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to("cuda") + + index_tokens = chunk_and_tokenize(data, index_tokenizer, max_seq_len=cfg.ctx_len, text_key=cfg.dataset_row) + index_tokens = index_tokens["input_ids"] + index_tokens = assert_type(Tensor, index_tokens) + + token_embeddings = index_model(index_tokens[:2].to("cuda")).last_hidden_state + + base_index = faiss.IndexFlatL2(token_embeddings.shape[-1]) + index = faiss.IndexIDMap(base_index) + + batch_size = 512 + dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore + + from tqdm import tqdm + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(dataloader)): + batch = batch.to("cuda") + token_embeddings = index_model(batch).last_hidden_state + sentence_embeddings = token_embeddings.mean(dim=1) + sentence_embeddings = sentence_embeddings.cpu().numpy().astype(np.float32) + + ids = np.arange(batch_idx * batch_size, batch_idx * batch_size + len(batch)) + index.add_with_ids(sentence_embeddings, ids) + + return index + + +def build_or_load_index(data: Dataset, base_path: Path, cfg: CacheConfig): + index_path = get_index_path(base_path, cfg) + + if not index_path.exists(): + logger.info(f"Building semantic index for {cfg.dataset_repo} {cfg.dataset_split} seq_len={cfg.ctx_len}...") + index = build_semantic_index(data, cfg) + save_index(index, base_path, cfg) + return index + else: + return load_index(base_path, cfg) From aca072c6b57899e6098297d652aad058bb900605 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 14:21:01 +0000 Subject: [PATCH 072/132] Fail in a smarter way --- delphi/explainers/explainer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index c9873ecf..529683ec 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -25,14 +25,18 @@ def __call__(self, record: LatentRecord) -> ExplainerResult: async def explanation_loader(record: LatentRecord, explanation_dir: str) -> ExplainerResult: try: - async with aiofiles.open(f'{explanation_dir}/{record.feature}.txt', 'r') as f: + async with aiofiles.open(f'{explanation_dir}/{record.latent}.txt', 'r') as f: explanation = json.loads(await f.read()) return ExplainerResult( record=record, explanation=explanation ) except FileNotFoundError: - return None + print(f"No explanation found for {record.latent}") + return ExplainerResult( + record=record, + explanation="No explanation found" + ) From 784539174925ea5b93fc3220cc1a31d756c86d8b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 14:22:24 +0000 Subject: [PATCH 073/132] Fix constructor + update neighbours --- delphi/latents/constructors.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 79b20a8f..4e69545f 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -119,12 +119,11 @@ def constructor( activations = buffer_output.activations # Add activation examples to the record in place - print(activations.shape, reshaped_tokens.shape, ctx_indices.shape, index_within_ctx.shape, max_examples) - token_windows, activation_windows = pool_max_activation_windows(record, - activations=activations, + token_windows, activation_windows = pool_max_activation_windows(activations=activations, tokens=reshaped_tokens, ctx_indices=ctx_indices, index_within_ctx=index_within_ctx, + ctx_len=ctx_len, max_examples=max_examples) record.examples = prepare_examples(token_windows, activation_windows) @@ -182,6 +181,7 @@ def neighbour_non_activation_windows( number_examples = 0 available_features = all_data.features all_examples = [] + used_neighbours = [] for neighbour in record.neighbours: if number_examples >= n_not_active: break @@ -210,17 +210,19 @@ def neighbour_non_activation_windows( mask_ctx = torch.isin(ctx_indices, available_indices) available_ctx_indices = ctx_indices[mask_ctx] available_index_within_ctx = index_within_ctx[mask_ctx] - token_windows, activation_windows = pool_max_activation_windows(record, - activations=activations, + activations = activations[mask_ctx] + token_windows, activation_windows = pool_max_activation_windows(activations=activations, tokens=reshaped_tokens, ctx_indices=available_ctx_indices, index_within_ctx=available_index_within_ctx, - max_examples=n_examples_per_neighbour) + max_examples=n_examples_per_neighbour, + ctx_len=ctx_len) # use the first n_examples_per_neighbour examples, which will be the most active examples examples_used = len(token_windows) all_examples.append(prepare_examples(token_windows, torch.zeros_like(token_windows))) + used_neighbours.append(neighbour) number_examples += examples_used - + record.neighbours = used_neighbours if len(all_examples) == 0: print("No examples found") From d3b3d71d6866bf3aa88830d7e6d8f7e1ac708aec Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 14:22:43 +0000 Subject: [PATCH 074/132] Update fuzz and detection logic for neighbours --- delphi/scorers/classifier/detection.py | 12 ++++++------ delphi/scorers/classifier/fuzz.py | 10 +++++----- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 0861e20d..0036d1de 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -50,23 +50,23 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: Prepare and shuffle a list of samples for classification. """ - # check if random_examples is a list of lists or a list of examples - if isinstance(record.random_examples[0], tuple): + # check if not_active is a list of lists or a list of examples + if isinstance(record.not_active[0], list): # Here we are using neighbours samples = [] - for i, (examples, neighbour) in enumerate(record.random_examples): + for i, examples in enumerate(record.not_active): samples.extend( examples_to_samples( examples, - distance=-neighbour.distance, + distance=-record.neighbours[i].distance, ground_truth=False, tokenizer=self.tokenizer, ) ) - elif isinstance(record.random_examples[0], Example): + elif isinstance(record.not_active[0], Example): # This is if we dont use neighbours samples = examples_to_samples( - record.random_examples, + record.not_active, distance=-1, ground_truth=False, tokenizer=self.tokenizer, diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 80aafaa3..5fbe6019 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -81,23 +81,23 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: all_examples.extend(examples_chunk) n_incorrect = self.mean_n_activations_ceil(all_examples) - if isinstance(record.random_examples[0], tuple): + if isinstance(record.not_active[0], list): # Here we are using neighbours samples = [] - for i, (examples, neighbour) in enumerate(record.random_examples): + for i, examples in enumerate(record.not_active): samples.extend( examples_to_samples( examples, - distance=-neighbour.distance, + distance=-record.neighbours[i].distance, ground_truth=False, n_incorrect=n_incorrect, **defaults, ) ) - elif isinstance(record.random_examples[0], Example): + elif isinstance(record.not_active[0], Example): # This is if we dont use neighbours samples = examples_to_samples( - record.random_examples, + record.not_active, distance=-1, ground_truth=False, n_incorrect=n_incorrect, From 873f185c740b1729910532ba7437cd6041a627c7 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 14:29:36 +0000 Subject: [PATCH 075/132] Fix neighbours non activating --- delphi/latents/constructors.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 4e69545f..701a8686 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -211,6 +211,9 @@ def neighbour_non_activation_windows( available_ctx_indices = ctx_indices[mask_ctx] available_index_within_ctx = index_within_ctx[mask_ctx] activations = activations[mask_ctx] + # If there are no available indices, skip this neighbour + if activations.numel() == 0: + continue token_windows, activation_windows = pool_max_activation_windows(activations=activations, tokens=reshaped_tokens, ctx_indices=available_ctx_indices, From 60623d07984567d03d3bc8c34e4ff7ebe54120d0 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 11 Feb 2025 15:28:25 +0000 Subject: [PATCH 076/132] Moving code to class --- delphi/explainers/contrastive_explainer.py | 50 +------------- delphi/explainers/default/default.py | 70 ------------------- delphi/explainers/explainer.py | 76 +++++++++++++++++++-- delphi/explainers/single_token_explainer.py | 70 ------------------- 4 files changed, 73 insertions(+), 193 deletions(-) diff --git a/delphi/explainers/contrastive_explainer.py b/delphi/explainers/contrastive_explainer.py index ddb85886..a2f68706 100644 --- a/delphi/explainers/contrastive_explainer.py +++ b/delphi/explainers/contrastive_explainer.py @@ -34,6 +34,7 @@ def __init__( self.generation_kwargs = generation_kwargs async def __call__(self, record): + # Need to change __call__ to use index messages = self._build_prompt(record.train) @@ -55,55 +56,6 @@ async def __call__(self, record): logger.error(f"Explanation parsing failed: {e}") return ExplainerResult(record=record, explanation="Explanation could not be parsed.") - def parse_explanation(self, text: str) -> str: - try: - match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) - return match.group(1).strip() if match else "Explanation could not be parsed." - except Exception as e: - logger.error(f"Explanation parsing regex failed: {e}") - raise - - def _highlight(self, index, example): - # result = f"Example {index}: " - result = f"" - threshold = example.max_activation * self.threshold - if self.tokenizer is not None: - str_toks = self.tokenizer.batch_decode(example.tokens) - example.str_toks = str_toks - else: - str_toks = example.tokens - example.str_toks = str_toks - activations = example.activations - - def check(i): - return activations[i] > threshold - - i = 0 - while i < len(str_toks): - if check(i): - # result += "<<" - - while i < len(str_toks) and check(i): - result += str_toks[i] - i += 1 - # result += ">>" - else: - # result += str_toks[i] - i += 1 - - return "".join(result) - - def _join_activations(self, example): - activations = [] - - for i, activation in enumerate(example.activations): - if activation > example.max_activation * self.threshold: - activations.append((example.str_toks[i], int(example.normalized_activations[i]))) - - acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) - - return "Activations: " + acts - def _build_prompt(self, examples): highlighted_examples = [] diff --git a/delphi/explainers/default/default.py b/delphi/explainers/default/default.py index 20725fa6..b1d9b2c8 100644 --- a/delphi/explainers/default/default.py +++ b/delphi/explainers/default/default.py @@ -29,76 +29,6 @@ def __init__( self.temperature = temperature self.generation_kwargs = generation_kwargs - async def __call__(self, record): - - messages = self._build_prompt(record.train) - - response = await self.client.generate( - messages, temperature=self.temperature, **self.generation_kwargs - ) - - try: - explanation = self.parse_explanation(response.text) - if self.verbose: - return ( - messages[-1]["content"], - response, - ExplainerResult(record=record, explanation=explanation), - ) - - return ExplainerResult(record=record, explanation=explanation) - except Exception as e: - logger.error(f"Explanation parsing failed: {e}") - return ExplainerResult(record=record, explanation="Explanation could not be parsed.") - - def parse_explanation(self, text: str) -> str: - try: - match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) - return match.group(1).strip() if match else "Explanation could not be parsed." - except Exception as e: - logger.error(f"Explanation parsing regex failed: {e}") - raise - - def _highlight(self, index, example): - result = f"Example {index}: " - - threshold = example.max_activation * self.threshold - if self.tokenizer is not None: - str_toks = self.tokenizer.batch_decode(example.tokens) - example.str_toks = str_toks - else: - str_toks = example.tokens - example.str_toks = str_toks - activations = example.activations - - def check(i): - return activations[i] > threshold - - i = 0 - while i < len(str_toks): - if check(i): - result += "<<" - - while i < len(str_toks) and check(i): - result += str_toks[i] - i += 1 - result += ">>" - else: - result += str_toks[i] - i += 1 - - return "".join(result) - - def _join_activations(self, example): - activations = [] - - for i, activation in enumerate(example.activations): - if activation > example.max_activation * self.threshold: - activations.append((example.str_toks[i], int(example.normalized_activations[i]))) - - acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) - - return "Activations: " + acts def _build_prompt(self, examples): highlighted_examples = [] diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 74ba1e4d..5263933d 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -3,11 +3,12 @@ import random from abc import ABC, abstractmethod from typing import NamedTuple +import re import aiofiles from ..latents.latents import LatentRecord - +from ..logger import logger class ExplainerResult(NamedTuple): record: LatentRecord @@ -18,9 +19,76 @@ class ExplainerResult(NamedTuple): class Explainer(ABC): - @abstractmethod - def __call__(self, record: LatentRecord) -> ExplainerResult: - pass + async def __call__(self, record): + + messages = self._build_prompt(record.train) + + response = await self.client.generate( + messages, temperature=self.temperature, **self.generation_kwargs + ) + + try: + explanation = self.parse_explanation(response.text) + if self.verbose: + return ( + messages[-1]["content"], + response, + ExplainerResult(record=record, explanation=explanation), + ) + + return ExplainerResult(record=record, explanation=explanation) + except Exception as e: + logger.error(f"Explanation parsing failed: {e}") + return ExplainerResult(record=record, explanation="Explanation could not be parsed.") + + def parse_explanation(self, text: str) -> str: + try: + match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) + return match.group(1).strip() if match else "Explanation could not be parsed." + except Exception as e: + logger.error(f"Explanation parsing regex failed: {e}") + raise + + def _highlight(self, index, example): + # result = f"Example {index}: " + result = "" + threshold = example.max_activation * self.threshold + if self.tokenizer is not None: + str_toks = self.tokenizer.batch_decode(example.tokens) + example.str_toks = str_toks + else: + str_toks = example.tokens + example.str_toks = str_toks + activations = example.activations + + def check(i): + return activations[i] > threshold + + i = 0 + while i < len(str_toks): + if check(i): + # result += "<<" + + while i < len(str_toks) and check(i): + result += str_toks[i] + i += 1 + # result += ">>" + else: + # result += str_toks[i] + i += 1 + + return "".join(result) + + def _join_activations(self, example): + activations = [] + + for i, activation in enumerate(example.activations): + if activation > example.max_activation * self.threshold: + activations.append((example.str_toks[i], int(example.normalized_activations[i]))) + + acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) + + return "Activations: " + acts async def explanation_loader(record: LatentRecord, explanation_dir: str) -> ExplainerResult: diff --git a/delphi/explainers/single_token_explainer.py b/delphi/explainers/single_token_explainer.py index bde3b7e6..a335bb5d 100644 --- a/delphi/explainers/single_token_explainer.py +++ b/delphi/explainers/single_token_explainer.py @@ -29,76 +29,6 @@ def __init__( self.temperature = temperature self.generation_kwargs = generation_kwargs - async def __call__(self, record): - - messages = self._build_prompt(record.train) - - response = await self.client.generate( - messages, temperature=self.temperature, **self.generation_kwargs - ) - - try: - explanation = self.parse_explanation(response.text) - if self.verbose: - return ( - messages[-1]["content"], - response, - ExplainerResult(record=record, explanation=explanation), - ) - - return ExplainerResult(record=record, explanation=explanation) - except Exception as e: - logger.error(f"Explanation parsing failed: {e}") - return ExplainerResult(record=record, explanation="Explanation could not be parsed.") - - def parse_explanation(self, text: str) -> str: - try: - match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) - return match.group(1).strip() if match else "Explanation could not be parsed." - except Exception as e: - logger.error(f"Explanation parsing regex failed: {e}") - raise - - def _highlight(self, index, example): - # result = f"Example {index}: " - result = f"" - threshold = example.max_activation * self.threshold - if self.tokenizer is not None: - str_toks = self.tokenizer.batch_decode(example.tokens) - example.str_toks = str_toks - else: - str_toks = example.tokens - example.str_toks = str_toks - activations = example.activations - - def check(i): - return activations[i] > threshold - - i = 0 - while i < len(str_toks): - if check(i): - # result += "<<" - - while i < len(str_toks) and check(i): - result += str_toks[i] - i += 1 - # result += ">>" - else: - # result += str_toks[i] - i += 1 - - return "".join(result) - - def _join_activations(self, example): - activations = [] - - for i, activation in enumerate(example.activations): - if activation > example.max_activation * self.threshold: - activations.append((example.str_toks[i], int(example.normalized_activations[i]))) - - acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) - - return "Activations: " + acts def _build_prompt(self, examples): highlighted_examples = [] From 910f45b41ebca1f4d8827229e181a66b1ee7d0e1 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 12 Feb 2025 15:01:00 +0000 Subject: [PATCH 077/132] Mask name change --- delphi/latents/constructors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 701a8686..0b1b3230 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -202,10 +202,10 @@ def neighbour_non_activation_windows( mask = torch.zeros(n_windows, dtype=torch.bool) mask[unique_batch_pos_active] = True - # Get the indices where mask is True but active_indices is False - new_mask = mask & not_active_mask + # Get the indices where mask and not_active_mask are True + mask = mask & not_active_mask - available_indices = new_mask.nonzero().flatten() + available_indices = mask.nonzero().flatten() mask_ctx = torch.isin(ctx_indices, available_indices) available_ctx_indices = ctx_indices[mask_ctx] From 498a0bc221c47ce56b4f0443a113b1b52e40a166 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 12 Feb 2025 15:01:11 +0000 Subject: [PATCH 078/132] Remove debug pring --- delphi/latents/loader.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 92c8580e..eba08f36 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -321,10 +321,8 @@ async def __aiter__(self): for buffer in self.latent_dataset.buffers: for data in buffer: if data is not None: - record = await self._aprocess_latent(data) if record is not None: - print() yield record await asyncio.sleep(0) From c87873388f8532a80cba391471a6ba5f45a03d60 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 12 Feb 2025 15:01:58 +0000 Subject: [PATCH 079/132] Deal with the case where there's only active --- delphi/scorers/classifier/detection.py | 39 ++++++++++++----------- delphi/scorers/classifier/fuzz.py | 44 ++++++++++++++------------ 2 files changed, 44 insertions(+), 39 deletions(-) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 0036d1de..a9411805 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -51,26 +51,29 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: """ # check if not_active is a list of lists or a list of examples - if isinstance(record.not_active[0], list): - # Here we are using neighbours - samples = [] - for i, examples in enumerate(record.not_active): - samples.extend( - examples_to_samples( - examples, - distance=-record.neighbours[i].distance, - ground_truth=False, - tokenizer=self.tokenizer, + if len(record.not_active) > 0: + if isinstance(record.not_active[0], list): + # Here we are using neighbours + samples = [] + for i, examples in enumerate(record.not_active): + samples.extend( + examples_to_samples( + examples, + distance=-record.neighbours[i].distance, + ground_truth=False, + tokenizer=self.tokenizer, + ) ) + elif isinstance(record.not_active[0], Example): + # This is if we dont use neighbours + samples = examples_to_samples( + record.not_active, + distance=-1, + ground_truth=False, + tokenizer=self.tokenizer, ) - elif isinstance(record.not_active[0], Example): - # This is if we dont use neighbours - samples = examples_to_samples( - record.not_active, - distance=-1, - ground_truth=False, - tokenizer=self.tokenizer, - ) + else: + samples = [] for i, examples in enumerate(record.test): samples.extend( diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 5fbe6019..1e7fdb9e 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -81,29 +81,31 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: all_examples.extend(examples_chunk) n_incorrect = self.mean_n_activations_ceil(all_examples) - if isinstance(record.not_active[0], list): - # Here we are using neighbours - samples = [] - for i, examples in enumerate(record.not_active): - samples.extend( - examples_to_samples( - examples, - distance=-record.neighbours[i].distance, - ground_truth=False, - n_incorrect=n_incorrect, - **defaults, + if len(record.not_active) > 0: + if isinstance(record.not_active[0], list): + # Here we are using neighbours + samples = [] + for i, examples in enumerate(record.not_active): + samples.extend( + examples_to_samples( + examples, + distance=-record.neighbours[i].distance, + ground_truth=False, + n_incorrect=n_incorrect, + **defaults, + ) ) + elif isinstance(record.not_active[0], Example): + # This is if we dont use neighbours + samples = examples_to_samples( + record.not_active, + distance=-1, + ground_truth=False, + n_incorrect=n_incorrect, + **defaults, ) - elif isinstance(record.not_active[0], Example): - # This is if we dont use neighbours - samples = examples_to_samples( - record.not_active, - distance=-1, - ground_truth=False, - n_incorrect=n_incorrect, - **defaults, - ) - + else: + samples = [] for i, examples in enumerate(record.test): samples.extend( From ee2d98ba462df589f02c6466dbab9bff1191c33b Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 12 Feb 2025 16:36:55 +0000 Subject: [PATCH 080/132] typehint --- delphi/latents/neighbours.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 765a9dcd..214e8d3f 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -5,7 +5,7 @@ import numpy as np import torch from safetensors.numpy import load_file - +from delphi.latents.latent_dataset import LatentDataset class NeighbourCalculator: """ From 920d669e4ab897aedafc20f95375e9e2931c9f52 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 13 Feb 2025 09:37:14 +0000 Subject: [PATCH 081/132] Working on semantic_index --- delphi/semantic_index/index.py | 195 ++++++++++++++++++++++++++++++--- 1 file changed, 177 insertions(+), 18 deletions(-) diff --git a/delphi/semantic_index/index.py b/delphi/semantic_index/index.py index 2ef11b5d..6dc9034f 100644 --- a/delphi/semantic_index/index.py +++ b/delphi/semantic_index/index.py @@ -1,19 +1,190 @@ import json from pathlib import Path +from typing import Optional -from torch.utils.data import DataLoader import faiss -import torch -from torch import Tensor -from sparsify.data import chunk_and_tokenize import numpy as np +import torch from datasets import Dataset -from transformers import AutoTokenizer, AutoModel +from safetensors.numpy import load_file +from sparsify.data import chunk_and_tokenize +from torch import Tensor +from torch.utils.data import DataLoader +from tqdm import tqdm +from transformers import AutoModel, AutoTokenizer from delphi.config import CacheConfig -from delphi.utils import assert_type +from delphi.latents import LatentDataset from delphi.logger import logger +from delphi.utils import assert_type +import time + + +#TODO: Think if this code should be moved to neighbours + +class AdversarialContexts: + """ + Class to compute the neighbours of selected latents using different methods: + - similarity: uses autoencoder weights + - correlation: uses pre-activation records and autoencoder + - co-occurrence: uses latent dataset statistics + """ + + def __init__( + self, + + latent_dataset: Optional['LatentDataset'] = None, + index_path: Optional[str] = None, + context_length: int = 32, + number_of_neighbours: int = 50, + ): + + """ + Initialize a NeighbourCalculator. + + Args: + latent_dataset (Optional[LatentDataset]): Dataset containing latent activations + autoencoder (Optional[Autoencoder]): The trained autoencoder model + residual_stream_record (Optional[ResidualStreamRecord]): Record of residual stream values + """ + self.latent_dataset = latent_dataset + self.context_length = context_length + self.index_path = index_path + # try to load index + self.index = self.load_index(index_path) + + + def _compute_similar_contexts(self) -> dict[int, list[int]]: + """ + Compute neighbour lists based on feature co-occurrence in the dataset. + If you run out of memory try reducing the token_batch_size + """ + + print("Creating index") + paths = [] + for buffer in self.latent_dataset.buffers: + paths.append(buffer.tensor_path) + + all_locations = [] + all_activations = [] + tokens = None + for path in paths: + split_data = load_file(path) + first_feature = int(path.split("/")[-1].split("_")[0]) + locations = torch.tensor(split_data["locations"].astype(np.int64)) + locations[:,2] = locations[:,2] + first_feature + activations = torch.tensor(split_data["activations"].astype(np.float32)) + # compute number of tokens + all_locations.append(locations) + all_activations.append(activations) + if tokens is None: + tokens = split_data["tokens"] + tokens = tokens[:10000] + reshaped_tokens = tokens.reshape(-1, self.context_length) + strings = self.latent_dataset.tokenizer.batch_decode(reshaped_tokens, skip_special_tokens=True) + if self.index is None: + index = self._build_index(strings) + self.save_index(index) + else: + index = self.index + + locations = torch.cat(all_locations) + activations = torch.cat(all_activations) + + indices = torch.argsort(locations[:,2], stable=True) + locations = locations[indices] + activations = activations[indices] + unique_latents, counts = torch.unique_consecutive(locations[:,2], return_counts=True) + cache_ctx_len = torch.max(locations[:,1])+1 + + latents = unique_latents + split_locations = torch.split(locations, counts.tolist()) + split_activations = torch.split(activations, counts.tolist()) + latents = unique_latents + dict_of_adversarial_contexts = {} + for latent, locations, activations in tqdm(zip(latents, split_locations, split_activations)): + flat_indices = locations[:,0]*cache_ctx_len+locations[:,1] + ctx_indices = flat_indices // self.context_length + index_within_ctx = flat_indices % self.context_length + unique_ctx_indices, inverses, lengths = torch.unique_consecutive(ctx_indices, return_counts=True, return_inverse=True) + # Get the max activation magnitude within each context window + max_buffer = torch.segment_reduce(activations, 'max', lengths=lengths) + k = 100 + _, top_indices = torch.topk(max_buffer, k, sorted=True) + top_indices = unique_ctx_indices[top_indices] + # find the context in the index, that are not activating contexts but are the most similar to top_indices + activating_contexts_indices = ctx_indices.unique() + print(top_indices.shape) + print(len(top_indices)) + query_vectors = [] + for i in range(len(top_indices)): + print(top_indices[i].item()) + breakpoint() + query_vectors.append(index.reconstruct(top_indices[i].item())) + + print(query_vectors.shape) + distances, indices = index.search(query_vectors, len(top_indices)+len(activating_contexts_indices)+self.number_of_neighbours) + filtered_indices = [] + + for index, distance in zip(indices, distances): + valid = [id_ for id_ in index if id_ not in activating_contexts_indices] + filtered_indices.append(valid[:k]) + print(len(filtered_indices)) + dict_of_adversarial_contexts[latent] = filtered_indices + self.save_adversarial_contexts(dict_of_adversarial_contexts) + + + + def _build_index(self,strings: list[str]) -> faiss.IndexIDMap: + index_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') + index_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to("cuda") + + tokenized = index_tokenizer(strings, return_tensors="pt", padding=True,max_length=512,padding_side="right",truncation=True) + index_tokens = tokenized["input_ids"] + index_initializer = index_model(index_tokens[:2].to("cuda")).last_hidden_state + + base_index = faiss.IndexFlatL2(index_initializer.shape[-1]) + index = faiss.IndexIDMap(base_index) + + + batch_size = 512 + dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore + from tqdm import tqdm + with torch.no_grad(): + for batch_idx, batch in enumerate(tqdm(dataloader)): + batch = batch.to("cuda") + token_embeddings = index_model(batch).last_hidden_state + sentence_embeddings = token_embeddings.mean(dim=1) + sentence_embeddings = sentence_embeddings.cpu().numpy().astype(np.float32) + ids = np.arange(batch_idx * batch_size, batch_idx * batch_size + len(batch)) + index.add_with_ids(sentence_embeddings, ids) + + return index + + def populate_neighbour_cache(self, methods: list[str]) -> None: + """ + Populate the neighbour cache with the computed neighbour lists + """ + for method in methods: + self._compute_neighbour_list(method) + + + def load_index(self, base_path: str) -> faiss.IndexFlatL2: + # check if index exists + index_path = base_path + "/index.faiss" + if not Path(index_path).exists(): + return None + return faiss.read_index(str(index_path)) + + def save_index(self,index: faiss.IndexFlatL2, ): + index_path = self.index_path + "/index.faiss" + + faiss.write_index(index, str(index_path)) + + def save_adversarial_contexts(self, dict_of_adversarial_contexts: dict[int, list[int]]): + with open(self.index_path + "adversarial_contexts.json", "w") as f: + json.dump(dict_of_adversarial_contexts, f) def get_neighbors_by_id(index: faiss.IndexIDMap, vector_id: int, k: int = 10): # First reconstruct the vector for the given ID @@ -32,18 +203,6 @@ def get_index_path(base_path: Path, cfg: CacheConfig): return base_path / f"{cfg.dataset_repo.replace('/', '_')}_{cfg.dataset_split}_{cfg.ctx_len}.idx" -def load_index(base_path: Path, cfg: CacheConfig) -> faiss.IndexFlatL2: - index_path = get_index_path(base_path, cfg) - return faiss.read_index(str(index_path)) - - -def save_index(index: faiss.IndexFlatL2, base_path: Path, cfg: CacheConfig): - index_path = get_index_path(base_path, cfg) - - faiss.write_index(index, str(index_path)) - - with open(index_path.with_suffix(".json"), "w") as f: - json.dump({"index_path": str(index_path), "embedding_model": "sentence-transformers/all-MiniLM-L6-v2"}, f) def build_semantic_index(data: Dataset, cfg: CacheConfig): From 8730b6c8f7990ad39b55f9b88811d1a7ad1caffd Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:16:28 +0000 Subject: [PATCH 082/132] New constructor function --- delphi/__main__.py | 36 ++------ delphi/latents/constructors.py | 151 +++++++++++++++++++-------------- 2 files changed, 98 insertions(+), 89 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index bd271328..e7e7e95e 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -26,7 +26,7 @@ from delphi.config import CacheConfig, ExperimentConfig, LatentConfig, RunConfig from delphi.explainers import DefaultExplainer from delphi.latents import LatentCache, LatentDataset -from delphi.latents.constructors import default_constructor +from delphi.latents.constructors import constructor from delphi.latents.samplers import sample from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper @@ -63,7 +63,6 @@ async def process_cache( latent_cfg: LatentConfig, run_cfg: RunConfig, experiment_cfg: ExperimentConfig, - base_path: Path, latents_path: Path, explanations_path: Path, scores_path: Path, @@ -91,10 +90,11 @@ async def process_cache( } # The latent range to explain latent_dict = cast(dict[str, int | Tensor], latent_dict) - constructor = partial( - default_constructor, + example_constructor = partial( + constructor, token_loader=None, n_not_active=experiment_cfg.n_non_activating, + constructor_type="random", ctx_len=experiment_cfg.example_ctx_len, max_examples=latent_cfg.max_examples, ) @@ -106,13 +106,10 @@ async def process_cache( modules=hookpoints, latents=latent_dict, tokenizer=tokenizer, - constructor=constructor, + constructor=example_constructor, sampler=sampler, ) - if run_cfg.semantic_index: - index = load_index(base_path, cache_cfg) - if run_cfg.explainer_provider == "offline": client = Offline( run_cfg.explainer_model, @@ -147,15 +144,8 @@ def explainer_postprocess(result): f.write(orjson.dumps(result.explanation)) return result - if run_cfg.semantic_index: - explainer = ContrastiveExplainer( - client, - tokenizer=dataset.tokenizer, - index=index, - threshold=0.3, - ) - else: - explainer = DefaultExplainer( + explainer_pipe = process_wrapper( + DefaultExplainer( client, tokenizer=dataset.tokenizer, threshold=0.3, @@ -212,13 +202,12 @@ def scorer_postprocess(result, score_dir): await pipeline.run(run_cfg.pipeline_num_proc) -def prepare_data( +def populate_cache( run_cfg: RunConfig, cfg: CacheConfig, model: PreTrainedModel, hookpoint_to_sparse_encode: dict[str, Callable], latents_path: Path, - base_path: Path, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, ): """ @@ -229,12 +218,7 @@ def prepare_data( data = load_dataset( cfg.dataset_repo, name=cfg.dataset_name, split=cfg.dataset_split ) - data = assert_type(Dataset, data) data = data.shuffle(run_cfg.seed) - - if run_cfg.semantic_index: - build_or_load_index(data, base_path, cfg) - data = chunk_and_tokenize( data, tokenizer, max_seq_len=cfg.ctx_len, text_key=cfg.dataset_column ) @@ -299,13 +283,12 @@ async def run( not glob(str(latents_path / ".*")) + glob(str(latents_path / "*")) or "cache" in run_cfg.overwrite ): - prepare_data( + populate_cache( run_cfg, cache_cfg, model, hookpoint_to_sparse_encode, latents_path, - base_path, tokenizer, ) else: @@ -321,7 +304,6 @@ async def run( latent_cfg, run_cfg, experiment_cfg, - base_path, latents_path, explanations_path, scores_path, diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 0b1b3230..b2dedd93 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -1,16 +1,18 @@ -from typing import Callable, Optional +from typing import Callable, Literal, Optional import torch from torchtyping import TensorType + from .latents import LatentRecord, prepare_examples -from .loader import BufferOutput, AllData +from .loader import ActivationData + def _top_k_pools( - max_buffer: TensorType["batch"], - split_activations: list[TensorType["activations"]], - buffer_tokens: TensorType["batch", "ctx_len"], - max_examples: int - ): + max_buffer: TensorType["batch"], + split_activations: list[TensorType["activations"]], + buffer_tokens: TensorType["batch", "ctx_len"], + max_examples: int, +): """ Get the top k activation pools. @@ -22,7 +24,8 @@ def _top_k_pools( max_examples (int): The maximum number of examples. Returns: - Tuple[TensorType["examples", "ctx_len"], TensorType["examples", "ctx_len"]]: The token windows and activation windows. + Tuple[TensorType["examples", "ctx_len"], TensorType["examples", "ctx_len"]]: + The token windows and activation windows. """ k = min(max_examples, len(max_buffer)) top_values, top_indices = torch.topk(max_buffer, k, sorted=True) @@ -52,19 +55,22 @@ def pool_max_activation_windows( ctx_len (int): The context length. max_examples (int): The maximum number of examples. """ - # unique_ctx_indices: array of distinct context window indices in order of first appearance. i.e. sequential integers from 0 to batch_size * cache_token_length // ctx_len - # inverses: maps each activation back to its index in unique_ctx_indices (can be used to dereference the context window idx of each activation) + # unique_ctx_indices: array of distinct context window indices in order of first + # appearance. sequential integers from 0 to batch_size * cache_token_length//ctx_len + # inverses: maps each activation back to its index in unique_ctx_indices + # (can be used to dereference the context window idx of each activation) # lengths: the number of activations per unique context window index - unique_ctx_indices, inverses, lengths = torch.unique_consecutive(ctx_indices, return_counts=True, return_inverse=True) - + unique_ctx_indices, inverses, lengths = torch.unique_consecutive( + ctx_indices, return_counts=True, return_inverse=True + ) + # Get the max activation magnitude within each context window - max_buffer = torch.segment_reduce(activations, 'max', lengths=lengths) + max_buffer = torch.segment_reduce(activations, "max", lengths=lengths) # Deduplicate the context windows - new_tensor= torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype) + new_tensor = torch.zeros(len(unique_ctx_indices), ctx_len, dtype=activations.dtype) new_tensor[inverses, index_within_ctx] = activations - tokens = tokens[unique_ctx_indices] token_windows, activation_windows = _top_k_pools( @@ -73,17 +79,18 @@ def pool_max_activation_windows( return token_windows, activation_windows + def constructor( record: LatentRecord, - buffer_output: BufferOutput, - all_data: AllData, + activation_data: ActivationData, n_not_active: int, max_examples: int, ctx_len: int, - constructor_type: str = "random", - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None = None + constructor_type: Literal["random", "neighbour"], + token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None = None, + all_data: Optional[ActivationData] = None, ): - tokens = all_data.tokens + tokens = activation_data.tokens if tokens is None: if token_loader is None: raise ValueError("Either tokens or token_loader must be provided") @@ -100,49 +107,54 @@ def constructor( "(assuming `dataset` is a `LatentDataset` instance)." ) - batch_size = tokens.shape[0] + tokens.shape[0] cache_token_length = tokens.shape[1] # Get all positions where the latent is active - flat_indices = buffer_output.locations[:, 0] * cache_token_length + buffer_output.locations[:, 1] + flat_indices = ( + activation_data.locations[:, 0] * cache_token_length + + activation_data.locations[:, 1] + ) ctx_indices = flat_indices // ctx_len index_within_ctx = flat_indices % ctx_len reshaped_tokens = tokens.reshape(-1, ctx_len) n_windows = reshaped_tokens.shape[0] - + unique_batch_pos = ctx_indices.unique() mask = torch.ones(n_windows, dtype=torch.bool) mask[unique_batch_pos] = False # Indices where the latent is active active_indices = mask.nonzero(as_tuple=False).squeeze() - activations = buffer_output.activations - - # Add activation examples to the record in place - token_windows, activation_windows = pool_max_activation_windows(activations=activations, - tokens=reshaped_tokens, - ctx_indices=ctx_indices, - index_within_ctx=index_within_ctx, - ctx_len=ctx_len, - max_examples=max_examples) - record.examples = prepare_examples(token_windows, activation_windows) + activations = activation_data.activations + + # Add activation examples to the record in place + token_windows, act_windows = pool_max_activation_windows( + activations=activations, + tokens=reshaped_tokens, + ctx_indices=ctx_indices, + index_within_ctx=index_within_ctx, + ctx_len=ctx_len, + max_examples=max_examples, + ) + record.examples = prepare_examples(token_windows, act_windows) if constructor_type == "random": # Add random non-activating examples to the record in place random_non_activating_windows( - record, - available_indices=active_indices, - reshaped_tokens=reshaped_tokens, - n_not_active=n_not_active, + record, + available_indices=active_indices, + reshaped_tokens=reshaped_tokens, + n_not_active=n_not_active, ) elif constructor_type == "neighbour": neighbour_non_activation_windows( - record, - not_active_mask=mask, - tokens=tokens, - all_data=all_data, - ctx_len=ctx_len, - n_not_active=n_not_active, + record, + not_active_mask=mask, + tokens=tokens, + all_data=all_data, + ctx_len=ctx_len, + n_not_active=n_not_active, ) @@ -150,7 +162,7 @@ def neighbour_non_activation_windows( record: LatentRecord, not_active_mask: TensorType["n_windows"], tokens: TensorType["batch", "seq"], - all_data: AllData, + all_data: ActivationData, ctx_len: int, n_not_active: int, ): @@ -169,13 +181,16 @@ def neighbour_non_activation_windows( if n_not_active == 0: record.not_active = [] return - - assert record.neighbours is not None, "Neighbours are not set, add them via a transform" + + assert ( + record.neighbours is not None + ), "Neighbours are not set, add them via a transform" cache_token_length = tokens.shape[1] reshaped_tokens = tokens.reshape(-1, ctx_len) n_windows = reshaped_tokens.shape[0] - # TODO: For now we use at most 10 examples per neighbour, we may want to allow a variable number of examples per neighbour + # TODO: For now we use at most 10 examples per neighbour, we may want to allow a + # variable number of examples per neighbour n_examples_per_neighbour = 10 number_examples = 0 @@ -214,15 +229,20 @@ def neighbour_non_activation_windows( # If there are no available indices, skip this neighbour if activations.numel() == 0: continue - token_windows, activation_windows = pool_max_activation_windows(activations=activations, - tokens=reshaped_tokens, - ctx_indices=available_ctx_indices, - index_within_ctx=available_index_within_ctx, - max_examples=n_examples_per_neighbour, - ctx_len=ctx_len) - # use the first n_examples_per_neighbour examples, which will be the most active examples + token_windows, act_windows = pool_max_activation_windows( + activations=activations, + tokens=reshaped_tokens, + ctx_indices=available_ctx_indices, + index_within_ctx=available_index_within_ctx, + max_examples=n_examples_per_neighbour, + ctx_len=ctx_len, + ) + # use the first n_examples_per_neighbour examples, + # which will be the most active examples examples_used = len(token_windows) - all_examples.append(prepare_examples(token_windows, torch.zeros_like(token_windows))) + all_examples.append( + prepare_examples(token_windows, torch.zeros_like(token_windows)) + ) used_neighbours.append(neighbour) number_examples += examples_used record.neighbours = used_neighbours @@ -231,6 +251,7 @@ def neighbour_non_activation_windows( record.not_active = all_examples + def random_non_activating_windows( record: LatentRecord, available_indices: TensorType["n_windows"], @@ -242,26 +263,32 @@ def random_non_activating_windows( Args: record (LatentRecord): The latent record to update. - available_indices (TensorType["n_windows"]): The indices of the windows where the latent is not active. - reshaped_tokens (TensorType["n_windows", "ctx_len"]): The tokens reshaped to the context length. + available_indices (TensorType["n_windows"]): The indices of the windows where + the latent is not active. + reshaped_tokens (TensorType["n_windows", "ctx_len"]): The tokens reshaped + to the context length. n_not_active (int): The number of non activating examples to generate. """ torch.manual_seed(22) if n_not_active == 0: record.not_active = [] return - - # If this happens it means that the latent is active in every window, so it is a bad latent + + # If this happens it means that the latent is active in every window, + # so it is a bad latent if available_indices.numel() < n_not_active: print("No available randomly sampled non-activating sequences") record.not_active = [] return else: - selected_indices = available_indices[torch.randint(0, available_indices.shape[0], size=(n_not_active,))] - + random_indices = torch.randint( + 0, available_indices.shape[0], size=(n_not_active,) + ) + selected_indices = available_indices[random_indices] + toks = reshaped_tokens[selected_indices] record.not_active = prepare_examples( toks, torch.zeros_like(toks), - ) \ No newline at end of file + ) From a89f1b77dc08a28351f2eebb5b4e5027002b4cb8 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:16:58 +0000 Subject: [PATCH 083/132] Correct format explainer --- delphi/explainers/explainer.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index f441a909..6d0d1b38 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -29,12 +29,10 @@ async def __call__(self, record): try: explanation = self.parse_explanation(response.text) if self.verbose: - return ( - messages[-1]["content"], - response, - ExplainerResult(record=record, explanation=explanation), - ) - + logger.info(f"Explanation: {explanation}") + logger.info(f"Messages: {messages[-1]['content']}") + logger.info(f"Response: {response}") + return ExplainerResult(record=record, explanation=explanation) except Exception as e: logger.error(f"Explanation parsing failed: {e}") From 672c1143f7185b0dfe647b0e18e67d99de1a79f3 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:20:14 +0000 Subject: [PATCH 084/132] Remove dataset loader --- delphi/latents/__init__.py | 5 +- delphi/latents/neighbours.py | 223 ++++++++++++++++++++++------------- 2 files changed, 140 insertions(+), 88 deletions(-) diff --git a/delphi/latents/__init__.py b/delphi/latents/__init__.py index 5db9c716..760e56a1 100644 --- a/delphi/latents/__init__.py +++ b/delphi/latents/__init__.py @@ -1,12 +1,12 @@ from .cache import LatentCache from .constructors import ( constructor, + neighbour_non_activation_windows, pool_max_activation_windows, random_non_activating_windows, - neighbour_non_activation_windows, ) from .latents import Example, Latent, LatentRecord -from .loader import LatentDataset, LatentLoader +from .loader import LatentDataset from .samplers import sample from .stats import unigram @@ -22,5 +22,4 @@ "constructor", "sample", "unigram", - "LatentLoader" ] diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 214e8d3f..d0821fbe 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -1,11 +1,13 @@ import json -from typing import Optional -from tqdm import tqdm +from typing import Optional import numpy as np import torch from safetensors.numpy import load_file -from delphi.latents.latent_dataset import LatentDataset +from tqdm import tqdm + +from delphi.latents import LatentDataset + class NeighbourCalculator: """ @@ -17,34 +19,33 @@ class NeighbourCalculator: def __init__( self, - - latent_dataset: Optional['LatentDataset'] = None, + latent_dataset: Optional["LatentDataset"] = None, autoencoder: Optional["Autoencoder"] = None, - residual_stream_record: Optional['ResidualStreamRecord'] = None, + residual_stream_record: Optional["ResidualStreamRecord"] = None, number_of_neighbours: int = 10, - neighbour_cache: Optional[dict[str, dict[int, list[tuple[int,float]]]]] = None, + neighbour_cache: Optional[dict[str, dict[int, list[tuple[int, float]]]]] = None, ): - """ Initialize a NeighbourCalculator. Args: - latent_dataset (Optional[LatentDataset]): Dataset containing latent activations + latent_dataset (Optional[LatentDataset]): Dataset + containing latent activations autoencoder (Optional[Autoencoder]): The trained autoencoder model - residual_stream_record (Optional[ResidualStreamRecord]): Record of residual stream values + residual_stream_record (Optional[ResidualStreamRecord]): Record of + residual stream values """ self.latent_dataset = latent_dataset self.autoencoder = autoencoder self.residual_stream_record = residual_stream_record self.number_of_neighbours = number_of_neighbours - # load the neighbour cache from the path + # load the neighbour cache from the path if neighbour_cache is not None: self.neighbour_cache = neighbour_cache else: # dictionary to cache computed neighbour lists self.neighbour_cache: dict[str, dict[int, list[int]]] = {} - def _compute_neighbour_list(self, method: str) -> None: """ @@ -53,27 +54,43 @@ def _compute_neighbour_list(self, method: str) -> None: Args: method (str): One of 'similarity', 'correlation', or 'co-occurrence' """ - if method == 'similarity_encoder': + if method == "similarity_encoder": if self.autoencoder is None: - raise ValueError("Autoencoder is required for similarity-based neighbours") - self.neighbour_cache[method] = self._compute_similarity_neighbours("encoder") - elif method == 'similarity_decoder': + raise ValueError( + "Autoencoder is required for similarity-based neighbours" + ) + self.neighbour_cache[method] = self._compute_similarity_neighbours( + "encoder" + ) + elif method == "similarity_decoder": if self.autoencoder is None: - raise ValueError("Autoencoder is required for similarity-based neighbours") - self.neighbour_cache[method] = self._compute_similarity_neighbours("decoder") - elif method == 'correlation': + raise ValueError( + "Autoencoder is required for similarity-based neighbours" + ) + self.neighbour_cache[method] = self._compute_similarity_neighbours( + "decoder" + ) + elif method == "correlation": if self.autoencoder is None or self.residual_stream_record is None: - raise ValueError("Autoencoder and residual stream record are required for correlation-based neighbours") + raise ValueError( + "Autoencoder and residual stream record are required " + "for correlation-based neighbours" + ) self.neighbour_cache[method] = self._compute_correlation_neighbours() - - elif method == 'co-occurrence': + + elif method == "co-occurrence": if self.latent_dataset is None: - raise ValueError("Latent dataset is required for co-occurrence-based neighbours") + raise ValueError( + "Latent dataset is required for co-occurrence-based neighbours" + ) self.neighbour_cache[method] = self._compute_cooccurrence_neighbours() - + else: - raise ValueError(f"Unknown method: {method}. Use 'similarity', 'correlation', or 'co-occurrence'") - + raise ValueError( + f"Unknown method: {method}. Use 'similarity', 'correlation', " + "or 'co-occurrence'" + ) + def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: """ Compute neighbour lists based on weight similarity in the autoencoder. @@ -82,11 +99,15 @@ def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: # We use the encoder vectors to compute the similarity between latents if method == "encoder": encoder = self.autoencoder.encoder.cuda() - weight_matrix_normalized = encoder.weight.data / encoder.weight.data.norm(dim=1, keepdim=True) - + weight_matrix_normalized = encoder.weight.data / encoder.weight.data.norm( + dim=1, keepdim=True + ) + elif method == "decoder": decoder = self.autoencoder.W_dec.cuda() - weight_matrix_normalized = decoder.data / decoder.data.norm(dim=1, keepdim=True) + weight_matrix_normalized = decoder.data / decoder.data.norm( + dim=1, keepdim=True + ) else: raise ValueError(f"Unknown method: {method}. Use 'encoder' or 'decoder'") @@ -99,115 +120,146 @@ def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: neighbour_lists = {} while not done: try: - - for start in tqdm(range(0,number_latents,batch_size)): - rows = wT[start:start+batch_size] + for start in tqdm(range(0, number_latents, batch_size)): + rows = wT[start : start + batch_size] similarity_matrix = weight_matrix_normalized @ rows - indices,values = torch.topk(similarity_matrix, self.number_of_neighbours+1, dim=1) - neighbour_lists.update({i+start: list(zip(indices[i].tolist()[1:],values[i].tolist()[1:])) for i in range(len(indices))}) + indices, values = torch.topk( + similarity_matrix, self.number_of_neighbours + 1, dim=1 + ) + neighbour_lists.update( + { + i + + start: list( + zip(indices[i].tolist()[1:], values[i].tolist()[1:]) + ) + for i in range(len(indices)) + } + ) del similarity_matrix torch.cuda.empty_cache() done = True except Exception: batch_size = batch_size // 2 if batch_size < 2: - raise ValueError("Batch size is too small to compute similarity matrix. You don't have enough memory.") - + raise ValueError( + "Batch size is too small to compute similarity matrix. " + "You don't have enough memory." + ) return neighbour_lists - def _compute_correlation_neighbours(self) -> dict[int, list[int]]: """ Compute neighbour lists based on activation correlation patterns. """ print("Computing correlation neighbours") - # the activation_matrix has the shape (number_of_samples,hidden_dimension) - - activations = torch.tensor(load_file(self.residual_stream_record+".safetensors")["activations"]) + # the activation_matrix has the shape (number_of_samples,hidden_dimension) + + activations = torch.tensor( + load_file(self.residual_stream_record + ".safetensors")["activations"] + ) - estimator = CovarianceEstimator(activations.shape[1]) # batch the activations batch_size = 10000 - for i in tqdm(range(0,activations.shape[0],batch_size)): - estimator.update(activations[i:i+batch_size]) + for i in tqdm(range(0, activations.shape[0], batch_size)): + estimator.update(activations[i : i + batch_size]) covariance_matrix = estimator.cov().cuda().half() # load the encoder encoder_matrix = self.autoencoder.encoder.weight.cuda().half() - covariance_between_latents = torch.zeros((encoder_matrix.shape[0],encoder_matrix.shape[0]),device="cpu") + covariance_between_latents = torch.zeros( + (encoder_matrix.shape[0], encoder_matrix.shape[0]), device="cpu" + ) # do batches of latents batch_size = 1024 - for start in tqdm(range(0,encoder_matrix.shape[0],batch_size)): - end = min(encoder_matrix.shape[0],start+batch_size) + for start in tqdm(range(0, encoder_matrix.shape[0], batch_size)): + end = min(encoder_matrix.shape[0], start + batch_size) encoder_rows = encoder_matrix[start:end] - + correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T covariance_between_latents[start:end] = correlation.cpu() - # the correlation is then the covariance divided by the product of the standard deviations + # the correlation is then the covariance divided + # by the product of the standard deviations diagonal_covariance = torch.diag(covariance_between_latents) - product_of_std = torch.sqrt(torch.outer(diagonal_covariance,diagonal_covariance)+1e-6) + product_of_std = torch.sqrt( + torch.outer(diagonal_covariance, diagonal_covariance) + 1e-6 + ) correlation_matrix = covariance_between_latents / product_of_std # get the indices of the top k neighbours for each feature - indices,values = torch.topk(correlation_matrix, self.number_of_neighbours+1, dim=1) + indices, values = torch.topk( + correlation_matrix, self.number_of_neighbours + 1, dim=1 + ) # return the neighbour lists - return {i: list(zip(indices[i].tolist()[1:],values[i].tolist()[1:])) for i in range(len(indices))} + return { + i: list(zip(indices[i].tolist()[1:], values[i].tolist()[1:])) + for i in range(len(indices)) + } - def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: """ Compute neighbour lists based on feature co-occurrence in the dataset. If you run out of memory try reducing the token_batch_size + Code adapted from https://github.com/taha-yassine/SAE-features/blob/main/cooccurrences/compute.py """ - - import cupy as cp import cupyx.scipy.sparse as cusparse + print("Computing co-occurrence neighbours") paths = [] for buffer in self.latent_dataset.buffers: paths.append(buffer.tensor_path) - + all_locations = [] for path in paths: split_data = load_file(path) first_feature = int(path.split("/")[-1].split("_")[0]) locations = torch.tensor(split_data["locations"].astype(np.int64)) - locations[:,2] = locations[:,2] + first_feature - # compute number of tokens + locations[:, 2] = locations[:, 2] + first_feature + # compute number of tokens all_locations.append(locations) - + # concatenate the locations and activations locations = torch.cat(all_locations) - n_latents = int(torch.max(locations[:,2])) + 1 + n_latents = int(torch.max(locations[:, 2])) + 1 - # 1. Get unique values of first 2 dims (i.e. absolute token index) and their counts - # Trick is to use Cantor pairing function to have a bijective mapping between (batch_id, ctx_pos) and a unique 1D index + # 1. Get unique values of first 2 dims (i.e. absolute token index) + # and their counts + # Trick is to use Cantor pairing function to have a bijective mapping between + # (batch_id, ctx_pos) and a unique 1D index # Faster than running `torch.unique_consecutive` on the first 2 dims - idx_cantor = (locations[:,0] + locations[:,1]) * (locations[:,0] + locations[:,1] + 1) // 2 + locations[:,1] - unique_idx, idx_counts = torch.unique_consecutive(idx_cantor, return_counts=True) + idx_cantor = (locations[:, 0] + locations[:, 1]) * ( + locations[:, 0] + locations[:, 1] + 1 + ) // 2 + locations[:, 1] + unique_idx, idx_counts = torch.unique_consecutive( + idx_cantor, return_counts=True + ) n_tokens = len(unique_idx) - # 2. The Cantor indices are not consecutive, so we create sorted ones from the counts - locations_flat = torch.repeat_interleave(torch.arange(n_tokens, device=locations.device), idx_counts) - del idx_cantor,unique_idx,idx_counts + # 2. The Cantor indices are not consecutive, + # so we create sorted ones from the counts + locations_flat = torch.repeat_interleave( + torch.arange(n_tokens, device=locations.device), idx_counts + ) + del idx_cantor, unique_idx, idx_counts rows = cp.asarray(locations[:, 2]) cols = cp.asarray(locations_flat) data = cp.ones(len(rows)) - sparse_matrix = cusparse.coo_matrix((data, (rows, cols)), shape=(n_latents, n_tokens)) + sparse_matrix = cusparse.coo_matrix( + (data, (rows, cols)), shape=(n_latents, n_tokens) + ) token_batch_size = 100_000 cooc_matrix = cp.zeros((n_latents, n_latents), dtype=cp.float32) - + sparse_matrix_csc = sparse_matrix.tocsc() for start in tqdm(range(0, n_tokens, token_batch_size)): end = min(n_tokens, start + token_batch_size) @@ -219,29 +271,32 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: # Free temporary variables. del rows, cols, data, sparse_matrix, sparse_matrix_csc - - + # Compute Jaccard similarity def compute_jaccard(cooc_matrix): self_occurrence = cooc_matrix.diagonal() - jaccard_matrix = cooc_matrix / (self_occurrence[:, None] + self_occurrence - cooc_matrix) + jaccard_matrix = cooc_matrix / ( + self_occurrence[:, None] + self_occurrence - cooc_matrix + ) # remove the diagonal and keep the upper triangle return jaccard_matrix # Compute Jaccard similarity matrix jaccard_matrix = compute_jaccard(cooc_matrix) - jaccard_torch = torch.as_tensor(cp.asnumpy(jaccard_matrix)) - # get the indices of the top k neighbours for each feature - top_k_indices,values = torch.topk(jaccard_torch, self.number_of_neighbours+1, dim=1) - del jaccard_matrix,cooc_matrix,jaccard_torch + # get the indices of the top k neighbours for each feature + top_k_indices, values = torch.topk( + jaccard_torch, self.number_of_neighbours + 1, dim=1 + ) + del jaccard_matrix, cooc_matrix, jaccard_torch torch.cuda.empty_cache() - - # return the neighbour lists - return {i: list(zip(top_k_indices[i].tolist()[1:],values[i].tolist()[1:])) for i in range(len(top_k_indices))} - + # return the neighbour lists + return { + i: list(zip(top_k_indices[i].tolist()[1:], values[i].tolist()[1:])) + for i in range(len(top_k_indices)) + } def populate_neighbour_cache(self, methods: list[str]) -> None: """ @@ -250,25 +305,23 @@ def populate_neighbour_cache(self, methods: list[str]) -> None: for method in methods: self._compute_neighbour_list(method) - def save_neighbour_cache(self, path: str) -> None: """ Save the neighbour cache to the path as a json file """ - with open(path, 'w') as f: + with open(path, "w") as f: json.dump(self.neighbour_cache, f) def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: """ Load the neighbour cache from the path as a json file """ - with open(path, 'r') as f: + with open(path, "r") as f: return json.load(f) - class CovarianceEstimator: - def __init__(self, n_latents, *, device = None): + def __init__(self, n_latents, *, device=None): self.mean = torch.zeros(n_latents, device=device) self.cov_ = torch.zeros(n_latents, n_latents, device=device) self.n = 0 @@ -288,4 +341,4 @@ def update(self, x: torch.Tensor): def cov(self): """Return the estimated covariance matrix.""" - return self.cov_ / self.n \ No newline at end of file + return self.cov_ / self.n From edeaf2a60256d9eb815adf1e725c5778431f5c76 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:20:58 +0000 Subject: [PATCH 085/132] Create all data tensor, rename dataclasses --- delphi/latents/loader.py | 119 ++++++++++++++++++++++++++------------- 1 file changed, 81 insertions(+), 38 deletions(-) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 47b5b03b..8ed159e3 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -19,22 +19,34 @@ from .latents import Latent, LatentRecord -class BufferOutput(NamedTuple): +class ActivationData(NamedTuple): """ - Represents the output of a TensorBuffer. + Represents the activation data for a latent. """ - latent: Latent - """The latent associated with this output.""" - locations: TensorType["locations", 2] """Tensor of latent locations.""" activations: TensorType["locations"] """Tensor of latent activations.""" - tokens: TensorType["tokens"] - """Tensor of all tokens.""" + tokens: Optional[TensorType["tokens"]] = None + """Tensor of tokens.""" + + +class LatentData(NamedTuple): + """ + Represents the output of a TensorBuffer. + """ + + latent: Latent + """The latent associated with this output.""" + + module: str + """The module associated with this output.""" + + activation_data: ActivationData + """The activation data for this latent.""" @dataclass @@ -55,6 +67,9 @@ class TensorBuffer: min_examples: int = 120 """Minimum number of examples required. Defaults to 120.""" + tokens: Optional[TensorType["tokens"]] = None + """Tensor of tokens.""" + def __iter__(self): """ Iterate over the buffer, yielding BufferOutput objects. @@ -63,7 +78,7 @@ def __iter__(self): Union[BufferOutput, None]: BufferOutput if enough examples, None otherwise. """ - latents, split_locations, split_activations, tokens = self.load() + latents, split_locations, split_activations = self.load_data_per_latent() for i in range(len(latents)): latent_locations = split_locations[i] @@ -71,13 +86,28 @@ def __iter__(self): if len(latent_locations) < self.min_examples: yield None else: - yield BufferOutput( + yield LatentData( Latent(self.module_path, int(latents[i].item())), - latent_locations, - latent_activations, - tokens, + self.module_path, + ActivationData(latent_locations, latent_activations, self.tokens), ) + def load_data_per_latent(self): + locations, activations, tokens = self.load() + if tokens is not None: + self.tokens = tokens + indices = torch.argsort(locations[:, 2], stable=True) + activations = activations[indices] + locations = locations[indices] + unique_latents, counts = torch.unique_consecutive( + locations[:, 2], return_counts=True + ) + latents = unique_latents + split_locations = torch.split(locations, counts.tolist()) + split_activations = torch.split(activations, counts.tolist()) + + return latents, split_locations, split_activations + def load(self): split_data = load_file(self.path) first_latent = int(self.path.split("/")[-1].split("_")[0]) @@ -95,23 +125,7 @@ def load(self): locations = locations[wanted_locations] activations = activations[wanted_locations] - indices = torch.argsort(locations[:, 2], stable=True) - activations = activations[indices] - locations = locations[indices] - unique_latents, counts = torch.unique_consecutive( - locations[:, 2], return_counts=True - ) - latents = unique_latents - split_locations = torch.split(locations, counts.tolist()) - split_activations = torch.split(activations, counts.tolist()) - - return latents, split_locations, split_activations, tokens - - def reset(self): - """Reset the buffer state.""" - self.start = 0 - self.activations = None - self.locations = None + return locations, activations, tokens class LatentDataset: @@ -140,7 +154,8 @@ def __init__( latents: Dictionary of latents per module. """ self.cfg = cfg - self.buffers = [] + self.buffers: list[TensorBuffer] = [] + self.all_data: dict[str, ActivationData | None] = {} if latents is None: self._build(raw_dir, modules) @@ -159,6 +174,11 @@ def __init__( self.sampler = sampler self.transform = transform + # TODO: is it possible to do this without loading all data? + if self.constructor is not None: + if self.constructor.keywords["constructor_type"] == "neighbour": + self.all_data = self._load_all_data(raw_dir, modules) + def load_tokens(self): """ Load tokenized data for the dataset. @@ -208,6 +228,7 @@ def _build(self, raw_dir: str, modules: Optional[list[str]] = None): self.buffers.append( TensorBuffer(path, module, min_examples=self.cfg.min_examples) ) + self.all_data[module] = None def _build_selected( self, @@ -253,15 +274,33 @@ def _build_selected( min_examples=self.cfg.min_examples, ) ) + self.all_data[module] = None def __len__(self): """Return the number of buffers in the dataset.""" return len(self.buffers) - def reset(self): - """Reset all buffers in the dataset.""" + def _load_all_data(self, raw_dir: str, modules: list[str]): + """For each module, load all locations and activations""" + all_locations = {} + all_activations = {} + all_data = {} + tokens = None for buffer in self.buffers: - buffer.reset() + module = buffer.module_path + if module not in all_locations: + all_locations[module] = [] + all_activations[module] = [] + activations, locations, tokens = buffer.load() + all_locations[module].append(locations) + all_activations[module].append(activations) + for module in all_locations: + all_data[module] = ActivationData( + torch.cat(all_locations[module]), + torch.cat(all_activations[module]), + tokens, + ) + return all_data def __iter__(self): """ @@ -304,7 +343,7 @@ async def _aprocess_buffer(self, buffer: TensorBuffer): yield record await asyncio.sleep(0) # Allow other coroutines to run - async def _aprocess_latent(self, buffer_output: BufferOutput) -> LatentRecord: + async def _aprocess_latent(self, latent_data: LatentData) -> LatentRecord: """ Asynchronously process a single latent. @@ -314,11 +353,15 @@ async def _aprocess_latent(self, buffer_output: BufferOutput) -> LatentRecord: Returns: Optional[LatentRecord]: Processed latent record or None. """ - record = LatentRecord(buffer_output.latent) + record = LatentRecord(latent_data.latent) + if self.transform is not None: + self.transform(record) if self.constructor is not None: - self.constructor(record=record, buffer_output=buffer_output) + self.constructor( + record=record, + activation_data=latent_data.activation_data, + all_data=self.all_data[latent_data.module], + ) if self.sampler is not None: self.sampler(record) - if self.transform is not None: - self.transform(record) return record From 44fbeeae5fab7fd9dde2bc68b53431211f5c8a05 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:21:26 +0000 Subject: [PATCH 086/132] pre-commit stuff --- delphi/latents/neighbours.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index d0821fbe..449311d0 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -29,10 +29,10 @@ def __init__( Initialize a NeighbourCalculator. Args: - latent_dataset (Optional[LatentDataset]): Dataset + latent_dataset (Optional[LatentDataset]): Dataset containing latent activations autoencoder (Optional[Autoencoder]): The trained autoencoder model - residual_stream_record (Optional[ResidualStreamRecord]): Record of + residual_stream_record (Optional[ResidualStreamRecord]): Record of residual stream values """ self.latent_dataset = latent_dataset @@ -184,7 +184,7 @@ def _compute_correlation_neighbours(self) -> dict[int, list[int]]: correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T covariance_between_latents[start:end] = correlation.cpu() - # the correlation is then the covariance divided + # the correlation is then the covariance divided # by the product of the standard deviations diagonal_covariance = torch.diag(covariance_between_latents) product_of_std = torch.sqrt( @@ -244,7 +244,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: ) n_tokens = len(unique_idx) - # 2. The Cantor indices are not consecutive, + # 2. The Cantor indices are not consecutive, # so we create sorted ones from the counts locations_flat = torch.repeat_interleave( torch.arange(n_tokens, device=locations.device), idx_counts From b17d7e86de7389baa551750d5af7b6e5a0045853 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:21:40 +0000 Subject: [PATCH 087/132] Remove code that shouldn't be here --- delphi/semantic_index/index.py | 250 --------------------------------- 1 file changed, 250 deletions(-) delete mode 100644 delphi/semantic_index/index.py diff --git a/delphi/semantic_index/index.py b/delphi/semantic_index/index.py deleted file mode 100644 index 6dc9034f..00000000 --- a/delphi/semantic_index/index.py +++ /dev/null @@ -1,250 +0,0 @@ -import json -from pathlib import Path -from typing import Optional - -import faiss -import numpy as np -import torch -from datasets import Dataset -from safetensors.numpy import load_file -from sparsify.data import chunk_and_tokenize -from torch import Tensor -from torch.utils.data import DataLoader -from tqdm import tqdm -from transformers import AutoModel, AutoTokenizer - -from delphi.config import CacheConfig -from delphi.latents import LatentDataset -from delphi.logger import logger -from delphi.utils import assert_type -import time - - -#TODO: Think if this code should be moved to neighbours - - -class AdversarialContexts: - """ - Class to compute the neighbours of selected latents using different methods: - - similarity: uses autoencoder weights - - correlation: uses pre-activation records and autoencoder - - co-occurrence: uses latent dataset statistics - """ - - def __init__( - self, - - latent_dataset: Optional['LatentDataset'] = None, - index_path: Optional[str] = None, - context_length: int = 32, - number_of_neighbours: int = 50, - ): - - """ - Initialize a NeighbourCalculator. - - Args: - latent_dataset (Optional[LatentDataset]): Dataset containing latent activations - autoencoder (Optional[Autoencoder]): The trained autoencoder model - residual_stream_record (Optional[ResidualStreamRecord]): Record of residual stream values - """ - self.latent_dataset = latent_dataset - self.context_length = context_length - self.index_path = index_path - # try to load index - self.index = self.load_index(index_path) - - - def _compute_similar_contexts(self) -> dict[int, list[int]]: - """ - Compute neighbour lists based on feature co-occurrence in the dataset. - If you run out of memory try reducing the token_batch_size - """ - - print("Creating index") - paths = [] - for buffer in self.latent_dataset.buffers: - paths.append(buffer.tensor_path) - - all_locations = [] - all_activations = [] - tokens = None - for path in paths: - split_data = load_file(path) - first_feature = int(path.split("/")[-1].split("_")[0]) - locations = torch.tensor(split_data["locations"].astype(np.int64)) - locations[:,2] = locations[:,2] + first_feature - activations = torch.tensor(split_data["activations"].astype(np.float32)) - # compute number of tokens - all_locations.append(locations) - all_activations.append(activations) - if tokens is None: - tokens = split_data["tokens"] - tokens = tokens[:10000] - reshaped_tokens = tokens.reshape(-1, self.context_length) - strings = self.latent_dataset.tokenizer.batch_decode(reshaped_tokens, skip_special_tokens=True) - if self.index is None: - index = self._build_index(strings) - self.save_index(index) - else: - index = self.index - - locations = torch.cat(all_locations) - activations = torch.cat(all_activations) - - indices = torch.argsort(locations[:,2], stable=True) - locations = locations[indices] - activations = activations[indices] - unique_latents, counts = torch.unique_consecutive(locations[:,2], return_counts=True) - cache_ctx_len = torch.max(locations[:,1])+1 - - latents = unique_latents - split_locations = torch.split(locations, counts.tolist()) - split_activations = torch.split(activations, counts.tolist()) - latents = unique_latents - dict_of_adversarial_contexts = {} - for latent, locations, activations in tqdm(zip(latents, split_locations, split_activations)): - flat_indices = locations[:,0]*cache_ctx_len+locations[:,1] - ctx_indices = flat_indices // self.context_length - index_within_ctx = flat_indices % self.context_length - unique_ctx_indices, inverses, lengths = torch.unique_consecutive(ctx_indices, return_counts=True, return_inverse=True) - # Get the max activation magnitude within each context window - max_buffer = torch.segment_reduce(activations, 'max', lengths=lengths) - k = 100 - _, top_indices = torch.topk(max_buffer, k, sorted=True) - top_indices = unique_ctx_indices[top_indices] - # find the context in the index, that are not activating contexts but are the most similar to top_indices - activating_contexts_indices = ctx_indices.unique() - print(top_indices.shape) - print(len(top_indices)) - query_vectors = [] - for i in range(len(top_indices)): - print(top_indices[i].item()) - breakpoint() - query_vectors.append(index.reconstruct(top_indices[i].item())) - - print(query_vectors.shape) - distances, indices = index.search(query_vectors, len(top_indices)+len(activating_contexts_indices)+self.number_of_neighbours) - filtered_indices = [] - - for index, distance in zip(indices, distances): - valid = [id_ for id_ in index if id_ not in activating_contexts_indices] - filtered_indices.append(valid[:k]) - print(len(filtered_indices)) - dict_of_adversarial_contexts[latent] = filtered_indices - self.save_adversarial_contexts(dict_of_adversarial_contexts) - - - - def _build_index(self,strings: list[str]) -> faiss.IndexIDMap: - index_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') - index_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to("cuda") - - tokenized = index_tokenizer(strings, return_tensors="pt", padding=True,max_length=512,padding_side="right",truncation=True) - index_tokens = tokenized["input_ids"] - index_initializer = index_model(index_tokens[:2].to("cuda")).last_hidden_state - - base_index = faiss.IndexFlatL2(index_initializer.shape[-1]) - index = faiss.IndexIDMap(base_index) - - - batch_size = 512 - dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore - from tqdm import tqdm - with torch.no_grad(): - for batch_idx, batch in enumerate(tqdm(dataloader)): - batch = batch.to("cuda") - token_embeddings = index_model(batch).last_hidden_state - sentence_embeddings = token_embeddings.mean(dim=1) - sentence_embeddings = sentence_embeddings.cpu().numpy().astype(np.float32) - ids = np.arange(batch_idx * batch_size, batch_idx * batch_size + len(batch)) - index.add_with_ids(sentence_embeddings, ids) - - return index - - def populate_neighbour_cache(self, methods: list[str]) -> None: - """ - Populate the neighbour cache with the computed neighbour lists - """ - for method in methods: - self._compute_neighbour_list(method) - - - def load_index(self, base_path: str) -> faiss.IndexFlatL2: - # check if index exists - index_path = base_path + "/index.faiss" - if not Path(index_path).exists(): - return None - return faiss.read_index(str(index_path)) - - def save_index(self,index: faiss.IndexFlatL2, ): - index_path = self.index_path + "/index.faiss" - - faiss.write_index(index, str(index_path)) - - def save_adversarial_contexts(self, dict_of_adversarial_contexts: dict[int, list[int]]): - with open(self.index_path + "adversarial_contexts.json", "w") as f: - json.dump(dict_of_adversarial_contexts, f) - -def get_neighbors_by_id(index: faiss.IndexIDMap, vector_id: int, k: int = 10): - # First reconstruct the vector for the given ID - vector = index.reconstruct(vector_id) - - # Reshape to match FAISS expectations (needs 2D array) - vector = vector.reshape(1, -1) - - # Search for nearest neighbors - distances, neighbor_ids = index.search(vector, k + 1) # k+1 since it will find itself - - # Remove the first result (which will be the query vector itself) - return distances[0][1:], neighbor_ids[0][1:] - -def get_index_path(base_path: Path, cfg: CacheConfig): - return base_path / f"{cfg.dataset_repo.replace('/', '_')}_{cfg.dataset_split}_{cfg.ctx_len}.idx" - - - - -def build_semantic_index(data: Dataset, cfg: CacheConfig): - """ - Build a semantic index of the token sequences. - """ - index_tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') - index_model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2').to("cuda") - - index_tokens = chunk_and_tokenize(data, index_tokenizer, max_seq_len=cfg.ctx_len, text_key=cfg.dataset_row) - index_tokens = index_tokens["input_ids"] - index_tokens = assert_type(Tensor, index_tokens) - - token_embeddings = index_model(index_tokens[:2].to("cuda")).last_hidden_state - - base_index = faiss.IndexFlatL2(token_embeddings.shape[-1]) - index = faiss.IndexIDMap(base_index) - - batch_size = 512 - dataloader = DataLoader(index_tokens, batch_size=batch_size) # type: ignore - - from tqdm import tqdm - with torch.no_grad(): - for batch_idx, batch in enumerate(tqdm(dataloader)): - batch = batch.to("cuda") - token_embeddings = index_model(batch).last_hidden_state - sentence_embeddings = token_embeddings.mean(dim=1) - sentence_embeddings = sentence_embeddings.cpu().numpy().astype(np.float32) - - ids = np.arange(batch_idx * batch_size, batch_idx * batch_size + len(batch)) - index.add_with_ids(sentence_embeddings, ids) - - return index - - -def build_or_load_index(data: Dataset, base_path: Path, cfg: CacheConfig): - index_path = get_index_path(base_path, cfg) - - if not index_path.exists(): - logger.info(f"Building semantic index for {cfg.dataset_repo} {cfg.dataset_split} seq_len={cfg.ctx_len}...") - index = build_semantic_index(data, cfg) - save_index(index, base_path, cfg) - return index - else: - return load_index(base_path, cfg) From 5ec2b70375dd8cc548eb085512d2783d06a524b1 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 16:26:35 +0000 Subject: [PATCH 088/132] Added new non_activating_source argument --- delphi/__main__.py | 2 +- delphi/config.py | 6 ++++++ delphi/tests/e2e.py | 1 + 3 files changed, 8 insertions(+), 1 deletion(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index e7e7e95e..7bf153d7 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -94,7 +94,7 @@ async def process_cache( constructor, token_loader=None, n_not_active=experiment_cfg.n_non_activating, - constructor_type="random", + constructor_type=experiment_cfg.non_activating_source, ctx_len=experiment_cfg.example_ctx_len, max_examples=latent_cfg.max_examples, ) diff --git a/delphi/config.py b/delphi/config.py index 7d12b64e..079b1e0e 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -29,6 +29,12 @@ class ExperimentConfig(Serializable): test_type: Literal["quantiles", "activation"] = "quantiles" """Type of sampler to use for latent explanation testing.""" + non_activating_source: Literal["random", "neighbours"] = "random" + """Source of non-activating examples. Random uses non-activating contexts + sampled from any non activating window. Neighbours uses actvating contexts + from pre-computed latent neighbours. They are still non-activating but + have a higher chance of being similar to the activating examples.""" + @dataclass class LatentConfig(Serializable): diff --git a/delphi/tests/e2e.py b/delphi/tests/e2e.py index c8eae835..c3d82c9e 100644 --- a/delphi/tests/e2e.py +++ b/delphi/tests/e2e.py @@ -24,6 +24,7 @@ async def test(): test_type="quantiles", n_examples_train=40, n_examples_test=50, + non_activating_source="random", ) latent_cfg = LatentConfig( min_examples=200, From 12af689b5ff482e53c7d83f4b52902cb1f04bddf Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 21:58:04 +0000 Subject: [PATCH 089/132] else bug --- delphi/scorers/classifier/detection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index 484d6d23..d3b04e3f 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -72,7 +72,7 @@ def _prepare(self, record: LatentRecord) -> list[list[Sample]]: ground_truth=False, tokenizer=self.tokenizer, ) - else: + else: samples = [] for i, examples in enumerate(record.test): From b8ddc8e08315a1a87b1a2841b1f6b756ec7f1227 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 21:58:44 +0000 Subject: [PATCH 090/132] circular import --- delphi/latents/neighbours.py | 2 +- delphi/latents/transforms.py | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 449311d0..4b7bbb44 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -216,7 +216,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: print("Computing co-occurrence neighbours") paths = [] for buffer in self.latent_dataset.buffers: - paths.append(buffer.tensor_path) + paths.append(buffer.path) all_locations = [] for path in paths: diff --git a/delphi/latents/transforms.py b/delphi/latents/transforms.py index 2ced7244..3cb61c3b 100644 --- a/delphi/latents/transforms.py +++ b/delphi/latents/transforms.py @@ -1,10 +1,5 @@ -from dataclasses import dataclass -from .latents import LatentRecord +from .latents import LatentRecord, Neighbour -@dataclass -class Neighbour: - distance: float - feature_index: int def set_neighbours( record: LatentRecord, @@ -14,11 +9,14 @@ def set_neighbours( """ Set the neighbours for the latent record. """ - + neighbours = neighbours[str(record.latent.latent_index)] # Each element in neighbours is a tuple of (distance,feature_index) # We want to keep only the ones with a distance less than the threshold neighbours = [neighbour for neighbour in neighbours if neighbour[0] > threshold] - record.neighbours = [Neighbour(distance=neighbour[0], feature_index=neighbour[1]) for neighbour in neighbours] + record.neighbours = [ + Neighbour(distance=neighbour[0], latent_index=neighbour[1]) + for neighbour in neighbours + ] From 9a166f654b5ebb33053c5fbad4856fae835dcadb Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 21:59:05 +0000 Subject: [PATCH 091/132] also circular import fix --- delphi/latents/latents.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index eedbac5b..39cb68ba 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -7,6 +7,11 @@ from transformers import AutoTokenizer +@dataclass +class Neighbour: + distance: float + latent_index: int + @dataclass class Example: """ @@ -97,6 +102,9 @@ class LatentRecord: test: list[list[Example]] = field(default_factory=list) """Test examples.""" + neighbours: list[Neighbour] = field(default_factory=list) + """Neighbours of the latent.""" + @property def max_activation(self) -> float: """ From c849aa079029f207d2551f918b22408e4938f722 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 21:59:48 +0000 Subject: [PATCH 092/132] Making constructor work. Loader changes --- delphi/latents/constructors.py | 37 ++++------------- delphi/latents/loader.py | 75 ++++++++++++++++++---------------- 2 files changed, 49 insertions(+), 63 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index b2dedd93..8de12089 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -1,4 +1,4 @@ -from typing import Callable, Literal, Optional +from typing import Literal, Optional import torch from torchtyping import TensorType @@ -87,27 +87,9 @@ def constructor( max_examples: int, ctx_len: int, constructor_type: Literal["random", "neighbour"], - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None = None, + tokens: TensorType["tokens"], all_data: Optional[ActivationData] = None, ): - tokens = activation_data.tokens - if tokens is None: - if token_loader is None: - raise ValueError("Either tokens or token_loader must be provided") - try: - tokens = token_loader() - except TypeError: - raise ValueError( - "Starting with v0.2, `tokens` was renamed to `token_loader`, " - "which must be a callable for lazy loading.\n\n" - "Instead of passing\n" - "` tokens=dataset.tokens`,\n" - "pass\n" - "` token_loader=lambda: dataset.load_tokens()`,\n" - "(assuming `dataset` is a `LatentDataset` instance)." - ) - - tokens.shape[0] cache_token_length = tokens.shape[1] # Get all positions where the latent is active @@ -194,19 +176,17 @@ def neighbour_non_activation_windows( n_examples_per_neighbour = 10 number_examples = 0 - available_features = all_data.features all_examples = [] used_neighbours = [] for neighbour in record.neighbours: if number_examples >= n_not_active: break - # find indice in all_data.features that matches the neighbour - indice = torch.where(available_features == neighbour.feature_index)[0] - if len(indice) == 0: - continue # get the locations of the neighbour - locations = all_data.locations[indice] - activations = all_data.activations[indice] + if neighbour.latent_index not in all_data: + print(f"Neighbour {neighbour.latent_index} not found in all_data") + continue + locations = all_data[neighbour.latent_index].locations + activations = all_data[neighbour.latent_index].activations # get the active window indices flat_indices = locations[:, 0] * cache_token_length + locations[:, 1] ctx_indices = flat_indices // ctx_len @@ -228,6 +208,7 @@ def neighbour_non_activation_windows( activations = activations[mask_ctx] # If there are no available indices, skip this neighbour if activations.numel() == 0: + print(f"No available indices for neighbour {neighbour.latent_index}") continue token_windows, act_windows = pool_max_activation_windows( activations=activations, @@ -245,10 +226,10 @@ def neighbour_non_activation_windows( ) used_neighbours.append(neighbour) number_examples += examples_used + # We change neighbours in place to be the list of neighbours used record.neighbours = used_neighbours if len(all_examples) == 0: print("No examples found") - record.not_active = all_examples diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 8ed159e3..86ab3116 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -27,12 +27,9 @@ class ActivationData(NamedTuple): locations: TensorType["locations", 2] """Tensor of latent locations.""" - activations: TensorType["locations"] + activations: TensorType["activations"] """Tensor of latent activations.""" - tokens: Optional[TensorType["tokens"]] = None - """Tensor of tokens.""" - class LatentData(NamedTuple): """ @@ -67,7 +64,7 @@ class TensorBuffer: min_examples: int = 120 """Minimum number of examples required. Defaults to 120.""" - tokens: Optional[TensorType["tokens"]] = None + _tokens: Optional[TensorType["tokens"]] = None """Tensor of tokens.""" def __iter__(self): @@ -89,13 +86,17 @@ def __iter__(self): yield LatentData( Latent(self.module_path, int(latents[i].item())), self.module_path, - ActivationData(latent_locations, latent_activations, self.tokens), + ActivationData(latent_locations, latent_activations), ) + @property + def tokens(self) -> TensorType["tokens"]: + if self._tokens is None: + self._tokens = self.load_tokens() + return self._tokens + def load_data_per_latent(self): - locations, activations, tokens = self.load() - if tokens is not None: - self.tokens = tokens + locations, activations, _ = self.load() indices = torch.argsort(locations[:, 2], stable=True) activations = activations[indices] locations = locations[indices] @@ -127,6 +128,10 @@ def load(self): return locations, activations, tokens + def load_tokens(self): + _, _, tokens = self.load() + return tokens + class LatentDataset: """ @@ -155,7 +160,8 @@ def __init__( """ self.cfg = cfg self.buffers: list[TensorBuffer] = [] - self.all_data: dict[str, ActivationData | None] = {} + self.all_data: dict[str, dict[int, ActivationData] | None] = {} + self.tokens = None if latents is None: self._build(raw_dir, modules) @@ -179,6 +185,8 @@ def __init__( if self.constructor.keywords["constructor_type"] == "neighbour": self.all_data = self._load_all_data(raw_dir, modules) + self.load_tokens() + def load_tokens(self): """ Load tokenized data for the dataset. @@ -225,9 +233,13 @@ def _build(self, raw_dir: str, modules: Optional[list[str]] = None): edges = self._edges(raw_dir, module) for start, end in edges: path = f"{raw_dir}/{module}/{start}_{end}.safetensors" - self.buffers.append( - TensorBuffer(path, module, min_examples=self.cfg.min_examples) + tensor_buffer = TensorBuffer( + path, module, min_examples=self.cfg.min_examples ) + if self.tokens is None: + self.tokens = tensor_buffer.tokens + self.buffers.append(tensor_buffer) + self.all_data[module] = None self.all_data[module] = None def _build_selected( @@ -265,15 +277,12 @@ def _build_selected( start, end = boundaries[bucket.item() - 1], boundaries[bucket.item()] # Adjust end by one as the path avoids overlap path = f"{raw_dir}/{module}/{start}_{end-1}.safetensors" - - self.buffers.append( - TensorBuffer( - path, - module, - _selected_latents, - min_examples=self.cfg.min_examples, - ) + tensor_buffer = TensorBuffer( + path, module, _selected_latents, min_examples=self.cfg.min_examples ) + if self.tokens is None: + self.tokens = tensor_buffer.tokens + self.buffers.append(tensor_buffer) self.all_data[module] = None def __len__(self): @@ -282,24 +291,19 @@ def __len__(self): def _load_all_data(self, raw_dir: str, modules: list[str]): """For each module, load all locations and activations""" - all_locations = {} - all_activations = {} all_data = {} - tokens = None for buffer in self.buffers: module = buffer.module_path - if module not in all_locations: - all_locations[module] = [] - all_activations[module] = [] - activations, locations, tokens = buffer.load() - all_locations[module].append(locations) - all_activations[module].append(activations) - for module in all_locations: - all_data[module] = ActivationData( - torch.cat(all_locations[module]), - torch.cat(all_activations[module]), - tokens, - ) + if module not in all_data: + all_data[module] = {} + temp_latents = buffer.latents + # we remove the filter on latents + buffer.latents = None + latents, locations, activations = buffer.load_data_per_latent() + # we restore the filter on latents + buffer.latents = temp_latents + for latent, location, activation in zip(latents, locations, activations): + all_data[module][latent.item()] = ActivationData(location, activation) return all_data def __iter__(self): @@ -361,6 +365,7 @@ async def _aprocess_latent(self, latent_data: LatentData) -> LatentRecord: record=record, activation_data=latent_data.activation_data, all_data=self.all_data[latent_data.module], + tokens=self.tokens, ) if self.sampler is not None: self.sampler(record) From 1126c96a9c05f1277c65210f96029fe6d1c7a685 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Mon, 17 Feb 2025 22:00:44 +0000 Subject: [PATCH 093/132] Adding neighbours to main, config --- delphi/__main__.py | 83 ++++++++++++++++++++++++++++++++++++++++++++-- delphi/config.py | 2 +- 2 files changed, 82 insertions(+), 3 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 7bf153d7..e0ab935f 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -27,7 +27,9 @@ from delphi.explainers import DefaultExplainer from delphi.latents import LatentCache, LatentDataset from delphi.latents.constructors import constructor +from delphi.latents.neighbours import NeighbourCalculator from delphi.latents.samplers import sample +from delphi.latents.transforms import set_neighbours from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper from delphi.scorers import DetectionScorer, FuzzingScorer @@ -59,11 +61,60 @@ def load_artifacts(run_cfg: RunConfig): return run_cfg.hookpoints, hookpoint_to_sparse_encode, model +async def create_neighbours( + latent_cfg: LatentConfig, + experiment_cfg: ExperimentConfig, + hookpoints: list[str], + latents_path: Path, + neighbours_path: Path, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, + latent_range: Tensor | None, +): + """ + Creates a neighbours file for the given hookpoints. + """ + if latent_range is None: + latent_dict = None + else: + latent_dict = { + hook: latent_range for hook in hookpoints + } # The latent range to explain + latent_dict = cast(dict[str, int | Tensor], latent_dict) + example_constructor = partial( + constructor, + n_not_active=experiment_cfg.n_non_activating, + constructor_type=experiment_cfg.non_activating_source, + ctx_len=experiment_cfg.example_ctx_len, + max_examples=latent_cfg.max_examples, + ) + sampler = partial(sample, cfg=experiment_cfg) + + dataset = LatentDataset( + raw_dir=str(latents_path), + cfg=latent_cfg, + modules=hookpoints, + latents=latent_dict, + tokenizer=tokenizer, + constructor=example_constructor, + sampler=sampler, + ) + + neighbour_calculator = NeighbourCalculator( + latent_dataset=dataset, number_of_neighbours=100 + ) + + neighbour_calculator.populate_neighbour_cache(["co-occurrence"]) + neighbours_path.mkdir(parents=True, exist_ok=True) + + neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/neighbours.json") + + async def process_cache( latent_cfg: LatentConfig, run_cfg: RunConfig, experiment_cfg: ExperimentConfig, latents_path: Path, + neighbours_path: Path, explanations_path: Path, scores_path: Path, hookpoints: list[str], @@ -92,14 +143,22 @@ async def process_cache( example_constructor = partial( constructor, - token_loader=None, n_not_active=experiment_cfg.n_non_activating, constructor_type=experiment_cfg.non_activating_source, ctx_len=experiment_cfg.example_ctx_len, max_examples=latent_cfg.max_examples, ) sampler = partial(sample, cfg=experiment_cfg) - + if experiment_cfg.non_activating_source == "neighbours": + with open(neighbours_path / "neighbours.json", "r") as f: + neighbours = json.load(f)["co-occurrence"] + transform = partial( + set_neighbours, + neighbours=neighbours, + threshold=0.0, + ) + else: + transform = None dataset = LatentDataset( raw_dir=str(latents_path), cfg=latent_cfg, @@ -108,6 +167,7 @@ async def process_cache( tokenizer=tokenizer, constructor=example_constructor, sampler=sampler, + transform=transform, ) if run_cfg.explainer_provider == "offline": @@ -272,6 +332,7 @@ async def run( latents_path = base_path / "latents" explanations_path = base_path / "explanations" scores_path = base_path / "scores" + neighbours_path = base_path / "neighbours" visualize_path = base_path / "visualize" latent_range = torch.arange(run_cfg.max_latents) if run_cfg.max_latents else None @@ -295,6 +356,23 @@ async def run( print(f"Files found in {latents_path}, skipping cache population...") del model, hookpoint_to_sparse_encode + if ( + not glob(str(neighbours_path / ".*")) + glob(str(neighbours_path / "*")) + or "neighbours" in run_cfg.overwrite + ): + # TODO: we probably want to use less arguments and load the latent dataset + # only once? + await create_neighbours( + latent_cfg, + experiment_cfg, + hookpoints, + latents_path, + neighbours_path, + tokenizer, + latent_range, + ) + else: + print(f"Files found in {neighbours_path}, skipping...") if ( not glob(str(scores_path / ".*")) + glob(str(scores_path / "*")) @@ -305,6 +383,7 @@ async def run( run_cfg, experiment_cfg, latents_path, + neighbours_path, explanations_path, scores_path, hookpoints, diff --git a/delphi/config.py b/delphi/config.py index 079b1e0e..39c0122f 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -151,6 +151,6 @@ class RunConfig: scoring speed but can leak information to the fuzzing and detection scorer, as well as increasing the scorer LLM task difficulty.""" - overwrite: list[Literal["cache", "scores"]] = list_field() + overwrite: list[Literal["cache", "neighbours", "scores"]] = list_field() """List of run stages to recompute. This is a debugging tool and may be removed in the future.""" From 46bfcb9875518d6d934bde86f5daa7624512d706 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 11:39:07 +0000 Subject: [PATCH 094/132] Fix all typing errors not envolving Nones or TensorTypes (I think) --- delphi/explainers/default/default.py | 111 ++++--------------- delphi/explainers/default/prompt_builder.py | 5 +- delphi/explainers/explainer.py | 116 +++++++++++++++++++- delphi/latents/constructors.py | 4 +- delphi/latents/latents.py | 9 +- delphi/latents/loader.py | 22 ++-- delphi/latents/neighbours.py | 10 +- delphi/latents/samplers.py | 6 +- 8 files changed, 159 insertions(+), 124 deletions(-) diff --git a/delphi/explainers/default/default.py b/delphi/explainers/default/default.py index d6a69037..c1419e70 100644 --- a/delphi/explainers/default/default.py +++ b/delphi/explainers/default/default.py @@ -1,8 +1,6 @@ import asyncio -import re -from ...logger import logger -from ..explainer import Explainer, ExplainerResult +from ..explainer import Example, Explainer from .prompt_builder import build_prompt @@ -20,98 +18,35 @@ def __init__( temperature: float = 0.0, **generation_kwargs, ): - self.client = client - self.tokenizer = tokenizer - self.verbose = verbose - - self.activations = activations - self.cot = cot - self.threshold = threshold - self.temperature = temperature - self.generation_kwargs = generation_kwargs - - async def __call__(self, record): - messages = self._build_prompt(record.train) - - response = await self.client.generate( - messages, temperature=self.temperature, **self.generation_kwargs + super().__init__( + client, + tokenizer, + verbose, + activations, + cot, + threshold, + temperature, + **generation_kwargs, ) - try: - explanation = self.parse_explanation(response.text) - if self.verbose: - logger.info(f"Explanation: {explanation}") - logger.info(f"Final message to explainer: {messages[-1]['content']}") - logger.info(f"Response from explainer: {response.text}") - - return ExplainerResult(record=record, explanation=explanation) - except Exception as e: - logger.error(f"Explanation parsing failed: {e}") - return ExplainerResult( - record=record, explanation="Explanation could not be parsed." - ) - - def parse_explanation(self, text: str) -> str: - try: - match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) - return ( - match.group(1).strip() if match else "Explanation could not be parsed." - ) - except Exception as e: - logger.error(f"Explanation parsing regex failed: {e}") - raise - - def _highlight(self, index, example): - result = f"Example {index}: " - - threshold = example.max_activation * self.threshold - if self.tokenizer is not None: - str_toks = self.tokenizer.batch_decode(example.tokens) - example.str_toks = str_toks - else: - str_toks = example.tokens - example.str_toks = str_toks - activations = example.activations - - def check(i): - return activations[i] > threshold - - i = 0 - while i < len(str_toks): - if check(i): - result += "<<" - - while i < len(str_toks) and check(i): - result += str_toks[i] - i += 1 - result += ">>" - else: - result += str_toks[i] - i += 1 - - return "".join(result) - - def _join_activations(self, example): - activations = [] - - for i, activation in enumerate(example.activations): - if activation > example.max_activation * self.threshold: - activations.append( - (example.str_toks[i], int(example.normalized_activations[i])) - ) - - acts = ", ".join(f'("{item[0]}" : {item[1]})' for item in activations) - - return "Activations: " + acts - - def _build_prompt(self, examples): + def _build_prompt(self, examples: list[Example]) -> list[dict]: highlighted_examples = [] for i, example in enumerate(examples): - highlighted_examples.append(self._highlight(i + 1, example)) + str_toks = self.tokenizer.batch_decode(example.tokens) + activations = example.activations.tolist() + highlighted_examples.append(self._highlight(str_toks, activations)) if self.activations: - highlighted_examples.append(self._join_activations(example)) + assert ( + example.normalized_activations is not None + ), "Normalized activations are required for activations in explainer" + normalized_activations = example.normalized_activations.tolist() + highlighted_examples.append( + self._join_activations( + str_toks, activations, normalized_activations + ) + ) highlighted_examples = "\n".join(highlighted_examples) diff --git a/delphi/explainers/default/prompt_builder.py b/delphi/explainers/default/prompt_builder.py index 8ca78140..55eddcae 100644 --- a/delphi/explainers/default/prompt_builder.py +++ b/delphi/explainers/default/prompt_builder.py @@ -26,10 +26,10 @@ def build_examples( def build_prompt( - examples, + examples: str, activations: bool = False, cot: bool = False, -): +) -> list[dict]: messages = system( cot=cot, ) @@ -49,7 +49,6 @@ def build_prompt( "content": user_start, } ) - print(messages) return messages diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 2ca9d35f..8eda3dd7 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -1,12 +1,14 @@ import json import os import random +import re from abc import ABC, abstractmethod from typing import NamedTuple import aiofiles -from ..latents.latents import LatentRecord +from ..latents.latents import Example, LatentRecord +from ..logger import logger class ExplainerResult(NamedTuple): @@ -18,18 +20,120 @@ class ExplainerResult(NamedTuple): class Explainer(ABC): + """ + Abstract base class for explainers. + """ + + def __init__( + self, + client, + tokenizer, + verbose: bool = False, + activations: bool = False, + cot: bool = False, + threshold: float = 0.6, + temperature: float = 0.0, + **generation_kwargs, + ): + self.client = client + self.tokenizer = tokenizer + self.verbose = verbose + self.activations = activations + self.cot = cot + self.threshold = threshold + self.temperature = temperature + self.generation_kwargs = generation_kwargs + + async def __call__(self, record: LatentRecord) -> ExplainerResult: + messages = self._build_prompt(record.train) + + response = await self.client.generate( + messages, temperature=self.temperature, **self.generation_kwargs + ) + + try: + explanation = self.parse_explanation(response.text) + if self.verbose: + logger.info(f"Explanation: {explanation}") + logger.info(f"Messages: {messages[-1]['content']}") + logger.info(f"Response: {response}") + + return ExplainerResult(record=record, explanation=explanation) + except Exception as e: + logger.error(f"Explanation parsing failed: {e}") + return ExplainerResult( + record=record, explanation="Explanation could not be parsed." + ) + + def parse_explanation(self, text: str) -> str: + try: + match = re.search(r"\[EXPLANATION\]:\s*(.*)", text, re.DOTALL) + if match: + return match.group(1).strip() + else: + return "Explanation could not be parsed." + except Exception as e: + logger.error(f"Explanation parsing regex failed: {e}") + raise + + def _highlight(self, str_toks: list[str], activations: list[float]) -> str: + result = "" + threshold = max(activations) * self.threshold + + def check(i): + return activations[i] > threshold + + i = 0 + while i < len(str_toks): + if check(i): + result += "<<" + + while i < len(str_toks) and check(i): + result += str_toks[i] + i += 1 + result += ">>" + else: + result += str_toks[i] + i += 1 + + return "".join(result) + + def _join_activations( + self, + str_toks: list[str], + token_activations: list[float], + normalized_activations: list[float], + ) -> str: + acts = "" + activation_count = 0 + for str_tok, token_activation, normalized_activation in zip( + str_toks, token_activations, normalized_activations + ): + if token_activation > max(token_activations) * self.threshold: + # TODO: for each example, we only show the first 10 activations + # decide on the best way to do this + if activation_count > 10: + break + acts += f'("{str_tok}" : {int(normalized_activation)}), ' + activation_count += 1 + + return "Activations: " + acts + @abstractmethod - def __call__(self, record: LatentRecord) -> ExplainerResult: + def _build_prompt(self, examples: list[Example]) -> list[dict]: pass async def explanation_loader( record: LatentRecord, explanation_dir: str ) -> ExplainerResult: - async with aiofiles.open(f"{explanation_dir}/{record.latent}.txt", "r") as f: - explanation = json.loads(await f.read()) - - return ExplainerResult(record=record, explanation=explanation) + try: + async with aiofiles.open(f"{explanation_dir}/{record.latent}.txt", "r") as f: + explanation = json.loads(await f.read()) + return ExplainerResult(record=record, explanation=explanation) + except FileNotFoundError: + print(f"No explanation found for {record.latent}") + return ExplainerResult(record=record, explanation="No explanation found") async def random_explanation_loader( diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 2742362e..96a388b6 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -11,7 +11,6 @@ def _top_k_pools( max_buffer: TensorType["batch"], split_activations: list[TensorType["activations"]], buffer_tokens: TensorType["batch", "ctx_len"], - ctx_len: int, max_examples: int, ): """ @@ -21,7 +20,6 @@ def _top_k_pools( max_buffer: The maximum buffer values. split_activations: The split activations. buffer_tokens: The buffer tokens. - ctx_len: The context length. max_examples: The maximum number of examples. Returns: @@ -81,7 +79,7 @@ def pool_max_activation_windows( buffer_tokens = buffer_tokens[unique_ctx_indices] token_windows, activation_windows = _top_k_pools( - max_buffer, new_tensor, buffer_tokens, ctx_len, max_examples + max_buffer, new_tensor, buffer_tokens, max_examples ) record.examples = prepare_examples(token_windows, activation_windows) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index c1f6e2f2..67c01d36 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -91,7 +91,7 @@ class LatentRecord: not_active: list[Example] = field(default_factory=list) """Non-activating examples.""" - train: list[list[Example]] = field(default_factory=list) + train: list[Example] = field(default_factory=list) """Training examples.""" test: list[list[Example]] = field(default_factory=list) @@ -132,7 +132,7 @@ def display( tokenizer: AutoTokenizer, threshold: float = 0.0, n: int = 10, - ) -> str: + ): """ Display the latent record in a formatted string. @@ -147,9 +147,7 @@ def display( """ from IPython.core.display import HTML, display - def _to_string( - tokens: TensorType["seq"], activations: TensorType["seq"] - ) -> str: + def _to_string(tokens: list[str], activations: TensorType["seq"]) -> str: """ Convert tokens and activations to a string. @@ -177,6 +175,7 @@ def _to_string( result.append(tokens[i]) i += 1 return "".join(result) + return "" strings = [ _to_string(tokenizer.batch_decode(example.tokens), example.activations) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 47b5b03b..eb0f53b7 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -141,13 +141,16 @@ def __init__( """ self.cfg = cfg self.buffers = [] - + if modules is None: + self.modules = os.listdir(raw_dir) + else: + self.modules = modules if latents is None: - self._build(raw_dir, modules) + self._build(raw_dir) else: - self._build_selected(raw_dir, modules, latents) + self._build_selected(raw_dir, latents) # TODO: this assumes that all modules have the same config - cache_config_dir = f"{raw_dir}/{modules[0]}/config.json" + cache_config_dir = f"{raw_dir}/{self.modules[0]}/config.json" with open(cache_config_dir, "r") as f: cache_config = json.load(f) if tokenizer is None: @@ -170,7 +173,7 @@ def load_tokens(self): if not hasattr(self, "tokens"): self.tokens = load_tokenized_data( self.cache_config["ctx_len"], - self.tokenizer, + self.tokenizer, # type: ignore self.cache_config["dataset_repo"], self.cache_config["dataset_split"], self.cache_config["dataset_name"], @@ -191,7 +194,7 @@ def _edges(self, raw_dir: str, module: str) -> list[tuple[int, int]]: edges.sort(key=lambda x: x[0]) return edges - def _build(self, raw_dir: str, modules: Optional[list[str]] = None): + def _build(self, raw_dir: str): """ Build dataset buffers which load all cached latents. @@ -199,9 +202,8 @@ def _build(self, raw_dir: str, modules: Optional[list[str]] = None): raw_dir (str): Directory containing raw latent data. modules (Optional[list[str]]): list of module names to include. """ - modules = os.listdir(raw_dir) if modules is None else modules - for module in modules: + for module in self.modules: edges = self._edges(raw_dir, module) for start, end in edges: path = f"{raw_dir}/{module}/{start}_{end}.safetensors" @@ -212,7 +214,6 @@ def _build(self, raw_dir: str, modules: Optional[list[str]] = None): def _build_selected( self, raw_dir: str, - modules: list[str], latents: dict[str, Union[int, torch.Tensor]], ): """ @@ -220,12 +221,11 @@ def _build_selected( Args: raw_dir (str): Directory containing raw latent data. - modules (list[str]): list of module names to include. latents (dict[str, Union[int, torch.Tensor]]): Dictionary of latents per module. """ - for module in modules: + for module in self.modules: edges = self._edges(raw_dir, module) selected_latents = latents[module] if isinstance(selected_latents, int): diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 0b370584..d1163a35 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -39,6 +39,7 @@ def __init__( self.latent_dataset = latent_dataset self.autoencoder = autoencoder self.pre_activation_record = pre_activation_record + self.number_of_neighbours = number_of_neighbours # load the neighbour cache from the path if neighbour_cache is not None: @@ -139,7 +140,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: """ Compute neighbour lists based on latent co-occurrence in the dataset. """ - # To be implemented paths = [] for buffer in self.latent_dataset.buffers: @@ -214,16 +214,16 @@ def populate_neighbour_cache(self, methods: list[str]) -> None: for method in methods: self._compute_neighbour_list(method) - def save_neighbour_cache(self) -> None: + def save_neighbour_cache(self, path: str) -> None: """ Save the neighbour cache to the path as a json file """ - with open(self.path, "w") as f: + with open(path, "w") as f: json.dump(self.neighbour_cache, f, indent=4) - def load_neighbour_cache(self) -> dict[str, dict[int, list[int]]]: + def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: """ Load the neighbour cache from the path as a json file """ - with open(self.path, "r") as f: + with open(path, "r") as f: return json.load(f) diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 3c0a2ca5..4a99feb6 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -31,7 +31,7 @@ def split_activation_quantiles( """ random.seed(seed) - examples = deque(examples) + queue_examples = deque(examples) max_activation = examples[0].max_activation # For 4 quantiles, thresholds are 0.25, 0.5, 0.75 @@ -41,8 +41,8 @@ def split_activation_quantiles( for threshold in thresholds: # Get all examples in quantile quantile = [] - while examples and examples[0].max_activation < threshold: - quantile.append(examples.popleft()) + while queue_examples and queue_examples[0].max_activation < threshold: + quantile.append(queue_examples.popleft()) sample = random.sample(quantile, n_samples) samples.append(sample) From c344cb2d82d7c338ca907e72883e8cf93c9c4bdd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Feb 2025 11:44:15 +0000 Subject: [PATCH 095/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 2 +- delphi/sparse_coders/__init__.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/README.md b/README.md index 0dbdb88a..a720fc48 100644 --- a/README.md +++ b/README.md @@ -16,7 +16,7 @@ Install this library as a local editable installation. Run the following command To run the default pipeline from the command line, use the following command: -`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B` +`python -m delphi meta-llama/Meta-Llama-3-8B EleutherAI/sae-llama-3-8b-32x --explainer_model 'hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4' --dataset_repo 'EleutherAI/fineweb-edu-dedup-10b' --dataset_split 'train[:1%]' --n_tokens 10_000_000 --max_latents 100 --hookpoints layers.5 --filter_bos --name llama-3-8B` This command will: 1. Cache activations for the first 10 million tokens of EleutherAI/rpj-v2-sample. diff --git a/delphi/sparse_coders/__init__.py b/delphi/sparse_coders/__init__.py index 731e9606..d72b3094 100644 --- a/delphi/sparse_coders/__init__.py +++ b/delphi/sparse_coders/__init__.py @@ -1,4 +1,3 @@ from .sparse_model import load_hooks_sparse_coders, load_sparse_coders __all__ = ["load_hooks_sparse_coders", "load_sparse_coders"] - From 4d3a17421d050bb249b2219326c42625ac324504 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 13:53:29 +0000 Subject: [PATCH 096/132] torchtyping to jaxtyping and other typing fixes --- delphi/__main__.py | 5 +- delphi/latents/cache.py | 78 +++++++++++++++-------- delphi/latents/collect_activations.py | 2 +- delphi/latents/constructors.py | 19 +++--- delphi/latents/latents.py | 23 +++---- delphi/latents/loader.py | 27 ++++---- delphi/latents/neighbours.py | 4 +- delphi/sparse_coders/custom/gemmascope.py | 2 +- delphi/sparse_coders/load_sparsify.py | 2 +- delphi/sparse_coders/sparse_model.py | 10 +-- delphi/tests/conftest.py | 35 ++++++---- delphi/tests/test_latents/test_cache.py | 8 +-- 12 files changed, 129 insertions(+), 86 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 186e9a9e..33e35e6d 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -12,7 +12,6 @@ from simple_parsing import ArgumentParser from sparsify.data import chunk_and_tokenize from torch import Tensor -from torchtyping import TensorType from transformers import ( AutoModel, AutoTokenizer, @@ -88,7 +87,7 @@ async def process_cache( latent_dict = { hook: latent_range for hook in hookpoints } # The latent range to explain - latent_dict = cast(dict[str, int | Tensor], latent_dict) + latent_dict = cast(dict[str, Tensor], latent_dict) constructor = partial( default_constructor, @@ -235,8 +234,6 @@ def populate_cache( ] tokens = truncated_tokens.reshape(-1, cfg.ctx_len) - tokens = cast(TensorType["batch", "seq"], tokens) - cache = LatentCache( model, hookpoint_to_sparse_encode, diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 3e2c0531..4fd8606a 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -5,10 +5,11 @@ import numpy as np import torch +from jaxtyping import Float from safetensors.numpy import save_file from torch import Tensor -from torchtyping import TensorType from tqdm import tqdm +from transformers import PreTrainedModel from delphi.config import CacheConfig from delphi.latents.collect_activations import collect_activations @@ -21,7 +22,9 @@ class Cache: """ def __init__( - self, filters: dict[str, TensorType["indices"]] = None, batch_size: int = 64 + self, + filters: dict[str, Float[Tensor, "indices"]] | None = None, + batch_size: int = 64, ): """ Initialize the Cache. @@ -30,16 +33,31 @@ def __init__( filters: Filters for selecting specific latents. batch_size: Size of batches for processing. Defaults to 64. """ - self.latent_locations = defaultdict(list) - self.latent_activations = defaultdict(list) - self.tokens = defaultdict(list) + self.latent_locations_batches: dict[ + str, list[Float[Tensor, "batch sequence num_latents"]] + ] = defaultdict(list) + self.latent_activations_batches: dict[ + str, list[Float[Tensor, "batch sequence num_latents"]] + ] = defaultdict(list) + self.tokens_batches: dict[ + str, list[Float[Tensor, "batch sequence"]] + ] = defaultdict(list) + + self.latent_locations: dict[ + str, Float[Tensor, "batch sequence num_latents"] + ] = {} + self.latent_activations: dict[ + str, Float[Tensor, "batch sequence num_latents"] + ] = {} + self.tokens: dict[str, Float[Tensor, "batch sequence"]] = {} + self.filters = filters self.batch_size = batch_size def add( self, - latents: TensorType["batch", "sequence", "latent"], - tokens: TensorType["batch", "sequence"], + latents: Float[Tensor, "mini_batch sequence num_latents"], + tokens: Float[Tensor, "mini_batch sequence"], batch_number: int, module_path: str, ): @@ -59,28 +77,32 @@ def add( # Adjust batch indices latent_locations[:, 0] += batch_number * self.batch_size - self.latent_locations[module_path].append(latent_locations) - self.latent_activations[module_path].append(latent_activations) - self.tokens[module_path].append(tokens) + self.latent_locations_batches[module_path].append(latent_locations) + self.latent_activations_batches[module_path].append(latent_activations) + self.tokens_batches[module_path].append(tokens) def save(self): """ Concatenate the latent locations and activations for all modules. """ - for module_path in self.latent_locations.keys(): + for module_path in self.latent_locations_batches.keys(): self.latent_locations[module_path] = torch.cat( - self.latent_locations[module_path], dim=0 + self.latent_locations_batches[module_path], dim=0 ) self.latent_activations[module_path] = torch.cat( - self.latent_activations[module_path], dim=0 + self.latent_activations_batches[module_path], dim=0 ) - self.tokens[module_path] = torch.cat(self.tokens[module_path], dim=0) + self.tokens[module_path] = torch.cat( + self.tokens_batches[module_path], dim=0 + ) def get_nonzeros_batch( - self, latents: TensorType["batch", "seq", "latent"] - ) -> tuple[Tensor, Tensor]: + self, latents: Float[Tensor, "batch sequence num_latents"] + ) -> tuple[ + Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "] + ]: """ Get non-zero activations for large batches that exceed int32 max value. @@ -115,8 +137,11 @@ def get_nonzeros_batch( return nonzero_latent_locations, nonzero_latent_activations def get_nonzeros( - self, latents: TensorType["batch", "seq", "latent"], module_path: str - ) -> tuple[Tensor, Tensor]: + self, latents: Float[Tensor, "batch sequence num_latents"], module_path: str + ) -> tuple[ + Float[Tensor, "batch sequence num_latents"], + Float[Tensor, "batch sequence num_latents"], + ]: """ Get the nonzero latent locations and activations. @@ -157,10 +182,10 @@ class LatentCache: def __init__( self, - model, + model: PreTrainedModel, hookpoint_to_sparse_encode: dict[str, Callable], batch_size: int, - filters: dict[str, TensorType["indices"]] | None = None, + filters: dict[str, Float[Tensor, "indices"]] | None = None, ): """ Initialize the LatentCache. @@ -181,8 +206,8 @@ def __init__( self.filter_submodules(filters) def load_token_batches( - self, n_tokens: int, tokens: TensorType["batch", "sequence"] - ) -> list[Tensor]: + self, n_tokens: int, tokens: Float[Tensor, "batch sequence"] + ) -> list[Float[Tensor, "batch sequence"]]: """ Load and prepare token batches for processing. @@ -205,7 +230,7 @@ def load_token_batches( return token_batches - def filter_submodules(self, filters: dict[str, TensorType["indices"]]): + def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]): """ Filter submodules based on the provided filters. @@ -218,7 +243,7 @@ def filter_submodules(self, filters: dict[str, TensorType["indices"]]): filtered_submodules[hookpoint] = self.hookpoint_to_sae[hookpoint] self.hookpoint_to_sae = filtered_submodules - def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]): + def run(self, n_tokens: int, tokens: Float[Tensor, "batch sequence"]): """ Run the latent caching process. @@ -277,9 +302,9 @@ def save(self, save_dir: Path, save_tokens: bool = True): if save_tokens: data["tokens"] = self.cache.tokens[module_path] - save_file(data, output_file) + save_file(data, output_file) # type: ignore - def _generate_split_indices(self, n_splits: int) -> list[tuple[int, int]]: + def _generate_split_indices(self, n_splits: int) -> list[tuple[Tensor, Tensor]]: """ Generate indices for splitting the latent space. @@ -289,6 +314,7 @@ def _generate_split_indices(self, n_splits: int) -> list[tuple[int, int]]: Returns: list[tuple[int, int]]: list of start and end indices for each split. """ + assert self.width is not None, "Width must be set before generating splits" boundaries = torch.linspace(0, self.width, steps=n_splits + 1).long() # Adjust end by one diff --git a/delphi/latents/collect_activations.py b/delphi/latents/collect_activations.py index 63c6a4b0..97489054 100644 --- a/delphi/latents/collect_activations.py +++ b/delphi/latents/collect_activations.py @@ -23,7 +23,7 @@ def collect_activations(model: PreTrainedModel, hookpoints: list[str]): handles = [] def create_hook(hookpoint: str): - def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor: + def hook_fn(module: nn.Module, input: Any, output: Tensor) -> Tensor | None: # If output is a tuple (like in some transformer layers), take first element if isinstance(output, tuple): activations[hookpoint] = output[0] diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 96a388b6..1a287506 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -1,18 +1,19 @@ from typing import Callable, Optional import torch -from torchtyping import TensorType +from jaxtyping import Float +from torch import Tensor from .latents import LatentRecord, prepare_examples from .loader import BufferOutput def _top_k_pools( - max_buffer: TensorType["batch"], - split_activations: list[TensorType["activations"]], - buffer_tokens: TensorType["batch", "ctx_len"], + max_buffer: Float[Tensor, "batch"], + split_activations: Float[Tensor, "activations ctx_len"], + buffer_tokens: Float[Tensor, "batch ctx_len"], max_examples: int, -): +) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: """ Get the top k activation pools. @@ -23,7 +24,7 @@ def _top_k_pools( max_examples: The maximum number of examples. Returns: - tuple[TensorType["examples", "ctx_len"], TensorType["examples", "ctx_len"]]: + tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: The token windows and activation windows. """ k = min(max_examples, len(max_buffer)) @@ -38,7 +39,7 @@ def _top_k_pools( def pool_max_activation_windows( record, buffer_output: BufferOutput, - tokens: TensorType["batch", "seq"], + tokens: Float[Tensor, "batch sequence"], ctx_len: int, max_examples: int, ): @@ -87,7 +88,7 @@ def pool_max_activation_windows( def random_non_activating_windows( record: LatentRecord, - tokens: TensorType["batch", "seq"], + tokens: Float[Tensor, "batch sequence"], buffer_output: BufferOutput, ctx_len: int, n_not_active: int, @@ -136,7 +137,7 @@ def random_non_activating_windows( def default_constructor( record: LatentRecord, - token_loader: Optional[Callable[[], TensorType["batch", "seq"]]] | None, + token_loader: Optional[Callable[[], Float[Tensor, "batch sequence"]]] | None, buffer_output: BufferOutput, n_not_active: int, ctx_len: int, diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 67c01d36..b4db1aa9 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -1,10 +1,11 @@ from dataclasses import dataclass, field -from typing import List, Optional +from typing import Optional import blobfile as bf import orjson -from torchtyping import TensorType -from transformers import AutoTokenizer +from jaxtyping import Float +from torch import Tensor +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast @dataclass @@ -13,13 +14,13 @@ class Example: A single example of latent data. """ - tokens: TensorType["seq"] + tokens: Float[Tensor, "ctx_len"] """Tokenized input sequence.""" - activations: TensorType["seq"] + activations: Float[Tensor, "ctx_len"] """Activation values for the input sequence.""" - normalized_activations: Optional[TensorType["seq"]] = None + normalized_activations: Optional[Float[Tensor, "ctx_len"]] = None """Activations quantized to integers in [0, 10].""" @property @@ -34,9 +35,9 @@ def max_activation(self) -> float: def prepare_examples( - tokens: List[TensorType["seq"]], - activations: List[TensorType["seq"]], -) -> List[Example]: + tokens: Float[Tensor, "examples ctx_len"], + activations: Float[Tensor, "examples ctx_len"], +) -> list[Example]: """ Prepare a list of examples from input tokens and activations. @@ -129,7 +130,7 @@ def save(self, directory: str, save_examples: bool = False): def display( self, - tokenizer: AutoTokenizer, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, threshold: float = 0.0, n: int = 10, ): @@ -147,7 +148,7 @@ def display( """ from IPython.core.display import HTML, display - def _to_string(tokens: list[str], activations: TensorType["seq"]) -> str: + def _to_string(tokens: list[str], activations: Float[Tensor, "ctx_len"]) -> str: """ Convert tokens and activations to a string. diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index eb0f53b7..db295891 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -3,12 +3,13 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Callable, NamedTuple, Optional, Union +from typing import Callable, NamedTuple, Optional import numpy as np import torch +from jaxtyping import Float from safetensors.numpy import load_file -from torchtyping import TensorType +from torch import Tensor from transformers import AutoTokenizer from delphi.utils import ( @@ -27,13 +28,13 @@ class BufferOutput(NamedTuple): latent: Latent """The latent associated with this output.""" - locations: TensorType["locations", 2] + locations: Float[Tensor, "locations 2"] """Tensor of latent locations.""" - activations: TensorType["locations"] + activations: Float[Tensor, "locations"] """Tensor of latent activations.""" - tokens: TensorType["tokens"] + tokens: Float[Tensor, "tokens"] """Tensor of all tokens.""" @@ -49,12 +50,15 @@ class TensorBuffer: module_path: str """Path of the module.""" - latents: Optional[TensorType["latents"]] = None + latents: Optional[Float[Tensor, "num_latents"]] = None """Tensor of latent indices.""" min_examples: int = 120 """Minimum number of examples required. Defaults to 120.""" + tokens: Optional[Float[Tensor, "batch sequence"]] = None + """Tensor of all tokens.""" + def __iter__(self): """ Iterate over the buffer, yielding BufferOutput objects. @@ -64,7 +68,10 @@ def __iter__(self): None otherwise. """ latents, split_locations, split_activations, tokens = self.load() - + if tokens is None: + tokens = self.tokens + if tokens is None: + raise ValueError("No tokens found") for i in range(len(latents)): latent_locations = split_locations[i] latent_activations = split_activations[i] @@ -125,7 +132,7 @@ def __init__( cfg: LatentConfig, tokenizer: Optional[Callable] = None, modules: Optional[list[str]] = None, - latents: Optional[dict[str, Union[int, torch.Tensor]]] = None, + latents: Optional[dict[str, torch.Tensor]] = None, constructor: Optional[Callable] = None, sampler: Optional[Callable] = None, transform: Optional[Callable] = None, @@ -214,7 +221,7 @@ def _build(self, raw_dir: str): def _build_selected( self, raw_dir: str, - latents: dict[str, Union[int, torch.Tensor]], + latents: dict[str, torch.Tensor], ): """ Build a dataset buffer which loads only selected latents. @@ -228,8 +235,6 @@ def _build_selected( for module in self.modules: edges = self._edges(raw_dir, module) selected_latents = latents[module] - if isinstance(selected_latents, int): - selected_latents = torch.tensor([selected_latents]) boundaries = [edges[0][0]] + [edge[1] + 1 for edge in edges] bucketized = torch.bucketize( diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index d1163a35..0a355a9c 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -6,7 +6,9 @@ import numpy as np import torch from safetensors.numpy import load_file +from torch import nn +from delphi.latents.latents import PreActivationRecord from delphi.latents.loader import LatentDataset @@ -21,7 +23,7 @@ class NeighbourCalculator: def __init__( self, latent_dataset: Optional["LatentDataset"] = None, - autoencoder: Optional["Autoencoder"] = None, + autoencoder: Optional["nn.Module"] = None, pre_activation_record: Optional["PreActivationRecord"] = None, number_of_neighbours: int = 10, neighbour_cache: Optional[dict[str, dict[int, list[int]]]] = None, diff --git a/delphi/sparse_coders/custom/gemmascope.py b/delphi/sparse_coders/custom/gemmascope.py index 1372ef4e..8a66f20c 100644 --- a/delphi/sparse_coders/custom/gemmascope.py +++ b/delphi/sparse_coders/custom/gemmascope.py @@ -14,7 +14,7 @@ def load_gemma_autoencoders( type: str, dtype: torch.dtype = torch.bfloat16, device: torch.device = torch.device("cuda"), -): +) -> dict[str, nn.Module]: saes = {} for layer, size, l0 in zip(ae_layers, sizes, average_l0s): diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index 6da26ce9..7c58c621 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -46,7 +46,7 @@ def load_sparsify_sparse_coders( hookpoints: list[str], device: str | torch.device | None = None, compile: bool = False, -) -> dict[str, Callable]: +) -> dict[str, Sae]: """ Load sparsify sparse coders for specified hookpoints. diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index 85f747cd..5b77feae 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -6,7 +6,7 @@ from delphi.config import RunConfig from .custom.gemmascope import load_gemma_autoencoders -from .load_sparsify import load_sparsify_hooks, load_sparsify_sparse_coders +from .load_sparsify import Sae, load_sparsify_hooks, load_sparsify_sparse_coders def load_hooks_sparse_coders( @@ -62,7 +62,9 @@ def load_hooks_sparse_coders( dtype=model.dtype, device=model.device, ) - + # throw an error if the dictionary is empty + if not hookpoint_to_sparse_encode: + raise ValueError("No sparse coders loaded") return hookpoint_to_sparse_encode @@ -70,7 +72,7 @@ def load_sparse_coders( model: PreTrainedModel, run_cfg: RunConfig, compile: bool = False, -) -> dict[str, nn.Module]: +) -> dict[str, nn.Module] | dict[str, Sae]: """ Load sparse coders for specified hookpoints. @@ -79,7 +81,7 @@ def load_sparse_coders( run_cfg (RunConfig): The run configuration. Returns: - dict[str, Callable]: A dictionary mapping hookpoints to sparse coders. + dict[str, nn.Module]: A dictionary mapping hookpoints to sparse coders. """ # Add SAE hooks to the model diff --git a/delphi/tests/conftest.py b/delphi/tests/conftest.py index e241c4ee..63151218 100644 --- a/delphi/tests/conftest.py +++ b/delphi/tests/conftest.py @@ -1,6 +1,15 @@ +from typing import cast + import pytest import torch -from transformers import AutoModelForCausalLM, AutoTokenizer +from torch import Tensor +from transformers import ( + AutoModel, + AutoTokenizer, + PreTrainedModel, + PreTrainedTokenizer, + PreTrainedTokenizerFast, +) from delphi.config import CacheConfig, RunConfig from delphi.latents import LatentCache @@ -24,30 +33,31 @@ @pytest.fixture(scope="module") -def tokenizer(): - tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-70m") +def tokenizer() -> PreTrainedTokenizer | PreTrainedTokenizerFast: + tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-160m") tokenizer.pad_token = tokenizer.eos_token return tokenizer @pytest.fixture(scope="module") -def model(): - model = AutoModelForCausalLM.from_pretrained("EleutherAI/pythia-70m") +def model() -> PreTrainedModel: + model = AutoModel.from_pretrained("EleutherAI/pythia-160m") return model @pytest.fixture(scope="module") -def mock_dataset(tokenizer: AutoTokenizer) -> torch.Tensor: +def mock_dataset(tokenizer: PreTrainedTokenizer) -> torch.Tensor: tokens = tokenizer( random_text, return_tensors="pt", truncation=True, max_length=16, padding=True )["input_ids"] + tokens = cast(Tensor, tokens) + print(tokens) + print(tokens.shape) return tokens @pytest.fixture(scope="module") -def cache_setup( - tmp_path_factory, mock_dataset: torch.Tensor, model: AutoModelForCausalLM -): +def cache_setup(tmp_path_factory, mock_dataset: torch.Tensor, model: PreTrainedModel): """ This fixture creates a temporary directory, loads the model, initializes the cache, runs the cache once, saves the cache splits @@ -58,12 +68,12 @@ def cache_setup( # Load model and set run configuration run_cfg_gemma = RunConfig( - model="EleutherAI/pythia-70m", - sparse_model="EleutherAI/sae-pythia-70m-32k", + model="EleutherAI/pythia-160m", + sparse_model="EleutherAI/sae-pythia-160m-32k", hookpoints=["layers.1"], ) hookpoint_to_sparse_encode = load_hooks_sparse_coders(model, run_cfg_gemma) - + print(hookpoint_to_sparse_encode) # Define cache config and initialize cache cache_cfg = CacheConfig(batch_size=1, ctx_len=16, n_tokens=100) cache = LatentCache( @@ -81,7 +91,6 @@ def cache_setup( # Save the cache config cache.save_config(temp_dir, cache_cfg, "EleutherAI/pythia-70m") - return { "cache": cache, "tokens": tokens, diff --git a/delphi/tests/test_latents/test_cache.py b/delphi/tests/test_latents/test_cache.py index 198f0cf3..3d60d094 100644 --- a/delphi/tests/test_latents/test_cache.py +++ b/delphi/tests/test_latents/test_cache.py @@ -13,7 +13,7 @@ def test_latent_locations(cache_setup: dict[str, Any]): shape and values. """ cache = cache_setup["cache"] - locations = cache.cache.latent_locations["gpt_neox.layers.1"] + locations = cache.cache.latent_locations["layers.1"] max_values, _ = locations.max(axis=0) # Expected values based on the cache run assert max_values[0] == 5, "Expected first dimension max value to be 5" @@ -25,7 +25,7 @@ def test_split_files_created(cache_setup: dict[str, Any]): """ Test that exactly 5 cache split files have been created. """ - save_dir = cache_setup["temp_dir"] / "gpt_neox.layers.1" + save_dir = cache_setup["temp_dir"] / "layers.1" cache_files = [f for f in os.listdir(save_dir) if f.endswith(".safetensors")] assert len(cache_files) == 5, "Expected 5 split files in the cache directory" @@ -37,7 +37,7 @@ def test_split_file_contents(cache_setup: dict[str, Any]): - tokens were correctly stored and match the input tokens. - latent max values are as expected. """ - save_dir = cache_setup["temp_dir"] / "gpt_neox.layers.1" + save_dir = cache_setup["temp_dir"] / "layers.1" tokens = cache_setup["tokens"] # Choose one file to verify cache_files = os.listdir(save_dir) @@ -67,7 +67,7 @@ def test_config_file(cache_setup: dict[str, Any]): """ Test that the saved configuration file contains the correct parameters. """ - config_path = cache_setup["temp_dir"] / "gpt_neox.layers.1" / "config.json" + config_path = cache_setup["temp_dir"] / "layers.1" / "config.json" with open(config_path, "r") as f: config = json.load(f) cache_cfg = cache_setup["cache_cfg"] From c0963b4dba77afa658637090b61eb0253c4065f4 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 13:56:54 +0000 Subject: [PATCH 097/132] torchtyping -> jaxtyping --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index ed9ae112..36d1b782 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ dependencies = [ "sparsify@git+https://github.com/EleutherAI/sparsify", "safetensors", "simple_parsing", - "torchtyping", + "jaxtyping", "fire", "blobfile", "bitsandbytes", From 813a189bdb162ed703186767f7409a3b82681c00 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 14:20:06 +0000 Subject: [PATCH 098/132] Reformulating neighbours --- delphi/latents/neighbours.py | 250 ++++++++++++++++++----------------- 1 file changed, 127 insertions(+), 123 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 99e3932e..b41725de 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -1,13 +1,14 @@ import json -from typing import Optional +import os +from pathlib import Path +from typing import Literal, Optional import numpy as np import torch from safetensors.numpy import load_file +from sparsify import Sae from torch import nn - -from delphi.latents.latents import PreActivationRecord -from delphi.latents.loader import LatentDataset +from tqdm import tqdm class NeighbourCalculator: @@ -20,9 +21,9 @@ class NeighbourCalculator: def __init__( self, - latent_dataset: Optional["LatentDataset"] = None, - autoencoder: Optional["nn.Module"] = None, - pre_activation_record: Optional["PreActivationRecord"] = None, + cache_dir: Optional[Path] = None, + autoencoder: Optional[nn.Module | Sae] = None, + # pre_activation_record: Optional["PreActivationRecord"] = None, number_of_neighbours: int = 10, neighbour_cache: Optional[dict[str, dict[int, list[tuple[int, float]]]]] = None, ): @@ -36,9 +37,9 @@ def __init__( residual_stream_record (Optional[ResidualStreamRecord]): Record of residual stream values """ - self.latent_dataset = latent_dataset + self.cache_dir = cache_dir self.autoencoder = autoencoder - self.residual_stream_record = residual_stream_record + # self.residual_stream_record = residual_stream_record self.number_of_neighbours = number_of_neighbours # load the neighbour cache from the path @@ -46,9 +47,12 @@ def __init__( self.neighbour_cache = neighbour_cache else: # dictionary to cache computed neighbour lists - self.neighbour_cache: dict[str, dict[int, list[int]]] = {} + self.neighbour_cache: dict[str, dict[int, list[tuple[int, float]]]] = {} - def _compute_neighbour_list(self, method: str) -> None: + def _compute_neighbour_list( + self, + method: Literal["similarity_encoder", "similarity_decoder", "co-occurrence"], + ) -> None: """ Compute complete neighbour lists using specified method. @@ -56,59 +60,53 @@ def _compute_neighbour_list(self, method: str) -> None: method (str): One of 'similarity', 'correlation', or 'co-occurrence' """ if method == "similarity_encoder": - if self.autoencoder is None: - raise ValueError( - "Autoencoder is required for similarity-based neighbours" - ) self.neighbour_cache[method] = self._compute_similarity_neighbours( "encoder" ) elif method == "similarity_decoder": - if self.autoencoder is None: - raise ValueError( - "Autoencoder is required for similarity-based neighbours" - ) self.neighbour_cache[method] = self._compute_similarity_neighbours( "decoder" ) - elif method == "correlation": - if self.autoencoder is None or self.residual_stream_record is None: - raise ValueError( - "Autoencoder and residual stream record are required " - "for correlation-based neighbours" - ) - self.neighbour_cache[method] = self._compute_correlation_neighbours() + # elif method == "correlation": + # if self.autoencoder is None or self.residual_stream_record is None: + # raise ValueError( + # "Autoencoder and residual stream record are required " + # "for correlation-based neighbours" + # ) + # self.neighbour_cache[method] = self._compute_correlation_neighbours() elif method == "co-occurrence": - if self.latent_dataset is None: - raise ValueError( - "Latent dataset is required for co-occurrence-based neighbours" - ) self.neighbour_cache[method] = self._compute_cooccurrence_neighbours() else: raise ValueError( - f"Unknown method: {method}. Use 'similarity', 'correlation', " - "or 'co-occurrence'" + f"Unknown method: {method}. Use 'similarity_encoder'," + "'similarity_decoder', or 'co-occurrence'" ) - def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: + def _compute_similarity_neighbours( + self, method: Literal["encoder", "decoder"] + ) -> dict[int, list[tuple[int, float]]]: """ Compute neighbour lists based on weight similarity in the autoencoder. """ + assert ( + self.autoencoder is not None + ), "Autoencoder is required for similarity-based neighbours" print("Computing similarity neighbours") # We use the encoder vectors to compute the similarity between latents if method == "encoder": - encoder = self.autoencoder.encoder.cuda() - weight_matrix_normalized = encoder.weight.data / encoder.weight.data.norm( - dim=1, keepdim=True - ) + encoder = self.autoencoder.encoder.weight.data.cuda() + weight_matrix_normalized = encoder / encoder.norm(dim=1, keepdim=True) elif method == "decoder": - decoder = self.autoencoder.W_dec.cuda() - weight_matrix_normalized = decoder.data / decoder.data.norm( - dim=1, keepdim=True - ) + # TODO: we would probably go around this by + # having a autoencoder wrapper + assert isinstance( + self.autoencoder, Sae + ), "Autoencoder must be a sparsify.Sae for decoder similarity" + decoder = self.autoencoder.W_dec.data.cuda() # type: ignore + weight_matrix_normalized = decoder / decoder.norm(dim=1, keepdim=True) else: raise ValueError(f"Unknown method: {method}. Use 'encoder' or 'decoder'") @@ -149,62 +147,7 @@ def _compute_similarity_neighbours(self, method: str) -> dict[int, list[int]]: return neighbour_lists - def _compute_correlation_neighbours(self) -> dict[int, list[int]]: - """ - Compute neighbour lists based on activation correlation patterns. - """ - print("Computing correlation neighbours") - - # the activation_matrix has the shape (number_of_samples,hidden_dimension) - - activations = torch.tensor( - load_file(self.residual_stream_record + ".safetensors")["activations"] - ) - - estimator = CovarianceEstimator(activations.shape[1]) - # batch the activations - batch_size = 10000 - for i in tqdm(range(0, activations.shape[0], batch_size)): - estimator.update(activations[i : i + batch_size]) - - covariance_matrix = estimator.cov().cuda().half() - - # load the encoder - encoder_matrix = self.autoencoder.encoder.weight.cuda().half() - - covariance_between_latents = torch.zeros( - (encoder_matrix.shape[0], encoder_matrix.shape[0]), device="cpu" - ) - - # do batches of latents - batch_size = 1024 - for start in tqdm(range(0, encoder_matrix.shape[0], batch_size)): - end = min(encoder_matrix.shape[0], start + batch_size) - encoder_rows = encoder_matrix[start:end] - - correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T - covariance_between_latents[start:end] = correlation.cpu() - - # the correlation is then the covariance divided - # by the product of the standard deviations - diagonal_covariance = torch.diag(covariance_between_latents) - product_of_std = torch.sqrt( - torch.outer(diagonal_covariance, diagonal_covariance) + 1e-6 - ) - correlation_matrix = covariance_between_latents / product_of_std - - # get the indices of the top k neighbours for each feature - indices, values = torch.topk( - correlation_matrix, self.number_of_neighbours + 1, dim=1 - ) - - # return the neighbour lists - return { - i: list(zip(indices[i].tolist()[1:], values[i].tolist()[1:])) - for i in range(len(indices)) - } - - def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: + def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]]: """ Compute neighbour lists based on feature co-occurrence in the dataset. If you run out of memory try reducing the token_batch_size @@ -215,9 +158,10 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[int]]: import cupyx.scipy.sparse as cusparse print("Computing co-occurrence neighbours") - paths = [] - for buffer in self.latent_dataset.buffers: - paths.append(buffer.path) + assert ( + self.cache_dir is not None + ), "Cache directory is required for co-occurrence-based neighbours" + paths = os.listdir(self.cache_dir) all_locations = [] for path in paths: @@ -299,7 +243,12 @@ def compute_jaccard(cooc_matrix): for i in range(len(top_k_indices)) } - def populate_neighbour_cache(self, methods: list[str]) -> None: + def populate_neighbour_cache( + self, + methods: list[ + Literal["similarity_encoder", "similarity_decoder", "co-occurrence"] + ], + ) -> None: """ Populate the neighbour cache with the computed neighbour lists """ @@ -321,25 +270,80 @@ def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: return json.load(f) -class CovarianceEstimator: - def __init__(self, n_latents, *, device=None): - self.mean = torch.zeros(n_latents, device=device) - self.cov_ = torch.zeros(n_latents, n_latents, device=device) - self.n = 0 - - def update(self, x: torch.Tensor): - n, d = x.shape - assert d == len(self.mean) - - self.n += n - - # Welford's online algorithm - delta = x - self.mean - self.mean.add_(delta.sum(dim=0), alpha=1 / self.n) - delta2 = x - self.mean - - self.cov_.addmm_(delta.mH, delta2) - - def cov(self): - """Return the estimated covariance matrix.""" - return self.cov_ / self.n +# TODO: add correlation neighbours, by re-adding activation records +# def _compute_correlation_neighbours(self) -> dict[int, list[int]]: +# """ +# Compute neighbour lists based on activation correlation patterns. +# """ +# print("Computing correlation neighbours") + +# # the activation_matrix has the shape (number_of_samples,hidden_dimension) + +# activations = torch.tensor( +# load_file(self.residual_stream_record + ".safetensors")["activations"] +# ) + +# estimator = CovarianceEstimator(activations.shape[1]) +# # batch the activations +# batch_size = 10000 +# for i in tqdm(range(0, activations.shape[0], batch_size)): +# estimator.update(activations[i : i + batch_size]) + +# covariance_matrix = estimator.cov().cuda().half() + +# # load the encoder +# encoder_matrix = self.autoencoder.encoder.weight.cuda().half() + +# covariance_between_latents = torch.zeros( +# (encoder_matrix.shape[0], encoder_matrix.shape[0]), device="cpu" +# ) + +# # do batches of latents +# batch_size = 1024 +# for start in tqdm(range(0, encoder_matrix.shape[0], batch_size)): +# end = min(encoder_matrix.shape[0], start + batch_size) +# encoder_rows = encoder_matrix[start:end] + +# correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T +# covariance_between_latents[start:end] = correlation.cpu() + +# # the correlation is then the covariance divided +# # by the product of the standard deviations +# diagonal_covariance = torch.diag(covariance_between_latents) +# product_of_std = torch.sqrt( +# torch.outer(diagonal_covariance, diagonal_covariance) + 1e-6 +# ) +# correlation_matrix = covariance_between_latents / product_of_std + +# # get the indices of the top k neighbours for each feature +# indices, values = torch.topk( +# correlation_matrix, self.number_of_neighbours + 1, dim=1 +# ) + +# # return the neighbour lists +# return { +# i: list(zip(indices[i].tolist()[1:], values[i].tolist()[1:])) +# for i in range(len(indices)) +# } +# class CovarianceEstimator: +# def __init__(self, n_latents, *, device=None): +# self.mean = torch.zeros(n_latents, device=device) +# self.cov_ = torch.zeros(n_latents, n_latents, device=device) +# self.n = 0 + +# def update(self, x: torch.Tensor): +# n, d = x.shape +# assert d == len(self.mean) + +# self.n += n + +# # Welford's online algorithm +# delta = x - self.mean +# self.mean.add_(delta.sum(dim=0), alpha=1 / self.n) +# delta2 = x - self.mean + +# self.cov_.addmm_(delta.mH, delta2) + +# def cov(self): +# """Return the estimated covariance matrix.""" +# return self.cov_ / self.n From f001b41ef885eed41d9374388cdb55f98a609f98 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 14:39:10 +0000 Subject: [PATCH 099/132] Type hints and simplification --- delphi/__main__.py | 40 +--------------------------------- delphi/latents/constructors.py | 37 +++++++++++++++---------------- delphi/latents/latents.py | 4 ++++ delphi/latents/loader.py | 18 ++++++++++----- 4 files changed, 35 insertions(+), 64 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index a39f178d..b77eda95 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -61,45 +61,14 @@ def load_artifacts(run_cfg: RunConfig): async def create_neighbours( - latent_cfg: LatentConfig, - experiment_cfg: ExperimentConfig, - hookpoints: list[str], latents_path: Path, neighbours_path: Path, - tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, - latent_range: Tensor | None, ): """ Creates a neighbours file for the given hookpoints. """ - if latent_range is None: - latent_dict = None - else: - latent_dict = { - hook: latent_range for hook in hookpoints - } # The latent range to explain - latent_dict = cast(dict[str, int | Tensor], latent_dict) - example_constructor = partial( - constructor, - n_not_active=experiment_cfg.n_non_activating, - constructor_type=experiment_cfg.non_activating_source, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=latent_cfg.max_examples, - ) - sampler = partial(sample, cfg=experiment_cfg) - - dataset = LatentDataset( - raw_dir=str(latents_path), - cfg=latent_cfg, - modules=hookpoints, - latents=latent_dict, - tokenizer=tokenizer, - constructor=example_constructor, - sampler=sampler, - ) - neighbour_calculator = NeighbourCalculator( - latent_dataset=dataset, number_of_neighbours=100 + cache_dir=latents_path, number_of_neighbours=100 ) neighbour_calculator.populate_neighbour_cache(["co-occurrence"]) @@ -357,16 +326,9 @@ async def run( not glob(str(neighbours_path / ".*")) + glob(str(neighbours_path / "*")) or "neighbours" in run_cfg.overwrite ): - # TODO: we probably want to use less arguments and load the latent dataset - # only once? await create_neighbours( - latent_cfg, - experiment_cfg, - hookpoints, latents_path, neighbours_path, - tokenizer, - latent_range, ) else: print(f"Files found in {neighbours_path}, skipping...") diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 9b4f4580..f27d43f3 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -24,7 +24,6 @@ def _top_k_pools( max_examples: The maximum number of examples. Returns: - Tuple[TensorType["examples", "ctx_len"], TensorType["examples", "ctx_len"]]: The token windows and activation windows. """ k = min(max_examples, len(max_buffer)) @@ -37,23 +36,23 @@ def _top_k_pools( def pool_max_activation_windows( - activations: TensorType["n_examples"], - tokens: TensorType["windows", "seq"], - ctx_indices: TensorType["n_examples"], - index_within_ctx: TensorType["n_examples"], + activations: Float[Tensor, "examples"], + tokens: Float[Tensor, "windows seq"], + ctx_indices: Float[Tensor, "examples"], + index_within_ctx: Float[Tensor, "examples"], ctx_len: int, max_examples: int, -): +) -> tuple[Float[Tensor, "examples ctx_len"], Float[Tensor, "examples ctx_len"]]: """ Pool max activation windows from the buffer output and update the latent record. Args: - activations (TensorType["n_examples"]): The activations. - tokens (TensorType["windows", "seq"]): The input tokens. - ctx_indices (TensorType["n_examples"]): The context indices. - index_within_ctx (TensorType["n_examples"]): The index within the context. - ctx_len (int): The context length. - max_examples (int): The maximum number of examples. + activations : The activations. + tokens : The input tokens. + ctx_indices : The context indices. + index_within_ctx : The index within the context. + ctx_len : The context length. + max_examples : The maximum number of examples. """ # unique_ctx_indices: array of distinct context window indices in order of first # appearance. sequential integers from 0 to batch_size * cache_token_length//ctx_len @@ -86,8 +85,8 @@ def constructor( n_not_active: int, max_examples: int, ctx_len: int, - constructor_type: Literal["random", "neighbour"], - tokens: TensorType["tokens"], + constructor_type: Literal["random", "neighbours"], + tokens: Float[Tensor, "batch seq"], all_data: Optional[dict[int, ActivationData]] = None, ): cache_token_length = tokens.shape[1] @@ -129,7 +128,7 @@ def constructor( reshaped_tokens=reshaped_tokens, n_not_active=n_not_active, ) - elif constructor_type == "neighbour": + elif constructor_type == "neighbours": assert all_data is not None, "All data is required for neighbour constructor" neighbour_non_activation_windows( record, @@ -143,8 +142,8 @@ def constructor( def neighbour_non_activation_windows( record: LatentRecord, - not_active_mask: TensorType["n_windows"], - tokens: TensorType["batch", "seq"], + not_active_mask: Float[Tensor, "windows"], + tokens: Float[Tensor, "batch seq"], all_data: dict[int, ActivationData], ctx_len: int, n_not_active: int, @@ -236,8 +235,8 @@ def neighbour_non_activation_windows( def random_non_activating_windows( record: LatentRecord, - available_indices: TensorType["n_windows"], - reshaped_tokens: TensorType["n_windows", "ctx_len"], + available_indices: Float[Tensor, "windows"], + reshaped_tokens: Float[Tensor, "windows ctx_len"], n_not_active: int, ): """ diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 5616e929..b104890b 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -13,6 +13,7 @@ class Neighbour: distance: float latent_index: int + @dataclass class Example: """ @@ -106,6 +107,9 @@ class LatentRecord: neighbours: list[Neighbour] = field(default_factory=list) """Neighbours of the latent.""" + explanation: str = "" + """Explanation of the latent.""" + @property def max_activation(self) -> float: """ diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index d19522eb..cb29d0e7 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -25,10 +25,10 @@ class ActivationData(NamedTuple): Represents the activation data for a latent. """ - locations: TensorType["locations", 2] + locations: Float[Tensor, "locations 2"] """Tensor of latent locations.""" - activations: TensorType["activations"] + activations: Float[Tensor, "activations"] """Tensor of latent activations.""" @@ -65,7 +65,7 @@ class TensorBuffer: min_examples: int = 120 """Minimum number of examples required. Defaults to 120.""" - _tokens: Optional[TensorType["tokens"]] = None + _tokens: Optional[Float[Tensor, "batch seq"]] = None """Tensor of tokens.""" def __iter__(self): @@ -91,7 +91,7 @@ def __iter__(self): ) @property - def tokens(self) -> TensorType["tokens"]: + def tokens(self) -> Float[Tensor, "batch seq"] | None: if self._tokens is None: self._tokens = self.load_tokens() return self._tokens @@ -110,7 +110,13 @@ def load_data_per_latent(self): return latents, split_locations, split_activations - def load(self): + def load( + self, + ) -> tuple[ + Float[Tensor, "locations 2"], + Float[Tensor, "activations"], + Float[Tensor, "batch seq"] | None, + ]: split_data = load_file(self.path) first_latent = int(self.path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) @@ -129,7 +135,7 @@ def load(self): return locations, activations, tokens - def load_tokens(self): + def load_tokens(self) -> Float[Tensor, "batch seq"] | None: _, _, tokens = self.load() return tokens From 5b0ea4e454dd5e48bb99071039fd0ce23382b8cc Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 14:41:54 +0000 Subject: [PATCH 100/132] Last torchtypings --- delphi/latents/samplers.py | 11 ++++------- delphi/utils.py | 5 +++-- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 4a99feb6..d3e685be 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -1,8 +1,6 @@ import random from collections import deque -from typing import Literal, cast - -from torchtyping import TensorType +from typing import Literal from ..config import ExperimentConfig from ..logger import logger @@ -121,10 +119,9 @@ def train( selected_examples = [] for quantile in selected_examples_quantiles: for example in quantile: - example.normalized_activations = cast( - TensorType["seq"], - (example.activations * 10 / max_activation).floor(), - ) + example.normalized_activations = ( + example.activations * 10 / max_activation + ).floor() selected_examples.extend(quantile) return selected_examples diff --git a/delphi/utils.py b/delphi/utils.py index cb2b56aa..6796f028 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -1,6 +1,7 @@ from typing import Any, Type, TypeVar, cast -from torchtyping import TensorType +from jaxtyping import Float +from torch import Tensor from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast @@ -28,7 +29,7 @@ def load_tokenized_data( ) tokens_ds = tokens_ds.shuffle(seed) - tokens = cast(TensorType["batch", "seq"], tokens_ds["input_ids"]) + tokens = cast(Float[Tensor, "batch seq"], tokens_ds["input_ids"]) return tokens From fa54ba4cdcb0ee72c7f924840fc2a5b247e00717 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 14:50:22 +0000 Subject: [PATCH 101/132] Use utils tokenized_data --- delphi/__main__.py | 18 +++++++++--------- delphi/utils.py | 38 +++++--------------------------------- 2 files changed, 14 insertions(+), 42 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index b77eda95..d5d3b38d 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -8,9 +8,7 @@ import orjson import torch -from datasets import load_dataset from simple_parsing import ArgumentParser -from sparsify.data import chunk_and_tokenize from torch import Tensor from transformers import ( AutoModel, @@ -33,6 +31,7 @@ from delphi.pipeline import Pipe, Pipeline, process_wrapper from delphi.scorers import DetectionScorer, FuzzingScorer from delphi.sparse_coders import load_hooks_sparse_coders +from delphi.utils import load_tokenized_data def load_artifacts(run_cfg: RunConfig): @@ -243,14 +242,15 @@ def populate_cache( """ latents_path.mkdir(parents=True, exist_ok=True) - data = load_dataset( - cfg.dataset_repo, name=cfg.dataset_name, split=cfg.dataset_split + tokens = load_tokenized_data( + cfg.ctx_len, + tokenizer, + cfg.dataset_repo, + cfg.dataset_split, + cfg.dataset_name, + cfg.dataset_column, + run_cfg.seed, ) - data = data.shuffle(run_cfg.seed) - data = chunk_and_tokenize( - data, tokenizer, max_seq_len=cfg.ctx_len, text_key=cfg.dataset_column - ) - tokens = data["input_ids"] if run_cfg.filter_bos: if tokenizer.bos_token_id is None: diff --git a/delphi/utils.py b/delphi/utils.py index 65e0ff85..b5831011 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -1,54 +1,26 @@ -from typing import Any, Type, TypeVar, cast - -from jaxtyping import Float -from torch import Tensor -from transformers import AutoTokenizer, PreTrainedTokenizer, PreTrainedTokenizerFast +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast def load_tokenized_data( ctx_len: int, - tokenizer: AutoTokenizer | PreTrainedTokenizer | PreTrainedTokenizerFast, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, dataset_repo: str, dataset_split: str, dataset_name: str = "", column_name: str = "text", seed: int = 22, - add_bos_token: bool = True, ): """ Load a huggingface dataset, tokenize it, and shuffle. + Using this function ensures we are using the same tokens everywhere. """ from datasets import load_dataset from sparsify.data import chunk_and_tokenize data = load_dataset(dataset_repo, name=dataset_name, split=dataset_split) + data = data.shuffle(seed) tokens_ds = chunk_and_tokenize( data, tokenizer, max_seq_len=ctx_len, text_key=column_name ) - tokens_ds = tokens_ds.shuffle(seed) - - tokens = cast(Float[Tensor, "batch seq"], tokens_ds["input_ids"]) - - return tokens - - -def load_filter(path: str, device: str = "cuda:0"): - import json - - import torch - - with open(path) as f: - filter = json.load(f) - - return {key: torch.tensor(value, device=device) for key, value in filter.items()} - - -T = TypeVar("T") - - -def assert_type(typ: Type[T], obj: Any) -> T: - """Assert that an object is of a given type at runtime and return it.""" - if not isinstance(obj, typ): - raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") - return cast(typ, obj) + return tokens_ds From 5346b13aef67782b2a37f321e7ed2860879529d4 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 16:22:36 +0000 Subject: [PATCH 102/132] ground_truth->activating --- delphi/log/result_analysis.py | 16 +++++++--------- examples/server.py | 14 ++++++-------- 2 files changed, 13 insertions(+), 17 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 1cccf0a6..34b265a1 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -77,18 +77,16 @@ def latent_balanced_score_metrics( def parse_score_file(file_path): with open(file_path, "rb") as f: data = orjson.loads(f.read()) - df = pd.DataFrame( [ { "text": "".join(example["str_tokens"]), "distance": example["distance"], - "ground_truth": example["ground_truth"], + "activating": example["activating"], "prediction": example["prediction"], "probability": example["probability"], "correct": example["correct"], "activations": example["activations"], - "highlighted": example["highlighted"], } for example in data ] @@ -99,14 +97,14 @@ def parse_score_file(file_path): df = df[df["prediction"].notna()] df.reset_index(drop=True, inplace=True) total_examples = len(df) - total_positives = (df["ground_truth"]).sum() - total_negatives = (~df["ground_truth"]).sum() + total_positives = (df["activating"]).sum() + total_negatives = (~df["activating"]).sum() # Calculate confusion matrix elements - true_positives = ((df["prediction"] == 1) & (df["ground_truth"])).sum() - true_negatives = ((df["prediction"] == 0) & (~df["ground_truth"])).sum() - false_positives = ((df["prediction"] == 1) & (~df["ground_truth"])).sum() - false_negatives = ((df["prediction"] == 0) & (df["ground_truth"])).sum() + true_positives = ((df["prediction"] == 1) & (df["activating"])).sum() + true_negatives = ((df["prediction"] == 0) & (~df["activating"])).sum() + false_positives = ((df["prediction"] == 1) & (~df["activating"])).sum() + false_negatives = ((df["prediction"] == 0) & (df["activating"])).sum() # Calculate rates true_positive_rate = true_positives / total_positives if total_positives > 0 else 0 diff --git a/examples/server.py b/examples/server.py index 96296f44..d06f1cde 100644 --- a/examples/server.py +++ b/examples/server.py @@ -17,18 +17,16 @@ def calculate_balanced_accuracy(dataframe): tp = len( - dataframe[(dataframe["ground_truth"] is True) & (dataframe["correct"] is True)] + dataframe[(dataframe["activating"] is True) & (dataframe["correct"] is True)] ) tn = len( - dataframe[(dataframe["ground_truth"] is False) & (dataframe["correct"] is True)] + dataframe[(dataframe["activating"] is False) & (dataframe["correct"] is True)] ) fp = len( - dataframe[ - (dataframe["ground_truth"] is False) & (dataframe["correct"] is False) - ] + dataframe[(dataframe["activating"] is False) & (dataframe["correct"] is False)] ) fn = len( - dataframe[(dataframe["ground_truth"] is True) & (dataframe["correct"] is False)] + dataframe[(dataframe["activating"] is True) & (dataframe["correct"] is False)] ) if tp + fn == 0: recall = 0 @@ -52,9 +50,9 @@ def per_latent_scores_fuzz_detection(score_data): def per_latent_scores_embedding(score_data): data_df = pd.DataFrame(score_data) - data_df["ground_truth"] = data_df["distance"] > 0 + data_df["activating"] = data_df["distance"] > 0 print(data_df) - auc_score = roc_auc_score(data_df["ground_truth"], data_df["similarity"]) + auc_score = roc_auc_score(data_df["activating"], data_df["similarity"]) return auc_score From 5b54f63d339978591025c5c137bd2912f0308c8e Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 16:24:08 +0000 Subject: [PATCH 103/132] (Non)Activating examples as children of Examples --- delphi/latents/latents.py | 39 ++++++++------- delphi/scorers/classifier/detection.py | 47 ++++++------------ delphi/scorers/classifier/fuzz.py | 67 +++++++++----------------- delphi/scorers/classifier/sample.py | 31 +++++++----- 4 files changed, 76 insertions(+), 108 deletions(-) diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index b104890b..458295c3 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -40,24 +40,27 @@ def max_activation(self) -> float: return float(self.activations.max()) -def prepare_examples( - tokens: Float[Tensor, "examples ctx_len"], - activations: Float[Tensor, "examples ctx_len"], -) -> list[Example]: +@dataclass +class ActivatingExample(Example): + """ + An example of a latent that activates a model. """ - Prepare a list of examples from input tokens and activations. - Args: - tokens: Tokenized input sequences. - activations: Activation values for the input sequences. + quantile: int = 0 + """The quantile of the activating example.""" + + +@dataclass +class NonActivatingExample(Example): + """ + An example of a latent that does not activate a model. + """ - Returns: - list[Example]: A list of prepared examples. + distance: float = 0.0 + """ + The distance from the neighbouring latent. + Defaults to -1.0 if not using neighbours. """ - return [ - Example(tokens=toks, activations=acts, normalized_activations=None) - for toks, acts in zip(tokens, activations) - ] @dataclass @@ -91,17 +94,17 @@ class LatentRecord: latent: Latent """The latent associated with the record.""" - examples: list[Example] = field(default_factory=list) + examples: list[ActivatingExample] = field(default_factory=list) """Example sequences where the latent activations, assumed to be sorted in descending order by max activation.""" - not_active: list[Example] = field(default_factory=list) + not_active: list[NonActivatingExample] = field(default_factory=list) """Non-activating examples.""" - train: list[Example] = field(default_factory=list) + train: list[ActivatingExample] = field(default_factory=list) """Training examples.""" - test: list[list[Example]] = field(default_factory=list) + test: list[ActivatingExample] = field(default_factory=list) """Test examples.""" neighbours: list[Neighbour] = field(default_factory=list) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index d3b04e3f..ffb84340 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -1,7 +1,7 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from ...clients.client import Client -from ...latents import LatentRecord, Example +from ...latents import LatentRecord from .classifier import Classifier from .prompts.detection_prompt import prompt from .sample import Sample, examples_to_samples @@ -43,46 +43,29 @@ def __init__( **generation_kwargs, ) - self.prompt = prompt + def prompt(self, examples: str, explanation: str) -> list[dict]: + return prompt(examples, explanation) - def _prepare(self, record: LatentRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[Sample]: """ Prepare and shuffle a list of samples for classification. """ # check if not_active is a list of lists or a list of examples if len(record.not_active) > 0: - if isinstance(record.not_active[0], list): - # Here we are using neighbours - samples = [] - for i, examples in enumerate(record.not_active): - samples.extend( - examples_to_samples( - examples, - distance=-record.neighbours[i].distance, - ground_truth=False, - tokenizer=self.tokenizer, - ) - ) - elif isinstance(record.not_active[0], Example): - # This is if we dont use neighbours - samples = examples_to_samples( - record.not_active, - distance=-1, - ground_truth=False, - tokenizer=self.tokenizer, - ) + samples = examples_to_samples( + record.not_active, + tokenizer=self.tokenizer, + ) + else: - samples = [] + samples = [] - for i, examples in enumerate(record.test): - samples.extend( - examples_to_samples( - examples, - distance=i + 1, - ground_truth=True, - tokenizer=self.tokenizer, - ) + samples.extend( + examples_to_samples( + record.test, + tokenizer=self.tokenizer, ) + ) return samples diff --git a/delphi/scorers/classifier/fuzz.py b/delphi/scorers/classifier/fuzz.py index 3924279e..833952f4 100644 --- a/delphi/scorers/classifier/fuzz.py +++ b/delphi/scorers/classifier/fuzz.py @@ -4,8 +4,8 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from ...clients.client import Client -from ...latents.latents import Example from ...latents import LatentRecord +from ...latents.latents import ActivatingExample from ..scorer import Scorer from .classifier import Classifier from .prompts.fuzz_prompt import prompt @@ -50,9 +50,11 @@ def __init__( ) self.threshold = threshold - self.prompt = prompt - def mean_n_activations_ceil(self, examples: list[Example]): + def prompt(self, examples: str, explanation: str) -> list[dict]: + return prompt(examples, explanation) + + def mean_n_activations_ceil(self, examples: list[ActivatingExample]): """ Calculate the ceiling of the average number of activations in each example. """ @@ -62,56 +64,31 @@ def mean_n_activations_ceil(self, examples: list[Example]): return ceil(avg) - def _prepare(self, record: LatentRecord) -> list[list[Sample]]: + def _prepare(self, record: LatentRecord) -> list[Sample]: """ Prepare and shuffle a list of samples for classification. """ - assert len(record.test) > 0 and len(record.test[0]) > 0, "No test records found" + assert len(record.test) > 0, "No test records found" - defaults = { - "highlighted": True, - "tokenizer": self.tokenizer, - } - all_examples = [] - for examples_chunk in record.test: - all_examples.extend(examples_chunk) + n_incorrect = self.mean_n_activations_ceil(record.test) - n_incorrect = self.mean_n_activations_ceil(all_examples) if len(record.not_active) > 0: - if isinstance(record.not_active[0], list): - # Here we are using neighbours - samples = [] - for i, examples in enumerate(record.not_active): - samples.extend( - examples_to_samples( - examples, - distance=-record.neighbours[i].distance, - ground_truth=False, - n_incorrect=n_incorrect, - **defaults, - ) - ) - elif isinstance(record.not_active[0], Example): - # This is if we dont use neighbours - samples = examples_to_samples( - record.not_active, - distance=-1, - ground_truth=False, - n_incorrect=n_incorrect, - **defaults, - ) + samples = examples_to_samples( + record.not_active, + tokenizer=self.tokenizer, + n_incorrect=n_incorrect, + highlighted=True, + ) + else: samples = [] - for i, examples in enumerate(record.test): - samples.extend( - examples_to_samples( - examples, - distance=i + 1, - ground_truth=True, - n_incorrect=0, - **defaults, - ) + samples.extend( + examples_to_samples( + record.test, + tokenizer=self.tokenizer, + n_incorrect=0, + highlighted=True, ) - + ) return samples diff --git a/delphi/scorers/classifier/sample.py b/delphi/scorers/classifier/sample.py index 11020ceb..4703508b 100644 --- a/delphi/scorers/classifier/sample.py +++ b/delphi/scorers/classifier/sample.py @@ -3,9 +3,9 @@ from typing import NamedTuple import torch -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast -from ...latents import Example +from ...latents import ActivatingExample, NonActivatingExample from ...logger import logger L = "<<" @@ -26,15 +26,12 @@ class ClassifierOutput: distance: float | int """Quantile or neighbor distance""" - ground_truth: bool + activating: bool """Whether the example is activating or not""" prediction: bool | None = False """Whether the model predicted the example activating or not""" - highlighted: bool = False - """Whether the sample is highlighted""" - probability: float | None = 0.0 """The probability of the example activating""" @@ -48,8 +45,8 @@ class Sample(NamedTuple): def examples_to_samples( - examples: list[Example], - tokenizer: PreTrainedTokenizer, + examples: list[ActivatingExample] | list[NonActivatingExample], + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, n_incorrect: int = 0, threshold: float = 0.3, highlighted: bool = False, @@ -61,6 +58,13 @@ def examples_to_samples( text, str_toks = _prepare_text( example, tokenizer, n_incorrect, threshold, highlighted ) + match example: + case ActivatingExample(): + activating = True + distance = example.quantile + case NonActivatingExample(): + activating = False + distance = example.distance samples.append( Sample( @@ -68,7 +72,8 @@ def examples_to_samples( data=ClassifierOutput( str_tokens=str_toks, activations=example.activations.tolist(), - highlighted=highlighted, + activating=activating, + distance=distance, **sample_kwargs, ), ) @@ -82,11 +87,11 @@ def examples_to_samples( def _prepare_text( example, - tokenizer: PreTrainedTokenizer, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, n_incorrect: int, threshold: float, highlighted: bool, -): +) -> tuple[str, list[str]]: if ( tokenizer is None ): # If we don't have a tokenizer, we assume the tokens are already strings @@ -104,10 +109,10 @@ def _prepare_text( # if correct example if n_incorrect == 0: - def check(i): + def threshold_check(i): return example.activations[i] >= threshold - return _highlight(str_toks, check), str_toks + return _highlight(str_toks, threshold_check), str_toks # Highlight n_incorrect tokens with activations # below threshold if incorrect example From 97d2dbc88d140deba3477749cdba4526e9b502ad Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 16:25:02 +0000 Subject: [PATCH 104/132] Mostly type hints --- delphi/latents/samplers.py | 113 +++++------------- .../classifier/prompts/detection_prompt.py | 2 +- delphi/utils.py | 9 +- 3 files changed, 35 insertions(+), 89 deletions(-) diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index d3e685be..85cbc18a 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -1,59 +1,24 @@ import random -from collections import deque from typing import Literal from ..config import ExperimentConfig from ..logger import logger -from .latents import Example, LatentRecord +from .latents import ActivatingExample, LatentRecord -def split_activation_quantiles( - examples: list[Example], n_quantiles: int, n_samples: int, seed: int = 22 -): - """ - TODO review this, there is a possible bug here: - `examples[0].max_activation < threshold` - - Split the examples into n_quantiles and sample n_samples from each quantile. - - Args: - examples: list of Examples, assumed to be in descending sorted order - by max_activation - n_quantiles: number of quantiles to split the examples into - n_samples: number of samples to sample from each quantile - seed: seed for the random number generator - - Returns: - list of lists of Examples. - Each inner list contains n_samples from a unique quantile. - """ - random.seed(seed) - - queue_examples = deque(examples) - max_activation = examples[0].max_activation - - # For 4 quantiles, thresholds are 0.25, 0.5, 0.75 - thresholds = [max_activation * i / n_quantiles for i in range(1, n_quantiles)] - - samples: list[list[Example]] = [] - for threshold in thresholds: - # Get all examples in quantile - quantile = [] - while queue_examples and queue_examples[0].max_activation < threshold: - quantile.append(queue_examples.popleft()) - - sample = random.sample(quantile, n_samples) - samples.append(sample) - - sample = random.sample(examples, n_samples) - samples.append(sample) - - return samples +def normalize_activations( + examples: list[ActivatingExample], max_activation: float + ) -> list[ActivatingExample]: + for example in examples: + example.normalized_activations = ( + example.activations * 10 / max_activation + ).floor() + return examples def split_quantiles( - examples: list[Example], n_quantiles: int, n_samples: int, seed: int = 22 -): + examples: list[ActivatingExample], n_quantiles: int, n_samples: int, seed: int = 22 +) -> list[ActivatingExample]: """ Randomly select (n_samples // n_quantiles) samples from each quantile. """ @@ -61,7 +26,7 @@ def split_quantiles( quantile_size = len(examples) // n_quantiles samples_per_quantile = n_samples // n_quantiles - samples: list[list[Example]] = [] + samples: list[ActivatingExample] = [] for i in range(n_quantiles): # Take an evenly spaced slice of the examples for the quantile. quantile = examples[i * quantile_size : (i + 1) * quantile_size] @@ -74,13 +39,16 @@ def split_quantiles( ) else: sample = random.sample(quantile, samples_per_quantile) - samples.append(sample) + # set the quantile index + for example in sample: + example.quantile = i + samples.extend(sample) return samples def train( - examples: list[Example], + examples: list[ActivatingExample], max_activation: float, n_train: int, train_type: Literal["top", "random", "quantiles"], @@ -90,44 +58,29 @@ def train( match train_type: case "top": selected_examples = examples[:n_train] - for example in selected_examples: - example.normalized_activations = ( - example.activations * 10 / max_activation - ).floor() + selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples case "random": random.seed(seed) - if n_train > len(examples): + n_sample = min(n_train, len(examples)) + if n_sample < n_train: logger.warning( "n_train is greater than the number of examples, using all examples" ) - for example in examples: - example.normalized_activations = ( - example.activations * 10 / max_activation - ).floor() - return examples + selected_examples = random.sample(examples, n_train) - for example in selected_examples: - example.normalized_activations = ( - example.activations * 10 / max_activation - ).floor() + selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples case "quantiles": - selected_examples_quantiles = split_quantiles( + selected_examples = split_quantiles( examples, n_quantiles, n_train ) - selected_examples = [] - for quantile in selected_examples_quantiles: - for example in quantile: - example.normalized_activations = ( - example.activations * 10 / max_activation - ).floor() - selected_examples.extend(quantile) + selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples def test( - examples: list[Example], + examples: list[ActivatingExample], max_activation: float, n_test: int, n_quantiles: int, @@ -136,22 +89,10 @@ def test( match test_type: case "quantiles": selected_examples = split_quantiles(examples, n_quantiles, n_test) - for quantile in selected_examples: - for example in quantile: - example.normalized_activations = ( - example.activations * 10 / max_activation - ).floor() + selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples case "activation": - selected_examples = split_activation_quantiles( - examples, n_quantiles, n_test - ) - for quantile in selected_examples: - for example in quantile: - example.normalized_activations = ( - example.activations * 10 / max_activation - ).floor() - return selected_examples + raise NotImplementedError("Activation sampling not implemented") def sample( diff --git a/delphi/scorers/classifier/prompts/detection_prompt.py b/delphi/scorers/classifier/prompts/detection_prompt.py index be5699b1..75b82478 100644 --- a/delphi/scorers/classifier/prompts/detection_prompt.py +++ b/delphi/scorers/classifier/prompts/detection_prompt.py @@ -66,7 +66,7 @@ ] -def prompt(examples, explanation): +def prompt(examples: str, explanation: str) -> list[dict]: generation_prompt = GENERATION_PROMPT.format( explanation=explanation, examples=examples ) diff --git a/delphi/utils.py b/delphi/utils.py index b5831011..ac0f88db 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -20,7 +20,12 @@ def load_tokenized_data( data = load_dataset(dataset_repo, name=dataset_name, split=dataset_split) data = data.shuffle(seed) tokens_ds = chunk_and_tokenize( - data, tokenizer, max_seq_len=ctx_len, text_key=column_name + data, # type: ignore + tokenizer, + max_seq_len=ctx_len, + text_key=column_name, ) - return tokens_ds + tokens = tokens_ds["input_ids"] + + return tokens From b9be963888aa15918a0b49b62e292674b6be7d04 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 16:26:06 +0000 Subject: [PATCH 105/132] Mostly type hinys --- delphi/latents/loader.py | 6 ++--- delphi/latents/transforms.py | 10 +++++---- delphi/scorers/classifier/classifier.py | 30 ++++++++++++++++--------- 3 files changed, 29 insertions(+), 17 deletions(-) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index cb29d0e7..bb686a70 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -25,10 +25,10 @@ class ActivationData(NamedTuple): Represents the activation data for a latent. """ - locations: Float[Tensor, "locations 2"] + locations: Float[Tensor, "n_examples 2"] """Tensor of latent locations.""" - activations: Float[Tensor, "activations"] + activations: Float[Tensor, "n_examples"] """Tensor of latent activations.""" @@ -193,7 +193,7 @@ def __init__( # TODO: is it possible to do this without loading all data? if self.constructor is not None: - if self.constructor.keywords["constructor_type"] == "neighbour": + if self.constructor.keywords["constructor_type"] == "neighbours": self.all_data = self._load_all_data(raw_dir, self.modules) self.load_tokens() diff --git a/delphi/latents/transforms.py b/delphi/latents/transforms.py index 3cb61c3b..3f4f88e9 100644 --- a/delphi/latents/transforms.py +++ b/delphi/latents/transforms.py @@ -3,20 +3,22 @@ def set_neighbours( record: LatentRecord, - neighbours: dict[int, list[tuple[float, int]]], + neighbours: dict[str, list[tuple[float, int]]], threshold: float, ): """ Set the neighbours for the latent record. """ - neighbours = neighbours[str(record.latent.latent_index)] + latent_neighbours = neighbours[str(record.latent.latent_index)] # Each element in neighbours is a tuple of (distance,feature_index) # We want to keep only the ones with a distance less than the threshold - neighbours = [neighbour for neighbour in neighbours if neighbour[0] > threshold] + latent_neighbours = [ + neighbour for neighbour in latent_neighbours if neighbour[0] > threshold + ] record.neighbours = [ Neighbour(distance=neighbour[0], latent_index=neighbour[1]) - for neighbour in neighbours + for neighbour in latent_neighbours ] diff --git a/delphi/scorers/classifier/classifier.py b/delphi/scorers/classifier/classifier.py index a279eb54..b66733d2 100644 --- a/delphi/scorers/classifier/classifier.py +++ b/delphi/scorers/classifier/classifier.py @@ -5,7 +5,7 @@ from abc import abstractmethod import numpy as np -from transformers import PreTrainedTokenizer +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from ...clients.client import Client from ...latents import LatentRecord @@ -13,11 +13,12 @@ from ..scorer import Scorer, ScorerResult from .sample import ClassifierOutput, Sample + class Classifier(Scorer): def __init__( self, client: Client, - tokenizer: PreTrainedTokenizer, + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, verbose: bool, n_examples_shown: int, log_prob: bool, @@ -46,10 +47,10 @@ def __init__( async def __call__( self, record: LatentRecord, - ) -> list[ClassifierOutput]: + ) -> ScorerResult: samples = self._prepare(record) random.shuffle(samples) - + samples = self._batch(samples) results = await self._query( record.explanation, @@ -116,14 +117,18 @@ async def _generate( result = sample.data result.prediction = prediction if prediction is not None: - result.correct = prediction == result.ground_truth + result.correct = prediction == result.activating else: result.correct = None result.probability = probability results.append(result) if self.verbose: - result.text = sample.text + logger.info( + f"Example {sample.text}, " + f"Prediction: {prediction}, " + f"Probability: {probability}" + ) return results def _parse(self, string, logprobs=None): @@ -132,8 +137,9 @@ def _parse(self, string, logprobs=None): # Matches the first instance of text enclosed in square brackets pattern = r"\[.*?\]" match = re.search(pattern, string) - - predictions: list[int] = json.loads(match.group(0)) + if match is None: + raise ValueError("No match found in string") + predictions: list[bool] = json.loads(match.group(0)) assert len(predictions) == self.n_examples_shown probabilities = ( self._parse_logprobs(logprobs) @@ -184,7 +190,7 @@ def _build_prompt( self, explanation: str, batch: list[Sample], - ) -> str: + ) -> list[dict]: """ Prepare prompt for generation. """ @@ -195,11 +201,15 @@ def _build_prompt( return self.prompt(explanation=explanation, examples=examples) + @abstractmethod + def prompt(self, examples: str, explanation: str) -> list[dict]: + pass + def _batch(self, samples): return [ samples[i : i + self.n_examples_shown] for i in range(0, len(samples), self.n_examples_shown) ] - def call_sync(self, record: LatentRecord) -> list[ClassifierOutput]: + def call_sync(self, record: LatentRecord) -> ScorerResult: return asyncio.run(self.__call__(record)) From 21dc1012140106cb74b2c7f76ad80f8746376b3f Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 16:26:37 +0000 Subject: [PATCH 106/132] Using ActivatingExample --- delphi/explainers/default/default.py | 4 +- delphi/explainers/explainer.py | 7 +--- delphi/latents/__init__.py | 10 ++++- delphi/latents/constructors.py | 58 +++++++++++++++++++++++----- 4 files changed, 61 insertions(+), 18 deletions(-) diff --git a/delphi/explainers/default/default.py b/delphi/explainers/default/default.py index c1419e70..5171a75d 100644 --- a/delphi/explainers/default/default.py +++ b/delphi/explainers/default/default.py @@ -1,6 +1,6 @@ import asyncio -from ..explainer import Example, Explainer +from ..explainer import ActivatingExample, Explainer from .prompt_builder import build_prompt @@ -29,7 +29,7 @@ def __init__( **generation_kwargs, ) - def _build_prompt(self, examples: list[Example]) -> list[dict]: + def _build_prompt(self, examples: list[ActivatingExample]) -> list[dict]: highlighted_examples = [] for i, example in enumerate(examples): diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index fcd6abd6..67fb6840 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -4,13 +4,10 @@ import re from abc import ABC, abstractmethod from typing import NamedTuple -import re import aiofiles -from ..logger import logger -from ..latents.latents import LatentRecord -from ..latents.latents import Example, LatentRecord +from ..latents.latents import ActivatingExample, LatentRecord from ..logger import logger @@ -123,7 +120,7 @@ def _join_activations( return "Activations: " + acts @abstractmethod - def _build_prompt(self, examples: list[Example]) -> list[dict]: + def _build_prompt(self, examples: list[ActivatingExample]) -> list[dict]: pass diff --git a/delphi/latents/__init__.py b/delphi/latents/__init__.py index 760e56a1..b41d2026 100644 --- a/delphi/latents/__init__.py +++ b/delphi/latents/__init__.py @@ -5,7 +5,13 @@ pool_max_activation_windows, random_non_activating_windows, ) -from .latents import Example, Latent, LatentRecord +from .latents import ( + ActivatingExample, + Example, + Latent, + LatentRecord, + NonActivatingExample, +) from .loader import LatentDataset from .samplers import sample from .stats import unigram @@ -16,6 +22,8 @@ "Latent", "LatentRecord", "Example", + "ActivatingExample", + "NonActivatingExample", "pool_max_activation_windows", "random_non_activating_windows", "neighbour_non_activation_windows", diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index f27d43f3..56946399 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -4,10 +4,52 @@ from jaxtyping import Float from torch import Tensor -from .latents import LatentRecord, prepare_examples +from .latents import ActivatingExample, LatentRecord, NonActivatingExample from .loader import ActivationData +def prepare_activating_examples( + tokens: Float[Tensor, "examples ctx_len"], + activations: Float[Tensor, "examples ctx_len"], +) -> list[ActivatingExample]: + """ + Prepare a list of examples from input tokens and activations. + + Args: + tokens: Tokenized input sequences. + activations: Activation values for the input sequences. + + Returns: + list[Example]: A list of prepared examples. + """ + return [ + ActivatingExample(tokens=toks, activations=acts, normalized_activations=None) + for toks, acts in zip(tokens, activations) + ] + + +def prepare_non_activating_examples( + tokens: Float[Tensor, "examples ctx_len"], + distance: float, +) -> list[NonActivatingExample]: + """ + Prepare a list of non-activating examples from input tokens and distance. + + Args: + tokens: Tokenized input sequences. + distance: The distance from the neighbouring latent. + """ + return [ + NonActivatingExample( + tokens=toks, + activations=torch.zeros_like(toks), + normalized_activations=None, + distance=distance, + ) + for toks in tokens + ] + + def _top_k_pools( max_buffer: Float[Tensor, "batch"], split_activations: Float[Tensor, "activations ctx_len"], @@ -118,7 +160,7 @@ def constructor( ctx_len=ctx_len, max_examples=max_examples, ) - record.examples = prepare_examples(token_windows, act_windows) + record.examples = prepare_activating_examples(token_windows, act_windows) if constructor_type == "random": # Add random non-activating examples to the record in place @@ -177,7 +219,6 @@ def neighbour_non_activation_windows( number_examples = 0 all_examples = [] - used_neighbours = [] for neighbour in record.neighbours: if number_examples >= n_not_active: break @@ -221,13 +262,10 @@ def neighbour_non_activation_windows( # use the first n_examples_per_neighbour examples, # which will be the most active examples examples_used = len(token_windows) - all_examples.append( - prepare_examples(token_windows, torch.zeros_like(token_windows)) + all_examples.extend( + prepare_non_activating_examples(token_windows, neighbour.distance) ) - used_neighbours.append(neighbour) number_examples += examples_used - # We change neighbours in place to be the list of neighbours used - record.neighbours = used_neighbours if len(all_examples) == 0: print("No examples found") record.not_active = all_examples @@ -269,7 +307,7 @@ def random_non_activating_windows( toks = reshaped_tokens[selected_indices] - record.not_active = prepare_examples( + record.not_active = prepare_non_activating_examples( toks, - torch.zeros_like(toks), + -1.0, ) From c898f160694c64c8ed1ffe0f94069dc34d092b3c Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 16:47:33 +0000 Subject: [PATCH 107/132] Better names for things --- delphi/latents/constructors.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 56946399..2825249a 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -147,8 +147,8 @@ def constructor( mask = torch.ones(n_windows, dtype=torch.bool) mask[unique_batch_pos] = False - # Indices where the latent is active - active_indices = mask.nonzero(as_tuple=False).squeeze() + # Indices where the latent is not active + non_active_indices = mask.nonzero(as_tuple=False).squeeze() activations = activation_data.activations # Add activation examples to the record in place @@ -166,7 +166,7 @@ def constructor( # Add random non-activating examples to the record in place random_non_activating_windows( record, - available_indices=active_indices, + available_indices=non_active_indices, reshaped_tokens=reshaped_tokens, n_not_active=n_not_active, ) @@ -251,7 +251,7 @@ def neighbour_non_activation_windows( if activations.numel() == 0: print(f"No available indices for neighbour {neighbour.latent_index}") continue - token_windows, act_windows = pool_max_activation_windows( + token_windows, _ = pool_max_activation_windows( activations=activations, tokens=reshaped_tokens, ctx_indices=available_ctx_indices, From eb11854732b2b0c3582f0a934f17accd6bd1e603 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 17:07:01 +0000 Subject: [PATCH 108/132] Ruff stuff --- delphi/explainers/contrastive_explainer.py | 12 ++-- delphi/explainers/single_token_explainer.py | 5 +- delphi/scorers/embedding/embedding.py | 1 + examples/caching_activations.ipynb | 25 ++++++--- examples/generate_explanations.ipynb | 15 ++--- examples/latent_contexts.ipynb | 62 ++++++++++++++------- examples/score_explanations.ipynb | 13 +++-- 7 files changed, 81 insertions(+), 52 deletions(-) diff --git a/delphi/explainers/contrastive_explainer.py b/delphi/explainers/contrastive_explainer.py index a2f68706..9245537a 100644 --- a/delphi/explainers/contrastive_explainer.py +++ b/delphi/explainers/contrastive_explainer.py @@ -1,12 +1,12 @@ -import re import asyncio import faiss -from delphi.explainers.explainer import Explainer, ExplainerResult from delphi.explainers.default.prompt_builder import build_single_token_prompt +from delphi.explainers.explainer import Explainer, ExplainerResult from delphi.logger import logger + class ContrastiveExplainer(Explainer): name = "contrastive" @@ -19,7 +19,7 @@ def __init__( activations: bool = False, cot: bool = False, threshold: float = 0.6, - temperature: float = 0., + temperature: float = 0.0, **generation_kwargs, ): self.client = client @@ -37,7 +37,7 @@ async def __call__(self, record): # Need to change __call__ to use index messages = self._build_prompt(record.train) - + response = await self.client.generate( messages, temperature=self.temperature, **self.generation_kwargs ) @@ -54,7 +54,9 @@ async def __call__(self, record): return ExplainerResult(record=record, explanation=explanation) except Exception as e: logger.error(f"Explanation parsing failed: {e}") - return ExplainerResult(record=record, explanation="Explanation could not be parsed.") + return ExplainerResult( + record=record, explanation="Explanation could not be parsed." + ) def _build_prompt(self, examples): highlighted_examples = [] diff --git a/delphi/explainers/single_token_explainer.py b/delphi/explainers/single_token_explainer.py index 638c6517..ffef2832 100644 --- a/delphi/explainers/single_token_explainer.py +++ b/delphi/explainers/single_token_explainer.py @@ -1,9 +1,7 @@ import asyncio -import re from delphi.explainers.default.prompt_builder import build_single_token_prompt -from delphi.explainers.explainer import Explainer, ExplainerResult -from delphi.logger import logger +from delphi.explainers.explainer import Explainer class SingleTokenExplainer(Explainer): @@ -30,7 +28,6 @@ def __init__( self.temperature = temperature self.generation_kwargs = generation_kwargs - def _build_prompt(self, examples): highlighted_examples = [] diff --git a/delphi/scorers/embedding/embedding.py b/delphi/scorers/embedding/embedding.py index 3c9e26eb..677f12a7 100644 --- a/delphi/scorers/embedding/embedding.py +++ b/delphi/scorers/embedding/embedding.py @@ -4,6 +4,7 @@ from typing import NamedTuple from transformers import PreTrainedTokenizer + from ...latents import Example, LatentRecord from ..scorer import Scorer, ScorerResult diff --git a/examples/caching_activations.ipynb b/examples/caching_activations.ipynb index 4b2969af..02f574f9 100644 --- a/examples/caching_activations.ipynb +++ b/examples/caching_activations.ipynb @@ -36,8 +36,8 @@ "source": [ "from transformers import AutoModel\n", "\n", - "from delphi.sparse_coders import load_hooks_sparse_coders\n", - "from delphi.config import RunConfig\n" + "from delphi.config import RunConfig\n", + "from delphi.sparse_coders import load_hooks_sparse_coders\n" ] }, { @@ -76,11 +76,15 @@ ], "source": [ "# Load the model\n", - "model = AutoModel.from_pretrained(\"google/gemma-2-2b\", device_map=\"cuda\", torch_dtype=\"float16\")\n", + "model = AutoModel.from_pretrained(\"google/gemma-2-2b\",\n", + " device_map=\"cuda\",\n", + " torch_dtype=\"float16\")\n", "\n", - "# Load the autoencoders, the function returns a dictionary of the submodules with the autoencoders and the edited model.\n", + "# Load the autoencoders, the function returns a dictionary of the submodules\n", + "# with the autoencoders and the edited model.\n", "# it takes as arguments the model, the layers to load the autoencoders into,\n", - "# the average L0 sparsity per layer, the size of the autoencoders and the type of autoencoders (residuals or MLPs).\n", + "# the average L0 sparsity per layer, the size of the autoencoders \n", + "# and the type of autoencoders (residuals or MLPs).\n", "\n", "run_cfg = RunConfig(\n", " sparse_model=\"google/gemma-scope-2b-pt-res\",\n", @@ -114,7 +118,8 @@ "metadata": {}, "outputs": [], "source": [ - "# There is a default cache config that can also be modified when using a \"production\" script.\n", + "# There is a default cache config that can also be modified\n", + "# when making a \"production\" script.\n", "cfg = CacheConfig(\n", " dataset_repo=\"EleutherAI/rpj-v2-sample\",\n", " dataset_split=\"train[:1%]\",\n", @@ -175,11 +180,13 @@ "cache.run(cfg.n_tokens, tokens)\n", "\n", "cache.save_splits(\n", - " n_splits=cfg.n_splits, # We split the activation and location indices into different files to make loading faster\n", + " n_splits=cfg.n_splits, \n", + " # We split the activation and location indices into different for faster loading\n", " save_dir=\"latents\"\n", ")\n", "\n", - "# The config of the cache should be saved with the results such that it can be loaded later.\n", + "# The config of the cache should be saved with the results~\n", + "# such that it can be loaded later.\n", "\n", "cache.save_config(\n", " save_dir=\"latents\",\n", @@ -205,7 +212,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.7" + "version": "3.12.2" } }, "nbformat": 4, diff --git a/examples/generate_explanations.ipynb b/examples/generate_explanations.ipynb index 7ac4892a..b49df9b1 100644 --- a/examples/generate_explanations.ipynb +++ b/examples/generate_explanations.ipynb @@ -30,7 +30,7 @@ "from delphi.config import ExperimentConfig, LatentConfig\n", "from delphi.explainers import DefaultExplainer\n", "from delphi.latents import LatentDataset\n", - "from delphi.latents.constructors import default_constructor\n", + "from delphi.latents.constructors import constructor\n", "from delphi.latents.samplers import sample\n", "from delphi.pipeline import Pipeline, process_wrapper\n", "\n", @@ -45,7 +45,8 @@ "source": [ "latent_cfg = LatentConfig(\n", " width=131072, # The number of latents of your SAE\n", - " min_examples=200, # The minimum number of examples to consider for the latent to be explained\n", + " min_examples=200,\n", + " # The minimum number of examples to consider for the latent to be explained\n", " max_examples=10000, # The maximum number of examples to be sampled from\n", " n_splits=5 # How many splits was the cache split into\n", ")\n" @@ -102,9 +103,8 @@ "metadata": {}, "outputs": [], "source": [ - "constructor=partial(\n", - " default_constructor,\n", - " token_loader=None,\n", + "example_constructor=partial(\n", + " constructor,\n", " n_not_active=experiment_cfg.n_non_activating, \n", " ctx_len=experiment_cfg.example_ctx_len, \n", " max_examples=latent_cfg.max_examples\n", @@ -115,7 +115,7 @@ " cfg=latent_cfg,\n", " modules=[module],\n", " latents=latent_dict,\n", - " constructor=constructor,\n", + " constructor=example_constructor,\n", " sampler=sampler\n", ") " ] @@ -214,7 +214,8 @@ " explainer_pipe,\n", ")\n", "number_of_parallel_latents = 10\n", - "await pipeline.run(number_of_parallel_latents) # This will start generating the explanations." + "await pipeline.run(number_of_parallel_latents)\n", + " # This will start generating the explanations." ] }, { diff --git a/examples/latent_contexts.ipynb b/examples/latent_contexts.ipynb index e5710901..ac6e0d5b 100644 --- a/examples/latent_contexts.ipynb +++ b/examples/latent_contexts.ipynb @@ -24,7 +24,7 @@ "\n", "from delphi.config import ExperimentConfig, LatentConfig\n", "from delphi.latents import LatentDataset\n", - "from delphi.latents.constructors import default_constructor\n", + "from delphi.latents.constructors import constructor\n", "from delphi.latents.samplers import sample\n" ] }, @@ -34,7 +34,12 @@ "metadata": {}, "outputs": [], "source": [ - "def make_colorbar(min_value, max_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):\n", + "def make_colorbar(min_value,\n", + " max_value,\n", + " white = 255,\n", + " red_blue_ness = 250,\n", + " positive_threshold = 0.01,\n", + " negative_threshold = 0.01):\n", " # Add color bar\n", " colorbar = \"\"\n", " num_colors = 4\n", @@ -43,27 +48,33 @@ " ratio = i / (num_colors)\n", " value = round((min_value*ratio),1)\n", " text_color = \"255,255,255\" if ratio > 0.5 else \"0,0,0\"\n", - " colorbar += f' {value} '\n", + " colorbar += f' {value} ' # noqa: E501\n", " # Do zero\n", - " colorbar += f' 0.0 '\n", + " colorbar += f' 0.0 ' # noqa: E501\n", " # Do positive\n", " if(max_value > positive_threshold):\n", " for i in range(1, num_colors+1):\n", " ratio = i / (num_colors)\n", " value = round((max_value*ratio),1)\n", " text_color = \"255,255,255\" if ratio > 0.5 else \"0,0,0\"\n", - " colorbar += f' {value} '\n", + " colorbar += f' {value} ' # noqa: E501\n", " return colorbar\n", "\n", - "def value_to_color(activation, max_value, min_value, white = 255, red_blue_ness = 250, positive_threshold = 0.01, negative_threshold = 0.01):\n", + "def value_to_color(activation,\n", + " max_value,\n", + " min_value,\n", + " white = 255,\n", + " red_blue_ness = 250, \n", + " positive_threshold = 0.01,\n", + " negative_threshold = 0.01):\n", " if activation > positive_threshold:\n", " ratio = activation/max_value\n", " text_color = \"0,0,0\" if ratio <= 0.5 else \"255,255,255\" \n", - " background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)'\n", + " background_color = f'rgba({int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},255,1)' # noqa: E501\n", " elif activation < -negative_threshold:\n", " ratio = activation/min_value\n", " text_color = \"0,0,0\" if ratio <= 0.5 else \"255,255,255\" \n", - " background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)'\n", + " background_color = f'rgba(255, {int(red_blue_ness-(red_blue_ness*ratio))},{int(red_blue_ness-(red_blue_ness*ratio))},1)' # noqa: E501\n", " else:\n", " text_color = \"0,0,0\"\n", " background_color = f'rgba({white},{white},{white},1)'\n", @@ -83,11 +94,18 @@ " array = [array]\n", " return array\n", "\n", - "def tokens_and_activations_to_html(toks, activations, tokenizer, logit_diffs=None, model_type=\"causal\"):\n", + "def tokens_and_activations_to_html(\n", + " toks,\n", + " activations,\n", + " tokenizer,\n", + " logit_diffs=None,\n", + " model_type=\"causal\"):\n", " text_spacing = \"0.00em\"\n", " toks = convert_token_array_to_list(toks)\n", " activations = convert_token_array_to_list(activations)\n", - " toks = [[tokenizer.decode(t).replace('Ä ', ' ').replace('\\n', '\\\\n') for t in tok] for tok in toks]\n", + " toks = [\n", + " [tokenizer.decode(t).replace('Ä ', ' ').replace('\\n', '\\\\n') for t in tok]\n", + " for tok in toks]\n", " print(len(activations))\n", " print(len(toks))\n", " highlighted_text = []\n", @@ -104,7 +122,8 @@ " highlighted_text.append(\"Token Activations: \" + make_colorbar(min_value, max_value))\n", " if(logit_diffs is not None and model_type != \"reward_model\"):\n", " highlighted_text.append('
')\n", - " highlighted_text.append(\"Logit Diff: \" + make_colorbar(logit_min_value, logit_max_value))\n", + " highlighted_text.append(\"Logit Diff: \" + make_colorbar(logit_min_value,\n", + " logit_max_value))\n", " \n", " highlighted_text.append('
')\n", " for seq_ind, (act, tok) in enumerate(zip(activations, toks)):\n", @@ -112,15 +131,17 @@ " if(logit_diffs is not None and model_type != \"reward_model\"):\n", " highlighted_text.append('
')\n", " text_color, background_color = value_to_color(a, max_value, min_value)\n", - " highlighted_text.append(f'{t.replace(\" \", \" \").replace(\"\",\"BOS\")}')\n", + " highlighted_text.append(f'{t.replace(\" \", \" \").replace(\"\",\"BOS\")}') # noqa: E501\n", " if(logit_diffs is not None and model_type != \"reward_model\"):\n", " logit_diffs_act = logit_diffs[seq_ind][act_ind]\n", - " _, logit_background_color = value_to_color(logit_diffs_act, logit_max_value, logit_min_value)\n", - " highlighted_text.append(f'
')\n", + " _, logit_background_color = value_to_color(logit_diffs_act,\n", + " logit_max_value,\n", + " logit_min_value)\n", + " highlighted_text.append(f'
') # noqa: E501\n", " if(logit_diffs is not None and model_type==\"reward_model\"):\n", " reward_change = logit_diffs[seq_ind].item()\n", " text_color, background_color = value_to_color(reward_change, 10, -10)\n", - " highlighted_text.append(f'
Reward: {reward_change:.2f}')\n", + " highlighted_text.append(f'
Reward: {reward_change:.2f}') # noqa: E501\n", " highlighted_text.append('
')\n", " highlighted_text = ''.join(highlighted_text)\n", " return highlighted_text\n" @@ -137,9 +158,8 @@ "\n", " experiment_cfg = ExperimentConfig(n_non_activating=0)\n", "\n", - " constructor = partial(\n", - " default_constructor,\n", - " token_loader=None,\n", + " example_constructor = partial(\n", + " constructor,\n", " n_not_active=experiment_cfg.n_non_activating,\n", " ctx_len=experiment_cfg.example_ctx_len,\n", " max_examples=latent_cfg.max_examples,\n", @@ -152,7 +172,7 @@ " cfg=latent_cfg,\n", " modules=[hookpoint],\n", " latents={hookpoint: torch.arange(100)},\n", - " constructor=constructor,\n", + " constructor=example_constructor,\n", " sampler=sampler\n", " )\n", " \n", @@ -167,7 +187,7 @@ "\n", "\n", "async def plot_examples( raw_dir, hookpoint: str):\n", - " all_examples, maximum_activations,tokenizer = await load_examples(raw_dir, hookpoint)\n", + " all_examples, maximum_acts,tokenizer = await load_examples(raw_dir, hookpoint)\n", " keys = list(all_examples.keys())\n", "\n", " current_index = [\n", @@ -182,7 +202,7 @@ " list_activations = []\n", " for example in all_examples[key]:\n", " example_tokens = example.tokens\n", - " activations = example.activations / maximum_activations[key]\n", + " activations = example.activations / maximum_acts[key]\n", " list_tokens.append(example_tokens)\n", " list_activations.append(activations.tolist())\n", "\n", diff --git a/examples/score_explanations.ipynb b/examples/score_explanations.ipynb index 5e65acc0..2b5dd354 100644 --- a/examples/score_explanations.ipynb +++ b/examples/score_explanations.ipynb @@ -30,7 +30,7 @@ "from delphi.config import ExperimentConfig, LatentConfig\n", "from delphi.explainers import explanation_loader\n", "from delphi.latents import LatentDataset\n", - "from delphi.latents.constructors import default_constructor\n", + "from delphi.latents.constructors import constructor\n", "from delphi.latents.samplers import sample\n", "from delphi.pipeline import Pipeline, process_wrapper\n", "from delphi.scorers import FuzzingScorer\n", @@ -46,7 +46,8 @@ "source": [ "latent_cfg = LatentConfig(\n", " width=131072, # The number of latents of your SAE\n", - " min_examples=200, # The minimum number of examples to consider for the latent to be explained\n", + " min_examples=200, \n", + " # The minimum number of examples to consider for the latent to be explained\n", " max_examples=10000, # The maximum number of examples to be sampled from\n", " n_splits=5 # How many splits was the cache split into\n", ")\n" @@ -103,9 +104,8 @@ "metadata": {}, "outputs": [], "source": [ - "constructor=partial(\n", - " default_constructor,\n", - " token_loader=None,\n", + "example_constructor=partial(\n", + " constructor,\n", " n_not_active=experiment_cfg.n_non_activating, \n", " ctx_len=experiment_cfg.example_ctx_len, \n", " max_examples=latent_cfg.max_examples\n", @@ -220,7 +220,8 @@ " scorer_pipe,\n", ")\n", "number_of_parallel_latents = 10\n", - "await pipeline.run(number_of_parallel_latents) # This will start generating the explanations." + "await pipeline.run(number_of_parallel_latents) \n", + "# This will start generating the explanations." ] }, { From 222ab8b4d50dc892cf036a529b19f59c92906dec Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 18 Feb 2025 17:08:11 +0000 Subject: [PATCH 109/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/latents/cache.py | 6 +++--- delphi/latents/samplers.py | 10 ++++------ 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 4fd8606a..dbc9d17f 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -39,9 +39,9 @@ def __init__( self.latent_activations_batches: dict[ str, list[Float[Tensor, "batch sequence num_latents"]] ] = defaultdict(list) - self.tokens_batches: dict[ - str, list[Float[Tensor, "batch sequence"]] - ] = defaultdict(list) + self.tokens_batches: dict[str, list[Float[Tensor, "batch sequence"]]] = ( + defaultdict(list) + ) self.latent_locations: dict[ str, Float[Tensor, "batch sequence num_latents"] diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index 85cbc18a..f4b448be 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -7,8 +7,8 @@ def normalize_activations( - examples: list[ActivatingExample], max_activation: float - ) -> list[ActivatingExample]: + examples: list[ActivatingExample], max_activation: float +) -> list[ActivatingExample]: for example in examples: example.normalized_activations = ( example.activations * 10 / max_activation @@ -67,14 +67,12 @@ def train( logger.warning( "n_train is greater than the number of examples, using all examples" ) - + selected_examples = random.sample(examples, n_train) selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples case "quantiles": - selected_examples = split_quantiles( - examples, n_quantiles, n_train - ) + selected_examples = split_quantiles(examples, n_quantiles, n_train) selected_examples = normalize_activations(selected_examples, max_activation) return selected_examples From a232531e21a7c88e3ced06a8de0e76e3c0f4018c Mon Sep 17 00:00:00 2001 From: SrGonao Date: Tue, 18 Feb 2025 17:12:11 +0000 Subject: [PATCH 110/132] Shouldn't be here --- delphi/explainers/contrastive_explainer.py | 75 ---------------------- 1 file changed, 75 deletions(-) delete mode 100644 delphi/explainers/contrastive_explainer.py diff --git a/delphi/explainers/contrastive_explainer.py b/delphi/explainers/contrastive_explainer.py deleted file mode 100644 index 9245537a..00000000 --- a/delphi/explainers/contrastive_explainer.py +++ /dev/null @@ -1,75 +0,0 @@ -import asyncio - -import faiss - -from delphi.explainers.default.prompt_builder import build_single_token_prompt -from delphi.explainers.explainer import Explainer, ExplainerResult -from delphi.logger import logger - - -class ContrastiveExplainer(Explainer): - name = "contrastive" - - def __init__( - self, - client, - tokenizer, - index: faiss.IndexFlatL2, - verbose: bool = False, - activations: bool = False, - cot: bool = False, - threshold: float = 0.6, - temperature: float = 0.0, - **generation_kwargs, - ): - self.client = client - self.tokenizer = tokenizer - self.index = index - self.verbose = verbose - - self.activations = activations - self.cot = cot - self.threshold = threshold - self.temperature = temperature - self.generation_kwargs = generation_kwargs - - async def __call__(self, record): - # Need to change __call__ to use index - - messages = self._build_prompt(record.train) - - response = await self.client.generate( - messages, temperature=self.temperature, **self.generation_kwargs - ) - - try: - explanation = self.parse_explanation(response.text) - if self.verbose: - return ( - messages[-1]["content"], - response, - ExplainerResult(record=record, explanation=explanation), - ) - - return ExplainerResult(record=record, explanation=explanation) - except Exception as e: - logger.error(f"Explanation parsing failed: {e}") - return ExplainerResult( - record=record, explanation="Explanation could not be parsed." - ) - - def _build_prompt(self, examples): - highlighted_examples = [] - - for i, example in enumerate(examples): - highlighted_examples.append(self._highlight(i + 1, example)) - - if self.activations: - highlighted_examples.append(self._join_activations(example)) - - return build_single_token_prompt( - examples=highlighted_examples, - ) - - def call_sync(self, record): - return asyncio.run(self.__call__(record)) From 035281e77eccdf1e3bb21c32430bced8dd170f1d Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:36:11 +0000 Subject: [PATCH 111/132] Fixing typing, switching to torch, and removing comments --- delphi/latents/neighbours.py | 133 ++++++----------------------------- 1 file changed, 23 insertions(+), 110 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index b41725de..45ded0ee 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -22,8 +22,7 @@ class NeighbourCalculator: def __init__( self, cache_dir: Optional[Path] = None, - autoencoder: Optional[nn.Module | Sae] = None, - # pre_activation_record: Optional["PreActivationRecord"] = None, + autoencoder: Optional[nn.Module] = None, number_of_neighbours: int = 10, neighbour_cache: Optional[dict[str, dict[int, list[tuple[int, float]]]]] = None, ): @@ -67,14 +66,6 @@ def _compute_neighbour_list( self.neighbour_cache[method] = self._compute_similarity_neighbours( "decoder" ) - # elif method == "correlation": - # if self.autoencoder is None or self.residual_stream_record is None: - # raise ValueError( - # "Autoencoder and residual stream record are required " - # "for correlation-based neighbours" - # ) - # self.neighbour_cache[method] = self._compute_correlation_neighbours() - elif method == "co-occurrence": self.neighbour_cache[method] = self._compute_cooccurrence_neighbours() @@ -137,7 +128,7 @@ def _compute_similarity_neighbours( del similarity_matrix torch.cuda.empty_cache() done = True - except Exception: + except RuntimeError: # Out of memory batch_size = batch_size // 2 if batch_size < 2: raise ValueError( @@ -154,9 +145,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] Code adapted from https://github.com/taha-yassine/SAE-features/blob/main/cooccurrences/compute.py """ - import cupy as cp - import cupyx.scipy.sparse as cusparse - print("Computing co-occurrence neighbours") assert ( self.cache_dir is not None @@ -165,7 +153,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] all_locations = [] for path in paths: - split_data = load_file(path) + split_data = load_file(self.cache_dir / path) first_feature = int(path.split("/")[-1].split("_")[0]) locations = torch.tensor(split_data["locations"].astype(np.int64)) locations[:, 2] = locations[:, 2] + first_feature @@ -174,16 +162,20 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] # concatenate the locations and activations locations = torch.cat(all_locations) - n_latents = int(torch.max(locations[:, 2])) + 1 + latent_index = locations[:, 2] + batch_index = locations[:, 0] + ctx_index = locations[:, 1] + + n_latents = int(torch.max(latent_index)) + 1 # 1. Get unique values of first 2 dims (i.e. absolute token index) # and their counts # Trick is to use Cantor pairing function to have a bijective mapping between # (batch_id, ctx_pos) and a unique 1D index # Faster than running `torch.unique_consecutive` on the first 2 dims - idx_cantor = (locations[:, 0] + locations[:, 1]) * ( - locations[:, 0] + locations[:, 1] + 1 - ) // 2 + locations[:, 1] + idx_cantor = (batch_index + ctx_index) * ( + batch_index + ctx_index + 1 + ) // 2 + ctx_index unique_idx, idx_counts = torch.unique_consecutive( idx_cantor, return_counts=True ) @@ -196,26 +188,27 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] ) del idx_cantor, unique_idx, idx_counts - rows = cp.asarray(locations[:, 2]) - cols = cp.asarray(locations_flat) - data = cp.ones(len(rows)) - sparse_matrix = cusparse.coo_matrix( - (data, (rows, cols)), shape=(n_latents, n_tokens) + # rows = cp.asarray(locations[:, 2]) + # cols = cp.asarray(locations_flat) + # data = cp.ones(len(rows)) + sparse_matrix_indices = torch.stack([locations_flat, latent_index]) + sparse_matrix = torch.sparse_coo_tensor( + sparse_matrix_indices, torch.ones(len(latent_index)), (n_latents, n_tokens) ) token_batch_size = 100_000 - cooc_matrix = cp.zeros((n_latents, n_latents), dtype=cp.float32) + cooc_matrix = torch.zeros((n_latents, n_latents), dtype=torch.float32) - sparse_matrix_csc = sparse_matrix.tocsc() + sparse_matrix_csc = sparse_matrix.to_sparse_csr() for start in tqdm(range(0, n_tokens, token_batch_size)): end = min(n_tokens, start + token_batch_size) # Slice the sparse matrix to get a batch of tokens. sub_matrix = sparse_matrix_csc[:, start:end] # Compute the partial co-occurrence matrix for this batch. - partial_cooc = (sub_matrix @ sub_matrix.T).toarray() + partial_cooc = (sub_matrix @ sub_matrix.T).to_dense() cooc_matrix += partial_cooc # Free temporary variables. - del rows, cols, data, sparse_matrix, sparse_matrix_csc + del sparse_matrix, sparse_matrix_csc # Compute Jaccard similarity def compute_jaccard(cooc_matrix): @@ -229,12 +222,11 @@ def compute_jaccard(cooc_matrix): # Compute Jaccard similarity matrix jaccard_matrix = compute_jaccard(cooc_matrix) - jaccard_torch = torch.as_tensor(cp.asnumpy(jaccard_matrix)) # get the indices of the top k neighbours for each feature top_k_indices, values = torch.topk( - jaccard_torch, self.number_of_neighbours + 1, dim=1 + jaccard_matrix, self.number_of_neighbours + 1, dim=1 ) - del jaccard_matrix, cooc_matrix, jaccard_torch + del jaccard_matrix, cooc_matrix torch.cuda.empty_cache() # return the neighbour lists @@ -268,82 +260,3 @@ def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: """ with open(path, "r") as f: return json.load(f) - - -# TODO: add correlation neighbours, by re-adding activation records -# def _compute_correlation_neighbours(self) -> dict[int, list[int]]: -# """ -# Compute neighbour lists based on activation correlation patterns. -# """ -# print("Computing correlation neighbours") - -# # the activation_matrix has the shape (number_of_samples,hidden_dimension) - -# activations = torch.tensor( -# load_file(self.residual_stream_record + ".safetensors")["activations"] -# ) - -# estimator = CovarianceEstimator(activations.shape[1]) -# # batch the activations -# batch_size = 10000 -# for i in tqdm(range(0, activations.shape[0], batch_size)): -# estimator.update(activations[i : i + batch_size]) - -# covariance_matrix = estimator.cov().cuda().half() - -# # load the encoder -# encoder_matrix = self.autoencoder.encoder.weight.cuda().half() - -# covariance_between_latents = torch.zeros( -# (encoder_matrix.shape[0], encoder_matrix.shape[0]), device="cpu" -# ) - -# # do batches of latents -# batch_size = 1024 -# for start in tqdm(range(0, encoder_matrix.shape[0], batch_size)): -# end = min(encoder_matrix.shape[0], start + batch_size) -# encoder_rows = encoder_matrix[start:end] - -# correlation = encoder_rows @ covariance_matrix @ encoder_matrix.T -# covariance_between_latents[start:end] = correlation.cpu() - -# # the correlation is then the covariance divided -# # by the product of the standard deviations -# diagonal_covariance = torch.diag(covariance_between_latents) -# product_of_std = torch.sqrt( -# torch.outer(diagonal_covariance, diagonal_covariance) + 1e-6 -# ) -# correlation_matrix = covariance_between_latents / product_of_std - -# # get the indices of the top k neighbours for each feature -# indices, values = torch.topk( -# correlation_matrix, self.number_of_neighbours + 1, dim=1 -# ) - -# # return the neighbour lists -# return { -# i: list(zip(indices[i].tolist()[1:], values[i].tolist()[1:])) -# for i in range(len(indices)) -# } -# class CovarianceEstimator: -# def __init__(self, n_latents, *, device=None): -# self.mean = torch.zeros(n_latents, device=device) -# self.cov_ = torch.zeros(n_latents, n_latents, device=device) -# self.n = 0 - -# def update(self, x: torch.Tensor): -# n, d = x.shape -# assert d == len(self.mean) - -# self.n += n - -# # Welford's online algorithm -# delta = x - self.mean -# self.mean.add_(delta.sum(dim=0), alpha=1 / self.n) -# delta2 = x - self.mean - -# self.cov_.addmm_(delta.mH, delta2) - -# def cov(self): -# """Return the estimated covariance matrix.""" -# return self.cov_ / self.n From e09da3490dd09311b965516d32a46a2d1609daf2 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:37:10 +0000 Subject: [PATCH 112/132] Removing extra function, adding seed --- delphi/latents/constructors.py | 37 +++++++++++----------------------- 1 file changed, 12 insertions(+), 25 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 2825249a..ad3149f6 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -8,26 +8,6 @@ from .loader import ActivationData -def prepare_activating_examples( - tokens: Float[Tensor, "examples ctx_len"], - activations: Float[Tensor, "examples ctx_len"], -) -> list[ActivatingExample]: - """ - Prepare a list of examples from input tokens and activations. - - Args: - tokens: Tokenized input sequences. - activations: Activation values for the input sequences. - - Returns: - list[Example]: A list of prepared examples. - """ - return [ - ActivatingExample(tokens=toks, activations=acts, normalized_activations=None) - for toks, acts in zip(tokens, activations) - ] - - def prepare_non_activating_examples( tokens: Float[Tensor, "examples ctx_len"], distance: float, @@ -130,6 +110,7 @@ def constructor( constructor_type: Literal["random", "neighbours"], tokens: Float[Tensor, "batch seq"], all_data: Optional[dict[int, ActivationData]] = None, + seed: int = 42, ): cache_token_length = tokens.shape[1] @@ -138,7 +119,7 @@ def constructor( activation_data.locations[:, 0] * cache_token_length + activation_data.locations[:, 1] ) - ctx_indices = flat_indices // ctx_len + ctx_indices, index_within_ctx = flat_indices // ctx_len, flat_indices % ctx_len index_within_ctx = flat_indices % ctx_len reshaped_tokens = tokens.reshape(-1, ctx_len) n_windows = reshaped_tokens.shape[0] @@ -160,8 +141,10 @@ def constructor( ctx_len=ctx_len, max_examples=max_examples, ) - record.examples = prepare_activating_examples(token_windows, act_windows) - + record.examples = [ + ActivatingExample(tokens=toks, activations=acts, normalized_activations=None) + for toks, acts in zip(token_windows, act_windows) + ] if constructor_type == "random": # Add random non-activating examples to the record in place random_non_activating_windows( @@ -169,6 +152,7 @@ def constructor( available_indices=non_active_indices, reshaped_tokens=reshaped_tokens, n_not_active=n_not_active, + seed=seed, ) elif constructor_type == "neighbours": assert all_data is not None, "All data is required for neighbour constructor" @@ -179,6 +163,7 @@ def constructor( all_data=all_data, ctx_len=ctx_len, n_not_active=n_not_active, + seed=seed, ) @@ -189,6 +174,7 @@ def neighbour_non_activation_windows( all_data: dict[int, ActivationData], ctx_len: int, n_not_active: int, + seed: int = 42, ): """ Generate random activation windows and update the latent record. @@ -201,7 +187,7 @@ def neighbour_non_activation_windows( ctx_len (int): The context length. n_random (int): The number of random examples to generate. """ - torch.manual_seed(22) + torch.manual_seed(seed) if n_not_active == 0: record.not_active = [] return @@ -276,6 +262,7 @@ def random_non_activating_windows( available_indices: Float[Tensor, "windows"], reshaped_tokens: Float[Tensor, "windows ctx_len"], n_not_active: int, + seed: int = 42, ): """ Generate random non-activating sequence windows and update the latent record. @@ -288,7 +275,7 @@ def random_non_activating_windows( to the context length. n_not_active (int): The number of non activating examples to generate. """ - torch.manual_seed(22) + torch.manual_seed(seed) if n_not_active == 0: record.not_active = [] return From 6a5a7acdc8fe06fe8fe04b24420d562a3f2c8e17 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:37:34 +0000 Subject: [PATCH 113/132] Trying out dataclass --- delphi/explainers/explainer.py | 35 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 19 deletions(-) diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 67fb6840..6cf9c237 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -3,10 +3,13 @@ import random import re from abc import ABC, abstractmethod +from dataclasses import dataclass from typing import NamedTuple import aiofiles +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from ..clients.client import Client from ..latents.latents import ActivatingExample, LatentRecord from ..logger import logger @@ -19,30 +22,24 @@ class ExplainerResult(NamedTuple): """Generated explanation for latent.""" +@dataclass class Explainer(ABC): """ Abstract base class for explainers. """ - def __init__( - self, - client, - tokenizer, - verbose: bool = False, - activations: bool = False, - cot: bool = False, - threshold: float = 0.6, - temperature: float = 0.0, - **generation_kwargs, - ): - self.client = client - self.tokenizer = tokenizer - self.verbose = verbose - self.activations = activations - self.cot = cot - self.threshold = threshold - self.temperature = temperature - self.generation_kwargs = generation_kwargs + client: Client + """Client to use for explanation generation. """ + tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast + """The tokenizer used to collect activations.""" + verbose: bool + """Whether to print verbose output.""" + threshold: float + """The activation threshold to select tokens to highlight.""" + temperature: float + """The temperature for explanation generation.""" + generation_kwargs: dict + """Additional keyword arguments for the generation client.""" async def __call__(self, record: LatentRecord) -> ExplainerResult: messages = self._build_prompt(record.train) From 7fee61ce338c50878f2b727c4415c8ad1537bb31 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:38:07 +0000 Subject: [PATCH 114/132] Dataclass things --- delphi/explainers/default/default.py | 29 ++++++---------------------- 1 file changed, 6 insertions(+), 23 deletions(-) diff --git a/delphi/explainers/default/default.py b/delphi/explainers/default/default.py index 5171a75d..d3c31664 100644 --- a/delphi/explainers/default/default.py +++ b/delphi/explainers/default/default.py @@ -1,33 +1,16 @@ import asyncio +from dataclasses import dataclass from ..explainer import ActivatingExample, Explainer from .prompt_builder import build_prompt +@dataclass class DefaultExplainer(Explainer): - name = "default" - - def __init__( - self, - client, - tokenizer, - verbose: bool = False, - activations: bool = False, - cot: bool = False, - threshold: float = 0.6, - temperature: float = 0.0, - **generation_kwargs, - ): - super().__init__( - client, - tokenizer, - verbose, - activations, - cot, - threshold, - temperature, - **generation_kwargs, - ) + activations: bool + """Whether to show activations to the explainer.""" + cot: bool + """Whether to use chain of thought reasoning.""" def _build_prompt(self, examples: list[ActivatingExample]) -> list[dict]: highlighted_examples = [] From d4abc2e12a247953d42863d8b561139d1dab898a Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:38:29 +0000 Subject: [PATCH 115/132] Adding tensor alias --- delphi/latents/cache.py | 40 +++++++++++++++++++++------------------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index dbc9d17f..e515596a 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -14,6 +14,8 @@ from delphi.config import CacheConfig from delphi.latents.collect_activations import collect_activations +location_tensor_shape = Float[Tensor, "batch sequence num_latents"] +token_tensor_shape = Float[Tensor, "batch sequence"] class Cache: """ @@ -34,30 +36,30 @@ def __init__( batch_size: Size of batches for processing. Defaults to 64. """ self.latent_locations_batches: dict[ - str, list[Float[Tensor, "batch sequence num_latents"]] + str, list[location_tensor_shape] ] = defaultdict(list) self.latent_activations_batches: dict[ - str, list[Float[Tensor, "batch sequence num_latents"]] + str, list[location_tensor_shape] ] = defaultdict(list) - self.tokens_batches: dict[str, list[Float[Tensor, "batch sequence"]]] = ( + self.tokens_batches: dict[str, list[token_tensor_shape]] = ( defaultdict(list) ) self.latent_locations: dict[ - str, Float[Tensor, "batch sequence num_latents"] + str, location_tensor_shape ] = {} self.latent_activations: dict[ - str, Float[Tensor, "batch sequence num_latents"] + str, location_tensor_shape ] = {} - self.tokens: dict[str, Float[Tensor, "batch sequence"]] = {} + self.tokens: dict[str, token_tensor_shape] = {} self.filters = filters self.batch_size = batch_size def add( self, - latents: Float[Tensor, "mini_batch sequence num_latents"], - tokens: Float[Tensor, "mini_batch sequence"], + latents: location_tensor_shape, + tokens: token_tensor_shape, batch_number: int, module_path: str, ): @@ -99,7 +101,7 @@ def save(self): ) def get_nonzeros_batch( - self, latents: Float[Tensor, "batch sequence num_latents"] + self, latents: location_tensor_shape ) -> tuple[ Float[Tensor, "batch sequence num_latents"], Float[Tensor, "batch sequence "] ]: @@ -137,10 +139,10 @@ def get_nonzeros_batch( return nonzero_latent_locations, nonzero_latent_activations def get_nonzeros( - self, latents: Float[Tensor, "batch sequence num_latents"], module_path: str + self, latents: location_tensor_shape, module_path: str ) -> tuple[ - Float[Tensor, "batch sequence num_latents"], - Float[Tensor, "batch sequence num_latents"], + location_tensor_shape, + location_tensor_shape, ]: """ Get the nonzero latent locations and activations. @@ -206,8 +208,8 @@ def __init__( self.filter_submodules(filters) def load_token_batches( - self, n_tokens: int, tokens: Float[Tensor, "batch sequence"] - ) -> list[Float[Tensor, "batch sequence"]]: + self, n_tokens: int, tokens: token_tensor_shape + ) -> list[token_tensor_shape]: """ Load and prepare token batches for processing. @@ -243,7 +245,7 @@ def filter_submodules(self, filters: dict[str, Float[Tensor, "indices"]]): filtered_submodules[hookpoint] = self.hookpoint_to_sae[hookpoint] self.hookpoint_to_sae = filtered_submodules - def run(self, n_tokens: int, tokens: Float[Tensor, "batch sequence"]): + def run(self, n_tokens: int, tokens: token_tensor_shape): """ Run the latent caching process. @@ -296,13 +298,13 @@ def save(self, save_dir: Path, save_tokens: bool = True): output_file = save_dir / f"{module_path}.safetensors" data = { - "locations": self.cache.latent_locations[module_path], - "activations": self.cache.latent_activations[module_path], + "locations": self.cache.latent_locations[module_path].numpy(), + "activations": self.cache.latent_activations[module_path].numpy(), } if save_tokens: - data["tokens"] = self.cache.tokens[module_path] + data["tokens"] = self.cache.tokens[module_path].numpy() - save_file(data, output_file) # type: ignore + save_file(data, output_file) def _generate_split_indices(self, n_splits: int) -> list[tuple[Tensor, Tensor]]: """ From 1a85b54cb6ac457595cebb3c2c83061665460402 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:48:53 +0000 Subject: [PATCH 116/132] Remove old comment --- delphi/scorers/classifier/detection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/delphi/scorers/classifier/detection.py b/delphi/scorers/classifier/detection.py index ffb84340..b01a07f3 100644 --- a/delphi/scorers/classifier/detection.py +++ b/delphi/scorers/classifier/detection.py @@ -51,7 +51,6 @@ def _prepare(self, record: LatentRecord) -> list[Sample]: Prepare and shuffle a list of samples for classification. """ - # check if not_active is a list of lists or a list of examples if len(record.not_active) > 0: samples = examples_to_samples( record.not_active, From eb877446467f05126258dea07d8293c3d3ff3ebb Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 11:49:09 +0000 Subject: [PATCH 117/132] Add defaults --- delphi/explainers/explainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/delphi/explainers/explainer.py b/delphi/explainers/explainer.py index 6cf9c237..531d5f4a 100644 --- a/delphi/explainers/explainer.py +++ b/delphi/explainers/explainer.py @@ -3,7 +3,7 @@ import random import re from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import NamedTuple import aiofiles @@ -32,13 +32,13 @@ class Explainer(ABC): """Client to use for explanation generation. """ tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast """The tokenizer used to collect activations.""" - verbose: bool + verbose: bool = False """Whether to print verbose output.""" - threshold: float + threshold: float = 0.3 """The activation threshold to select tokens to highlight.""" - temperature: float + temperature: float = 0.0 """The temperature for explanation generation.""" - generation_kwargs: dict + generation_kwargs: dict = field(default_factory=dict) """Additional keyword arguments for the generation client.""" async def __call__(self, record: LatentRecord) -> ExplainerResult: From 8bfbe330dcf68570efe7f141b3ec67fa76c72ac3 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 14:06:09 +0000 Subject: [PATCH 118/132] Remove cupy dependency, --- delphi/latents/neighbours.py | 95 ++++++++++++++++++------------------ 1 file changed, 47 insertions(+), 48 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 45ded0ee..10611140 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -153,62 +153,61 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] all_locations = [] for path in paths: - split_data = load_file(self.cache_dir / path) - first_feature = int(path.split("/")[-1].split("_")[0]) - locations = torch.tensor(split_data["locations"].astype(np.int64)) - locations[:, 2] = locations[:, 2] + first_feature - # compute number of tokens - all_locations.append(locations) + if path.endswith(".safetensors"): + split_data = load_file(self.cache_dir / path) + first_feature = int(path.split("/")[-1].split("_")[0]) + locations = torch.tensor(split_data["locations"].astype(np.int64)) + locations[:, 2] = locations[:, 2] + first_feature + + all_locations.append(locations) # concatenate the locations and activations locations = torch.cat(all_locations) - latent_index = locations[:, 2] + batch_index = locations[:, 0] ctx_index = locations[:, 1] - + latent_index = locations[:, 2] + n_latents = int(torch.max(latent_index)) + 1 - # 1. Get unique values of first 2 dims (i.e. absolute token index) - # and their counts - # Trick is to use Cantor pairing function to have a bijective mapping between - # (batch_id, ctx_pos) and a unique 1D index - # Faster than running `torch.unique_consecutive` on the first 2 dims - idx_cantor = (batch_index + ctx_index) * ( - batch_index + ctx_index + 1 - ) // 2 + ctx_index - unique_idx, idx_counts = torch.unique_consecutive( - idx_cantor, return_counts=True - ) - n_tokens = len(unique_idx) + # Convert from (batch_id, ctx_pos) to a unique 1D index + + idx_cantor = batch_index * ctx_index + ctx_index - # 2. The Cantor indices are not consecutive, - # so we create sorted ones from the counts - locations_flat = torch.repeat_interleave( - torch.arange(n_tokens, device=locations.device), idx_counts - ) - del idx_cantor, unique_idx, idx_counts - - # rows = cp.asarray(locations[:, 2]) - # cols = cp.asarray(locations_flat) - # data = cp.ones(len(rows)) - sparse_matrix_indices = torch.stack([locations_flat, latent_index]) - sparse_matrix = torch.sparse_coo_tensor( - sparse_matrix_indices, torch.ones(len(latent_index)), (n_latents, n_tokens) - ) - token_batch_size = 100_000 - cooc_matrix = torch.zeros((n_latents, n_latents), dtype=torch.float32) + # Sort the indices, because they are not sorted after concatenation + idx_cantor, idx_cantor_sorted_idx = idx_cantor.sort(dim=0,stable=True) + latent_index = latent_index[idx_cantor_sorted_idx] + + n_tokens = int(idx_cantor.max().item()) + + token_batch_size = 20_000 + + # Find indices where idx_cantor crosses each batch boundary + bounday_values = torch.arange(token_batch_size,n_tokens,token_batch_size) + + batch_boundaries_tensor = torch.searchsorted(idx_cantor,bounday_values) + batch_boundaries = [0] + batch_boundaries_tensor.tolist() + + if batch_boundaries[-1] != len(idx_cantor): + batch_boundaries.append(len(idx_cantor)) - sparse_matrix_csc = sparse_matrix.to_sparse_csr() - for start in tqdm(range(0, n_tokens, token_batch_size)): - end = min(n_tokens, start + token_batch_size) - # Slice the sparse matrix to get a batch of tokens. - sub_matrix = sparse_matrix_csc[:, start:end] - # Compute the partial co-occurrence matrix for this batch. - partial_cooc = (sub_matrix @ sub_matrix.T).to_dense() - cooc_matrix += partial_cooc + co_occurrence_matrix = torch.zeros((n_latents, n_latents), dtype=torch.int32).cuda() + + for start,end in tqdm(zip(batch_boundaries[:-1],batch_boundaries[1:])): + # get all ind_cantor values between start and start + token_batch_size + selected_idx_cantor = idx_cantor[start:end] + selected_latent_index = latent_index[start:end] + + # create a sparse matrix of the selected indices + sparse_matrix_indices = torch.stack([selected_latent_index,selected_idx_cantor],dim=0) + sparse_matrix = torch.sparse_coo_tensor( + sparse_matrix_indices, torch.ones(len(selected_latent_index)), (n_latents, token_batch_size) + ) + sparse_matrix = sparse_matrix.cuda() + partial_cooc = (sparse_matrix @ sparse_matrix.T).to_dense() + co_occurrence_matrix += partial_cooc.int() + del sparse_matrix,partial_cooc - # Free temporary variables. - del sparse_matrix, sparse_matrix_csc # Compute Jaccard similarity def compute_jaccard(cooc_matrix): @@ -220,13 +219,13 @@ def compute_jaccard(cooc_matrix): return jaccard_matrix # Compute Jaccard similarity matrix - jaccard_matrix = compute_jaccard(cooc_matrix) + jaccard_matrix = compute_jaccard(co_occurrence_matrix) # get the indices of the top k neighbours for each feature top_k_indices, values = torch.topk( jaccard_matrix, self.number_of_neighbours + 1, dim=1 ) - del jaccard_matrix, cooc_matrix + del jaccard_matrix, co_occurrence_matrix torch.cuda.empty_cache() # return the neighbour lists From 4dda923ee22779b4a4bb22725b2d05d714d09611 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 14:12:55 +0000 Subject: [PATCH 119/132] Correctly handle more than one hookpoint --- delphi/__main__.py | 24 ++++++++++++++---------- delphi/latents/neighbours.py | 12 ++++++++---- delphi/latents/transforms.py | 9 +++++++-- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index d5d3b38d..1a548f89 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -62,18 +62,21 @@ def load_artifacts(run_cfg: RunConfig): async def create_neighbours( latents_path: Path, neighbours_path: Path, + hookpoints: list[str], ): """ Creates a neighbours file for the given hookpoints. """ - neighbour_calculator = NeighbourCalculator( - cache_dir=latents_path, number_of_neighbours=100 - ) - - neighbour_calculator.populate_neighbour_cache(["co-occurrence"]) neighbours_path.mkdir(parents=True, exist_ok=True) - neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/neighbours.json") + for hookpoint in hookpoints: + neighbour_calculator = NeighbourCalculator( + cache_dir=latents_path / hookpoint, number_of_neighbours=100 + ) + + neighbour_calculator.populate_neighbour_cache(["co-occurrence"]) + + neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}.json") async def process_cache( @@ -117,8 +120,10 @@ async def process_cache( ) sampler = partial(sample, cfg=experiment_cfg) if experiment_cfg.non_activating_source == "neighbours": - with open(neighbours_path / "neighbours.json", "r") as f: - neighbours = json.load(f)["co-occurrence"] + neighbours = {} + for hookpoint in hookpoints: + with open(neighbours_path / f"{hookpoint}.json", "r") as f: + neighbours[hookpoint] = json.load(f)["co-occurrence"] transform = partial( set_neighbours, neighbours=neighbours, @@ -327,8 +332,7 @@ async def run( or "neighbours" in run_cfg.overwrite ): await create_neighbours( - latents_path, - neighbours_path, + latents_path, neighbours_path, hookpoints ) else: print(f"Files found in {neighbours_path}, skipping...") diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 10611140..142cdfe3 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -191,7 +191,8 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] if batch_boundaries[-1] != len(idx_cantor): batch_boundaries.append(len(idx_cantor)) - co_occurrence_matrix = torch.zeros((n_latents, n_latents), dtype=torch.int32).cuda() + co_occurrence_matrix = torch.zeros((n_latents, n_latents), dtype=torch.int32) + co_occurrence_matrix = co_occurrence_matrix.cuda() for start,end in tqdm(zip(batch_boundaries[:-1],batch_boundaries[1:])): # get all ind_cantor values between start and start + token_batch_size @@ -199,14 +200,17 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] selected_latent_index = latent_index[start:end] # create a sparse matrix of the selected indices - sparse_matrix_indices = torch.stack([selected_latent_index,selected_idx_cantor],dim=0) + sparse_matrix_indices = torch.stack( + [selected_latent_index,selected_idx_cantor],dim=0) sparse_matrix = torch.sparse_coo_tensor( - sparse_matrix_indices, torch.ones(len(selected_latent_index)), (n_latents, token_batch_size) + sparse_matrix_indices, + torch.ones(len(selected_latent_index)), + (n_latents, token_batch_size), ) sparse_matrix = sparse_matrix.cuda() partial_cooc = (sparse_matrix @ sparse_matrix.T).to_dense() co_occurrence_matrix += partial_cooc.int() - del sparse_matrix,partial_cooc + del sparse_matrix, partial_cooc # Compute Jaccard similarity diff --git a/delphi/latents/transforms.py b/delphi/latents/transforms.py index 3f4f88e9..1bf53c9c 100644 --- a/delphi/latents/transforms.py +++ b/delphi/latents/transforms.py @@ -3,14 +3,19 @@ def set_neighbours( record: LatentRecord, - neighbours: dict[str, list[tuple[float, int]]], + neighbours: dict[str, dict[str, list[tuple[float, int]]]], threshold: float, ): """ Set the neighbours for the latent record. + Neighbours should be a dictionary with module names as keys, + where the values are a dictionary of latent indices as keys, + and a list of tuples of (distance,feature_index) as values. """ - latent_neighbours = neighbours[str(record.latent.latent_index)] + latent_neighbours = neighbours[record.latent.module_name][ + str(record.latent.latent_index) + ] # Each element in neighbours is a tuple of (distance,feature_index) # We want to keep only the ones with a distance less than the threshold From c6d9dca24dc95a77873271fc0e55733b87e88660 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Feb 2025 14:14:46 +0000 Subject: [PATCH 120/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/__main__.py | 4 +--- delphi/latents/cache.py | 27 ++++++++++----------------- delphi/latents/neighbours.py | 26 +++++++++++++------------- 3 files changed, 24 insertions(+), 33 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 1a548f89..552ce8e7 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -331,9 +331,7 @@ async def run( not glob(str(neighbours_path / ".*")) + glob(str(neighbours_path / "*")) or "neighbours" in run_cfg.overwrite ): - await create_neighbours( - latents_path, neighbours_path, hookpoints - ) + await create_neighbours(latents_path, neighbours_path, hookpoints) else: print(f"Files found in {neighbours_path}, skipping...") diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index e515596a..59907425 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -17,6 +17,7 @@ location_tensor_shape = Float[Tensor, "batch sequence num_latents"] token_tensor_shape = Float[Tensor, "batch sequence"] + class Cache: """ The Cache class stores latent locations and activations for modules. @@ -35,22 +36,16 @@ def __init__( filters: Filters for selecting specific latents. batch_size: Size of batches for processing. Defaults to 64. """ - self.latent_locations_batches: dict[ - str, list[location_tensor_shape] - ] = defaultdict(list) - self.latent_activations_batches: dict[ - str, list[location_tensor_shape] - ] = defaultdict(list) - self.tokens_batches: dict[str, list[token_tensor_shape]] = ( + self.latent_locations_batches: dict[str, list[location_tensor_shape]] = ( + defaultdict(list) + ) + self.latent_activations_batches: dict[str, list[location_tensor_shape]] = ( defaultdict(list) ) + self.tokens_batches: dict[str, list[token_tensor_shape]] = defaultdict(list) - self.latent_locations: dict[ - str, location_tensor_shape - ] = {} - self.latent_activations: dict[ - str, location_tensor_shape - ] = {} + self.latent_locations: dict[str, location_tensor_shape] = {} + self.latent_activations: dict[str, location_tensor_shape] = {} self.tokens: dict[str, token_tensor_shape] = {} self.filters = filters @@ -138,9 +133,7 @@ def get_nonzeros_batch( nonzero_latent_activations = torch.cat(nonzero_latent_activations, dim=0) return nonzero_latent_locations, nonzero_latent_activations - def get_nonzeros( - self, latents: location_tensor_shape, module_path: str - ) -> tuple[ + def get_nonzeros(self, latents: location_tensor_shape, module_path: str) -> tuple[ location_tensor_shape, location_tensor_shape, ]: @@ -304,7 +297,7 @@ def save(self, save_dir: Path, save_tokens: bool = True): if save_tokens: data["tokens"] = self.cache.tokens[module_path].numpy() - save_file(data, output_file) + save_file(data, output_file) def _generate_split_indices(self, n_splits: int) -> list[tuple[Tensor, Tensor]]: """ diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 142cdfe3..219dad00 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -163,29 +163,29 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] # concatenate the locations and activations locations = torch.cat(all_locations) - + batch_index = locations[:, 0] ctx_index = locations[:, 1] latent_index = locations[:, 2] - + n_latents = int(torch.max(latent_index)) + 1 # Convert from (batch_id, ctx_pos) to a unique 1D index - + idx_cantor = batch_index * ctx_index + ctx_index # Sort the indices, because they are not sorted after concatenation - idx_cantor, idx_cantor_sorted_idx = idx_cantor.sort(dim=0,stable=True) + idx_cantor, idx_cantor_sorted_idx = idx_cantor.sort(dim=0, stable=True) latent_index = latent_index[idx_cantor_sorted_idx] - + n_tokens = int(idx_cantor.max().item()) - + token_batch_size = 20_000 # Find indices where idx_cantor crosses each batch boundary - bounday_values = torch.arange(token_batch_size,n_tokens,token_batch_size) - - batch_boundaries_tensor = torch.searchsorted(idx_cantor,bounday_values) + bounday_values = torch.arange(token_batch_size, n_tokens, token_batch_size) + + batch_boundaries_tensor = torch.searchsorted(idx_cantor, bounday_values) batch_boundaries = [0] + batch_boundaries_tensor.tolist() if batch_boundaries[-1] != len(idx_cantor): @@ -194,14 +194,15 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] co_occurrence_matrix = torch.zeros((n_latents, n_latents), dtype=torch.int32) co_occurrence_matrix = co_occurrence_matrix.cuda() - for start,end in tqdm(zip(batch_boundaries[:-1],batch_boundaries[1:])): + for start, end in tqdm(zip(batch_boundaries[:-1], batch_boundaries[1:])): # get all ind_cantor values between start and start + token_batch_size selected_idx_cantor = idx_cantor[start:end] selected_latent_index = latent_index[start:end] - + # create a sparse matrix of the selected indices sparse_matrix_indices = torch.stack( - [selected_latent_index,selected_idx_cantor],dim=0) + [selected_latent_index, selected_idx_cantor], dim=0 + ) sparse_matrix = torch.sparse_coo_tensor( sparse_matrix_indices, torch.ones(len(selected_latent_index)), @@ -212,7 +213,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] co_occurrence_matrix += partial_cooc.int() del sparse_matrix, partial_cooc - # Compute Jaccard similarity def compute_jaccard(cooc_matrix): self_occurrence = cooc_matrix.diagonal() From 161b224caef16d51cae1f393a88c16529e22c9bf Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 16:14:42 +0000 Subject: [PATCH 121/132] Remove debug print --- delphi/latents/constructors.py | 1 - 1 file changed, 1 deletion(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index ad3149f6..f4c309c7 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -235,7 +235,6 @@ def neighbour_non_activation_windows( activations = activations[mask_ctx] # If there are no available indices, skip this neighbour if activations.numel() == 0: - print(f"No available indices for neighbour {neighbour.latent_index}") continue token_windows, _ = pool_max_activation_windows( activations=activations, From f576dcf7f694afb7a7670a9d72245b3a4de5fee5 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:18:36 +0000 Subject: [PATCH 122/132] add defaults --- delphi/explainers/default/default.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/delphi/explainers/default/default.py b/delphi/explainers/default/default.py index d3c31664..b0db2cdb 100644 --- a/delphi/explainers/default/default.py +++ b/delphi/explainers/default/default.py @@ -7,9 +7,9 @@ @dataclass class DefaultExplainer(Explainer): - activations: bool + activations: bool = True """Whether to show activations to the explainer.""" - cot: bool + cot: bool = False """Whether to use chain of thought reasoning.""" def _build_prompt(self, examples: list[ActivatingExample]) -> list[dict]: From b913b96286b9b772611e834fbefcaae718164664 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:25:20 +0000 Subject: [PATCH 123/132] Removing transforms --- delphi/latents/latents.py | 12 ++++++++++++ delphi/latents/loader.py | 12 ++++++++---- delphi/latents/transforms.py | 29 ----------------------------- 3 files changed, 20 insertions(+), 33 deletions(-) delete mode 100644 delphi/latents/transforms.py diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 458295c3..b3c328a8 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -143,6 +143,18 @@ def save(self, directory: str, save_examples: bool = False): with bf.BlobFile(path, "wb") as f: f.write(orjson.dumps(serializable)) + def set_neighbours( + self, + neighbours: list[tuple[float, int]], + ): + """ + Set the neighbours for the latent record. + """ + self.neighbours = [ + Neighbour(distance=neighbour[0], latent_index=neighbour[1]) + for neighbour in neighbours + ] + def display( self, tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index bb686a70..3e1cc253 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -154,7 +154,7 @@ def __init__( latents: Optional[dict[str, torch.Tensor]] = None, constructor: Optional[Callable] = None, sampler: Optional[Callable] = None, - transform: Optional[Callable] = None, + neighbours: Optional[dict[str, dict[str, list[tuple[float, int]]]]] = None, ): """ Initialize a LatentDataset. @@ -189,7 +189,7 @@ def __init__( self.cache_config = cache_config self.constructor = constructor self.sampler = sampler - self.transform = transform + self.neighbours = neighbours # TODO: is it possible to do this without loading all data? if self.constructor is not None: @@ -364,8 +364,12 @@ async def _aprocess_latent(self, latent_data: LatentData) -> LatentRecord: Optional[LatentRecord]: Processed latent record or None. """ record = LatentRecord(latent_data.latent) - if self.transform is not None: - self.transform(record) + if self.neighbours is not None: + record.set_neighbours( + self.neighbours[latent_data.module][ + str(latent_data.latent.latent_index) + ], + ) if self.constructor is not None: self.constructor( record=record, diff --git a/delphi/latents/transforms.py b/delphi/latents/transforms.py deleted file mode 100644 index 1bf53c9c..00000000 --- a/delphi/latents/transforms.py +++ /dev/null @@ -1,29 +0,0 @@ -from .latents import LatentRecord, Neighbour - - -def set_neighbours( - record: LatentRecord, - neighbours: dict[str, dict[str, list[tuple[float, int]]]], - threshold: float, -): - """ - Set the neighbours for the latent record. - Neighbours should be a dictionary with module names as keys, - where the values are a dictionary of latent indices as keys, - and a list of tuples of (distance,feature_index) as values. - """ - - latent_neighbours = neighbours[record.latent.module_name][ - str(record.latent.latent_index) - ] - - # Each element in neighbours is a tuple of (distance,feature_index) - # We want to keep only the ones with a distance less than the threshold - latent_neighbours = [ - neighbour for neighbour in latent_neighbours if neighbour[0] > threshold - ] - - record.neighbours = [ - Neighbour(distance=neighbour[0], latent_index=neighbour[1]) - for neighbour in latent_neighbours - ] From c0a74192e6519a2d9fe10f1a7bbd56780549cddd Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:30:12 +0000 Subject: [PATCH 124/132] Removing useless code --- delphi/latents/stats.py | 119 ---------------------------------------- 1 file changed, 119 deletions(-) delete mode 100644 delphi/latents/stats.py diff --git a/delphi/latents/stats.py b/delphi/latents/stats.py deleted file mode 100644 index 46a87df4..00000000 --- a/delphi/latents/stats.py +++ /dev/null @@ -1,119 +0,0 @@ -from collections import defaultdict -from math import floor - -import numpy as np -import torch -import torch.nn.functional as F - -from . import LatentRecord - - -def logits( - records: list[LatentRecord], - W_U: torch.nn.Module, - W_dec: torch.nn.Module, - k: int = 10, - tokenizer=None, -) -> list[list[str]]: - """ - Compute the top k logits via direct logit attribution for a set of records. - - Args: - records (list[LatentRecord]): A list of latent records. - W_U (torch.nn.Module): The linear layer for the encoder. - W_dec (torch.nn.Module): The linear layer for the decoder. - k (int): The number of top logits to compute. - tokenizer (Optional): A tokenizer for decoding logits. - - Returns: - decoded_top_logits (list[list[str]]): A list of top k logits for each record. - """ - - latent_indices = [record.latent.latent_index for record in records] - - narrowed_logits = torch.matmul(W_U, W_dec[:, latent_indices]) - - top_logits = torch.topk(narrowed_logits, k, dim=0).indices - - per_example_top_logits = top_logits.T - - decoded_top_logits = [] - - for record_index in range(len(records)): - decoded = tokenizer.batch_decode(per_example_top_logits[record_index]) - decoded_top_logits.append(decoded) - - records[record_index].top_logits = decoded - - -def unigram( - record: LatentRecord, k: int = 10, threshold: float = 0.0, negative_shift: int = 0 -): - avg_nonzero = [] - top_tokens = [] - - n_examples = floor(len(record.examples) * threshold) - - for example in record.examples[:n_examples]: - # Get the number of nonzero activations per example - avg_nonzero.append(np.count_nonzero(example.activations)) - - # Get the max activating token per example - index = np.argmax(example.activations) - negative_shift - - if index < 0: - continue - - top_tokens.append(example.tokens[index].item()) - - if len(set(top_tokens)) < k: - return set(top_tokens), np.mean(avg_nonzero) - - return -1, np.mean(avg_nonzero) - - -def cos(matrix, selected_latents=[0]): - a = matrix[:, selected_latents] - b = matrix - - a = F.normalize(a, p=2, dim=0) - b = F.normalize(b, p=2, dim=0) - - cos_sim = torch.mm(a.t(), b) - - return cos_sim - - -def get_neighbors(submodule_dict, latent_filter, k=10): - """ - Get the required latents for neighbor scoring. - - Returns: - neighbors_dict: Nested dictionary of modules -> neighbors -> indices, values - per_layer_latents (dict): A dictionary of latents per layer - """ - - neighbors_dict = defaultdict(dict) - per_layer_latents = {} - - for module_path, submodule in submodule_dict.items(): - selected_latents = latent_filter.get(module_path, False) - if not selected_latents: - continue - - W_D = submodule.ae.autoencoder._module.decoder.weight - cos_sim = cos(W_D, selected_latents=selected_latents) - top = torch.topk(cos_sim, k=k) - - top_indices = top.indices - top_values = top.values - - for i, (indices, values) in enumerate(zip(top_indices, top_values)): - neighbors_dict[module_path][i] = { - "indices": indices.tolist()[1:], - "values": values.tolist()[1:], - } - - per_layer_latents[module_path] = torch.unique(top_indices).tolist() - - return neighbors_dict, per_layer_latents From 12ec375668e2937cd6d081bff8d3b0aea237a06f Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:48:24 +0000 Subject: [PATCH 125/132] Not changing latent record in place --- delphi/latents/constructors.py | 4 ++-- delphi/latents/samplers.py | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index f4c309c7..3ae2d81c 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -111,7 +111,7 @@ def constructor( tokens: Float[Tensor, "batch seq"], all_data: Optional[dict[int, ActivationData]] = None, seed: int = 42, -): +) -> LatentRecord: cache_token_length = tokens.shape[1] # Get all positions where the latent is active @@ -165,7 +165,7 @@ def constructor( n_not_active=n_not_active, seed=seed, ) - + return record def neighbour_non_activation_windows( record: LatentRecord, diff --git a/delphi/latents/samplers.py b/delphi/latents/samplers.py index f4b448be..f1315857 100644 --- a/delphi/latents/samplers.py +++ b/delphi/latents/samplers.py @@ -93,10 +93,10 @@ def test( raise NotImplementedError("Activation sampling not implemented") -def sample( +def sampler( record: LatentRecord, cfg: ExperimentConfig, -): +) -> LatentRecord: examples = record.examples max_activation = record.max_activation _train = train( @@ -116,3 +116,4 @@ def sample( cfg.test_type, ) record.test = _test + return record From 1b78ce336d7e3fbdba183e1c5bb82aa657f143c0 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:48:41 +0000 Subject: [PATCH 126/132] Changing name --- delphi/latents/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/delphi/latents/__init__.py b/delphi/latents/__init__.py index b41d2026..1a91cd6a 100644 --- a/delphi/latents/__init__.py +++ b/delphi/latents/__init__.py @@ -13,8 +13,7 @@ NonActivatingExample, ) from .loader import LatentDataset -from .samplers import sample -from .stats import unigram +from .samplers import sampler __all__ = [ "LatentCache", @@ -28,6 +27,5 @@ "random_non_activating_windows", "neighbour_non_activation_windows", "constructor", - "sample", - "unigram", + "sampler", ] From d5b04805dd55a067cd1421c52c2a833154d115ba Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:49:17 +0000 Subject: [PATCH 127/132] adding neighbour type --- delphi/config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/delphi/config.py b/delphi/config.py index 39c0122f..79690028 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -35,6 +35,9 @@ class ExperimentConfig(Serializable): from pre-computed latent neighbours. They are still non-activating but have a higher chance of being similar to the activating examples.""" + neighbours_type: str = "co-occurrence" + """Type of neighbours to use. Only used if non_activating_source is 'neighbours'.""" + @dataclass class LatentConfig(Serializable): From cffcb332fb710d52e6b506717d6eb4cf5093438d Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:49:52 +0000 Subject: [PATCH 128/132] Not passing constructor or sampler anymore --- delphi/__main__.py | 36 +++----------------- delphi/latents/loader.py | 72 ++++++++++++++++++++++++++-------------- 2 files changed, 51 insertions(+), 57 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 552ce8e7..ee72c3c4 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -4,7 +4,7 @@ from functools import partial from glob import glob from pathlib import Path -from typing import Callable, cast +from typing import Callable import orjson import torch @@ -23,10 +23,7 @@ from delphi.config import CacheConfig, ExperimentConfig, LatentConfig, RunConfig from delphi.explainers import DefaultExplainer from delphi.latents import LatentCache, LatentDataset -from delphi.latents.constructors import constructor from delphi.latents.neighbours import NeighbourCalculator -from delphi.latents.samplers import sample -from delphi.latents.transforms import set_neighbours from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper from delphi.scorers import DetectionScorer, FuzzingScorer @@ -84,7 +81,6 @@ async def process_cache( run_cfg: RunConfig, experiment_cfg: ExperimentConfig, latents_path: Path, - neighbours_path: Path, explanations_path: Path, scores_path: Path, hookpoints: list[str], @@ -109,37 +105,14 @@ async def process_cache( latent_dict = { hook: latent_range for hook in hookpoints } # The latent range to explain - latent_dict = cast(dict[str, Tensor], latent_dict) - - example_constructor = partial( - constructor, - n_not_active=experiment_cfg.n_non_activating, - constructor_type=experiment_cfg.non_activating_source, - ctx_len=experiment_cfg.example_ctx_len, - max_examples=latent_cfg.max_examples, - ) - sampler = partial(sample, cfg=experiment_cfg) - if experiment_cfg.non_activating_source == "neighbours": - neighbours = {} - for hookpoint in hookpoints: - with open(neighbours_path / f"{hookpoint}.json", "r") as f: - neighbours[hookpoint] = json.load(f)["co-occurrence"] - transform = partial( - set_neighbours, - neighbours=neighbours, - threshold=0.0, - ) - else: - transform = None + dataset = LatentDataset( raw_dir=str(latents_path), - cfg=latent_cfg, + latent_cfg=latent_cfg, + experiment_cfg=experiment_cfg, modules=hookpoints, latents=latent_dict, tokenizer=tokenizer, - constructor=example_constructor, - sampler=sampler, - transform=transform, ) if run_cfg.explainer_provider == "offline": @@ -344,7 +317,6 @@ async def run( run_cfg, experiment_cfg, latents_path, - neighbours_path, explanations_path, scores_path, hookpoints, diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 3e1cc253..4c215271 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -16,8 +16,10 @@ load_tokenized_data, ) -from ..config import LatentConfig +from ..config import ExperimentConfig, LatentConfig +from .constructors import constructor from .latents import Latent, LatentRecord +from .samplers import sampler class ActivationData(NamedTuple): @@ -148,24 +150,26 @@ class LatentDataset: def __init__( self, raw_dir: str, - cfg: LatentConfig, + latent_cfg: LatentConfig, + experiment_cfg: ExperimentConfig, tokenizer: Optional[Callable] = None, modules: Optional[list[str]] = None, latents: Optional[dict[str, torch.Tensor]] = None, - constructor: Optional[Callable] = None, - sampler: Optional[Callable] = None, - neighbours: Optional[dict[str, dict[str, list[tuple[float, int]]]]] = None, ): """ Initialize a LatentDataset. Args: raw_dir: Directory containing raw latent data. - cfg: Configuration for latent processing. + latent_cfg: Configuration for latent processing. + experiment_cfg: Configuration for example creation + and sampling. + tokenizer: Tokenizer used to tokenize the data. modules: list of module names to include. latents: Dictionary of latents per module. """ - self.cfg = cfg + self.latent_config = latent_cfg + self.experiment_config = experiment_cfg self.buffers: list[TensorBuffer] = [] self.all_data: dict[str, dict[int, ActivationData] | None] = {} self.tokens = None @@ -187,14 +191,18 @@ def __init__( else: self.tokenizer = tokenizer self.cache_config = cache_config - self.constructor = constructor - self.sampler = sampler - self.neighbours = neighbours - # TODO: is it possible to do this without loading all data? - if self.constructor is not None: - if self.constructor.keywords["constructor_type"] == "neighbours": - self.all_data = self._load_all_data(raw_dir, self.modules) + if self.experiment_config.non_activating_source == "neighbours": + # path is always going to end with /latents + split_path = raw_dir.split("/")[:-1] + neighbours_path = "/".join(split_path) + "/neighbours" + self.neighbours = self.load_neighbours( + neighbours_path, self.experiment_config.neighbours_type + ) + # TODO: is it possible to do this without loading all data? + self.all_data = self._load_all_data(raw_dir, self.modules) + else: + self.neighbours = None self.load_tokens() @@ -220,6 +228,13 @@ def load_tokens(self): ) return self.tokens + def load_neighbours(self, neighbours_path: str, neighbours_type: str): + neighbours = {} + for hookpoint in self.modules: + with open(neighbours_path + f"{hookpoint}.json", "r") as f: + neighbours[hookpoint] = json.load(f)[neighbours_type] + return neighbours + def _edges(self, raw_dir: str, module: str) -> list[tuple[int, int]]: module_dir = Path(raw_dir) / module safetensor_files = [f for f in module_dir.glob("*.safetensors")] @@ -244,7 +259,7 @@ def _build(self, raw_dir: str): for start, end in edges: path = f"{raw_dir}/{module}/{start}_{end}.safetensors" tensor_buffer = TensorBuffer( - path, module, min_examples=self.cfg.min_examples + path, module, min_examples=self.latent_config.min_examples ) if self.tokens is None: self.tokens = tensor_buffer.tokens @@ -284,7 +299,10 @@ def _build_selected( # Adjust end by one as the path avoids overlap path = f"{raw_dir}/{module}/{start}_{end-1}.safetensors" tensor_buffer = TensorBuffer( - path, module, _selected_latents, min_examples=self.cfg.min_examples + path, + module, + _selected_latents, + min_examples=self.latent_config.min_examples, ) if self.tokens is None: self.tokens = tensor_buffer.tokens @@ -363,6 +381,8 @@ async def _aprocess_latent(self, latent_data: LatentData) -> LatentRecord: Returns: Optional[LatentRecord]: Processed latent record or None. """ + if self.tokens is None: + raise ValueError("Tokens are not loaded") record = LatentRecord(latent_data.latent) if self.neighbours is not None: record.set_neighbours( @@ -370,13 +390,15 @@ async def _aprocess_latent(self, latent_data: LatentData) -> LatentRecord: str(latent_data.latent.latent_index) ], ) - if self.constructor is not None: - self.constructor( - record=record, - activation_data=latent_data.activation_data, - all_data=self.all_data[latent_data.module], - tokens=self.tokens, - ) - if self.sampler is not None: - self.sampler(record) + record = constructor( + record=record, + activation_data=latent_data.activation_data, + n_not_active=self.experiment_config.n_non_activating, + constructor_type=self.experiment_config.non_activating_source, + ctx_len=self.experiment_config.example_ctx_len, + max_examples=self.latent_config.max_examples, + tokens=self.tokens, + all_data=self.all_data[latent_data.module], + ) + record = sampler(record, self.experiment_config) return record From 1cdfd01d045cdad2b772124465c4e0f39153a73c Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 17:55:12 +0000 Subject: [PATCH 129/132] Fixing circular imports --- delphi/latents/constructors.py | 8 +++- delphi/latents/latents.py | 72 +++++++++++++++++++++++----------- delphi/latents/loader.py | 31 +-------------- 3 files changed, 57 insertions(+), 54 deletions(-) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index 3ae2d81c..e0d77827 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -4,8 +4,12 @@ from jaxtyping import Float from torch import Tensor -from .latents import ActivatingExample, LatentRecord, NonActivatingExample -from .loader import ActivationData +from .latents import ( + ActivatingExample, + ActivationData, + LatentRecord, + NonActivatingExample, +) def prepare_non_activating_examples( diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index b3c328a8..560e73e1 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -1,5 +1,5 @@ from dataclasses import dataclass, field -from typing import Optional +from typing import NamedTuple, Optional import blobfile as bf import orjson @@ -8,6 +8,54 @@ from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +@dataclass +class Latent: + """ + A latent extracted from a model's activations. + """ + + module_name: str + """The module name associated with the latent.""" + + latent_index: int + """The index of the latent within the module.""" + + def __repr__(self) -> str: + """ + Return a string representation of the latent. + + Returns: + str: A string representation of the latent. + """ + return f"{self.module_name}_latent{self.latent_index}" + + +class ActivationData(NamedTuple): + """ + Represents the activation data for a latent. + """ + + locations: Float[Tensor, "n_examples 2"] + """Tensor of latent locations.""" + + activations: Float[Tensor, "n_examples"] + """Tensor of latent activations.""" + + +class LatentData(NamedTuple): + """ + Represents the output of a TensorBuffer. + """ + + latent: Latent + """The latent associated with this output.""" + + module: str + """The module associated with this output.""" + + activation_data: ActivationData + """The activation data for this latent.""" + @dataclass class Neighbour: distance: float @@ -63,28 +111,6 @@ class NonActivatingExample(Example): """ -@dataclass -class Latent: - """ - A latent extracted from a model's activations. - """ - - module_name: str - """The module name associated with the latent.""" - - latent_index: int - """The index of the latent within the module.""" - - def __repr__(self) -> str: - """ - Return a string representation of the latent. - - Returns: - str: A string representation of the latent. - """ - return f"{self.module_name}_latent{self.latent_index}" - - @dataclass class LatentRecord: """ diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 4c215271..ecd67125 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -3,7 +3,7 @@ import os from dataclasses import dataclass from pathlib import Path -from typing import Callable, NamedTuple, Optional +from typing import Callable, Optional import numpy as np import torch @@ -18,37 +18,10 @@ from ..config import ExperimentConfig, LatentConfig from .constructors import constructor -from .latents import Latent, LatentRecord +from .latents import ActivationData, Latent, LatentData, LatentRecord from .samplers import sampler -class ActivationData(NamedTuple): - """ - Represents the activation data for a latent. - """ - - locations: Float[Tensor, "n_examples 2"] - """Tensor of latent locations.""" - - activations: Float[Tensor, "n_examples"] - """Tensor of latent activations.""" - - -class LatentData(NamedTuple): - """ - Represents the output of a TensorBuffer. - """ - - latent: Latent - """The latent associated with this output.""" - - module: str - """The module associated with this output.""" - - activation_data: ActivationData - """The activation data for this latent.""" - - @dataclass class TensorBuffer: """ From 614473b95dbc016fcec70439df2dc7bbcb0337fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Feb 2025 17:58:17 +0000 Subject: [PATCH 130/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/latents/constructors.py | 1 + delphi/latents/latents.py | 1 + 2 files changed, 2 insertions(+) diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index e0d77827..b8e060d0 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -171,6 +171,7 @@ def constructor( ) return record + def neighbour_non_activation_windows( record: LatentRecord, not_active_mask: Float[Tensor, "windows"], diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 560e73e1..4cafbf10 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -56,6 +56,7 @@ class LatentData(NamedTuple): activation_data: ActivationData """The activation data for this latent.""" + @dataclass class Neighbour: distance: float From 369eb3b7f4df64848ba854356184055e04e690c1 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Wed, 19 Feb 2025 18:38:19 +0000 Subject: [PATCH 131/132] Adding encoder/decoder similarity neighbours --- delphi/__main__.py | 36 +++++++++++++++++----- delphi/config.py | 4 ++- delphi/latents/constructors.py | 1 - delphi/latents/loader.py | 4 +-- delphi/latents/neighbours.py | 37 ++++++++++------------- delphi/sparse_coders/custom/gemmascope.py | 4 +-- delphi/sparse_coders/load_sparsify.py | 7 ++--- delphi/sparse_coders/sparse_model.py | 10 +++--- 8 files changed, 59 insertions(+), 44 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index ee72c3c4..ac417901 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -27,7 +27,7 @@ from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper from delphi.scorers import DetectionScorer, FuzzingScorer -from delphi.sparse_coders import load_hooks_sparse_coders +from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import load_tokenized_data @@ -57,23 +57,39 @@ def load_artifacts(run_cfg: RunConfig): async def create_neighbours( + run_cfg: RunConfig, latents_path: Path, neighbours_path: Path, hookpoints: list[str], + experiment_cfg: ExperimentConfig, ): """ Creates a neighbours file for the given hookpoints. """ neighbours_path.mkdir(parents=True, exist_ok=True) + if experiment_cfg.neighbours_type != "co-occurrence": + saes = load_sparse_coders(run_cfg, device="cuda") + for hookpoint in hookpoints: - neighbour_calculator = NeighbourCalculator( - cache_dir=latents_path / hookpoint, number_of_neighbours=100 - ) - neighbour_calculator.populate_neighbour_cache(["co-occurrence"]) + if experiment_cfg.neighbours_type == "co-occurrence": + neighbour_calculator = NeighbourCalculator( + cache_dir=latents_path / hookpoint, number_of_neighbours=100 + ) + + elif experiment_cfg.neighbours_type == "decoder_similarity": + + neighbour_calculator = NeighbourCalculator( + autoencoder=saes[hookpoint], number_of_neighbours=100 + ) - neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}.json") + elif experiment_cfg.neighbours_type == "encoder_similarity": + neighbour_calculator = NeighbourCalculator( + autoencoder=saes[hookpoint], number_of_neighbours=100 + ) + neighbour_calculator.populate_neighbour_cache(experiment_cfg.neighbours_type) + neighbour_calculator.save_neighbour_cache(f"{neighbours_path}/{hookpoint}") async def process_cache( @@ -304,7 +320,13 @@ async def run( not glob(str(neighbours_path / ".*")) + glob(str(neighbours_path / "*")) or "neighbours" in run_cfg.overwrite ): - await create_neighbours(latents_path, neighbours_path, hookpoints) + await create_neighbours( + run_cfg, + latents_path, + neighbours_path, + hookpoints, + experiment_cfg, + ) else: print(f"Files found in {neighbours_path}, skipping...") diff --git a/delphi/config.py b/delphi/config.py index 79690028..4ea6cf73 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -35,7 +35,9 @@ class ExperimentConfig(Serializable): from pre-computed latent neighbours. They are still non-activating but have a higher chance of being similar to the activating examples.""" - neighbours_type: str = "co-occurrence" + neighbours_type: Literal[ + "co-occurrence", "decoder_similarity", "encoder_similarity" + ] = "co-occurrence" """Type of neighbours to use. Only used if non_activating_source is 'neighbours'.""" diff --git a/delphi/latents/constructors.py b/delphi/latents/constructors.py index e0d77827..1f781cd6 100644 --- a/delphi/latents/constructors.py +++ b/delphi/latents/constructors.py @@ -214,7 +214,6 @@ def neighbour_non_activation_windows( break # get the locations of the neighbour if neighbour.latent_index not in all_data: - print(f"Neighbour {neighbour.latent_index} not found in all_data") continue locations = all_data[neighbour.latent_index].locations activations = all_data[neighbour.latent_index].activations diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index ecd67125..4f261ceb 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -204,8 +204,8 @@ def load_tokens(self): def load_neighbours(self, neighbours_path: str, neighbours_type: str): neighbours = {} for hookpoint in self.modules: - with open(neighbours_path + f"{hookpoint}.json", "r") as f: - neighbours[hookpoint] = json.load(f)[neighbours_type] + with open(neighbours_path + f"/{hookpoint}-{neighbours_type}.json", "r") as f: + neighbours[hookpoint] = json.load(f) return neighbours def _edges(self, raw_dir: str, module: str) -> list[tuple[int, int]]: diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 219dad00..c76e7e4b 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -24,7 +24,7 @@ def __init__( cache_dir: Optional[Path] = None, autoencoder: Optional[nn.Module] = None, number_of_neighbours: int = 10, - neighbour_cache: Optional[dict[str, dict[int, list[tuple[int, float]]]]] = None, + neighbour_cache: Optional[dict[int, list[tuple[int, float]]]] = None, ): """ Initialize a NeighbourCalculator. @@ -40,17 +40,16 @@ def __init__( self.autoencoder = autoencoder # self.residual_stream_record = residual_stream_record self.number_of_neighbours = number_of_neighbours - # load the neighbour cache from the path if neighbour_cache is not None: self.neighbour_cache = neighbour_cache else: # dictionary to cache computed neighbour lists - self.neighbour_cache: dict[str, dict[int, list[tuple[int, float]]]] = {} + self.neighbour_cache: dict[int, list[tuple[int, float]]] = {} def _compute_neighbour_list( self, - method: Literal["similarity_encoder", "similarity_decoder", "co-occurrence"], + method: Literal["encoder_similarity", "decoder_similarity", "co-occurrence"], ) -> None: """ Compute complete neighbour lists using specified method. @@ -58,21 +57,20 @@ def _compute_neighbour_list( Args: method (str): One of 'similarity', 'correlation', or 'co-occurrence' """ - if method == "similarity_encoder": - self.neighbour_cache[method] = self._compute_similarity_neighbours( - "encoder" - ) - elif method == "similarity_decoder": - self.neighbour_cache[method] = self._compute_similarity_neighbours( - "decoder" - ) + if method == "encoder_similarity": + self.method = "encoder_similarity" + self.neighbour_cache = self._compute_similarity_neighbours("encoder") + elif method == "decoder_similarity": + self.method = "decoder_similarity" + self.neighbour_cache = self._compute_similarity_neighbours("decoder") elif method == "co-occurrence": - self.neighbour_cache[method] = self._compute_cooccurrence_neighbours() + self.method = "co-occurrence" + self.neighbour_cache = self._compute_cooccurrence_neighbours() else: raise ValueError( - f"Unknown method: {method}. Use 'similarity_encoder'," - "'similarity_decoder', or 'co-occurrence'" + f"Unknown method: {method}. Use 'encoder similarity'," + "'decoder similarity', or 'co-occurrence'" ) def _compute_similarity_neighbours( @@ -240,21 +238,18 @@ def compute_jaccard(cooc_matrix): def populate_neighbour_cache( self, - methods: list[ - Literal["similarity_encoder", "similarity_decoder", "co-occurrence"] - ], + method: Literal["encoder_similarity", "decoder_similarity", "co-occurrence"], ) -> None: """ Populate the neighbour cache with the computed neighbour lists """ - for method in methods: - self._compute_neighbour_list(method) + self._compute_neighbour_list(method) def save_neighbour_cache(self, path: str) -> None: """ Save the neighbour cache to the path as a json file """ - with open(path, "w") as f: + with open(path + f"-{self.method}.json", "w") as f: json.dump(self.neighbour_cache, f) def load_neighbour_cache(self, path: str) -> dict[str, dict[int, list[int]]]: diff --git a/delphi/sparse_coders/custom/gemmascope.py b/delphi/sparse_coders/custom/gemmascope.py index 8a66f20c..27511d58 100644 --- a/delphi/sparse_coders/custom/gemmascope.py +++ b/delphi/sparse_coders/custom/gemmascope.py @@ -13,7 +13,7 @@ def load_gemma_autoencoders( sizes: list[str], type: str, dtype: torch.dtype = torch.bfloat16, - device: torch.device = torch.device("cuda"), + device: str | torch.device = torch.device("cuda"), ) -> dict[str, nn.Module]: saes = {} @@ -45,7 +45,7 @@ def load_gemma_hooks( sizes: list[str], type: str, dtype: torch.dtype = torch.bfloat16, - device: torch.device = torch.device("cuda"), + device: str | torch.device = torch.device("cuda"), ): saes = load_gemma_autoencoders( model_path, diff --git a/delphi/sparse_coders/load_sparsify.py b/delphi/sparse_coders/load_sparsify.py index 7c58c621..b9528cb4 100644 --- a/delphi/sparse_coders/load_sparsify.py +++ b/delphi/sparse_coders/load_sparsify.py @@ -41,10 +41,9 @@ def resolve_path(model: PreTrainedModel, path_segments: list[str]) -> list[str] def load_sparsify_sparse_coders( - model: PreTrainedModel, name: str, hookpoints: list[str], - device: str | torch.device | None = None, + device: str | torch.device, compile: bool = False, ) -> dict[str, Sae]: """ @@ -62,8 +61,6 @@ def load_sparsify_sparse_coders( Returns: dict[str, Any]: A dictionary mapping hookpoints to sparse models. """ - if device is None: - device = model.device or "cpu" # Load the sparse models sparse_model_dict = {} @@ -112,8 +109,8 @@ def load_sparsify_hooks( Returns: dict[str, Callable]: A dictionary mapping hookpoints to encode functions. """ + device = model.device or "cpu" sparse_model_dict = load_sparsify_sparse_coders( - model, name, hookpoints, device, diff --git a/delphi/sparse_coders/sparse_model.py b/delphi/sparse_coders/sparse_model.py index 5b77feae..af2e9a12 100644 --- a/delphi/sparse_coders/sparse_model.py +++ b/delphi/sparse_coders/sparse_model.py @@ -1,5 +1,6 @@ from typing import Callable +import torch import torch.nn as nn from transformers import PreTrainedModel @@ -69,8 +70,8 @@ def load_hooks_sparse_coders( def load_sparse_coders( - model: PreTrainedModel, run_cfg: RunConfig, + device: str | torch.device, compile: bool = False, ) -> dict[str, nn.Module] | dict[str, Sae]: """ @@ -87,10 +88,10 @@ def load_sparse_coders( # Add SAE hooks to the model if "gemma" not in run_cfg.sparse_model: hookpoint_to_sparse_model = load_sparsify_sparse_coders( - model, run_cfg.sparse_model, run_cfg.hookpoints, - compile=compile, + device, + compile, ) else: # model path will always be of the form google/gemma-scope--pt-/ @@ -118,8 +119,7 @@ def load_sparse_coders( average_l0s=l0s, sizes=sae_sizes, type=type, - dtype=model.dtype, - device=model.device, + device=device, ) return hookpoint_to_sparse_model From 2c939693d13d26a46ad83e156c2ca405f38a6440 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 19 Feb 2025 18:39:12 +0000 Subject: [PATCH 132/132] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/latents/loader.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 4f261ceb..06eba648 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -204,7 +204,9 @@ def load_tokens(self): def load_neighbours(self, neighbours_path: str, neighbours_type: str): neighbours = {} for hookpoint in self.modules: - with open(neighbours_path + f"/{hookpoint}-{neighbours_type}.json", "r") as f: + with open( + neighbours_path + f"/{hookpoint}-{neighbours_type}.json", "r" + ) as f: neighbours[hookpoint] = json.load(f) return neighbours