In [1]:
# Enable autoreload in Jupyter
%load_ext autoreload
%autoreload 2

# Imports and Seed Management

In [2]:
import os

# Set environment variables for reproducibility BEFORE importing torch
os.environ['PYTHONHASHSEED'] = '51'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import sys
from pathlib import Path

# Add project root to sys.path for module imports
PROJECT_ROOT = Path.cwd().parent
sys.path.append(str(PROJECT_ROOT))

import wandb
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset, DataLoader
import fiftyone as fo
from torch.optim import Adam
from pathlib import Path
from tabulate import tabulate

from src.datasets import CustomTorchImageDataset
from src.models import ConcatIntermediateNet, ConcatIntermediateNetWithStride
from src.training import train_model
from src.utils import (
    set_seeds,
    create_deterministic_training_dataloader,
    infer_model,
    get_mm_intermediate_inputs,
)

set_seeds(51)

PROJECT_NAME = "cilp-extended-assessment"

All random seeds set to 51 for reproducibility


# Dataset Loading

In [3]:
IMG_SIZE = 64

dataset_name = "cilp_assessment"

# Load the FiftyOne dataset from disk
dataset = fo.Dataset.from_dir(
    dataset_dir=Path.cwd().parent / dataset_name,
    dataset_type=fo.types.FiftyOneDataset,
)

print(f"Total samples in dataset: {len(dataset)}")

Importing samples...
 100% |███████████████| 3228/3228 [93.1ms elapsed, 0s remaining, 34.7K samples/s]   
Total samples in dataset: 1076


Extract train and test split of the dataset.

In [4]:
train_dataset = dataset.match_tags("train")
val_dataset = dataset.match_tags("validation")

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 897
Validation samples: 179


Generate custom torch datasets to use dataloader.

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device: ", device)

torch_train_dataset = CustomTorchImageDataset(
    fiftyone_dataset=train_dataset,
    img_size=IMG_SIZE,
    device=device,
)

torch_val_dataset = CustomTorchImageDataset(
    fiftyone_dataset=val_dataset,
    img_size=IMG_SIZE,
    device=device,
)

Device:  cpu
CustomTorchImageDataset initialized with 897 samples.
CustomTorchImageDataset initialized with 179 samples.


Create a DataLoader and use a deterministic setup for training to make the results reproducible

In [6]:
train_dataloader = create_deterministic_training_dataloader(
    torch_train_dataset,
    batch_size=32,
    shuffle=True,
)

val_dataloader = DataLoader(
    torch_val_dataset,
    batch_size=32,
    shuffle=False,
)

Create a concatinated dataset for inference.

In [7]:
concat_dataset = ConcatDataset([torch_train_dataset, torch_val_dataset])
print(f"Total samples in concat dataset: {len(concat_dataset)}")

concat_dataloader = DataLoader(
    concat_dataset,
    batch_size=32,
    shuffle=False,
)

Total samples in concat dataset: 1076


# Hyperparameters

For the loss function, we use the same one as in the assessment: **BCEWithLogitsLoss**. This loss works well with a single output neuron for binary classification. In later tasks, we set *num_classes=1* to ensure the model has only one output neuron, which aligns with this loss function.

In [8]:
loss_func = nn.BCEWithLogitsLoss()

We initialize our table for the comparison of the different architectures at the end.

In [9]:
table = [
    ["Metric", "Validation Loss", "Parameters (M)", "Training Time", "Final Accuracy"]
]

We define hyperparameters that are shared between both architectures.

In [10]:
epochs = 2
rgb_ch = 4
xyz_ch = 4
lr = 0.0001

# MaxPool2d

First, we start with the **MaxPool2d** version of the fusion model. This was already trained in the last notebook. To make them more independent, we train it here again.

In [11]:
mm_max_pool_net = ConcatIntermediateNet(rgb_ch, xyz_ch).to(device)
mm_max_pool_opt = Adam(mm_max_pool_net.parameters(), lr=lr)
mm_max_pool_save_path = Path.cwd().parent / "checkpoints" / "03_mm_max_pool_model.pth"
mm_max_pool_run = wandb.init(project=PROJECT_NAME, name=f"{ConcatIntermediateNet.__name__}")

print("Training mm_max_pool_net")
set_seeds(51)
mm_max_pool_train_loss, mm_max_pool_valid_loss, mm_max_pool_train_time = train_model(
    mm_max_pool_net,
    mm_max_pool_opt,
    loss_func,
    get_mm_intermediate_inputs,
    epochs,
    train_dataloader,
    val_dataloader,
    save_path=mm_max_pool_save_path,
    run=mm_max_pool_run,
)
mm_max_pool_num_params = mm_max_pool_net.get_number_of_parameters() / 1e6  # in millions

mm_max_pool_run.finish()

