In [24]:
import torch
from bottleneck_models import Bottleneck_Model, gen_artificial_data, train_bottleneck_model, plot_learned_features
from sae import Config, SAE, train_sae

In [25]:
cfg = Config()
device = cfg.device

In [26]:
# used to train model. see bottleneck_models.py for explanations/source
importances = (0.9 ** torch.arange(cfg.input_dim)).to(device)
model = Bottleneck_Model(cfg).to(device)
train_bottleneck_model(cfg, model, feature_prob=0.01, importances=importances)

# visualize the model's feature representations
plot_learned_features(
    cfg,
    [model.W],
    title='5 features represented in 2D space',
    feature_probs=[0.01],
    importances=importances
)

100%|██████████| 10000/10000 [00:31<00:00, 315.56it/s, loss=2.21e-5]


In [27]:
# setup for sae
sae = SAE(cfg).to(device)
train_sae(cfg, sae, model, feature_prob=0.01)

100%|██████████| 10000/10000 [00:51<00:00, 192.81it/s, loss=0.00823]


In [33]:
# visualize the bottleneck model's hidden state w/ some input data
data = gen_artificial_data(cfg, feature_prob=0.01) # [batch_size, hidden_dim]
model_embeddings = data @ model.W

plot_learned_features(
    cfg,
    [model_embeddings],
    'Model hidden state on some random data'
)

In [31]:
# reconstruct model outputs with SAE
# looks like we see dead neurons? or there is a bug in training.
reconstruction, _ = sae(model_embeddings)

plot_learned_features(
    cfg,
    [reconstruction],
    'SAE reconstruction of the same hidden data'
)