Skip to content
4 changes: 3 additions & 1 deletion delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,9 @@ def populate_cache(
cache.run(cache_cfg.n_tokens, tokens)

# Save firing counts to the run-specific log directory
cache.save_firing_counts()
if run_cfg.verbose:
cache.save_firing_counts()
cache.generate_statistics_cache()

cache.save_splits(
# Split the activation and location indices into different files to make
Expand Down
150 changes: 149 additions & 1 deletion delphi/latents/cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from collections import defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Callable

import numpy as np
import torch
from jaxtyping import Float
from jaxtyping import Float, Int
from safetensors.numpy import save_file
from torch import Tensor
from tqdm import tqdm
Expand Down Expand Up @@ -386,6 +387,24 @@ def save_splits(self, n_splits: int, save_dir: Path, save_tokens: bool = True):

save_file(split_data, output_file)

def generate_statistics_cache(self):
"""
Print statistics (number of dead features, number of single token features)
to the console.
"""
assert self.width is not None, "Width must be set before generating statistics"
print("Feature statistics:")
# Token frequency
for module_path in self.cache.latent_locations.keys():
print(f"# Module: {module_path}")
generate_statistics_cache(
self.cache.tokens[module_path],
self.cache.latent_locations[module_path],
self.cache.latent_activations[module_path],
self.width,
verbose=True,
)

def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str):
"""
Save the configuration for the cached latents.
Expand Down Expand Up @@ -414,3 +433,132 @@ def save_firing_counts(self):

log_path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.hookpoint_firing_counts, log_path)


@dataclass
class CacheStatistics:
frac_alive: float
frac_fired_1pct: float
frac_fired_10pct: float
frac_weak_single_token: float
frac_strong_single_token: float


@torch.inference_mode()
def generate_statistics_cache(
tokens: Int[Tensor, "batch sequence"],
latent_locations: Int[Tensor, "n_activations 3"],
activations: Float[Tensor, "n_activations"],
width: int,
verbose: bool = False,
) -> CacheStatistics:
"""Generate global statistics for the cache."

Args:
tokens (Int[Tensor, "batch sequence"]): Tokens used to generate the cache.
latent_locations (Int[Tensor, "n_activations 3"]): Indices of the latent
activations, corresponding to `tokens`.
activations (Float[Tensor, "n_activations"]): Activations of the latents,
as stored by the cache.
width (int): Width of the cache to test.
verbose (bool, optional): Print results to stdout. Defaults to False.
Returns:
CacheStatistics: the statistics
"""
total_n_tokens = tokens.shape[0] * tokens.shape[1]

latent_locations, latents = latent_locations[:, :2], latent_locations[:, 2]

# torch always sorts for unique, so we might as well do it
sorted_latents, latent_indices = latents.sort()
sorted_activations = activations[latent_indices]
sorted_tokens = tokens[latent_locations[latent_indices]]

unique_latents, counts = torch.unique_consecutive(
sorted_latents, return_counts=True
)

# How many unique latents ever activated on the cached tokens
num_alive = counts.shape[0]
fraction_alive = num_alive / width
if verbose:
print(f"Fraction of latents alive: {fraction_alive:%}")
# Compute densities of latents
densities = counts / total_n_tokens

# How many fired more than 1% of the time
one_percent = (densities > 0.01).sum() / width
# How many fired more than 10% of the time
ten_percent = (densities > 0.1).sum() / width
if verbose:
print(f"Fraction of latents fired more than 1% of the time: {one_percent:%}")
print(f"Fraction of latents fired more than 10% of the time: {ten_percent:%}")
# Try to estimate simple feature frequency
split_indices = torch.cumsum(counts, dim=0)
activation_splits = torch.tensor_split(sorted_activations, split_indices[:-1])
token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1])

# This might take a while and we may only care for statistics
# but for now we do the full loop
num_single_token_features = 0
maybe_single_token_features = 0
for _latent_idx, activation_group, token_group in zip(
unique_latents, activation_splits, token_splits
):
maybe_single_token, single_token = check_single_feature(
activation_group, token_group
)
num_single_token_features += single_token
maybe_single_token_features += maybe_single_token

single_token_fraction = maybe_single_token_features / num_alive
strong_token_fraction = num_single_token_features / num_alive
if verbose:
print(f"Fraction of weak single token latents: {single_token_fraction:%}")
print(f"Fraction of strong single token latents: {strong_token_fraction:%}")

return CacheStatistics(
frac_alive=fraction_alive,
frac_fired_1pct=one_percent,
frac_fired_10pct=ten_percent,
frac_weak_single_token=single_token_fraction,
frac_strong_single_token=strong_token_fraction,
)


@torch.inference_mode()
def check_single_feature(activation_group, token_group):
sorted_activation_group, sorted_indices = activation_group.sort()
sorted_token_group = token_group[sorted_indices]

number_activations = sorted_activation_group.shape[0]
# Get the first 50 elements if possible
num_elements = min(50, number_activations)

wanted_tokens = sorted_token_group[:num_elements]

# Check how many of them are exactly the same
_, unique_counts = torch.unique_consecutive(wanted_tokens, return_counts=True)

max_count = unique_counts.max()
maybe_single_token = False
if max_count > 0.9 * num_elements:
# Single token feature
maybe_single_token = True

# Randomly sample 100 activations from the top 50%
n_top = max(1, int(number_activations * 0.5))
num_samples = min(100, n_top)
top_50_percent = sorted_token_group[:n_top]
sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples]
sampled_tokens = top_50_percent[sampled_indices]
_, unique_counts = torch.unique_consecutive(sampled_tokens, return_counts=True)

max_count = unique_counts.max()
other_maybe_single_token = max_count > 0.75 * num_samples
if other_maybe_single_token and maybe_single_token:
return 0, 1
elif maybe_single_token or other_maybe_single_token:
return 1, 0
else:
return 0, 0
6 changes: 6 additions & 0 deletions delphi/latents/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,12 @@ def load(
Float[Tensor, "activations"],
Float[Tensor, "batch seq"] | None,
]:
"""Load the tensor buffer's data.

Returns:
Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations,
and tokens (if present in the cache).
"""
split_data = load_file(self.path)
first_latent = int(self.path.split("/")[-1].split("_")[0])
activations = torch.tensor(split_data["activations"])
Expand Down
4 changes: 2 additions & 2 deletions delphi/log/result_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@

def import_plotly():
try:
import plotly.express as px
import plotly.io as pio
import plotly.express as px # type: ignore
import plotly.io as pio # type: ignore
except ImportError:
raise ImportError(
"Plotly is not installed.\n"
Expand Down
2 changes: 1 addition & 1 deletion delphi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def assert_type(typ: Type[T], obj: Any) -> T:
if not isinstance(obj, typ):
raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}")

return cast(typ, obj)
return cast(typ, obj) # type: ignore