Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -238,8 +238,13 @@ Surprisal scoring computes the loss over some examples and uses a base model. We

Embedding scoring uses a small embedding model through `sentence_transformers` to embed the examples do retrival. It also does not use VLLM but run the model directly. The setup is similar as above but for a example check `embedding.py` in the experiments folder.

# Breaking changes in v0.2


`features.cache`: Dataset tokens are now saved in safetensors files together with the activations.

`features.constructors.default_constructor`: `tokens` was renamed to `token_loader`, which must be a callable for lazy loading. Instead of passing `tokens=dataset.tokens`, pass `token_loader=lambda: dataset.load_tokens()` (assuming `dataset` is a `FeatureDataset` instance).

# Scripts

Example scripts can be found in `demos`. Some of these scripts can be called from the CLI, as seen in examples found in `scripts`. These baseline scripts should allow anyone to start generating and scoring explanations in any SAE they are interested in. One always needs to first cache the activations of the features of any given SAE, and then generating explanations and scoring them can be done at the same time.
Expand Down
25 changes: 21 additions & 4 deletions sae_auto_interp/features/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,14 @@ def __init__(
"""
self.feature_locations = defaultdict(list)
self.feature_activations = defaultdict(list)
self.tokens = defaultdict(list)
self.filters = filters
self.batch_size = batch_size

def add(
self,
latents: TensorType["batch", "sequence", "feature"],
tokens: TensorType["batch", "sequence"],
batch_number: int,
module_path: str,
):
Expand All @@ -44,17 +46,20 @@ def add(

Args:
latents (TensorType["batch", "sequence", "feature"]): Latent activations.
tokens (TensorType["batch", "sequence"]): Input tokens.
batch_number (int): Current batch number.
module_path (str): Path of the module.
"""
feature_locations, feature_activations = self.get_nonzeros(latents, module_path)
feature_locations = feature_locations.cpu()
feature_activations = feature_activations.cpu()
tokens = tokens.cpu()

# Adjust batch indices
feature_locations[:, 0] += batch_number * self.batch_size
self.feature_locations[module_path].append(feature_locations)
self.feature_activations[module_path].append(feature_activations)
self.tokens[module_path].append(tokens)

def save(self):
"""
Expand All @@ -68,6 +73,10 @@ def save(self):
self.feature_activations[module_path] = torch.cat(
self.feature_activations[module_path], dim=0
)

self.tokens[module_path] = torch.cat(
self.tokens[module_path], dim=0
)

def get_nonzeros_batch(self, latents: TensorType["batch", "seq", "feature"]):
"""
Expand Down Expand Up @@ -112,7 +121,8 @@ def get_nonzeros(
module_path (str): Path of the module.

Returns:
Tuple[torch.Tensor, torch.Tensor]: Non-zero feature locations and activations.
Tuple[TensorType["num_nonzero", 3], TensorType["num_nonzero"]]:
Non-zero feature locations and activations.
"""
size = latents.shape[1] * latents.shape[0] * latents.shape[2]
if size > torch.iinfo(torch.int32).max:
Expand Down Expand Up @@ -228,7 +238,7 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):
for module_path, submodule in self.submodule_dict.items():
buffer[module_path] = submodule.ae.output.save()
for module_path, latents in buffer.items():
self.cache.add(latents, batch_number, module_path)
self.cache.add(latents, batch, batch_number, module_path)

del buffer
torch.cuda.empty_cache()
Expand All @@ -240,12 +250,13 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):
print(f"Total tokens processed: {total_tokens:,}")
self.cache.save()

def save(self, save_dir):
def save(self, save_dir, save_tokens: bool = True):
"""
Save the cached features to disk.

Args:
save_dir (str): Directory to save the features.
save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to
"""
for module_path in self.cache.feature_locations.keys():
output_file = f"{save_dir}/{module_path}.safetensors"
Expand All @@ -254,6 +265,8 @@ def save(self, save_dir):
"locations": self.cache.feature_locations[module_path],
"activations": self.cache.feature_activations[module_path],
}
if save_tokens:
data["tokens"] = self.cache.tokens[module_path]

save_file(data, output_file)

Expand All @@ -272,19 +285,21 @@ def _generate_split_indices(self, n_splits):
# Adjust end by one
return list(zip(boundaries[:-1], boundaries[1:] - 1))

def save_splits(self, n_splits: int, save_dir):
def save_splits(self, n_splits: int, save_dir, save_tokens: bool = True):
"""
Save the cached features in splits.

Args:
n_splits (int): Number of splits to generate.
save_dir (str): Directory to save the splits.
save_tokens (bool): Whether to save the dataset tokens used to generate the cache. Defaults to True.
"""
split_indices = self._generate_split_indices(n_splits)

for module_path in self.cache.feature_locations.keys():
feature_locations = self.cache.feature_locations[module_path]
feature_activations = self.cache.feature_activations[module_path]
tokens = self.cache.tokens[module_path].numpy()
features = feature_locations[:, 2]

for start, end in split_indices:
Expand All @@ -309,6 +324,8 @@ def save_splits(self, n_splits: int, save_dir):
"locations": masked_locations,
"activations": masked_activations,
}
if save_tokens:
split_data["tokens"] = tokens

save_file(split_data, output_file)

Expand Down
22 changes: 20 additions & 2 deletions sae_auto_interp/features/constructors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import torch
from torchtyping import TensorType
from typing import Callable, Optional

from .features import FeatureRecord, prepare_examples
from .loader import BufferOutput
Expand Down Expand Up @@ -100,7 +101,7 @@ def random_activation_windows(

def default_constructor(
record: FeatureRecord,
tokens: TensorType["batch", "seq"],
token_loader: Optional[Callable[[], TensorType["batch", "seq"]]],
buffer_output: BufferOutput,
n_random: int,
ctx_len: int,
Expand All @@ -111,12 +112,29 @@ def default_constructor(

Args:
record (FeatureRecord): The feature record to update.
tokens (TensorType["batch", "seq"]): The input tokens.
token_loader (Optional[Callable[[], TensorType["batch", "seq"]]]):
An optional function that creates the dataset tokens.
buffer_output (BufferOutput): The buffer output containing activations and locations.
n_random (int): Number of random examples to generate.
ctx_len (int): Context length for each example.
max_examples (int): Maximum number of examples to generate.
"""
tokens = buffer_output.tokens
if tokens is None:
if token_loader is None:
raise ValueError("Either tokens or token_loader must be provided")
try:
tokens = token_loader()
except TypeError:
raise ValueError(
"Starting with v0.2, `tokens` was renamed to `token_loader`, "
"which must be a callable for lazy loading.\n\n"
"Instead of passing\n"
"` tokens=dataset.tokens`,\n"
"pass\n"
"` token_loader=lambda: dataset.load_tokens()`,\n"
"(assuming `dataset` is a `FeatureDataset` instance)."
)
pool_max_activation_windows(
record,
tokens=tokens,
Expand Down
44 changes: 32 additions & 12 deletions sae_auto_interp/features/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@ class BufferOutput(NamedTuple):
feature (Feature): The feature associated with this output.
locations (TensorType["locations", 2]): Tensor of feature locations.
activations (TensorType["locations"]): Tensor of feature activations.
tokens (TensorType["tokens"]): Tensor of all tokens.
"""
feature: Feature
locations: TensorType["locations", 2]
activations: TensorType["locations"]
tokens: TensorType["tokens"]


class TensorBuffer:
Expand Down Expand Up @@ -70,6 +72,10 @@ def __iter__(self):
first_feature = int(self.tensor_path.split("/")[-1].split("_")[0])
activations = torch.tensor(split_data["activations"])
locations = torch.tensor(split_data["locations"].astype(np.int64))
if hasattr(split_data, "tokens"):
tokens = torch.tensor(split_data["tokens"].astype(np.int64))
else:
tokens = None

locations[:,2] = locations[:,2] + first_feature

Expand All @@ -95,7 +101,8 @@ def __iter__(self):
yield BufferOutput(
Feature(self.module_path, int(features[i].item())),
feature_locations,
feature_activations
feature_activations,
tokens
)

def reset(self):
Expand Down Expand Up @@ -139,17 +146,30 @@ def __init__(
cache_config = json.load(f)
temp_model = LanguageModel(cache_config["model_name"], device_map="cpu", dispatch=False)
self.tokenizer = temp_model.tokenizer
print(cache_config)
self.tokens = load_tokenized_data(
cache_config["ctx_len"],
self.tokenizer,
cache_config["dataset_repo"],
cache_config["dataset_split"],
cache_config["dataset_name"],
cache_config["dataset_column_name"],
)
print(self.tokenizer.decode(self.tokens[0]))


self.cache_config = cache_config

def load_tokens(self):
"""
Load tokenized data for the dataset.
Caches the tokenized data if not already loaded.

Returns:
torch.Tensor: The tokenized dataset.
"""
if not hasattr(self, "tokens"):
self.tokens = load_tokenized_data(
self.cache_config["ctx_len"],
self.tokenizer,
self.cache_config["dataset_repo"],
self.cache_config["dataset_split"],
self.cache_config["dataset_name"],
column_name=self.cache_config.get(
"column_name", self.cache_config.get("dataset_row", "raw_content")
),
)
return self.tokens

def _edges(self):
"""Generate edge indices for feature splits."""
return torch.linspace(0, self.cfg.width, steps=self.cfg.n_splits + 1).long()
Expand Down