In [None]:

import matplotlib.pyplot as plt
import torch
from muutils.dbg import dbg_auto

from spd.clustering.activations import component_activations, process_activations
from spd.clustering.merge import compute_merge_costs, merge_iteration
from spd.clustering.merge_matrix import GroupMerge
from spd.experiments.resid_mlp.resid_mlp_dataset import ResidualMLPDataset
from spd.models.component_model import ComponentModel
from spd.utils.data_utils import DatasetGeneratedDataLoader

DEVICE: str = "cuda" if torch.cuda.is_available() else "cpu"


In [None]:
# magic autoreload
%load_ext autoreload
%autoreload 2

In [None]:
component_model, cfg, path = ComponentModel.from_pretrained("wandb:goodfire/spd/runs/dcjm9g2n")
component_model.to(DEVICE);
# dbg_auto(component_model)
# dbg_auto(cfg)
# dbg_auto(path)
# dir(component_model)

In [None]:

# grep_repr((component_model, cfg, path, dir(component_model)), "_features")
# cfg.task_config
# grep_repr(, "_features")

In [None]:
N_SAMPLES: int = 1000

dataset = ResidualMLPDataset(
    n_features=component_model.model.config.n_features,
    feature_probability=cfg.task_config.feature_probability,
    device=DEVICE,
    calc_labels=False,  # Our labels will be the output of the target model
    label_type=None,
    act_fn_name=None,
    label_fn_seed=None,
    label_coeffs=None,
    data_generation_type=cfg.task_config.data_generation_type,
    # synced_inputs=synced_inputs,
)

dbg_auto(dict(
    n_features=dataset.n_features,
    feature_probability=dataset.feature_probability,
    data_generation_type=dataset.data_generation_type,
))

dataloader = DatasetGeneratedDataLoader(dataset, batch_size=N_SAMPLES, shuffle=False)


In [None]:
ci = component_activations(
	component_model,
	dataloader,
	device=DEVICE,
	# threshold=0.1,
)

dbg_auto(ci);

In [None]:
coa = process_activations(
	ci,
	filter_dead_threshold=0.001,
	plots=True,
);

$$
	F_g := \frac{\alpha}{n}
	\Bigg[
		d(A(g)) \cdot Q^T 
		+ Q \cdot d(A(g))^T
		- \Big(
			R \mathbf{1}^T
			+ \mathbf{1} R^T + \alpha^{-1}
		\Big) 
		\odot A(g)
	\Bigg]
$$

In [None]:
gm_ident = GroupMerge.identity(n_components=coa["n_components_alive"])
gm_ident.plot(figsize=(10, 2), component_labels=coa["labels"])
costs = compute_merge_costs(
	coact=coa['coactivations'],
	merges=gm_ident,
)
plt.matshow(costs.cpu(), cmap='viridis')
plt.colorbar();

In [None]:
act_bool = coa['activations'] > 0.01
merge_iteration(
	coact=act_bool.float().T @ act_bool.float(),
	activation_mask=act_bool,
	check_threshold=0.0001,
	# initial_merge=?,
	# rank_cost=lambda _: 100,
	rank_cost_fn=lambda _: 0.001,
	# alpha=100,
	pop_component_prob=1,
	alpha=0.00001,
	iters=200,
	plot_every=20,
	plot_every_min=20,
	component_labels=coa["labels"],
	# plot_every=None,
);

In [None]:
#%% Example Hyperparameter Sweep Usage

# Create sweep configuration
import numpy as np

from spd.clustering.sweep import SweepConfig, run_hyperparameter_sweep

