In [24]:
import os

import matplotlib.pyplot as plt
import torch

from src.crosscoder.crosscoder import CrossCoder

torch.set_grad_enabled(False)
import einops
import numpy as np
from datasets import Dataset
from PIL import Image
from tqdm import tqdm

SMALL_SIZE = 22
MEDIUM_SIZE = 24
BIGGER_SIZE = 26
plt.rc("font", size=SMALL_SIZE, family="Times New Roman")  # controls default text sizes
plt.rc(
    "axes", titlesize=BIGGER_SIZE, labelsize=MEDIUM_SIZE
)  # fontsize of the axes title
plt.rc("axes", labelsize=MEDIUM_SIZE)  # fontsize of the x and y labels
plt.rc("xtick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("ytick", labelsize=SMALL_SIZE)  # fontsize of the tick labels
plt.rc("legend", fontsize=SMALL_SIZE)  # legend fontsize
plt.rc("figure", titlesize=BIGGER_SIZE, labelsize=SMALL_SIZE)  # f

colors = ["#386EC2", "#B5B5B2", "#990006", "#625D0A", "#B9741F", "#213958"]


In [2]:
dtype = torch.float16
ckpt_path = "/home/bcywinski/code/diffing/sae-ckpts/crosscoder-sdxl/expansion_factor4_l120.0_dec_init_norm0.1_4_steps"
hookpoint = "down_blocks.2"


In [3]:
crosscoder = CrossCoder.load_from_disk(
    os.path.join(
        ckpt_path,
        hookpoint,
    ),
    device="cuda",
).to(dtype)


In [4]:
dataset_path = "/data/SHARE/datasets/"

In [5]:
model1_dataset_path = "/data/bcywinski/activations/coco2017/sdxl_4/output/"
model2_dataset_path = "/data/bcywinski/activations/coco2017/sdxl-turbo/output/"
model1_num_timesteps = 4
model2_num_timesteps = 4

In [None]:
from src.scripts.train_crosscoder import PairedDataset

model1_dataset = Dataset.load_from_disk(
    os.path.join(model1_dataset_path, hookpoint), keep_in_memory=False
)
model2_dataset = Dataset.load_from_disk(
    os.path.join(model2_dataset_path, hookpoint), keep_in_memory=False
)
model1_dataset.set_format(
    type="torch", columns=["activations", "timestep", "file_name"], dtype=dtype
)
model2_dataset.set_format(
    type="torch", columns=["activations", "timestep", "file_name"], dtype=dtype
)


paired_dataset = PairedDataset(
    model1_dataset,
    model2_dataset,
    100,
    True,
    model1_num_timesteps,
    model2_num_timesteps,
)


In [7]:
dtype = torch.float16

In [None]:
avg_activations_per_sample1 = torch.zeros(
    (len(paired_dataset), crosscoder.num_latents), dtype=torch.float16
)
batch_size = 16
dl = torch.utils.data.DataLoader(
    paired_dataset, batch_size=16, shuffle=False, num_workers=4
)
for i, batch in tqdm(enumerate(dl), total=len(dl)):
    acts = batch["model1"]
    acts = acts.to(crosscoder.W_dec.device)
    acts = einops.rearrange(
        acts,
        "batch sample_size d_model -> (batch sample_size) d_model",
    )
    out = einops.einsum(
        acts,
        crosscoder.W_enc[0],
        "batch d_model, d_model d_latent -> batch d_latent",
    )
    out = torch.nn.functional.relu(out + crosscoder.b_enc[0])
    # Reshape to get per-sample activations and compute mean for each sample
    out = out.view(
        batch["model1"].shape[0], -1, crosscoder.num_latents
    )  # [batch, sample_size, num_latents]
    batch_avg_activations = out.mean(dim=1).to(
        dtype=torch.float16
    )  # [batch, num_latents]

    # Store in the correct indices
    start_idx = i * batch_size
    end_idx = min(start_idx + batch_size, len(paired_dataset))
    avg_activations_per_sample1[start_idx:end_idx] = batch_avg_activations


In [39]:
def find_topk_activating_examples(activations_per_sample, latent_idx, k=10):
    topk_indices = torch.argsort(
        activations_per_sample[:, latent_idx], dim=0, descending=True
    )[:k]
    return topk_indices


In [40]:
latent_idx = 2735


In [None]:
k = 10
topk_indices1 = find_topk_activating_examples(
    avg_activations_per_sample1, latent_idx, k
)  # find topk samples containing patches with higest activations
topk_samples1 = paired_dataset[topk_indices1.tolist()]
sae_latents1 = []
activations1 = topk_samples1["model1"]
timesteps1 = topk_samples1["model1_timestep"]
file_names_topk1 = topk_samples1["file_name"]
activations1 = einops.rearrange(
    activations1,
    "batch sample_size d_model -> (batch sample_size) d_model",
)
activations1 = activations1.to(crosscoder.W_dec.device)
out = einops.einsum(
    activations1,
    crosscoder.W_enc[0],
    "batch d_model, d_model d_latent -> batch d_latent",
)
out = torch.nn.functional.relu(out + crosscoder.b_enc[0])
sae_latents1 = out.view(k, -1, crosscoder.num_latents)
sae_latents1.shape


In [None]:
fig, axes = plt.subplots(2, len(topk_indices1), figsize=(18, 6))

# Plot max activating examples in two rows:
# Row 1: Original images from Model 1
# Row 2: Model 1 activations
for i in range(len(topk_indices1)):
    # Model 1 images
    img1 = Image.open(os.path.join(dataset_path, file_names_topk1[i]))
    img1 = img1.resize((512, 512))
    img1 = img1.convert("RGB")

    # Process activations for model 1
    sae_latent_activations1 = sae_latents1[i].reshape(
        int(torch.sqrt(torch.tensor(sae_latents1.shape[1])).item()),
        int(torch.sqrt(torch.tensor(sae_latents1.shape[1])).item()),
        -1,
    )[:, :, latent_idx]
    # Convert latent activations to numpy and normalize
    activation_map1 = sae_latent_activations1[:, :].detach().cpu().numpy()
    activation_map1 = (activation_map1 - activation_map1.min()) / (
        activation_map1.max() - activation_map1.min() + 1e-8
    )

    # Calculate upscale factor to match image size for model 1
    patch_size1 = 512 // activation_map1.shape[0]
    activation_map1 = np.kron(activation_map1, np.ones((patch_size1, patch_size1)))

    # Create heatmap overlays
    heatmap1 = np.uint8(plt.cm.jet(activation_map1)[..., :3] * 255)
    heatmap1 = Image.fromarray(heatmap1)

    # Blend original images with heatmaps
    blended_img1 = Image.blend(img1, heatmap1, alpha=0.4)

    # Calculate average activation for the image
    avg_activation = sae_latent_activations1.mean().item()

    # Row 1: Original images
    axes[0, i].imshow(img1)
    axes[0, i].axis("off")
    axes[0, i].set_title(
        f"Activation: {avg_activation:.2f}\nTimestep: {int(timesteps1[i].item())}",
        fontsize=SMALL_SIZE,
    )
    if i == 0:
        axes[0, 0].set_ylabel("Original Images", fontsize=SMALL_SIZE)

    # Row 2: Activations
    axes[1, i].imshow(blended_img1)
    axes[1, i].axis("off")
    if i == 0:
        axes[1, 0].set_ylabel("Activations", fontsize=SMALL_SIZE)

plt.suptitle(f"Max Activating Examples for Neuron {latent_idx}", fontsize=BIGGER_SIZE)
plt.tight_layout()


In [None]:
# measure l0 of latent space on whole dataset
# Calculate L0 sparsity across the whole dataset
def calculate_average_l0(dataloader, threshold=0.0):
    """
    Calculate the average L0 sparsity (number of active neurons) in latent space

    Args:
        dataloader: DataLoader containing the dataset
        threshold: Activation threshold (neurons with activation > threshold are considered active)

    Returns:
        average_l0: Average number of active neurons per sample
        l0_per_sample: Array of L0 values for each sample
    """
    l0_per_sample = []
    total_samples = 0
    neuron_activated_number = torch.zeros(crosscoder.num_latents)

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Calculating L0 sparsity"):
            acts = batch["model1"]
            acts = acts.to(crosscoder.W_dec.device)
            acts = einops.rearrange(
                acts,
                "batch sample_size d_model -> (batch sample_size) d_model",
            )

            # Encode to latent space
            latents = einops.einsum(
                acts,
                crosscoder.W_enc[0],
                "batch d_model, d_model d_latent -> batch d_latent",
            )
            latents = torch.nn.functional.relu(latents + crosscoder.b_enc[0])
            neuron_activated_number += (latents > threshold).sum(dim=0).cpu()

            # Count active neurons (L0 norm)
            active_neurons = (latents > threshold).sum(dim=1)

            # Store L0 values
            l0_per_sample.extend(active_neurons.cpu().tolist())
            total_samples += latents.shape[0]

    l0_per_sample = np.array(l0_per_sample)
    average_l0 = l0_per_sample.mean()

    return average_l0, l0_per_sample, neuron_activated_number


# Calculate for the whole dataset with a minimal threshold
dl = torch.utils.data.DataLoader(
    paired_dataset, batch_size=16, shuffle=False, num_workers=4
)
avg_l0, l0_values, neuron_activated_number = calculate_average_l0(dl, threshold=0.01)
print(
    f"Average L0 sparsity (active neurons): {avg_l0:.2f} out of {crosscoder.num_latents}"
)
print(f"Average L0 percent: {100 * avg_l0 / crosscoder.num_latents:.2f}%")


In [None]:
# Visualize the distribution of L0 values
plt.figure(figsize=(10, 6))
plt.hist(l0_values, bins=50, alpha=0.7)
plt.axvline(
    avg_l0,
    color="blue",
    linestyle="dashed",
    linewidth=2,
    label=f"Avg: {avg_l0:.2f}",
)
plt.xlabel("Number of Active Neurons")
plt.ylabel("Count")
plt.title("L0 of Latent Space - Base Model")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()


In [10]:
timestep_map = [
    model1_dataset[0]["timestep"].item(),
    model1_dataset[1]["timestep"].item(),
    model1_dataset[2]["timestep"].item(),
    model1_dataset[3]["timestep"].item(),
]


In [None]:
# Create 4 histograms of L0 values, sampling every 256th value with different offsets
fig, axes = plt.subplots(2, 2, figsize=(14, 10))
axes = axes.flatten()

# Create 4 histograms with different starting points
for i, start in enumerate([0, 1, 2, 3]):
    # Select every 256th value starting from 'start'
    l0_subset = l0_values.reshape(len(model1_dataset), 256)[start::4].flatten()
    print(l0_subset.shape)

    # Plot histogram
    axes[i].hist(l0_subset, bins=50, alpha=0.7, color=f"C{i}")

    # Add mean line
    subset_mean = l0_subset.mean()
    axes[i].axvline(
        subset_mean,
        color="red",
        linestyle="dashed",
        linewidth=2,
        label=f"Avg: {subset_mean:.2f}",
    )

    # Calculate and display percentages
    subset_percent = 100 * subset_mean / crosscoder.num_latents

    axes[i].set_xlabel("Number of Active Neurons")
    axes[i].set_ylabel("Count")
    axes[i].set_title(f"L0 timestep {int(timestep_map[start])}")

    axes[i].legend()
    axes[i].grid(alpha=0.3, axis="y")

plt.suptitle("L0 across timesteps - Base Model")
plt.tight_layout()


In [None]:
# Calculate average L0 for each timestep
timestep_averages = []
timestep_percentages = []


# Create 4 subsets with different starting points (corresponding to different timesteps)
for i, start in enumerate([0, 1, 2, 3]):
    # Select values for this timestep (every 4th sample starting from the offset)
    l0_subset = l0_values.reshape(len(model1_dataset), 256)[start::4].flatten()

    # Calculate average L0 for this timestep
    avg_l0_timestep = l0_subset.mean()
    timestep_averages.append(avg_l0_timestep)

    # Calculate percentage of active neurons
    percent_active = 100 * avg_l0_timestep / crosscoder.num_latents
    timestep_percentages.append(percent_active)

# Create bar plot
plt.figure(figsize=(10, 6))

# Use the timestep_map to get actual timestep values for x-axis labels
x_labels = [f"Timestep {int(timestep_map[i])}" for i in range(4)]
bars = plt.bar(x_labels, timestep_averages, color=colors[:4])

# Add value labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 5,  # Small offset above the bar
        f"{height:.1f}",
        ha="center",
        va="bottom",
    )

