In [1]:
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool 

  if not hasattr(np, "bool"):


In [2]:
from dataclasses import dataclass, field
from pathlib import Path

import torch as th
from tqdm.auto import trange
from transformers import AutoTokenizer
from dictionary_learning.cache import PairedActivationCache
from dictionary_learning.dictionary import BatchTopKCrossCoder
import pandas as pd
import json

In [3]:
def load_dictionary_model(
    model_name: str | Path, is_sae: bool | None = None,
):
    """Load a dictionary model from a local path or HuggingFace Hub.

    Args:
        model_name: Name or path of the model to load

    Returns:
        The loaded dictionary model
    """
    # Check if it's a HuggingFace Hub model
    if "/" not in str(model_name) or not Path(model_name).exists():
        # Legacy model
        if str(model_name) in df_hf_repo_legacy:
            model_name = df_hf_repo_legacy[str(model_name)]
        else:
            model_name = str(model_name)
        if "/" not in str(model_name):
            model_id = f"{author}/{str(model_name)}"
        else:
            model_id = model_name
        # Download config to determine model type
        if file_exists(model_id, "trainer_config.json", repo_type="model"):
            config_path = hf_hub_download(
                repo_id=model_id, filename="trainer_config.json"
            )
            with open(config_path, "r") as f:
                config = json.load(f)["trainer"]

            # Determine model class based on config
            if "dict_class" in config and config["dict_class"] in [
                "BatchTopKSAE",
                "CrossCoder",
                "BatchTopKCrossCoder",
            ]:
                return eval(
                    f"{config['dict_class']}.from_pretrained(model_id, from_hub=True)"
                )
            else:
                raise ValueError(f"Unknown model type: {config['dict_class']}")
        else:
            logger.info(
                f"No config found for {model_id}, relying on is_sae={is_sae} arg to determine model type"
            )
            # If no model_type in config, try to infer from other fields
            if is_sae:
                return BatchTopKSAE.from_pretrained(model_id, from_hub=True)
            else:
                return CrossCoder.from_pretrained(model_id, from_hub=True)
    else:
        # Local model
        model_path = Path(model_name)
        if not model_path.exists():
            raise ValueError(f"Local model {model_name} does not exist")

        # Load the config
        with open(model_path.parent / "config.json", "r") as f:
            config = json.load(f)["trainer"]

        # Determine model class based on config
        if "dict_class" in config and config["dict_class"] in [
            "BatchTopKSAE",
            "CrossCoder",
            "BatchTopKCrossCoder",
        ]:
            return eval(f"{config['dict_class']}.from_pretrained(model_path)")
        else:
            raise ValueError(f"Unknown model type: {config['dict_class']}")


