In [1]:
import sys
sys.path.insert(0, "../../../")

In [2]:
import deeplay as dtm
import torchvision.transforms as transforms

import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl

In [3]:
# Load the MNIST dataset and normalize it to [-1, 1]

mnist = torchvision.datasets.MNIST(
    root="data", train=True, download=True, transform=torchvision.transforms.ToTensor()
)

mnist_test = torchvision.datasets.MNIST(
    root="data", train=False, download=True, transform=torchvision.transforms.ToTensor()
)

In [4]:
mnist_dataloader = torch.utils.data.DataLoader(mnist, batch_size=32, num_workers=4)
mnist_test_dataloader = torch.utils.data.DataLoader(mnist_test, batch_size=32, num_workers=4)

In [5]:

def generate_examples(model, latent_range, n=5):
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision
    import torch
    plt.figure(figsize=(10, 10))
    x = np.linspace(latent_range[0][0], latent_range[1][0], n)
    y = np.linspace(latent_range[0][1], latent_range[1][1], n)

    xx, yy = np.meshgrid(x, y)
    xx = torch.tensor(xx).float()
    yy = torch.tensor(yy).float()
    z = torch.stack([xx, yy], dim=-1).view(-1, 2)
    yhat = model.decode(z).detach().cpu().numpy()

    for i in range(n ** 2):
        plt.subplot(n, n, i + 1)
        plt.imshow(yhat[i].transpose((1, 2, 0)), cmap="Greys_r")
        plt.axis("off")

    plt.tight_layout()


    

In [6]:


def scatterplot(model, testset):
    import matplotlib.pyplot as plt
    import numpy as np
    import torchvision
    import torch

    xy = torch.zeros((len(testset) * 32, 2))
    test_classes = torch.zeros((len(testset) * 32))

    i0 = 0
    for i, (x, y) in enumerate(testset):
        yhat = model.encode(x.view(-1, 1, 28, 28)).detach().cpu()
        xy[i:i + len(yhat)] = yhat
        test_classes[i:i + len(yhat)] = y.cpu().view(-1)

    latent_range = (
        xy.min(dim=0)[0],
        xy.max(dim=0)[0],
    )
   
    plt.figure(figsize=(10, 10))
    plt.scatter(xy[:, 0], xy[:, 1], c=test_classes, cmap="jet")
    plt.colorbar()
    plt.tight_layout()
    plt.show()

    return latent_range

In [8]:
autoencoder = dtm.Autoencoder()
autoencoder(torch.randn(1, 1, 28, 28))
autoencoder

Autoencoder(
  (encoder): ImageToImageEncoder(
    (blocks): ModuleList(
      (0): Template(
        (layer): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (activation): ReLU()
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): Template(
        (layer): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (activation): ReLU()
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (output): Identity()
  )
  (bottleneck): Bottleneck(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (layer): Linear(in_features=784, out_features=2, bias=True)
    (activation): Identity()
  )
  (decoder): SpatialBroadcastDecoder2d(
    (input): Identity()
    (encoding): PositionalEncodingLinear2d()
    (blocks): ModuleList(
      (0): Template(
        (layer): Conv2d(4, 128, kernel_size=(1, 1), stride=(1, 1))
        (activation): ReLU()
      

In [7]:
trainer = pl.Trainer(max_epochs=10, accelerator="cuda")
trainer.fit(autoencoder, mnist_dataloader, mnist_test_dataloader)
trainer.test(dataloaders=mnist_test_dataloader)

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


VariationalAutoencoder(
  (encoder): ImageToImageEncoder(
    (blocks): ModuleList(
      (0): Template(
        (layer): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (activation): ReLU()
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
      (1): Template(
        (layer): Conv2d(8, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (activation): ReLU()
        (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      )
    )
    (output): Identity()
  )
  (bottleneck): VariationalBottleneck(
    (flatten): Flatten(start_dim=1, end_dim=-1)
    (layer): Linear(in_features=784, out_features=4, bias=True)
    (samplers): ModuleList(
      (0): NormalDistribtionSampler(
        (loc): Identity()
      )
      (1): NormalDistribtionSampler(
        (loc): Identity()
      )
    )
    (activation): Identity()
  )
  (decoder): SpatialBroadcastDecoder2d(
    (input): Identit

In [None]:
latent_range = scatterplot(autoencoder, mnist_test_dataloader)
generate_examples(autoencoder, latent_range, n=10)