# Model Architecture:

**Stage 1:** Contrastive Pretraining: CILP_model

**Goal:** align RGB and LiDAR in a shared 200-D space --> encodes both rgb and lidar in the same dimensionality space
```
RGB ----> Img Encoder ----\
                            ----> CLIP-style similarity
LiDAR -> Lidar Encoder ----/
```
**Outcome:** Shared embedding space where matching RGB/LiDAR pairs have high similarity and non-matches low similarity.

----------------------------

**Stage 2:** Projector Training: projector

**Goal:** learn a mapping from RGB CILP embeddings to LiDAR embeddings used by lidar_cnn:
ℝ²⁰⁰ (CILP RGB embedding) → ℝ³²⁰⁰ (LiDAR-CNN embedding)

projector knows how to “pretend” RGBs are LiDAR internally: projected_RGB_embedding ≈ “real” LiDAR embedding for each paired RGB/LiDAR sample.
```
RGB ----> Img Encoder ----> Projector ----> LiDAR embedding
                                     |
                                     v
                             MSE-loss to true LiDAR embedding

```
----------------------------

**Stage 3:** Final Classifier: RGB2LiDARClassifier

**Goal:** chaining all models together to classify spheres and cubes from images

pretends the RGBs look like LiDAR in the internal feature space and then uses LiDAR classifier.
```
RGB (img) ----> (CILP Img Encoder) ----> 200-D CILP embedding ----> (Projector) ---> 3200-D LiDAR embedding
---> (LiDAR Classifier) ---> cube/sphere

```

# Setup

In [None]:
%%capture
%pip install wandb weave

In [None]:
%%capture
%pip install fiftyone==1.10.0 sympy==1.12 torch==2.9.0 torchvision==0.20.0 numpy open-clip-torch

## Imports

In [None]:
import os
from pathlib import Path
from google.colab import userdata

import torch
import torch.nn as nn
import torch.nn.functional as F
#import torchvision.transforms as transforms
import torchvision.transforms.v2 as transforms

import wandb

In [None]:
from google.colab import drive
drive.mount('/content/drive')

STORAGE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/")
DATA_PATH = STORAGE_PATH / "multimodal_training_workshop/data/assessment"
TMP_DATA_PATH = Path("/content/data")

Mounted at /content/drive


In [None]:
!cp -r "/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/multimodal_training_workshop/data/assessment" /content/data

In [None]:
%cd "/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/"

In [None]:
from src.utility import set_seeds
from src.datasets import compute_dataset_mean_std_neu, get_cilp_dataloaders
from src.training import compute_class_weights, get_rgb_inputs, train_model, init_wandb, train_with_batch_loss
#from src.visualization import build_fusion_comparison_df, plot_losses
from src.models import CILPBackbone, ContrastivePretraining, Classifier, Projector, EmbedderMaxPool, RGB2LiDARClassifier

## Constants

In [None]:
SEED = 51
NUM_WORKERS = os.cpu_count()  # Number of CPU cores

BATCH_SIZE = 32
IMG_SIZE = 64

CLASSES = ["cubes", "spheres"]
NUM_CLASSES = len(CLASSES)
LABEL_MAP = {"cubes": 0, "spheres": 1}

VALID_BATCHES = 10
N = 12500

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.is_available()

True

In [None]:
# Usage: Call this function at the beginning and before each training phase
set_seeds(SEED)

# Integration of Wandb

