# An example of autoencoder (AE) with pytorch lightning

In [1]:
# library
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch import Tensor
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import pytorch_lightning as pl

In [2]:
# model: 784 -> 64 -> 3 -> 64 -> 784
encoder = nn.Sequential(nn.Linear(28 * 28, 64), 
                        nn.ReLU(), 
                        nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), 
                        nn.ReLU(), 
                        nn.Linear(64, 28 * 28))

class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Include extra logging here
        self.log('train_loss', loss)
        print(f'train_loss: {loss}')
        return loss
    
    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=0.001)
        return optimizer

ae = LitAutoEncoder(encoder, decoder)


In [3]:
# data
dataset = MNIST(root='./data', download=True, transform=ToTensor())
train_loader = DataLoader(dataset, batch_size=32, shuffle=True)

# trainer
trainer = pl.Trainer(max_epochs=1)
trainer.fit(model=ae, train_dataloaders=train_loader)

Trainer will use only 1 of 2 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=2)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
-----------------

Training: |          | 0/? [00:00<?, ?it/s]

train_loss: 0.12982700765132904
train_loss: 0.14043591916561127
train_loss: 0.13210667669773102
train_loss: 0.11799132823944092
train_loss: 0.12972961366176605
train_loss: 0.1236516609787941
train_loss: 0.12556129693984985
train_loss: 0.1233401745557785
train_loss: 0.10441011190414429
train_loss: 0.11241642385721207
train_loss: 0.110621877014637
train_loss: 0.10509290546178818
train_loss: 0.10522322356700897
train_loss: 0.09652478992938995
train_loss: 0.09180828183889389
train_loss: 0.10036342591047287
train_loss: 0.08846341073513031
train_loss: 0.09337251633405685
train_loss: 0.0965307354927063
train_loss: 0.0909583568572998
train_loss: 0.09523572772741318
train_loss: 0.09019909799098969
train_loss: 0.08766265958547592
train_loss: 0.08130418509244919
train_loss: 0.08146239817142487
train_loss: 0.07984609156847
train_loss: 0.08516711741685867
train_loss: 0.0824030190706253
train_loss: 0.081952765583992
train_loss: 0.0801665335893631
train_loss: 0.07392263412475586
train_loss: 0.0758946

`Trainer.fit` stopped: `max_epochs=1` reached.


train_loss: 0.03918258100748062
train_loss: 0.048604048788547516
train_loss: 0.03833622485399246
train_loss: 0.04008423537015915
train_loss: 0.040796250104904175
train_loss: 0.041862428188323975
train_loss: 0.039605725556612015
train_loss: 0.04285390302538872
train_loss: 0.04430767893791199
train_loss: 0.04532242566347122
train_loss: 0.04274338111281395
train_loss: 0.03967258334159851
train_loss: 0.039830755442380905
train_loss: 0.03838515281677246
train_loss: 0.04989895597100258
train_loss: 0.04129233583807945
train_loss: 0.04614856839179993