plt.title("Average Number of Active Neurons by Timestep - Base Model")
plt.ylabel("Average L0 (Active Neurons)")
plt.ylim(0, 100)
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()


In [None]:
timestep_averages


In [None]:
timestep_averages2 = [
    np.float64(3.3881645833333334),
    np.float64(10.302678776041667),
    np.float64(29.474933463541667),
    np.float64(53.709834244791665),
]
# Create bar plot
plt.figure(figsize=(10, 7))

# Use the timestep_map to get actual timestep values for x-axis labels
x_labels = [f"Timestep {int(timestep_map[i])}" for i in range(4)]

# Set width and positions for the bars
bar_width = 0.35
x_positions = np.arange(len(x_labels))

# Create grouped bars for both sets of averages - use single colors for each group
bars1 = plt.bar(
    x_positions - bar_width / 2,
    timestep_averages,
    bar_width,
    color=colors[1],  # Single blue color for all left bars
    label="Base Model",
    alpha=0.7,
)
bars2 = plt.bar(
    x_positions + bar_width / 2,
    timestep_averages2,
    bar_width,
    color=colors[-2],  # Single red color for all right bars
    label="Distilled Model",
    alpha=0.7,
)

# Add value labels on top of each bar
for bars in [bars1, bars2]:
    for bar in bars:
        height = bar.get_height()
        plt.text(
            bar.get_x() + bar.get_width() / 2.0,
            height + 2,  # Small offset above the bar
            f"{height:.1f}",
            ha="center",
            va="bottom",
        )

plt.title("Average Number of Active Neurons per Timestep")
plt.ylabel("Average L0 (Active Neurons)")
plt.ylim(0, 60)
plt.grid(axis="y", alpha=0.25, linestyle="--")
plt.xticks(x_positions, x_labels)
plt.legend()
plt.tight_layout()
plt.savefig("figures/l0_timesteps.pdf", dpi=300, format="pdf", bbox_inches="tight")


In [None]:
neuron_activated_number.shape

In [None]:
norms = crosscoder.W_dec.norm(dim=-1).cpu()
relative_norms = norms[:, 1] / norms.sum(dim=-1)
only_base_features_mask = relative_norms < 0.25
only_turbo_features_mask = relative_norms > 0.75
shared_features_mask = (relative_norms >= 0.25) & (relative_norms <= 0.75)
only_base_features_mask.shape

In [None]:
# Calculate average number of activated neurons for base and turbo features
base_activated = neuron_activated_number[only_base_features_mask].mean().item()
turbo_activated = neuron_activated_number[only_turbo_features_mask].mean().item()
shared_activated = neuron_activated_number[shared_features_mask].mean().item()

# Create a bar plot
plt.figure(figsize=(10, 6))
categories = ["Base", "Shared", "Distilled"]
values = [base_activated, shared_activated, turbo_activated]
colors = ["#3498db", "#2ecc71", "#e74c3c"]

bars = plt.bar(categories, values, color=colors)

# Add value labels on top of each bar
for bar in bars:
    height = bar.get_height()
    plt.text(
        bar.get_x() + bar.get_width() / 2.0,
        height + 0.01,  # Reduced offset to avoid warning
        f"{height:.4f}",  # Increased precision for small values
        ha="center",
        va="bottom",
    )

plt.title("Average Number of Times a Neuron Fired - Base Model")
plt.ylim(0, 800000)  # Set y-limit with appropriate headroom
plt.grid(axis="y", alpha=0.3)
plt.tight_layout(pad=1.1)  # Added padding to avoid tight layout warning
plt.show()

