# Model Architecture

```
RGB ----> RGB Encoder ----\
                            ----> Fusion ---> Classifier ---> cube/sphere
LiDAR -> LiDAR Encoder ----/
```



**The Architecture Flow:**

```
RGB Input (4ch)       LiDAR Input (4ch)
      │                     │
[RGB Encoder]         [XYZ Encoder]    <-- Learn specific features independently
      │                     │
  RGB Features          XYZ Features   <-- (e.g. 128 channels each)
      └──────────┬──────────┘
                 │
           Concatenation               <-- Fuse at the "Feature Level"
                 │
         [Regression Head]             <-- Learn relationships between features
                 │
           Output (x,y,z)
```

Multimodal fusion refers to how we combine information from different modalities (e.g., RGB and LiDAR).
There are three canonical levels of fusion:

Early fusion – combine raw or early-level features

Intermediate fusion – combine learned feature representations

Late fusion – combine decisions or latent vectors at the end of the pipeline;  it's almost like we're creating an ensemble model, where each model has a weighted vote in the final result.

Each level has different strengths + limitations.

# Setup

## Installations & Imports

In [1]:
%%capture
%pip install wandb weave

In [2]:
%%capture
%pip install fiftyone==1.10.0 sympy==1.12 torch==2.9.0 torchvision==0.20.0 numpy open-clip-torch

In [3]:
import os
from pathlib import Path
from google.colab import userdata
import time

from PIL import Image
from tqdm import tqdm
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import pandas as pd

import wandb
import cv2
import albumentations as A

## Storage

In [4]:
## stays - ggf. wo gebraucht

from google.colab import drive
drive.mount('/content/drive')

STORAGE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/")

DATA_PATH = STORAGE_PATH / "multimodal_training_workshop/data"
print(f"Data path: {DATA_PATH}")
print(f"Data path exists: {DATA_PATH.exists()}")

Mounted at /content/drive
Data path: /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/multimodal_training_workshop/data
Data path exists: True


In [5]:
# !rsync -ah --progress "/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/data" "/content/data/"


In [6]:
## stays

device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.is_available()

True

## Constants

In [7]:
## stays

SEED = 51
NUM_WORKERS = os.cpu_count()  # Number of CPU cores

BATCH_SIZE = 32
IMG_SIZE = 64

CLASSES = ["cubes", "spheres"]
NUM_CLASSES = len(CLASSES)
LABEL_MAP = {"cubes": 0, "spheres": 1}

VALID_BATCHES = 10
N = 6500
EPOCHS = 10
LR = 0.0001

# Integration of Wandb

In [8]:
## stays

# Load W&B API key from Colab Secrets and make it available as env variable
wandb_key = userdata.get('WANDB_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_key
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mmichele-marschner[0m ([33mmichele-marschner-university-of-potsdam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [9]:
## move: training.py

def init_wandb(model, fusion_name, num_params, opt_name, batch_size=BATCH_SIZE, epochs=15):
  """
  Initialize a Weights & Biases run for a given fusion model.

  Args:
      model (nn.Module): The PyTorch model to track.
      fusion_name (str): Short name of the fusion strategy (e.g. "early_fusion").
      num_params (int): Total number of trainable parameters of the model.
      opt_name (str): Name of the optimizer (e.g. "Adam").
      batch_size (int, optional): Batch size used during training.
      epochs (int, optional): Number of training epochs.

  Returns:
      wandb.sdk.wandb_run.Run: The initialized W&B run object.
  """

  config = {
    # "embedding_size": embedding_size,      ## TODO: ändert die sich? hab ich die bei fusion?
    "optimizer_type": opt_name,
    "fusion_strategy": fusion_name,
    "model_architecture": model.__class__.__name__,
    "batch_size": batch_size,
    "num_epochs": epochs,
    "num_parameters": num_params
  }

  run = wandb.init(
    project="cilp-extended-assessment",
    name=f"{fusion_name}_run",
    config=config,
    reinit='finish_previous',                           # allows starting a new run inside one script
  )

  return run

# Utility Functions

alle in utility.py reinschieben



In [10]:
def set_seeds(seed=SEED):
    """
    Set seeds for complete reproducibility across all libraries and operations.

    Args:
        seed (int): Random seed value
    """
    # Set environment variables before other imports
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    # Python random module
    random.seed(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch CPU
    torch.manual_seed(seed)

    # PyTorch GPU (all devices)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

        # CUDA deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # OpenCV
    cv2.setRNGSeed(seed)

    # Albumentations (for data augmentation)
    try:
        A.seed_everything(seed)
    except AttributeError:
        # Older versions of albumentations
        pass

    # PyTorch deterministic algorithms (may impact performance)
    try:
        torch.use_deterministic_algorithms(True)
    except RuntimeError:
        # Some operations don't have deterministic implementations
        print("Warning: Some operations may not be deterministic")

    print(f"All random seeds set to {seed} for reproducibility")



# Usage: Call this function at the beginning and before each training phase
set_seeds(SEED)

# Additional reproducibility considerations:

def create_deterministic_training_dataloader(dataset, batch_size, shuffle=True, **kwargs):
    """
    Create a DataLoader with deterministic shuffling behaviour.

    Args:
        dataset (Dataset): PyTorch Dataset instance.
        batch_size (int): Batch size.
        shuffle (bool): Whether to shuffle data.
        **kwargs: Additional DataLoader keyword arguments.

    Returns:
        DataLoader: Training DataLoader with reproducible sampling.
    """
    # Use a generator with fixed seed for reproducible shuffling
    generator = torch.Generator()
    generator.manual_seed(SEED)

    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        generator=generator if shuffle else None,
        **kwargs
    )



All random seeds set to 51 for reproducibility


In [11]:
def format_time(seconds):
    """
    Convert a duration in seconds to a human-readable 'MMm SSs' string.

    Args:
        seconds (float): Duration in seconds.

    Returns:
        str: Formatted duration, e.g. "02m 15s".
    """
    m = int(seconds // 60)
    s = int(seconds % 60)
    return f"{m:02d}m {s:02d}s"

In [12]:
def get_torch_xyza(lidar_depth, azimuth, zenith):
    """
    Convert LiDAR depth + angular coordinates into an XYZA tensor.

    Args:
        lidar_depth (torch.Tensor): Radial distances of shape (H, W).
        azimuth (torch.Tensor): Azimuth angles in radians, shape (H,).
        zenith (torch.Tensor): Zenith angles in radians, shape (W,).

    Returns:
        torch.Tensor: Stacked tensor of shape (4, H, W) containing
            X, Y, Z coordinates and a validity mask A.
    """
    # Broadcast azimuth (per row) and zenith (per column) to full image grid
    # and convert polar coordinates into Cartesian coordinates.
    x = lidar_depth * torch.sin(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    y = lidar_depth * torch.cos(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    z = lidar_depth * torch.sin(-zenith[None, :])

    # A is a binary mask: 1 for valid points, 0 for "no return" / far-away
    a = torch.where(lidar_depth < 50.0,
                    torch.ones_like(lidar_depth),
                    torch.zeros_like(lidar_depth))

    xyza = torch.stack([x, y, z, a], dim=0)
    return xyza

In [13]:
def format_positions(positions):
    """
    Format a sequence of numerical positions as nicely aligned strings.

    Args:
        positions (Iterable[float]): Sequence of scalar values.

    Returns:
        list[str]: List of strings with 4 decimal places.
    """
    return ['{0: .4f}'.format(x) for x in positions]

In [14]:
def create_subset(size, dataset):
    """
    Create a random subset of a given dataset.

    Args:
        size (int): Desired number of samples in the subset.
        dataset (torch.utils.data.Dataset): Dataset to sample from.

    Returns:
        torch.utils.data.Subset: Random subset of the given dataset.
    """
    # Sample unique indices uniformly at random from the dataset
    indices = np.random.choice(size, size=size, replace=False)
    return Subset(dataset, indices)

In [15]:
def print_loss(epoch, loss, outputs, target, is_train=True, is_debug=False):
    """
    Print a formatted loss line and optionally one example prediction.

    Args:
        epoch (int): Current epoch index.
        loss (float or Tensor): Loss value for this epoch.
        outputs (torch.Tensor): Model predictions for the current batch.
        target (torch.Tensor): Ground-truth targets for the current batch.
        is_train (bool): If True, label as training loss; else validation.
        is_debug (bool): If True, also print one prediction/target pair.
    """

    loss_type = "train loss:" if is_train else "valid loss:"
    print("epoch", str(epoch), loss_type, str(loss))

    if is_debug:
        print("example pred:", format_positions(outputs[0].tolist()))
        print("example real:", format_positions(target[0].tolist()))

In [16]:
## Final: löschen
## move: visualization.py

def plot_losses(losses, title="Training & Validation Loss Comparison", figsize=(10,6)):
    """
    Legacy plotting helper to show train/valid losses for multiple models.

    Args:
        losses (dict): Mapping model_name -> {"train_losses": [...],
                                              "valid_losses": [...]}.
        title (str): Plot title.
        figsize (tuple): Matplotlib figure size.
    """
    plt.figure(figsize=figsize)

    for model_name, log in losses.items():
        train = log["train_losses"]
        valid = log["valid_losses"]

        # plot train + valid with different line styles
        plt.plot(train, label=f"{model_name} - train", linewidth=2)
        plt.plot(valid, label=f"{model_name} - valid", linestyle="--", linewidth=2)

    plt.title(title, fontsize=16)
    plt.xlabel("Epochs", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

# Loading and preparation of Data

In [17]:
## Final: überdenken woher datenset kommen soll
## move: datasets.py

class AssessmentXYZADataset(Dataset):
    """
    Dataset for the CILP XYZ + RGB assessment data.

    It expects the following folder structure:

        root/
          cubes/
            rgb/*.png
            lidar_xyza/*.npy
          spheres/
            rgb/*.png
            lidar_xyza/*.npy

    Each sample consists of an RGB image, a LiDAR XYZA tensor and a class label.
    """
    def __init__(self, root_dir, start_idx=0, end_idx=None,
                 transform_rgb=None, transform_lidar=None, shuffle=True):
        """
        Args:
            root_dir (str or Path): Root directory of the dataset.
            start_idx (int): Start index (inclusive) for slicing the dataset.
            end_idx (int or None): End index (exclusive); if None use all.
            transform_rgb (callable or None): Transform applied to RGB images.
            transform_lidar (callable or None): Transform applied to LiDAR tensors.
            shuffle (bool): If True, shuffle the full list of samples once.
        """
        self.root_dir = Path(root_dir)
        self.transform_rgb = transform_rgb
        self.transform_lidar = transform_lidar

        self.classes = ["cubes", "spheres"]
        self.label_map = {"cubes": 0, "spheres": 1}

        samples = []

        print(f"Scanning dataset in {root_dir}...")
        for cls in self.classes:
            cls_dir = self.root_dir / cls
            rgb_dir = cls_dir / "rgb"
            lidar_dir = cls_dir / "lidar_xyza"

            rgb_files = sorted(rgb_dir.glob("*.png"))
            print(f"{cls}: {len(rgb_files)} RGB files found. Matching XYZA...")

            for rgb_path in tqdm(rgb_files, desc=f"{cls} matching", leave=False):
                stem = rgb_path.stem
                lidar_path = lidar_dir / f"{stem}.npy"
                if lidar_path.exists():
                    samples.append({
                        "rgb": rgb_path,
                        "lidar_xyza": lidar_path,
                        "label": self.label_map[cls],
                    })

        if shuffle:
            rng = random.Random(SEED)
            rng.shuffle(samples)

        if end_idx is None:
            end_idx = len(samples)
        self.samples = samples[start_idx:end_idx]

        # Preload LiDAR tensors into memory since they are small and fast to cache
        print(f"Preloading LiDAR XYZA tensors into RAM...")
        self.lidar_tensors = []
        for item in tqdm(self.samples, desc="Loading XYZA", leave=False):
            lidar_np = np.load(item["lidar_xyza"])        # (4, H, W)
            lidar_t  = torch.from_numpy(lidar_np).float() # CPU tensor
            self.lidar_tensors.append(lidar_t)

        print(
            f"Dataset ready: {len(self.samples)} samples loaded.\n"
            f"Slice [{start_idx}:{end_idx}]"
        )

    def __len__(self):
        """Return the number of samples in this dataset slice."""
        return len(self.samples)

    def __getitem__(self, idx):
        """
        Load a single (rgb, lidar, label) triplet.

        Returns:
            tuple: (rgb_tensor, lidar_tensor, label_tensor)
        """
        item  = self.samples[idx]
        lidar = self.lidar_tensors[idx]

        # RGB image is loaded on the fly
        rgb = Image.open(item["rgb"])
        if self.transform_rgb:
            rgb = self.transform_rgb(rgb)

        if self.transform_lidar:
            lidar = self.transform_lidar(lidar)

        label = torch.tensor(item["label"], dtype=torch.long)
        return rgb, lidar, label


In [18]:
## move: datasets.py

def compute_dataset_mean_std(root_dir):
    """
    Estimate the per-channel mean and std for the RGB+LiDAR data.

    Args:
        root_dir (str or Path): Root directory passed to AssessmentXYZADataset.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            mean and std tensors with shape (C,).
    """

    stats_transforms = transforms.Compose([
      transforms.Resize(IMG_SIZE),
      transforms.ToImage(),
      transforms.ToDtype(torch.float32, scale=True),  # [0,1], 4 channels
    ])

    stats_dataset = AssessmentXYZADataset(
        root_dir=root_dir,
        start_idx=0,
        end_idx=None,          # or e.g. 1000 to subsample
        transform_rgb=stats_transforms,
    )

    subset_size = min(2000, len(stats_dataset)*0.3)
    subset_for_stats = create_subset(size=subset_size)

    loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)

    mean = 0.
    std = 0.
    total = 0

    for images, _, _ in tqdm(loader, desc="Computing mean/std"):
        images = images.float()       # B, C, H, W
        batch_size = images.size(0)

        # compute mean over batch (channels only!)
        mean += images.mean(dim=[0, 2, 3]) * batch_size

        # compute std over batch
        std += images.std(dim=[0, 2, 3]) * batch_size

        total += batch_size

    mean /= total
    std /= total

    return mean, std


In [19]:
def compute_dataset_mean_std_neu(root_dir):
    """
    Estimate the per-channel mean and std for the RGB+LiDAR data.

    Args:
        root_dir (str or Path): Root directory passed to AssessmentXYZADataset.

    Returns:
        tuple[torch.Tensor, torch.Tensor]:
            mean and std tensors with shape (C,).
    """
    stats_transforms = transforms.Compose([
        transforms.Resize(IMG_SIZE),
        transforms.ToImage(),
        transforms.ToDtype(torch.float32, scale=True),  # [0,1], 4 channels
    ])

    stats_dataset = AssessmentXYZADataset(
        root_dir=root_dir,
        transform_rgb=stats_transforms,
        transform_lidar=None,
        shuffle=False,
    )

    loader = DataLoader(stats_dataset, batch_size=64, shuffle=False)

    # Accumulate running sum and sum of squares to compute mean/std
    channel_sum = torch.zeros(4)
    channel_sq_sum = torch.zeros(4)
    num_pixels = 0

    for rgb, _, _ in tqdm(loader, desc="Computing mean/std"):
        # rgb shape: (B, C, H, W)
        b, c, h, w = rgb.shape
        num_pixels += b * h * w
        channel_sum += rgb.sum(dim=[0, 2, 3])
        channel_sq_sum += (rgb ** 2).sum(dim=[0, 2, 3])

    mean = channel_sum / num_pixels
    std = torch.sqrt(channel_sq_sum / num_pixels - mean ** 2)
    return mean, std


In [None]:
## Final: dynamisch
root = STORAGE_PATH / "data"
mean, std = compute_dataset_mean_std(root_dir=root)

Scanning dataset in /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/data...
cubes: 2501 RGB files found. Matching XYZA...




spheres: 9999 RGB files found. Matching XYZA...




Preloading LiDAR XYZA tensors into RAM...


Loading XYZA:   0%|          | 38/12500 [00:28<2:35:57,  1.33it/s]

In [None]:
## Final: dynamisch
img_transforms = transforms.Compose([
    transforms.ToImage(),   # Scales data into [0,1]
    transforms.Resize(IMG_SIZE),
    transforms.ToDtype(torch.float32, scale=True),
    transforms.Normalize(([0.0051, 0.0052, 0.0051, 1.0000]), ([5.8023e-02, 5.8933e-02, 5.8108e-02, 2.4509e-07]))     ## assessment dataset
    # transforms.Normalize(mean.tolist(), std.tolist())     ## assessment dataset
])

In [None]:
## move: datasets.py

def get_dataloaders(root_dir):
    """
    Create train and validation datasets + dataloaders.

    The split is done by taking the first part for training and the last
    VALID_BATCHES * BATCH_SIZE samples for validation.

    Args:
        root_dir (str or Path): Root path to the cubes/spheres data.

    Returns:
        tuple: (train_dataset, train_dataloader,
                valid_dataset, val_dataloader)
    """
    train_data = AssessmentXYZADataset(
        root_dir,
        0,
        N-VALID_BATCHES*BATCH_SIZE,
        img_transforms
    )

    train_dataloader = create_deterministic_training_dataloader(
        train_data,
        batch_size=BATCH_SIZE,
        shuffle=True,
        drop_last=True
    )

    valid_data = AssessmentXYZADataset(
        root_dir,
        N-VALID_BATCHES*BATCH_SIZE,
        N,
        img_transforms
    )

    val_dataloader = DataLoader(
        valid_data,
        batch_size=BATCH_SIZE,
        shuffle=False,
        drop_last=True
    )

    return train_data, train_dataloader, valid_data, val_dataloader


In [None]:
train_data, train_dataloader, valid_data, val_dataloader = get_dataloaders(str(STORAGE_PATH / "data"))

for i, sample in enumerate(train_data):
    print(i, *(x.shape for x in sample))
    break

# Models

Take the EmbedderMaxPool architecture from the workshop and turn it into an encoder that outputs an embedding instead of 9 positions.

## Embedder

In [None]:
## move: model.py

class EmbedderMaxPool(nn.Module):
    """
    Convolutional encoder that down-samples via MaxPool2d and outputs a flat feature vector.

    This is used as a shared building block for the early and intermediate
    fusion architectures.
    """
    def __init__(self, in_ch, feature_dim=200):
        """
        Args:
            in_ch (int): Number of input channels.
            feature_dim (int): Number of output channels in the last conv layer.
        """
        kernel_size = 3
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 50, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(50, 100, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(100, feature_dim, kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2)

        # For 64x64 input and 3 pooling steps we end up at 8x8 spatial size.
        self.flatten_dim = 200 * 8 * 8


    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, C, H, W).

        Returns:
            torch.Tensor: Flattened feature tensor of shape (B, flatten_dim).
        """
        x = self.pool(F.relu(self.conv1(x)))    # 64x64 -> 32x32
        x = self.pool(F.relu(self.conv2(x)))    # 32x32 -> 16x16
        x = self.pool(F.relu(self.conv3(x)))    # 16x16 -> 8x8
        x = torch.flatten(x, 1) # flatten all dimensions except batch

        return x

## Early Fusion Model

**Concept:** Fuse modalities before any deep processing — usually by concatenating channels or inputs.

```
input = concat(RGB, XYZ)  → shape (8, H, W)
-> shared CNN processes everything together
```



**Advantages:**

* **Captures Early Cross-Modal Interactions:** Learns joint low-level correlations directly from raw signals.
* **Simple & Lightweight**: Easiest fusion method to implement; minimal architectural overhead.
* **Effective with Perfect Alignment:** Works well when modalities are tightly synchronized and spatially aligned.

**Limitations:**

* **Noise Sensitivity:** One noisy or corrupted modality directly contaminates the shared feature space.
* **Strict Alignment Requirement:** Modalities must have matching spatial resolution, alignment, and synchronization.
* **Feature Space Mismatch:** Raw modalities differ in scale, units, and distribution; one modality can dominate without careful normalization.
* **High Input Dimensionality:** Channel concatenation increases the input size and can require more data and compute to train effectively.
* **Limited Flexibility:** Assumes combining low-level signals is beneficial; may underperform when modalities carry different types of information.

In [None]:
## move: model.py

class FullyConnectedHead(nn.Module):
    """
    Fully-connected classification head mapping features to class logits.
    """
    def __init__(self, input_dim, output_dim=2):
        """
        Args:
            input_dim (int): Dimensionality of the flattened feature vector.
            output_dim (int): Number of output classes.
        """
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, output_dim)

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input features of shape (B, input_dim).

        Returns:
            torch.Tensor: Class logits of shape (B, output_dim).
        """
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

In [None]:
## move: model.py

class EarlyFusionModel(nn.Module):
    """
    Early fusion model that concatenates RGB and LiDAR channels at the input.

    The 8-channel tensor (4 RGB-like + 4 XYZA) is passed through a shared
    CNN embedder and a fully-connected classification head.
    """
    def __init__(self, in_ch=8, output_dim=2):
        """
        Args:
            in_ch (int): Number of input channels after concatenation.
            output_dim (int): Number of output classes.
        """
        super().__init__()

        # Shared embedder for all channels
        self.embedder = EmbedderMaxPool(in_ch)

        # Fully-connected head on top of the shared embedding
        self.fullyConnected = FullyConnectedHead(
            input_dim=self.embedder.flatten_dim,
            output_dim=output_dim
        )

    def forward(self, x):
        """
        Args:
            x (torch.Tensor): Input tensor of shape (B, in_ch, 64, 64).

        Returns:
            torch.Tensor: Class logits of shape (B, output_dim).
        """
        features = self.embedder(x)     # → (B, 12800)
        preds = self.fullyConnected(features)  # → (B, output_dim)
        return preds

## Intermediate Fusion Model

**Concept:** Each modality has its own encoder / feature extractor, and fusion happens after some layers but before classification.

```
RGB → RGB_conv → RGB_features (C, H, W)
LiDAR → LiDAR_conv → LiDAR_features (C, H, W)

Fusion → joint_features → FC → output
```



**Advantages:**

* **Specialized Processing:** Each modality gets its own encoder, tailored to its characteristics.
* **Learned Representations:** Fusion occurs on higher-level, more discriminative features rather than raw data.
* **Flexible Design:** The fusion point can be chosen at different network depths, allowing fine-grained architectural control.
* **Easily Extendable:** New modalities can be added by including additional modality-specific branches.


**Limitations:**

* **Architectural Complexity:** Requires designing separate modality-specific encoders and choosing an appropriate fusion point.
* **Higher Computational Cost:** More expensive than early fusion due to duplicated feature extractors.
* **Fusion Design Sensitivity:** Performance depends on the chosen fusion mechanism (concat, addition, multiplicative, bilinear, attention), which often requires experimentation.
* **Depth Selection Challenge:** Deciding how much unimodal processing to perform before fusion can be non-trivial and task-dependent.

Implemented 4 variants:

*   Concatenation
*   Addition
* Hadamard Product (element-wise multiplication)
* Matrix-Multiplication



| Fusion Method | Advantages | Limitations |
|---------------|------------|-------------|
| **Concatenation** | - Very expressive and flexible<br>- Lets the network learn arbitrary cross-modal interactions<br>- Robust and widely used baseline | - Doubles channel count → more parameters & memory<br>- Computationally heavier<br>- Fusion is unguided; model must discover interactions itself |
| **Addition** | - Lightweight (no increase in channels)<br>- Fast and parameter-efficient<br>- Enforces similar feature spaces between modalities | - Assumes features are aligned and comparable<br>- One noisy modality corrupts the other<br>- Sensitive to scale differences between modalities |
| **Multiplicative (Hadamard Product)** | - Gating effect: highlights features important in *both* modalities<br>- More expressive than addition, cheaper than concat<br>- Natural for attention-like fusion | - Suppresses features when one modality has low magnitude<br>- Requires careful normalization<br>- Can amplify noise if both activations are high |
| **Matrix Multiplication (Bilinear-like)** | - Captures rich pairwise correlations between modalities<br>- Most expressive among all four<br>- Enables true 2nd-order interaction learning | - Very heavy in compute & memory<br>- Requires flattening or dimensionality reduction<br>- Easily overfits; harder to train and tune |


In [None]:
## move: model.py

class ConcatIntermediateNet(nn.Module):
    """
    Intermediate fusion model using feature concatenation.

    Two separate EmbedderMaxPool encoders are applied to RGB and XYZA
    inputs, their flattened features are concatenated, and a shared
    FullyConnectedHead maps the joint representation to class logits.
    """
    def __init__(self, rgb_ch, xyz_ch, output_dim=2):
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # Calculate combined dimension
        # (200 * 8 * 8) + (200 * 8 * 8)
        combined_dim = self.rgb_encoder.flatten_dim + self.xyz_encoder.flatten_dim

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=combined_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # Fuse (Concatenate) at the feature level
        x_fused = torch.cat((x_rgb, x_xyz), dim=1)                      # (B, 2*D)

        # Predict
        output = self.head(x_fused)

        return output

In [None]:
## move: model.py

class AddIntermediateNet(nn.Module):
    """
    Intermediate fusion model using element-wise addition.

    Two separate encoders process each modality independently.
    The resulting feature vectors must have the same size; they are
    added element-wise and passed to a shared FullyConnectedHead.
    """
    def __init__(self, rgb_ch, xyz_ch, output_dim):
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # For addition, shapes must match
        fused_dim = self.rgb_encoder.flatten_dim                        # same size after addition

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=fused_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # Additive fusion in feature space
        x_fused = x_rgb + x_xyz                                         # (B, D)

        # Predict
        output = self.head(x_fused)                                     # (B, output_dim)

        return output

In [None]:
## move: model.py

class MatmulIntermediateNet(nn.Module):
    """
    Intermediate fusion model using matrix multiplication.

    The two modality-specific embeddings are reshaped into matrices
    and combined via a bilinear interaction (matrix product) before
    the fully-connected head.
    """
    def __init__(self, rgb_ch, xyz_ch, output_dim):
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # For multiplication, shapes must match
        embedding_dim = self.rgb_encoder.flatten_dim
        fused_dim = embedding_dim * embedding_dim                       # D * D after matmul

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=fused_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # Matrix multiplication: (B, D, 1) @ (B, 1, D)
        x_fused = torch.matmul(x_rgb.unsqueeze(2), x_xyz.unsqueeze(1))  # (B, D, D)
        x_fused = x_fused.flatten(start_dim=1)                          # (B, D*D)

        # Predict
        output = self.head(x_fused)                                     # (B, output_dim)

        return output

In [None]:
## move: model.py

class HadamardIntermediateNet(nn.Module):
    """
    Intermediate fusion model using the Hadamard (element-wise) product.

    After independent encoding, the feature vectors are multiplied
    element-wise to capture multiplicative interactions between
    modalities, then fed to the classification head.
    """
    def __init__(self, rgb_ch, xyz_ch, output_dim):
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # For elementwise multiplication, shapes must match
        fused_dim = self.rgb_encoder.flatten_dim                        # same size after addition

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=fused_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # Multiplicative / gating-like fusion
        x_fused = x_rgb * x_xyz                                         # shape: (B, D)

        # Predict
        output = self.head(x_fused)                                     # (B, output_dim)

        return output

## Late Fusion Model

**Concept:** Each modality is processed completely separately, and only the final predictions or high-level embeddings are fused.

```
RGB → RGB-Embedder → logits_rgb
LiDAR → LiDAR-Embedder → logits_lidar

Fusion → final decision
```

**Advantages:**

* **Robust to Missing Modalities:** The system can still operate if one modality is noisy, unreliable, or absent.
* **Best for Heterogeneous Modalities:** Works well when modalities differ greatly.
* **Modular & Simple:** Unimodal models can be trained, debugged, and replaced independently.
* **Leverages Existing Models:** Allows the reuse of strong off-the-shelf unimodal experts without architectural changes.


**Limitations:**

* **Missed Interactions:** No joint feature learning — modalities never influence each other during representation learning.
* **Limited Expressiveness:** Simple fusion rules (e.g., averaging, weighted sum) cannot capture complex cross-modal relationships.
* **Information Loss:** By the time unimodal predictors output logits/embeddings, rich spatial and semantic details may already be discarded, limiting the power of fusion.

In [None]:
## move: model.py

rgb_net = EmbedderMaxPool(4).to(device)
xyz_net = EmbedderMaxPool(4).to(device)

## TODO: passiert das woanders nicht?
networks = [rgb_net, xyz_net]

class LateNet(nn.Module):
    """
    Late fusion model combining unimodal logits.

    Two independent classifiers are trained for RGB and LiDAR.
    Their logits are then averaged (or combined) at the decision level
    to obtain the final prediction.
    """
    def __init__(self, output_dim):
        super().__init__()
        self.rgb = rgb_net
        self.xyz = xyz_net

        # each embedder outputs flatten_dim (e.g. 12800)
        fusion_dim = self.rgb.flatten_dim * 2  # rgb + xyz

        # single FullyConnected head in which data is fused
        self.fullyConnected = FullyConnectedHead(
            input_dim=fusion_dim,
            output_dim=output_dim,
        )

    def forward(self, x_rgb, x_xyz):
        # Extract features independently
        x_rgb = self.rgb(x_rgb)     # (B, 12800)
        x_xyz = self.xyz(x_xyz)     # (B, 12800)

        # this concatenates the features from the two branches
        x_fused = torch.cat((x_rgb, x_xyz), dim=1)    # (B, 25600)

        # Predict
        preds = self.fullyConnected(x_fused)           # (B, output_dim)
        return preds

# Model Training

In [None]:
## move: training.py

def train_model(model, optimizer, input_fn, loss_fn, epochs, train_dataloader, val_dataloader, model_save_path, target_idx=-1, log_to_wandb=False, model_name=None):
    """
    Generic training loop for all fusion models.

    Args:
        model (nn.Module): Model to train.
        optimizer (torch.optim.Optimizer): Optimizer instance.
        input_fn (callable): Function that maps a batch to model inputs.
                             Takes a batch tuple and returns a tuple of tensors.
        loss_fn (callable): Loss function (e.g. CrossEntropyLoss).
        epochs (int): Number of training epochs.
        train_dataloader (DataLoader): Dataloader for training data.
        val_dataloader (DataLoader): Dataloader for validation data.
        model_save_path (str or Path): Where to save the best model checkpoint.
        target_idx (int): If using multi-target labels, index of the target
                          to use (-1 for all / default).
        log_to_wandb (bool): If True, log metrics to Weights & Biases.
        model_name (str or None): Optional label for logging / printing.

    Returns:
        dict: Dictionary containing training history:
              {
                "train_losses": [...],
                "valid_losses": [...],
                "epoch_times": [...],
                "best_valid_loss": float,
                "best_model_state_dict": dict,
                "num_params": int,
                "max_gpu_mem_mb": float,
              }
    """
    train_losses = []
    valid_losses = []
    epoch_times = []

    best_val_loss = float('inf')
    best_model = None

    # Track peak GPU memory usage (if CUDA is available)
    max_gpu_mem_mb = 0.0
    use_cuda = torch.cuda.is_available()

    if use_cuda:
        torch.cuda.reset_peak_memory_stats()

    for epoch in tqdm(range(epochs)):
        start_time = time.time()                  # to track the train time per model
        print(f"Epoch and start time: {epoch} und {start_time}")

        # ----- Training loop -----
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_dataloader):

            rgb, lidar_xyza, position = batch
            rgb = rgb.to(device)
            lidar_xyza = lidar_xyza.to(device)
            position = position.to(device)

            optimizer.zero_grad()
            target = batch[target_idx].to(device)
            outputs = model(*input_fn(batch))

            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / (step + 1)
        train_losses.append(train_loss)
        print_loss(epoch, train_loss, outputs, target, is_train=True)

        # ----- Validation loop -----
        model.eval()
        valid_loss = 0
        with torch.no_grad():
          for step, batch in enumerate(val_dataloader):
              target = batch[target_idx].to(device)
              outputs = model(*input_fn(batch))
              valid_loss += loss_fn(outputs, target).item()
        valid_loss = valid_loss / (step + 1)
        valid_losses.append(valid_loss)
        print_loss(epoch, valid_loss, outputs, target, is_train=False)

        # Save best model based on validation loss
        if valid_loss < best_val_loss:
          best_val_loss = valid_loss
          best_model = model
          torch.save(best_model.state_dict(), model_save_path)
          print('Found and saved better weights for the model')

        # calculate epoch times
        epoch_time = time.time() - start_time
        epoch_time_formatted = format_time(epoch_time)
        epoch_times.append(epoch_time_formatted)

        # GPU memory usage
        if use_cuda:
            gpu_mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
            max_gpu_mem_mb = max(max_gpu_mem_mb, gpu_mem_mb)

        # wandb logging
        if log_to_wandb:
            wandb.log(
                {
                    "model": model.__class__.__name__,
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "valid_loss": valid_loss,
                    "lr": optimizer.param_groups[0]["lr"],
                    "epoch_time": epoch_time_formatted,
                    "max_gpu_mem_mb_epoch": gpu_mem_mb if use_cuda else 0.0,
                }
            )

    return train_losses, valid_losses, epoch_times, max_gpu_mem_mb

In [None]:
## move: training.py

def get_early_inputs(batch):
    """
    Prepare inputs for the early fusion model.

    Concatenates RGB and XYZA along the channel dimension to obtain
    an 8-channel tensor.

    Args:
        batch (tuple): (rgb, xyz, label) from the dataset.

    Returns:
        tuple[torch.Tensor]: Single-element tuple (inputs_mm_early,).
    """
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)

    # Concatenate along channel dimension: (B, 4, H, W) + (B, 4, H, W) -> (B, 8, H, W)
    inputs_mm_early = torch.cat((inputs_rgb, inputs_xyz), 1)
    return (inputs_mm_early,)

In [None]:
## move: training.py

def get_inputs(batch):
    """
    Prepare inputs for intermediate/late fusion models.

    Returns RGB and XYZA tensors separately so that each modality
    can be passed to its own encoder.

    Args:
        batch (tuple): (rgb, xyz, label) from the dataset.

    Returns:
        tuple[torch.Tensor, torch.Tensor]: (inputs_rgb, inputs_xyz)
    """
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    return (inputs_rgb, inputs_xyz)

In [None]:
## move: training.py

def compute_class_weights(dataset):
  """
  Compute inverse-frequency class weights to handle imbalance.

  If train_labels is not provided, hard-coded counts are used
  (as in the assignment description).

  Args:
      train_labels (torch.Tensor or None): Optional 1D tensor of labels
          from the training set.

  Returns:
      torch.Tensor: Normalized class weights of shape (num_classes,).
  """
  # Extract all labels from the dataset
  labels = [dataset[i][2] for i in range(len(dataset))]
  labels = torch.tensor(labels, dtype=torch.long)

  # Count occurrences of each class
  unique, counts = torch.unique(labels, return_counts=True)
  class_counts = counts.float()

  # Compute inverse-frequency weights (rarer class -> higher weight)
  class_weights = class_counts.sum() / (class_counts + 1e-6)
  class_weights = class_weights / class_weights.mean()

  return class_weights

In [None]:
## stay

set_seeds(SEED)

class_weights = compute_class_weights(train_data).to(device)
loss_func = nn.CrossEntropyLoss(weight=class_weights.to(device))

metrics = {}   # store losses for each model

# Defines fusion models to train and compare
models_to_train = {
    "early_fusion": EarlyFusionModel(in_ch=8, output_dim=NUM_CLASSES).to(device),
    #"intermediate_fusion_concat": ConcatIntermediateNet(4, 4, output_dim=NUM_CLASSES).to(device),
    #"intermediate_fusion_matmul": MatmulIntermediateNet(4, 4, output_dim=NUM_CLASSES).to(device),
    #"intermediate_fusion_hadamard": HadamardIntermediateNet(4, 4, output_dim=NUM_CLASSES).to(device),
    #"intermediate_fusion_add": AddIntermediateNet(4, 4, output_dim=NUM_CLASSES).to(device),
    #"late_fusion": LateNet(output_dim=NUM_CLASSES).to(device),
}

# Directory where best model is saved
checkpoint_dir = STORAGE_PATH / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)


# === Main experiment loop over all fusion strategies ===
for name, model in models_to_train.items():
  model_save_path = checkpoint_dir / f"{name}.pth"

  # Number of trainable parameters (for the comparison table)
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

  opt = Adam(model.parameters(), lr=LR)

  # Initialize a new Weights & Biases run for this model.
  init_wandb(
      model=model,
      fusion_name=name,
      num_params=num_params,
      opt_name = opt.__class__.__name__)

  # Choose the proper input function depending on the fusion strategy:
  if name == "early_fusion":
    input_fn = get_early_inputs
  else:
    input_fn = get_inputs

  train_losses, valid_losses, epoch_times, max_gpu_mem_mb = train_model(
    model=model,
    optimizer=opt,
    input_fn=input_fn,
    epochs=EPOCHS,
    loss_fn=loss_func,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    model_save_path=model_save_path,
    target_idx=-1,   # last element in batch is target
    log_to_wandb=True,
    model_name=name
  )

  metrics[name] = {
      "train_losses": train_losses,
      "valid_losses": valid_losses,
      "epoch_times": epoch_times,
      "best_valid_loss": min(valid_losses),
      "max_gpu_mem_mb": max_gpu_mem_mb,
      "num_params": num_params,
  }

  # End wandb run before starting the next model
  wandb.finish()

# Evaluation

In [None]:
## move: visualization.py

def plot_losses(loss_dict, title="Validation Loss per Model", ylabel="Loss", xlabel="Epoch"):
    """
    Plot validation loss curves for multiple models.

    Args:
        loss_dict (dict): Mapping "model_name" -> list_of_losses (same length).
        title (str): Plot title.
        ylabel (str): Label for y-axis.
        xlabel (str): Label for x-axis.
    """
    plt.figure(figsize=(8,5))

    # Auto-generate x-axis based on first model
    any_key = next(iter(loss_dict))
    epochs = range(len(loss_dict[any_key]))

    for model_name, losses in loss_dict.items():
        plt.plot(epochs, losses, label=model_name)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.show()

In [None]:
# compute avg_epoch_time
avg_epoch_time = sum(epoch_times) / len(epoch_times)

In [None]:
def build_fusion_comparison_df(metrics, name_map=None):
    rows = []
    for key, m in metrics.items():
        avg_train_loss = float(np.mean(m["train_losses"]))
        avg_valid_loss = float(np.mean(m["valid_losses"]))
        avg_epoch_time = float(np.mean(m["epoch_times"]))
        rows.append({
            "Fusion Strategy": name_map.get(key, key) if name_map else key,
            "Avg Valid Loss": avg_valid_loss,
            "Best Valid Loss": float(m["best_valid_loss"]),
            "Num of params": int(m["num_params"]),
            "Avg time per epoch (min:s)": avg_epoch_time,
            "GPU Memory (MB, max)": float(m["max_gpu_mem_mb"]),
        })
    return pd.DataFrame(rows)

In [None]:
loss_dict = {name: m["valid_losses"] for name, m in metrics.items()}
plot_losses(loss_dict, title="Validation Loss per Model")

df_comparison = build_fusion_comparison_df(metrics, name_map)
df_comparison

In [None]:
# logs the comparison table to wandb
wandb.init(
    project="cilp-extended-assessment",   # your project name
    name="fusion_comparison_all",
    job_type="analysis",
)

fusion_comparison_table = wandb.Table(dataframe=df_comparison)
wandb.log({"fusion_comparison": fusion_comparison_table})

wandb.finish()

**When to use**

**Early Fusion:**
* Aligned, closely related low-level modalities and comparable features
* Simple setup; avoid if sensors are noisy

**Intermediate Fusion:**
* Modalities with different structure that benefit from separate early processing in order to learn modality-specific features   
* best overall balance of performance and flexibility

**Late Fusion:**
* Strong, independent unimodal predictors, to combine their strengths
* ideal for heterogeneous or missing modalities
* robust fallback when one modality fails