In [6]:
from notebooks.bootstrap_notebooks import bootstrap_notebook, load_dataset, train, test
from pl_bolts.models.autoencoders import VAE
from src.base.training.models.aes_utils import get_enc_dec, preview_dims
from torch.nn import Flatten, Unflatten, Sequential, Linear

In [7]:
EXP_NAME = "vae_nb"

config = {
    "exp_name": EXP_NAME,
    "enable_gpu": True,
    "training_n_dev": 1,
    "logs": "C:\\Users\\micdu\\Code\\pythonProject\\dmtl\\notebooks\\logs",
    "training_out": "C:\\Users\\micdu\\Code\\pythonProject\\dmtl\\notebooks\\lightning_data",
    "tracking_uri": "http://localhost:5000"
}
bootstrap_notebook(config)

In [8]:
load_dataset({
    "dataset": "mnist",
    "batch_size": 200,
    "data_path": "C:\\Users\\micdu\\Code\\pythonProject\\dmtl\\data",
})

In [15]:
class LenetVAE(VAE):

    def __init__(
            self,
            input_height=28,
            kl_coeff: float = 0.1,
            latent_dim: int = 200,
            lr: float = 1e-4,
                 ):
        layers = [
            {
                "in_channels": 1,
                "out_channels": 8,
                "padding": 2,
                "kernel_size": 5,
                "scale_factor": 1 / 2
            },
            {
                "in_channels": 8,
                "out_channels": 16,
                "padding": 0,
                "kernel_size": 5,
                "scale_factor": 1 /2
            },
        ]
        channels, x, y = preview_dims((input_height,input_height), layers)
        super(LenetVAE, self).__init__(
            input_height=input_height,
            enc_out_dim=channels * x * y,
            kl_coeff=kl_coeff,
            latent_dim=latent_dim,
            lr=lr
        )
        self.encoder, decoder = get_enc_dec(layers)
        self.encoder.append(Flatten(start_dim=1))
        self.decoder = Sequential(
            Linear(latent_dim, self.enc_out_dim),
            Unflatten(1, (channels, x, y)),
            decoder
        )

In [14]:
model = LenetVAE()
train(model, epochs=50)

  super(LenetVAE, self).__init__(
Auto select gpus: [0]
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 3.4 K 
1 | decoder | Sequential | 83.8 K
2 | fc_mu   | Linear     | 80.2 K
3 | fc_var  | Linear     | 80.2 K
---------------------------------------
247 K     Trainable params
0         Non-trainable params
247 K     Total params
0.991     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


<src.base.training.models.experiment.Experiment at 0x1cc3b418850>

In [None]:
checkpoint = "./lightning_data/checkpoints/6/5ac72ec78b5f4835a2c46599e50e1f92/checkpoints/epoch=27-step=44632.ckpt"
#LenetVAE().load_from_checkpoint(checkpoint)
loaded_model = model

In [None]:
from matplotlib.pyplot import imshow, figure
from torchvision.utils import make_grid
import torch

figure(figsize=(8, 3), dpi=300)

num_preds = 16
x = torch.rand((num_preds, 1600))
mu = loaded_model.fc_mu(x)
log_var = loaded_model.fc_var(x)
p, q, z = loaded_model.sample(mu, log_var)

with torch.no_grad():
    pred = loaded_model.decoder(z.to(loaded_model.device)).cpu()

img = make_grid(pred).permute(1, 2, 0).numpy()

# PLOT IMAGES
imshow(img);