In [1]:
import os
from torch import optim, nn
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import torch.nn.functional as F
import pytorch_lightning as pl

# define any number of nn.Modules (or use your current ones)
encoder = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))
decoder = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

# define the LightningModule
class LitAutoEncoder(pl.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        # it is independent of forward
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("train_loss", loss)
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        # Logging to TensorBoard (if installed) by default
        self.log("val_loss", loss)
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.parameters(), lr=1e-3)
        return optimizer


# init the autoencoder
autoencoder = LitAutoEncoder(encoder, decoder)

In [2]:
# setup data
dataset = MNIST("~/data/", download=True, transform=ToTensor())
train_loader = DataLoader(dataset, num_workers=8, batch_size=128, shuffle=True)
test_loader = DataLoader(dataset, num_workers=8, batch_size=128, shuffle=False)

In [3]:
# DRY RUN
logger = pl.loggers.TensorBoardLogger("tb_logs", name="autoencoder")
trainer = pl.Trainer(max_epochs=2, 
                        accelerator="cpu", devices=1,
                        logger=logger, fast_dev_run=True)
trainer.fit(model=autoencoder, 
    train_dataloaders=train_loader,
    val_dataloaders=test_loader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
Running in `fast_dev_run` mode: will run the requested loop using 1 batch(es). Logging and checkpointing is suppressed.

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_steps=1` reached.


In [4]:
# TRAIN FROM SCRATCH
trainer = pl.Trainer(max_epochs=2, 
                        accelerator="cpu", devices=1,
                        logger=True, fast_dev_run=False)
trainer.fit(model=autoencoder, 
    train_dataloaders=train_loader,
    val_dataloaders=test_loader)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=2` reached.


In [5]:
# RESUME CHEKPOINT
trainer = pl.Trainer(max_epochs=3, 
                        accelerator="cpu", devices=1,
                        logger=False, fast_dev_run=False)
trainer.fit(model=autoencoder, 
    train_dataloaders=train_loader,
    val_dataloaders=test_loader,
    ckpt_path="lightning_logs/version_0/checkpoints/epoch=1-step=938.ckpt"
    )

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at lightning_logs/version_0/checkpoints/epoch=1-step=938.ckpt

  | Name    | Type       | Params
---------------------------------------
0 | encoder | Sequential | 50.4 K
1 | decoder | Sequential | 51.2 K
---------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
Restored all states from the checkpoint file at lightning_logs/version_0/checkpoints/epoch=1-step=938.ckpt


Sanity Checking: 0it [00:00, ?it/s]

Training: 469it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

`Trainer.fit` stopped: `max_epochs=3` reached.
