# Neurogenesis Demo
Train an autoencoder on a small MNIST subset and generate intrinsic replay samples.

In [None]:
import torch
import torchvision
import pytorch_lightning as pl
from pytorch_lightning.loggers import MLFlowLogger

from data.mnist_datamodule import MNISTDataModule
from models.autoencoder import AutoEncoder
from training.intrinsic_replay_runner import run_intrinsic_replay

In [None]:
# datamodule restricted to digits 0 and 1
dm = MNISTDataModule(batch_size=64, num_workers=0, classes=[0, 1])
dm.setup()

In [None]:
# LightningModule wrapping the AutoEncoder
class LitWrapper(pl.LightningModule):
    def __init__(self, model):
        super().__init__()
        self.ae = model
        self.loss_fn = torch.nn.MSELoss()
        self.train_losses = []
        self.val_losses = []

    def forward(self, x):
        return self.ae(x)

    def training_step(self, batch, batch_idx):
        imgs, _ = batch
        out = self(imgs)
        loss = self.loss_fn(out['recon'], imgs.view(imgs.size(0), -1))
        self.train_losses.append(loss.item())
        self.log('train_loss', loss)
        return loss

    def validation_step(self, batch, batch_idx):
        imgs, _ = batch
        out = self(imgs)
        self._last_imgs = imgs
        self._last_recons = out['recon'].view_as(imgs)
        loss = self.loss_fn(out['recon'], imgs.view(imgs.size(0), -1))
        self.val_losses.append(loss.item())
        self.log('val_loss', loss)

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=1e-3)

    def on_validation_epoch_end(self):
        grid = torchvision.utils.make_grid(
            torch.cat([self._last_imgs, self._last_recons], dim=0),
            nrow=self._last_imgs.size(0),
        )
        self.example_grid = grid
        self._last_imgs = None
        self._last_recons = None

In [None]:
model = AutoEncoder(input_dim=28*28, hidden_sizes=[64, 32], activation='relu')
lit = LitWrapper(model)
logger = MLFlowLogger(experiment_name='demo')
trainer = pl.Trainer(max_epochs=1, logger=logger)
trainer.fit(lit, dm)

In [None]:
print('Train losses:', lit.train_losses)
print('Val losses:', lit.val_losses)
display(torchvision.transforms.ToPILImage()(lit.example_grid))

In [None]:
run_intrinsic_replay(
    encoder=model.encoder,
    decoder=model.decoder,
    dataloader=dm.train_dataloader(),
    mlf_logger=logger,
    n_samples_per_class=16,
    device=trainer.strategy.root_device,
)

In [None]:
# display the intrinsic replay images logged for class 0 and 1
from pathlib import Path
import matplotlib.pyplot as plt

artifacts = Path(logger.experiment.get_run(logger.run_id).info.artifact_uri)
for cls in [0, 1]:
    img = plt.imread(artifacts / f'ir_replay/class_{cls}/ir_class_{cls}.png')
    plt.figure();
    plt.imshow(img);
    plt.axis('off');
    plt.title(f'IR samples for class {cls}')