In [None]:
import argparse
import itertools
import os
import pickle
from collections.abc import Generator
from dataclasses import dataclass, field
from datetime import datetime
from typing import Any, List, Optional, Tuple, Union
from scipy.spatial.distance import pdist, squareform
from fastcluster import linkage

import numpy as np
import numpy.typing as npt
import torch
import torch.multiprocessing as mp
import torch.nn as nn
import torch.optim as optim
from matplotlib import pyplot as plt
from scipy.stats import ortho_group
from torchtyping import TensorType
from tqdm import tqdm

In [None]:
@dataclass
class ToyArgs:
    device: Union[str, Any] = "cuda:0" if torch.cuda.is_available() else "cpu"
    tied_ae: bool = False
    seed: int = 103
    learned_dict_ratio: float = 1.0
    output_folder: str = "outputs"
    # dtype: torch.dtype = torch.float32
    activation_dim: int = 32
    feature_prob_decay: float = 0.99
    feature_num_nonzero: int = 8
    correlated_components: bool = False
    n_ground_truth_components: int = 128
    batch_size: int = 4_096
    lr: float = 4e-4
    epochs: int = 300_000
    n_components_dictionary: int = 256
    n_components_dictionary_trans: int = 256
    l1_alpha: float = 5e-3
    use_topk: bool = False
    topk: list[int] = field(
        default_factory=lambda: [
            1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16
            # 7, 8, 9, 10, 11
        ]
    )
    norm_output: bool = False
    lp_alphas: list[float] = field(
        # default_factory=lambda: [
        #     1e-4,
        #     1.8e-4,
        #     3e-4,
        #     5.6e-4,
        #     1e-3,
        #     1.8e-3,
        #     3e-3,
        #     5.6e-3,
        #     1e-2,
        #     1.8e-2,
        #     3e-2,
        #     5.6e-2,
        #     1e-1,
        #     1.8e-1,
        #     3e-1,
        #     5.6e-1,
        #     1,
        #     1.8,
        #     3,
        # ]
        # default_factory=lambda: [
        #     5.6,
        #     1e1,
        #     1.8e1,
        #     3e1,
        #     5.6e1,
        #     1e2,
        #     1.8e2,
        #     3e2,
        # ]
        # default_factory=lambda: [ # fine
        #     1e-2,
        #     1.8e-2,
        #     3e-2,
        #     5.6e-2,
        #     1e-1,
        #     1.8e-1,
        #     3e-1,
        #     5.6e-1,
        #     1,
        #     1.8,
        #     3,
        # ]
        default_factory=lambda: [
            3,
        ]
    )
    # lp_alphas: list[float] = field(
    #     default_factory=lambda: [
    #         1e-7,
    #         3e-7,
    #         1e-6,
    #         3e-6,
    #         1e-5,
    #         3e-5,
    #         1e-4,
    #         3e-4,
    #         1e-3,
    #         3e-3,
    #         1e-2,
    #         3e-2,
    #         1e-1,
    #         3e-1,
    #         1,
    #         3,
    #     ]
    # )
    # lp_alphas: list[float] = field(
    #     default_factory=lambda: [
    #         1
    #     ]
    # )
    p_values: list[float] = field(
        # default_factory=lambda: [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1, 1.1]
        # default_factory=lambda: [0.1, 0.4, 0.7, 1,]
        default_factory=lambda: [1,]
    )
    anneal: bool = False
    loss_fn: str =  "lp^p" # options for sparsity penalty: lp_norm, lp^p, log, gated, None
    eps: list[float] = field(
        default_factory=lambda: [
            0.01, 0.1
        ]
    )
    
    def ps(self):
        if self.use_topk:
            return self.topk
        elif self.loss_fn=="log":
            return self.eps
        else:
            return self.p_values

# Define data generators and autoencoders

In [None]:
@dataclass
class RandomDatasetGenerator(Generator):
    activation_dim: int
    n_ground_truth_components: int
    batch_size: int
    feature_num_nonzero: int
    feature_prob_decay: float
    correlated: bool
    device: Union[torch.device, str]

    feats: Optional[TensorType["n_ground_truth_components", "activation_dim"]] = None
    generated_so_far: int = 0

    frac_nonzero: float = field(init=False)
    decay: TensorType["n_ground_truth_components"] = field(init=False)
    corr_matrix: Optional[
        TensorType["n_ground_truth_components", "n_ground_truth_components"]
    ] = field(init=False)
    component_probs: Optional[TensorType["n_ground_truth_components"]] = field(init=False)

    def __post_init__(self):
        self.frac_nonzero = self.feature_num_nonzero / self.n_ground_truth_components

        # Define the probabilities of each component being included in the data
        self.decay = torch.tensor(
            [self.feature_prob_decay**i for i in range(self.n_ground_truth_components)]
        ).to(self.device)  # FIXME: 1 / i

        self.component_probs = self.decay * self.frac_nonzero  # Only if non-correlated
        if self.feats is None:
            self.feats = generate_rand_feats(
                self.activation_dim,
                self.n_ground_truth_components,
                device=self.device,
            )

    def send(self, ignored_arg: Any) -> TensorType["dataset_size", "activation_dim"]:
        torch.manual_seed(self.generated_so_far)  # Set a deterministic seed for reproducibility
        self.generated_so_far += 1

        # Assuming generate_rand_dataset is your data generation function
        _, ground_truth, data = generate_rand_dataset(
            self.n_ground_truth_components,
            self.batch_size,
            self.component_probs,
            self.feats,
            self.device,
        )
        return ground_truth, data

    def throw(self, type: Any = None, value: Any = None, traceback: Any = None) -> None:
        raise StopIteration


def generate_rand_dataset(
    n_ground_truth_components: int,  #
    dataset_size: int,
    feature_probs: TensorType["n_ground_truth_components"],
    feats: TensorType["n_ground_truth_components", "activation_dim"],
    device: Union[torch.device, str],
) -> Tuple[
    TensorType["n_ground_truth_components", "activation_dim"],
    TensorType["dataset_size", "n_ground_truth_components"],
    TensorType["dataset_size", "activation_dim"],
]:
    # generate random feature strengths
    feature_strengths = torch.rand((dataset_size, n_ground_truth_components), device=device)
    # only some features are activated, chosen at random
    dataset_thresh = torch.rand(dataset_size, n_ground_truth_components, device=device)
    data_zero = torch.zeros_like(dataset_thresh, device=device)

    dataset_codes = torch.where(
        dataset_thresh <= feature_probs,
        feature_strengths,
        data_zero,
    )  # dim: dataset_size x n_ground_truth_components

    dataset = dataset_codes @ feats

    return feats, dataset_codes, dataset


def generate_rand_feats(
    feat_dim: int,
    num_feats: int,
    device: Union[torch.device, str],
) -> TensorType["n_ground_truth_components", "activation_dim"]:
    data_path = os.path.join(os.getcwd(), "data")
    data_filename = os.path.join(data_path, f"feats_{feat_dim}_{num_feats}.npy")

    feats = np.random.multivariate_normal(np.zeros(feat_dim), np.eye(feat_dim), size=num_feats)
    feats = feats.T / np.linalg.norm(feats, axis=1)
    feats = feats.T

    feats_tensor = torch.from_numpy(feats).to(device).float()
    return feats_tensor


