Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions delphi/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ def scorer_postprocess(result, score_dir):

def populate_cache(
run_cfg: RunConfig,
latent_cfg: LatentConfig,
cfg: CacheConfig,
model: PreTrainedModel,
hookpoint_to_sae_encode: dict[str, Callable],
Expand Down Expand Up @@ -286,7 +285,6 @@ async def run(
):
populate_cache(
run_cfg,
latent_cfg,
cache_cfg,
model,
hookpoint_to_sae_encode,
Expand Down
13 changes: 7 additions & 6 deletions delphi/latents/cache.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import json
from collections import defaultdict
from pathlib import Path
from typing import Callable

import numpy as np
import torch
from safetensors.numpy import save_file
from torch import Tensor, nn
from torch import Tensor
from torchtyping import TensorType
from tqdm import tqdm

Expand Down Expand Up @@ -157,7 +158,7 @@ class LatentCache:
def __init__(
self,
model,
hookpoint_to_sae_encode: dict[str, nn.Module],
hookpoint_to_sparse_encode: dict[str, Callable],
batch_size: int,
filters: dict[str, TensorType["indices"]] | None = None,
):
Expand All @@ -166,12 +167,12 @@ def __init__(

Args:
model: The model to cache latents for.
hookpoint_to_sae_encode: Dictionary of submodules to cache.
hookpoint_to_sparse_encode: Dictionary of sparse encoding functions.
batch_size: Size of batches for processing.
filters: Filters for selecting specific latents.
"""
self.model = model
self.hookpoint_to_sae_encode = hookpoint_to_sae_encode
self.hookpoint_to_sparse_encode = hookpoint_to_sparse_encode

self.batch_size = batch_size
self.width = None
Expand Down Expand Up @@ -237,12 +238,12 @@ def run(self, n_tokens: int, tokens: TensorType["batch", "seq"]):

with torch.no_grad():
with collect_activations(
self.model, list(self.hookpoint_to_sae_encode.keys())
self.model, list(self.hookpoint_to_sparse_encode.keys())
) as activations:
self.model(batch.to(self.model.device))

for hookpoint, latents in activations.items():
sae_latents = self.hookpoint_to_sae_encode[hookpoint](
sae_latents = self.hookpoint_to_sparse_encode[hookpoint](
latents
)
self.cache.add(sae_latents, batch, batch_number, hookpoint)
Expand Down