In [11]:
from dataclasses import dataclass

import torch
import plotly.graph_objects as go
from plotly.subplots import make_subplots

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


Using device: cpu


In [12]:
@dataclass
class SweepResult:
    mean_norm: torch.Tensor
    mean_superposition: torch.Tensor
    mean_eval_loss: torch.Tensor


def sample_hierarchical_batch(
    batch_size: int,
    p_feature1: torch.Tensor,
    p_feature2_conditional: torch.Tensor,
) -> torch.Tensor:
    n_models = p_feature1.numel()

    primary_active = torch.rand(batch_size, n_models, device=device) < p_feature1.unsqueeze(0)
    secondary_active = primary_active & (
        torch.rand(batch_size, n_models, device=device) < p_feature2_conditional.unsqueeze(0)
    )

    x = torch.rand(batch_size, n_models, 2, device=device)
    x[:, :, 0] = x[:, :, 0] * primary_active
    x[:, :, 1] = x[:, :, 1] * secondary_active
    return x


def initialize_tied_relu_models(n_models: int) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
    # n_hidden = 1, so each model has one row-vector of feature weights.
    weights = torch.randn(n_models, 2, device=device)
    weights = weights / weights.abs().clamp_min(1e-8)
    biases = torch.zeros(n_models, 2, device=device)

    return torch.nn.Parameter(weights), torch.nn.Parameter(biases)


def tied_relu_forward(
    x: torch.Tensor, weights: torch.Tensor, biases: torch.Tensor
) -> torch.Tensor:
    z = torch.sum(x * weights.unsqueeze(0), dim=-1)
    return torch.relu(z.unsqueeze(-1) * weights.unsqueeze(0) + biases.unsqueeze(0))


def weighted_reconstruction_loss(
    x_true: torch.Tensor,
    x_pred: torch.Tensor,
    importances: torch.Tensor,
) -> torch.Tensor:
    weighted_error = (x_true - x_pred).pow(2) * importances.unsqueeze(0)
    return weighted_error.sum(dim=-1).mean(dim=0)


def drop_worst_and_average(values: torch.Tensor, losses: torch.Tensor) -> torch.Tensor:
    keep_mask = torch.ones_like(losses, dtype=torch.bool)
    worst_index = losses.argmax(dim=1, keepdim=True)
    keep_mask.scatter_(1, worst_index, False)

    keep_count = keep_mask.sum(dim=1).clamp_min(1)
    return (values * keep_mask).sum(dim=1) / keep_count


def run_phase_sweep_for_p_feature1(
    p_feature1: float,
    importance_axis: torch.Tensor,
    conditional_density_axis: torch.Tensor,
    ensemble_size: int = 10,
    train_steps: int = 800,
    batch_size: int = 256,
    eval_batch_size: int = 4096,
    learning_rate: float = 1e-2,
    weight_decay: float = 1e-4,
) -> SweepResult:
    ny = conditional_density_axis.numel()
    nx = importance_axis.numel()
    n_configs = ny * nx

    cond_grid = conditional_density_axis[:, None].repeat(1, nx).reshape(-1).to(device)
    importance_grid = importance_axis[None, :].repeat(ny, 1).reshape(-1).to(device)
    p1_grid = torch.full((n_configs,), p_feature1, device=device)

    p_feature1_models = p1_grid.repeat_interleave(ensemble_size)
    p_feature2_models = cond_grid.repeat_interleave(ensemble_size)
    importance_models = importance_grid.repeat_interleave(ensemble_size)

    # Feature 1 importance is fixed to 1.0, feature 2 is swept on x-axis.
    importances = torch.stack((torch.ones_like(importance_models), importance_models), dim=-1)

    n_models = n_configs * ensemble_size
    weights, biases = initialize_tied_relu_models(n_models)

    optimizer = torch.optim.Adam([weights, biases], lr=learning_rate, weight_decay=weight_decay)

    for _ in range(train_steps):
        x = sample_hierarchical_batch(batch_size, p_feature1_models, p_feature2_models)
        x_hat = tied_relu_forward(x, weights, biases)

        per_model_loss = weighted_reconstruction_loss(x, x_hat, importances)
        loss = per_model_loss.mean()

        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()

    with torch.no_grad():
        x_eval = sample_hierarchical_batch(
            eval_batch_size, p_feature1_models, p_feature2_models
        )
        x_hat_eval = tied_relu_forward(x_eval, weights, biases)
        eval_loss = weighted_reconstruction_loss(x_eval, x_hat_eval, importances)

        feature2_norm = weights[:, 1].abs()

        # Superposition for feature 2: sum_{j != 2} (W_hat_2 Â· W_j)^2
        unit_feature2 = torch.zeros_like(weights[:, 1])
        represented = feature2_norm > 1e-8
        unit_feature2[represented] = weights[represented, 1] / feature2_norm[represented]

        projection_feature1 = unit_feature2 * weights[:, 0]
        feature2_superposition = projection_feature1.pow(2)

        eval_loss = eval_loss.view(n_configs, ensemble_size)
        feature2_norm = feature2_norm.view(n_configs, ensemble_size)
        feature2_superposition = feature2_superposition.view(n_configs, ensemble_size)

        mean_norm = drop_worst_and_average(feature2_norm, eval_loss)
        mean_superposition = drop_worst_and_average(feature2_superposition, eval_loss)
        mean_eval_loss = drop_worst_and_average(eval_loss, eval_loss)

    return SweepResult(
        mean_norm=mean_norm.view(ny, nx).cpu(),
        mean_superposition=mean_superposition.view(ny, nx).cpu(),
        mean_eval_loss=mean_eval_loss.view(ny, nx).cpu(),
    )