# AutoEncoder Definition
class AutoEncoder(nn.Module):
    def __init__(self, activation_size, n_dict_components):
        super(AutoEncoder, self).__init__()

        self.encoder = nn.Sequential(nn.Linear(activation_size, n_dict_components), nn.ReLU())
        self.decoder = nn.Linear(n_dict_components, activation_size, bias=False)

        # Initialize the decoder weights orthogonally
        nn.init.orthogonal_(self.decoder.weight)

    def forward(self, x, topk=None, norm_output=False):
        c = self.encoder(x)
        if topk is not None:
            _, indices = torch.topk(c, topk, dim=-1, largest=True, sorted=False)
            zero_tensor = torch.zeros_like(c)
            zero_tensor.scatter_(-1, indices, c.gather(-1, indices))
            c = zero_tensor

        # Apply unit norm constraint to the decoder weights
        self.decoder.weight.data = nn.functional.normalize(self.decoder.weight.data, dim=0)

        x_hat = self.decoder(c)
        
        if norm_output:            
            x_hat_norm = torch.norm(x_hat, p=2, dim=1, keepdim=True)
            # If the norm is 0, we can't divide by it
            x_hat = x_hat / torch.clamp(x_hat_norm, min=1e-8)
            c = c / torch.clamp(x_hat_norm, min=1e-8)
            
        return x_hat, c
    
    def get_dictionary(self):
        self.decoder.weight.data = nn.functional.normalize(self.decoder.weight.data, dim=0)
        return self.decoder.weight

    @property
    def device(self):
        return next(self.parameters()).device


def cosine_sim(
    vecs1: Union[torch.Tensor, torch.nn.parameter.Parameter, npt.NDArray],
    vecs2: Union[torch.Tensor, torch.nn.parameter.Parameter, npt.NDArray],
) -> np.ndarray:
    vecs = [vecs1, vecs2]
    for i in range(len(vecs)):
        if not isinstance(vecs[i], np.ndarray):
            vecs[i] = vecs[i].detach().cpu().numpy()  # type: ignore
    vecs1, vecs2 = vecs
    normalize = lambda v: (v.T / np.linalg.norm(v, axis=1)).T
    vecs1_norm = normalize(vecs1)
    vecs2_norm = normalize(vecs2)

    return vecs1_norm @ vecs2_norm.T


def mean_max_cosine_similarity(ground_truth_features, learned_dictionary, debug=False):
    # Calculate cosine similarity between all pairs of ground truth and learned features
    cos_sim = cosine_sim(ground_truth_features, learned_dictionary)
    # Find the maximum cosine similarity for each ground truth feature, then average
    mmcs = cos_sim.max(axis=1).mean()
    return mmcs


def calculate_mmcs(auto_encoder, ground_truth_features):
    learned_dictionary = auto_encoder.decoder.weight.data.t()
    with torch.no_grad():
        mmcs = mean_max_cosine_similarity(
            ground_truth_features.to(auto_encoder.device), learned_dictionary
        )
    return mmcs


def get_alive_neurons(auto_encoder, data_generator, n_batches=10, topk=None):
    """
    :param result_dict: dictionary containing the results of a single run
    :return: number of dead neurons

    Estimates the number of dead neurons in the network by running a few batches of data through the network and
    calculating the mean activation of each neuron. If the mean activation is 0 for a neuron, it is considered dead.
    """
    outputs = []
    for i in range(n_batches):
        ground_truth, batch = next(data_generator)
        x_hat, c = auto_encoder(
            batch, topk=topk
        )  # x_hat: (batch_size, activation_dim), c: (batch_size, n_dict_components)
        outputs.append(c)
    outputs = torch.cat(outputs)  # (n_batches * batch_size, n_dict_components)
    mean_activations = outputs.mean(
        dim=0
    )  # (n_dict_components), c is after the ReLU, no need to take abs
    alive_neurons = mean_activations > 0
    return alive_neurons

In [None]:
cfg = ToyArgs()
cfg.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

torch.manual_seed(cfg.seed)
np.random.seed(cfg.seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

## Loading Saved SAEs

In [None]:
# load saes
seed_string = f"{cfg.seed}" if cfg.seed >= 0 else ""
print(f"loading from toy_saes{seed_string}")
auto_encoders = [[None for l1_coef in cfg.lp_alphas] for p in cfg.ps()]
for p_id, p in enumerate(cfg.ps()):
    for lp_id, lp_alpha in enumerate(cfg.lp_alphas):
        save_name = f"sae_l{p}_{lp_alpha}"
        auto_encoder = torch.load(
            f"/root/sparsify/trained_models/toy_saes{seed_string}/{save_name}.pt"
        )
        auto_encoders[p_id][lp_id] = auto_encoder

# load ground truth features
ground_truth_features = torch.load(
    f"/root/sparsify/trained_models/toy_saes{seed_string}/ground_truth_features.pt"
)
data_generator = RandomDatasetGenerator(
    activation_dim=cfg.activation_dim,
    n_ground_truth_components=cfg.n_ground_truth_components,
    batch_size=cfg.batch_size,
    feature_num_nonzero=cfg.feature_num_nonzero,
    feature_prob_decay=cfg.feature_prob_decay,
    correlated=cfg.correlated_components,
    device=device,
    generated_so_far=cfg.seed,
    feats=ground_truth_features,
)

In [None]:
data_generator.component_probs.sum()

In [None]:
opt_topk = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))
if cfg.use_topk and False:
    # recompute losses
    additional_topks = range(24)
    final_reconstruction = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas), len(additional_topks)))
    final_lp = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas), len(additional_topks)))

    num_inner_epochs = 100
    for epoch in tqdm(range(num_inner_epochs)):
        ground_truth, batch = next(data_generator)

        for p_id, p in enumerate(cfg.ps()):
            for lp_id, lp_alpha in enumerate(cfg.lp_alphas):
                with torch.no_grad():
                    auto_encoder = auto_encoders[p_id][lp_id]
                    for t_id, add_topk in enumerate(additional_topks):

                        # Forward pass
                        topk = p if cfg.use_topk else None
                        x_hat, c = auto_encoder(batch, topk=topk + add_topk, norm_output=cfg.norm_output)
                        # also normalize output for accuracte reconstruction loss:
                        if cfg.norm_output:
                            x_hat = x_hat * torch.norm(batch, p=2, dim=1, keepdim=True)
                            c = c * torch.norm(batch, p=2, dim=1, keepdim=True)
                            # batch = batch/torch.clamp(torch.norm(batch, p=2, dim=1, keepdim=True), min=1e-8)

                        # Compute the reconstruction loss and L1 regularization
                        l_reconstruction = torch.nn.MSELoss()(batch, x_hat)
                        l_lp = lp_alpha * torch.norm(c, p, dim=1).mean() / c.size(1)

                        final_reconstruction[p_id, lp_id, t_id] += l_reconstruction.detach().cpu().clone()
                        final_lp[p_id, lp_id, t_id] += l_lp.detach().cpu().clone()

    final_reconstruction = final_reconstruction/num_inner_epochs
    final_lp = final_lp/num_inner_epochs

    # calculate best add_topk:
    opt_topk = torch.argmax(-final_reconstruction, dim=-1)
    final_reconstruction = final_reconstruction[torch.arange(final_reconstruction.size(0)).unsqueeze(1), torch.arange(final_reconstruction.size(1)), opt_topk]
    final_lp = final_lp[torch.arange(final_lp.size(0)).unsqueeze(1), torch.arange(final_lp.size(1)), opt_topk]
    