In [None]:
# Plot number of activated neurons from base, shared, and distilled groups per timestep for both models
def plot_activated_neurons_by_group_both_models(threshold=0.0):
    print("Analyzing activated neurons by group and timestep for both models...")

    # Store counts for each group by timestep for both models
    model1_activated_counts = {
        ts: {"base": 0, "shared": 0, "distilled": 0, "samples": 0}
        for ts in timestep_map
    }

    model2_activated_counts = {
        ts: {"base": 0, "shared": 0, "distilled": 0, "samples": 0}
        for ts in timestep_map
    }

    # Create dataloader for batch processing
    dl = torch.utils.data.DataLoader(
        paired_dataset, batch_size=16, shuffle=False, num_workers=4
    )

    # Process dataset in batches
    with torch.no_grad():
        for batch in tqdm(dl, desc="Processing samples"):
            # Process base model activations (model1)
            acts_model1 = batch["model1"]
            timesteps_model1 = batch["model1_timestep"]

            # Process distilled model activations (model2)
            acts_model2 = batch["model2"]
            timesteps_model2 = batch["model2_timestep"]

            # Analyze each sample in model1
            for i in range(acts_model1.shape[0]):
                ts = int(timesteps_model1[i].item())
                if ts not in model1_activated_counts:
                    continue

                # Get activations for this sample
                act = acts_model1[i : i + 1].to(
                    crosscoder.W_dec.device
                )  # Add batch dimension

                # Process through encoder
                act = act.reshape(
                    -1, act.shape[-1]
                )  # Reshape to [batch*spatial, d_model]

                # Get latent activations using model1 encoder
                latents = einops.einsum(
                    act,
                    crosscoder.W_enc[0],
                    "batch d_model, d_model d_latent -> batch d_latent",
                )
                latents = torch.nn.functional.relu(latents + crosscoder.b_enc[0])

                # Count activated neurons at each spatial position
                activated_per_position = (latents > threshold).float().cpu()

                # Count across all spatial positions for each neuron type
                base_activated = (
                    activated_per_position[:, only_base_features_mask].sum(dim=1).mean()
                ).item()
                shared_activated = (
                    activated_per_position[:, shared_features_mask].sum(dim=1).mean()
                ).item()
                distilled_activated = (
                    activated_per_position[:, only_turbo_features_mask]
                    .sum(dim=1)
                    .mean()
                ).item()

                # Add to counts
                model1_activated_counts[ts]["base"] += base_activated
                model1_activated_counts[ts]["shared"] += shared_activated
                model1_activated_counts[ts]["distilled"] += distilled_activated
                model1_activated_counts[ts]["samples"] += 1

            # Analyze each sample in model2
            for i in range(acts_model2.shape[0]):
                ts = int(timesteps_model2[i].item())
                if ts not in model2_activated_counts:
                    continue

                # Get activations for this sample
                act = acts_model2[i : i + 1].to(
                    crosscoder.W_dec.device
                )  # Add batch dimension

                # Process through encoder
                act = act.reshape(
                    -1, act.shape[-1]
                )  # Reshape to [batch*spatial, d_model]

                # Get latent activations using model2 encoder
                latents = einops.einsum(
                    act,
                    crosscoder.W_enc[1],  # Use model2 encoder weights
                    "batch d_model, d_model d_latent -> batch d_latent",
                )
                latents = torch.nn.functional.relu(
                    latents + crosscoder.b_enc[1]
                )  # Use model2 bias

                # Count activated neurons at each spatial position
                activated_per_position = (latents > threshold).float().cpu()

                # Count across all spatial positions for each neuron type
                base_activated = (
                    activated_per_position[:, only_base_features_mask].sum(dim=1).mean()
                ).item()
                shared_activated = (
                    activated_per_position[:, shared_features_mask].sum(dim=1).mean()
                ).item()
                distilled_activated = (
                    activated_per_position[:, only_turbo_features_mask]
                    .sum(dim=1)
                    .mean()
                ).item()

                # Add to counts
                model2_activated_counts[ts]["base"] += base_activated
                model2_activated_counts[ts]["shared"] += shared_activated
                model2_activated_counts[ts]["distilled"] += distilled_activated
                model2_activated_counts[ts]["samples"] += 1

    # Calculate averages for model1
    model1_timesteps = []
    model1_base_counts = []
    model1_shared_counts = []
    model1_distilled_counts = []

    for ts in sorted(model1_activated_counts.keys()):
        if model1_activated_counts[ts]["samples"] > 0:
            model1_timesteps.append(ts)
            model1_base_counts.append(
                model1_activated_counts[ts]["base"]
                / model1_activated_counts[ts]["samples"]
            )
            model1_shared_counts.append(
                model1_activated_counts[ts]["shared"]
                / model1_activated_counts[ts]["samples"]
            )
            model1_distilled_counts.append(
                model1_activated_counts[ts]["distilled"]
                / model1_activated_counts[ts]["samples"]
            )

    # Calculate averages for model2
    model2_timesteps = []
    model2_base_counts = []
    model2_shared_counts = []
    model2_distilled_counts = []

    for ts in sorted(model2_activated_counts.keys()):
        if model2_activated_counts[ts]["samples"] > 0:
            model2_timesteps.append(ts)
            model2_base_counts.append(
                model2_activated_counts[ts]["base"]
                / model2_activated_counts[ts]["samples"]
            )
            model2_shared_counts.append(
                model2_activated_counts[ts]["shared"]
                / model2_activated_counts[ts]["samples"]
            )
            model2_distilled_counts.append(
                model2_activated_counts[ts]["distilled"]
                / model2_activated_counts[ts]["samples"]
            )

    # Create figure with subplots for both models
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 8))

    # Model 1 (Base Model) - Stacked bar chart
    x1 = np.arange(len(model1_timesteps))
    width = 0.6

    # Create stacked bars for model1
    ax1.bar(x1, model1_base_counts, width, label="Base-specific", color=colors[0])
    ax1.bar(
        x1,
        model1_shared_counts,
        width,
        bottom=model1_base_counts,
        label="Shared",
        color=colors[1],
    )

    # Calculate total heights for adding the distilled bars
    model1_bottoms = [b + s for b, s in zip(model1_base_counts, model1_shared_counts)]
    ax1.bar(
        x1,
        model1_distilled_counts,
        width,
        bottom=model1_bottoms,
        label="Distilled-specific",
        color=colors[2],
    )

    # Add text annotations for model1
    for i in range(len(x1)):
        # Base count (middle of segment)
        if model1_base_counts[i] > 0.5:
            ax1.text(
                x1[i],
                model1_base_counts[i] / 2,
                f"{model1_base_counts[i]:.1f}",
                ha="center",
                va="center",
                color="white",
                fontweight="bold",
            )

        # Shared count (middle of segment)
        if model1_shared_counts[i] > 0.5:
            ax1.text(
                x1[i],
                model1_base_counts[i] + model1_shared_counts[i] / 2,
                f"{model1_shared_counts[i]:.1f}",
                ha="center",
                va="center",
                color="black",
                fontweight="bold",
            )

        # Distilled count (middle of segment)
        if model1_distilled_counts[i] > 0.5:
            ax1.text(
                x1[i],
                model1_bottoms[i] + model1_distilled_counts[i] / 2,
                f"{model1_distilled_counts[i]:.1f}",
                ha="center",
                va="center",
                color="white",
                fontweight="bold",
            )

        # Total count (on top)
        total = (
            model1_base_counts[i] + model1_shared_counts[i] + model1_distilled_counts[i]
        )
        ax1.text(
            x1[i],
            model1_bottoms[i] + model1_distilled_counts[i] + 0.5,
            f"Total: {total:.1f}",
            ha="center",
            va="bottom",
            fontweight="bold",
        )

    # Add model1 sample counts below x-axis
    for i, ts in enumerate(model1_timesteps):
        ax1.text(
            i,
            -1.5,
            f"n={model1_activated_counts[ts]['samples']}",
            ha="center",
            fontsize=10,
        )

    ax1.set_xlabel("Timestep")
    ax1.set_ylabel("Average Number of Activated Neurons")
    ax1.set_title("Base Model - Average Activated Neurons by Group")
    ax1.set_xticks(x1)
    ax1.set_xticklabels([str(ts) for ts in model1_timesteps])
    ax1.legend(loc="upper right")
    ax1.grid(axis="y", alpha=0.3)

    # Model 2 (Distilled Model) - Stacked bar chart
    x2 = np.arange(len(model2_timesteps))

    # Create stacked bars for model2
    ax2.bar(x2, model2_base_counts, width, label="Base-specific", color=colors[0])
    ax2.bar(
        x2,
        model2_shared_counts,
        width,
        bottom=model2_base_counts,
        label="Shared",
        color=colors[1],
    )

    # Calculate total heights for adding the distilled bars
    model2_bottoms = [b + s for b, s in zip(model2_base_counts, model2_shared_counts)]
    ax2.bar(
        x2,
        model2_distilled_counts,
        width,
        bottom=model2_bottoms,
        label="Distilled-specific",
        color=colors[2],
    )

    # Add text annotations for model2
    for i in range(len(x2)):
        # Base count (middle of segment)
        if model2_base_counts[i] > 0.5:
            ax2.text(
                x2[i],
                model2_base_counts[i] / 2,
                f"{model2_base_counts[i]:.1f}",
                ha="center",
                va="center",
                color="white",
                fontweight="bold",
            )

        # Shared count (middle of segment)
        if model2_shared_counts[i] > 0.5:
            ax2.text(
                x2[i],
                model2_base_counts[i] + model2_shared_counts[i] / 2,
                f"{model2_shared_counts[i]:.1f}",
                ha="center",
                va="center",
                color="black",
                fontweight="bold",
            )

        # Distilled count (middle of segment)
        if model2_distilled_counts[i] > 0.5:
            ax2.text(
                x2[i],
                model2_bottoms[i] + model2_distilled_counts[i] / 2,
                f"{model2_distilled_counts[i]:.1f}",
                ha="center",
                va="center",
                color="white",
                fontweight="bold",
            )

        # Total count (on top)
        total = (
            model2_base_counts[i] + model2_shared_counts[i] + model2_distilled_counts[i]
        )
        ax2.text(
            x2[i],
            model2_bottoms[i] + model2_distilled_counts[i] + 0.5,
            f"Total: {total:.1f}",
            ha="center",
            va="bottom",
            fontweight="bold",
        )

    # Add model2 sample counts below x-axis
    for i, ts in enumerate(model2_timesteps):
        ax2.text(
            i,
            -1.5,
            f"n={model2_activated_counts[ts]['samples']}",
            ha="center",
            fontsize=10,
        )

    ax2.set_xlabel("Timestep")
    ax2.set_ylabel("Average Number of Activated Neurons")
    ax2.set_title("Distilled Model - Average Activated Neurons by Group")
    ax2.set_xticks(x2)
    ax2.set_xticklabels([str(ts) for ts in model2_timesteps])
    ax2.legend(loc="upper right")
    ax2.grid(axis="y", alpha=0.3)

    plt.suptitle(
        "Comparison of Activated Neurons by Group and Timestep", fontsize=BIGGER_SIZE
    )
    plt.tight_layout()
    plt.savefig(
        "figures/activated_neurons_by_group_both_models.pdf",
        dpi=300,
        format="pdf",
        bbox_inches="tight",
    )

    # Also create a plot comparing total activations between models
    plt.figure(figsize=(12, 8))

    # Prepare data
    common_timesteps = sorted(set(model1_timesteps).intersection(set(model2_timesteps)))

    model1_totals = []
    model2_totals = []

    for ts in common_timesteps:
        idx1 = model1_timesteps.index(ts)
        idx2 = model2_timesteps.index(ts)

        model1_total = (
            model1_base_counts[idx1]
            + model1_shared_counts[idx1]
            + model1_distilled_counts[idx1]
        )
        model2_total = (
            model2_base_counts[idx2]
            + model2_shared_counts[idx2]
            + model2_distilled_counts[idx2]
        )

        model1_totals.append(model1_total)
        model2_totals.append(model2_total)

    # Create grouped bar chart
    x = np.arange(len(common_timesteps))
    width = 0.35

    plt.bar(x - width / 2, model1_totals, width, label="Base Model", color=colors[0])
    plt.bar(
        x + width / 2, model2_totals, width, label="Distilled Model", color=colors[2]
    )

    # Add annotations
    for i in range(len(x)):
        plt.text(
            x[i] - width / 2,
            model1_totals[i] + 0.5,
            f"{model1_totals[i]:.1f}",
            ha="center",
            va="bottom",
        )
        plt.text(
            x[i] + width / 2,
            model2_totals[i] + 0.5,
            f"{model2_totals[i]:.1f}",
            ha="center",
            va="bottom",
        )

    plt.xlabel("Timestep")
    plt.ylabel("Average Number of Activated Neurons")
    plt.title("Total Activated Neurons Comparison")
    plt.xticks(x, [str(ts) for ts in common_timesteps])
    plt.legend()
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()
    plt.savefig(
        "figures/total_activated_neurons_comparison.pdf",
        dpi=300,
        format="pdf",
        bbox_inches="tight",
    )

    return {"model1": model1_activated_counts, "model2": model2_activated_counts}


# Run the function
activated_neurons_data_both_models = plot_activated_neurons_by_group_both_models(
    threshold=0.01
)


In [None]:
# Plot number of activated neurons from base, shared, and distilled groups per timestep for both models
def plot_activated_neurons_by_group_both_models(threshold=0.0):
    print("Analyzing activated neurons by group and timestep for both models...")

    # Store counts for each group by timestep for both models
    model1_activated_counts = {
        ts: {"base": 0, "shared": 0, "distilled": 0, "samples": 0}
        for ts in timestep_map
    }

    model2_activated_counts = {
        ts: {"base": 0, "shared": 0, "distilled": 0, "samples": 0}
        for ts in timestep_map
    }

    # Create dataloader for batch processing
    dl = torch.utils.data.DataLoader(
        paired_dataset, batch_size=16, shuffle=False, num_workers=4
    )

    # Process dataset in batches
    with torch.no_grad():
        for batch in tqdm(dl, desc="Processing samples"):
            # Process base model activations (model1)
            acts_model1 = batch["model1"]
            timesteps_model1 = batch["model1_timestep"]

            # Process distilled model activations (model2)
            acts_model2 = batch["model2"]
            timesteps_model2 = batch["model2_timestep"]

            # Analyze each sample in model1
            for i in range(acts_model1.shape[0]):
                ts = int(timesteps_model1[i].item())
                if ts not in model1_activated_counts:
                    continue

                # Get activations for this sample
                act = acts_model1[i : i + 1].to(
                    crosscoder.W_dec.device
                )  # Add batch dimension

                # Process through encoder
                act = act.reshape(
                    -1, act.shape[-1]
                )  # Reshape to [batch*spatial, d_model]

                # Get latent activations using model1 encoder
                latents = einops.einsum(
                    act,
                    crosscoder.W_enc[0],
                    "batch d_model, d_model d_latent -> batch d_latent",
                )
                latents = torch.nn.functional.relu(latents + crosscoder.b_enc[0])

                # Count activated neurons at each spatial position
                activated_per_position = (latents > threshold).float().cpu()

                # Count across all spatial positions for each neuron type
                base_activated = (
                    activated_per_position[:, only_base_features_mask].sum(dim=1).mean()
                ).item()
                shared_activated = (
                    activated_per_position[:, shared_features_mask].sum(dim=1).mean()
                ).item()
                distilled_activated = (
                    activated_per_position[:, only_turbo_features_mask]
                    .sum(dim=1)
                    .mean()
                ).item()

                # Add to counts
                model1_activated_counts[ts]["base"] += base_activated
                model1_activated_counts[ts]["shared"] += shared_activated
                model1_activated_counts[ts]["distilled"] += distilled_activated
                model1_activated_counts[ts]["samples"] += 1

            # Analyze each sample in model2
            for i in range(acts_model2.shape[0]):
                ts = int(timesteps_model2[i].item())
                if ts not in model2_activated_counts:
                    continue

                # Get activations for this sample
                act = acts_model2[i : i + 1].to(
                    crosscoder.W_dec.device
                )  # Add batch dimension

                # Process through encoder
                act = act.reshape(
                    -1, act.shape[-1]
                )  # Reshape to [batch*spatial, d_model]

                # Get latent activations using model2 encoder
                latents = einops.einsum(
                    act,
                    crosscoder.W_enc[1],  # Use model2 encoder weights
                    "batch d_model, d_model d_latent -> batch d_latent",
                )
                latents = torch.nn.functional.relu(
                    latents + crosscoder.b_enc[1]
                )  # Use model2 bias

                # Count activated neurons at each spatial position
                activated_per_position = (latents > threshold).float().cpu()

                # Count across all spatial positions for each neuron type
                base_activated = (
                    activated_per_position[:, only_base_features_mask].sum(dim=1).mean()
                ).item()
                shared_activated = (
                    activated_per_position[:, shared_features_mask].sum(dim=1).mean()
                ).item()
                distilled_activated = (
                    activated_per_position[:, only_turbo_features_mask]
                    .sum(dim=1)
                    .mean()
                ).item()

                # Add to counts
                model2_activated_counts[ts]["base"] += base_activated
                model2_activated_counts[ts]["shared"] += shared_activated
                model2_activated_counts[ts]["distilled"] += distilled_activated
                model2_activated_counts[ts]["samples"] += 1

    return {"model1": model1_activated_counts, "model2": model2_activated_counts}


# Run the function to collect data
activated_neurons_data_both_models = plot_activated_neurons_by_group_both_models(
    threshold=0.01
)

# Create a combined plot outside the function
# Extract data for plotting
model1_data = activated_neurons_data_both_models["model1"]
model2_data = activated_neurons_data_both_models["model2"]