In [4]:
class LatentActivationCache:
    def __init__(
        self,
        latent_activations_dir: Path,
        expand=True,
        offset=0,
        use_sparse_tensor=False,
        device: th.device = None,
    ):
        if isinstance(latent_activations_dir, str):
            latent_activations_dir = Path(latent_activations_dir)

        # Create progress bar for 7 files to load
        pbar = tqdm(total=7, desc="Loading cache files")

        pbar.set_postfix_str("Loading out_acts.pt")
        self.acts = th.load(latent_activations_dir / "out_acts.pt", weights_only=True)
        pbar.update(1)

        pbar.set_postfix_str("Loading out_ids.pt")
        self.ids = th.load(latent_activations_dir / "out_ids.pt", weights_only=True)
        pbar.update(1)

        pbar.set_postfix_str("Loading max_activations.pt")
        self.max_activations = th.load(
            latent_activations_dir / "max_activations.pt", weights_only=True
        )
        pbar.update(1)

        pbar.set_postfix_str("Loading latent_ids.pt")
        self.latent_ids = th.load(
            latent_activations_dir / "latent_ids.pt", weights_only=True
        )
        pbar.update(1)

        pbar.set_postfix_str("Loading padded_sequences.pt")
        self.padded_sequences = th.load(
            latent_activations_dir / "padded_sequences.pt", weights_only=True
        )
        pbar.update(1)

        self.dict_size = self.max_activations.shape[0]

        pbar.set_postfix_str("Loading seq_lengths.pt")
        self.sequence_lengths = th.load(
            latent_activations_dir / "seq_lengths.pt", weights_only=True
        )
        pbar.update(1)

        pbar.set_postfix_str("Loading seq_ranges.pt")
        self.sequence_ranges = th.load(
            latent_activations_dir / "seq_ranges.pt", weights_only=True
        )
        pbar.update(1)
        pbar.close()

        self.expand = expand
        self.offset = offset
        self.use_sparse_tensor = use_sparse_tensor
        self.device = device
        if device is not None:
            self.to(device)

    def __len__(self):
        return len(self.padded_sequences) - self.offset

    def __getitem__(self, index: int):
        """
        Retrieves tokens and latent activations for a specific sequence.

        Args:
            index (int): The index of the sequence to retrieve.

        Returns:
            tuple: A pair containing:
                - The token sequence for the sample
                - If self.expand is True:
                    - If use_sparse_tensor is True:
                        A sparse tensor of shape (sequence_length, dict_size) containing the latent activations
                    - If use_sparse_tensor is False:
                        A dense tensor of shape (sequence_length, dict_size) containing the latent activations
                - If self.expand is False:
                    A tuple of (indices, values) representing sparse latent activations where:
                    - indices: Tensor of shape (N, 2) containing (token_idx, dict_idx) pairs
                    - values: Tensor of shape (N,) containing activation values
        """
        return self.get_sequence(index), self.get_latent_activations(
            index, expand=self.expand, use_sparse_tensor=self.use_sparse_tensor
        )

    def get_sequence(self, index: int):
        return self.padded_sequences[index + self.offset][
            : self.sequence_lengths[index + self.offset]
        ]

    def get_latent_activations(
        self, index: int, expand: bool = True, use_sparse_tensor: bool = False
    ):
        start_index = self.sequence_ranges[index + self.offset]
        end_index = self.sequence_ranges[index + self.offset + 1]
        seq_indices = self.ids[start_index:end_index]
        assert th.all(
            seq_indices[:, 0] == index + self.offset
        ), f"Was supposed to find {index + self.offset} but found {seq_indices[:, 0].unique()}"
        seq_indices = seq_indices[:, 1:]  # remove seq_idx column

        if expand:
            if use_sparse_tensor:
                # Create sparse tensor directly
                indices = (
                    seq_indices.t()
                )  # Transpose to get 2xN format required by sparse tensors
                values = self.acts[start_index:end_index]
                sparse_shape = (
                    self.sequence_lengths[index + self.offset],
                    self.dict_size,
                )
                return th.sparse_coo_tensor(indices, values, sparse_shape)
            else:
                # Create dense tensor as before
                latent_activations = th.zeros(
                    self.sequence_lengths[index + self.offset],
                    self.dict_size,
                    device=self.acts.device,
                )
                latent_activations[seq_indices[:, 0], seq_indices[:, 1]] = self.acts[
                    start_index:end_index
                ]
                return latent_activations
        else:
            return (seq_indices, self.acts[start_index:end_index])

    def to(self, device: th.device):
        self.acts = self.acts.to(device)
        self.ids = self.ids.to(device)
        self.max_activations = self.max_activations.to(device)
        self.latent_ids = self.latent_ids.to(device)
        self.padded_sequences = self.padded_sequences.to(device)
        self.device = device
        return self

