In [None]:
pip install tensorboard, lightning

# Define a lightning model
A LightningModule enables your PyTorch nn.Module to play together in complex ways inside the training_step (there is also an optional validation_step and test_step).

In [None]:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning.pytorch as pl
import torch.nn.functional as F

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat, z

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        x_hat, z = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, y = batch
        x = x.view(x.size(0), -1)
        x_hat, z = self.forward(x)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

# Defining a dataset

Lightning supports ANY iterable (DataLoader, numpy, etc…) for the train/val/test/predict splits.

In [None]:
# setup data
train_set = MNIST(root="MNIST", download=True, train=True, transform=ToTensor())
test_set = MNIST(root="MNIST", download=True, train=False, transform=ToTensor())
train_loader = utils.data.DataLoader(train_set, batch_size=32)
val_loader = utils.data.DataLoader(test_set, batch_size=32)

In [None]:
train_set

In [None]:
train_loader

In [None]:
for batch in train_loader:
    first_batch = batch
    break

In [None]:
first_batch[0].shape

In [None]:
first_batch[1]

In [None]:
first_batch[0][0]

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.imshow(first_batch[0][0].permute(1, 2, 0))

In [None]:
plt.imshow(first_batch[0][11].permute(1, 2, 0))

In [None]:
plt.imshow(first_batch[0][1].permute(1, 2, 0))

# Train the model

The Lightning Trainer “mixes” any LightningModule with any dataset and abstracts away all the engineering complexity needed for scale.

In [None]:
# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
trainer = pl.Trainer(limit_train_batches=100, max_epochs=50)
trainer.fit(autoencoder, train_loader, val_loader)

# Use the model
Once you’ve trained the model you can export to onnx, torchscript and put it into production or simply load the weights and run predictions.

In [None]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=49-step=5000.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(checkpoint, encoder=encoder, decoder=decoder)

# choose your trained nn.Module
encoder = autoencoder.encoder
decoder = autoencoder.decoder
encoder.eval()
decoder.eval()

from torch import rand
# embed 4 randomly initialized images!
fake_image_batch =  rand(4, 28*28).to("cpu")
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

# and decode this embeddings
reconstructed = decoder(embeddings)
plt.imshow(reconstructed[0].reshape(28, 28, 1).cpu().detach().numpy())

In [None]:
plt.imshow(fake_image_batch[0].reshape(28, 28, 1))

# Play with the embeddings

In [None]:
first_batch[1]

In [None]:
five_1 = first_batch[0][0].reshape(1, 28 * 28).to('cpu')
five_2 = first_batch[0][11].reshape(1, 28 * 28).to('cpu')
zero = first_batch[0][1].reshape(1, 28 * 28).to('cpu')

In [None]:
plt.imshow(decoder(encoder(zero) + (encoder(five_1) - encoder(five_2)))[0].reshape(28, 28, 1).cpu().detach().numpy())

In [None]:
plt.imshow(decoder(encoder(five_1) - encoder(zero))[0].reshape(28, 28, 1).cpu().detach().numpy())

# Visualize training
If you have tensorboard installed, you can use it for visualizing experiments.

In [None]:
%load_ext tensorboard
%tensorboard --logdir /home/studio-lab-user//lightning_logs/version_1/ --bind_all