# Find common timesteps between both models
common_timesteps = sorted(set(model1_data.keys()).intersection(set(model2_data.keys())))

# Calculate averages for each model and timestep
model1_timesteps = []
model1_base_counts = []
model1_shared_counts = []
model1_distilled_counts = []

model2_timesteps = []
model2_base_counts = []
model2_shared_counts = []
model2_distilled_counts = []

for ts in common_timesteps:
    if model1_data[ts]["samples"] > 0 and model2_data[ts]["samples"] > 0:
        # Add to model1 data
        model1_timesteps.append(ts)
        model1_base_counts.append(model1_data[ts]["base"] / model1_data[ts]["samples"])
        model1_shared_counts.append(
            model1_data[ts]["shared"] / model1_data[ts]["samples"]
        )
        model1_distilled_counts.append(
            model1_data[ts]["distilled"] / model1_data[ts]["samples"]
        )

        # Add to model2 data
        model2_timesteps.append(ts)
        model2_base_counts.append(model2_data[ts]["base"] / model2_data[ts]["samples"])
        model2_shared_counts.append(
            model2_data[ts]["shared"] / model2_data[ts]["samples"]
        )
        model2_distilled_counts.append(
            model2_data[ts]["distilled"] / model2_data[ts]["samples"]
        )


In [None]:
# Create a combined plot outside the function
# Extract data for plotting
model1_data = activated_neurons_data_both_models["model1"]
model2_data = activated_neurons_data_both_models["model2"]

# Find common timesteps between both models and filter to desired range (999 to 249)
common_timesteps = sorted(
    [
        ts
        for ts in set(model1_data.keys()).intersection(set(model2_data.keys()))
        if 249 <= ts <= 999
    ],
    reverse=True,  # Sort in descending order (from 999 to 249)
)

# Calculate averages for each model and timestep
model1_timesteps = []
model1_base_counts = []
model1_shared_counts = []
model1_distilled_counts = []

model2_timesteps = []
model2_base_counts = []
model2_shared_counts = []
model2_distilled_counts = []

for ts in common_timesteps:
    if model1_data[ts]["samples"] > 0 and model2_data[ts]["samples"] > 0:
        # Add to model1 data
        model1_timesteps.append(ts)
        model1_base_counts.append(model1_data[ts]["base"] / model1_data[ts]["samples"])
        model1_shared_counts.append(
            model1_data[ts]["shared"] / model1_data[ts]["samples"]
        )
        model1_distilled_counts.append(
            model1_data[ts]["distilled"] / model1_data[ts]["samples"]
        )

        # Add to model2 data
        model2_timesteps.append(ts)
        model2_base_counts.append(model2_data[ts]["base"] / model2_data[ts]["samples"])
        model2_shared_counts.append(
            model2_data[ts]["shared"] / model2_data[ts]["samples"]
        )
        model2_distilled_counts.append(
            model2_data[ts]["distilled"] / model2_data[ts]["samples"]
        )

# Create plot with grouped bars for each timestep
plt.figure(figsize=(14, 6))
plt.ylim(0, 60)
# Set positions for grouped bars
bar_width = 0.35
x_positions = np.arange(len(common_timesteps))

# Plot model1 (Base model) bars with hatching
plt.bar(
    x_positions - bar_width / 2,
    model1_base_counts,
    bar_width,
    color=colors[1],
    hatch="//",
    edgecolor="black",
    alpha=0.4,
)
plt.bar(
    x_positions - bar_width / 2,
    model1_shared_counts,
    bar_width,
    bottom=model1_base_counts,
    color=colors[2],
    hatch="//",
    edgecolor="black",
    alpha=0.4,
)
# Calculate bottoms for the third bar segment
model1_bottoms = [b + s for b, s in zip(model1_base_counts, model1_shared_counts)]
plt.bar(
    x_positions - bar_width / 2,
    model1_distilled_counts,
    bar_width,
    bottom=model1_bottoms,
    color=colors[-2],
    hatch="//",
    edgecolor="black",
    alpha=0.4,
)

# Plot model2 (Distilled model) bars
plt.bar(
    x_positions + bar_width / 2,
    model2_base_counts,
    bar_width,
    color=colors[1],
    alpha=0.4,
)
plt.bar(
    x_positions + bar_width / 2,
    model2_shared_counts,
    bar_width,
    bottom=model2_base_counts,
    color=colors[2],
    alpha=0.4,
)
# Calculate bottoms for the third bar segment
model2_bottoms = [b + s for b, s in zip(model2_base_counts, model2_shared_counts)]
plt.bar(
    x_positions + bar_width / 2,
    model2_distilled_counts,
    bar_width,
    bottom=model2_bottoms,
    color=colors[-2],
    alpha=0.4,
)

# Add model labels above each bar
for i in range(len(x_positions)):
    # Calculate the total height for each model's bar
    model1_total = (
        model1_base_counts[i] + model1_shared_counts[i] + model1_distilled_counts[i]
    )
    model2_total = (
        model2_base_counts[i] + model2_shared_counts[i] + model2_distilled_counts[i]
    )

    # Add a small offset to position the text right above each bar
    offset = 0.8

    plt.text(
        x_positions[i] - bar_width / 2,
        model1_total + offset,
        "Base",
        ha="center",
        va="bottom",
        fontsize=SMALL_SIZE,
    )

    plt.text(
        x_positions[i] + bar_width / 2,
        model2_total + offset,
        "Distilled",
        ha="center",
        va="bottom",
        fontsize=SMALL_SIZE,
    )

# Set chart labels and properties
plt.ylabel("Average L0 (Active Neurons)")
plt.title("Average Number of Activated Neurons from each Feature Group")
plt.xticks(x_positions, [f"Timestep {int(ts)}" for ts in common_timesteps])

# Create a simplified legend with only the feature types
custom_handles = [
    plt.Rectangle((0, 0), 1, 1, color=colors[1], alpha=0.4),
    plt.Rectangle((0, 0), 1, 1, color=colors[2], alpha=0.4),
    plt.Rectangle((0, 0), 1, 1, color=colors[-2], alpha=0.4),
]
custom_labels = ["Base-specific", "Shared", "Distilled-specific"]
plt.legend(custom_handles, custom_labels, loc="upper left")

plt.grid(axis="y", alpha=0.25, linestyle="--")
plt.tight_layout()

plt.savefig(
    "figures/activated_neurons_both_models_combined.pdf",
    dpi=300,
    format="pdf",
    bbox_inches="tight",
)


In [None]:
# Find neurons that activate most frequently on timestep 999
def find_most_active_neurons_for_timestep(target_timestep=999, top_k=30, threshold=0.0):
    print(f"Finding neurons most active on timestep {target_timestep}...")

    # Get the index for this timestep in our timestep_map
    timestep_idx = None
    for i, ts in enumerate(timestep_map):
        if ts == target_timestep:
            timestep_idx = i
            break

    if timestep_idx is None:
        raise ValueError(f"Timestep {target_timestep} not found in timestep_map")

    # Initialize counters for neuron activation
    neuron_activation_count = torch.zeros(crosscoder.num_latents, device="cpu")
    total_samples = 0

    # Process dataset in batches
    with torch.no_grad():
        for batch in tqdm(dl, desc=f"Processing timestep {target_timestep}"):
            acts = batch["activations"]
            timesteps = batch["timestep"]

            # Only process samples for our target timestep
            mask = timesteps == target_timestep
            if mask.sum() == 0:
                continue

            timestep_acts = acts[mask]

            # Process activations
            timestep_acts = timestep_acts.to(crosscoder.W_dec.device)
            batch_size = timestep_acts.shape[0]

            # Reshape and normalize
            timestep_acts = einops.rearrange(
                timestep_acts,
                "batch sample_size d_model -> (batch sample_size) d_model",
            )
            timestep_acts = (
                timestep_acts * paired_dataset.norm_scaling_factors[0]
            )  # Model 2

            # Forward through encoder to get latent activations
            latents = einops.einsum(
                timestep_acts,
                crosscoder.W_enc[0],  # Model 2
                "batch d_model, d_model d_latent -> batch d_latent",
            )
            latents = torch.nn.functional.relu(latents + crosscoder.b_enc[0])

            # Count activated neurons (activation > threshold)
            activated = (latents > threshold).float()
            neuron_activation_count += activated.sum(dim=0).cpu()
            total_samples += latents.shape[0]

    # Calculate activation frequency (percentage of samples where neuron activated)
    if total_samples > 0:
        activation_frequency = neuron_activation_count / total_samples
    else:
        print(f"Warning: No samples found for timestep {target_timestep}")
        return None, None

    # Get top-k most frequently activated neurons
    top_values, top_indices = torch.topk(activation_frequency, top_k)

    # Determine neuron types
    neuron_types = []
    for idx in top_indices:
        if only_base_features_mask[idx]:
            neuron_types.append("Base")
        elif only_turbo_features_mask[idx]:
            neuron_types.append("Turbo")
        elif shared_features_mask[idx]:
            neuron_types.append("Shared")
        else:
            neuron_types.append("Unknown")

    return top_indices, top_values, neuron_types


# Find top 30 most active neurons for timestep 999
top_indices, top_frequencies, neuron_types = find_most_active_neurons_for_timestep(
    target_timestep=999, top_k=50, threshold=0.0
)

# Display results
if top_indices is not None:
    # Create a table with results
    plt.figure(figsize=(12, 8))
    table_data = []
    for i in range(len(top_indices)):
        neuron_idx = top_indices[i].item()
        freq = top_frequencies[i].item()
        neuron_type = neuron_types[i]
        table_data.append(
            [neuron_idx, f"{freq:.4f}", f"{freq * 100:.2f}%", neuron_type]
        )

    # Create table
    plt.axis("off")
    table = plt.table(
        cellText=table_data,
        colLabels=["Neuron Index", "Activation Frequency", "Percentage", "Type"],
        loc="center",
        cellLoc="center",
    )
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.5)

    plt.title(
        f"Top {len(top_indices)} Neurons Most Frequently Active at Timestep 999",
        fontsize=16,
        pad=20,
    )
    plt.tight_layout()

    # Also create a bar chart
    plt.figure(figsize=(14, 8))
    bars = plt.bar(range(len(top_indices)), top_frequencies.numpy() * 100)

    # Color bars by neuron type
    for i, bar in enumerate(bars):
        if neuron_types[i] == "Base":
            bar.set_color(colors[0])
        elif neuron_types[i] == "Turbo":
            bar.set_color(colors[2])
        elif neuron_types[i] == "Shared":
            bar.set_color(colors[1])

    # Add neuron indices as x-tick labels
    plt.xticks(
        range(len(top_indices)), [f"{idx.item()}" for idx in top_indices], rotation=45
    )

    # Add neuron type annotations above each bar
    for i, bar in enumerate(bars):
        plt.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + 1,
            neuron_types[i],
            ha="center",
            va="bottom",
            rotation=90,
            fontsize=8,
        )

    plt.title("Top Neurons by Activation Frequency at Timestep 999", fontsize=16)
    plt.xlabel("Neuron Index", fontsize=14)
    plt.ylabel("Activation Frequency (%)", fontsize=14)
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()

    # Create a pie chart showing the distribution of neuron types in the top active neurons
    type_counts = {"Base": 0, "Turbo": 0, "Shared": 0}
    for t in neuron_types:
        type_counts[t] += 1

    plt.figure(figsize=(7, 7))
    plt.pie(
        list(type_counts.values()),
        labels=list(type_counts.keys()),
        autopct="%1.1f%%",
        startangle=90,
        colors=[colors[0], colors[2], colors[1]],
    )
    plt.title(
        f"Distribution of Neuron Types Among Top {len(top_indices)} Active at Timestep 999",
        fontsize=14,
    )
    plt.tight_layout()