sweep_config = SweepConfig(
    # activation_thresholds=[0.001, 0.002, 0.005],
    activation_thresholds=np.logspace(-4, -1, 4).tolist(),
    # check_thresholds=[0.05, 0.1, 0.2], 
    check_thresholds=np.logspace(-2, 0, 3).tolist(),
    # alphas=[0.1, 1.0, 10.0],
    alphas=np.logspace(-3, 0, 4).tolist(),
    rank_cost_funcs={
        # "constant_1": lambda _: 1.0,
        # "constant_0.1": lambda _: 0.1,
        # "constant_0.001": lambda _: 0.001,
        # "constant_0.01": lambda _: 0.01,
        # "constant_0.1": lambda _: 0.1,
        "constant_1": lambda _: 1.0,
        # "constant_10": lambda _: 10.0,
        # "constant_100": lambda _: 100.0,
        "linear": lambda c: c,
        "log": lambda c: np.log(c + 1),
    },
    iters=50,
)

# Run sweep
sweep_results = run_hyperparameter_sweep(coa['coactivations'], sweep_config)

print(f"\\n{len(sweep_results) = }")
for i, result in enumerate(sweep_results[:3]):
    print(f"{i+1}: thresh={result.activation_threshold}, check={result.check_threshold}, α={result.alpha}, rank={result.rank_cost_name}")
    print(f"   iters={result.total_iterations}, groups={result.final_k_groups}")

In [None]:
# Evolution histories
from spd.clustering.sweep import plot_evolution_histories

# plot_evolution_histories(
#     sweep_results,
#     metric='non_diag_costs_min',
#     lines_by='alpha',
#     cols_by='activation_threshold',
#     rows_by='check_threshold',
# 	fixed_params={'rank_cost_name': 'constant_0.1'},  # Example fixed parameter
# )

plot_evolution_histories(
    sweep_results,
    # metric='non_diag_costs_min',
	metric='costs_range',
    cols_by='activation_threshold',
    rows_by='alpha',
    lines_by='check_threshold',
	fixed_params={'rank_cost_name': 'constant_1'},
)

plot_evolution_histories(
    sweep_results,
    # metric='non_diag_costs_min',
	metric='costs_range',
    cols_by='activation_threshold',
    rows_by='alpha',
    lines_by='check_threshold',
	fixed_params={'rank_cost_name': 'linear'},
)

plot_evolution_histories(
    sweep_results,
    # metric='non_diag_costs_min',
	metric='costs_range',
    cols_by='activation_threshold',
    rows_by='alpha',
    lines_by='check_threshold',
	fixed_params={'rank_cost_name': 'log'},
)

# plot_evolution_histories(
#     sweep_results,
#     metric='non_diag_costs_min',
#     lines_by='activation_threshold',
#     cols_by='rank_cost_name',
#     rows_by='check_threshold',
# 	fixed_params={'alpha': 0.001},  # Example fixed parameter
# )

In [None]:
# Heatmaps with smart parameter selection
from spd.clustering.sweep import create_smart_heatmap

create_smart_heatmap(
    sweep_results,
    statistic_func=lambda r: r.final_k_groups,
    statistic_name="Final Groups"
)

create_smart_heatmap(
    sweep_results,
    statistic_func=lambda r: r.total_iterations,
    statistic_name="Total Iterations"
)

create_smart_heatmap(
    sweep_results,
    statistic_func=lambda r: r.non_diag_costs_min[-1] if r.non_diag_costs_min else 0,
    statistic_name="Final Cost",
    log_scale=True,
    normalize=True
)

In [None]:
#%% Stopping Condition Examples

# Example: Stop when cost reaches 2x original
from spd.clustering.sweep import cost_ratio_stopping_condition

stop_at_2x = cost_ratio_stopping_condition(2.0)

# Run with stopping condition
result_with_stop = merge_iteration(
    coact=(coa['coactivations'] > 0.002).float().T @ (coa['coactivations'] > 0.002).float(),
    activation_mask=coa['coactivations'] > 0.002,
    alpha=1.0,
    check_threshold=0.1,
    stopping_condition=stop_at_2x,
    plot_every=None,
    plot_final=False,
)

print(f"{result_with_stop['total_iterations'] = }")
print(f"{result_with_stop['final_k_groups'] = }")

# You can also use the stopping condition in sweeps by modifying sweep_config.iters
# or adding the stopping condition to merge_iteration calls in the sweep