In [1]:
import torch
import torch.nn as nn

def count_trainable_params(model: nn.Module) -> int:
    """
    Count the number of trainable parameters in a PyTorch model.

    Args:
        model (nn.Module): The model to inspect.

    Returns:
        int: Total number of parameters with requires_grad=True.
    """
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def report(model):
    total_trainable = count_trainable_params(model)
    total_params    = sum(p.numel() for p in model.parameters())
    print(f"Trainable parameters: {total_trainable:,}")
    print(f"Total parameters:     {total_params:,}")
    print(f"Trainable %:          {100 * total_trainable / total_params:.2f}%")


In [4]:
from base_model import BaseClassifier
from torch_utils import freeze_layers

model = BaseClassifier()
freeze_layers([model.backbone])

report(model)

Trainable parameters: 876,826
Total parameters:     12,053,338
Trainable %:          7.27%


In [3]:
from base_model import BaseClassifier
from torch_utils import freeze_layers

model = BaseClassifier(num_res_blocks=0, hidden_dim=96)
freeze_layers([model.backbone])

report(model)

Trainable parameters: 178,394
Total parameters:     11,354,906
Trainable %:          1.57%


In [3]:
print(1e-3)

0.001
