diff --git a/delphi/__main__.py b/delphi/__main__.py index 4186b284..ff547c92 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -278,7 +278,9 @@ def populate_cache( cache.run(cache_cfg.n_tokens, tokens) # Save firing counts to the run-specific log directory - cache.save_firing_counts() + if run_cfg.verbose: + cache.save_firing_counts() + cache.generate_statistics_cache() cache.save_splits( # Split the activation and location indices into different files to make diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index f9bf52bf..00e082e8 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -1,11 +1,12 @@ import json from collections import defaultdict +from dataclasses import dataclass from pathlib import Path from typing import Callable import numpy as np import torch -from jaxtyping import Float +from jaxtyping import Float, Int from safetensors.numpy import save_file from torch import Tensor from tqdm import tqdm @@ -386,6 +387,24 @@ def save_splits(self, n_splits: int, save_dir: Path, save_tokens: bool = True): save_file(split_data, output_file) + def generate_statistics_cache(self): + """ + Print statistics (number of dead features, number of single token features) + to the console. + """ + assert self.width is not None, "Width must be set before generating statistics" + print("Feature statistics:") + # Token frequency + for module_path in self.cache.latent_locations.keys(): + print(f"# Module: {module_path}") + generate_statistics_cache( + self.cache.tokens[module_path], + self.cache.latent_locations[module_path], + self.cache.latent_activations[module_path], + self.width, + verbose=True, + ) + def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str): """ Save the configuration for the cached latents. @@ -414,3 +433,132 @@ def save_firing_counts(self): log_path.parent.mkdir(parents=True, exist_ok=True) torch.save(self.hookpoint_firing_counts, log_path) + + +@dataclass +class CacheStatistics: + frac_alive: float + frac_fired_1pct: float + frac_fired_10pct: float + frac_weak_single_token: float + frac_strong_single_token: float + + +@torch.inference_mode() +def generate_statistics_cache( + tokens: Int[Tensor, "batch sequence"], + latent_locations: Int[Tensor, "n_activations 3"], + activations: Float[Tensor, "n_activations"], + width: int, + verbose: bool = False, +) -> CacheStatistics: + """Generate global statistics for the cache." + + Args: + tokens (Int[Tensor, "batch sequence"]): Tokens used to generate the cache. + latent_locations (Int[Tensor, "n_activations 3"]): Indices of the latent + activations, corresponding to `tokens`. + activations (Float[Tensor, "n_activations"]): Activations of the latents, + as stored by the cache. + width (int): Width of the cache to test. + verbose (bool, optional): Print results to stdout. Defaults to False. + Returns: + CacheStatistics: the statistics + """ + total_n_tokens = tokens.shape[0] * tokens.shape[1] + + latent_locations, latents = latent_locations[:, :2], latent_locations[:, 2] + + # torch always sorts for unique, so we might as well do it + sorted_latents, latent_indices = latents.sort() + sorted_activations = activations[latent_indices] + sorted_tokens = tokens[latent_locations[latent_indices]] + + unique_latents, counts = torch.unique_consecutive( + sorted_latents, return_counts=True + ) + + # How many unique latents ever activated on the cached tokens + num_alive = counts.shape[0] + fraction_alive = num_alive / width + if verbose: + print(f"Fraction of latents alive: {fraction_alive:%}") + # Compute densities of latents + densities = counts / total_n_tokens + + # How many fired more than 1% of the time + one_percent = (densities > 0.01).sum() / width + # How many fired more than 10% of the time + ten_percent = (densities > 0.1).sum() / width + if verbose: + print(f"Fraction of latents fired more than 1% of the time: {one_percent:%}") + print(f"Fraction of latents fired more than 10% of the time: {ten_percent:%}") + # Try to estimate simple feature frequency + split_indices = torch.cumsum(counts, dim=0) + activation_splits = torch.tensor_split(sorted_activations, split_indices[:-1]) + token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1]) + + # This might take a while and we may only care for statistics + # but for now we do the full loop + num_single_token_features = 0 + maybe_single_token_features = 0 + for _latent_idx, activation_group, token_group in zip( + unique_latents, activation_splits, token_splits + ): + maybe_single_token, single_token = check_single_feature( + activation_group, token_group + ) + num_single_token_features += single_token + maybe_single_token_features += maybe_single_token + + single_token_fraction = maybe_single_token_features / num_alive + strong_token_fraction = num_single_token_features / num_alive + if verbose: + print(f"Fraction of weak single token latents: {single_token_fraction:%}") + print(f"Fraction of strong single token latents: {strong_token_fraction:%}") + + return CacheStatistics( + frac_alive=fraction_alive, + frac_fired_1pct=one_percent, + frac_fired_10pct=ten_percent, + frac_weak_single_token=single_token_fraction, + frac_strong_single_token=strong_token_fraction, + ) + + +@torch.inference_mode() +def check_single_feature(activation_group, token_group): + sorted_activation_group, sorted_indices = activation_group.sort() + sorted_token_group = token_group[sorted_indices] + + number_activations = sorted_activation_group.shape[0] + # Get the first 50 elements if possible + num_elements = min(50, number_activations) + + wanted_tokens = sorted_token_group[:num_elements] + + # Check how many of them are exactly the same + _, unique_counts = torch.unique_consecutive(wanted_tokens, return_counts=True) + + max_count = unique_counts.max() + maybe_single_token = False + if max_count > 0.9 * num_elements: + # Single token feature + maybe_single_token = True + + # Randomly sample 100 activations from the top 50% + n_top = max(1, int(number_activations * 0.5)) + num_samples = min(100, n_top) + top_50_percent = sorted_token_group[:n_top] + sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples] + sampled_tokens = top_50_percent[sampled_indices] + _, unique_counts = torch.unique_consecutive(sampled_tokens, return_counts=True) + + max_count = unique_counts.max() + other_maybe_single_token = max_count > 0.75 * num_samples + if other_maybe_single_token and maybe_single_token: + return 0, 1 + elif maybe_single_token or other_maybe_single_token: + return 1, 0 + else: + return 0, 0 diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index dcab0d50..a3d55e58 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -86,6 +86,12 @@ def load( Float[Tensor, "activations"], Float[Tensor, "batch seq"] | None, ]: + """Load the tensor buffer's data. + + Returns: + Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations, + and tokens (if present in the cache). + """ split_data = load_file(self.path) first_latent = int(self.path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index bd0ca176..81776ccc 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -9,8 +9,8 @@ def import_plotly(): try: - import plotly.express as px - import plotly.io as pio + import plotly.express as px # type: ignore + import plotly.io as pio # type: ignore except ImportError: raise ImportError( "Plotly is not installed.\n" diff --git a/delphi/utils.py b/delphi/utils.py index 2b278cb9..a148d238 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -41,4 +41,4 @@ def assert_type(typ: Type[T], obj: Any) -> T: if not isinstance(obj, typ): raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") - return cast(typ, obj) + return cast(typ, obj) # type: ignore