# Level 3: Transfer Learning
## Use pretrained models

https://lightning.ai/docs/pytorch/stable/advanced/transfer_learning.html

In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
import torch.nn.functional as F

from lightning import LightningModule
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.loggers import CSVLogger

from torchvision import datasets
import torchvision.transforms as transforms

import os

In [2]:
# Load data sets
transform = transforms.ToTensor()
train_set = datasets.MNIST(root="../data/MNIST", download=True, train=True, transform=transform)
test_set = datasets.MNIST(root="../data/MNIST", download=True, train=False, transform=transform)

# use 20% of training data for validation
train_set_size = int(len(train_set) * 0.8)
valid_set_size = len(train_set) - train_set_size

# split the train set into two
seed = torch.Generator().manual_seed(42)
train_set, valid_set = random_split(train_set, [train_set_size, valid_set_size], generator=seed)

train_loader = DataLoader(train_set, batch_size=128, shuffle=True, num_workers=16, persistent_workers=True, pin_memory=True)
valid_loader = DataLoader(valid_set, batch_size=128, num_workers=16, persistent_workers=True, pin_memory=True)
test_loader = DataLoader(test_set, batch_size=1024, num_workers=16, persistent_workers=True, pin_memory=True)

### Pretraining phase
Define the encoder and the decoder. In the Pretraining phase we train the encoder to recreate a significant representation of the image data. This is done in such a way that a decoder is able to recosntruct the full image.

In [3]:
class Encoder(nn.Module):
    def __init__(self, in_dim=28*28, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=4):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class Decoder(nn.Module):
    def __init__(self, in_dim=4, hidden_nodes_1=64, hidden_nodes_2=64, out_dim=28*28):
        super().__init__()
        self.ff = nn.Sequential(
            nn.Linear(in_dim, hidden_nodes_1),
            nn.ReLU(),
            nn.Linear(hidden_nodes_1, hidden_nodes_2),
            nn.ReLU(),
            nn.Linear(hidden_nodes_2, out_dim)
        )

    def forward(self, x):
        return self.ff(x)

class LitAutoEncoder(LightningModule):
    def __init__(self, encoder, decoder, lr=1e-5):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.lr = lr
        self.save_hyperparameters(ignore=["encoder", "decoder"])

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def _get_loss(self, batch):
        x, _ = batch
        x = x.view(x.size(0), -1)
        x_hat = self.forward(x)
        loss = F.mse_loss(x_hat, x)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer
    
    def forward(self, x):
        z = self.encoder(x)
        x_hat = self.decoder(z)
        return x_hat

In [4]:
model = LitAutoEncoder(
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    )
)