In [5]:
def load_activation_dataset(
    activation_store_dir: Path,
    base_model: str = "base",
    finetune_model: str = "finetune",
    layer: int = 20,
    split: str = "train",
    true_split: str | None = None,
    false_split: str | None = None,
    true_name: str = "MATS_true_processed",
    false_name: str = "MATS_false_processed",
):
    """
    Load the saved activations of the base and finetuned models for a given layer.

    Args:
        activation_store_dir: Root directory where activations are stored
        base_model: The base model name
        finetune_model: The finetuned model name
        layer: Layer index to load
        split: Default split to load ("train", "val", etc.)
        true_split: Override split for true dataset
        false_split: Override split for false dataset
        true_name: Dataset name for true facts
        false_name: Dataset name for false facts

    Returns:
        A tuple (true_cache, false_cache) where each is a PairedActivationCache
    """
    # Resolve splits
    if true_split is None:
        true_split = split
    if false_split is None:
        false_split = split

    activation_store_dir = Path(activation_store_dir)

    # Build paths for true dataset
   
    base_model_dir_true = activation_store_dir / base_model
    finetune_model_dir_true = activation_store_dir / finetune_model
    
    # Build paths for false dataset
    base_model_dir_false = activation_store_dir / base_model
    finetune_model_dir_false = activation_store_dir / finetune_model

    submodule_name = f"layer_{layer}_out"

    # Final dataset directories
    base_model_true = base_model_dir_true / true_name / true_split
    finetune_model_true = finetune_model_dir_true / true_name / true_split

    base_model_false = base_model_dir_false / false_name / false_split
    finetune_model_false = finetune_model_dir_false / false_name / false_split

    # Load activation caches
    print(
        f"Loading true cache from {base_model_true / submodule_name} "
        f"and {finetune_model_true / submodule_name}"
    )
    true_cache = PairedActivationCache(
        base_model_true / submodule_name, finetune_model_true / submodule_name
    )

    print(
        f"Loading false cache from {base_model_false / submodule_name} "
        f"and {finetune_model_false / submodule_name}"
    )
    false_cache = PairedActivationCache(
        base_model_false / submodule_name, finetune_model_false / submodule_name
    )

    return true_cache, false_cache

In [6]:
def split_into_sequences(tokens, seq_ranges):
    
    indices_of_bos = seq_ranges
    
    if indices_of_bos[-1] == len(tokens):
        indices_of_bos = indices_of_bos[:-1]
    
    # Split tokens into sequences starting with BOS token
    sequences = []
    index_to_seq_pos = []  # List of (sequence_idx, idx_in_sequence) tuples
    ranges = []
    for i in trange(len(indices_of_bos)):
        start_idx = indices_of_bos[i]
        end_idx = indices_of_bos[i + 1] if i < len(indices_of_bos) - 1 else len(tokens)
        sequence = tokens[start_idx:end_idx]
        sequences.append(sequence)
        ranges.append((start_idx, end_idx))
        # Add mapping for each token in this sequence
        for j in range(len(sequence)):
            index_to_seq_pos.append((i, j))

    return sequences, index_to_seq_pos, ranges

In [7]:
@th.no_grad()
def get_positive_activations(sequences, ranges, dataset, cc, latent_ids):
    """
    Extract positive activations and their indices from sequences.
    Also compute the maximum activation for each latent feature.

    Args:
        sequences: List of sequences
        ranges: List of (start_idx, end_idx) tuples for each sequence
        dataset: Dataset containing activations
        cc: Object with get_activations method
        latent_ids: Tensor of latent indices to extract

    Returns:
        Tuple of:
        - activations tensor: positive activation values
        - indices tensor: in (seq_idx, seq_pos, feature_pos) format
        - max_activations: maximum activation value for each latent feature
    """
    out_activations = []
    out_ids = []
    seq_ranges = [0]

    # Initialize tensors to track max activations for each latent
    max_activations = th.zeros(len(latent_ids), device="cuda")

    for seq_idx in trange(len(sequences)):
        activations = th.stack(
            [dataset[j].cuda().to(dtype=th.float32) for j in range(ranges[seq_idx][0], ranges[seq_idx][1])]
        )
        feature_activations = cc.get_activations(activations)[:, latent_ids]

        assert feature_activations.shape == (
            len(activations),
            len(latent_ids),
        ), f"Feature activations shape: {feature_activations.shape}, expected: {(len(activations), len(latent_ids))}"

        # Track maximum activations
        # For each latent feature, find the max activation in this sequence
        seq_max_values, seq_max_positions = feature_activations.max(dim=0)

        # Update global maximums where this sequence has a higher value
        update_mask = seq_max_values > max_activations
        max_activations[update_mask] = seq_max_values[update_mask]

        # Get indices where feature activations are positive
        pos_mask = feature_activations > 0
        pos_indices = th.nonzero(pos_mask, as_tuple=True)

        # Get the positive activation values
        pos_activations = feature_activations[pos_mask]

        # Create sequence indices tensor matching size of positive indices
        seq_idx_tensor = th.full_like(pos_indices[0], seq_idx)

        # Stack indices into (seq_idx, seq_pos, feature_pos) format
        pos_ids = th.stack([seq_idx_tensor, pos_indices[0], pos_indices[1]], dim=1)

        out_activations.append(pos_activations)
        out_ids.append(pos_ids)
        seq_ranges.append(seq_ranges[-1] + len(pos_ids))

    out_activations = th.cat(out_activations).cpu()
    out_ids = th.cat(out_ids).cpu()
    return out_activations, out_ids, seq_ranges, max_activations

