In [1]:
# Enable autoreload in Jupyter
%load_ext autoreload
%autoreload 2

# Imports and Seed Management

In [None]:
import os

# Set environment variables for reproducibility BEFORE importing torch
os.environ['PYTHONHASHSEED'] = '51'
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

import sys
from pathlib import Path

# Add project root to sys.path for module imports
PROJECT_ROOT = Path.cwd().parent
sys.path.append(str(PROJECT_ROOT))

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import ConcatDataset, DataLoader
import fiftyone as fo
from torch.optim import Adam
from pathlib import Path

from src.datasets import CustomTorchImageDataset
from src.models import (
    ContrastivePretraining,
    Embedder,
    Projector,
    RGB2LiDARClassifier,
)
from src.training import train_model
from src.utils import (
    set_seeds,
    create_deterministic_training_dataloader,
    get_rgb_input,
    get_mm_intermediate_inputs,
    infer_model,
)

set_seeds(51)

All random seeds set to 51 for reproducibility
All random seeds set to 51 for reproducibility


# Dataset Loading

In [3]:
IMG_SIZE = 64

dataset_name = "cilp_assessment"

dataset = fo.Dataset.from_dir(
    dataset_dir=Path.cwd().parent / dataset_name,
    dataset_type=fo.types.FiftyOneDataset,
)

print(f"Total samples in dataset: {len(dataset)}")

Importing samples...
 100% |█████████████| 32253/32253 [1.0s elapsed, 0s remaining, 31.6K samples/s]         
Total samples in dataset: 10751


Extract train and test split of the dataset.

In [4]:
train_dataset = dataset.match_tags("train")
val_dataset = dataset.match_tags("validation")

# select 10% of both
train_dataset = train_dataset.take(int(0.1 * len(train_dataset)), seed=51)
val_dataset = val_dataset.take(int(0.1 * len(val_dataset)), seed=51)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

Training samples: 895
Validation samples: 179


Generate custom torch datasets to use dataloader.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print("Device: ", device)

torch_train_dataset = CustomTorchImageDataset(
    fiftyone_dataset=train_dataset,
    img_size=IMG_SIZE,
)

torch_val_dataset = CustomTorchImageDataset(
    fiftyone_dataset=val_dataset,
    img_size=IMG_SIZE,
)

CustomTorchImageDataset initialized with 895 samples.
CustomTorchImageDataset initialized with 179 samples.


Create a DataLoader and use a deterministic setup for training to make the results reproducible

In [None]:
train_dataloader = create_deterministic_training_dataloader(
    torch_train_dataset,
    batch_size=32,
    shuffle=True,
)

val_dataloader = DataLoader(
    torch_val_dataset,
    batch_size=32,
    shuffle=False,
)

# Contrastive Pretraining

First, we create our embedder for RGB and Lidar data and train them using contrastive pretraining.

In [None]:
CILP_EMB_SIZE = 200

img_embedder = Embedder(in_ch=4, emb_size=CILP_EMB_SIZE).to(device)
lidar_embedder = Embedder(in_ch=4, emb_size=CILP_EMB_SIZE).to(device)

We define a custom loss function for contrastive pretraining that aligns the embeddings of the two modalities

In [None]:
class ContrastiveLoss(nn.Module):
    """
    Contrastive Loss for matching embeddings from two modalities.
    
    Args:
        embeddings (torch.Tensor): A tuple containing image embeddings and lidar embeddings.
        _: Placeholder for compatibility.
        
    Returns:
        torch.Tensor: Computed contrastive loss.
    """
    def __init__(self):
        super(ContrastiveLoss, self).__init__()
        self.loss_img = nn.CrossEntropyLoss()
        self.loss_lidar = nn.CrossEntropyLoss()

    def forward(self, embeddings: torch.Tensor, _) -> torch.Tensor:
        img_embeddings, lidar_embeddings = embeddings
        
        batch_size = img_embeddings.size(0)
        ground_truth = torch.arange(batch_size, dtype=torch.long).to(device)

        loss_img = self.loss_img(img_embeddings, ground_truth)
        loss_lidar = self.loss_lidar(lidar_embeddings, ground_truth)
        loss = (loss_img + loss_lidar) / 2

        return loss

We use contrastive pretraining to pretrain our embedder for RGB and Lidar data.

In [None]:
epochs = 2

CILP_model = ContrastivePretraining(img_embedder, lidar_embedder).to(device)
optimizer = Adam(CILP_model.parameters(), lr=0.0001)
loss_func = ContrastiveLoss()

cilp_save_path = Path.cwd().parent / "checkpoints" / "04_cilp_contrastive_best.pth"

print("Training contrastive pretraining model...")
set_seeds(51)
mm_cilp_train_loss, mm_cilp_valid_loss, mm_cilp_train_time = train_model(
    CILP_model,
    optimizer,
    loss_func,
    get_mm_intermediate_inputs,
    epochs,
    train_dataloader,
    val_dataloader,
    save_path=cilp_save_path,
)

print("Validation loss: ", np.min(mm_cilp_valid_loss))

Training contrastive pretraining model...


