In [None]:
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import os
from tqdm import tqdm
import tyro
import numpy as np
import plotly.graph_objects as go

from sative.datasets import FeatureDataset
from sative.models import SparseAutoencoder

In [None]:
in_dim: int = 768
sae_expansion: int = 16
target_layer: int = 12

dataset = FeatureDataset("features", layer=target_layer)

In [None]:
# sae_feat_dir = Path("sae_features/12_02048_004/weights_004")
sae_feat_dir = Path("sae_features/12_32768_004/weights_004")
cumulative_item_sparsity = np.zeros(in_dim * sae_expansion, dtype=int)
# sparsity_per_item = []

batches_len = 0
for filename in tqdm(os.listdir(sae_feat_dir)):
    batch = np.load(sae_feat_dir / filename)
    cumulative_item_sparsity += np.count_nonzero(np.abs(batch) > 0, axis=0)
    # sparsity_per_item.append(np.count_nonzero(np.abs(batch) > 0, axis=1))

    batches_len += batch.shape[0]

# sparsity_per_item = np.concatenate(sparsity_per_item)
# log_sparsity_per_item = np.log10(sparsity_per_item / (in_dim * sae_expansion))

In [None]:
sparsity = np.log10(cumulative_item_sparsity[cumulative_item_sparsity.nonzero()] / len(dataset))
hist, bin_edges = np.histogram(sparsity, 100)

fig = go.Figure(data=[go.Bar(
    x=bin_edges,
    y=hist,
)])

fig.update_layout(
    title="Batch Size 32768",
    xaxis_title="Log 10 sparsity",
    yaxis_title="Count"
)

fig.show()