In [1]:
import torch
from torch import Tensor
from occhio import ToyModel
from occhio.autoencoder import TiedLinearRelu
from occhio.distributions import SparseUniform
from occhio.model_grid import ModelGrid, Axis
from occhio.visualization.phase_change import plot_phase_change

In [2]:
generator = torch.Generator("cpu").manual_seed(42)

In [3]:
N_FEATURES = 2
N_HIDDEN = 1
EXPERIMENT_SIZE = 2

In [4]:
# density = 1
# relative_importance = 10

In [5]:
# model = ToyModel(
#     distribution=SparseUniform(
#         n_features=N_FEATURES,
#         p_active=density,
#     ),
#     ae=TiedLinearRelu(N_FEATURES, N_HIDDEN, generator=generator),
#     importances=Tensor([1, relative_importance]),
# )
# model.fit(n_epochs=15000)

In [6]:
# model.W

In [7]:
relative_importances = torch.logspace(-1, 1, EXPERIMENT_SIZE)
densities = torch.logspace(0, -2, EXPERIMENT_SIZE)

In [8]:
def model_trainer(relative_importance, density):
    print(density, relative_importance)
    model = ToyModel(
        distribution=SparseUniform(N_FEATURES, p_active=density, generator=generator),
        ae=TiedLinearRelu(N_FEATURES, N_HIDDEN, generator=generator),
        importances=relative_importance ** torch.arange(N_FEATURES),
    )
    model.fit(n_epochs=15000)
    return model


model_grid = ModelGrid(
    model_trainer,
    x_axis=Axis("Relative Importance", relative_importances),
    y_axis=Axis("Density", densities),
)

Training Model Grid:   0%|          | 0/4 [00:00<?, ?model/s]

tensor(1.) tensor(0.1000)


Training Model Grid:  25%|██▌       | 1/4 [00:03<00:10,  3.38s/model]

tensor(1.) tensor(10.)


Training Model Grid:  50%|█████     | 2/4 [00:06<00:06,  3.01s/model]

tensor(0.0100) tensor(0.1000)


Training Model Grid:  75%|███████▌  | 3/4 [00:09<00:02,  2.96s/model]

tensor(0.0100) tensor(10.)


Training Model Grid: 100%|██████████| 4/4 [00:11<00:00,  2.92s/model]


In [9]:
fig = plot_phase_change(model_grid)

[[2.663311e-04 9.999769e-01]
 [9.878439e-01 1.002113e+00]]
[[1.0001149e+00 1.8654116e-06]
 [1.0014465e+00 9.8798364e-01]]
[[ 0.1 10. ]
 [ 0.1 10. ]]
[[1.   1.  ]
 [0.01 0.01]]


In [10]:
fig.show()