# Neurogenesis Demo
Train an autoencoder on a small MNIST subset and generate intrinsic replay samples.

In [2]:
import pytorch_lightning as pl
from omegaconf import OmegaConf
from pytorch_lightning.loggers import MLFlowLogger

# project modules
from data.mnist_datamodule import MNISTDataModule
from training.neurogenesis_lightning_module import NeurogenesisLightningModule
from training.neurogenesis_trainer import NeurogenesisTrainer
from utils.viz_utils import plot_recon_error_history, plot_recon_grid


In [3]:
cfg = OmegaConf.load("C:/Users/Admin/Documents/GitHub/Neurogenesis/config/notebook_config.yaml")
print(OmegaConf.to_yaml(cfg))

dm = MNISTDataModule(
    data_dir=cfg.datamodule.data_dir,
    batch_size=cfg.datamodule.batch_size,
    num_workers=cfg.datamodule.num_workers,
    classes=cfg.datamodule.cls_pretraining,  # or list of classes, e.g. [1,7]
)
# Download and set up datasets
dm.prepare_data()
dm.setup()
# preview batch shapes
batch = next(iter(dm.train_dataloader()))
print([x.shape for x in batch])

cfg.trainer


In [5]:
model = NeurogenesisLightningModule(
    input_dim=28 * 28,
    hidden_sizes=cfg.model.hidden_sizes,
    activation=cfg.model.activation,
    activation_last=cfg.model.activation_last,
    thresholds=cfg.neurogenesis.thresholds,
    max_nodes=cfg.neurogenesis.max_nodes,
    max_outliers=cfg.neurogenesis.max_outliers,
    base_lr=cfg.neurogenesis.base_lr,
    plasticity_epochs=cfg.neurogenesis.plasticity_epochs,
    stability_epochs=cfg.neurogenesis.stability_epochs,
    next_layer_epochs=cfg.neurogenesis.next_layer_epochs,
)
model.ae.assert_valid_structure()
# Instantiate MLflow Logger directly
logger = MLFlowLogger(
    experiment_name=cfg.mlflow.experiment_name, tracking_uri=cfg.mlflow.tracking_uri
)
trainer = pl.Trainer(
    logger=logger,
    max_epochs=model.hparams.pretrain_epochs,
    accelerator=cfg.trainer.accelerator,
    log_every_n_steps=cfg.trainer.log_every_n_steps,
)
model.ae.assert_valid_structure()

trainer.fit(model, dm)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type          | Params | Mode 
--------------------------------------------------
0 | ae      | NGAutoEncoder | 369 K  | train
1 | loss_fn | MSELoss       | 0      | train
--------------------------------------------------
369 K     Trainable params
0         Non-trainable params
369 K     Total params
1.478     Total estimated model params size (MB)
20        Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\Admin\Anaconda3\envs\neurogenesis\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


                                                                           

c:\Users\Admin\Anaconda3\envs\neurogenesis\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


Epoch 19: 100%|██████████| 204/204 [00:06<00:00, 33.20it/s, v_num=c018]

`Trainer.fit` stopped: `max_epochs=20` reached.


Epoch 19: 100%|██████████| 204/204 [00:06<00:00, 32.76it/s, v_num=c018]


In [7]:
model.ae.assert_valid_structure()
# Extract the pretrained AE & IR
ae = model.ae
ir = model.ir

# Build the NeurogenesisTrainer with the same hyperparams
trainer_ng = NeurogenesisTrainer(
    ae=ae,
    ir=ir,
    thresholds=cfg.neurogenesis.thresholds,
    max_nodes=cfg.neurogenesis.max_nodes,
    max_outliers=cfg.neurogenesis.max_outliers,
    base_lr=cfg.neurogenesis.base_lr,
    plasticity_epochs=cfg.neurogenesis.plasticity_epochs,
    stability_epochs=cfg.neurogenesis.stability_epochs,
    next_layer_epochs=cfg.neurogenesis.next_layer_epochs,
)
trainer_ng.ae.assert_valid_structure()

# Grow network one class at a time
for cls in cfg.ir.class_sequence:
    loader = dm.get_class_dataloader(cls)
    trainer_ng.ae.assert_valid_structure()
    trainer_ng.learn_class(class_id=cls, loader=loader)
    trainer_ng.ae.assert_valid_structure()

    # Report and visualize growth
    print(f"After class {cls}, hidden sizes = {ae.hidden_sizes}")
    fig = plot_recon_error_history(trainer_ng, cls)
    fig.show()

    # Optional: show reconstructions for a batch
    batch = next(iter(loader))[0]
    grid = plot_recon_grid(ae, batch, view_shape=(1, 28, 28))
    grid.show()

✅ intrinsic-replay artifacts logged under run a3fdaeac425c4d128fb17291bddbc018