def phase_rgb(norm: torch.Tensor, superposition: torch.Tensor) -> torch.Tensor:
    norm = norm.clamp(0.0, 1.0)
    superposition = superposition.clamp(0.0, 1.0)

    # Vertical flip: high norm is mapped to the top edge of the legend.
    n = (1.0 - norm).unsqueeze(-1)
    s = superposition.unsqueeze(-1)

    color_top_left = torch.tensor([0.10, 0.13, 0.96])
    color_top_right = torch.tensor([0.97, 0.07, 0.18])
    color_bottom_left = torch.tensor([0.84, 0.85, 0.88])
    color_bottom_right = torch.tensor([0.97, 0.94, 0.95])

    rgb = (
        (1 - n) * (1 - s) * color_top_left
        + (1 - n) * s * color_top_right
        + n * (1 - s) * color_bottom_left
        + n * s * color_bottom_right
    )
    return rgb.clamp(0.0, 1.0)


def rgb_to_css(rgb: torch.Tensor) -> list[str]:
    rgb_255 = (rgb * 255).round().to(torch.int64).view(-1, 3).tolist()
    return [f"rgb({r},{g},{b})" for r, g, b in rgb_255]


In [13]:
p_feature1_values = [0.01, 0.1, 0.2, 0.4, 0.6, 0.8, 0.9, 0.95, 0.99]
importance_axis = torch.logspace(-1, 1, 40)
conditional_density_axis = torch.logspace(-2, 0, 40)

ensemble_size = 10
train_steps = 800
batch_size = 256
eval_batch_size = 4096
learning_rate = 1e-2
weight_decay = 1e-4

results: dict[float, SweepResult] = {}

for subplot_index, p_feature1 in enumerate(p_feature1_values):
    torch.manual_seed(10_000 + subplot_index)
    print(
        f"Training subplot {subplot_index + 1}/{len(p_feature1_values)} at p(feature1)={p_feature1:.2f}"
    )

    result = run_phase_sweep_for_p_feature1(
        p_feature1=p_feature1,
        importance_axis=importance_axis,
        conditional_density_axis=conditional_density_axis,
        ensemble_size=ensemble_size,
        train_steps=train_steps,
        batch_size=batch_size,
        eval_batch_size=eval_batch_size,
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )
    results[p_feature1] = result

print("Training complete.")


Training subplot 1/9 at p(feature1)=0.01
Training subplot 2/9 at p(feature1)=0.10
Training subplot 3/9 at p(feature1)=0.20
Training subplot 4/9 at p(feature1)=0.40
Training subplot 5/9 at p(feature1)=0.60
Training subplot 6/9 at p(feature1)=0.80
Training subplot 7/9 at p(feature1)=0.90
Training subplot 8/9 at p(feature1)=0.95
Training subplot 9/9 at p(feature1)=0.99
Training complete.


In [14]:
fig = make_subplots(
    rows=3,
    cols=3,
    subplot_titles=[f"p(feature1) = {p:.2f}" for p in p_feature1_values],
    horizontal_spacing=0.04,
    vertical_spacing=0.08,
)

x_grid = importance_axis.unsqueeze(0).repeat(conditional_density_axis.numel(), 1)
y_grid = conditional_density_axis.unsqueeze(1).repeat(1, importance_axis.numel())
tile_size = 7.0

