# BYOL Pretraining with Attention-Weighted Pooling

This notebook demonstrates a **self-supervised learning** pipeline using the **Bootstrap Your Own Latent (BYOL)** framework, augmented with an **attention-weighted pooling** layer. Applied to the **PituPhase** dataset of endoscopic pituitary surgery frames, our goal is to learn rich visual embeddings without requiring manual labels.

By integrating attention-weighted pooling into BYOL, we capture nuanced spatial details—critical in differentiating subtle surgical phases—while benefiting from BYOL’s stable, negative-free learning paradigm. These pretrained embeddings will later be evaluated in a downstream linear classification task for surgical phase recognition.

**Key Objectives:**
1. **Data Augmentation**  
   Create two distinct views of each input frame through stochastic transformations (random resized crop, color jitter, Gaussian blur, etc.), forming positive pairs for representation alignment.

2. **Online & Target Networks with Attention**  
   Use a ResNet-50 encoder in both online and target networks. After the backbone, apply an attention-weighted pooling operator to highlight salient spatial features before projection.

3. **Prediction & Exponential Moving Average (EMA)**  
   Pass pooled features through a predictor MLP in the online network. Update the target network weights via EMA of the online network’s parameters, ensuring stable learning without negative samples.

4. **Mean Squared Error Loss**  
   Optimize the mean squared error between normalized online predictions and target projections, computed symmetrically for both view orderings.

5. **Training Workflow & Logging**  
   Implement the model using PyTorch Lightning, track loss curves on a validation set, and save the best checkpoints automatically.


![BYOL%20attention%20architecture.png](attachment:BYOL%20attention%20architecture.png)

### Imports

In [1]:
from __future__ import print_function, division
import os
import torch
import torchvision
import pandas as pd
from skimage import io, transform, util
from sklearn import metrics
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from torch.utils.data import Dataset, DataLoader, TensorDataset, Subset
from torchvision import transforms, models, utils
from PIL import Image
import random
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler

# PyTorch Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping

Device: cuda:0
Number of workers: 16


## 1. Data Augmentation for Contrastive Learning


In [2]:
class ContrastiveTransformations:
    """
    A stochastic data augmentation module
    Transforms any given data example randomly
    resulting in two correlated views of the same example,
    denoted x ̃i and x ̃j, which we consider as a positive pair.
    """
    def __init__(self, base_transforms, n_views=2):
        self.base_transforms = base_transforms
        self.n_views = n_views

    def __call__(self, sample):
        image, label = sample['image'], sample['label']

        # Convert NumPy array to PIL Image if necessary
        if isinstance(image, np.ndarray):
            image_pil = Image.fromarray(image)
        else:
            image_pil = image

        # Apply base transformations to create two views
        transformed_images = [self.base_transforms(image_pil) for _ in range(self.n_views)]
        
        return {'image': transformed_images, 'label': label}

In [None]:
# Define BYOL augmentations
byol_transform = transforms.Compose([
    transforms.RandomResizedCrop(size=224),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomApply([transforms.ColorJitter(brightness=0.6, 
                                                   contrast=0.4, 
                                                   saturation=0.5, 
                                                   hue=0.5)], p=0.6),
    transforms.RandomApply([transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.6),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.312, 0.120, 0.117], std=[0.280, 0.158, 0.160])])

contrastive_augment = ContrastiveTransformations(byol_transform, n_views=2)

## 2. BYOL Implementation

### Load dataset