In [5]:
logger = CSVLogger(
    save_dir='logs',
    name='autoencoder_mnist',
    version=None,
    prefix='test_'
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="autoencoder_best-{epoch:02d}-{val_loss:.3f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

In [6]:
trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=1000
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [7]:
trainer.fit(model, train_loader, valid_loader)

You are using a CUDA device ('NVIDIA L40S') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name    | Type    | Params | Mode  | FLOPs
----------------------------------------------------
0 | encoder | Encoder | 558 K  | train | 0    
1 | decoder | Decoder | 247 K  | train | 0    
----------------------------------------------------
806 K     Trainable params
0         Non-trainable params
806 K     Total params
3.226     Total estimated model params size (MB)
14        Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 928: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 375/375 [00:03<00:00, 106.33it/s, v_num=0, val_loss=0.00538, train_loss=0.00536]


In [8]:
from lightning.pytorch.callbacks.early_stopping import EarlyStoppingReason

In [9]:
# Check why training stopped
if early_stop_callback.stopping_reason == EarlyStoppingReason.PATIENCE_EXHAUSTED:
    print("Training stopped due to patience exhaustion")
elif early_stop_callback.stopping_reason == EarlyStoppingReason.STOPPING_THRESHOLD:
    print("Training stopped due to reaching stopping threshold")
elif early_stop_callback.stopping_reason == EarlyStoppingReason.NOT_STOPPED:
    print("Training completed normally without early stopping")

Training stopped due to patience exhaustion


In [10]:
# Access human-readable message
if early_stop_callback.stopping_reason_message:
    print(f"Details: {early_stop_callback.stopping_reason_message}")

Details: Monitored metric val_loss did not improve in the last 3 records. Best score: 0.005. Signaling Trainer to stop.


### Training Phase
Now we define the classifier

In [13]:
AutoEncoder_checkpoint_path = checkpoint_callback.best_model_path
print(f"Best model checkpoint path: {AutoEncoder_checkpoint_path}")

Best model checkpoint path: /storage3/DSIP/rriva/tutorials/pl_tutorial/basic/logs/autoencoder_mnist/version_0/checkpoints/autoencoder_best-epoch=925-val_loss=0.005.ckpt


In [None]:
class MNISTClassifier(LightningModule):
    def __init__(self, checkpoint_path,encoder,decoder):
        super().__init__()
        # init the pretrained LightningModule
        autoencoder = LitAutoEncoder.load_from_checkpoint(
            checkpoint_path,
            encoder=encoder,
            decoder=decoder
        )
        self.feature_extractor = autoencoder.encoder
        #self.feature_extractor.freeze()
        self.feature_extractor.requires_grad_(False)
        self.lr = autoencoder.lr
        self.save_hyperparameters(ignore=["feature_extractor"])
        
        # the autoencoder outputs a 100-dim representation and CIFAR-10 has 10 classes
        self.classifier = nn.Linear(encoder.ff[-1].out_features, 10)

    def training_step(self, batch, batch_idx):
        # training_step defines the train loop.
        loss = self._get_loss(batch)
        self.log("train/loss", loss, on_step=True, on_epoch=True)
        self.log("train_loss", loss, prog_bar=True, on_step=False, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        # this is the validation loop
        loss = self._get_loss(batch)
        self.log("val_loss", loss, prog_bar=True)

    def test_step(self, batch, batch_idx):
        # this is the test loop
        loss = self._get_loss(batch)
        self.log("test_loss", loss)

    def _get_loss(self, batch):
        x, y = batch
        x = x.view(x.size(0), -1)
        y_hat = self.forward(x)
        loss = F.cross_entropy(y_hat, y)
        return loss

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
        return optimizer

    def forward(self, x):
        representations = self.feature_extractor(x)
        x = self.classifier(representations)
        return x

In [50]:
classifier_model = MNISTClassifier(AutoEncoder_checkpoint_path,
    encoder=Encoder(
        in_dim=28*28,
        hidden_nodes_1=512,
        hidden_nodes_2=256,
        out_dim=100
    ),
    decoder=Decoder(
        in_dim=100,
        hidden_nodes_1=128,
        hidden_nodes_2=256,
        out_dim=28*28
    )                                
)

In [51]:
logger = CSVLogger(
    save_dir='logs',
    name='classifier_mnist',
    version=None
)

checkpoint_callback = ModelCheckpoint(
    dirpath=os.path.join(logger.log_dir, "checkpoints"),
    filename="classifier-{epoch:02d}-{val_loss:.3f}",
    monitor="val_loss",    
    mode="min",
    save_top_k=3,     # keep ONLY the best
    save_last=True    # ALSO save last.ckpt
)

early_stop_callback = EarlyStopping(
    monitor="val_loss",
    patience=3,
    verbose=False,
    mode="min"
)

In [52]:
classifier_trainer = Trainer(
    logger=logger,
    callbacks=[checkpoint_callback, early_stop_callback],
    accelerator="gpu",
    devices=1,
    max_epochs=1000
)

/storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib/python3.13/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /storage3/DSIP/rriva/tutorials/pl_tutorial/.venv/lib ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


In [None]:
classifier_trainer.fit(classifier_model, train_loader, valid_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]

  | Name              | Type    | Params | Mode  | FLOPs
--------------------------------------------------------------
0 | feature_extractor | Encoder | 558 K  | train | 0    
1 | classifier        | Linear  | 1.0 K  | train | 0    
--------------------------------------------------------------
1.0 K     Trainable params
558 K     Non-trainable params
559 K     Total params
2.240     Total estimated model params size (MB)
8         Modules in train mode
0         Modules in eval mode
0         Total Flops


Epoch 494:  69%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–Š   | 257/375 [00:01<00:00, 143.71it/s, v_num=3, val_loss=0.384, train_loss=0.363]

In [1]:
# Access human-readable message
if early_stop_callback.stopping_reason_message:
    print(f"Details: {early_stop_callback.stopping_reason_message}")

NameError: name 'early_stop_callback' is not defined

## Automated Finetuning with Callbacks

PyTorch Lightning provides the BackboneFinetuning callback to automate the finetuning process. This callback gradually unfreezes your modelâ€™s backbone during training. This is particularly useful when working with large pretrained models, as it allows you to start training with a frozen backbone and then progressively unfreeze layers to fine-tune the model.

The `BackboneFinetuning` callback expects your model to have a specific structure:

In [None]:
class MyModel(LightningModule):
    def __init__(self):
        super().__init__()

        # REQUIRED: Your model must have a 'backbone' attribute
        # This should be the pretrained part you want to finetune
        self.backbone = some_pretrained_model

        # Your task-specific layers (head, classifier, etc.)
        self.head = nn.Linear(backbone_features, num_classes)

    def configure_optimizers(self):
        # Only optimize the head initially - backbone will be added automatically
        return torch.optim.Adam(self.head.parameters(), lr=1e-3)

### Example: Computer Vision with ResNet

In [4]:
import torch
import torch.nn as nn
import torchvision.models as models
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import BackboneFinetuning


class ResNetClassifier(LightningModule):
    def __init__(self, num_classes=10, learning_rate=1e-3):
        super().__init__()
        self.save_hyperparameters()

        # Create backbone from pretrained ResNet
        resnet = models.resnet50(weights="DEFAULT")
        # Remove the final classification layer
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])

        # Add custom classification head
        self.head = nn.Sequential(
            nn.Flatten(),
            nn.Linear(resnet.fc.in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        # Extract features with backbone
        features = self.backbone(x)
        # Classify with head
        return self.head(features)

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = nn.functional.cross_entropy(y_hat, y)
        self.log('train_loss', loss)
        return loss

    def configure_optimizers(self):
        # Initially only train the head - backbone will be added by callback
        return torch.optim.Adam(self.head.parameters(), lr=self.hparams.learning_rate)


# Setup the finetuning callback
backbone_finetuning = BackboneFinetuning(
    unfreeze_backbone_at_epoch=10,  # Start unfreezing backbone at epoch 10
    lambda_func=lambda epoch: 1.5,  # Gradually increase backbone learning rate
    backbone_initial_ratio_lr=0.1,  # Backbone starts at 10% of head learning rate
    should_align=True,  # Align rates when backbone rate reaches head rate
    verbose=True  # Print learning rates during training
)

model = ResNetClassifier()
trainer = Trainer(callbacks=[backbone_finetuning], max_epochs=20)

Trainer will use only 1 of 8 GPUs because it is running inside an interactive / notebook environment. You may try to set `Trainer(devices=8)` but please note that multi-GPU inside interactive / notebook environments is considered experimental and unstable. Your mileage may vary.
ðŸ’¡ Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


### Custom Finetuning Strategies
For more control, you can create custom finetuning strategies by subclassing `BaseFinetuning`:

In [None]:
from lightning.pytorch.callbacks.finetuning import BaseFinetuning
b
class CustomFinetuning(BaseFinetuning):
    def __init__(self, unfreeze_at_epoch=5, layers_per_epoch=2):
        super().__init__()
        self.unfreeze_at_epoch = unfreeze_at_epoch
        self.layers_per_epoch = layers_per_epoch

    def freeze_before_training(self, pl_module):
        # Freeze the entire backbone initially
        self.freeze(pl_module.backbone)

    def finetune_function(self, pl_module, epoch, optimizer):
        # Gradually unfreeze layers
        if epoch >= self.unfreeze_at_epoch:
            layers_to_unfreeze = min(
                self.layers_per_epoch,
                len(list(pl_module.backbone.children()))
            )

            # Unfreeze from the top layers down
            backbone_children = list(pl_module.backbone.children())
            for layer in backbone_children[-layers_to_unfreeze:]:
                self.unfreeze_and_add_param_group(
                    layer, optimizer, lr=1e-4
                )