# Level 2: Add a validation and test set

## Validate and test a model (basic)

### Find the train and test splits
Datasets come with two splits. Refer to the dataset documentation to find the train and test splits.

In [1]:
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import random_split
from torch.utils.data import DataLoader
from torchvision import datasets
import torchvision.transforms as transforms

# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

import lightning as L

### nn.Module

In [2]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

### Define the test loop
To add a test loop, implement the test_step method of the LightningModule

In [3]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

In [4]:
train_loader = DataLoader(train_set, batch_size=1024, shuffle=True, num_workers=16)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16)


In [5]:
autoencoder = LitAutoEncoder(Encoder(), Decoder())

# initialize the Trainer
trainer = L.Trainer(
    accelerator="gpu",
    devices=1,
    max_epochs=1
)

#torch.set_float32_matmul_precision('medium' | 'high')


# train model
trainer.fit(model=autoencoder, train_dataloaders=train_loader)

# test the model
trainer.test(model=autoencoder, dataloaders=test_loader)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Param

Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 59/59 [00:00<00:00, 136.10it/s, v_num=0]

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 59/59 [00:00<00:00, 130.72it/s, v_num=0]


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]


Testing DataLoader 0: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 10/10 [00:00<00:00, 159.14it/s]
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
       Test metric             DataLoader 0
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€
        test_loss           0.06368378549814224
â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”€â”

[{'test_loss': 0.06368378549814224}]

In [6]:
trainer = L.Trainer(
    accelerator="gpu",
    devices=1
)
print(trainer.strategy.root_device)

ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


cuda:0


### Add a validation loop
Split the training data

In [7]:
# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

### Define the validation loop
To add a validation loop, implement the validation_step method of the LightningModule

In [None]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        loss = F.mse_loss(x_hat, x)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        z = self.encoder(x)
        x_hat = self.decoder(z)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer

### Train with the validation loop
To run the validation loop, pass in the validation set to .fit

In [9]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16)
valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False, num_workers=16)
model = LitAutoEncoder(Encoder(), Decoder())

# train with both splits
trainer = L.Trainer(max_epochs=10)
trainer.fit(model, train_loader, valid_loader)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 50.4 K | train | 0    
1 | decoder | Decoder | 51.2 K | train | 0    
----------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB

Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 155.22it/s, v_num=1]        

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 154.06it/s, v_num=1]


## Saving and loading checkpoints (basic)

In [20]:
# saves checkpoints to 'some/path/checkpoints/'
trainer = L.Trainer(default_root_dir="some/path/", logger=False, max_epochs=5)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
from lightning.pytorch.loggers import CSVLogger

# checkpoints will be saved to 'logs/my_experiment/version_0/checkpoints/'
# NOT to 'some/path/checkpoints/'
trainer = L.Trainer(
    default_root_dir="some/path/",  # This will be ignored for checkpoints!
    logger=CSVLogger("logs", "my_experiment"),
    max_epochs=5
)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16)
valid_loader = DataLoader(valid_set, batch_size=128, shuffle=False, num_workers=16)
model = LitAutoEncoder(Encoder(), Decoder())

# train with both splits
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 50.4 K | train | 0    
1 | decoder | Decoder | 51.2 K | train | 0    
----------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 152.68it/s]                 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 151.57it/s]


In [None]:
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

# explicitly set checkpoint directory
checkpoint_callback = ModelCheckpoint(dirpath="my/custom/checkpoint/path/")
trainer = L.Trainer(
    logger=CSVLogger("logs", "my_experiment"),
    callbacks=[checkpoint_callback],
    max_epochs=5
)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [24]:
# train with both splits
trainer.fit(model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 50.4 K | train | 0    
1 | decoder | Decoder | 51.2 K | train | 0    
----------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 149.71it/s, v_num=0]        

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 148.48it/s, v_num=0]


In [47]:
class LitAutoEncoder(L.LightningModule):
    def __init__(self, encoder, decoder):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
        return optimizer
    
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [46]:
model = LitAutoEncoder.load_from_checkpoint("/storage3/DSIP/rriva/tutorials/pl_tutorial/basic/my/custom/checkpoint/path/epoch=4-step=1875.ckpt",
    encoder=Encoder(),
    decoder=Decoder()
)

# disable randomness, dropout, etc...
model.eval()