[34m[1mwandb[0m: Currently logged in as: [33mkarl-schuetz[0m ([33mkarl-schuetz-hasso-plattner-institut[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here
here


0,1
epoch,▁█
learning_rate,▁▁
train_loss,█▁
valid_loss,▁█

0,1
epoch,2.0
learning_rate,0.0001
train_loss,3.46436
valid_loss,3.37909


We load the best-performing model and freeze all of its parameters.

In [None]:
CILP_model.load_state_dict(torch.load(cilp_save_path))

for param in CILP_model.parameters():
    CILP_model.requires_grad = False

# Cross-Modal Projector

The projector takes embedded RGB features as input. We obtain these embeddings using the best-performing RGB encoder from the contrastive pretraining stage.

In [None]:
def get_projector_inputs(batch):
    """
    Get the inputs for the projector model.

    Args:
        batch: A batch of data containing RGB images and LiDAR depth maps.

    Returns:
        List of image embeddings obtained from the CILP model's image embedder.
    """
    rbg_img, _, _ = batch
    imb_embs = CILP_model.img_embedder(rbg_img).to(device)
    
    return [imb_embs]

For the loss computation, we compare the projected RGB embeddings with the corresponding LiDAR embeddings. To simplify the training loop, we introduce an auxiliary loss function that computes the LiDAR embeddings on the fly. Alternatively, one could precompute the embeddings and construct a separate dataset.

In [None]:
class ProjectorLoss(nn.Module):
    """
    Projector Loss to align projected image embeddings with LiDAR embeddings.   
    Uses Mean Squared Error (MSE) loss.
    
    Args:
        embeddings (torch.Tensor): A tuple containing image embeddings and lidar embeddings.
        _: Placeholder for compatibility.
        
    Returns:
        torch.Tensor: Computed MSE loss.
    """
    def __init__(self):
        super(ProjectorLoss, self).__init__()
        self.loss_func = nn.MSELoss()

    def forward(self, embeddings: torch.Tensor, lidar_data: torch.Tensor) -> torch.Tensor:
        # Get LiDAR embeddings from the CILP model
        lidar_embeddings = CILP_model.lidar_embedder(lidar_data).to(device)
        loss = self.loss_func(embeddings, lidar_embeddings)
        return loss

We train a projector that maps RGB embeddings to LiDAR embeddings. Since this is a regression task, we use a mean squared error (MSE) loss. We need an own loss function to compare embeddings and not the class label.

In [None]:
epochs = 2

projector = Projector(in_emb_size=CILP_EMB_SIZE, out_emb_size=CILP_EMB_SIZE)
projector_opt = torch.optim.Adam(projector.parameters())
# We want to minimize the MSE between the projected RGB embeddings and the Lidar embeddings
projector_loss_func = ProjectorLoss()

projector_save_path = Path.cwd().parent / "checkpoints" / "04_mm_projector_best.pth"

print("Training projector model...")
set_seeds(51)
mm_projector_train_loss, mm_projector_valid_loss, mm_projector_train_time = train_model(
    projector,
    projector_opt,
    projector_loss_func,
    get_projector_inputs,
    epochs,
    train_dataloader,
    val_dataloader,
    target_idx=1, # We want to predict lidar embeddings
    save_path=projector_save_path,
)

print("Validation loss: ", np.min(mm_projector_valid_loss))

SyntaxError: invalid syntax. Perhaps you forgot a comma? (2119248113.py, line 16)

# Final Classifier

Load the CILP and projector model.

In [None]:
CILP_model.load_state_dict(torch.load(cilp_save_path))
projector.load_state_dict(torch.load(projector_save_path))

We train the RGB2LiDARClassifier, which embeds and projects RGB images into the LiDAR embedding space and then applies a lightweight LiDAR-based classifier to produce the final predictions.

In [None]:
epochs = 2

rgb_2_lidar_classifier = RGB2LiDARClassifier(
    img_embedder=CILP_model.img_embedder,
    projector=projector,
)
rgb_2_lidar_classifier_opt = torch.optim.Adam(rgb_2_lidar_classifier.parameters())
bce_loss_func = nn.BCEWithLogitsLoss()
rgb_2_lidar_save_path = Path.cwd().parent / "checkpoints" / "04_rgb2lidar_classifier.pth"

print("Training RGB2LiDARClassifier model...")
set_seeds(51)
mm_rgb2lidar_train_loss, mm_rgb2lidar_valid_loss, mm_rgb2lidar_train_time = train_model(
    rgb_2_lidar_classifier,
    rgb_2_lidar_classifier_opt,
    bce_loss_func,
    get_rgb_input,
    epochs,
    train_dataloader,
    val_dataloader,
    save_path=rgb_2_lidar_save_path,
)

print("Validation loss: ", np.min(mm_rgb2lidar_valid_loss))

Create a concatinated dataset for inference.

In [None]:
concat_dataset = ConcatDataset([torch_train_dataset, torch_val_dataset])
print(f"Total samples in concat dataset: {len(concat_dataset)}")

concat_dataloader = DataLoader(
    concat_dataset,
    batch_size=32,
    shuffle=False,
)

Load best model and calculate accuracy.

In [None]:
rgb_2_lidar_classifier.load_state_dict(torch.load(rgb_2_lidar_save_path))

accuracy, _ = infer_model(
    rgb_2_lidar_classifier,
    concat_dataloader,
    get_rgb_input,
)

print(f"Final accuracy on combined train and validation set: {accuracy*100:.2f}%")