In [None]:
# Find neurons that activate most frequently on a specific timestep for both models
def compare_most_active_neurons(target_timestep=999, top_k=20, threshold=0.0):
    print(f"Finding neurons most active on timestep {target_timestep}...")

    # Get the index for this timestep in our timestep_map
    timestep_idx = None
    for i, ts in enumerate(timestep_map):
        if ts == target_timestep:
            timestep_idx = i
            break

    if timestep_idx is None:
        raise ValueError(f"Timestep {target_timestep} not found in timestep_map")

    # Initialize counters for neuron activation - base model
    base_neuron_activation_count = torch.zeros(crosscoder.num_latents, device="cpu")
    base_total_samples = 0

    # Initialize counters for neuron activation - distilled model
    distilled_neuron_activation_count = torch.zeros(
        crosscoder.num_latents, device="cpu"
    )
    distilled_total_samples = 0

    # Process base model dataset
    base_dl = torch.utils.data.DataLoader(
        model1_dataset, batch_size=16, shuffle=False, num_workers=4
    )
    distilled_dl = torch.utils.data.DataLoader(
        model2_dataset, batch_size=16, shuffle=False, num_workers=4
    )

    # Process base model dataset
    with torch.no_grad():
        for batch in tqdm(
            base_dl, desc=f"Processing base model - timestep {target_timestep}"
        ):
            acts = batch["activations"]
            timesteps = batch["timestep"]

            # Only process samples for our target timestep
            mask = timesteps == target_timestep
            if mask.sum() == 0:
                continue

            timestep_acts = acts[mask]

            # Process activations
            timestep_acts = timestep_acts.to(crosscoder.W_dec.device)
            batch_size = timestep_acts.shape[0]

            # Reshape and normalize
            timestep_acts = einops.rearrange(
                timestep_acts,
                "batch sample_size d_model -> (batch sample_size) d_model",
            )
            timestep_acts = timestep_acts * paired_dataset.norm_scaling_factors[0]

            # Forward through encoder to get latent activations
            latents = einops.einsum(
                timestep_acts,
                crosscoder.W_enc[0],
                "batch d_model, d_model d_latent -> batch d_latent",
            )
            latents = torch.nn.functional.relu(latents + crosscoder.b_enc[0])

            # Count activated neurons (activation > threshold)
            activated = (latents > threshold).float()
            base_neuron_activation_count += activated.sum(dim=0).cpu()
            base_total_samples += latents.shape[0]

    # Process distilled model dataset
    with torch.no_grad():
        for batch in tqdm(
            distilled_dl,
            desc=f"Processing distilled model - timestep {target_timestep}",
        ):
            acts = batch["activations"]
            timesteps = batch["timestep"]

            # Only process samples for our target timestep
            mask = timesteps == target_timestep
            if mask.sum() == 0:
                continue

            timestep_acts = acts[mask]

            # Process activations
            timestep_acts = timestep_acts.to(crosscoder.W_dec.device)
            batch_size = timestep_acts.shape[0]

            # Reshape and normalize
            timestep_acts = einops.rearrange(
                timestep_acts,
                "batch sample_size d_model -> (batch sample_size) d_model",
            )
            timestep_acts = timestep_acts * paired_dataset.norm_scaling_factors[1]

            # Forward through encoder to get latent activations
            latents = einops.einsum(
                timestep_acts,
                crosscoder.W_enc[1],
                "batch d_model, d_model d_latent -> batch d_latent",
            )
            latents = torch.nn.functional.relu(latents + crosscoder.b_enc[1])

            # Count activated neurons (activation > threshold)
            activated = (latents > threshold).float()
            distilled_neuron_activation_count += activated.sum(dim=0).cpu()
            distilled_total_samples += latents.shape[0]

    # Calculate activation frequency for base model
    if base_total_samples > 0:
        base_activation_frequency = base_neuron_activation_count / base_total_samples
    else:
        print(f"Warning: No base model samples found for timestep {target_timestep}")
        return None, None, None

    # Calculate activation frequency for distilled model
    if distilled_total_samples > 0:
        distilled_activation_frequency = (
            distilled_neuron_activation_count / distilled_total_samples
        )
    else:
        print(
            f"Warning: No distilled model samples found for timestep {target_timestep}"
        )
        return None, None, None

    # Get top-k most frequently activated neurons in the base model
    top_values, top_indices = torch.topk(base_activation_frequency, top_k)

    # Get activation frequencies for those same neurons in the distilled model
    distilled_values = distilled_activation_frequency[top_indices]

    # Determine neuron types
    neuron_types = []
    for idx in top_indices:
        if only_base_features_mask[idx]:
            neuron_types.append("Base")
        elif only_turbo_features_mask[idx]:
            neuron_types.append("Distilled")
        elif shared_features_mask[idx]:
            neuron_types.append("Shared")
        else:
            neuron_types.append("Unknown")

    return top_indices, top_values, distilled_values, neuron_types


# Find top 20 most active neurons for timestep 999 in both models
top_indices, base_frequencies, distilled_frequencies, neuron_types = (
    compare_most_active_neurons(target_timestep=999, top_k=20, threshold=0.0)
)

# Create the comparison bar plot
if top_indices is not None:
    plt.figure(figsize=(14, 8))

    # Set positions for the bars
    bar_width = 0.35
    x_positions = np.arange(len(top_indices))

    # Create grouped bars for both models
    bars1 = plt.bar(
        x_positions - bar_width / 2,
        base_frequencies.numpy() * 100,
        bar_width,
        label="Base Model",
        color=colors[0],
    )

    bars2 = plt.bar(
        x_positions + bar_width / 2,
        distilled_frequencies.numpy() * 100,
        bar_width,
        label="Distilled Model",
        color=colors[2],
    )

    # Set plot aesthetics
    plt.xlabel("Neuron Index")
    plt.ylabel("Activation Frequency (%)")
    plt.title(
        f"Activation Comparison of Top {len(top_indices)} Most Active Neurons at Timestep 999"
    )
    plt.xticks(x_positions, [f"{idx.item()}" for idx in top_indices], rotation=45)

    # Add neuron type markers at the top of the plot
    for i, neuron_type in enumerate(neuron_types):
        plt.text(
            x_positions[i],
            max(base_frequencies[i], distilled_frequencies[i]) * 100 + 2,
            neuron_type,
            ha="center",
            rotation=45,
            fontsize=9,
            color="black",
        )

    plt.legend()
    plt.grid(axis="y", alpha=0.3)
    plt.tight_layout()

    # Save the figure
    plt.savefig(
        "figures/neuron_activation_comparison.pdf",
        dpi=300,
        format="pdf",
        bbox_inches="tight",
    )


In [None]:
(
    crosscoder.W_dec.shape,
    crosscoder.W_enc.shape,
    crosscoder.b_dec.shape,
)

In [None]:
# Function to analyze activations for a single sample
def analyze_sample_activations(sample_idx=0, threshold=0.1, top_k=20):
    """
    Analyze and display neurons activated in both models for a single sample

    Args:
        sample_idx: Index of the sample in the paired dataset
        threshold: Activation threshold to consider a neuron as "active"
        top_k: Number of top activated neurons to display
    """
    # Get the sample from paired dataset
    sample = paired_dataset[sample_idx]

    # Extract activations for both models
    acts_model1 = sample["model1"].to(crosscoder.W_dec.device)  # Base model
    acts_model2 = sample["model2"].to(crosscoder.W_dec.device)  # Distilled model
    timestep1 = sample["model1_timestep"].item()
    timestep2 = sample["model2_timestep"].item()

    # Get latent activations for each model
    # Base model
    latents_model1 = einops.einsum(
        acts_model1,
        crosscoder.W_enc[0],
        "sample_size d_model, d_model d_latent -> sample_size d_latent",
    )
    latents_model1 = torch.nn.functional.relu(latents_model1 + crosscoder.b_enc[0])

    # Distilled model
    latents_model2 = einops.einsum(
        acts_model2,
        crosscoder.W_enc[1],
        "sample_size d_model, d_model d_latent -> sample_size d_latent",
    )
    latents_model2 = torch.nn.functional.relu(latents_model2 + crosscoder.b_enc[1])

    # Calculate mean activations across spatial positions
    avg_latents_model1 = latents_model1.mean(dim=0)
    avg_latents_model2 = latents_model2.mean(dim=0)

    # Identify active neurons in each model
    active_neurons_model1 = (avg_latents_model1 > threshold).cpu()
    active_neurons_model2 = (avg_latents_model2 > threshold).cpu()

    # Find neurons active in both models
    active_in_both = active_neurons_model1 & active_neurons_model2

    # Count active neurons
    num_active_model1 = active_neurons_model1.sum().item()
    num_active_model2 = active_neurons_model2.sum().item()
    num_active_both = active_in_both.sum().item()

    # Check for which feature types the neurons belong to
    active_base_specific = active_neurons_model1 & only_base_features_mask
    active_distilled_specific = active_neurons_model2 & only_turbo_features_mask
    active_shared = active_in_both & shared_features_mask

    # Print statistics
    print(f"Sample Index: {sample_idx}, File: {sample['file_name']}")
    print(f"Timesteps: Base = {int(timestep1)}, Distilled = {int(timestep2)}")
    print(f"Active neurons in Base model: {num_active_model1}")
    print(f"Active neurons in Distilled model: {num_active_model2}")
    print(f"Neurons active in both models: {num_active_both}")
    print(f"  - Base-specific neurons active: {active_base_specific.sum().item()}")
    print(
        f"  - Distilled-specific neurons active: {active_distilled_specific.sum().item()}"
    )
    print(f"  - Shared neurons active: {active_shared.sum().item()}")

    # Find top-k activating neurons for each model
    top_neurons_model1 = torch.topk(avg_latents_model1, top_k)
    top_neurons_model2 = torch.topk(avg_latents_model2, top_k)

    # Create visualization
    fig, axes = plt.subplots(1, 2, figsize=(14, 6))

    # Display image
    img = Image.open(os.path.join(dataset_path, sample["file_name"]))
    img = img.resize((512, 512))
    axes[0].imshow(img)
    axes[0].set_title(f"Sample Image\n{os.path.basename(sample['file_name'])}")
    axes[0].axis("off")

    # Create a visualization of neuron types
    neuron_types = np.zeros(crosscoder.num_latents)
    # 1 for base-only, 2 for distilled-only, 3 for shared
    neuron_types[only_base_features_mask.cpu().numpy()] = 1
    neuron_types[only_turbo_features_mask.cpu().numpy()] = 2
    neuron_types[shared_features_mask.cpu().numpy()] = 3

    # Create a visualization of active neurons
    active_status = np.zeros(crosscoder.num_latents)
    active_status[active_neurons_model1.cpu().numpy()] = 1
    active_status[active_neurons_model2.cpu().numpy()] += 2
    # Now: 0=inactive, 1=base-only, 2=distilled-only, 3=both

    # Plot the top activating neurons
    bar_width = 0.35
    x_positions = np.arange(top_k)

    axes[1].bar(
        x_positions - bar_width / 2,
        top_neurons_model1.values.cpu().numpy(),
        bar_width,
        label="Base Model",
        color=colors[0],
        alpha=0.7,
    )

    axes[1].bar(
        x_positions + bar_width / 2,
        top_neurons_model2.values.cpu().numpy(),
        bar_width,
        label="Distilled Model",
        color=colors[2],
        alpha=0.7,
    )

    # Annotate with neuron indices and types
    for i in range(top_k):
        idx1 = top_neurons_model1.indices[i].item()
        idx2 = top_neurons_model2.indices[i].item()

        # Add annotations
        axes[1].text(
            x_positions[i] - bar_width / 2,
            top_neurons_model1.values[i].item() + 0.01,
            f"{idx1}",
            ha="center",
            va="bottom",
            rotation=45,
        )

        axes[1].text(
            x_positions[i] + bar_width / 2,
            top_neurons_model2.values[i].item() + 0.01,
            f"{idx2}",
            ha="center",
            va="bottom",
            rotation=45,
        )

    axes[1].set_xticks(x_positions)
    axes[1].set_xticklabels([f"{i + 1}" for i in range(top_k)])
    axes[1].set_xlabel("Rank")
    axes[1].set_ylabel("Activation Value")
    axes[1].set_title("Top Activating Neurons")
    axes[1].legend()
    axes[1].grid(alpha=0.3)

    plt.tight_layout()

    return latents_model1, latents_model2