else:
    # recompute losses
    final_reconstruction = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))
    final_lp = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))

    num_inner_epochs = 100
    for epoch in tqdm(range(num_inner_epochs)):
        ground_truth, batch = next(data_generator)

        for p_id, p in enumerate(cfg.ps()):
            for lp_id, lp_alpha in enumerate(cfg.lp_alphas):
                with torch.no_grad():
                    auto_encoder = auto_encoders[p_id][lp_id]

                    # Forward pass
                    topk = p if cfg.use_topk else None                    
                    x_hat, c = auto_encoder(batch, topk=topk, norm_output=cfg.norm_output)
                    # print("xhat norm", torch.norm(x_hat, p=2, dim=1, keepdim=True))
                    # also normalize output for accuracte reconstruction loss:
                    if cfg.norm_output:
                        x_hat = x_hat * torch.norm(batch, p=2, dim=1, keepdim=True)
                        c = c * torch.norm(batch, p=2, dim=1, keepdim=True)
                        # batch = batch/torch.clamp(torch.norm(batch, p=2, dim=1, keepdim=True), min=1e-8)

                    # Compute the reconstruction loss and L1 regularization
                    l_reconstruction = torch.nn.MSELoss()(batch, x_hat)
                    
                    if cfg.loss_fn == "lp_norm":
                        l_lp = lp_alpha * torch.norm(c, p, dim=1).mean() / c.size(1)
                    elif cfg.loss_fn == "lp^p":
                        l_lp = lp_alpha *torch.pow(c, p).sum(dim=-1).mean() / c.size(1)

                    final_reconstruction[p_id, lp_id] += l_reconstruction.detach().cpu().clone()
                    final_lp[p_id, lp_id] += l_lp.detach().cpu().clone()

    final_reconstruction = final_reconstruction/num_inner_epochs
    final_lp = final_lp/num_inner_epochs

In [None]:
def effective_rank(matrix: torch.Tensor) -> torch.Tensor:
    # For each row of a matrix, return the effective count of the non-zero values.
    # If all the nonzero values are the same, this equals the number of nonzero values

    matrix = torch.clamp(matrix, min=0.0)  # should be unnecessary
    normalized_matrix = (matrix / matrix.sum(dim=-1, keepdim=True)).nan_to_num(0)
    plogp = (normalized_matrix * normalized_matrix.log()).nan_to_num(0)
    shannon_entropy = -plogp.sum(dim=-1)
    return shannon_entropy.exp()


def centroid(matrix: torch.Tensor) -> torch.Tensor:
    # For each row of a matrix, return the effective count of the non-zero values.
    # If all the nonzero values are the same, this equals the number of nonzero values

    matrix = torch.clamp(matrix, min=0.0)  # should be unnecessary
    matrix, indices = torch.sort(matrix, dim=1, descending=True)
    # normalized_matrix = (matrix / matrix.sum(dim=-1, keepdim=True)).nan_to_num(0)
    positions = torch.arange(matrix.shape[1], device=matrix.device).unsqueeze(0)
    centroid = (matrix * positions).sum(dim=-1, keepdim=True) / matrix.sum(dim=-1, keepdim=True)
    return centroid.nan_to_num(0) + 1


def count_top(matrix: torch.Tensor, threshold=0.5) -> torch.Tensor:
    # For each row of a matrix, return the number of values greater than threshold times max

    matrix = torch.clamp(matrix, min=0.0)  # should be unnecessary
    max_sim = matrix.max(dim=-1, keepdim=True)[0]

    return (matrix >= max_sim * threshold).sum(dim=-1, keepdim=True)


def effective_count(matrix: torch.Tensor) -> torch.Tensor:
    return count_top(matrix)


def similarity_concentration_oneway(matrix, filter_dead_rows=False):
    # half of the full metric, using only one direction
    # (either just monosemanticity or superposition checking)
    if matrix.shape[-1] == 0:
        return 0
    max_sim = matrix.max(dim=-1, keepdim=True)[0]
    eff_feature_count = effective_count(matrix)
    rowwise_concentration = max_sim / eff_feature_count
    if filter_dead_rows:
        rowwise_concentration = rowwise_concentration[max_sim > 0]
    return rowwise_concentration.mean()


def similarity_concentration(matrix, ground=None, empirical=None, alive_neurons=None):
    """
    Compute the similarity_concentration metric given:
        matrix[ground_feature, sae_feature]: number of co-occurrences, or cosine similarities
        ground[ground_feature]: total number of appearances of ground features, or None if cosine_sims
        empirical[sae_feature]: total number of appearances of empirical features, or None if cosine_sims
    method: (max(M)/effective_count(M)) for both matrix and matrix.mT, then combined
    """

    if ground is None:
        ground = torch.ones_like(matrix[:, 0])
    if empirical is None:
        empirical = torch.ones_like(matrix[0, :])

    if alive_neurons is None:
        alive_neurons = torch.full_like(matrix[0], True)

    matrix = matrix[:, alive_neurons]
    empirical = empirical[alive_neurons]

    ground_truth_sc = similarity_concentration_oneway(matrix / ground[:, None])
    sae_feature_sc = similarity_concentration_oneway(
        matrix.mT / empirical[:, None], filter_dead_rows=True
    )

    return (ground_truth_sc * sae_feature_sc).sqrt(), ground_truth_sc, sae_feature_sc  # geo mean

# def zero_out_except_topk(tensor, topk):
#     # Keep only the topk values, set others to zero
#     _, indices = torch.topk(tensor, topk, dim=-1, largest=True, sorted=False)
#     zero_tensor = torch.zeros_like(tensor)
#     zero_tensor.scatter_(-1, indices, tensor.gather(-1, indices))
#     return zero_tensor

class BatchCorrelationCalculator:
    def __init__(self, feature_dim_x, feature_dim_y, device="cuda"):
        # Initialize sums needed for correlation calculation for each feature dimension
        self.sum_x = torch.zeros(feature_dim_x, device=device)
        self.sum_y = torch.zeros(feature_dim_y, device=device)
        self.sum_x2 = torch.zeros(feature_dim_x, device=device)
        self.sum_y2 = torch.zeros(feature_dim_y, device=device)
        self.sum_xy = torch.zeros(
            (feature_dim_x, feature_dim_y), device=device
        )  # This now becomes a matrix
        self.n = 0
        self.device = device

    def update(self, x_batch, y_batch):
        # Update running sums with a new batch
        self.sum_x += torch.sum(x_batch, dim=0)
        self.sum_y += torch.sum(y_batch, dim=0)
        self.sum_x2 += torch.sum(x_batch**2, dim=0)
        self.sum_y2 += torch.sum(y_batch**2, dim=0)
        self.sum_xy += torch.einsum("bg, bs -> gs", x_batch, y_batch)
        self.n += x_batch.shape[0]

    def compute_correlation(self):
        # Compute Pearson correlation coefficient matrix between features of the two vectors
        numerator = self.n * self.sum_xy - torch.ger(self.sum_x, self.sum_y)
        denominator = torch.sqrt(
            torch.ger(self.n * self.sum_x2 - self.sum_x**2, self.n * self.sum_y2 - self.sum_y**2)
        )

        # Handle division by zero for cases with no variance
        valid = denominator != 0
        correlation_matrix = torch.zeros_like(denominator)
        correlation_matrix[valid] = numerator[valid] / denominator[valid]

        # Set correlations to 0 where denominator is 0 (indicating no variance)
        correlation_matrix[~valid] = 0

        return correlation_matrix