In [None]:
# Load W&B API key from Colab Secrets and make it available as env variable
wandb_key = userdata.get('WANDB_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_key
wandb.login()

# Loading and preparation of Data

In [None]:
## Final: dynamisch
# mean, std = compute_dataset_mean_std(root_dir=root, img_size=IMG_SIZE)
mean, std = compute_dataset_mean_std_neu(root_dir=TMP_DATA_PATH, img_size=IMG_SIZE, seed=SEED)
print(mean, std)

In [None]:
## Final: dynamisch
img_transforms = transforms.Compose([
    transforms.ToImage(),   # Scales data into [0,1]
    transforms.Resize(IMG_SIZE),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(([0.0051, 0.0052, 0.0051, 1.0000]), ([5.8023e-02, 5.8933e-02, 5.8108e-02, 2.4509e-07]))     ## assessment dataset
    # transforms.Normalize(mean.tolist(), std.tolist())     ## assessment dataset
])

In [None]:
train_data, train_dataloader, valid_data, val_dataloader, test_data, test_dataloader  = get_cilp_dataloaders(
    str(TMP_DATA_PATH),
    VALID_BATCHES,
    test_frac=0.10,
    batch_size=BATCH_SIZE,
    img_transforms=img_transforms,
    num_workers=NUM_WORKERS,
    seed=SEED
)

for i, sample in enumerate(train_data):
    print(i, *(x.shape for x in sample))
    break

# The Model

I use a 2-layer projection head with a 512-dimensional hidden layer between the backbone’s CNN features and the final CILP embedding.
The 512 dimension is standard in contrastive learning literature (e.g., CLIP, SimCLR), as it provides a good balance between model expressiveness and computational efficiency.
It allows a nonlinear transformation from the high-dimensional CNN output into the shared embedding space while avoiding overfitting.

## Stage 1: CILP contrastive pretraining

## Stage 2: Projector training

# The Dataset

**TODO:** I would use the code from the task 3

```
AssessmentDataset (Pattern A or your MyDataset variant)

create_assessment_splits(...)

make_dataloaders(...)
```



# Training


```
train_model
```



## Stage 1: CILP contrastive pretraining

In [None]:
BEST_EMBEDDER = EmbedderMaxPool           
FEATURE_DIM = 128
CILP_EMB_SIZE = 200

In [None]:
img_embedder = CILPBackbone(
    in_ch=4, 
    embedder_cls=BEST_EMBEDDER, 
    feature_dim=FEATURE_DIM,
    emb_size=CILP_EMB_SIZE
).to(device)

lidar_embedder = CILPBackbone(
    in_ch=1, 
    embedder_cls=BEST_EMBEDDER, 
    feature_dim=FEATURE_DIM,
    emb_size=CILP_EMB_SIZE
).to(device)

In [None]:
# Initialize the model
CILP_model = ContrastivePretraining(img_embedder, lidar_embedder).to(device)

loss_img = nn.CrossEntropyLoss()
loss_lidar = nn.CrossEntropyLoss()

In [None]:
def cilp_batch_loss_fn(model, batch, device):
    """
    outputs: (logits_per_img, logits_per_lidar), each of shape (B, B)

    We build ground-truth indices 0..B-1 so that:
      - row i in logits_per_img should classify LiDAR i as the correct match
      - row i in logits_per_lidar should classify RGB i as the correct match
    """
    rgb, lidar, _ = batch
    rgb = rgb.to(device)
    lidar = lidar.to(device)

    logits_per_img, logits_per_lidar = model(rgb, lidar)   # (B, B)
    B = logits_per_img.size(0)
    ground_truth = torch.arange(B, dtype=torch.long, device=device)

    loss_i = loss_img(logits_per_img, ground_truth)
    loss_l = loss_lidar(logits_per_lidar, ground_truth)

    total_loss = (loss_i + loss_l) / 2.0

    return total_loss, logits_per_img

In [None]:
## train CILP
EPOCHS_CILP = 5
LR_CILP = 0.0001

opt = torch.optim.Adam(CILP_model.parameters(), LR_CILP)

best_val = float("inf")

# Path where best model is saved
checkpoint_dir = STORAGE_PATH / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)
model_save_path=checkpoint_dir / "cilp_model.pth"

init_wandb(
    model=CILP_model,
    opt_name = opt.__class__.__name__
)

results = train_with_batch_loss(
    model=CILP_model,
    optimizer=opt,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    batch_loss_fn=cilp_batch_loss_fn,
    epochs=EPOCHS_CILP,
    model_save_path=model_save_path,
    log_to_wandb=True,
    device=device
)

best_cilp_val = results["best_valid_loss"]
print(f"[5.1] Best CILP validation loss: {best_cilp_val:.4f}")


wandb.run.summary["cilp_best_val_loss"] = best_cilp_val
# End wandb run before starting the next model
wandb.finish()

In [None]:
## freeze pre-trained model
for param in CILP_model.parameters():
    param.requires_grad = False

CILP_model.eval()

## Stage 2: Cross-Modal Projection

In [None]:
# load pre-trained lidar_cnn classifier
lidar_cnn_path = STORAGE_PATH / "checkpoints/lidar_cnn.pt"

lidar_cnn = Classifier(in_ch=1).to(device)
lidar_cnn.load_state_dict(torch.load(lidar_cnn_path, weights_only=True))

for param in lidar_cnn.parameters():
    param.requires_grad = False

lidar_cnn.eval()

Classifier(
  (embedder): Sequential(
    (0): Conv2d(1, 50, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(50, 100, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (4): ReLU()
    (5): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(100, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU()
    (8): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (9): Conv2d(200, 200, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (10): ReLU()
    (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (12): Flatten(start_dim=1, end_dim=-1)
  )
  (classifier): Sequential(
    (0): Linear(in_features=3200, out_features=100, bias=True)
    (1): ReLU()
    (2): Linear(in_features=100, out_features=1, bias=True)
  )
)

In [None]:
def projector_batch_loss_fn(model, batch, device, CILP_model, lidar_cnn):
    rgb_img, lidar_depth, _ = batch
    rgb_img = rgb_img.to(device)
    lidar_depth = lidar_depth.to(device)

    # Use frozen encoders
    CILP_model.eval()
    lidar_cnn.eval()

    with torch.no_grad():
        img_embs = CILP_model.img_embedder(rgb_img)      # (B, CILP_EMB_SIZE)
        lidar_embs = lidar_cnn.get_embs(lidar_depth)     # (B, lidar_dim)

    pred_lidar_embs = model(img_embs)
    
    loss = F.mse_loss(pred_lidar_embs, lidar_embs)

    # match the convention used in train_with_batch_loss
    return loss, {"loss": loss.item()}


In [None]:
img_dim = CILP_EMB_SIZE
lidar_dim = 200 * 4 * 4

projector = Projector(img_dim, lidar_dim).to(device)

In [None]:
EPOCHS_PROJ = 20
LR_PROJECTOR = 1e-4
model_save_path=checkpoint_dir / "projector.pth"

opt = torch.optim.Adam(projector.parameters(), LR_PROJECTOR)

best_val_proj = float("inf")

init_wandb(
    model=projector,
    opt_name = opt.__class__.__name__
)

results = train_with_batch_loss(
    model=projector,
    optimizer=opt,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    batch_loss_fn=projector_batch_loss_fn,
    epochs=EPOCHS_PROJ,
    model_save_path=model_save_path,
    log_to_wandb=True,
    device=device,
    extra_args={
        "CILP_model": CILP_model,
        "lidar_cnn": lidar_cnn,
    }
)

best_proj_val = results["best_valid_loss"]
print(f"[5.2] Best projector validation MSE: {best_proj_val:.4f}")

wandb.run.summary["projector_best_val_mse"] = best_proj_val
# End wandb run before starting the next model
wandb.finish()

## Stage 3: RGB2LiDARClassifier

In [None]:
def train_rgb2lidar_classifier(
    model,
    train_loader,
    val_loader,
    epochs,
    lr,
    device,
):

    model = model.to(device)
    optimizer = torch.optim.Adam(model.projector.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()

    history = {
        "train_loss": [],
        "val_loss": [],
        "val_acc": []
    }

    for epoch in range(epochs):
        print(f"\nEpoch {epoch+1}/{epochs}")

        # ---------- TRAIN ----------
        model.train()
        running_loss = 0.0

        for imgs, _, labels in train_loader:
            imgs = imgs.to(device)
            labels = labels.float().view(-1, 1).to(device)   # [B,1]

            optimizer.zero_grad()

            logits = model(imgs)          # [B,1], no sigmoid
            loss = loss_fn(logits, labels)

            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        history["train_loss"].append(train_loss)


        # ---------- VALID ----------
        model.eval()
        val_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            for imgs, _, labels in val_loader:
                imgs = imgs.to(device)
                labels = labels.float().view(-1, 1).to(device)

                logits = model(imgs)              # [B,1]
                loss = loss_fn(logits, labels)
                val_loss += loss.item()

                # accuracy
                probs = torch.sigmoid(logits)     # [B,1], 0–1
                preds = (probs >= 0.5).long()     # threshold
                correct += (preds.view(-1) == labels.view(-1).long()).sum().item()
                total += labels.size(0)

        val_loss = val_loss / len(val_loader)
        val_acc = correct / total

        history["val_loss"].append(val_loss)
        history["val_acc"].append(val_acc)

        print(f"train_loss={train_loss:.4f}  val_loss={val_loss:.4f}  val_acc={val_acc*100:.2f}%")

    return history

In [None]:
LR_RGB2LIDAR = 1e-3
EPOCHS_RGB2LIDAR = 5

rgb2lidar_clf = RGB2LiDARClassifier(
    CILP=CILP_model,
    projector=projector,
    lidar_cnn=lidar_cnn,
).to(device)

#class_weights = compute_class_weights(train_data, NUM_CLASSES).to(device)
#loss_func = nn.CrossEntropyLoss(weight=class_weights.to(device))
loss_func = nn.BCEWithLogitsLoss()

opt = torch.optim.Adam(rgb2lidar_clf.parameters(), lr=LR_RGB2LIDAR)

results = train_rgb2lidar_classifier(
    model=rgb2lidar_clf,
    train_loader=train_dataloader,
    val_loader=val_dataloader,
    epochs=EPOCHS_RGB2LIDAR,
    lr=LR_RGB2LIDAR,
    device=device
)

best_rgb2lidar_val = results["best_valid_loss"]
print(f"[5.3] Best validation loss: {best_rgb2lidar_val:.4f}")
best_rgb2lidar_acc = results["best_valid_acc"]
print(f"[5.3] Best validation accuracy: {best_rgb2lidar_acc:.4f}")