[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Training mm_max_pool_net
All random seeds set to 51 for reproducibility


0,1
epoch,▁█
learning_rate,▁▁
train_loss,█▁
valid_loss,▁█

0,1
epoch,2.0
learning_rate,0.0001
train_loss,0.23266
valid_loss,0.21223


Load best model and calculate accuracy.

In [12]:
best_mm_max_pool_model = ConcatIntermediateNet(rgb_ch, xyz_ch).to(device)
best_mm_max_pool_model.load_state_dict(torch.load(mm_max_pool_save_path))

mm_max_pool_accuracy, _ = infer_model(
    best_mm_max_pool_model,
    concat_dataloader,
    get_mm_intermediate_inputs,
)

Add metrics to table.

In [13]:
table.append([
    "MaxPool2d",
    np.min(mm_max_pool_valid_loss),
    mm_max_pool_num_params,
    mm_max_pool_train_time,
    mm_max_pool_accuracy,
])

# Strided Conv

Now, we train the model using the **Strided Conv** version. Here, we removed the pooling layers and added stride to the convolutional layers.

In [14]:
mm_stride_net = ConcatIntermediateNetWithStride(rgb_ch, xyz_ch).to(device)
mm_stride_opt = Adam(mm_stride_net.parameters(), lr=lr)
mm_stride_save_path = Path.cwd().parent / "checkpoints" / "03_mm_stride_model.pth"
mm_stride_run = wandb.init(project=PROJECT_NAME, name=f"{ConcatIntermediateNetWithStride.__name__}")

print("Training mm_stride_net")
set_seeds(51)
mm_stride_train_loss, mm_stride_valid_loss, mm_stride_train_time = train_model(
    mm_stride_net,
    mm_stride_opt,
    loss_func,
    get_mm_intermediate_inputs,
    epochs,
    train_dataloader,
    val_dataloader,
    save_path=mm_stride_save_path,
    run=mm_stride_run,
)
mm_stride_num_params = mm_stride_net.get_number_of_parameters() / 1e6  # in millions

mm_stride_run.finish()

Training mm_stride_net
All random seeds set to 51 for reproducibility


0,1
epoch,▁█
learning_rate,▁▁
train_loss,█▁
valid_loss,█▁

0,1
epoch,2.0
learning_rate,0.0001
train_loss,0.23844
valid_loss,0.20609


Load best model and calculate accuracy.

In [15]:
best_mm_stride_model = ConcatIntermediateNetWithStride(rgb_ch, xyz_ch).to(device)
best_mm_stride_model.load_state_dict(torch.load(mm_stride_save_path))

mm_stride_accuracy, _ = infer_model(
    best_mm_stride_model,
    concat_dataloader,
    get_mm_intermediate_inputs,
)

Add metrics to table.

In [16]:
table.append([
    "Strid-2 Conv2d",
    np.min(mm_stride_valid_loss),
    mm_stride_num_params,
    mm_stride_train_time,
    mm_stride_accuracy,
])

# Comparison

We calculate the differences for all our metrics.

In [17]:
differences = ["Difference"]
for i in range(1, len(table[0])):
    differences.append(table[2][i] - table[1][i])
table.append(differences)

Print the comparision table.

In [18]:
rows = list(zip(*table)) # transpose for tabulate
print(tabulate(rows[1:], headers=rows[0], tablefmt="grid"))

+-----------------+-------------+------------------+--------------+
| Metric          |   MaxPool2d |   Strid-2 Conv2d |   Difference |
| Validation Loss |   0.207008  |       0.20609    | -0.000918161 |
+-----------------+-------------+------------------+--------------+
| Parameters (M)  |  13.0159    |      13.0159     |  0           |
+-----------------+-------------+------------------+--------------+
| Training Time   |   7.09494   |       4.75805    | -2.33689     |
+-----------------+-------------+------------------+--------------+
| Final Accuracy  |   0.0288104 |       0.00743494 | -0.0213755   |
+-----------------+-------------+------------------+--------------+


## Theoretical Differences
Both approaches reduce the spatial size of an activation map, but they do so in different ways.

MaxPool2D selects the maximum value within each 2×2 window (with kernel size of 2). This halves the height and width of the activation map without learning any new parameters. Pooling is therefore parameter-free and purely statistical.

Strided Convolution reduces spatial resolution by moving the convolution filter. The stride defines how many pixels the filter shifts after each convolution step. A stride of 2 skips every second position in both spatial dimensions. Unlike pooling, this operation is learnable because the convolution weights are trained. The downsampling is therefore coupled with feature extraction.

## Impact on Gradient Flow and learned Features
Because MaxPool2D is not learnable, gradients do not pass through pooling weights (there are none). Gradients only flow back to the max-selected elements, meaning some activations receive no gradient signal. This can make the model slightly more selective and introduce sparsity.

In contrast, Strided Convolution learns both feature extraction and downsampling jointly. This means, that gradients propagate through the convolution weights and more elements contribute to the backward pass.

## Recommendation with Justification
Use Strided Convolution when you want:

the model to learn how to downsample

richer, more flexible feature learning

tighter control over how spatial information is preserved

Use MaxPool2D when you want:

a simpler, parameter-free downsampling method

stronger spatial invariance and feature sparsity

a slight regularization effect