In [3]:
class PituDataset(Dataset):
    """Pituitary Endoscopy dataset."""

    def __init__(self, csv_file, root_dir, transform=None, maxSize=0, unlabeled=False):
        """
        Args:
            csv_file (string): Path to the csv file with annotations.
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
            maxSize (int, optional): Maximum size of the dataset (number of samples).
            unlabeled (bool, optional): If True, ignore labels.
        """
        self.dataset = pd.read_csv(csv_file, header=0, dtype={'id': str, 'label': int})
        
        if maxSize > 0:
            newDatasetSize = maxSize  # maxSize samples (Parameter to select a specific number of images)
            idx = np.random.RandomState(seed=42).permutation(range(len(self.dataset)))
            reduced_dataset = self.dataset.iloc[idx[0:newDatasetSize]]
            self.dataset = reduced_dataset.reset_index(drop=True)

        self.root_dir = root_dir
        self.img_dir = os.path.join(root_dir, 'images')
        self.transform = transform
        self.unlabeled = unlabeled
        self.classes = ['Desconocida', 'Preparacion colgajo', 'Etmoidectomia', 'Apertura selar', 
                        'Apertura dural', 'Reseccion tumoral', 'Cierre']

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        # Read the image
        img_name = os.path.join(self.img_dir, self.dataset.id[idx] + '.png')
        image = io.imread(img_name)
        
        if self.unlabeled:
            sample = {'image': image, 'label': np.int64(-1)}  # Use -1 to indicate unlabeled, keep datatype
        else:
            sample = {'image': image, 'label': self.dataset.label[idx].astype(dtype=np.long)}
        
        if self.transform:
            sample = self.transform(sample)
        return sample 

In [None]:
# Train Dataset
unlabeled_train_dataset = PituDataset(csv_file="/home/train_set.csv",
                                      root_dir='/home',
                                      #maxSize=100000,
                                      transform=contrastive_augment,
                                      unlabeled=True)

# Validation Dataset
val_dataset = PituDataset(csv_file="/home/acenteno/val_set.csv",
                          root_dir='/home',
                          transform=contrastive_augment)

### Define Attention-Weighted Pooling Layer

