https://lightning.ai/docs/pytorch/stable/starter/introduction.html

# 1: Install PyTorch Lightning

In [6]:
# ! pip install lightning

## 2: Define a LightningModule

In [7]:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torch

In [2]:
# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

## 3: Define a dataset

In [3]:
# setup data
dataset = MNIST(os.getcwd(), download=True, transform=ToTensor())
train_loader = utils.data.DataLoader(dataset)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\train-images-idx3-ubyte.gz


100%|██████████| 9912422/9912422 [00:01<00:00, 9871090.69it/s] 


Extracting c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\train-images-idx3-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\train-labels-idx1-ubyte.gz


100%|██████████| 28881/28881 [00:00<00:00, 24368475.93it/s]

Extracting c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\train-labels-idx1-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz





Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\t10k-images-idx3-ubyte.gz


100%|██████████| 1648877/1648877 [00:00<00:00, 15365198.91it/s]


Extracting c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\t10k-images-idx3-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|██████████| 4542/4542 [00:00<00:00, 590677.44it/s]

Extracting c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw\t10k-labels-idx1-ubyte.gz to c:\Users\thoma\OneDrive\AUSBILDUNG\MCI\Sem 6 - LV\Bachelorarbeit\Code\MNIST\raw






## 4: Train the model

In [5]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = L.Trainer(limit_train_batches=100, max_epochs=10)
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=10` reached.


## 5: Use the model

In [8]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡ 
Predictions (4 image embeddings):
 tensor([[ 0.3524, -0.0387, -0.3097],
        [ 0.5647, -0.1102, -0.4769],
        [ 0.4923, -0.0377, -0.3691],
        [ 0.4778, -0.0473, -0.3776]], grad_fn=<AddmmBackward0>) 
 ⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡⚡
