In [1]:
%load_ext autoreload
%autoreload 2

# Pretraining backbones with Minerva Learn from Randomness

This notebook provides a demonstration of how to pretrain feature extraction backbones using the Minerva Learn from Randomness model. 

## 1. Introduction

### Learn from Randomness (LFR)

LFR is a self-supervised learning (SSL) method for learning representations without requiring labeled data. 

In [None]:
# Took 10h in my computer
n_epochs = 81

# Dataloaders/Datamodule parameters
DL_BATCH_SIZE=2 ** 8
DL_NUM_WORKERS=4
DL_TRAIN_SAMPLES_PER_CLASS=128_000
DL_VAL_SAMPLES_PER_CLASS=None

# LFR Parameters
N_PROJECTORS=6
PREDICTOR_EPOCHS=3
LEARNING_RATE=3e-2

# Data directory
data_dir = "./data"

### 2.3 Importing basic modules

Let's import the basic modules, such as lightning, torch, minerva, and other utility modules.

In [3]:
import torch
import torchvision
import lightning
import minerva

print(f"PyTorch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
print(f"Lightning version: {lightning.__version__}")

lightning.seed_everything(1969)

  from .autonotebook import tqdm as notebook_tqdm
Seed set to 1969


PyTorch version: 2.7.1+cu126
torchvision version: 0.22.1+cu126
Lightning version: 2.5.1.post0


1969

## <a id="sec_3">3. Setting up the Dataset</a>

We will use the unlabeled split of the STL10 dataset to pretrain our backbone. 
To enable contrastive learning, we will apply a series of data transformations to generate randomly augmented views of each image.

For a detailed discussion of the data augmentation strategies used in the next code block, please refer to the tutorial:
`08_minerva_data_transforms.ipynb`.

In [4]:
from dataset_pcam import PCamDataModule

datamodule = PCamDataModule(data_dir=data_dir, batch_size=DL_BATCH_SIZE, num_workers=DL_NUM_WORKERS, val_samples_per_class=DL_VAL_SAMPLES_PER_CLASS)
class_names = datamodule.full_dataset.classes

In [5]:
import torch
from torchvision.transforms.v2 import Compose, ToImage, ToDtype, Normalize, RandomHorizontalFlip, RandomVerticalFlip, ColorJitter, CenterCrop, RandomGrayscale

precomputed_dataset_stats = {'mean': torch.tensor([0.6982, 0.5344, 0.6907]), 'std': torch.tensor([0.2343, 0.2761, 0.2113])}

# Set the training set image transformation pipeline
train_transform_pipeline = Compose([ToImage(),
                                    ToDtype(torch.float32, scale=True),
                                    CenterCrop(42),
                                    RandomHorizontalFlip(),
                                    RandomVerticalFlip(),
                                    RandomGrayscale(),
                                    ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.1),
                                    Normalize(precomputed_dataset_stats["mean"],
                                              precomputed_dataset_stats["std"])])

## <a id="sec_4">4. Create the Model for the Pretext Task</a>

### 4.1 Backbone and Projection Head Generation

We will use a modified version of the ResNet18 model as the backbone. 
Specifically, we replace its final fully connected (fc) layer with an identity layer—`torch.nn.Identity()`—which effectively removes any operation at that stage, allowing us to extract raw feature representations.

The `generate_backbone()` function handles this process: it instantiates a ResNet18 model, replaces its fully connected layer with an identity layer, and returns the modified model.

In the following code block, we instantiate the backbone and display its architecture using the summary() function from the torchinfo package.

In [6]:
from pcam.backbone import generate_backbone
from torchinfo import summary

# Generate the backbone and check its structure
backbone = generate_backbone()
summary(backbone,
        input_size=(DL_BATCH_SIZE, 3, 42, 42), # input data shape (N x C x H x W)
        col_names=["input_size", "output_size", "num_params", "trainable"],
        col_width=20,
        row_settings=["var_names"]
)

Layer (type (var_name))                       Input Shape          Output Shape         Param #              Trainable
DenseNet (DenseNet)                           [256, 3, 42, 42]     [256, 1920]          --                   True
├─Sequential (features)                       [256, 3, 42, 42]     [256, 1920, 1, 1]    --                   True
│    └─Conv2d (conv0)                         [256, 3, 42, 42]     [256, 64, 21, 21]    9,408                True
│    └─BatchNorm2d (norm0)                    [256, 64, 21, 21]    [256, 64, 21, 21]    128                  True
│    └─ReLU (relu0)                           [256, 64, 21, 21]    [256, 64, 21, 21]    --                   --
│    └─MaxPool2d (pool0)                      [256, 64, 21, 21]    [256, 64, 11, 11]    --                   --
│    └─_DenseBlock (denseblock1)              [256, 64, 11, 11]    [256, 256, 11, 11]   --                   True
│    │    └─_DenseLayer (denselayer1)         [256, 64, 11, 11]    [256, 32, 11, 11]   