# sample input
x = torch.randn(1, 28 * 28).to(trainer.strategy.root_device)

# predict with the model
y_hat = model(x)

### Save hyperparameters

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [2]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16)

In [3]:
class Encoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(28 * 28, 64), nn.ReLU(), nn.Linear(64, 3))

    def forward(self, x):
        return self.l1(x)


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.l1 = nn.Sequential(nn.Linear(3, 64), nn.ReLU(), nn.Linear(64, 28 * 28))

    def forward(self, x):
        return self.l1(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        val_loss = F.mse_loss(x_hat, x)
        self.log("val_loss", val_loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self(x)
        test_loss = F.mse_loss(x_hat, x)
        self.log("test_loss", test_loss)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [4]:
model = LitAutoEncoder(
    encoder=Encoder(),
    decoder=Decoder()
)

In [5]:
# explicitly set checkpoint directory
logger = CSVLogger("logs", "experiment_with_hparams")
checkpoint_callback = ModelCheckpoint(dirpath=os.path.join(logger.log_dir,'checkpoints'))

trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=5
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [6]:
trainer.fit(model, train_loader, valid_loader)

You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 50.4 K | train | 0    
1 | decoder | Decoder | 51.2 K | train | 0    
----------------------------------------------------
101 K     Trainable params
0         Non-trainable params
101 K     Total params
0.407     Total estimated model params size (MB)
10        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 142.60it/s, v_num=3, val_loss=0.078] 

`Trainer.fit` stopped: `max_epochs=5` reached.


Epoch 4: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 141.61it/s, v_num=3, val_loss=0.078]


In [7]:
checkpoint_path = '/storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/experiment_with_hparams/version_2/checkpoints/epoch=4-step=1875.ckpt'

checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
print(checkpoint["hyper_parameters"])
# {"learning_rate": the_value, "another_parameter": the_other_value}

{'lr': 1e-05}


In [8]:
model = LitAutoEncoder.load_from_checkpoint(checkpoint_path,
    encoder=Encoder(),
    decoder=Decoder(),
)
print(model.hparams.lr)

1e-05


In [9]:
model = LitAutoEncoder.load_from_checkpoint(checkpoint_path,
    encoder=Encoder(),
    decoder=Decoder(),
    lr=1e-3
)

print(model.hparams.lr)

0.001


In [10]:
checkpoint = torch.load(checkpoint_path)
encoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("encoder.")}
decoder_weights = {k: v for k, v in checkpoint["state_dict"].items() if k.startswith("decoder.")}
print(checkpoint.keys())

dict_keys(['epoch', 'global_step', 'pytorch-lightning_version', 'state_dict', 'loops', 'callbacks', 'optimizer_states', 'lr_schedulers', 'hparams_name', 'hyper_parameters'])


### Disable checkpointing

In [None]:
trainer = Trainer(enable_checkpointing=False, max_epochs=10)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


### Resume training state

In [11]:
trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=10
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [12]:
model = LitAutoEncoder(
    encoder=Encoder(),
    decoder=Decoder()
)
# automatically restores model, epoch, step, LR schedulers, etc...
trainer.fit(model, train_loader, valid_loader, ckpt_path=checkpoint_path)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:881: Checkpoint directory /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/experiment_with_hparams/version_3/checkpoints exists and is not empty.
Restoring states from the checkpoint path at /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/experiment_with_hparams/version_2/checkpoints/epoch=4-step=1875.ckpt
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:566: The dirpath has changed from '/storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/experiment_with_hparams/version_2/checkpoints' to '/storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/experiment_with_hparams/version_3/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.
LOCAL_RANK: 0 - CUDA_VI

Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 139.67it/s, v_num=3, val_loss=0.0633]

`Trainer.fit` stopped: `max_epochs=10` reached.


Epoch 9: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 138.72it/s, v_num=3, val_loss=0.0633]


## Early Stopping

In [13]:
from lightning.pytorch.callbacks.early_stopping import EarlyStopping

In [14]:
trainer = Trainer(callbacks=[EarlyStopping(monitor="val_loss", mode="min")])
trainer.fit(model,train_loader, valid_loader)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or

Epoch 85: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:02<00:00, 181.39it/s, v_num=3, val_loss=0.0471]


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


https://lightning.ai/docs/pytorch/stable/common/early_stopping.html