In [1]:
import os
from torch import optim, nn, utils, Tensor
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
import lightning as L
import torch

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))


# define the LightningModule
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder, lr=1e-3):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log(
            "train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = nn.functional.mse_loss(x_hat, x)
        self.log(
            "val_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True
        )
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=self.lr)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# setup data


from lightning import LightningDataModule
from torch.utils.data import DataLoader


train_dataset = MNIST(os.getcwd(), train=True, download=True, transform=ToTensor())
val_dataset = MNIST(os.getcwd(), train=False, download=True, transform=ToTensor())


class LitDataModule(LightningDataModule):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size

    def train_dataloader(self):
        return DataLoader(
            train_dataset,
            batch_size=self.batch_size,
            persistent_workers=True,
            num_workers=16,
            pin_memory=True,
            shuffle=True,
        )

    def val_dataloader(self):
        return DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            persistent_workers=True,
            num_workers=16,
            pin_memory=True,
            shuffle=False,
        )


datamodule = LitDataModule(batch_size=2048)
print(datamodule.batch_size)

2048


In [7]:
from lightning.pytorch.tuner import Tuner
from lightning.pytorch.callbacks import StochasticWeightAveraging
from pytorch_lightning.loggers import TensorBoardLogger

# train the model (hint: here are some helpful Trainer arguments for rapid idea iteration)
tb_logger = TensorBoardLogger("./logs", "VAE", version="b2048-lr0.005-swa-le-2")
trainer = L.Trainer(
    # limit_train_batches=100,
    max_epochs=100,
    callbacks=[StochasticWeightAveraging(swa_lrs=1e-2)],
    logger=tb_logger,
    profiler="simple",
)

tuner = Tuner(trainer)
# tuner.scale_batch_size(autoencoder, datamodule, mode="power")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [8]:
# lr_finder = tuner.lr_find(autoencoder, datamodule)
# assert lr_finder is not None, "No learning rate finder found"
# fig = lr_finder.plot(suggest=True)
# assert fig is not None, "No plot returned"
# fig.show()

# # Pick point based on plot, or get suggestion
# new_lr = lr_finder.suggestion()
new_lr = 0.005754399373371567


autoencoder.lr = new_lr  # type: ignore
print(new_lr)

0.005754399373371567


In [9]:
trainer.fit(autoencoder, datamodule)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Epoch 99: 100%|██████████| 30/30 [00:00<00:00, 39.46it/s, v_num=le-2, val_loss=0.0372, train_loss=0.0366]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 30/30 [00:00<00:00, 37.89it/s, v_num=le-2, val_loss=0.0372, train_loss=0.0366]


FIT Profiler Report

--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                         	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                          	|  -              	|  138224         	|

In [6]:
# load checkpoint
checkpoint = "./lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt"
autoencoder = LitAutoEncoder.load_from_checkpoint(
    checkpoint, encoder=encoder, decoder=decoder
)

# choose your trained nn.Module
encoder = autoencoder.encoder
encoder.eval()

# embed 4 fake images!
fake_image_batch = torch.rand(4, 28 * 28, device=autoencoder.device)
embeddings = encoder(fake_image_batch)
print("⚡" * 20, "\nPredictions (4 image embeddings):\n", embeddings, "\n", "⚡" * 20)

FileNotFoundError: [Errno 2] No such file or directory: '/app/lightning_logs/version_0/checkpoints/epoch=0-step=100.ckpt'