for idx, p_feature1 in enumerate(p_feature1_values):
    row = (idx // 3) + 1
    col = (idx % 3) + 1

    result = results[p_feature1]
    rgb = phase_rgb(result.mean_norm, result.mean_superposition)

    custom_data = torch.stack(
        [
            result.mean_norm.reshape(-1),
            result.mean_superposition.reshape(-1),
            result.mean_eval_loss.reshape(-1),
        ],
        dim=-1,
    ).tolist()

    fig.add_trace(
        go.Scatter(
            x=x_grid.reshape(-1).tolist(),
            y=y_grid.reshape(-1).tolist(),
            mode="markers",
            marker={
                "symbol": "square",
                "size": tile_size,
                "color": rgb_to_css(rgb),
                "line": {"width": 0},
            },
            customdata=custom_data,
            hovertemplate=(
                "importance(feature2)=%{x:.4f}<br>"
                "p(feature2|feature1)=%{y:.4f}<br>"
                "||W2||=%{customdata[0]:.4f}<br>"
                "superposition=%{customdata[1]:.4f}<br>"
                "eval_loss=%{customdata[2]:.6f}<extra></extra>"
            ),
            showlegend=False,
        ),
        row=row,
        col=col,
    )

    fig.update_xaxes(
        type="log",
        tickmode="array",
        tickvals=[0.1, 1.0, 10.0],
        ticktext=["0.1", "1.0", "10"],
        row=row,
        col=col,
    )

    fig.update_yaxes(
        type="log",
        tickmode="array",
        tickvals=[0.01, 0.1, 1.0],
        ticktext=["0.01", "0.1", "1.0"],
        row=row,
        col=col,
    )

for axis_index in range(1, len(p_feature1_values) + 1):
    axis_suffix = "" if axis_index == 1 else str(axis_index)
    fig.layout[f"yaxis{axis_suffix}"].update(
        scaleanchor=f"x{axis_suffix}",
        scaleratio=1,
    )

for col in [1, 2, 3]:
    fig.update_xaxes(title_text="Feature 2 Importance", row=3, col=col)

for row in [1, 2, 3]:
    fig.update_yaxes(title_text="p(feature2 | feature1)", row=row, col=1)

fig.update_xaxes(showgrid=False, zeroline=False)
fig.update_yaxes(showgrid=False, zeroline=False)

fig.update_layout(
    title="Hierarchical Superposition Phase Diagram (2 features, 1 latent dim)",
    width=1500,
    height=900,
    plot_bgcolor="white",
    paper_bgcolor="white",
    margin=dict(l=30, r=15, t=70, b=35),
)

fig.show()


In [15]:
legend_resolution = 100
legend_norm = torch.linspace(0.0, 1.0, legend_resolution)
legend_super = torch.linspace(0.0, 1.0, legend_resolution)
legend_super_grid, legend_norm_grid = torch.meshgrid(legend_super, legend_norm, indexing="xy")
legend_rgb = phase_rgb(legend_norm_grid, legend_super_grid)

legend_fig = go.Figure(
    go.Scatter(
        x=legend_super_grid.reshape(-1).tolist(),
        y=legend_norm_grid.reshape(-1).tolist(),
        mode="markers",
        marker={
            "symbol": "square",
            "size": 5.2,
            "color": rgb_to_css(legend_rgb),
            "line": {"width": 0},
        },
        hoverinfo="skip",
        showlegend=False,
    )
)

legend_fig.update_xaxes(
    tickmode="array",
    tickvals=[0.0, 0.5, 1.0],
    ticktext=["0", "0.5", "â¥1"],
    title="Superposition",
    range=[0.0, 1.0],
    showgrid=False,
    zeroline=False,
)
legend_fig.update_yaxes(
    tickmode="array",
    tickvals=[0.0, 0.5, 1.0],
    ticktext=["0", "0.5", "â¥1"],
    title="||W2||",
    range=[0.0, 1.0],
    showgrid=False,
    zeroline=False,
    scaleanchor="x",
    scaleratio=1,
)
legend_fig.update_layout(
    title="Color Map: Norm vs Superposition",
    width=450,
    height=420,
    plot_bgcolor="white",
    paper_bgcolor="white",
)
legend_fig.show()


In [None]:
# Debug-only diagnostics: run this after the training cell and before plotting changes.
for p_feature1 in p_feature1_values:
    result = results[p_feature1]
    norm_flat = result.mean_norm.reshape(-1)
    super_flat = result.mean_superposition.reshape(-1)

    corr = float(torch.corrcoef(torch.stack([norm_flat, super_flat]))[0, 1])
    red_like = int(((norm_flat > 0.6) & (super_flat > 0.6)).sum())
    blue_like = int(((norm_flat > 0.6) & (super_flat < 0.2)).sum())

    print(
        f"p(feature1)={p_feature1:.2f} | corr(norm,super)={corr:+.3f} | red_like={red_like} | blue_like={blue_like}"
    )