In [None]:
# compute similarity concentration for each sae
sim_conc_co_occur = torch.zeros(
    (len(cfg.ps()), len(cfg.lp_alphas), 3)
)  # p, lp_alpha, ground or sae sc
sim_conc_cosim = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas), 3))  # p, lp_alpha, ground or sae sc
sim_conc_corr = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas), 3))  # p, lp_alpha, ground or sae sc
mmcs = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))
features_per_ground = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))
l2_ratio = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))
l0 = torch.zeros((len(cfg.ps()), len(cfg.lp_alphas)))

for p_id, p in enumerate(tqdm(cfg.ps())):
    for lp_id, lp_alpha in enumerate(cfg.lp_alphas):
        auto_encoder = auto_encoders[p_id][lp_id]

        # adjacency_matrix[ground_truth_feature,sae_feature] = number of times they co-occur
        adjacency_matrix = torch.zeros(
            (cfg.n_ground_truth_components, cfg.n_components_dictionary), device=device
        )

        # iteratively calculate correlations
        corr_calculator = BatchCorrelationCalculator(
            cfg.n_ground_truth_components, cfg.n_components_dictionary, device=cfg.device
        )

        # feat_corr_calculator = BatchCorrelationCalculator(
        #     cfg.n_components_dictionary, cfg.n_components_dictionary, device=cfg.device
        # )

        cosim_matrix = data_generator.feats @ nn.functional.normalize(
            auto_encoder.decoder.weight.data, dim=0
        )  # of shape (ground_feats, sae_feats)
        mmcs[p_id, lp_id] = cosim_matrix.max(dim=-1)[0].mean()
        # might need to flip data_generator.feats?
        ground_truth_appearances = torch.zeros((cfg.n_ground_truth_components), device=device)
        sae_appearances = torch.zeros((cfg.n_components_dictionary), device=device)

        inp_norm = 0
        outp_norm = 0

        # Filter out dead neurons: either do it out here, or inside the metric
        topk = int(p + opt_topk[p_id,lp_id]) if cfg.use_topk else None
        alive = get_alive_neurons(auto_encoder, data_generator, topk=topk)

        num_inner_epochs = 4
        for epoch in range(num_inner_epochs):  # range(cfg.epochs):
            ground_truth, batch = next(data_generator)

            # Forward pass
            with torch.no_grad():
                x_hat1, c1 = auto_encoder(batch, topk=topk, norm_output=cfg.norm_output)
                
                # also normalize output for accuracte reconstruction loss:
                if cfg.norm_output:
                    x_hat1 = x_hat1 * torch.norm(batch, p=2, dim=1, keepdim=True)
                    c1 = c1 * torch.norm(batch, p=2, dim=1, keepdim=True)
                    # batch = batch/torch.clamp(torch.norm(batch, p=2, dim=1, keepdim=True), min=1e-8)
                
                # x_hat2, c2 = auto_encoder2(Mbatch)
                inp_norm += batch.norm(p=2, dim=-1).mean()
                outp_norm += x_hat1.norm(p=2, dim=-1).mean()
                
                corr_calculator.update(ground_truth, c1) # calc corr before binarizing
                # feat_corr_calculator.update(c1, c1)

                ground_truth[ground_truth != 0] = 1
                c1[c1 != 0] = 1

                adjacency_matrix += torch.einsum("bg, bs -> gs", ground_truth, c1)
                
                ground_truth_appearances += ground_truth.sum(dim=0)
                sae_appearances += c1.sum(dim=0)
                l0[p_id, lp_id] += c1.sum(dim=-1).mean().cpu()

        l0[p_id, lp_id] = l0[p_id, lp_id] / num_inner_epochs

        sc, gsc, ssc = similarity_concentration(
            adjacency_matrix, ground_truth_appearances, sae_appearances, alive_neurons=alive
        )
        sim_conc_co_occur[p_id, lp_id, :] = torch.tensor([sc, gsc, ssc])

        sc, gsc, ssc = similarity_concentration(cosim_matrix, alive_neurons=alive)
        sim_conc_cosim[p_id, lp_id, :] = torch.tensor([sc, gsc, ssc])

        corr_matrix = corr_calculator.compute_correlation()
        sc, gsc, ssc = similarity_concentration(corr_matrix, alive_neurons=alive)
        sim_conc_corr[p_id, lp_id, :] = torch.tensor([sc, gsc, ssc])

        features_per_ground[p_id, lp_id] = (adjacency_matrix > 0).sum(dim=-1).mean(dtype=float)
        l2_ratio[p_id, lp_id] = outp_norm / inp_norm

In [None]:
pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sim_conc_co_occur, f"{pre}/sim_conc_co_occur_{cfg.seed}.pt")
torch.save(sim_conc_corr, f"{pre}/sim_conc_corr_{cfg.seed}.pt")
torch.save(sim_conc_cosim, f"{pre}/sim_conc_cosim_{cfg.seed}.pt")
torch.save(mmcs, f"{pre}/mmcs_{cfg.seed}.pt")
torch.save(features_per_ground, f"{pre}/features_per_ground_{cfg.seed}.pt")
torch.save(l2_ratio, f"{pre}/l2_ratio_{cfg.seed}.pt")
torch.save(l0, f"{pre}/l0_{cfg.seed}.pt")
torch.save(final_reconstruction, f"{pre}/final_reconstruction_{cfg.seed}.pt")
torch.save(final_lp, f"{pre}/final_lp_{cfg.seed}.pt")

# Load computed matrices

In [None]:
def load_mean(seeds, matrix_string):    
    matrices = []
    for seed in seeds:
        pre = f"metrics_{seed}"
        path = f"{pre}/{matrix_string}_{seed}.pt"
        matrices.append(torch.load(path))
    return sum(matrices)/len(matrices)

seeds = [cfg.seed,]
# seeds = [15,16,17,18]  # lp_norm
# seeds = [22,24,25]    # lp^p
# seeds = [22]
sim_conc_co_occur = load_mean(seeds, "sim_conc_co_occur")
sim_conc_corr = load_mean(seeds, "sim_conc_corr")
sim_conc_cosim = load_mean(seeds, "sim_conc_cosim")
mmcs = load_mean(seeds, "mmcs")
features_per_ground = load_mean(seeds, "features_per_ground")
l2_ratio = load_mean(seeds, "l2_ratio")
l0 = load_mean(seeds, "l0")
final_reconstruction = load_mean(seeds, "final_reconstruction")
final_lp = load_mean(seeds, "final_lp")

if not os.path.exists(f"/root/sparsify/notebooks/images_{cfg.seed}"):
    os.makedirs(f"/root/sparsify/notebooks/images_{cfg.seed}")


In [None]:
if cfg.loss_fn=="lp_norm":
    l0 = torch.cat([load_mean(seeds, "l0") for seeds in [[15,16,17,18],[21]]], dim=1)
    final_reconstruction = torch.cat([load_mean(seeds, "final_reconstruction") for seeds in [[15,16,17,18],[21]]], dim=1)
    
if cfg.loss_fn=="lp^p":
    l0 = torch.cat([load_mean(seeds, "l0") for seeds in [[22,24,25],[26]]], dim=1)
    final_reconstruction = torch.cat([load_mean(seeds, "final_reconstruction") for seeds in [[22,24,25],[26]]], dim=1)

