In [1]:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl

In [2]:
from lightning.pytorch.cli import LightningCLI

# Define Model Arch

In [3]:
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)

        z = self.encoder(x)
        x_hat = self.decoder(z)
        
        loss = nn.functional.mse_loss(x_hat, x)
        self.log("train_loss", loss)
        
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)

        return optimizer

In [4]:
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

autoencoder = LitAutoEncoder(encoder, decoder)

# Get Data

In [5]:
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/train-images-idx3-ubyte.gz


100%|███████████████████████████████████████████████████████████████████| 9912422/9912422 [00:00<00:00, 37399328.26it/s]


Extracting /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/train-images-idx3-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/train-labels-idx1-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 20263582.11it/s]

Extracting /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/train-labels-idx1-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/t10k-images-idx3-ubyte.gz



100%|███████████████████████████████████████████████████████████████████| 1648877/1648877 [00:00<00:00, 12937273.81it/s]


Extracting /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/t10k-images-idx3-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 23606603.18it/s]


Extracting /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw/t10k-labels-idx1-ubyte.gz to /home/rishabh/cookiecutter-kaggle/{{ cookiecutter.project_name }}/experiments/notebooks/MNIST/raw



In [None]:
train_loader = utils.data.DataLoader(dataset)

# Training

In [None]:
trainer = pl.Trainer(
    max_epochs=10,
    min_epochs=5
)

In [None]:
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# Inference

In [None]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = Tensor(4, 28 * 28)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)