In [None]:
import numpy as np
if not hasattr(np, "bool"):
    np.bool = bool 

In [None]:
from dataclasses import dataclass, field
from pathlib import Path

import torch as th
from tqdm.auto import trange
from transformers import AutoTokenizer
from dictionary_learning.cache import PairedActivationCache
from dictionary_learning.dictionary import BatchTopKCrossCoder
import json
from tqdm import tqdm

In [None]:
@th.no_grad()
def compute_stats(
    crosscoder: BatchTopKCrossCoder,
    cache: PairedActivationCache,
    device,
    batch_size: int = 2048,
    num_workers: int = 16,
):
    dataloader = th.utils.data.DataLoader(
        cache, batch_size=batch_size, num_workers=num_workers
    )
    max_activations = th.zeros(crosscoder.dict_size, device=device)
    nonzero_counts = th.zeros(crosscoder.dict_size, device=device)
    total_tokens = 0

    for batch in tqdm(dataloader):
        activations = crosscoder.get_activations(
            batch.to(device,dtype=th.float32) 
        )  # (batch_size, dict_size)
        assert activations.shape == (len(batch), crosscoder.dict_size)
        max_activations = th.max(max_activations, activations.max(dim=0).values)
        nonzero_counts += (activations != 0).sum(dim=0)
        total_tokens += activations.shape[0]

    frequencies = nonzero_counts / total_tokens
    assert max_activations.shape == (crosscoder.dict_size,)
    return max_activations.cpu(), frequencies.cpu(), total_tokens

In [None]:
def load_dictionary_model(
    model_name: str | Path, is_sae: bool | None = None,
):
    """Load a dictionary model from a local path or HuggingFace Hub.

    Args:
        model_name: Name or path of the model to load

    Returns:
        The loaded dictionary model
    """
    # Local model
    model_path = Path(model_name)
    if not model_path.exists():
        raise ValueError(f"Local model {model_name} does not exist")

    # Load the config
    with open(model_path.parent / "config.json", "r") as f:
        config = json.load(f)["trainer"]

    # Determine model class based on config
    if "dict_class" in config and config["dict_class"] in [
        "BatchTopKSAE",
        "CrossCoder",
        "BatchTopKCrossCoder",
    ]:
        return eval(f"{config['dict_class']}.from_pretrained(model_path)")
    else:
        raise ValueError(f"Unknown model type: {config['dict_class']}")


In [None]:
def load_activation_dataset(
    activation_store_dir: Path,
    base_model: str = "base",
    finetune_model: str = "finetune",
    layer: int = 20,
    split: str = "train",
    true_split: str | None = None,
    false_split: str | None = None,
    true_name: str = "MATS_true_processed",
    false_name: str = "MATS_false_processed",
):
    """
    Load the saved activations of the base and finetuned models for a given layer.

    Args:
        activation_store_dir: Root directory where activations are stored
        base_model: The base model name
        finetune_model: The finetuned model name
        layer: Layer index to load
        split: Default split to load ("train", "val", etc.)
        true_split: Override split for true dataset
        false_split: Override split for false dataset
        true_name: Dataset name for true facts
        false_name: Dataset name for false facts

    Returns:
        A tuple (true_cache, false_cache) where each is a PairedActivationCache
    """
    # Resolve splits
    if true_split is None:
        true_split = split
    if false_split is None:
        false_split = split

    activation_store_dir = Path(activation_store_dir)

    # Build paths for true dataset
   
    base_model_dir_true = activation_store_dir / base_model
    finetune_model_dir_true = activation_store_dir / finetune_model
    
    # Build paths for false dataset
    base_model_dir_false = activation_store_dir / base_model
    finetune_model_dir_false = activation_store_dir / finetune_model

    submodule_name = f"layer_{layer}_out"

    # Final dataset directories
    base_model_true = base_model_dir_true / true_name / true_split
    finetune_model_true = finetune_model_dir_true / true_name / true_split

    base_model_false = base_model_dir_false / false_name / false_split
    finetune_model_false = finetune_model_dir_false / false_name / false_split

    # Load activation caches
    print(
        f"Loading true cache from {base_model_true / submodule_name} "
        f"and {finetune_model_true / submodule_name}"
    )
    true_cache = PairedActivationCache(
        base_model_true / submodule_name, finetune_model_true / submodule_name
    )

    print(
        f"Loading false cache from {base_model_false / submodule_name} "
        f"and {finetune_model_false / submodule_name}"
    )
    false_cache = PairedActivationCache(
        base_model_false / submodule_name, finetune_model_false / submodule_name
    )

    return true_cache, false_cache

In [None]:
dict_model_path = "crosscoder_checkpoints/Qwen3-1.7B-L20-k100-lr1e-04-ep2-run_1-Crosscoder/checkpoint_90000.pt"

In [None]:
device = 'cuda'
batch_size = 2048
num_workers = 16

In [None]:
crosscoder_model = load_dictionary_model(dict_model_path).to(device)
for split in ["test"]:
    true_dataset, false_dataset = load_activation_dataset(
        "model_activations",
        split=split,
    )

    max_activations_true, frequencies_true, total_tokens_true = compute_stats(
        crosscoder_model, true_dataset, device, batch_size, num_workers
    )
    max_activations_false, frequencies_false, total_tokens_false = (
        compute_stats(
            crosscoder_model,
            false_dataset,
            device,
            batch_size,
            num_workers,
        )
    )
    results[split] = {
        "max_activations_true": max_activations_true.tolist(),
        "max_activations_false": max_activations_false.tolist(),
        "frequencies_true": frequencies_true.tolist(),
        "frequencies_false": frequencies_false.tolist(),
        "total_tokens_true": total_tokens_true,
        "total_tokens_false": total_tokens_false,
    }

In [None]:
results[split] = {
        "max_activations_true": max_activations_true.tolist(),
        "max_activations_false": max_activations_false.tolist(),
        "frequencies_true": frequencies_true.tolist(),
        "frequencies_false": frequencies_false.tolist(),
        "total_tokens_true": total_tokens_true,
        "total_tokens_false": total_tokens_false,
    }

In [None]:
len(results['test']['max_activations_true'])

In [None]:
import pandas as pd

In [None]:
latent_df = pd.DataFrame(results['test'])

In [None]:
latent_df['latent_id'] = [i for i in range(len(latent_df))]

In [None]:
latent_df.to_csv('latent_df.csv',index=False)