# Call the function with sample index 0
latents1, latents2 = analyze_sample_activations(sample_idx=6, threshold=0.0, top_k=10)


In [None]:
def visualize_top_activations(sample_idx, dataset_path, threshold=0.01, top_k=5):
    """
    Visualize top activating neurons for both models on a specific sample.

    Args:
        sample_idx: Index of the sample to visualize
        dataset_path: Path to the dataset containing images
        threshold: Activation threshold
        top_k: Number of top neurons to display
    """
    # Get the sample
    sample = paired_dataset[sample_idx]

    # Get activations for both models
    model1_act = sample["model1"].unsqueeze(0).to(crosscoder.W_dec.device)  # Base model
    model2_act = (
        sample["model2"].unsqueeze(0).to(crosscoder.W_dec.device)
    )  # Distilled model

    # Get timesteps
    model1_ts = sample["model1_timestep"].item()
    model2_ts = sample["model2_timestep"].item()

    # Load the original image
    img_path = os.path.join(dataset_path, sample["file_name"])
    original_img = Image.open(img_path)
    original_img = original_img.resize((512, 512))  # Resize to consistent dimensions
    original_img = original_img.convert("RGB")  # Ensure RGB format

    # Process through encoders
    with torch.no_grad():
        # Base model
        model1_act_reshaped = model1_act.reshape(-1, model1_act.shape[-1])
        model1_latents = einops.einsum(
            model1_act_reshaped,
            crosscoder.W_enc[0],
            "batch d_model, d_model d_latent -> batch d_latent",
        )
        model1_latents = torch.nn.functional.relu(model1_latents + crosscoder.b_enc[0])

        # Distilled model
        model2_act_reshaped = model2_act.reshape(-1, model2_act.shape[-1])
        model2_latents = einops.einsum(
            model2_act_reshaped,
            crosscoder.W_enc[1],  # Use model2 encoder weights
            "batch d_model, d_model d_latent -> batch d_latent",
        )
        model2_latents = torch.nn.functional.relu(model2_latents + crosscoder.b_enc[1])

    # Get average activation per neuron across spatial positions
    model1_avg_activations = model1_latents.mean(dim=0)
    model2_avg_activations = model2_latents.mean(dim=0)

    # Get neurons that exceed the threshold
    model1_active_neurons = (
        (model1_avg_activations > threshold).nonzero().squeeze().cpu().numpy()
    )
    model2_active_neurons = (
        (model2_avg_activations > threshold).nonzero().squeeze().cpu().numpy()
    )

    # Get the top_k neurons with highest activations
    model1_top_neurons = (
        torch.argsort(model1_avg_activations, descending=True)[:top_k].cpu().numpy()
    )
    model2_top_neurons = (
        torch.argsort(model2_avg_activations, descending=True)[:top_k].cpu().numpy()
    )

    # Create a figure with 2 rows, top_k columns
    fig, axes = plt.subplots(2, top_k, figsize=(top_k * 3, 6))

    # # Add title
    # fig.suptitle(
    #     f"Top {top_k} Activated Neurons for Sample {sample_idx}",
    # )

    # Function to overlay activations on image using the original method
    def plot_activation_overlay(ax, neuron_idx, model_idx, latents, activation_value):
        # Make a copy of the original image for this subplot
        img = original_img.copy()

        # Get spatial activations for this neuron
        spatial_dim = int(np.sqrt(latents.shape[0]))
        spatial_activations = (
            latents[:, neuron_idx]
            .reshape(spatial_dim, spatial_dim)
            .detach()
            .cpu()
            .numpy()
        )

        # Normalize for visualization
        activation_map = (spatial_activations - spatial_activations.min()) / (
            spatial_activations.max() - spatial_activations.min() + 1e-8
        )

        # Calculate upscale factor to match image size
        patch_size = 512 // activation_map.shape[0]
        activation_map = np.kron(activation_map, np.ones((patch_size, patch_size)))

        # Create heatmap overlay using jet colormap
        heatmap = np.uint8(plt.cm.jet(activation_map)[..., :3] * 255)
        heatmap = Image.fromarray(heatmap)

        # Blend original image with heatmap
        blended_img = Image.blend(img, heatmap, alpha=0.4)

        # Display the blended image
        ax.imshow(np.array(blended_img))

        # Remove ticks
        ax.set_xticks([])
        ax.set_yticks([])

        # For base model (top row), add neuron ID above
        if model_idx == 0:
            ax.set_title(f"#{neuron_idx}")
        # For distilled model (bottom row), add neuron ID below
        else:
            ax.set_xlabel(f"#{neuron_idx}")

    # Plot base model activations (top row)
    for i in range(top_k):
        neuron_idx = model1_top_neurons[i]
        activation_value = model1_avg_activations[neuron_idx].item()
        plot_activation_overlay(
            axes[0, i], neuron_idx, 0, model1_latents, activation_value
        )

    # Plot distilled model activations (bottom row)
    for i in range(top_k):
        neuron_idx = model2_top_neurons[i]
        activation_value = model2_avg_activations[neuron_idx].item()
        plot_activation_overlay(
            axes[1, i], neuron_idx, 1, model2_latents, activation_value
        )

    # Add row labels
    axes[0, 0].set_ylabel("Base Model")
    axes[1, 0].set_ylabel("Distilled Model")

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Adjust for suptitle

    # Save the figure
    plt.savefig(
        f"figures/top_activations_sample_{sample_idx}.pdf",
        dpi=300,
        format="pdf",
        bbox_inches="tight",
    )

    # Return statistics about the activation distribution
    return {
        "base_model": {
            "top_neurons": model1_top_neurons,
            "activations": model1_avg_activations[model1_top_neurons].cpu().numpy(),
            "total_active": len(model1_active_neurons),
        },
        "distilled_model": {
            "top_neurons": model2_top_neurons,
            "activations": model2_avg_activations[model2_top_neurons].cpu().numpy(),
            "total_active": len(model2_active_neurons),
        },
    }


# Example usage
# Visualize top activations for a specific sample
dataset_path = "/data/SHARE/datasets"  # Replace with actual dataset path
activation_stats = visualize_top_activations(
    sample_idx=15, dataset_path=dataset_path, threshold=0.01, top_k=5
)

# Print summary statistics
print("Base Model activation summary:")
for i, (neuron, act) in enumerate(
    zip(
        activation_stats["base_model"]["top_neurons"],
        activation_stats["base_model"]["activations"],
    )
):
    print(f"  #{i + 1}: Neuron {neuron} - Activation: {act:.4f}")
print(f"Total active neurons: {activation_stats['base_model']['total_active']}")

print("\nDistilled Model activation summary:")
for i, (neuron, act) in enumerate(
    zip(
        activation_stats["distilled_model"]["top_neurons"],
        activation_stats["distilled_model"]["activations"],
    )
):
    print(f"  #{i + 1}: Neuron {neuron} - Activation: {act:.4f}")
print(f"Total active neurons: {activation_stats['distilled_model']['total_active']}")