In [7]:
class Projector(torch.nn.Module):
   def __init__(self, dim):
      super().__init__()
      
      self.network = torch.nn.Sequential(
         torch.nn.Conv2d(3, 8, kernel_size=3, padding=1),
         torch.nn.ReLU(),
         torch.nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1),
         torch.nn.ReLU(),
         torch.nn.MaxPool2d(2, 2),

         torch.nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1),
         torch.nn.ReLU(),
         torch.nn.Conv2d(32, 32, kernel_size=3, stride=1, padding=1),
         torch.nn.ReLU(),
         torch.nn.MaxPool2d(2, 2),

         torch.nn.Flatten(), 
         torch.nn.Linear(3200, dim)
      )

   def forward(self, x):
      return self.network(x)

In [8]:
class Predictor(torch.nn.Module):
    def __init__(self, in_features, hidden_dim, out_features):
        super().__init__()
        self.sequential = torch.nn.Sequential(
            torch.nn.Linear(in_features, hidden_dim, bias=False),
            torch.nn.BatchNorm1d(hidden_dim),
            torch.nn.ReLU(inplace=True),
            torch.nn.Linear(hidden_dim, out_features)
        )

    def forward(self, X):
        return self.sequential(X)

In [None]:
from minerva.models.ssl.lfr import RepeatedModuleList
from pcam.projectors import create_targets

projectors = create_targets(lambda: Projector(2048), N_PROJECTORS, datamodule)
predictors = RepeatedModuleList(len(projectors), Predictor, 1920, 256, 2048)

In [10]:
from pcam.lfr import LFRModel

backbone  = generate_backbone()
model = LFRModel(
    backbone=backbone, 
    projectors=projectors,
    predictors=predictors,
    predictor_epochs=PREDICTOR_EPOCHS,
    lr=LEARNING_RATE
)

### 5.2 Create the Downstream Benchmark

In [12]:
from pcam.benchmark import SGDBenchmark
downstream_benchmark = SGDBenchmark(datamodule=datamodule, backbone=backbone, train_samples=DL_TRAIN_SAMPLES_PER_CLASS, predictor_epochs=PREDICTOR_EPOCHS)

## <a id="sec_6">6. Training the model</a>

In [21]:
from lightning import Trainer
from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor
from lightning.pytorch.loggers import TensorBoardLogger, CSVLogger

log_ckpt_dir=f"logs/PCam/Pretext"

checkpoint_callback = ModelCheckpoint(dirpath=f"{log_ckpt_dir}/checkpoints", save_weights_only=True, mode='min', monitor='val_loss', save_last="link")
trainer = Trainer(max_epochs=n_epochs,
                  log_every_n_steps=5,
                  benchmark=True,
                  callbacks=[checkpoint_callback, 
                             ModelCheckpoint(dirpath=f"{log_ckpt_dir}/checkpoints", save_weights_only=True, every_n_epochs=4, save_top_k=-1),
                             LearningRateMonitor("epoch"), 
                             downstream_benchmark],
                  logger = [TensorBoardLogger(save_dir=log_ckpt_dir, name=f"LFR-DenseNet201"),
                            CSVLogger(save_dir=log_ckpt_dir, name=f"LFR-DenseNet201")])

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_dataloaders=datamodule.train_dataloader(samples_per_class=DL_TRAIN_SAMPLES_PER_CLASS, transform=train_transform_pipeline), val_dataloaders=datamodule.val_dataloader())

In [27]:
print(f"Loading best model from {checkpoint_callback.best_model_path}")
best_model = LFRModel.load_from_checkpoint(checkpoint_callback.best_model_path, backbone=backbone, predictors=predictors, projectors=projectors, predictor_epochs=PREDICTOR_EPOCHS, lr=LEARNING_RATE)
torch.save(model.backbone.state_dict(), f"{log_ckpt_dir}/checkpoints/backbone.ckpt")

Loading best model from /home/igor/Desktop/mo810/course-work/logs/PCam/Pretext/checkpoints/epoch=81-step=168000.ckpt
