# Level 2: Add a validation and test set
## Early Stopping

https://lightning.ai/docs/pytorch/stable/common/early_stopping.html

In [38]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [39]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="../data/MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="../data/MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16, persistent_workers=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16, persistent_workers=True, pin_memory=True)

In [40]:
help(LightningModule.log)

Help on function log in module lightning.pytorch.core.module:

log(
    self,
    name: str,
    value: Union[torchmetrics.metric.Metric, torch.Tensor, int, float],
    prog_bar: bool = False,
    logger: Optional[bool] = None,
    on_step: Optional[bool] = None,
    on_epoch: Optional[bool] = None,
    reduce_fx: Union[str, Callable[[Any], Any]] = 'mean',
    enable_graph: bool = False,
    sync_dist: bool = False,
    sync_dist_group: Optional[Any] = None,
    add_dataloader_idx: bool = True,
    batch_size: Optional[int] = None,
    metric_attribute: Optional[str] = None,
    rank_zero_only: bool = False
) -> None
    Log a key, value pair.

    Example::

        self.log('train_loss', loss)

    The default behavior per hook is documented here: :ref:`extensions/logging:Automatic Logging`.

    Args:
        name: key to log. Must be identical across all processes if using DDP or any other distributed strategy.
        value: value to log. Can be a ``float``, ``Tensor``, or a ``Met

In [41]:
class Encoder(nn.Module):
    def __init__(self, in_dim=28*28, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=4):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)


class Decoder(nn.Module):
    def __init__(self, in_dim=4, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=28*28):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [42]:
model = LitAutoEncoder(
    encoder=Encoder(),
    decoder=Decoder()
)

In [43]:
help(CSVLogger)

Help on class CSVLogger in module lightning.pytorch.loggers.csv_logs:

class CSVLogger(lightning.pytorch.loggers.logger.Logger, lightning.fabric.loggers.csv_logs.CSVLogger)
 |  CSVLogger(
 |      save_dir: Union[str, pathlib._local.Path],
 |      name: Optional[str] = 'lightning_logs',
 |      version: Union[int, str, NoneType] = None,
 |      prefix: str = '',
 |      flush_logs_every_n_steps: int = 100
 |  )
 |
 |  Log to local file system in yaml and CSV format.
 |
 |  Logs are saved to ``os.path.join(save_dir, name, version)``.
 |
 |  Example:
 |      >>> from lightning.pytorch import Trainer
 |      >>> from lightning.pytorch.loggers import CSVLogger
 |      >>> logger = CSVLogger("logs", name="my_exp_name")
 |      >>> trainer = Trainer(logger=logger)
 |
 |  Args:
 |      save_dir: Save directory
 |      name: Experiment name, optional. Defaults to ``'lightning_logs'``. If name is ``None``, logs
 |          (versions) will be stored to the save dir directly.
 |      version: Expe

In [44]:
help(ModelCheckpoint)

Help on class ModelCheckpoint in module lightning.pytorch.callbacks.model_checkpoint:

class ModelCheckpoint(lightning.pytorch.callbacks.checkpoint.Checkpoint)
 |  ModelCheckpoint(
 |      dirpath: Union[str, pathlib._local.Path, NoneType] = None,
 |      filename: Optional[str] = None,
 |      monitor: Optional[str] = None,
 |      verbose: bool = False,
 |      save_last: Union[bool, Literal['link'], NoneType] = None,
 |      save_top_k: int = 1,
 |      save_on_exception: bool = False,
 |      save_weights_only: bool = False,
 |      mode: str = 'min',
 |      auto_insert_metric_name: bool = True,
 |      every_n_train_steps: Optional[int] = None,
 |      train_time_interval: Optional[datetime.timedelta] = None,
 |      every_n_epochs: Optional[int] = None,
 |      save_on_train_epoch_end: Optional[bool] = None,
 |      enable_version_counter: bool = True
 |  )
 |
 |  Save the model after every epoch by monitoring a quantity. Every logged metrics are passed to the
 |  :class:`~light

In [45]:
help(EarlyStopping)

Help on class EarlyStopping in module lightning.pytorch.callbacks.early_stopping:

class EarlyStopping(lightning.pytorch.callbacks.callback.Callback)
 |  EarlyStopping(
 |      monitor: str,
 |      min_delta: float = 0.0,
 |      patience: int = 3,
 |      verbose: bool = False,
 |      mode: str = 'min',
 |      strict: bool = True,
 |      check_finite: bool = True,
 |      stopping_threshold: Optional[float] = None,
 |      divergence_threshold: Optional[float] = None,
 |      check_on_train_epoch_end: Optional[bool] = None,
 |      log_rank_zero_only: bool = False
 |  )
 |
 |  Monitor a metric and stop training when it stops improving.
 |
 |  Args:
 |      monitor: quantity to be monitored.
 |      min_delta: minimum change in the monitored quantity to qualify as an improvement, i.e. an absolute
 |          change of less than or equal to `min_delta`, will count as no improvement.
 |      patience: number of checks with no improvement
 |          after which training will be stopp

In [46]:
logger = CSVLogger(
    save_dir='logs',
    name='autoencoder_mnist',
    version=None,
    prefix='test_'
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="autoencoder_best-{epoch:02d}-{val_loss:.3f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

In [47]:
model = LitAutoEncoder(Encoder(), Decoder())

In [48]:
trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=1000
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [49]:
trainer.fit(model, train_loader, valid_loader)

You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 54.7 K | train | 0    
1 | decoder | Decoder | 55.4 K | train | 0    
----------------------------------------------------
110 K     Trainable params
0         Non-trainable params
110 K     Total params
0.440     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 5:  20%|██        | 75/375 [00:00<00:02, 144.99it/s, v_num=1, val_loss=0.0647, train_loss=0.0661] 


Detected KeyboardInterrupt, attempting graceful shutdown ...


SystemExit: 1

In [None]:
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason

In [None]:
# Check why training stopped
if early_stop_callback.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
    print("Training stopped due to patience exhaustion")
elif early_stop_callback.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
    print("Training stopped due to reaching stopping threshold")
elif early_stop_callback.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
    print("Training completed normally without early stopping")

In [None]:
# Access human-readable message
if early_stop_callback.stopping_reason_message:
    print(f"Details: {early_stop_callback.stopping_reason_message}")