In [8]:
def collect_dictionary_activations(
    dictionary_model_name: str,
    dictionary_model_path: str,
    activation_store_dir: str | Path = "model_activations/",
    base_model: str = "base",
    finetune_model: str = "finetune",
    layer: int = 20,
    latent_activations_dir: str | Path = "latent_activations/",
    split: str = "test",
    load_from_disk: bool = True,
) -> None:
    """
    Compute and save latent activations for a given dictionary model.

    This function processes activations from specified datasets (e.g., FineWeb and LMSYS),
    applies the provided dictionary model to compute latent activations, and saves the results
    to disk. Optionally, it can upload the computed activations to the Hugging Face Hub.

    Args:
        dictionary_model (str): Path or identifier for the dictionary (crosscoder) model to use.
        activation_store_dir (str, optional): Directory containing the raw activation datasets.
            Defaults to $DATASTORE/activations/.
        base_model (str, optional): Name or path of the base model (e.g., "google/gemma-2-2b").
            Defaults to "google/gemma-2-2b".
        chat_model (str, optional): Name or path of the chat/instruct model.
            Defaults to "google/gemma-2-2b-it".
        layer (int, optional): The layer index from which to extract activations.
            Defaults to 13.
        latent_ids (th.Tensor or None, optional): Tensor of latent indices to compute activations for.
            If None, uses all latents in the dictionary model.
        latent_activations_dir (str, optional): Directory to save computed latent activations.
            Defaults to $DATASTORE/latent_activations/.
        upload_to_hub (bool, optional): Whether to upload the computed activations to the Hugging Face Hub.
            Defaults to False.
        split (str, optional): Dataset split to use (e.g., "validation").
            Defaults to "validation".
        load_from_disk (bool, optional): If True, load precomputed activations from disk instead of recomputing.
            Defaults to False.
        is_sae (bool, optional): Whether the model is an SAE rather than a crosscoder.
            Defaults to False.
        is_difference_sae (bool, optional): Whether the SAE is trained on activation differences.
            Defaults to False.

    Returns:
        None
    """

    out_dir = Path(latent_activations_dir) / dictionary_model_name
    out_dir.mkdir(parents=True, exist_ok=True)


    true_cache, false_cache = load_activation_dataset(
        activation_store_dir=activation_store_dir,
        split=split,
    )


    tokens_true = true_cache.tokens[0]
    tokens_false = false_cache.tokens[0]

    true_seq_ranges = true_cache.sequence_ranges
    false_seq_ranges = false_cache.sequence_ranges

    # Load the dictionary model
    dictionary_model = load_dictionary_model(
        dictionary_model_path).to("cuda")

    df = pd.read_csv("latent_df.csv")
    latent_ids = th.tensor(df.index[df["latent_tag"] == "Finetune_only"].to_numpy())
    
    print(latent_ids)

    # Load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained("Qwen3-1.7B")

    seq_false, idx_to_seq_pos_false, ranges_false = split_into_sequences(
        tokens_false, false_seq_ranges
    )
    seq_true, idx_to_seq_pos_true, ranges_true = split_into_sequences(
        tokens_true, true_seq_ranges
    )

    print(
        f"Collecting activations for {len(seq_true)} True sequences and {len(seq_false)} False sequences"
    )

    (
        out_acts_true,
        out_ids_true,
        seq_ranges_true,
        max_activations_true,
    ) = get_positive_activations(
        seq_true, ranges_true, true_cache, dictionary_model, latent_ids
    )
    out_acts_false, out_ids_false, seq_ranges_false, max_activations_false = (
        get_positive_activations(
            seq_false, ranges_false, false_cache, dictionary_model, latent_ids
        )
    )

    out_acts = th.cat([out_acts_true, out_acts_false])
    # add offset to seq_idx in out_ids_false
    out_ids_false[:, 0] += len(seq_true)
    out_ids = th.cat([out_ids_true, out_ids_false])

    seq_ranges_false = [i + len(out_acts_true) for i in seq_ranges_false]
    seq_ranges = th.cat(
        [th.tensor(seq_ranges_true[:-1]), th.tensor(seq_ranges_false)]
    )

    # Combine max activations, taking the maximum between both datasets
    combined_max_activations = th.maximum(
        max_activations_true, max_activations_false
    )

    sequences_all = seq_true + seq_false

    # Find max length
    max_len = max(len(s) for s in sequences_all)
    seq_lengths = th.tensor([len(s) for s in sequences_all])
    # Pad each sequence to max length
    padded_seqs = [
        th.cat(
            [
                s,
                th.full(
                    (max_len - len(s),), tokenizer.pad_token_id, device=s.device
                ),
            ]
        )
        for s in sequences_all
    ]
    # Convert to tensor and save
    padded_tensor = th.stack(padded_seqs)

    # Save tensors
    th.save(out_acts.cpu(), out_dir / "out_acts.pt")
    th.save(out_ids.cpu(), out_dir / "out_ids.pt")
    th.save(padded_tensor.cpu(), out_dir / "padded_sequences.pt")
    th.save(latent_ids.cpu(), out_dir / "latent_ids.pt")
    th.save(seq_ranges.cpu(), out_dir / "seq_ranges.pt")
    th.save(seq_lengths.cpu(), out_dir / "seq_lengths.pt")
    th.save(combined_max_activations.cpu(), out_dir / "max_activations.pt")

    # Print some stats about max activations
    print("Maximum activation statistics:")
    print(f"  Average: {combined_max_activations.mean().item():.4f}")
    print(f"  Maximum: {combined_max_activations.max().item():.4f}")
    print(f"  Minimum: {combined_max_activations.min().item():.4f}")

   
    return LatentActivationCache(out_dir)