In [None]:
plt.imshow(l2_ratio)  # cosim_matrix.detach().cpu()
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("L2 Ratio")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/l2_ratio.png")

In [None]:
l2_ratio

In [None]:
l2_ratio

In [None]:
mmcs

In [None]:
plt.imshow(final_reconstruction.log10())  # cosim_matrix.detach().cpu()
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("Reconstruction Loss")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/reconstruction.png")

In [None]:
cmap_reversed = plt.cm.get_cmap('viridis_r')
plt.imshow(l0, cmap=cmap_reversed)  # cosim_matrix.detach().cpu()
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("L0")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/L0.png")

In [None]:
# pareto curve
for p in cfg.ps():
    p_id = cfg.ps().index(p)
    print(p_id)
    l0_p = l0[p_id]
    mse_p = final_reconstruction[p_id]
    # l0 = (df_filtered["l0"] + df_filtered["l02"])/2
    # mse = (df_filtered["mse"] + df_filtered["mse2"])/2
    plt.plot(l0_p, np.log10(mse_p), label=f"L{p}")

plt.axvline(6, linestyle="dashed", c="green")
plt.axvline(32, linestyle="dashed", c="black")
plt.xlabel("L0")
plt.ylabel("Log10(MSE)")
plt.legend()
plt.title("Toy SAEs Pareto Curve")
plt.savefig(f"images_{cfg.seed}/pareto.png")

In [None]:
import seaborn as sns
# palette = sns.color_palette()
# sns.color_palette()
palette = sns.color_palette("magma_r")
display(palette)
# palette = [palette[i] for i in [0,1,3,5]]

In [None]:
# pareto curve
plt.style.use('default')
plt.axvline(8, linestyle="dashed", c="blue")
plt.axvline(32, linestyle="dashed", c="black")

method_string = cfg.loss_fn
if cfg.loss_fn=="lp_norm":
    method_string = "$L_p$"
    use_ps = cfg.ps()[3::2]
    palette_generator = iter([palette[i] for i in [1,2,3,5]])
if cfg.loss_fn=="lp^p":
    method_string = "$L_p^p$"
    use_ps = cfg.ps()[1::2]
    palette_generator = iter([palette[i] for i in [0,1,2,3,5]])


for p in use_ps:
    p_id = cfg.ps().index(p)
    print(p_id)
    color = next(palette_generator)
    l0_p = l0[p_id]
    mse_p = final_reconstruction[p_id]
    # l0 = (df_filtered["l0"] + df_filtered["l02"])/2
    # mse = (df_filtered["mse"] + df_filtered["mse2"])/2
    
    # label = f"$L_{{{p}}}$" if cfg.loss_fn=="lp_norm" else f"$L_{{{p}}}^{{{p}}}$"
    label = f"p={p}"
    plt.plot(l0_p, np.log10(mse_p), label=label, color=color)
    
    # add scatter plot for individual points
    seeds = [15,16,17,18,21] if cfg.loss_fn=="lp_norm" else [22,24,25]
    for seed in seeds:
        l0_seed = load_mean([seed], "l0")[p_id]
        mse_seed = load_mean([seed], "final_reconstruction")[p_id]
        plt.scatter(l0_seed, np.log10(mse_seed), s=10, color=color)


plt.xlabel("$L_0$")
plt.ylabel("$\log_{10}(MSE)$")
plt.xlim(left=-2,right=35)
plt.legend()
plt.title(f"Synthetic Data MSE vs. $L_0$, with penalty {method_string}")
plt.savefig(f"images_{cfg.seed}/pareto.png")

In [None]:
plt.imshow(features_per_ground)  # cosim_matrix.detach().cpu()
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("Number of Active Features")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/active.png")

In [None]:
p_id = 7
l1_id = 10
cosim_matrix = (
    (
        data_generator.feats
        @ nn.functional.normalize(auto_encoders[p_id][l1_id].decoder.weight.data, dim=0)
    )
    .detach()
    .cpu()
)
sorted_cosim_matrix = cosim_matrix.sort(descending=True, dim=-1)[0].cpu()
plt.imshow(sorted_cosim_matrix)  # cosim_matrix.detach().cpu()
plt.colorbar()
plt.xlabel("Sorted SAE features")
plt.ylabel("Ground truth features")
plt.title(f"Sorted Cosim matrix for p {cfg.ps()[p_id]}, lp_alpha {cfg.lp_alphas[l1_id]}")

# plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
# plt.yticks(
#     range(len(cfg.ps)),
#     cfg.ps,
# )

In [None]:
sorted_cosim = cosim_matrix[0].clamp(min=0).sort(descending=True)[0].cpu()
threshold = 0.5
plt.plot(sorted_cosim)
plt.axvline(effective_rank(cosim_matrix[0:1]).cpu().item(), label="Effective Rank", color="black")
plt.axvline(centroid(cosim_matrix[0:1]).cpu().item(), label="Center of Mass", color="brown")
plt.axvline(
    count_top(cosim_matrix[0:1], threshold=0.5).cpu().item(),
    label=f">{threshold}*max",
    color="green",
)
plt.title("Different Effective Count Methods")
plt.xlabel("Sorted Features")
plt.ylabel("Cosine Similarity to True Feature 0")
plt.legend()

In [None]:
plt.imshow(mmcs)
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("MMCS")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/mmcs.png")
mmcs

In [None]:
# use either sim_conc_cosim or sim_conc_co_occur
# plt.imshow(sim_conc_cosim[...,0]) #.mean(dim=-1))
plt.imshow(mmcs)
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("MMCS")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/mmcs.png")

# for every ground truth feature, you take the maximum cosine similarity with any of the learned SAE features
# take the mean over all ground truth features

# Lp = |x - x'|_p  "almost a norm"
#    =  [ Sum_i (x_i - x'_i)^p ]^1/p

# Lp^p = |x - x'|_p^p  - has diminishing returns to scale, treats each feature independently
#      = Sum_i (x_i - x'_i)^p

# Lp^p -> L0
# Lp does not limit to L0

In [None]:
mmcs.max()  # no-anneal

In [None]:
mmcs

In [None]:
# use either sim_conc_cosim or sim_conc_co_occur
# plt.imshow(sim_conc_cosim[...,0]) #.mean(dim=-1))
plt.imshow(sim_conc_cosim[..., 0])
plt.colorbar()

plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("Similarity Concentration (Using Cosine-sim Matrix)")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/cosim.png")

In [None]:
# use either sim_conc_cosim or sim_conc_co_occur
# plt.imshow(sim_conc_cosim[...,0]) #.mean(dim=-1))
plt.imshow(sim_conc_co_occur[..., 0])
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("Similarity Concentration (Using Co-occurence Matrix)")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/co_occur.png")

In [None]:
# use either sim_conc_cosim or sim_conc_co_occur
# plt.imshow(sim_conc_cosim[...,0]) #.mean(dim=-1))
plt.imshow(sim_conc_corr[..., 0])
plt.colorbar()
plt.xlabel("Lp Coefficient")
plt.ylabel("p norm")
plt.title("Similarity Concentration (Using Correlation Matrix)")

plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
plt.yticks(
    range(len(cfg.ps())),
    cfg.ps(),
)
plt.savefig(f"images_{cfg.seed}/corr.png")

In [None]:
# For topk, seed 14 with opt_topk
for matrix in [mmcs, sim_conc_cosim, sim_conc_co_occur, sim_conc_corr]:
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    p_id, lp_id = max_arg
    print(f"Best p={cfg.ps()[p_id]}, alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max: .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id]: .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id]: .4f}"
    )
