# Level 6: Predict with your model

## Saving and loading checkpoints (basic)
When a model is training, the performance changes as it continues to see more data. It is a best practice to save the state of a model throughout the training process. This gives you a version of the model, a checkpoint, at each key point during the development of the model. Once training has completed, use the checkpoint that corresponds to the best performance you found during the training process.

Checkpoints also enable your training to resume from where it was in case the training process is interrupted.

PyTorch Lightning checkpoints are fully usable in plain PyTorch.

Contents of a checkpoint
A Lightning checkpoint contains a dump of the model’s entire internal state. Unlike plain PyTorch, Lightning saves everything you need to restore a model even in the most complex distributed training environments.

Inside a Lightning checkpoint you’ll find:

- 16-bit scaling factor (if using 16-bit precision training)

- Current epoch

- Global step

- LightningModule’s state_dict

- State of all optimizers

- State of all learning rate schedulers

- State of all callbacks (for stateful callbacks)

- State of datamodule (for stateful datamodules)

- The hyperparameters (init arguments) with which the model was created

- The hyperparameters (init arguments) with which the datamodule was created

- State of Loops

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import EarlyStopping
from lightning.pytorch.callbacks import ModelSummary
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [2]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="../data/MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="../data/MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16, persistent_workers=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16, persistent_workers=True, pin_memory=True)

In [None]:
class Encoder(nn.Module):
    def __init__(self, in_dim=28*28, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=4):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class Decoder(nn.Module):
    def __init__(self, in_dim=4, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=28*28):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5, example_input_array=None):
        super().__init__()
        self.example_input_array = example_input_array
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch
        x = x.view(x.size(0), -1)
        return self.forward(x)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [4]:
model = LitAutoEncoder(
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    ),
    lr=1e-4,
    example_input_array=torch.zeros(32, 1, 28*28)
)

In [5]:
logger = CSVLogger(
    save_dir='logs',
    name='mnist_test',
    version=None,
    prefix='test_'
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="autoencoder-{epoch:02d}-{val_loss:.6f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

model_summary_callback = ModelSummary(-1)

In [6]:
trainer = Trainer(
    profiler = 'simple',
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback, model_summary_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=100
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
Trainer already configured with model summary callbacks: [<class 'lightning.pytorch.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [7]:
help(trainer.test)

Help on method test in module lightning.pytorch.trainer.trainer:

test(
    model: Optional[ForwardRef('pl.LightningModule')] = None,
    dataloaders: Union[Any, lightning.pytorch.core.datamodule.LightningDataModule, NoneType] = None,
    ckpt_path: Union[str, pathlib._local.Path, NoneType] = None,
    verbose: bool = True,
    datamodule: Optional[lightning.pytorch.core.datamodule.LightningDataModule] = None,
    weights_only: Optional[bool] = None
) -> list[collections.abc.Mapping[str, float]] method of lightning.pytorch.trainer.trainer.Trainer instance
    Perform one evaluation epoch over the test set. It's separated from fit to make sure you never run on your
    test set until you want to.

    Args:
        model: The model to test.

        dataloaders: An iterable or collection of iterables specifying test samples.
            Alternatively, a :class:`~lightning.pytorch.core.datamodule.LightningDataModule` that defines
            the :class:`~lightning.pytorch.core.hooks.Data

In [8]:
trainer.fit(model, train_loader, valid_loader)
trainer.test(model, test_loader)

You are using a CUDA device ('NVIDIA A40') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

   | Name         | Type       | Params | Mode  | FLOPs  | In sizes     | Out sizes   
--------------------------------------------------------------------------------------------
0  | encoder      | Encoder    | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
1  | encoder.ff   | Sequential | 558 K  | train | 35.7 M | [32, 1, 784] | [32, 1, 100]
2  | encoder.ff.0 | Linear     | 401 K  | train | 25.7 M | [32, 1, 784] | [32, 1, 512]
3  | encoder.ff.1 | ReLU       | 0      | train | 0      | [32, 1, 512] | [32, 1, 512]
4  | encoder.ff.2 | Linear     | 131 K  | train | 8.4 M  | [32, 

Epoch 0:   0%|          | 0/375 [00:00<?, ?it/s]                           

Epoch 99: 100%|██████████| 375/375 [00:04<00:00, 76.13it/s, v_num=1, val_loss=0.00617, train_loss=0.00616]

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 375/375 [00:05<00:00, 73.91it/s, v_num=1, val_loss=0.00617, train_loss=0.00616]


FIT Profiler Report

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                               	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                	|  -       

Testing DataLoader 0: 100%|██████████| 10/10 [00:00<00:00, 149.54it/s]
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_loss          0.006019039079546928
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


TEST Profiler Report

----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Action                                                                                                                                                               	|  Mean duration (s)	|  Num calls      	|  Total time (s) 	|  Percentage %   	|
----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
|  Total                                                                                                                                                                	|  -      

[{'test_loss': 0.006019039079546928}]

In [9]:
# Access human-readable message
if early_stop_callback.stopping_reason_message:
    print(f"Details: {early_stop_callback.stopping_reason_message}")
else:
    print("Early stopping was not triggered.")

Early stopping was not triggered.


### Load a checkpoint and predict
The easiest way to use a model for predictions is to load the weights using load_from_checkpoint found in the LightningModule.

In [18]:
model_new = LitAutoEncoder.load_from_checkpoint(
    checkpoint_callback.best_model_path,
    encoder=model.encoder,
    decoder=model.decoder
)
model_new.eval()
x = torch.randn(128, 1, 28*28)
x = x.to(model_new.device)

with torch.no_grad():
    y_hat = model_new(x)

### Predict step with your LightningModule
Loading a checkpoint and predicting still leaves you with a lot of boilerplate around the predict epoch. The predict step in the LightningModule removes this boilerplate.

In [19]:
class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5, example_input_array=None):
        super().__init__()
        self.example_input_array = example_input_array
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def predict_step(self, batch, batch_idx, dataloader_idx=0):
        x, _ = batch
        x = x.view(x.size(0), -1)
        return self.forward(x)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [21]:
trainer.predict(model, dataloaders=test_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Predicting DataLoader 0:   0%|          | 0/10 [00:00<?, ?it/s]

TypeError: linear(): argument 'input' (position 1) must be Tensor, not list

In [23]:
class LitMCdropoutModel(LightningModule):
    def __init__(self, model, mc_iteration):
        super().__init__()
        self.model = model
        self.dropout = nn.Dropout()
        self.mc_iteration = mc_iteration

    def predict_step(self, batch, batch_idx):
        # enable Monte Carlo Dropout
        self.dropout.train()

        # take average of `self.mc_iteration` iterations
        pred = [self.dropout(self.model(x)).unsqueeze(0) for _ in range(self.mc_iteration)]
        pred = torch.vstack(pred).mean(dim=0)
        return pred

In [None]:
model