# Model Architecture

```
RGB ----> RGB Encoder ----\
                            ----> Fusion ---> Classifier ---> cube/sphere
LiDAR -> LiDAR Encoder ----/
```



Multimodal fusion refers to how we combine information from different modalities (e.g., RGB and LiDAR).
There are three canonical levels of fusion:

Early fusion – combine raw or early-level features

Intermediate fusion – combine learned feature representations

Late fusion – combine decisions or latent vectors at the end of the pipeline

Each level has different strengths + limitations.

# Setup

In [1]:
%%capture
%pip install wandb weave

In [2]:
%%capture
%pip install fiftyone==1.10.0 sympy==1.12 torch==2.9.0 torchvision==0.20.0 numpy open-clip-torch

In [3]:
import os
from pathlib import Path
from google.colab import userdata
import time

from PIL import Image
from tqdm import tqdm
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import pandas as pd

import wandb
import cv2
import albumentations as A

In [5]:
from google.colab import drive
drive.mount('/content/drive')

STORAGE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/")

Mounted at /content/drive


In [6]:
DATA_PATH = STORAGE_PATH / "multimodal_training_workshop/data"
print(f"Data path: {DATA_PATH}")
print(f"Data path exists: {DATA_PATH.exists()}")

Data path: /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/multimodal_training_workshop/data
Data path exists: True


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.is_available()

True

In [8]:
SEED = 51
NUM_WORKERS = os.cpu_count()  # Number of CPU cores

BATCH_SIZE = 32
IMG_SIZE = 64

CLASSES = ["cubes", "spheres"]
LABEL_MAP = {"cubes": 0, "spheres": 1}

In [9]:
## Antonio
VALID_BATCHES = 10
N = 1000

# Integrate Wandb

In [10]:
# Load W&B API key from .env file and make it available as env variable
# from dotenv import load_dotenv
# load_dotenv()  # loads .env automatically

# os.environ["WANDB_API_KEY"]