# For topk, seed 14 with opt_topk

In [None]:
# For topk, seed 14
for matrix in [mmcs, sim_conc_cosim, sim_conc_co_occur, sim_conc_corr]:
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    p_id, lp_id = max_arg
    print(f"Best p={cfg.ps()[p_id]}, alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max: .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id]: .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id]: .4f}"
    )
# For topk, seed 14

In [None]:
# For topk, seed 15
for matrix in [mmcs, sim_conc_cosim, sim_conc_co_occur, sim_conc_corr]:
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    p_id, lp_id = max_arg
    print(f"Best p={cfg.ps()[p_id]}, alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max: .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id]: .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id]: .4f}"
    )
    # For topk, seed 15

In [None]:
# For lp_norm, seed 12
for matrix in [mmcs, sim_conc_cosim, sim_conc_co_occur, sim_conc_corr]:
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    p_id, lp_id = max_arg
    print(f"Best p={cfg.ps()[p_id]}, alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max: .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id]: .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id]: .4f}"
    )

In [None]:
# For lp_norm, seed 6
for matrix in [mmcs, sim_conc_cosim, sim_conc_co_occur, sim_conc_corr]:
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    p_id, lp_id = max_arg
    print(f"Best p={cfg.ps[p_id]}, alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max: .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id]: .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id]: .4f}"
    )
    # seed 6

In [None]:
p_id = 4
lp_id = 6
print(f"Best p={cfg.ps[p_id]}, alpha={cfg.lp_alphas[lp_id]}")
print(
    f"     Metric: {matrix_max: .8f} \
        Reconstruction: {final_reconstruction[p_id][lp_id]: .4E}\
        L2 Ratio: {l2_ratio[p_id][lp_id]: .4f}"
)

# Examining just the L1

In [None]:
# get metrics for each layer

for p_id in range(len(cfg.ps())):
    p_name = cfg.ps()[p_id]
    # print(f"using p of {p_name}")
    matrix = mmcs[p_id] #sim_conc_cosim[p_id] mmcs
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    lp_id = max_arg

    print(f"Given p={cfg.ps()[p_id]}, best alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max.item(): .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id].item(): .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id].item(): .4f}"
    )
    # l1_sae = auto_encoders[p_id][lp_id]
    # alive = get_alive_neurons(l1_sae, data_generator, n_batches=100)

In [None]:
# get best L1 coeff

for p_id in [-2]:
    p_name = cfg.ps()[p_id]
    print(f"using p of {p_name}")
    matrix = mmcs[p_id]
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    lp_id = max_arg

    print(f"Given p={cfg.ps()[p_id]}, best alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max.item(): .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id].item(): .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id].item(): .4f}"
    )
    l1_sae = auto_encoders[p_id][lp_id]
    alive = get_alive_neurons(l1_sae, data_generator, n_batches=100)

In [None]:
# get best L0.6 coeff

for p_id in [5]:
    p_name = cfg.ps()[p_id]
    print(f"using p of {p_name}")
    matrix = mmcs[p_id]
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    lp_id = max_arg

    print(f"Given p={cfg.ps()[p_id]}, best alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max.item(): .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id].item(): .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id].item(): .4f}"
    )
    l1_sae = auto_encoders[p_id][lp_id]
    alive = get_alive_neurons(l1_sae, data_generator, n_batches=100)

In [None]:
# get best L0.1^p coeff

for p_id in [0]:
    p_name = cfg.ps()[p_id]
    print(f"using p of {p_name}")
    matrix = mmcs[p_id]
    if matrix.shape[-1] == 3:
        matrix = matrix[..., 0]
    matrix_max = torch.max(
        matrix.nan_to_num(0),
    )
    max_arg = (matrix == matrix_max).nonzero()[0]
    lp_id = max_arg

    print(f"Given p={cfg.ps()[p_id]}, best alpha={cfg.lp_alphas[lp_id]}")
    print(
        f"     Metric: {matrix_max.item(): .8f} \
          Reconstruction: {final_reconstruction[p_id][lp_id].item(): .4E}\
          L2 Ratio: {l2_ratio[p_id][lp_id].item(): .4f}"
    )
    l1_sae = auto_encoders[p_id][lp_id]
    alive = get_alive_neurons(l1_sae, data_generator, n_batches=100)

In [None]:
# compute feature corrs
auto_encoder = l1_sae

# adjacency_matrix[ground_truth_feature,sae_feature] = number of times they co-occur
adjacency_matrix = torch.zeros(
    (cfg.n_ground_truth_components, cfg.n_components_dictionary), device=device
)
adjacency_matrix_3d = torch.zeros(
    (cfg.n_ground_truth_components, cfg.n_ground_truth_components, cfg.n_components_dictionary), 
    device=device
)
auto_adjacency_matrix = torch.zeros(
    (cfg.n_components_dictionary, cfg.n_components_dictionary), device=device
)

# iteratively calculate correlations
corr_calculator = BatchCorrelationCalculator(
    cfg.n_ground_truth_components, cfg.n_components_dictionary, device=cfg.device
)

feat_corr_calculator = BatchCorrelationCalculator(
    cfg.n_components_dictionary, cfg.n_components_dictionary, device=cfg.device
)

sae_feats = nn.functional.normalize(
    auto_encoder.decoder.weight.data, dim=0
)
cosim_matrix = data_generator.feats @ sae_feats  # of shape (ground_feats, sae_feats)
f_cosim_matrix = sae_feats.mT @ sae_feats # of shape (sae_feats, sae_feats)
# mmcs[p_id, lp_id] = cosim_matrix.max(dim=-1)[0].mean()
# might need to flip data_generator.feats?

ground_truth_appearances = torch.zeros((cfg.n_ground_truth_components), device=device)
sae_appearances = torch.zeros((cfg.n_components_dictionary), device=device)

g_or_appearances = torch.zeros((cfg.n_ground_truth_components, cfg.n_ground_truth_components), device=device)
g_or_adjacency = torch.zeros((cfg.n_ground_truth_components, cfg.n_ground_truth_components, cfg.n_components_dictionary), device=device)

g_and_appearances = torch.zeros((cfg.n_ground_truth_components, cfg.n_ground_truth_components), device=device)
g_and_adjacency = torch.zeros((cfg.n_ground_truth_components, cfg.n_ground_truth_components, cfg.n_components_dictionary), device=device)

inp_norm = 0
outp_norm = 0

num_inner_epochs = 10000
for epoch in tqdm(range(num_inner_epochs)):  # range(cfg.epochs):
    ground_truth, batch = next(data_generator)

    # Forward pass
    with torch.no_grad():
        x_hat1, c1 = auto_encoder(batch)
        # x_hat2, c2 = auto_encoder2(Mbatch)
        inp_norm += batch.norm(p=2, dim=-1).mean()
        outp_norm += x_hat1.norm(p=2, dim=-1).mean()
        
        corr_calculator.update(ground_truth, c1) # calc corr before binarizing
        feat_corr_calculator.update(c1, c1)

        ground_truth[ground_truth != 0] = 1
        c1[c1 != 0] = 1

        adjacency_matrix += torch.einsum("bg, bs -> gs", ground_truth, c1)
        auto_adjacency_matrix += torch.einsum("bg, bs -> gs", c1, c1)
        adjacency_matrix_3d += torch.einsum("bg, bh, bs -> ghs", ground_truth, ground_truth, c1)
        
        g_or = 1-torch.einsum("bg, bs -> bgs", 1-ground_truth, 1-ground_truth)
        g_and = torch.einsum("bg, bs -> bgs", ground_truth, ground_truth)
        g_or_appearances += g_or.sum(dim=0)
        g_and_appearances += g_and.sum(dim=0)
        
        g_or_adjacency += torch.einsum("bgs, bf -> gsf", g_or, c1)
        g_and_adjacency += torch.einsum("bgs, bf -> gsf", g_and, c1)
        
        ground_truth_appearances += ground_truth.sum(dim=0)
        sae_appearances += c1.sum(dim=0)