In [9]:
dict_model_path = "/pscratch/sd/r/ritesh11/temp_dir/crosscoder_checkpoints/Qwen3-1.7B-L20-k100-lr1e-04-ep2-run_1-Crosscoder/checkpoint_90000.pt"

In [10]:
latent_activation = collect_dictionary_activations('crosscoder',dict_model_path)

Loading true cache from model_activations/base/MATS_true_processed/test/layer_20_out and model_activations/finetune/MATS_true_processed/test/layer_20_out


  self.activation_cache_1 = ActivationCache(store_dir_1, submodule_name)
  self.activation_cache_2 = ActivationCache(store_dir_2, submodule_name)


Loading false cache from model_activations/base/MATS_false_processed/test/layer_20_out and model_activations/finetune/MATS_false_processed/test/layer_20_out
tensor([ 1774,  3609,  6425, 15943, 18232, 24833, 30340, 36090, 43592, 44647,
        51653, 51802, 53929, 56428, 57787, 58237])


  0%|          | 0/3968 [00:00<?, ?it/s]

  0%|          | 0/3968 [00:00<?, ?it/s]

Collecting activations for 3968 True sequences and 3968 False sequences


  0%|          | 0/3968 [00:00<?, ?it/s]

  0%|          | 0/3968 [00:00<?, ?it/s]

Maximum activation statistics:
  Average: 282.5135
  Maximum: 479.5937
  Minimum: 175.3419


NameError: name 'tqdm' is not defined

In [12]:
from tqdm import tqdm

In [15]:
la = LatentActivationCache("/pscratch/sd/r/ritesh11/temp_dir/latent_activations/crosscoder")

Loading cache files: 100%|██████████| 7/7 [00:00<00:00, 111.48it/s, Loading seq_ranges.pt]      


In [26]:
for q in la:
    print(q[0].shape)
    break

torch.Size([2023])