In [11]:
# Load W&B API key from Colab Secrets and make it available as env variable
wandb_key = userdata.get('WANDB_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_key

In [12]:
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mmichele-marschner[0m ([33mmichele-marschner-university-of-potsdam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [15]:
def init_wandb(model, fusion_name, num_params, opt_name, batch_size=BATCH_SIZE, epochs=15):
  config = {
    # "embedding_size": embedding_size,      ## TODO: ändert die sich? hab ich die bei fusion?
    "optimizer_type": opt_name,
    "fusion_strategy": fusion_name,
    "model_architecture": model.__class__.__name__,
    "batch_size": batch_size,
    "num_epochs": epochs,
    "num_parameters": num_params
  }

  run = wandb.init(
    project="cilp-extended-assessment",
    name=f"{fusion_name}_run",
    config=config,
    reinit=True,                          # allows multiple runs in one script
  )

  return

# Reproducibility

In [16]:
def set_seeds(seed=SEED):
    """
    Set seeds for complete reproducibility across all libraries and operations.

    Args:
        seed (int): Random seed value
    """
    # Set environment variables before other imports
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    # Python random module
    random.seed(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch CPU
    torch.manual_seed(seed)

    # PyTorch GPU (all devices)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

        # CUDA deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # OpenCV
    cv2.setRNGSeed(seed)

    # Albumentations (for data augmentation)
    try:
        A.seed_everything(seed)
    except AttributeError:
        # Older versions of albumentations
        pass

    # PyTorch deterministic algorithms (may impact performance)
    try:
        torch.use_deterministic_algorithms(True)
    except RuntimeError:
        # Some operations don't have deterministic implementations
        print("Warning: Some operations may not be deterministic")

    print(f"All random seeds set to {seed} for reproducibility")



# Usage: Call this function at the beginning and before each training phase
set_seeds(SEED)

# Additional reproducibility considerations:

def create_deterministic_training_dataloader(dataset, batch_size, shuffle=True, **kwargs):
    """
    Create a DataLoader with deterministic behavior.

    Args:
        dataset: PyTorch Dataset instance
        batch_size: Batch size
        shuffle: Whether to shuffle data
        **kwargs: Additional DataLoader arguments

    Returns:
        Training DataLoader with reproducible behavior
    """
    # Use a generator with fixed seed for reproducible shuffling
    generator = torch.Generator()
    generator.manual_seed(51)

    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        generator=generator if shuffle else None,
        **kwargs
    )



All random seeds set to 51 for reproducibility


# Utility Functions

In [None]:
def format_time(seconds):
    m = int(seconds // 60)
    s = int(seconds % 60)
    return f"{m:02d}m {s:02d}s"

In [17]:
def hex_to_RGB(hex_str):
    """ #FFFFFF -> [255,255,255]"""
    #Pass 16 to the integer function for change of base
    return [int(hex_str[i:i+2], 16) for i in range(1,6,2)]


def get_color_gradient(c1, c2, c3, n1, n2):
    """
    Given two hex colors, returns a color gradient
    with n colors.
    """
    c1_rgb = np.array(hex_to_RGB(c1))/255
    c2_rgb = np.array(hex_to_RGB(c2))/255
    c3_rgb = np.array(hex_to_RGB(c3))/255
    mix_pcts_c1_c2 = [x/(n1-1) for x in range(n1)]
    mix_pcts_c2_c3 = [x/(n2-1) for x in range(n2)]
    rgb_c1_c2 = [((1-mix)*c1_rgb + (mix*c2_rgb)) for mix in mix_pcts_c1_c2]
    rgb_c2_c3 = [((1-mix)*c2_rgb + (mix*c3_rgb)) for mix in mix_pcts_c2_c3]
    rgb_colors = rgb_c1_c2 + rgb_c2_c3
    return ["#" + "".join([format(int(round(val*255)), "02x") for val in item]) for item in rgb_colors]


cmap = colors.ListedColormap(get_color_gradient("#000000", "#76b900", "#f1ffd9", 64, 128))


In [18]:
def get_outputs(model, batch, inputs_idx):
    inputs = batch[inputs_idx].to(device)
    target = batch[-1].to(device)
    outputs = model(inputs)
    return outputs, target

In [19]:
def get_torch_xyza(lidar_depth, azimuth, zenith):
    x = lidar_depth * torch.sin(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    y = lidar_depth * torch.cos(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    z = lidar_depth * torch.sin(-zenith[None, :])
    a = torch.where(lidar_depth < 50.0, torch.ones_like(lidar_depth), torch.zeros_like(lidar_depth))
    xyza = torch.stack((x, y, z, a))
    return xyza

In [20]:
def format_positions(positions):
    return ['{0: .3f}'.format(x) for x in positions]

In [40]:
def train_model(model, optimizer, input_fn, loss_fn, epochs, train_dataloader, val_dataloader, model_save_path, target_idx=-1, log_to_wandb=False, model_name=None):
    train_losses = []
    valid_losses = []
    epoch_times = []

    best_val_loss = float('inf')
    best_model = None

    # for GPU memory tracking
    max_gpu_mem_mb = 0.0
    use_cuda = torch.cuda.is_available()

    if use_cuda:
        torch.cuda.reset_peak_memory_stats()

    for epoch in range(epochs):
        start_time = time.time()                  # to track the train time per model
        print(f"Epoch and start time: {epoch} und {start_time}")
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_dataloader):

            rgb, lidar_xyza, position = batch
            rgb = rgb.to(device)
            lidar_xyza = lidar_xyza.to(device)
            position = position.to(device)

            optimizer.zero_grad()
            target = batch[target_idx].to(device)
            outputs = model(*input_fn(batch))

            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / (step + 1)
        train_losses.append(train_loss)
        print_loss(epoch, train_loss, outputs, target, is_train=True)

        # ----- validation -----
        model.eval()
        valid_loss = 0
        with torch.no_grad():
          for step, batch in enumerate(val_dataloader):
              target = batch[target_idx].to(device)
              outputs = model(*input_fn(batch))
              valid_loss += loss_fn(outputs, target).item()
        valid_loss = valid_loss / (step + 1)
        valid_losses.append(valid_loss)
        print_loss(epoch, valid_loss, outputs, target, is_train=False)

        if valid_loss < best_val_loss:
          best_val_loss = valid_loss
          best_model = model
          # Save the best model
          torch.save(best_model.state_dict(), model_save_path)
          print('Found and saved better weights for the model')

        # timing
        epoch_time = time.time() - start_time
        epoch_time_formatted = format_time(epoch_time)
        epoch_times.append(epoch_time_formatted)

        # GPU memory
        if use_cuda:
            gpu_mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
            max_gpu_mem_mb = max(max_gpu_mem_mb, gpu_mem_mb)

        # wandb logging
        if log_to_wandb:
            wandb.log(
                {
                    "model": model.__class__.__name__,
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "valid_loss": valid_loss,
                    "lr": optimizer.param_groups[0]["lr"],
                    "epoch_time": epoch_time_formatted,
                    "max_gpu_mem_mb_epoch": gpu_mem_mb if use_cuda else 0.0,
                }
            )

    return train_losses, valid_losses, epoch_times, max_gpu_mem_mb

In [22]:
def print_loss(epoch, loss, outputs, target, is_train=True, is_debug=False):
    loss_type = "train loss:" if is_train else "valid loss:"
    print("epoch", str(epoch), loss_type, str(loss))
    if is_debug:
        print("example pred:", format_positions(outputs[0].tolist()))
        print("example real:", format_positions(target[0].tolist()))

In [23]:
def plot_losses(losses, title="Training & Validation Loss Comparison", figsize=(10,6)):
    plt.figure(figsize=figsize)

    for model_name, log in losses.items():
        train = log["train_losses"]
        valid = log["valid_losses"]

        # plot train + valid with different line styles
        plt.plot(train, label=f"{model_name} - train", linewidth=2)
        plt.plot(valid, label=f"{model_name} - valid", linestyle="--", linewidth=2)

    plt.title(title, fontsize=16)
    plt.xlabel("Epochs", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

In [24]:
img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),  # Scales data into [0,1]    ## TODO: transforms.v2?
])



In [25]:
img_transforms = transforms.Compose([
    transforms.ToImage(),   # Scales data into [0,1]    ## TODO: transforms.v2?
    transforms.Resize(IMG_SIZE),
    transforms.ToDtype(torch.float32, scale=True),
    ## normalize
])

# Load and prepare Data

In [26]:
class ReplicatorDataset(Dataset):
    def __init__(self, root_dir, start_idx, stop_idx):
        self.root_dir = Path(root_dir)

        # indices this dataset will cover
        self.indices = list(range(start_idx, stop_idx))

        # positions: (N, 3) or (N, 4) depending on file
        all_positions = pd.read_csv(
            "https://github.com/andandandand/practical-computer-vision/blob/main/artifacts/positions.csv?raw=1"
        ).values
        self.positions = all_positions[start_idx:stop_idx]

        # azimuth / zenith loaded once
        azimuth_path = self.root_dir / "azimuth.npy"
        zenith_path = self.root_dir / "zenith.npy"

        if not azimuth_path.exists():
            raise FileNotFoundError(f"azimuth.npy not found at {azimuth_path}")
        if not zenith_path.exists():
            raise FileNotFoundError(f"zenith.npy not found at {zenith_path}")

        azimuth = np.load(azimuth_path)
        zenith = np.load(zenith_path)
        # keep them as CPU tensors; move to device in training if needed
        self.azimuth = torch.from_numpy(azimuth)      # shape (H,)
        self.zenith = torch.from_numpy(zenith)        # shape (W,)

        # dirs
        self.rgb_dir = self.root_dir / "rgb"
        self.lidar_dir = self.root_dir / "lidar"

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # map dataset index -> real frame index
        frame_idx = self.indices[idx]
        file_number = f"{frame_idx:04d}"

        # --- RGB ---
        rgb_path = self.rgb_dir / f"{file_number}.png"
        rgb_img = Image.open(rgb_path)
        rgb_tensor = img_transforms(rgb_img)   # still on CPU

        # --- LiDAR depth ---
        lidar_path = self.lidar_dir / f"{file_number}.npy"
        lidar_depth = np.load(lidar_path)      # (H, W)
        lidar_depth = torch.from_numpy(lidar_depth).to(torch.float32)  # CPU

        # --- XYZA ---
        lidar_xyza = get_torch_xyza(lidar_depth, self.azimuth, self.zenith)  # (4, H, W)

        # --- position ---
        position_np = self.positions[idx]     # numpy row
        position = torch.from_numpy(position_np).to(torch.float32)  # CPU

        return rgb_tensor, lidar_xyza, position

In [27]:
def get_replicator_dataloaders(root_dir):
    train_data = ReplicatorDataset(root_dir, 0, N-VALID_BATCHES*BATCH_SIZE)
    train_dataloader = create_deterministic_training_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_data = ReplicatorDataset(root_dir, N-VALID_BATCHES*BATCH_SIZE, N)
    val_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
    return train_data, train_dataloader, valid_data, val_dataloader


In [28]:
train_data, train_dataloader, valid_data, val_dataloader = get_replicator_dataloaders(str(DATA_PATH / "replicator_data_cubes/"))

for i, sample in enumerate(train_data):
    print(i, *(x.shape for x in sample))
    break

0 torch.Size([4, 64, 64]) torch.Size([4, 64, 64]) torch.Size([9])


# Create Models

Take the Net architecture from the workshop and turn it into an encoder that outputs an embedding instead of 9 positions.

## Early Fusion Model

**Concept:** Fuse modalities before any deep processing — usually by concatenating channels or inputs.

```
input = concat(RGB, XYZ)  → shape (8, H, W)
-> shared CNN processes everything together
```



**Advantages:**

* **Captures Early Cross-Modal Interactions:** Learns joint low-level correlations directly from raw signals.
* **Simple & Lightweight**: Easiest fusion method to implement; minimal architectural overhead.
* **Effective with Perfect Alignment:** Works well when modalities are tightly synchronized and spatially aligned.

**Limitations:**

* **Noise Sensitivity:** One noisy or corrupted modality directly contaminates the shared feature space.
* **Strict Alignment Requirement:** Modalities must have matching spatial resolution, alignment, and synchronization.
* **Feature Space Mismatch:** Raw modalities differ in scale, units, and distribution; one modality can dominate without careful normalization.
* **High Input Dimensionality:** Channel concatenation increases the input size and can require more data and compute to train effectively.
* **Limited Flexibility:** Assumes combining low-level signals is beneficial; may underperform when modalities carry different types of information.

In [29]:
num_positions = 9

class Net(nn.Module):
    def __init__(self, in_ch):
        kernel_size = 3
        super().__init__()
        flattened_size = 200 * 8 * 8
        self.conv1 = nn.Conv2d(in_ch, 50, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(50, 100, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(100, 200, kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(flattened_size, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Intermediate Fusion Models

**Concept:** Each modality has its own encoder / feature extractor, and fusion happens after some layers but before classification.

```
RGB → RGB_conv → RGB_features (C, H, W)
LiDAR → LiDAR_conv → LiDAR_features (C, H, W)

Fusion → joint_features → FC → output
```



**Advantages:**

* **Specialized Processing:** Each modality gets its own encoder, tailored to its characteristics.
* **Learned Representations:** Fusion occurs on higher-level, more discriminative features rather than raw data.
* **Flexible Design:** The fusion point can be chosen at different network depths, allowing fine-grained architectural control.
* **Easily Extendable:** New modalities can be added by including additional modality-specific branches.


**Limitations:**

* **Architectural Complexity:** Requires designing separate modality-specific encoders and choosing an appropriate fusion point.
* **Higher Computational Cost:** More expensive than early fusion due to duplicated feature extractors.
* **Fusion Design Sensitivity:** Performance depends on the chosen fusion mechanism (concat, addition, multiplicative, bilinear, attention), which often requires experimentation.
* **Depth Selection Challenge:** Deciding how much unimodal processing to perform before fusion can be non-trivial and task-dependent.

Implemented 4 variants:

*   Concatenation
*   Addition
* Hadamard Product (element-wise multiplication)
* Matrix-Multiplication



| Fusion Method | Advantages | Limitations |
|---------------|------------|-------------|
| **Concatenation** | - Very expressive and flexible<br>- Lets the network learn arbitrary cross-modal interactions<br>- Robust and widely used baseline | - Doubles channel count → more parameters & memory<br>- Computationally heavier<br>- Fusion is unguided; model must discover interactions itself |
| **Addition** | - Lightweight (no increase in channels)<br>- Fast and parameter-efficient<br>- Enforces similar feature spaces between modalities | - Assumes features are aligned and comparable<br>- One noisy modality corrupts the other<br>- Sensitive to scale differences between modalities |
| **Multiplicative (Hadamard Product)** | - Gating effect: highlights features important in *both* modalities<br>- More expressive than addition, cheaper than concat<br>- Natural for attention-like fusion | - Suppresses features when one modality has low magnitude<br>- Requires careful normalization<br>- Can amplify noise if both activations are high |
| **Matrix Multiplication (Bilinear-like)** | - Captures rich pairwise correlations between modalities<br>- Most expressive among all four<br>- Enables true 2nd-order interaction learning | - Very heavy in compute & memory<br>- Requires flattening or dimensionality reduction<br>- Easily overfits; harder to train and tune |


In [30]:
class ConcatIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        kernel_size = 3
        num_positions = 9
        super().__init__()
        self.rgb_conv1 = nn.Conv2d(rgb_ch, 25, kernel_size, padding=1)
        self.rgb_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.rgb_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)

        self.xyz_conv1 = nn.Conv2d(xyz_ch, 25, kernel_size, padding=1)
        self.xyz_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.xyz_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)

        # this downsampling can be done with convolutions of stride 2
        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(200 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.pool(F.relu(self.rgb_conv1(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv2(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv3(x_rgb)))

        x_xyz = self.pool(F.relu(self.xyz_conv1(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv2(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv3(x_xyz)))

        x = torch.cat((x_rgb, x_xyz), 1)
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

In [31]:
class AddIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        super().__init__()
        kernel_size = 3
        num_positions = 9

        # same twin towers as before
        self.rgb_conv1 = nn.Conv2d(rgb_ch, 25, kernel_size, padding=1)
        self.rgb_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.rgb_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)

        self.xyz_conv1 = nn.Conv2d(xyz_ch, 25, kernel_size, padding=1)
        self.xyz_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.xyz_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)

        self.pool = nn.MaxPool2d(2)

        # now we keep 100 channels (not 200), so:
        self.fc1 = nn.Linear(100 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.pool(F.relu(self.rgb_conv1(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv2(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv3(x_rgb)))      # (B, 100, 8, 8)

        x_xyz = self.pool(F.relu(self.xyz_conv1(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv2(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv3(x_xyz)))      # (B, 100, 8, 8)

        # intermediate fusion via addition
        x = x_rgb + x_xyz                                     # (B, 100, 8, 8)

        x = torch.flatten(x, 1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [32]:
class MatmulIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        kernel_size = 3
        color_chs = 9
        num_positions = 9
        super().__init__()
        self.rgb_conv1 = nn.Conv2d(rgb_ch, 25, kernel_size, padding=1)
        self.rgb_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.rgb_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)

        self.xyz_conv1 = nn.Conv2d(xyz_ch, 25, kernel_size, padding=1)
        self.xyz_conv2 = nn.Conv2d(25, 50, kernel_size, padding=1)
        self.xyz_conv3 = nn.Conv2d(50, 100, kernel_size, padding=1)

        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(100 * 8 * 8, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.pool(F.relu(self.rgb_conv1(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv2(x_rgb)))
        x_rgb = self.pool(F.relu(self.rgb_conv3(x_rgb)))

        x_xyz = self.pool(F.relu(self.xyz_conv1(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv2(x_xyz)))
        x_xyz = self.pool(F.relu(self.xyz_conv3(x_xyz)))

        #x = torch.matmul(x_rgb, x_xyz)
        x = x_rgb * x_xyz
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

## Late Fusion Model

**Concept:** Each modality is processed completely separately, and only the final predictions or high-level embeddings are fused.

```
RGB → RGB-Net → logits_rgb
LiDAR → LiDAR-Net → logits_lidar

Fusion → final decision
```

**Advantages:**

* **Robust to Missing Modalities:** The system can still operate if one modality is noisy, unreliable, or absent.
* **Best for Heterogeneous Modalities:** Works well when modalities differ greatly.
* **Modular & Simple:** Unimodal models can be trained, debugged, and replaced independently.
* **Leverages Existing Models:** Allows the reuse of strong off-the-shelf unimodal experts without architectural changes.


**Limitations:**

* **Missed Interactions:** No joint feature learning — modalities never influence each other during representation learning.
* **Limited Expressiveness:** Simple fusion rules (e.g., averaging, weighted sum) cannot capture complex cross-modal relationships.
* **Information Loss:** By the time unimodal predictors output logits/embeddings, rich spatial and semantic details may already be discarded, limiting the power of fusion.

In [33]:
rgb_net = Net(4).to(device)
xyz_net = Net(4).to(device)

class LateNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rgb = rgb_net
        self.xyz = xyz_net
        self.fc1 = nn.Linear(num_positions * 2, num_positions * 10)
        self.fc2 = nn.Linear(num_positions * 10, num_positions)

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.rgb(x_rgb)
        x_xyz = self.xyz(x_xyz)
        # this concatenates the features from the two branches
        x = torch.cat((x_rgb, x_xyz), 1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [34]:
### TODO aufpassen: bei Antonio zusätzlich vorher ConvEnc trainiert und gradienten eingefroern - chatty macht das hier nicht??
### aber auch im nächsten lab trainiert antonio nicht vor um besser vergleichen zu können mit early fusion
### old
class LateFusionModel(nn.Module):
    """
    Late fusion:
    - RGB and LiDAR are encoded completely separately.
    - Only at the very end we concatenate their embeddings and optionally project.

    This matches the idea:
        "Separate encoders → final embeddings → fusion → similarity"
    """

    def __init__(self, emb_dim=128, hidden_dim=256, out_dim=2):
        """
        Args:
            emb_dim: size of each individual modality embedding
            fused_dim: size of the final fused embedding
        """
        super().__init__()

        # Separate encoders for each modality
        self.rgb_enc = ConvEncoder(in_ch=4, emb_dim=emb_dim)      ## TODO: mit Antonio abgleichen
        self.lidar_enc = ConvEncoder(in_ch=4, emb_dim=emb_dim)    ## TODO: mit Antonio abgleichen

        # Linear layer to mix and reduce concatenated embeddings
        # Input is [rgb_emb, lidar_emb] of size 2 * emb_dim
        self.fusion_fc1 = nn.Linear(2 * emb_dim, hidden_dim)
        self.fusion_fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, rgb, lidar):
        """
        Args:
            rgb:   (B, 4, 64, 64)
            lidar: (B, 4, 64, 64)

        Returns:
            fused_emb: (B, fused_dim)  # used for similarity / contrastive loss
            rgb_emb:   (B, emb_dim)    # optional, for analysis
            lidar_emb: (B, emb_dim)    # optional, for analysis
        """
        # 1) Encode each modality with its own ConvEncoder
        rgb_emb = self.rgb_enc(rgb)           # (B, emb_dim)
        lidar_emb = self.lidar_enc(lidar)     # (B, emb_dim)

        # 2) Late fusion: concatenate embeddings
        x = torch.cat((rgb_emb, lidar_emb), dim=1)  # (B, 2*emb_dim)

        # 3) Optional projection to a joint fused space   ## TODO auch hier abweichung zu antonio, hat relu + das zweite linear
        x = F.relu(self.fusion_fc1(x))                 # (B, fused_dim)
        out = self.fusion_fc2(x)                       # (B, out_dim)

        return out


In [35]:
## old - aber vielleicht wichtig für Vergleich/Optimization
class IntermediateFusionModel(nn.Module):
    """
    Intermediate fusion:
    - Each modality has its own early conv layers (conv1, conv2).
    - Their feature maps are then concatenated.
    - Shared later layers (conv3 + FCs) operate on the fused feature maps.

    This lets RGB and LiDAR interact earlier and at a more local spatial level.
    """

    def __init__(self, emb_dim=128, hidden_dim=256, out_dim=2):    ## TODO Antonio gibt channels als parameter rein
        super().__init__()
        k = 3
        # this downsampling can be done with convolutions of stride 2
        self.pool = nn.MaxPool2d(2)

        # --- Modality-specific early convolutions ---

        # RGB branch: takes 4-channel input and produces feature maps
        self.rgb_conv1 = nn.Conv2d(4, 50, k, padding=1)
        self.rgb_conv2 = nn.Conv2d(50, 100, k, padding=1)     ## TODO Antonio hat ein Conv2d more, das erste auch nur 4,25?

        # LiDAR branch: same structure but separate weights
        self.lidar_conv1 = nn.Conv2d(4, 50, k, padding=1)
        self.lidar_conv2 = nn.Conv2d(50, 100, k, padding=1)

        # --- Shared later layers ---        ## TODO baut das hier komplett anders auf als Antonio, außerdem hat Matmul, also nicht concatenate sondern multiplication den besten val_loss

        # After conv2+pool in each branch:
        # RGB feature maps:   (B, 100, 16, 16)
        # LiDAR feature maps: (B, 100, 16, 16)
        # Concatenated:       (B, 200, 16, 16)
        self.shared_conv3 = nn.Conv2d(200, 200, k, padding=1)

        # After another pool: (B, 200, 8, 8)
        self.fc1 = nn.Linear(200 * 8 * 8, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, out_dim)

    def forward(self, rgb, lidar):
        """
        Args:
            rgb:   (B, 4, 64, 64)
            lidar: (B, 4, 64, 64)

        Returns:
            emb: (B, emb_dim)  # fused embedding for similarity / contrastive loss
        """
        # --- RGB early branch ---
        x_rgb = self.pool(F.relu(self.rgb_conv1(rgb)))     # (B, 50, 32, 32)
        x_rgb = self.pool(F.relu(self.rgb_conv2(x_rgb)))   # (B, 100, 16, 16)

        # --- LiDAR early branch ---
        x_lid = self.pool(F.relu(self.lidar_conv1(lidar))) # (B, 50, 32, 32)
        x_lid = self.pool(F.relu(self.lidar_conv2(x_lid))) # (B, 100, 16, 16)

        # --- Intermediate fusion on feature maps ---
        # Concatenate along channel dimension
        x = torch.cat([x_rgb, x_lid], dim=1)               # (B, 200, 16, 16)

        # Shared conv and pooling
        x = self.pool(F.relu(self.shared_conv3(x)))        # (B, 200, 8, 8)

        # Flatten and project to embedding
        x = torch.flatten(x, 1)                            # (B, 200*8*8)
        x = F.relu(self.fc1(x))                            # (B, 1000)
        out = self.fc2(x)                                  # (B, out_dim)

        return out


# Model Training

In [36]:
def get_early_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    inputs_mm_early = torch.cat((inputs_rgb, inputs_xyz), 1)
    return (inputs_mm_early,)

In [37]:
def get_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    return (inputs_rgb, inputs_xyz)

In [41]:
set_seeds(SEED)

EPOCHS = 20
LR = 0.0001

loss_func = nn.MSELoss()
metrics = {}   # store losses for each model

models_to_train = {
    "early_fusion": Net(8).to(device),
    "intermediate_fusion_concat": ConcatIntermediateNet(4, 4).to(device),
    "intermediate_fusion_matmul": MatmulIntermediateNet(4, 4).to(device),
    "late_fusion": LateNet().to(device),
}

for name, model in models_to_train.items():
  model_save_path = STORAGE_PATH / f"checkpoints/{name}.pth"

  # metrics for comparison table
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

  opt = Adam(model.parameters(), lr=LR)

  # initialize wandb
  init_wandb(
      model=model,
      fusion_name=name,
      num_params=num_params,
      opt_name = opt.__class__.__name__)

  if name == "early_fusion":
    input_fn = get_early_inputs
  else:
    input_fn = get_inputs

  train_losses, valid_losses, epoch_times, max_gpu_mem_mb = train_model(
    model=model,
    optimizer=opt,
    input_fn=input_fn,
    epochs=EPOCHS,
    loss_fn=loss_func,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    model_save_path=model_save_path,
    target_idx=-1,   # last element in batch is target
    log_to_wandb=True,
    model_name=name
  )

  metrics[name] = {
      "train_losses": train_losses,
      "valid_losses": valid_losses,
      "epoch_times": epoch_times,
      "best_valid_loss": min(valid_losses),
      "max_gpu_mem_mb": max_gpu_mem_mb,
      "num_params": num_params,
  }

  # End wandb run
  wandb.finish()

All random seeds set to 51 for reproducibility


Epoch and start time: 0 und 1764573617.0311217
epoch 0 train loss: 8.515963849567232
epoch 0 valid loss: 8.233570241928101
Found and saved better weights for the model
Epoch and start time: 1 und 1764573624.7899916
epoch 1 train loss: 8.209987685793923
epoch 1 valid loss: 7.8153589248657225
Found and saved better weights for the model
Epoch and start time: 2 und 1764573634.77336
epoch 2 train loss: 7.531391893114362
epoch 2 valid loss: 7.271495199203491
Found and saved better weights for the model
Epoch and start time: 3 und 1764573643.4777575
epoch 3 train loss: 7.028215748923166
epoch 3 valid loss: 7.0777812004089355
Found and saved better weights for the model
Epoch and start time: 4 und 1764573654.0677788
epoch 4 train loss: 6.792626176561628
epoch 4 valid loss: 7.054382944107056
Found and saved better weights for the model
Epoch and start time: 5 und 1764573662.348445
epoch 5 train loss: 6.631522610073998
epoch 5 valid loss: 6.857812833786011
Found and saved better weights for the

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch_time_sec,▂▆▃▇▃█▅▁▅▂▄▂▄▂▁▂▂▁▁▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▇▆▅▅▅▄▄▄▄▃▃▃▃▃▂▂▂▁▁
valid_loss,█▆▄▃▃▂▂▁▁▁▁▁▁▂▂▂▂▂▃▃

0,1
epoch,20
epoch_time_sec,7.4402
lr,0.0001
max_gpu_mem_mb_epoch,1107.18457
model,Net
train_loss,4.37123
valid_loss,7.14445


Epoch and start time: 0 und 1764573789.9061744
epoch 0 train loss: 8.507612932296027
epoch 0 valid loss: 8.294523239135742
Found and saved better weights for the model
Epoch and start time: 1 und 1764573797.8143268
epoch 1 train loss: 8.34981784366426
epoch 1 valid loss: 8.084389305114746
Found and saved better weights for the model
Epoch and start time: 2 und 1764573805.707572
epoch 2 train loss: 7.910055228642055
epoch 2 valid loss: 7.4321009635925295
Found and saved better weights for the model
Epoch and start time: 3 und 1764573813.700791
epoch 3 train loss: 7.155026367732456
epoch 3 valid loss: 6.95517930984497
Found and saved better weights for the model
Epoch and start time: 4 und 1764573821.8236983
epoch 4 train loss: 6.675247714633033
epoch 4 valid loss: 6.647566604614258
Found and saved better weights for the model
Epoch and start time: 5 und 1764573830.836064
epoch 5 train loss: 6.2017669677734375
epoch 5 valid loss: 6.106414413452148
Found and saved better weights for the m

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch_time_sec,▂▂▂▂▅▂█▁▅▇▄▂▇▁▂▅▅▆▁▂
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,██▇▇▆▅▅▄▃▃▃▂▂▂▂▂▁▁▁▁
valid_loss,██▇▆▆▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁

0,1
epoch,20
epoch_time_sec,8.014
lr,0.0001
max_gpu_mem_mb_epoch,1062.55469
model,ConcatIntermediateNe...
train_loss,2.17871
valid_loss,3.31819


Epoch and start time: 0 und 1764573961.521104
epoch 0 train loss: 8.461432411557151
epoch 0 valid loss: 8.26976523399353
Found and saved better weights for the model
Epoch and start time: 1 und 1764573968.9938958
epoch 1 train loss: 8.315308480035691
epoch 1 valid loss: 7.995764970779419
Found and saved better weights for the model
Epoch and start time: 2 und 1764573976.7329223
epoch 2 train loss: 7.531715892610096
epoch 2 valid loss: 7.134623336791992
Found and saved better weights for the model
Epoch and start time: 3 und 1764573984.741451
epoch 3 train loss: 6.984503768739247
epoch 3 valid loss: 6.8982240676879885
Found and saved better weights for the model
Epoch and start time: 4 und 1764573992.270693
epoch 4 train loss: 6.630430584862118
epoch 4 valid loss: 6.493797063827515
Found and saved better weights for the model
Epoch and start time: 5 und 1764574000.9977586
epoch 5 train loss: 5.845290842510405
epoch 5 valid loss: 5.4144923210144045
Found and saved better weights for the 

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch_time_sec,▁▃▄▂▇▁▃▅▂█▄▃▁▆▂▂▅▁▃▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,██▇▇▆▆▅▄▃▃▂▂▂▂▂▂▁▁▁▁
valid_loss,██▇▆▆▄▃▂▂▂▁▁▂▁▁▁▁▁▁▁

0,1
epoch,20
epoch_time_sec,7.37403
lr,0.0001
max_gpu_mem_mb_epoch,1036.26074
model,MatmulIntermediateNe...
train_loss,0.97222
valid_loss,3.04706


Epoch and start time: 0 und 1764574121.6844091
epoch 0 train loss: 8.465805121830531
epoch 0 valid loss: 8.286694240570068
Found and saved better weights for the model
Epoch and start time: 1 und 1764574130.3891132
epoch 1 train loss: 8.411811419895717
epoch 1 valid loss: 8.186436462402344
Found and saved better weights for the model
Epoch and start time: 2 und 1764574139.7838795
epoch 2 train loss: 8.13121934164138
epoch 2 valid loss: 7.802686166763306
Found and saved better weights for the model
Epoch and start time: 3 und 1764574148.613451
epoch 3 train loss: 7.401416506086077
epoch 3 valid loss: 7.208476161956787
Found and saved better weights for the model
Epoch and start time: 4 und 1764574157.3967774
epoch 4 train loss: 6.7741608165559315
epoch 4 valid loss: 6.693150186538697
Found and saved better weights for the model
Epoch and start time: 5 und 1764574166.3995264
epoch 5 train loss: 6.339687006814139
epoch 5 valid loss: 6.282789707183838
Found and saved better weights for the

0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
epoch_time_sec,▁▂▂▁▂▁▁▂▁▃█▂▃▅▆▆▁▄▃▁
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,███▇▆▅▅▄▄▃▃▂▂▂▂▂▁▁▁▁
valid_loss,██▇▆▆▅▅▄▃▃▂▂▂▁▂▁▁▁▁▁

0,1
epoch,20
epoch_time_sec,8.51392
lr,0.0001
max_gpu_mem_mb_epoch,1409.93359
model,LateNet
train_loss,2.80872
valid_loss,3.53529


In [42]:
set_seeds(SEED)

matmul_net2 = MatmulIntermediateNet(4, 4).to(device)
matmul_net2_opt = Adam(matmul_net2.parameters(), lr=LR)
model_save_path = STORAGE_PATH / "checkpoints/intermediate_fusion_hadamard.pth"

matmul_net2_train_losses, matmul_net2_valid_losses, epoch_times, max_gpu_mem_mb = train_model(
    model=matmul_net2,
    optimizer=matmul_net2_opt,
    input_fn=get_inputs,
    epochs=EPOCHS,
    loss_fn=loss_func,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    model_save_path=model_save_path,
    target_idx=-1,   # last element in batch is target
    log_to_wandb=True,
    model_name="intermediate_fusion_hadamard"
)


All random seeds set to 51 for reproducibility
Epoch and start time: 0 und 1764574313.9564257
epoch 0 train loss: 8.448761167980376
epoch 0 valid loss: 8.241648721694947
Found and saved better weights for the model


Error: You must call wandb.init() before wandb.log()

In [None]:
set_seeds(SEED)
model_save_path = STORAGE_PATH / "checkpoints/intermediate_fusion_add.pth"
add_net = AddIntermediateNet(4, 4).to(device)
add_net_opt = Adam(add_net.parameters(), lr=LR)
add_net_train_losses, add_net_valid_losses, epoch_times, max_gpu_mem_mb = train_model(
    model=add_net,
    optimizer=add_net_opt,
    input_fn=get_inputs,
    epochs=EPOCHS,
    loss_fn=loss_func,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    model_save_path=model_save_path,
    target_idx=-1,   # last element in batch is target
    log_to_wandb=True,
    model_name="intermediate_fusion_add"
)

In [None]:
single_mode_data = pd.read_csv('https://raw.githubusercontent.com/andandandand/practical-computer-vision/refs/heads/main/artifacts/cubes_only_single_mode_results.csv').values

plot_x = range(len(single_mode_data))
plt.xlabel("Epoch")
plt.ylabel("Average Loss")
plt.plot(plot_x, add_net_valid_losses, "goldenrod", label = "intermediate_fusion_add")
plt.plot(plot_x, matmul_net2_valid_losses, "green", label = "intermediate_fusion_hadamard")
plt.legend()
plt.show()

# Evaluation

In [None]:
def plot_losses(loss_dict, title="Validation Loss per Model", ylabel="Loss", xlabel="Epoch"):
    """
    loss_dict: dict of { "model_name": list_of_losses }
               Every list must have the same length.

    Example:
        loss_dict = {
            "EarlyNet": early_valid_losses,
            "LateNet": late_valid_losses,
            "CatNet": cat_net_valid_losses,
            "MatmulNet": matmul_net_valid_losses,
        }
    """

    plt.figure(figsize=(8,5))

    # Auto-generate x-axis based on first model
    any_key = next(iter(loss_dict))
    epochs = range(len(loss_dict[any_key]))

    for model_name, losses in loss_dict.items():
        plt.plot(epochs, losses, label=model_name)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.show()

In [None]:
loss_dict = {
    "EarlyNet": metrics["early_fusion"]["valid_losses"],
    "CatNet": metrics["intermediate_fusion_concat"]["valid_losses"],
    "MatmulNet": metrics["intermediate_fusion_"]["valid_losses"],
    "AddNet": add_net_valid_losses,
    "MatmulNet2": matmul_net2_valid_losses,
    "LateNet": metrics["late_fusion"]["valid_losses"]
}

plot_losses(loss_dict, title="Validation Loss Comparison")

In [None]:
# compute avg_epoch_time
avg_epoch_time = sum(epoch_times) / len(epoch_times)

In [None]:
import numpy as np
import pandas as pd

# Optional: nice display names for each key in `metrics`
name_map = {
    "early_fusion": "Early Fusion",
    "late_fusion": "Late Fusion",
    "intermediate_fusion_concat": "Intermediate (Concat)",
    "intermediate_fusion_matmul": "Intermediate (Multiplicative)",
    "intermediate_fusion_add": "Intermediate (Add)",   # if you have it
}

rows = []

for key, m in metrics.items():
    avg_train_loss = float(np.mean(m["train_losses"]))
    avg_valid_loss = float(np.mean(m["valid_losses"]))
    avg_epoch_time = float(np.mean(m["epoch_times"]))

    rows.append({
        "Fusion Strategy": name_map.get(key, key),
        "Avg Valid Loss": avg_valid_loss,
        "Best Valid Loss": float(m["best_valid_loss"]),
        "Num of params": int(m["num_params"]),
        "Avg time per epoch (min:s)": avg_epoch_time,
        "GPU Memory (MB, max)": float(m["max_gpu_mem_mb"]),
    })

df_comparison = pd.DataFrame(rows)
df_comparison


In [None]:
# logs the comparison table to wandb
wandb.init(
    project="cilp-extended-assessment",   # your project name
    name="fusion_comparison_all",
    job_type="analysis",
)

fusion_comparison_table = wandb.Table(dataframe=df_comparison)
wandb.log({"fusion_comparison": fusion_comparison_table})

wandb.finish()

**When to use**

**Early Fusion:**
* Aligned, closely related low-level modalities and comparable features
* Simple setup; avoid if sensors are noisy

**Intermediate Fusion:**
* Modalities with different structure that benefit from separate early processing in order to learn modality-specific features   
* best overall balance of performance and flexibility

**Late Fusion:**
* Strong, independent unimodal predictors, to combine their strengths
* ideal for heterogeneous or missing modalities
* robust fallback when one modality fails