g_corr_matrix = corr_calculator.compute_correlation()
f_corr_matrix = feat_corr_calculator.compute_correlation()

# compute union (A or B) by PIE
union = ground_truth_appearances[:,None] + sae_appearances[None, :] - adjacency_matrix
auto_union = sae_appearances[:,None] + sae_appearances[None, :] - auto_adjacency_matrix

jaccard = (adjacency_matrix/union)
auto_jaccard = (auto_adjacency_matrix/auto_union)

# compute union_3d [(g_A or g_B) and f_C]=[(g_A and f_C) or (g_B and f_C)] by PIE
union_3d = adjacency_matrix.unsqueeze(0) + adjacency_matrix.unsqueeze(1) - adjacency_matrix_3d
jaccard_3d = (adjacency_matrix_3d/union_3d)

union_combo = g_and_appearances[:,:,None] + sae_appearances[None, None, :] - g_and_adjacency
jaccard_combo = union_combo/g_and_adjacency

union_or = g_or_appearances[:,:,None] + sae_appearances[None, None, :] - g_or_adjacency
jaccard_or = union_or/g_or_adjacency


In [None]:
jaccard_or

In [None]:
g_or_adjacency

In [None]:
# Hierarchical Clustering
def seriation(Z,N,cur_index):
    '''
        input:
            - Z is a hierarchical tree (dendrogram)
            - N is the number of points given to the clustering process
            - cur_index is the position in the tree for the recursive traversal
        output:
            - order implied by the hierarchical tree Z
            
        seriation computes the order implied by a hierarchical tree (dendrogram)
    '''
    if cur_index < N:
        return [cur_index]
    else:
        left = int(Z[cur_index-N,0])
        right = int(Z[cur_index-N,1])
        return (seriation(Z,N,left) + seriation(Z,N,right))
    
def compute_serial_matrix(dist_mat,method="ward"):
    '''
        input:
            - dist_mat is a distance matrix
            - method = ["ward","single","average","complete"]
        output:
            - seriated_dist is the input dist_mat,
              but with re-ordered rows and columns
              according to the seriation, i.e. the
              order implied by the hierarchical tree
            - res_order is the order implied by
              the hierarhical tree
            - res_linkage is the hierarhical tree (dendrogram)
        
        compute_serial_matrix transforms a distance matrix into 
        a sorted distance matrix according to the order implied 
        by the hierarchical tree (dendrogram)
    '''
    N = len(dist_mat)
    flat_dist_mat = dist_mat #squareform(dist_mat)
    res_linkage = linkage(flat_dist_mat, method=method,preserve_input=True)
    res_order = seriation(res_linkage, N, N + N-2)
    return res_order, res_linkage

def sort_cross_similarity(similarity_mat):
    '''
        Given a non-square matrix that is 1 for highly similar variables 
        and 0 for dis-similar variables, sort just the columns.
        Shape ground_truth x sae_features
    '''
    most_similar_gt_feature = np.argmax(similarity_mat, axis=0)
    similarity_strength = np.max(similarity_mat, axis=0)

    # Step 2: Sort the experimental features based on the identified ground truth feature and similarity strength
    # We'll create a structured array for sorting
    dtype = [('gt_index', int), ('similarity', float), ('exp_index', int)]
    data = np.array([(gt_index, -similarity_strength[i], i) for i, gt_index in enumerate(most_similar_gt_feature)], dtype=dtype)

    # Sorting - Primary by gt_index then by similarity strength
    sorted_data = np.sort(data, order=['gt_index', 'similarity'])

    # Extracting the sorted experimental feature indices
    sorted_exp_indices = sorted_data['exp_index']
    return sorted_exp_indices

Compute hierarchical clustering

In [None]:
# use_matrix = f_corr_matrix[alive][:,alive].cpu() #cosim_matrix

cosim_matrix = (
    (
        nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0).mT
        @ nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0)
    )
    .detach()
    .cpu()
)
cosim_matrix_g = (
    (
        data_generator.feats
        @ nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0)
    )
    .detach()
    .cpu()
)

use_matrix = g_corr_matrix[:,alive] #cosim_matrix_g

# methods = ["average"] #["ward","single","average","complete"]
# for method in methods:
#     print("Method:\t",method)
    
res_order = sort_cross_similarity(use_matrix.cpu().numpy())
sorted_cosim_matrix_g = cosim_matrix_g[:, res_order]
sorted_cosim_matrix = cosim_matrix[res_order, res_order]

sorted_corr_g = g_corr_matrix[:,alive][:, res_order].cpu()
sorted_corr_f = f_corr_matrix[:,alive][alive,:][res_order,:][:, res_order].cpu()

sorted_dirs = nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0).T[res_order,:]
sorted_thresholds = (-l1_sae.encoder[0].bias/l1_sae.encoder[0].weight.norm(dim=1))[alive][res_order]

In [None]:
sorted_jaccard_3d = jaccard_3d[:,:,alive][:,:,res_order].cpu().numpy()
sorted_max_jaccard_3d = jaccard_3d[:,:,alive][:,:,res_order].cpu().clone()
for i in range(sorted_max_jaccard_3d.shape[0]):
    sorted_max_jaccard_3d[i,i,:] = 0
sorted_max_jaccard_3d = sorted_max_jaccard_3d.amax(dim=(0,1))

sorted_jaccard = jaccard[:,alive][:,res_order].cpu().numpy() 
sorted_max_jaccard = torch.tensor(sorted_jaccard).amax(dim=0).numpy()
sorted_auto_jaccard = auto_jaccard[alive,:][:,alive][:,res_order][res_order,:].cpu().numpy() 

sorted_jaccard_combo = jaccard_combo[:,:,alive][:,:,res_order].cpu().numpy()
sorted_max_jaccard_combo = jaccard_combo[:,:,alive][:,:,res_order].cpu().clone()
for i in range(sorted_max_jaccard_combo.shape[0]):
    sorted_max_jaccard_combo[i,i,:] = 0
sorted_max_jaccard_combo = sorted_max_jaccard_combo.amax(dim=(0,1))

sorted_jaccard_or = jaccard_or[:,:,alive][:,:,res_order].cpu().numpy()
sorted_max_jaccard_or = jaccard_or[:,:,alive][:,:,res_order].cpu().clone()
for i in range(sorted_max_jaccard_or.shape[0]):
    sorted_max_jaccard_or[i,i,:] = 0
sorted_max_jaccard_or = sorted_max_jaccard_or.amax(dim=(0,1))

In [None]:
sorted_thresholds.max(dim=0)

In [None]:
cosim_matrix = (
    (
        nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0).mT
        @ nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0)
    )
    .detach()
    .cpu()
)
if cfg.loss_fn=="lp^p":
    append = "^p"
else:
    append = ""
    
