# An example of autoencoder (AE) with pytorch lightning

In [1]:
# library
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl

In [2]:
# model: 784 -> 64 -> 3 -> 64 -> 784
encoder = nn.Sequential(nn.Linear(28 * 28, 64), 
                        nn.ReLU(), 
                        nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), 
                        nn.ReLU(), 
                        nn.Linear(64, 28 * 28))

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Include extra logging here
        self.log('train_loss', loss)
        print(f'train_loss: {loss}')
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=0.001)
        return optimizer

ae = LitAutoEncoder(encoder, decoder)


In [3]:
# data
dataset = MNIST(root='./data', download=True, transform=ToTensor())
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# trainer
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model=ae, train_dataloaders=train_loader)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:14<00:00, 678687.25it/s]


Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 117479.16it/s]


Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:02<00:00, 668030.19it/s]


Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 2315892.14it/s]
Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.


Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw



GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
Missing logger folder: /mnt/nas/augix/sandbox/mnist_ae/lightning_logs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
/home/augix/miniconda/envs/mamba/lib/python3.11/site-pack

Training: |          | 0/? [00:00<?, ?it/s]

train_loss: 0.12751959264278412
train_loss: 0.14317581057548523
train_loss: 0.11914743483066559
train_loss: 0.13077330589294434
train_loss: 0.1213809996843338
train_loss: 0.12607617676258087
train_loss: 0.11274271458387375
train_loss: 0.10770411789417267
train_loss: 0.10215792804956436
train_loss: 0.09805278480052948
train_loss: 0.10442468523979187
train_loss: 0.10403739660978317
train_loss: 0.09111207723617554
train_loss: 0.10049887746572495
train_loss: 0.097576804459095
train_loss: 0.10033366084098816
train_loss: 0.09294869005680084
train_loss: 0.09833518415689468
train_loss: 0.09394189715385437
train_loss: 0.09700331091880798
train_loss: 0.09154554456472397
train_loss: 0.08839244395494461
train_loss: 0.0779370367527008
train_loss: 0.08739665895700455
train_loss: 0.0886666476726532
train_loss: 0.08363693207502365
train_loss: 0.0731501430273056
train_loss: 0.07998725771903992
train_loss: 0.07254442572593689
train_loss: 0.07517071068286896
train_loss: 0.07385583221912384
train_loss: 0.

`Trainer.fit` stopped: `max_epochs=1` reached.


train_loss: 0.042352091521024704
train_loss: 0.04379018396139145
train_loss: 0.04103630408644676
train_loss: 0.041470468044281006
train_loss: 0.040585920214653015
train_loss: 0.0384526252746582
train_loss: 0.03974176198244095
train_loss: 0.03917107358574867
train_loss: 0.04524951055645943
train_loss: 0.04062916338443756
train_loss: 0.043837886303663254
train_loss: 0.044590551406145096
train_loss: 0.04531894251704216
train_loss: 0.042006928473711014
train_loss: 0.03977338224649429
train_loss: 0.04366181790828705
train_loss: 0.047427307814359665
train_loss: 0.046864818781614304
train_loss: 0.04250354692339897
train_loss: 0.039451614022254944
train_loss: 0.047011759132146835