In [None]:
def visualize_top_activations(sample_idx, dataset_path, threshold=0.01, top_k=5):
    """
    Visualize top activating neurons for both models on a specific sample.

    Args:
        sample_idx: Index of the sample to visualize
        dataset_path: Path to the dataset containing images
        threshold: Activation threshold
        top_k: Number of top neurons to display
    """
    # Get the sample
    sample = paired_dataset[sample_idx]

    # Get activations for both models
    model1_act = sample["model1"].unsqueeze(0).to(crosscoder.W_dec.device)  # Base model
    model2_act = (
        sample["model2"].unsqueeze(0).to(crosscoder.W_dec.device)
    )  # Distilled model

    # Get timesteps
    model1_ts = sample["model1_timestep"].item()
    model2_ts = sample["model2_timestep"].item()

    # Load the original image
    img_path = os.path.join(dataset_path, sample["file_name"])
    original_img = Image.open(img_path)
    original_img = original_img.resize((512, 512))  # Resize to consistent dimensions
    original_img = original_img.convert("RGB")  # Ensure RGB format

    # Process through encoders
    with torch.no_grad():
        # Base model
        model1_act_reshaped = model1_act.reshape(-1, model1_act.shape[-1])
        model1_latents = einops.einsum(
            model1_act_reshaped,
            crosscoder.W_enc[0],
            "batch d_model, d_model d_latent -> batch d_latent",
        )
        model1_latents = torch.nn.functional.relu(model1_latents + crosscoder.b_enc[0])

        # Distilled model
        model2_act_reshaped = model2_act.reshape(-1, model2_act.shape[-1])
        model2_latents = einops.einsum(
            model2_act_reshaped,
            crosscoder.W_enc[1],  # Use model2 encoder weights
            "batch d_model, d_model d_latent -> batch d_latent",
        )
        model2_latents = torch.nn.functional.relu(model2_latents + crosscoder.b_enc[1])

    # Get average activation per neuron across spatial positions
    model1_avg_activations = model1_latents.mean(dim=0)
    model2_avg_activations = model2_latents.mean(dim=0)

    # Get neurons that exceed the threshold
    model1_active_neurons = (
        (model1_avg_activations > threshold).nonzero().squeeze().cpu().numpy()
    )
    model2_active_neurons = (
        (model2_avg_activations > threshold).nonzero().squeeze().cpu().numpy()
    )

    # Get the top_k neurons with highest activations
    model1_top_neurons = (
        torch.argsort(model1_avg_activations, descending=True)[:top_k].cpu().numpy()
    )
    model2_top_neurons = (
        torch.argsort(model2_avg_activations, descending=True)[:top_k].cpu().numpy()
    )

    # Create a figure with 2 rows, top_k columns
    fig, axes = plt.subplots(2, top_k, figsize=(top_k * 3, 6))

    # Add title
    fig.suptitle(
        f"Top {top_k} Activated Neurons for Sample {sample_idx} (Timestep Base: {model1_ts}, Distilled: {model2_ts})",
        fontsize=14,
    )

    # Function to overlay activations on image using the original method
    def plot_activation_overlay(ax, neuron_idx, model_idx, latents, activation_value):
        # Make a copy of the original image for this subplot
        img = original_img.copy()

        # Get spatial activations for this neuron
        spatial_dim = int(np.sqrt(latents.shape[0]))
        spatial_activations = (
            latents[:, neuron_idx]
            .reshape(spatial_dim, spatial_dim)
            .detach()
            .cpu()
            .numpy()
        )

        # Normalize for visualization
        activation_map = (spatial_activations - spatial_activations.min()) / (
            spatial_activations.max() - spatial_activations.min() + 1e-8
        )

        # Calculate upscale factor to match image size
        patch_size = 512 // activation_map.shape[0]
        activation_map = np.kron(activation_map, np.ones((patch_size, patch_size)))

        # Create heatmap overlay using jet colormap
        heatmap = np.uint8(plt.cm.jet(activation_map)[..., :3] * 255)
        heatmap = Image.fromarray(heatmap)

        # Blend original image with heatmap
        blended_img = Image.blend(img, heatmap, alpha=0.4)

        # Display the blended image
        ax.imshow(np.array(blended_img))

        # Remove ticks
        ax.set_xticks([])
        ax.set_yticks([])

        # For base model (top row), add neuron ID above
        if model_idx == 0:
            ax.set_title(f"#{neuron_idx}\n{activation_value:.2f}", fontsize=10)
        # For distilled model (bottom row), add neuron ID below
        else:
            ax.set_xlabel(f"#{neuron_idx}\n{activation_value:.2f}", fontsize=10)

    # Plot base model activations (top row)
    for i in range(top_k):
        neuron_idx = model1_top_neurons[i]
        activation_value = model1_avg_activations[neuron_idx].item()
        plot_activation_overlay(
            axes[0, i], neuron_idx, 0, model1_latents, activation_value
        )

    # Plot distilled model activations (bottom row)
    for i in range(top_k):
        neuron_idx = model2_top_neurons[i]
        activation_value = model2_avg_activations[neuron_idx].item()
        plot_activation_overlay(
            axes[1, i], neuron_idx, 1, model2_latents, activation_value
        )

    # Add row labels
    axes[0, 0].set_ylabel("Base Model", fontsize=12)
    axes[1, 0].set_ylabel("Distilled Model", fontsize=12)

    plt.tight_layout()
    plt.subplots_adjust(top=0.9)  # Adjust for suptitle

    # Save the figure
    plt.savefig(
        f"figures/top_activations_sample_{sample_idx}.pdf",
        dpi=300,
        format="pdf",
        bbox_inches="tight",
    )

    # Return statistics about the activation distribution
    return {
        "base_model": {
            "top_neurons": model1_top_neurons,
            "activations": model1_avg_activations[model1_top_neurons].cpu().numpy(),
            "total_active": len(model1_active_neurons),
        },
        "distilled_model": {
            "top_neurons": model2_top_neurons,
            "activations": model2_avg_activations[model2_top_neurons].cpu().numpy(),
            "total_active": len(model2_active_neurons),
        },
    }


# Example usage
# Visualize top activations for a specific sample
dataset_path = "/path/to/your/dataset"  # Replace with actual dataset path
activation_stats = visualize_top_activations(
    sample_idx=42, dataset_path=dataset_path, threshold=0.01, top_k=5
)

# Print summary statistics
print("Base Model activation summary:")
for i, (neuron, act) in enumerate(
    zip(
        activation_stats["base_model"]["top_neurons"],
        activation_stats["base_model"]["activations"],
    )
):
    print(f"  #{i + 1}: Neuron {neuron} - Activation: {act:.4f}")
print(f"Total active neurons: {activation_stats['base_model']['total_active']}")

print("\nDistilled Model activation summary:")
for i, (neuron, act) in enumerate(
    zip(
        activation_stats["distilled_model"]["top_neurons"],
        activation_stats["distilled_model"]["activations"],
    )
):
    print(f"  #{i + 1}: Neuron {neuron} - Activation: {act:.4f}")
print(f"Total active neurons: {activation_stats['distilled_model']['total_active']}")

In [None]:
# Calculate reconstruction error of distilled model on its own data by timestep
def calculate_distilled_error_by_timestep():
    results = {}

    # Create dataloader for distilled model dataset
    distilled_dl = torch.utils.data.DataLoader(
        model2_dataset, batch_size=16, shuffle=False, num_workers=4
    )

    # Initialize counters for each timestep
    timestep_errors = {
        int(ts): {"total_mse": 0.0, "total_samples": 0} for ts in timestep_map
    }

    # Get weights for distilled model
    W_enc = crosscoder.W_enc[1]
    b_enc = crosscoder.b_enc
    W_dec = crosscoder.W_dec[:, 1, :]
    norm_scaling_factor = paired_dataset.norm_scaling_factors[1]

    # Process all samples
    with torch.no_grad():
        for batch in tqdm(
            distilled_dl, desc="Computing reconstruction error by timestep"
        ):
            acts = batch["activations"].to(crosscoder.W_dec.device)
            timesteps = batch["timestep"]

            # Process each sample
            for i in range(acts.shape[0]):
                act = acts[i].unsqueeze(0)  # Add batch dimension
                ts = int(timesteps[i].item())

                if ts not in timestep_errors:
                    continue

                # Reshape and normalize
                act = einops.rearrange(
                    act,
                    "batch sample_size d_model -> (batch sample_size) d_model",
                )
                act = act * norm_scaling_factor

                # Forward through encoder to get latent activations
                latents = einops.einsum(
                    act,
                    W_enc,
                    "batch d_model, d_model d_latent -> batch d_latent",
                )
                latents = torch.nn.functional.relu(latents + b_enc)

                # Forward through decoder to reconstruct
                reconstructed = einops.einsum(
                    latents,
                    W_dec,
                    "batch d_latent, d_latent d_model -> batch d_model",
                )

                # Calculate MSE
                mse = torch.nn.functional.mse_loss(reconstructed, act)

                # Update counters
                timestep_errors[ts]["total_mse"] += mse.item() * act.shape[0]
                timestep_errors[ts]["total_samples"] += act.shape[0]

    # Calculate average MSE for each timestep
    for ts in timestep_errors:
        if timestep_errors[ts]["total_samples"] > 0:
            avg_mse = (
                timestep_errors[ts]["total_mse"] / timestep_errors[ts]["total_samples"]
            )
            results[ts] = {
                "mse": avg_mse,
                "samples": timestep_errors[ts]["total_samples"],
            }
            print(
                f"Timestep {ts}: MSE = {avg_mse:.6f}, Samples = {timestep_errors[ts]['total_samples']}"
            )
        else:
            print(f"No samples found for timestep {ts}")

    return results


# Calculate reconstruction errors by timestep
distilled_errors = calculate_distilled_error_by_timestep()

# Create a bar plot of the errors
plt.figure(figsize=(10, 6))

timesteps = list(distilled_errors.keys())
errors = [distilled_errors[ts]["mse"] for ts in timesteps]
samples = [distilled_errors[ts]["samples"] for ts in timesteps]

# Create the bar plot
bars = plt.bar(range(len(timesteps)), errors, color=colors[2], alpha=0.8)

# Add sample count as text on each bar
for i, bar in enumerate(bars):
    plt.text(
        bar.get_x() + bar.get_width() / 2,
        bar.get_height() + 0.001,
        f"n={samples[i]}",
        ha="center",
        fontsize=9,
    )

plt.title("Reconstruction Error of Distilled Model by Timestep")
plt.xlabel("Timestep")
plt.ylabel("MSE")
plt.xticks(range(len(timesteps)), [f"{ts}" for ts in timesteps])
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()

# Save the figure
plt.savefig(
    "figures/distilled_error_by_timestep.pdf",
    dpi=300,
    format="pdf",
    bbox_inches="tight",
)

# Also create a version with error bars representing standard deviation
plt.figure(figsize=(10, 6))

# Calculate standard deviations by timestep (we'll need to run another pass over the data)
timestep_stds = {int(ts): {"errors": [], "total_samples": 0} for ts in timestep_map}
distilled_dl = torch.utils.data.DataLoader(
    model2_dataset, batch_size=16, shuffle=False, num_workers=4
)
norm_scaling_factor = paired_dataset.norm_scaling_factors[1]
with torch.no_grad():
    for batch in tqdm(distilled_dl, desc="Computing error standard deviations"):
        acts = batch["activations"].to(crosscoder.W_dec.device)
        timesteps = batch["timestep"]

        # Process each sample
        for i in range(acts.shape[0]):
            act = acts[i].unsqueeze(0)  # Add batch dimension
            ts = int(timesteps[i].item())

            if ts not in timestep_stds:
                continue

            # Reshape and normalize
            act = einops.rearrange(
                act,
                "batch sample_size d_model -> (batch sample_size) d_model",
            )
            act = act * norm_scaling_factor

            # Forward through encoder to get latent activations
            latents = einops.einsum(
                act,
                crosscoder.W_enc[1],
                "batch d_model, d_model d_latent -> batch d_latent",
            )
            latents = torch.nn.functional.relu(latents + crosscoder.b_enc[1])

            # Forward through decoder to reconstruct
            reconstructed = einops.einsum(
                latents,
                crosscoder.W_dec[:, 1, :],
                "batch d_latent, d_latent d_model -> batch d_model",
            )

            # Calculate MSE
            mse = torch.nn.functional.mse_loss(reconstructed, act)

            # Store individual errors
            timestep_stds[ts]["errors"].append(mse.item())
            timestep_stds[ts]["total_samples"] += 1

# Calculate standard deviations
error_stds = []
for ts in timesteps:
    if timestep_stds[ts]["total_samples"] > 0:
        std = np.std(timestep_stds[ts]["errors"])
        error_stds.append(std)
    else:
        error_stds.append(0)

# Create bar plot with error bars
plt.bar(
    range(len(timesteps)),
    errors,
    yerr=error_stds,
    color=colors[2],
    alpha=0.8,
    ecolor="black",
    capsize=5,
)

plt.title(
    "Reconstruction Error of Distilled Model by Timestep (with Standard Deviation)"
)
plt.xlabel("Timestep")
plt.ylabel("MSE")
plt.xticks(range(len(timesteps)), [f"{ts}" for ts in timesteps])
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()

# Save the figure
plt.savefig(
    "figures/distilled_error_by_timestep_with_std.pdf",
    dpi=300,
    format="pdf",
    bbox_inches="tight",
)