In [8]:
# define the attention layer
class AttentionWeightedPooling(nn.Module):
    def __init__(self, input_dim):
        super(AttentionWeightedPooling, self).__init__()
        
        # Apply three consecutive convolution blocks (convolution, batch norm, ReLU activation)
        # with decreasing number of filters 
        # and then one final convolution with one filter (sigmoid activation)
        
        self.conv1 = nn.Conv2d(input_dim, input_dim // 2, kernel_size=1, padding='same')
        self.bn1 = nn.BatchNorm2d(input_dim // 2)
        self.relu1 = nn.ReLU()
        
        self.conv2 = nn.Conv2d(input_dim // 2, input_dim // 4, kernel_size=1, padding='same')
        self.bn2 = nn.BatchNorm2d(input_dim // 4)
        self.relu2 = nn.ReLU()
        
        self.conv3 = nn.Conv2d(input_dim // 4, input_dim // 8, kernel_size=1, padding='same')
        self.bn3 = nn.BatchNorm2d(input_dim // 8)
        self.relu3 = nn.ReLU()
        
        self.attn_conv = nn.Conv2d(input_dim // 8, 1, kernel_size=1, padding='valid')
        
        self.global_avg_pool = nn.AdaptiveAvgPool2d((1, 1))
    
    def forward(self, x):
        attn = self.relu1(self.bn1(self.conv1(x)))
        attn = self.relu2(self.bn2(self.conv2(attn)))
        attn = self.relu3(self.bn3(self.conv3(attn)))
        attn = torch.sigmoid(self.attn_conv(attn))
        
        attn_up = attn * x
        
        gap_features = self.global_avg_pool(attn_up)
        gap_mask = self.global_avg_pool(attn)
        
        gap_features = gap_features / gap_mask
        return gap_features.squeeze()

### Define BYOL Class

In [5]:
# Define the MLP used in the projection head p(·) and predictor q(·)
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim=4096, projection_size=512):
        super().__init__()
        # The MLP consists of Linear->BN->ReLU->Linear
        self.net = nn.Sequential( 
            nn.Linear(dim, hidden_dim), # Linear(2048, 4096) (entra h)
            nn.BatchNorm1d(hidden_dim),
            nn.ReLU(inplace=True),
            nn.Linear(hidden_dim, projection_size) # Linear(4096, 512) (sale z)
        )
    
    def forward(self, x):
        return self.net(x)

In [7]:
def L2_loss(x, y):
    x = F.normalize(x, dim=1)
    y = F.normalize(y, dim=1)
    return 2 - 2 * (x * y).sum(dim=-1)

In [None]:
class EMA():
    def __init__(self, alpha):
        super().__init__()
        self.alpha = alpha

    def update_average(self, target, online):
        return target * self.alpha + (1 - self.alpha) * online

In [9]:
# Define the BYOL model
class BYOL(pl.LightningModule):
    """
    BYOL (Bootstrap Your Own Latent) implementation.

    This class defines the architecture and training process for a self-supervised learning
    model, allowing it to learn useful representations without using labeled data.
    """
    
    def __init__(self, hidden_dim, projection_size, lr, momentum, weight_decay, moving_average_decay):
        super().__init__()
        """
        Args:
            hidden_dim (int): The size of the hidden vector in the MLPs of the student and teacher projection heads.
            projection_size (int): The size of the output vector from the projection head (dimension of the embedding space).
            lr (float): Learning rate for the optimizer.
            momentum (float): Momentum parameter for the SGD optimizer.
            weight_decay (float): Weight decay for L2 regularization.
            moving_average_decay (float): Decay factor for the exponential moving average used to update the teacher model. e.g. 0.99
        """
        self.save_hyperparameters()
        
        # Base encoder f(.): ResNet-50
        self.backbone = torchvision.models.resnet50()
        in_features = self.backbone.fc.in_features  # 2048 for ResNet-50 
        #self.backbone.fc = nn.Identity()  # Remove the final classification layer to get the feature vector
        self.backbone = nn.Sequential(*list(self.backbone.children())[:-2])
        
        # Add attention pooling layer a(.)
        self.attn_pooling = AttentionWeightedPooling(in_features) #(this is vector h)
        
        # Projection head p(·) --> consists of Linear->BN->ReLU->Linear
        self.student_projector = MLP(in_features, hidden_dim, projection_size) #(2048,4096,512)
        
        # Prediction head q(·) --> lo usa el student model para parecerse al teacher
        self.student_predictor = MLP(projection_size, hidden_dim, projection_size)  #(512,4096,512)
        
        # Teacher model
        self.teacher_projector = copy.deepcopy(self.student_projector)
        
        self.target_ema_updater = EMA(moving_average_decay)
        self.moving_average_decay = moving_average_decay
        
   
    def configure_optimizers(self):
        '''optimizer = optim.AdamW(self.parameters(), 
                                lr=self.hparams.lr,
                                weight_decay=self.hparams.weight_decay)
        '''
        optimizer = optim.SGD(self.parameters(),
                              lr=self.hparams.lr,
                              weight_decay=self.hparams.weight_decay,
                              momentum=self.hparams.momentum)
        
        return optimizer


    @torch.no_grad()
    def update_moving_average(self):
        """
        Updates the weights of the teacher model as a moving average of the student model's weights.
        """
        for student_params, teacher_params in zip(self.student_projector.parameters(), self.teacher_projector.parameters()):
            # teacher = teacher * self.alpha + (1 - self.alpha) * student 
            # where alpha = moving average decay
            
            #old_weight, up_weight = teacher_params.data, student_params.data
            #teacher_params.data = self.target_ema_updater.update_average(old_weight, up_weight)
            
            teacher_params.data = teacher_params.data * self.moving_average_decay + (1. - self.moving_average_decay) * student_params.data
          
        
    def initializes_target_network(self):
        '''
        Initializes the target (teacher) network with the same weights as the student model.
        Ensures the teacher's parameters do not require gradient updates.
        '''
        for student_params, teacher_params in zip(self.student_projector.parameters(), self.teacher_projector.parameters()):
            teacher_params.data.copy_(student_params.data)  # initialize
            teacher_params.requires_grad = False  # not update by gradient
            
    
    def on_train_start(self):
        # Initialize the teacher network at the start of training
        self.initializes_target_network()


    def forward(self, x):
        '''
        Forward pass through the student network and student projector
        '''
        features = self.backbone(x)
        attn_features = self.attn_pooling(features)
        student_projection = self.student_projector(attn_features)
        student_prediction = self.student_predictor(student_projection)
        return student_prediction
    

    def shared_step(self, img1, img2):
    
        # get student projections: backbone + atention + MLP projection head
        feats1 = self.backbone(img1)  
        feats2 = self.backbone(img2)
        
        # ahora aplicar atencion
        attn_feats1 = self.attn_pooling(feats1) 
        attn_feats2 = self.attn_pooling(feats2)
        
        student_proj1 = self.student_projector(attn_feats1) 
        student_proj2 = self.student_projector(attn_feats2)

        # Apply the predictor MLP to the student's projections
        student_pred1 = self.student_predictor(student_proj1) 
        student_pred2 = self.student_predictor(student_proj2)

        # Get teacher projections (no gradient updates)
        with torch.no_grad():
            # teacher processes the images and makes projections: backbone + MLP
            teacher_proj1 = self.teacher_projector(attn_feats1)
            teacher_proj2 = self.teacher_projector(attn_feats2)
        
        # calculate loss
        loss = L2_loss(student_pred1, teacher_proj2)
        loss += L2_loss(student_pred2, teacher_proj1)        

        return loss.mean() #loss = (loss1 + loss2).mean()

    
    def training_step(self, batch, batch_idx):
        img1, img2 = batch['image'][0], batch['image'][1]
        img1 = img1.to(device=device, dtype=torch.float)
        img2 = img2.to(device=device, dtype=torch.float)

        loss = self.shared_step(img1, img2)
        self.log('train_loss', loss)

        # Update the teacher model
        self.update_moving_average()
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        img1, img2 = batch['image'][0], batch['image'][1]
        img1 = img1.to(device=device, dtype=torch.float)
        img2 = img2.to(device=device, dtype=torch.float)

        loss = self.shared_step(img1, img2)
        self.log('val_loss', loss)
                
        return loss

### Pre-Training 

In [16]:
def train_byol(batch_size, max_epochs=5, **kwargs):
    trainer = pl.Trainer(default_root_dir='/home/byol_models',
                         accelerator="gpu",
                         devices=1,
                         max_epochs=max_epochs,
                         callbacks=[ModelCheckpoint(save_weights_only=True, mode='max', monitor='val_loss'),
                                    LearningRateMonitor('epoch'), EarlyStopping(monitor="val_loss", patience=3, mode="min")])
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    train_loader = DataLoader(unlabeled_train_dataset, batch_size=batch_size, shuffle=True, 
                                    drop_last=True, pin_memory=True, num_workers=NUM_WORKERS)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, 
                                    drop_last=False, pin_memory=True, num_workers=NUM_WORKERS)

    pl.seed_everything(42) # To be reproducable

    
    # Initialize the BYOL model
    byol_model = BYOL(hidden_dim=4096, 
                        projection_size=512,
                        lr=0.03, 
                        momentum=0.9,
                        weight_decay=0.0004,
                        moving_average_decay=0.996)
    
    trainer.fit(byol_model, train_loader, val_loader)
    
    model = BYOL.load_from_checkpoint(trainer.checkpoint_callback.best_model_path) # Load best checkpoint after training

    return model

In [None]:
byol_model = train_byol(batch_size=64, 
                        hidden_dim=4096, 
                        projection_size=512,
                        lr=0.03, 
                        momentum=0.9,
                        weight_decay=0.0004,
                        moving_average_decay=0.996,
                        max_epochs=100)


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
Global seed set to 42
You are using a CUDA device ('NVIDIA GeForce RTX 3090') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name              | Type                     | Params
---------------------------------------------------------------
0 | backbone          | Sequential               | 23.5 M
1 | attn_pooling      | AttentionWeightedPooling | 2.8 M 
2 | student_projector | MLP                      | 10.5 M
3 | student_predictor | MLP                      | 4.2 M 
4 | teacher_projector | MLP                      | 10.5 M
-------

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
# Save the model
save_dir = '/home/byol_models'
os.makedirs(save_dir, exist_ok=True)

# Define the path 
save_path = os.path.join(save_dir, 'byol_attention_model.pt')

# Save the model's state dictionary
torch.save(simclr_model.state_dict(), save_path)

print(f"Best model saved to {save_path}")