From 09d618893bb51bb0c5d76287c189eda6200014e4 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Fri, 21 Feb 2025 17:27:55 +0000 Subject: [PATCH 1/9] Caching first time --- delphi/__main__.py | 5 ++- delphi/latents/cache.py | 86 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 90 insertions(+), 1 deletion(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index e4dff184..da9fc748 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -258,6 +258,9 @@ def populate_cache( ) cache.run(cache_cfg.n_tokens, tokens) + if run_cfg.verbose: + cache.generate_statistics_cache() + cache.save_splits( # Split the activation and location indices into different files to make # loading faster @@ -279,7 +282,7 @@ async def run( run_cfg.save_json(base_path / "run_config.json", indent=4) - latents_path = base_path / "latents" + latents_path = Path("temperature") / "latents" explanations_path = base_path / "explanations" scores_path = base_path / "scores" neighbours_path = base_path / "neighbours" diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 59907425..1b4a2591 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -384,3 +384,89 @@ def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str): config_dict = cfg.to_dict() config_dict["model_name"] = model_name json.dump(config_dict, f, indent=4) + + def generate_statistics_cache(self): + + # Token frequency + for module_path in self.cache.latent_locations.keys(): + + tokens = self.cache.tokens[module_path] + total_n_tokens = tokens.shape[0] * tokens.shape[1] + + latent_locations = self.cache.latent_locations[module_path][:,:2] + + latents = self.cache.latent_locations[module_path][:, 2] + activations = self.cache.latent_activations[module_path] + + #torch always sorts for unique, so we might as well do it + sorted_latents , latent_indices = latents.sort() + sorted_activations = activations[latent_indices] + sorted_tokens = tokens[latent_locations[latent_indices]] + + unique_latents, counts = torch.unique_consecutive(sorted_latents, return_counts=True) + + # How many unique latents ever activated on the cached tokens + fraction_alive = counts.shape[0] / self.width + + print(f"Fraction of latents alive: {fraction_alive}") + # Compute densities of latents + densities = counts / total_n_tokens + + # How many fired more than 1% of the time + one_percent = (densities > 0.01).sum()/self.width + print(f"Fraction of latents fired more than 1% of the time: {one_percent}") + # How many fired more than 10% of the time + ten_percent = (densities > 0.1).sum()/self.width + print(f"Fraction of latents fired more than 10% of the time: {ten_percent}") + # Try to estimate simple feature frequency + split_indices = torch.cumsum(counts, dim=0) + activation_splits = torch.tensor_split(sorted_activations, split_indices[:-1]) + token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1]) + + + # This might take a while and we may only care for statistics + # but for now we do the full loop + num_single_token_features=0 + maybe_single_token_features=0 + for latent_idx, activation_group, token_group in zip(unique_latents, activation_splits, token_splits): + sorted_activation_group, sorted_indices = activation_group.sort() + sorted_token_group = token_group[sorted_indices] + + number_activations = sorted_activation_group.shape[0] + # Get the first 50 elements if possible + num_elements = min(50, number_activations) + + wanted_tokens = sorted_token_group[:num_elements] + + # Check how many of them are exactly the same + _, unique_counts = torch.unique_consecutive(wanted_tokens, return_counts=True) + + max_count = unique_counts.max() + maybe_single_token = False + if max_count > 0.9 * num_elements: + # Single token feature + maybe_single_token = True + + # Randomly sample 100 activations from the top 50% + num_samples = min(100, number_activations/2) + top_50_percent = sorted_token_group[:number_activations//2] + sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples] + sampled_tokens = top_50_percent[sampled_indices] + _, unique_counts = torch.unique_consecutive(sampled_tokens, return_counts=True) + + max_count = unique_counts.max() + if max_count > 0.75 * num_samples: + if maybe_single_token: + num_single_token_features += 1 + else: + maybe_single_token_features += 1 + + single_token_fraction = num_single_token_features/maybe_single_token_features + strong_token_fraction = maybe_single_token_features/num_single_token_features + print(f"Fraction of weak token latents: {single_token_fraction}") + print(f"Fraction of strong token latents: {strong_token_fraction}") + + + + pass + # \ No newline at end of file From d56f911cc48304cc997db24ffa7a68b6405218fc Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Feb 2025 17:28:41 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/latents/cache.py | 59 +++++++++++++++++++++--------------- delphi/latents/neighbours.py | 16 +++++++--- 2 files changed, 46 insertions(+), 29 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 1b4a2591..640b4fa5 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -393,17 +393,19 @@ def generate_statistics_cache(self): tokens = self.cache.tokens[module_path] total_n_tokens = tokens.shape[0] * tokens.shape[1] - latent_locations = self.cache.latent_locations[module_path][:,:2] + latent_locations = self.cache.latent_locations[module_path][:, :2] latents = self.cache.latent_locations[module_path][:, 2] activations = self.cache.latent_activations[module_path] - - #torch always sorts for unique, so we might as well do it - sorted_latents , latent_indices = latents.sort() + + # torch always sorts for unique, so we might as well do it + sorted_latents, latent_indices = latents.sort() sorted_activations = activations[latent_indices] sorted_tokens = tokens[latent_locations[latent_indices]] - - unique_latents, counts = torch.unique_consecutive(sorted_latents, return_counts=True) + + unique_latents, counts = torch.unique_consecutive( + sorted_latents, return_counts=True + ) # How many unique latents ever activated on the cached tokens fraction_alive = counts.shape[0] / self.width @@ -413,22 +415,25 @@ def generate_statistics_cache(self): densities = counts / total_n_tokens # How many fired more than 1% of the time - one_percent = (densities > 0.01).sum()/self.width + one_percent = (densities > 0.01).sum() / self.width print(f"Fraction of latents fired more than 1% of the time: {one_percent}") # How many fired more than 10% of the time - ten_percent = (densities > 0.1).sum()/self.width + ten_percent = (densities > 0.1).sum() / self.width print(f"Fraction of latents fired more than 10% of the time: {ten_percent}") # Try to estimate simple feature frequency split_indices = torch.cumsum(counts, dim=0) - activation_splits = torch.tensor_split(sorted_activations, split_indices[:-1]) + activation_splits = torch.tensor_split( + sorted_activations, split_indices[:-1] + ) token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1]) - # This might take a while and we may only care for statistics # but for now we do the full loop - num_single_token_features=0 - maybe_single_token_features=0 - for latent_idx, activation_group, token_group in zip(unique_latents, activation_splits, token_splits): + num_single_token_features = 0 + maybe_single_token_features = 0 + for latent_idx, activation_group, token_group in zip( + unique_latents, activation_splits, token_splits + ): sorted_activation_group, sorted_indices = activation_group.sort() sorted_token_group = token_group[sorted_indices] @@ -439,20 +444,24 @@ def generate_statistics_cache(self): wanted_tokens = sorted_token_group[:num_elements] # Check how many of them are exactly the same - _, unique_counts = torch.unique_consecutive(wanted_tokens, return_counts=True) + _, unique_counts = torch.unique_consecutive( + wanted_tokens, return_counts=True + ) max_count = unique_counts.max() maybe_single_token = False if max_count > 0.9 * num_elements: # Single token feature maybe_single_token = True - + # Randomly sample 100 activations from the top 50% - num_samples = min(100, number_activations/2) - top_50_percent = sorted_token_group[:number_activations//2] + num_samples = min(100, number_activations / 2) + top_50_percent = sorted_token_group[: number_activations // 2] sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples] sampled_tokens = top_50_percent[sampled_indices] - _, unique_counts = torch.unique_consecutive(sampled_tokens, return_counts=True) + _, unique_counts = torch.unique_consecutive( + sampled_tokens, return_counts=True + ) max_count = unique_counts.max() if max_count > 0.75 * num_samples: @@ -460,13 +469,15 @@ def generate_statistics_cache(self): num_single_token_features += 1 else: maybe_single_token_features += 1 - - single_token_fraction = num_single_token_features/maybe_single_token_features - strong_token_fraction = maybe_single_token_features/num_single_token_features + + single_token_fraction = ( + num_single_token_features / maybe_single_token_features + ) + strong_token_fraction = ( + maybe_single_token_features / num_single_token_features + ) print(f"Fraction of weak token latents: {single_token_fraction}") print(f"Fraction of strong token latents: {strong_token_fraction}") - - pass - # \ No newline at end of file + # diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 747b9bad..f80ef7b0 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -177,14 +177,16 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] latent_index = latent_index[idx_cantor_sorted_idx] n_tokens = int(idx_cantor.max().item()) - + token_batch_size = 20_000 done = False while not done: try: print("Trying with batch size", token_batch_size) # Find indices where idx_cantor crosses each batch boundary - bounday_values = torch.arange(token_batch_size, n_tokens, token_batch_size) + bounday_values = torch.arange( + token_batch_size, n_tokens, token_batch_size + ) batch_boundaries_tensor = torch.searchsorted(idx_cantor, bounday_values) batch_boundaries = [0] + batch_boundaries_tensor.tolist() @@ -192,10 +194,14 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] if batch_boundaries[-1] != len(idx_cantor): batch_boundaries.append(len(idx_cantor)) - co_occurrence_matrix = torch.zeros((n_latents, n_latents), dtype=torch.int32) - #co_occurrence_matrix = co_occurrence_matrix.cuda() + co_occurrence_matrix = torch.zeros( + (n_latents, n_latents), dtype=torch.int32 + ) + # co_occurrence_matrix = co_occurrence_matrix.cuda() - for start, end in tqdm(zip(batch_boundaries[:-1], batch_boundaries[1:])): + for start, end in tqdm( + zip(batch_boundaries[:-1], batch_boundaries[1:]) + ): # get all ind_cantor values between start and start + token_batch_size selected_idx_cantor = idx_cantor[start:end] selected_latent_index = latent_index[start:end] From e57c0fd736bd4363c8f7d7e33bff761af71df3de Mon Sep 17 00:00:00 2001 From: nev Date: Sun, 9 Mar 2025 19:01:18 +0000 Subject: [PATCH 3/9] Fix OOM issue --- delphi/latents/cache.py | 200 ++++++++++++++++++++------------------- delphi/latents/loader.py | 5 + 2 files changed, 108 insertions(+), 97 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index f2e6c8ca..3ed3a057 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -5,7 +5,7 @@ import numpy as np import torch -from jaxtyping import Float +from jaxtyping import Float, Int from safetensors.numpy import save_file from torch import Tensor from tqdm import tqdm @@ -387,99 +387,105 @@ def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str): config_dict["model_name"] = model_name json.dump(config_dict, f, indent=4) - def generate_statistics_cache(self): - - # Token frequency - for module_path in self.cache.latent_locations.keys(): - - tokens = self.cache.tokens[module_path] - total_n_tokens = tokens.shape[0] * tokens.shape[1] - - latent_locations = self.cache.latent_locations[module_path][:, :2] - - latents = self.cache.latent_locations[module_path][:, 2] - activations = self.cache.latent_activations[module_path] - - # torch always sorts for unique, so we might as well do it - sorted_latents, latent_indices = latents.sort() - sorted_activations = activations[latent_indices] - sorted_tokens = tokens[latent_locations[latent_indices]] - - unique_latents, counts = torch.unique_consecutive( - sorted_latents, return_counts=True - ) - - # How many unique latents ever activated on the cached tokens - fraction_alive = counts.shape[0] / self.width - - print(f"Fraction of latents alive: {fraction_alive}") - # Compute densities of latents - densities = counts / total_n_tokens - - # How many fired more than 1% of the time - one_percent = (densities > 0.01).sum() / self.width - print(f"Fraction of latents fired more than 1% of the time: {one_percent}") - # How many fired more than 10% of the time - ten_percent = (densities > 0.1).sum() / self.width - print(f"Fraction of latents fired more than 10% of the time: {ten_percent}") - # Try to estimate simple feature frequency - split_indices = torch.cumsum(counts, dim=0) - activation_splits = torch.tensor_split( - sorted_activations, split_indices[:-1] - ) - token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1]) - - # This might take a while and we may only care for statistics - # but for now we do the full loop - num_single_token_features = 0 - maybe_single_token_features = 0 - for latent_idx, activation_group, token_group in zip( - unique_latents, activation_splits, token_splits - ): - sorted_activation_group, sorted_indices = activation_group.sort() - sorted_token_group = token_group[sorted_indices] - - number_activations = sorted_activation_group.shape[0] - # Get the first 50 elements if possible - num_elements = min(50, number_activations) - - wanted_tokens = sorted_token_group[:num_elements] - - # Check how many of them are exactly the same - _, unique_counts = torch.unique_consecutive( - wanted_tokens, return_counts=True - ) - - max_count = unique_counts.max() - maybe_single_token = False - if max_count > 0.9 * num_elements: - # Single token feature - maybe_single_token = True - - # Randomly sample 100 activations from the top 50% - num_samples = min(100, number_activations / 2) - top_50_percent = sorted_token_group[: number_activations // 2] - sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples] - sampled_tokens = top_50_percent[sampled_indices] - _, unique_counts = torch.unique_consecutive( - sampled_tokens, return_counts=True - ) - - max_count = unique_counts.max() - if max_count > 0.75 * num_samples: - if maybe_single_token: - num_single_token_features += 1 - else: - maybe_single_token_features += 1 - - single_token_fraction = ( - num_single_token_features / maybe_single_token_features - ) - strong_token_fraction = ( - maybe_single_token_features / num_single_token_features - ) - print(f"Fraction of weak token latents: {single_token_fraction}") - print(f"Fraction of strong token latents: {strong_token_fraction}") - - pass - # +@torch.inference_mode() +def generate_statistics_cache( + tokens: Int[Tensor, "batch sequence"], + latent_locations: Int[Tensor, "n_activations 3"], + activations: Float[Tensor, "n_activations"], + width: int, +): + total_n_tokens = tokens.shape[0] * tokens.shape[1] + + latent_locations, latents = latent_locations[:, :2], latent_locations[:, 2] + + # torch always sorts for unique, so we might as well do it + sorted_latents, latent_indices = latents.sort() + sorted_activations = activations[latent_indices] + sorted_tokens = tokens[latent_locations[latent_indices]] + + unique_latents, counts = torch.unique_consecutive( + sorted_latents, return_counts=True + ) + + # How many unique latents ever activated on the cached tokens + fraction_alive = counts.shape[0] / width + + print(f"Fraction of latents alive: {fraction_alive}") + # Compute densities of latents + densities = counts / total_n_tokens + + # How many fired more than 1% of the time + one_percent = (densities > 0.01).sum() / width + print(f"Fraction of latents fired more than 1% of the time: {one_percent}") + # How many fired more than 10% of the time + ten_percent = (densities > 0.1).sum() / width + print(f"Fraction of latents fired more than 10% of the time: {ten_percent}") + # Try to estimate simple feature frequency + split_indices = torch.cumsum(counts, dim=0) + activation_splits = torch.tensor_split( + sorted_activations, split_indices[:-1] + ) + token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1]) + + # This might take a while and we may only care for statistics + # but for now we do the full loop + num_single_token_features = 0 + maybe_single_token_features = 0 + for latent_idx, activation_group, token_group in zip( + unique_latents, activation_splits, token_splits + ): + maybe_single_token, single_token = check_single_feature( + latent_idx, activation_group, token_group + ) + num_single_token_features += single_token + maybe_single_token_features += maybe_single_token + + single_token_fraction = ( + maybe_single_token / width + ) + strong_token_fraction = ( + single_token / width + ) + print(f"Fraction of weak token latents: {single_token_fraction:%}") + print(f"Fraction of strong token latents: {strong_token_fraction:%}") + + +@torch.inference_mode() +def check_single_feature(latent_idx, activation_group, token_group): + sorted_activation_group, sorted_indices = activation_group.sort() + sorted_token_group = token_group[sorted_indices] + + number_activations = sorted_activation_group.shape[0] + # Get the first 50 elements if possible + num_elements = min(50, number_activations) + + wanted_tokens = sorted_token_group[:num_elements] + + # Check how many of them are exactly the same + _, unique_counts = torch.unique_consecutive( + wanted_tokens, return_counts=True + ) + + max_count = unique_counts.max() + maybe_single_token = False + if max_count > 0.9 * num_elements: + # Single token feature + maybe_single_token = True + + # Randomly sample 100 activations from the top 50% + n_top = max(1, int(number_activations * 0.5)) + num_samples = min(100, n_top) + top_50_percent = sorted_token_group[:n_top] + sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples] + sampled_tokens = top_50_percent[sampled_indices] + _, unique_counts = torch.unique_consecutive( + sampled_tokens, return_counts=True + ) + + max_count = unique_counts.max() + if max_count > 0.75 * num_samples: + if maybe_single_token: + return 1, 0 + else: + return 0, 1 + return 0, 0 diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index dcab0d50..710e7273 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -86,6 +86,11 @@ def load( Float[Tensor, "activations"], Float[Tensor, "batch seq"] | None, ]: + """Load the tensor buffer's data. + + Returns: + Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations, and tokens (if present in the cache). + """ split_data = load_file(self.path) first_latent = int(self.path.split("/")[-1].split("_")[0]) activations = torch.tensor(split_data["activations"]) From 97f379d805818953cdb488cf749b13c37a8ffa3c Mon Sep 17 00:00:00 2001 From: nev Date: Sun, 9 Mar 2025 19:52:36 +0000 Subject: [PATCH 4/9] Return caching statistics, fix single token bug --- delphi/latents/cache.py | 83 ++++++++++++++++++++++++++++++++--------- 1 file changed, 66 insertions(+), 17 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 3ed3a057..11ec1d48 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -2,6 +2,7 @@ from collections import defaultdict from pathlib import Path from typing import Callable +from dataclasses import dataclass import numpy as np import torch @@ -371,6 +372,21 @@ def save_splits(self, n_splits: int, save_dir: Path, save_tokens: bool = True): save_file(split_data, output_file) + def generate_statistics_cache(self): + """Print statistics (number of dead features, number of single token features) to the console. + """ + print("Feature statistics:") + # Token frequency + for module_path in self.cache.latent_locations.keys(): + print(f"# Module: {module_path}") + generate_statistics_cache( + self.cache.tokens[module_path], + self.cache.latent_locations[module_path], + self.cache.latent_activations[module_path], + self.width, + verbose=True + ) + def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str): """ Save the configuration for the cached latents. @@ -387,13 +403,35 @@ def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str): config_dict["model_name"] = model_name json.dump(config_dict, f, indent=4) + +@dataclass +class CacheStatistics: + frac_alive: float + frac_fired_1pct: float + frac_fired_10pct: float + frac_weak_single_token: float + frac_strong_single_token: float + + @torch.inference_mode() def generate_statistics_cache( tokens: Int[Tensor, "batch sequence"], latent_locations: Int[Tensor, "n_activations 3"], activations: Float[Tensor, "n_activations"], width: int, -): + verbose: bool = False, +) -> CacheStatistics: + """Generate global statistics for the cache." + + Args: + tokens (Int[Tensor, "batch sequence"]): Tokens used to generate the cache. + latent_locations (Int[Tensor, "n_activations 3"]): Indices of the latent activations, corresponding to `tokens`. + activations (Float[Tensor, "n_activations"]): Activations of the latents, as stored by the cache. + width (int): Width of the cache to test. + verbose (bool, optional): Print results to stdout. Defaults to False. + Returns: + CacheStatistics: the statistics + """ total_n_tokens = tokens.shape[0] * tokens.shape[1] latent_locations, latents = latent_locations[:, :2], latent_locations[:, 2] @@ -408,18 +446,20 @@ def generate_statistics_cache( ) # How many unique latents ever activated on the cached tokens - fraction_alive = counts.shape[0] / width - - print(f"Fraction of latents alive: {fraction_alive}") + num_alive = counts.shape[0] + fraction_alive = num_alive / width + if verbose: + print(f"Fraction of latents alive: {fraction_alive:%}") # Compute densities of latents densities = counts / total_n_tokens # How many fired more than 1% of the time one_percent = (densities > 0.01).sum() / width - print(f"Fraction of latents fired more than 1% of the time: {one_percent}") # How many fired more than 10% of the time ten_percent = (densities > 0.1).sum() / width - print(f"Fraction of latents fired more than 10% of the time: {ten_percent}") + if verbose: + print(f"Fraction of latents fired more than 1% of the time: {one_percent:%}") + print(f"Fraction of latents fired more than 10% of the time: {ten_percent:%}") # Try to estimate simple feature frequency split_indices = torch.cumsum(counts, dim=0) activation_splits = torch.tensor_split( @@ -431,27 +471,36 @@ def generate_statistics_cache( # but for now we do the full loop num_single_token_features = 0 maybe_single_token_features = 0 - for latent_idx, activation_group, token_group in zip( + for _latent_idx, activation_group, token_group in zip( unique_latents, activation_splits, token_splits ): maybe_single_token, single_token = check_single_feature( - latent_idx, activation_group, token_group + activation_group, token_group ) num_single_token_features += single_token maybe_single_token_features += maybe_single_token single_token_fraction = ( - maybe_single_token / width + maybe_single_token_features / num_alive ) strong_token_fraction = ( - single_token / width + num_single_token_features / num_alive + ) + if verbose: + print(f"Fraction of weak token latents: {single_token_fraction:%}") + print(f"Fraction of strong token latents: {strong_token_fraction:%}") + + return CacheStatistics( + frac_alive=fraction_alive, + frac_fired_1pct=one_percent, + frac_fired_10pct=ten_percent, + frac_weak_single_token=single_token_fraction, + frac_strong_single_token=strong_token_fraction, ) - print(f"Fraction of weak token latents: {single_token_fraction:%}") - print(f"Fraction of strong token latents: {strong_token_fraction:%}") @torch.inference_mode() -def check_single_feature(latent_idx, activation_group, token_group): +def check_single_feature(activation_group, token_group): sorted_activation_group, sorted_indices = activation_group.sort() sorted_token_group = token_group[sorted_indices] @@ -484,8 +533,8 @@ def check_single_feature(latent_idx, activation_group, token_group): max_count = unique_counts.max() if max_count > 0.75 * num_samples: - if maybe_single_token: - return 1, 0 - else: return 0, 1 - return 0, 0 + elif maybe_single_token: + return 1, 0 + else: + return 0, 0 From d5d1a2c57dc8547a6378cdd46be967f5f4b0cf2b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 9 Mar 2025 19:57:56 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/latents/cache.py | 29 +++++++++-------------------- 1 file changed, 9 insertions(+), 20 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index a5575157..c625a54d 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -1,8 +1,8 @@ import json from collections import defaultdict +from dataclasses import dataclass from pathlib import Path from typing import Callable -from dataclasses import dataclass import numpy as np import torch @@ -388,8 +388,7 @@ def save_splits(self, n_splits: int, save_dir: Path, save_tokens: bool = True): save_file(split_data, output_file) def generate_statistics_cache(self): - """Print statistics (number of dead features, number of single token features) to the console. - """ + """Print statistics (number of dead features, number of single token features) to the console.""" print("Feature statistics:") # Token frequency for module_path in self.cache.latent_locations.keys(): @@ -399,7 +398,7 @@ def generate_statistics_cache(self): self.cache.latent_locations[module_path], self.cache.latent_activations[module_path], self.width, - verbose=True + verbose=True, ) def save_config(self, save_dir: Path, cfg: CacheConfig, model_name: str): @@ -490,9 +489,7 @@ def generate_statistics_cache( print(f"Fraction of latents fired more than 10% of the time: {ten_percent:%}") # Try to estimate simple feature frequency split_indices = torch.cumsum(counts, dim=0) - activation_splits = torch.tensor_split( - sorted_activations, split_indices[:-1] - ) + activation_splits = torch.tensor_split(sorted_activations, split_indices[:-1]) token_splits = torch.tensor_split(sorted_tokens, split_indices[:-1]) # This might take a while and we may only care for statistics @@ -508,16 +505,12 @@ def generate_statistics_cache( num_single_token_features += single_token maybe_single_token_features += maybe_single_token - single_token_fraction = ( - maybe_single_token_features / num_alive - ) - strong_token_fraction = ( - num_single_token_features / num_alive - ) + single_token_fraction = maybe_single_token_features / num_alive + strong_token_fraction = num_single_token_features / num_alive if verbose: print(f"Fraction of weak token latents: {single_token_fraction:%}") print(f"Fraction of strong token latents: {strong_token_fraction:%}") - + return CacheStatistics( frac_alive=fraction_alive, frac_fired_1pct=one_percent, @@ -539,9 +532,7 @@ def check_single_feature(activation_group, token_group): wanted_tokens = sorted_token_group[:num_elements] # Check how many of them are exactly the same - _, unique_counts = torch.unique_consecutive( - wanted_tokens, return_counts=True - ) + _, unique_counts = torch.unique_consecutive(wanted_tokens, return_counts=True) max_count = unique_counts.max() maybe_single_token = False @@ -555,9 +546,7 @@ def check_single_feature(activation_group, token_group): top_50_percent = sorted_token_group[:n_top] sampled_indices = torch.randperm(top_50_percent.shape[0])[:num_samples] sampled_tokens = top_50_percent[sampled_indices] - _, unique_counts = torch.unique_consecutive( - sampled_tokens, return_counts=True - ) + _, unique_counts = torch.unique_consecutive(sampled_tokens, return_counts=True) max_count = unique_counts.max() if max_count > 0.75 * num_samples: From e001b2d240ef0458c7228f11f78b6b1da9345ed4 Mon Sep 17 00:00:00 2001 From: nev Date: Sun, 9 Mar 2025 20:04:22 +0000 Subject: [PATCH 6/9] CI fixes --- delphi/latents/cache.py | 12 ++++++++---- delphi/latents/loader.py | 3 ++- delphi/latents/neighbours.py | 5 ++--- delphi/log/result_analysis.py | 4 ++-- delphi/utils.py | 2 +- 5 files changed, 15 insertions(+), 11 deletions(-) diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index a5575157..7cffb487 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -1,8 +1,8 @@ import json from collections import defaultdict +from dataclasses import dataclass from pathlib import Path from typing import Callable -from dataclasses import dataclass import numpy as np import torch @@ -388,8 +388,10 @@ def save_splits(self, n_splits: int, save_dir: Path, save_tokens: bool = True): save_file(split_data, output_file) def generate_statistics_cache(self): - """Print statistics (number of dead features, number of single token features) to the console. + """Print statistics (number of dead features, number of single token features) + to the console. """ + assert self.width is not None, "Width must be set before generating statistics" print("Feature statistics:") # Token frequency for module_path in self.cache.latent_locations.keys(): @@ -453,8 +455,10 @@ def generate_statistics_cache( Args: tokens (Int[Tensor, "batch sequence"]): Tokens used to generate the cache. - latent_locations (Int[Tensor, "n_activations 3"]): Indices of the latent activations, corresponding to `tokens`. - activations (Float[Tensor, "n_activations"]): Activations of the latents, as stored by the cache. + latent_locations (Int[Tensor, "n_activations 3"]): Indices of the latent + activations, corresponding to `tokens`. + activations (Float[Tensor, "n_activations"]): Activations of the latents, + as stored by the cache. width (int): Width of the cache to test. verbose (bool, optional): Print results to stdout. Defaults to False. Returns: diff --git a/delphi/latents/loader.py b/delphi/latents/loader.py index 710e7273..a3d55e58 100644 --- a/delphi/latents/loader.py +++ b/delphi/latents/loader.py @@ -89,7 +89,8 @@ def load( """Load the tensor buffer's data. Returns: - Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations, and tokens (if present in the cache). + Tuple[Tensor, Tensor, Optional[Tensor]]: Locations, activations, + and tokens (if present in the cache). """ split_data = load_file(self.path) first_latent = int(self.path.split("/")[-1].split("_")[0]) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 791e6388..f4878de4 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -179,8 +179,6 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] idx_cantor, idx_cantor_sorted_idx = idx_cantor.sort(dim=0, stable=True) latent_index = latent_index[idx_cantor_sorted_idx] - n_tokens = int(idx_cantor.max().item()) - token_batch_size = 20_000 co_occurrence_matrix = None done = False @@ -213,7 +211,8 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] for start, end in tqdm( zip(batch_boundaries[:-1], batch_boundaries[1:]) ): - # get all ind_cantor values between start and start + token_batch_size + # get all ind_cantor values between start + # and start + token_batch_size selected_idx_cantor = idx_cantor[start:end] # shift the indices to start from 0 selected_idx_cantor = selected_idx_cantor - selected_idx_cantor[0] diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index bd0ca176..81776ccc 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -9,8 +9,8 @@ def import_plotly(): try: - import plotly.express as px - import plotly.io as pio + import plotly.express as px # type: ignore + import plotly.io as pio # type: ignore except ImportError: raise ImportError( "Plotly is not installed.\n" diff --git a/delphi/utils.py b/delphi/utils.py index 2b278cb9..a148d238 100644 --- a/delphi/utils.py +++ b/delphi/utils.py @@ -41,4 +41,4 @@ def assert_type(typ: Type[T], obj: Any) -> T: if not isinstance(obj, typ): raise TypeError(f"Expected {typ.__name__}, got {type(obj).__name__}") - return cast(typ, obj) + return cast(typ, obj) # type: ignore From 57f2bdbba48b77e6694f4fffcc551a475b510033 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 13 Mar 2025 12:10:23 +0000 Subject: [PATCH 7/9] Correct neighbours --- delphi/latents/neighbours.py | 10 ++-------- 1 file changed, 2 insertions(+), 8 deletions(-) diff --git a/delphi/latents/neighbours.py b/delphi/latents/neighbours.py index 586938f5..a1879598 100644 --- a/delphi/latents/neighbours.py +++ b/delphi/latents/neighbours.py @@ -181,7 +181,7 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] idx_cantor, idx_cantor_sorted_idx = idx_cantor.sort(dim=0, stable=True) latent_index = latent_index[idx_cantor_sorted_idx] - token_batch_size = 20_000 + token_batch_size = 100_000 co_occurrence_matrix = None done = False while not done: @@ -205,16 +205,10 @@ def _compute_cooccurrence_neighbours(self) -> dict[int, list[tuple[int, float]]] (n_latents, n_latents), dtype=torch.int64 ) - co_occurrence_matrix = torch.zeros( - (n_latents, n_latents), dtype=torch.int32 - ) - # co_occurrence_matrix = co_occurrence_matrix.cuda() - for start, end in tqdm( zip(batch_boundaries[:-1], batch_boundaries[1:]) ): - # get all ind_cantor values between start - # and start + token_batch_size + # get all ind_cantor values between start and start selected_idx_cantor = idx_cantor[start:end] # shift the indices to start from 0 selected_idx_cantor = selected_idx_cantor - selected_idx_cantor[0] From bde095c3670f8e528a8b50f0926401435d8c7720 Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 13 Mar 2025 12:28:35 +0000 Subject: [PATCH 8/9] Fix print --- delphi/__main__.py | 1 - delphi/latents/cache.py | 9 +++++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 2c71c32b..7cf5e21b 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -279,7 +279,6 @@ def populate_cache( # Save firing counts to the run-specific log directory cache.save_firing_counts() - if run_cfg.verbose: cache.generate_statistics_cache() diff --git a/delphi/latents/cache.py b/delphi/latents/cache.py index 9f741404..00e082e8 100644 --- a/delphi/latents/cache.py +++ b/delphi/latents/cache.py @@ -514,8 +514,8 @@ def generate_statistics_cache( single_token_fraction = maybe_single_token_features / num_alive strong_token_fraction = num_single_token_features / num_alive if verbose: - print(f"Fraction of weak token latents: {single_token_fraction:%}") - print(f"Fraction of strong token latents: {strong_token_fraction:%}") + print(f"Fraction of weak single token latents: {single_token_fraction:%}") + print(f"Fraction of strong single token latents: {strong_token_fraction:%}") return CacheStatistics( frac_alive=fraction_alive, @@ -555,9 +555,10 @@ def check_single_feature(activation_group, token_group): _, unique_counts = torch.unique_consecutive(sampled_tokens, return_counts=True) max_count = unique_counts.max() - if max_count > 0.75 * num_samples: + other_maybe_single_token = max_count > 0.75 * num_samples + if other_maybe_single_token and maybe_single_token: return 0, 1 - elif maybe_single_token: + elif maybe_single_token or other_maybe_single_token: return 1, 0 else: return 0, 0 From b53e4385cb43bb0227d541e3fb28af8aac64beba Mon Sep 17 00:00:00 2001 From: SrGonao Date: Thu, 13 Mar 2025 12:32:30 +0000 Subject: [PATCH 9/9] Only save if verbose --- delphi/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 7cf5e21b..ff547c92 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -278,8 +278,8 @@ def populate_cache( cache.run(cache_cfg.n_tokens, tokens) # Save firing counts to the run-specific log directory - cache.save_firing_counts() if run_cfg.verbose: + cache.save_firing_counts() cache.generate_statistics_cache() cache.save_splits(