In [None]:
# Compute average norms per timestep for both datasets using batches
def compute_norms_by_timestep(batch_size=32):
    print("Computing norms by timestep...")

    # Dictionaries to store results
    model1_norms_by_timestep = {int(ts): [] for ts in timestep_map}
    model2_norms_by_timestep = {int(ts): [] for ts in timestep_map}

    # For calculating variance
    model1_acts_by_timestep = {int(ts): [] for ts in timestep_map}
    model2_acts_by_timestep = {int(ts): [] for ts in timestep_map}

    # Create dataloaders for batch processing
    model1_loader = torch.utils.data.DataLoader(
        model1_dataset, batch_size=batch_size, shuffle=False, num_workers=4
    )

    model2_loader = torch.utils.data.DataLoader(
        model2_dataset, batch_size=batch_size, shuffle=False, num_workers=4
    )

    # Process model1 dataset in batches
    print("Processing base model dataset...")
    for batch in tqdm(model1_loader, desc="Base model"):
        acts = batch["activations"]
        timesteps = batch["timestep"]

        # Process each sample in the batch
        for i in range(acts.shape[0]):
            ts = int(timesteps[i].item())

            if ts in model1_norms_by_timestep:
                act = acts[i]
                norm = torch.norm(act, dim=-1).mean().item()
                model1_norms_by_timestep[ts].append(norm)
                model1_acts_by_timestep[ts].append(act)

    # Process model2 dataset in batches
    print("Processing distilled model dataset...")
    for batch in tqdm(model2_loader, desc="Distilled model"):
        acts = batch["activations"]
        timesteps = batch["timestep"]

        # Process each sample in the batch
        for i in range(acts.shape[0]):
            ts = int(timesteps[i].item())

            if ts in model2_norms_by_timestep:
                act = acts[i]
                norm = torch.norm(act, dim=-1).mean().item()
                model2_norms_by_timestep[ts].append(norm)
                model2_acts_by_timestep[ts].append(act)

    # Compute variance for model1 by timestep
    print("Computing variances...")
    model1_variance_by_timestep = {}
    for ts in tqdm(model1_acts_by_timestep, desc="Base model variance"):
        if len(model1_acts_by_timestep[ts]) > 0:
            # Stack activations for this timestep - limit to 1000 samples if more available
            samples = model1_acts_by_timestep[ts]
            if len(samples) > 1000:
                samples = random.sample(samples, 1000)

            acts_tensor = torch.stack(samples)

            # Calculate mean and variance
            mean = acts_tensor.mean(0)
            variance = (acts_tensor - mean).pow(2).sum(-1).mean().item()

            model1_variance_by_timestep[ts] = variance
        else:
            model1_variance_by_timestep[ts] = None

    # Compute variance for model2 by timestep
    model2_variance_by_timestep = {}
    for ts in tqdm(model2_acts_by_timestep, desc="Distilled model variance"):
        if len(model2_acts_by_timestep[ts]) > 0:
            # Stack activations for this timestep - limit to 1000 samples if more available
            samples = model2_acts_by_timestep[ts]
            if len(samples) > 1000:
                samples = random.sample(samples, 1000)

            acts_tensor = torch.stack(samples)

            # Calculate mean and variance
            mean = acts_tensor.mean(0)
            variance = (acts_tensor - mean).pow(2).sum(-1).mean().item()

            model2_variance_by_timestep[ts] = variance
        else:
            model2_variance_by_timestep[ts] = None

    # Calculate average norms for each timestep
    model1_avg_norms = {}
    model2_avg_norms = {}

    for ts in model1_norms_by_timestep:
        if len(model1_norms_by_timestep[ts]) > 0:
            model1_avg_norms[ts] = np.mean(model1_norms_by_timestep[ts])
        else:
            print(f"No samples for model1, timestep {ts}")
            model1_avg_norms[ts] = 0

    for ts in model2_norms_by_timestep:
        if len(model2_norms_by_timestep[ts]) > 0:
            model2_avg_norms[ts] = np.mean(model2_norms_by_timestep[ts])
        else:
            print(f"No samples for model2, timestep {ts}")
            model2_avg_norms[ts] = 0

    # Print results
    print("\nBase Model (Model 1) average norms by timestep:")
    for ts in sorted(model1_avg_norms.keys()):
        if model1_variance_by_timestep[ts] is not None:
            print(
                f"  Timestep {ts}: Avg Norm = {model1_avg_norms[ts]:.4f}, Variance = {model1_variance_by_timestep[ts]:.4f}, Samples = {len(model1_norms_by_timestep[ts])}"
            )
        else:
            print(
                f"  Timestep {ts}: Avg Norm = {model1_avg_norms[ts]:.4f}, Variance = N/A, Samples = {len(model1_norms_by_timestep[ts])}"
            )

    print("\nDistilled Model (Model 2) average norms by timestep:")
    for ts in sorted(model2_avg_norms.keys()):
        if model2_variance_by_timestep[ts] is not None:
            print(
                f"  Timestep {ts}: Avg Norm = {model2_avg_norms[ts]:.4f}, Variance = {model2_variance_by_timestep[ts]:.4f}, Samples = {len(model2_norms_by_timestep[ts])}"
            )
        else:
            print(
                f"  Timestep {ts}: Avg Norm = {model2_avg_norms[ts]:.4f}, Variance = N/A, Samples = {len(model2_norms_by_timestep[ts])}"
            )

    # Calculate target norm (as in the original code)
    d_model = model1_dataset[0]["activations"].shape[-1]
    target_norm = torch.sqrt(torch.tensor(d_model)).item()
    print(f"\nTarget norm: {target_norm:.4f}")

    # Calculate scaling factors for each timestep
    model1_scaling_factors = {
        ts: target_norm / norm if norm > 0 else 0
        for ts, norm in model1_avg_norms.items()
    }
    model2_scaling_factors = {
        ts: target_norm / norm if norm > 0 else 0
        for ts, norm in model2_avg_norms.items()
    }

    print("\nScaling factors by timestep:")
    for ts in sorted(model1_scaling_factors.keys()):
        print(
            f"  Timestep {ts}: Model1 = {model1_scaling_factors[ts]:.4f}, Model2 = {model2_scaling_factors[ts]:.4f}"
        )

    return {
        "model1_norms": model1_avg_norms,
        "model2_norms": model2_avg_norms,
        "model1_variance": model1_variance_by_timestep,
        "model2_variance": model2_variance_by_timestep,
        "target_norm": target_norm,
        "model1_samples": {
            ts: len(model1_norms_by_timestep[ts]) for ts in model1_norms_by_timestep
        },
        "model2_samples": {
            ts: len(model2_norms_by_timestep[ts]) for ts in model2_norms_by_timestep
        },
        "model1_scaling_factors": model1_scaling_factors,
        "model2_scaling_factors": model2_scaling_factors,
    }


# Import random for sampling in variance calculation
import random

# Calculate norms by timestep
norm_results = compute_norms_by_timestep(batch_size=64)

# Create a bar plot comparing norms by timestep
plt.figure(figsize=(12, 6))

timesteps = sorted(norm_results["model1_norms"].keys())
model1_norms = [norm_results["model1_norms"][ts] for ts in timesteps]
model2_norms = [norm_results["model2_norms"][ts] for ts in timesteps]

# Set positions for the bars
bar_width = 0.35
x_positions = np.arange(len(timesteps))

# Create grouped bars for both models
bars1 = plt.bar(
    x_positions - bar_width / 2,
    model1_norms,
    bar_width,
    label="Base Model",
    color=colors[0],
)

bars2 = plt.bar(
    x_positions + bar_width / 2,
    model2_norms,
    bar_width,
    label="Distilled Model",
    color=colors[2],
)

# Add a horizontal line for the target norm
plt.axhline(
    y=norm_results["target_norm"],
    color="k",
    linestyle="--",
    alpha=0.5,
    label=f"Target Norm ({norm_results['target_norm']:.2f})",
)

# Add sample counts as text
for i, (ts, count) in enumerate(norm_results["model1_samples"].items()):
    if ts in timesteps:
        idx = timesteps.index(ts)
        plt.text(
            x_positions[idx] - bar_width / 2,
            model1_norms[idx] + 0.5,
            f"n={count}",
            ha="center",
            fontsize=8,
            rotation=45,
        )

for i, (ts, count) in enumerate(norm_results["model2_samples"].items()):
    if ts in timesteps:
        idx = timesteps.index(ts)
        plt.text(
            x_positions[idx] + bar_width / 2,
            model2_norms[idx] + 0.5,
            f"n={count}",
            ha="center",
            fontsize=8,
            rotation=45,
        )

plt.title("Average Activation Norms by Timestep")
plt.xlabel("Timestep")
plt.ylabel("Average Norm")
plt.xticks(x_positions, [f"{ts}" for ts in timesteps])
plt.legend()
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()

# Save the figure
plt.savefig(
    "figures/activation_norms_by_timestep.pdf",
    dpi=300,
    format="pdf",
    bbox_inches="tight",
)

# Also create a bar plot for the variances
plt.figure(figsize=(12, 6))

model1_variances = [
    norm_results["model1_variance"][ts]
    if norm_results["model1_variance"][ts] is not None
    else 0
    for ts in timesteps
]
model2_variances = [
    norm_results["model2_variance"][ts]
    if norm_results["model2_variance"][ts] is not None
    else 0
    for ts in timesteps
]

# Create grouped bars for variances
bars1 = plt.bar(
    x_positions - bar_width / 2,
    model1_variances,
    bar_width,
    label="Base Model",
    color=colors[0],
)

bars2 = plt.bar(
    x_positions + bar_width / 2,
    model2_variances,
    bar_width,
    label="Distilled Model",
    color=colors[2],
)

plt.title("Activation Variances by Timestep")
plt.xlabel("Timestep")
plt.ylabel("Variance")
plt.xticks(x_positions, [f"{ts}" for ts in timesteps])
plt.legend()
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()

# Save the figure
plt.savefig(
    "figures/activation_variances_by_timestep.pdf",
    dpi=300,
    format="pdf",
    bbox_inches="tight",
)

# Calculate and plot norm ratios
plt.figure(figsize=(10, 6))

norm_ratios = [
    model1_norms[i] / model2_norms[i] if model2_norms[i] > 0 else 0
    for i in range(len(timesteps))
]

plt.bar(x_positions, norm_ratios, color=colors[1])
plt.axhline(y=1.0, color="k", linestyle="--", alpha=0.5)

plt.title("Ratio of Base Model Norm to Distilled Model Norm by Timestep")
plt.xlabel("Timestep")
plt.ylabel("Norm Ratio (Base/Distilled)")
plt.xticks(x_positions, [f"{ts}" for ts in timesteps])
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()

# Save the figure
plt.savefig(
    "figures/norm_ratios_by_timestep.pdf", dpi=300, format="pdf", bbox_inches="tight"
)

# Create a plot for the scaling factors
plt.figure(figsize=(10, 6))

model1_scaling = [norm_results["model1_scaling_factors"][ts] for ts in timesteps]
model2_scaling = [norm_results["model2_scaling_factors"][ts] for ts in timesteps]

# Create grouped bars
bars1 = plt.bar(
    x_positions - bar_width / 2,
    model1_scaling,
    bar_width,
    label="Base Model",
    color=colors[0],
)

bars2 = plt.bar(
    x_positions + bar_width / 2,
    model2_scaling,
    bar_width,
    label="Distilled Model",
    color=colors[2],
)

plt.title("Scaling Factors by Timestep")
plt.xlabel("Timestep")
plt.ylabel("Scaling Factor")
plt.xticks(x_positions, [f"{ts}" for ts in timesteps])
plt.legend()
plt.grid(axis="y", alpha=0.3)
plt.tight_layout()

# Save the figure
plt.savefig(
    "figures/scaling_factors_by_timestep.pdf",
    dpi=300,
    format="pdf",
    bbox_inches="tight",
)
