In [1]:
import torch
import torch.nn as nn
import torchvision
import pytorch_lightning as pl
import multiprocessing

from lightly.data import LightlyDataset
from lightly.data import SwaVCollateFunction
from lightly.loss import SwaVLoss
from lightly.models.modules import SwaVProjectionHead, SwaVPrototypes
from lightly.models.utils import batch_shuffle, batch_unshuffle


In [8]:
# ==========================================
# 1. Configuration (Thesis Hyperparameters)
# ==========================================
# Your Hardware Config
BATCH_SIZE = 128         # Fits comfortably on RTX 5090 (32GB VRAM)
NUM_WORKERS = 12         # Utilizing your i9-14900K cores
INPUT_SIZE = 224         # Global crop size
MAX_EPOCHS = 100

# Model Config
NUM_FTRS = 2048          # ResNet-50 feature dim (Use 512 if switching to ResNet-18)
PROJ_HIDDEN_DIM = 2048   # Hidden layer in projection head
PROJ_OUTPUT_DIM = 128    # The unit sphere dimension (standard for SwAV)
N_PROTOTYPES = 50        # Number of clusters (Set > expected classes, e.g., 50 for 10 classes)
QUEUE_LENGTH = 3840      # Critical for Single-GPU: Buffer stores ~30 batches of history

# Path to your extracted frames (folder containing images)
DATASET_PATH = 'F:\\ExoNet_Images_curated'

In [3]:
# ==========================================
# 2. The SwAV Lightning Module
# ==========================================
class SwAVModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        # 1. Architecture
        resnet = torchvision.models.resnet50()
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.projection_head = SwaVProjectionHead(NUM_FTRS, PROJ_HIDDEN_DIM, PROJ_OUTPUT_DIM)
        self.prototypes = SwaVPrototypes(PROJ_OUTPUT_DIM, n_prototypes=N_PROTOTYPES)
        
        # 2. Loss & Queue
        self.criterion = SwaVLoss(sinkhorn_gather_distributed=False) 
        self.register_buffer("queue", torch.zeros(0, PROJ_OUTPUT_DIM))
        self.queue_length = QUEUE_LENGTH
        self.use_queue = False

    def forward(self, x):
        # Returns normalized FEATURES (on unit sphere)
        x = self.backbone(x).flatten(start_dim=1)
        x = self.projection_head(x)
        return x

    def training_step(self, batch, batch_idx):
        crops, _, _ = batch
        
        # 1. Get FEATURES for all crops
        multi_crop_features = [self.forward(crop) for crop in crops]
        
        # 2. Convert Features to LOGITS (Scores) using Prototypes
        multi_crop_logits = [self.prototypes(f, step=self.global_step) for f in multi_crop_features]
        
        high_res_logits = multi_crop_logits[:2]
        low_res_logits = multi_crop_logits[2:]

        # 3. Handle Queue (Concatenation Strategy)
        if self.use_queue and len(self.queue) > 0:
            with torch.no_grad():
                queue_logits = self.prototypes(self.queue.clone().detach(), step=self.global_step)
            
            # Append queue logits to the high-res logits
            high_res_logits_input = [torch.cat((l, queue_logits)) for l in high_res_logits]
        else:
            high_res_logits_input = high_res_logits

        # 4. Compute Loss
        loss = self.criterion(high_res_logits_input, low_res_logits)

        # 5. Update Queue (Store FEATURES)
        with torch.no_grad():
            batch_high_res_features = torch.cat(multi_crop_features[:2]).detach()
            
            if len(self.queue) < self.queue_length:
                self.queue = torch.cat([self.queue, batch_high_res_features])
            else:
                self.queue = torch.cat([self.queue[len(batch_high_res_features):], batch_high_res_features])
                self.queue = self.queue[:self.queue_length]

        self.log("train_loss", loss)
        return loss

    def configure_optimizers(self):
        # --- This was missing in the previous step ---
        # High LR is standard for SwAV / SSL methods
        optimizer = torch.optim.SGD(
            self.parameters(), 
            lr=0.6, 
            momentum=0.9, 
            weight_decay=1e-6
        )
        # Cosine Decay is standard for convergence
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, 
            T_max=MAX_EPOCHS
        )
        return [optimizer], [scheduler]

    def on_train_epoch_end(self):
        # Enable queue after 15 epochs to stabilize training
        if self.current_epoch >= 15:
            self.use_queue = True

In [9]:
# ==========================================
# 3. Data Loading & Multi-Crop Augmentation
# ==========================================
# This function handles the complex "2 global + 6 local" crop logic automatically
collate_fn = SwaVCollateFunction(
    crop_sizes=[224, 96],             # Global vs Local pixel sizes
    crop_counts=[2, 6],               # 2 Global views, 6 Local views
    crop_min_scales=[0.14, 0.05],     # Scaling factors (Zoom levels)
    crop_max_scales=[1.0, 0.14],
)

# LightlyDataset expects a folder of images (e.g., ./data/images/frame1.jpg)
dataset = LightlyDataset(input_dir=DATASET_PATH)

dataloader = torch.utils.data.DataLoader(
    dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate_fn,
    drop_last=True,
    num_workers=NUM_WORKERS,
    persistent_workers=True,
    pin_memory=True # Faster transfer to GPU
)

In [10]:
print(f"Training on: {torch.cuda.get_device_name(0)}")
    # Save the best model based on training loss
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    dirpath="checkpoints/",
    filename="swav_resnet50_city_{epoch:02d}",
    save_top_k=1,
    monitor="train_loss"
)

trainer = pl.Trainer(
    max_epochs=MAX_EPOCHS,
    accelerator="gpu",
    devices=1,
    precision="16-mixed",  # <--- CRITICAL for RTX 5090 (Speeds up 2x, saves VRAM)
    callbacks=[checkpoint_callback],
    log_every_n_steps=10
)

model = SwAVModel()
trainer.fit(model=model, train_dataloaders=dataloader)

print("Training Complete. Model saved in 'checkpoints/'")

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores


Training on: NVIDIA GeForce RTX 5090


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type               | Params | Mode  | FLOPs
-----------------------------------------------------------------------
0 | backbone        | Sequential         | 23.5 M | train | 0    
1 | projection_head | SwaVProjectionHead | 4.5 M  | train | 0    
2 | prototypes      | SwaVPrototypes     | 6.5 K  | train | 0    
3 | criterion       | SwaVLoss           | 0      | train | 0    
-----------------------------------------------------------------------
28.0 M    Trainable params
0         Non-trainable params
28.0 M    Total params
111.901   Total estimated model params size (MB)
160       Modules in train mode
0         Modules in eval mode
0         Total Flops


Training: |                                                                                      | 0/? [00:00<â€¦

RuntimeError: DataLoader worker (pid(s) 40920) exited unexpectedly