sorted_cosim_matrix = cosim_matrix[res_order][:,res_order] #.sort(descending=True, dim=-1)[0].cpu()
plt.imshow(sorted_cosim_matrix, cmap="PiYG", vmin=-1, vmax=1)
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270)
plt.xlabel("SAE features")
plt.ylabel("SAE features")
plt.title(f"SAE Features Cosine Similarity (p={p_name})")
plt.savefig(f"images_{cfg.seed}/ffl{p_name}{append}.png")


pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_cosim_matrix, f"{pre}/ffl{p_name}{append}.pt")

In [None]:
cosim_matrix = (
    (
        data_generator.feats
        @ nn.functional.normalize(l1_sae.decoder.weight.data[:,alive], dim=0)
    )
    .detach()
    .cpu()
)
if cfg.loss_fn=="lp^p":
    append = "^p"
else:
    append = ""
    
sorted_cosim_matrix = cosim_matrix[:][:,res_order] #cosim_matrix.sort(descending=True, dim=-1)[0].cpu()
plt.imshow(sorted_cosim_matrix, cmap="PiYG", vmin=-1, vmax=1)  # cosim_matrix.detach().cpu()
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270)
plt.xlabel("Sorted SAE features")
plt.ylabel("Ground truth features")
plt.title(f"Ground Truth vs SAE Features Cosine Similarity (p={p_name})")
plt.savefig(f"images_{cfg.seed}/gfl{p_name}{append}.png")

# plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
# plt.yticks(
#     range(len(cfg.ps)),
#     cfg.ps,
# )

pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_cosim_matrix, f"{pre}/gfl{p_name}{append}.pt")

In [None]:
if cfg.loss_fn=="lp^p":
    append = "^p"
else:
    append = ""
    
plt.imshow(sorted_corr_f, cmap="PiYG", vmin=-1, vmax=1)
cbar = plt.colorbar()
cbar.ax.set_ylabel('Correlation', rotation=270)
plt.xlabel("SAE features")
plt.ylabel("SAE features")
plt.title(f"SAE Features Correlation (p={p_name})")
plt.savefig(f"images_{cfg.seed}/ffl{p_name}{append}_corr.png")


pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_corr_f, f"{pre}/ffl{p_name}{append}_corr.pt")

In [None]:
if cfg.loss_fn=="lp^p":
    append = "^p"
else:
    append = ""
    
plt.imshow(sorted_corr_g, cmap="PiYG", vmin=-1, vmax=1)
cbar = plt.colorbar()
cbar.ax.set_ylabel('Correlation', rotation=270)
plt.xlabel("SAE features")
plt.ylabel("Ground truth features")
plt.title(f"Ground Truth vs SAE Features Correlation (p={p_name})")
plt.savefig(f"images_{cfg.seed}/gfl{p_name}{append}_corr.png")


pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_corr_g, f"{pre}/gfl{p_name}{append}_corr.pt")

In [None]:
# 3D jaccard
# plt.imshow(sorted_max_jaccard_3d, cmap="PiYG", vmin=-1, vmax=1)
if cfg.loss_fn=="lp^p":
    append = "^p"
else:
    append = ""
    
plt.plot(sorted_max_jaccard, label="Max Jaccard")
plt.plot(sorted_max_jaccard_combo, label="Max Jaccard with Combo")
plt.plot(sorted_max_jaccard_or, label="Max Jaccard with Or")
plt.legend()
plt.title(f"Max Ground-Ground Jaccard Conditional on SAE Feature, p={p_name}")
# cbar = plt.colorbar()
# cbar.ax.set_ylabel('Cosine Similarity', rotation=270, labelpad=15)

# plt.imshow(sorted_jaccard_3d[:,:,11], cmap="PiYG", vmin=-1, vmax=1)
# cbar = plt.colorbar()
# cbar.ax.set_ylabel('Cosine Similarity', rotation=270, labelpad=15)

pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_jaccard_3d, f"{pre}/jaccard_3d{p_name}{append}.pt")
torch.save(sorted_max_jaccard_3d, f"{pre}/max_jaccard_3d{p_name}{append}.pt")

In [None]:
most_combo_index = torch.max(sorted_max_jaccard_3d, dim=0)[1]
re_sorted, indices = sorted_max_jaccard_3d.sort()
print(re_sorted)
print(indices)
print(sorted_thresholds[indices[-1]].item(), sorted_thresholds[indices[-2]].item())

In [None]:
print(sorted(list(sorted_jaccard[:,most_combo_index]))[::-1])
print(sorted(list(sorted_cosim_matrix[:,most_combo_index].numpy()))[::-1])

In [None]:
# sorted_jaccard_3d[:,3].max()
plt.imshow(sorted_jaccard_3d[:,:,17], cmap="PiYG", vmin=-1, vmax=1)
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270, labelpad=15)

In [None]:
plt.plot(sorted_jaccard[:,3])

In [None]:
if cfg.loss_fn=="lp^p":
    append = "^p"
else:
    append = ""
plt.imshow(sorted_jaccard, cmap="PiYG", vmin=-1, vmax=1)  # cosim_matrix.detach().cpu()
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270)
plt.xlabel("Sorted SAE features")
plt.ylabel("Ground truth features")
plt.title(f"Ground Truth vs SAE Features Jaccard Index (p={p_name})")
plt.savefig(f"images_{cfg.seed}/gfl{p_name}{append}_jaccard.png")


pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_jaccard, f"{pre}/gfl{p_name}{append}_jaccard.pt")

In [None]:
plt.imshow(sorted_auto_jaccard, cmap="PiYG", vmin=-1, vmax=1)  # cosim_matrix.detach().cpu()
cbar = plt.colorbar()
cbar.ax.set_ylabel('Cosine Similarity', rotation=270)
plt.xlabel("Sorted SAE features")
plt.ylabel("Sorted SAE features")
plt.title(f"SAE Features Jaccard Index (p={p_name})")
plt.savefig(f"images_{cfg.seed}/ffl{p_name}{append}_jaccard.png")


pre = f"metrics_{cfg.seed}"
if not os.path.exists(f"/root/sparsify/notebooks/{pre}"):
    os.makedirs(f"/root/sparsify/notebooks/{pre}")
torch.save(sorted_auto_jaccard, f"{pre}/ffl{p_name}{append}_jaccard.pt")

In [None]:
plt.imshow(g_corr_matrix[:,alive][:,res_order].cpu(), cmap="PiYG", vmin=-1, vmax=1)
cbar = plt.colorbar()
cbar.ax.set_ylabel('Activation Correlation', rotation=270, labelpad=15)
plt.xlabel("SAE features")
plt.ylabel("Ground truth features")
plt.title(f"SAE Feature Correlation (L{p_name})")
plt.savefig(f"images_{cfg.seed}/gfl{p_name}_correlation.png")

# plt.xticks(range(len(cfg.lp_alphas)), cfg.lp_alphas, rotation=30)
# plt.yticks(
#     range(len(cfg.ps)),
#     cfg.ps,
# )

In [None]:
# def get_feature(sae, f, alive=None):
#     # f is the feature id
#     # returns encoder direction, decoder direction, enc bias, decoder bias
#     # if alive is not None, take the f'th alive feature
#     if alive is not None:
        
    
#     return sae.encoder

In [None]:
f_corr_matrix[alive][:,alive].abs().sum(dim=-1)

In [None]:
f_corr_matrix.shape

In [None]:
alive.sum()