In [None]:
import torch
from occhio import ToyModel
from occhio.model_grid import ModelGrid, Axis
from occhio.autoencoder import TiedLinearRelu
from occhio.distributions import CorrelatedPairs, HierarchicalPairs, SparseUniform
from occhio.visualization.phase_change import plot_phase_change
from occhio.visualization.export import export_figure
from occhio.visualization.phase_change import plot_phase_change_multi

In [None]:
generator = torch.Generator("cpu").manual_seed(42)

In [None]:
N_FEATURES = 2
N_HIDDEN = 1
P_INDIVIDUAL = 1
P_FOLLOW = 1
DATA = "uniform"
EXPERIMENT_SIZE = 24

In [None]:
relative_importances = torch.logspace(-1, 1, EXPERIMENT_SIZE)
densities = torch.logspace(0, -2, EXPERIMENT_SIZE)

In [None]:
def model_trainer(relative_importance, density):
    model = ToyModel(
        distribution=SparseUniform(N_FEATURES, p_active=density, generator=generator),
        ae=TiedLinearRelu(N_FEATURES, N_HIDDEN, generator=generator),
        importances=relative_importance ** torch.arange(N_FEATURES),
    )
    model.fit(n_epochs=32000)
    return model


model_grid = ModelGrid(
    model_trainer,
    x_axis=Axis("Relative Importance", relative_importances),
    y_axis=Axis("Density", densities),
)

In [None]:
fig = plot_phase_change(model_grid, tracked_feature=0)

In [None]:
fig.show()

In [None]:
plot_phase_change_multi(model_grid, up_to=2)

In [None]:
export_figure(
    fig,
    {"data": DATA, "n_hidden": N_HIDDEN, "n_features": N_FEATURES},
    # {"data": DATA, "p_individual": P_INDIVIDUAL, "n_hidden": N_HIDDEN},
    # {"data": DATA, "p_follow": P_FOLLOW},
    subdir="phase-changes",
)