# Model Architecture

```
RGB ----> RGB Encoder ----\
                            ----> Fusion ---> Classifier ---> cube/sphere
LiDAR -> LiDAR Encoder ----/
```



**The Architecture Flow:**

```
RGB Input (4ch)       LiDAR Input (4ch)
      │                     │
[RGB Encoder]         [XYZ Encoder]    <-- Learn specific features independently
      │                     │
  RGB Features          XYZ Features   <-- (e.g. 128 channels each)
      └──────────┬──────────┘
                 │
           Concatenation               <-- Fuse at the "Feature Level"
                 │
         [Regression Head]             <-- Learn relationships between features
                 │
           Output (x,y,z)
```

Multimodal fusion refers to how we combine information from different modalities (e.g., RGB and LiDAR).
There are three canonical levels of fusion:

Early fusion – combine raw or early-level features

Intermediate fusion – combine learned feature representations

Late fusion – combine decisions or latent vectors at the end of the pipeline;  it's almost like we're creating an ensemble model, where each model has a weighted vote in the final result.

Each level has different strengths + limitations.

# Setup

In [None]:
%%capture
%pip install wandb weave

In [None]:
%%capture
%pip install fiftyone==1.10.0 sympy==1.12 torch==2.9.0 torchvision==0.20.0 numpy open-clip-torch

In [None]:
import os
from pathlib import Path
from google.colab import userdata
import time

from PIL import Image
from tqdm import tqdm
import random

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from torch.utils.data import Dataset, DataLoader, Subset
from torch.optim import Adam

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as colors
import pandas as pd

import wandb
import cv2
import albumentations as A

In [None]:
from google.colab import drive
drive.mount('/content/drive')

STORAGE_PATH = Path("/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/")

Mounted at /content/drive


In [None]:
DATA_PATH = STORAGE_PATH / "multimodal_training_workshop/data"
print(f"Data path: {DATA_PATH}")
print(f"Data path exists: {DATA_PATH.exists()}")

Data path: /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/multimodal_training_workshop/data
Data path exists: True


In [None]:
# !rsync -ah --progress "/content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/data" "/content/data/"


[1;30;43mDie letzten 5000 Zeilen der Streamingausgabe wurden abgeschnitten.[0m
data/cubes/lidar_xyza/6604.npy
         65.66K 100%  212.33kB/s    0:00:00 (xfr#1638, ir-chk=3364/5007)
data/cubes/lidar_xyza/6605.npy
         65.66K 100%  208.88kB/s    0:00:00 (xfr#1639, ir-chk=3363/5007)
data/cubes/lidar_xyza/6606.npy
         65.66K 100%  207.52kB/s    0:00:00 (xfr#1640, ir-chk=3362/5007)
data/cubes/lidar_xyza/6607.npy
         65.66K 100%  205.53kB/s    0:00:00 (xfr#1641, ir-chk=3361/5007)
data/cubes/lidar_xyza/6610.npy
         65.66K 100%  204.22kB/s    0:00:00 (xfr#1642, ir-chk=3360/5007)
data/cubes/lidar_xyza/6611.npy
         65.66K 100%  202.29kB/s    0:00:00 (xfr#1643, ir-chk=3359/5007)
data/cubes/lidar_xyza/6612.npy
         65.66K 100%  201.02kB/s    0:00:00 (xfr#1644, ir-chk=3358/5007)
data/cubes/lidar_xyza/6613.npy
         65.66K 100%  199.77kB/s    0:00:00 (xfr#1645, ir-chk=3357/5007)
data/cubes/lidar_xyza/6614.npy
         65.66K 100%  197.92kB/s    0:00:00 (xfr#1646, i

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.cuda.is_available()

True

In [None]:
SEED = 51
NUM_WORKERS = os.cpu_count()  # Number of CPU cores

BATCH_SIZE = 32
IMG_SIZE = 64

CLASSES = ["cubes", "spheres"]
LABEL_MAP = {"cubes": 0, "spheres": 1}

In [None]:
## Antonio
VALID_BATCHES = 10
N = 6500

# Integrate Wandb

In [None]:
# Load W&B API key from .env file and make it available as env variable
# from dotenv import load_dotenv
# load_dotenv()  # loads .env automatically

# os.environ["WANDB_API_KEY"]

In [None]:
# Load W&B API key from Colab Secrets and make it available as env variable
wandb_key = userdata.get('WANDB_API_KEY')
os.environ["WANDB_API_KEY"] = wandb_key

In [None]:
wandb.login()

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: Currently logged in as: [33mmichele-marschner[0m ([33mmichele-marschner-university-of-potsdam[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
def init_wandb(model, fusion_name, num_params, opt_name, batch_size=BATCH_SIZE, epochs=15):
  config = {
    # "embedding_size": embedding_size,      ## TODO: ändert die sich? hab ich die bei fusion?
    "optimizer_type": opt_name,
    "fusion_strategy": fusion_name,
    "model_architecture": model.__class__.__name__,
    "batch_size": batch_size,
    "num_epochs": epochs,
    "num_parameters": num_params
  }

  run = wandb.init(
    project="cilp-extended-assessment",
    name=f"{fusion_name}_run",
    config=config,
    reinit='finish_previous',                           # allows multiple runs in one script
  )

  return

# Reproducibility

In [None]:
def set_seeds(seed=SEED):
    """
    Set seeds for complete reproducibility across all libraries and operations.

    Args:
        seed (int): Random seed value
    """
    # Set environment variables before other imports
    os.environ['PYTHONHASHSEED'] = str(seed)
    os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'

    # Python random module
    random.seed(seed)

    # NumPy
    np.random.seed(seed)

    # PyTorch CPU
    torch.manual_seed(seed)

    # PyTorch GPU (all devices)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)  # For multi-GPU setups

        # CUDA deterministic operations
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    # OpenCV
    cv2.setRNGSeed(seed)

    # Albumentations (for data augmentation)
    try:
        A.seed_everything(seed)
    except AttributeError:
        # Older versions of albumentations
        pass

    # PyTorch deterministic algorithms (may impact performance)
    try:
        torch.use_deterministic_algorithms(True)
    except RuntimeError:
        # Some operations don't have deterministic implementations
        print("Warning: Some operations may not be deterministic")

    print(f"All random seeds set to {seed} for reproducibility")



# Usage: Call this function at the beginning and before each training phase
set_seeds(SEED)

# Additional reproducibility considerations:

def create_deterministic_training_dataloader(dataset, batch_size, shuffle=True, **kwargs):
    """
    Create a DataLoader with deterministic behavior.

    Args:
        dataset: PyTorch Dataset instance
        batch_size: Batch size
        shuffle: Whether to shuffle data
        **kwargs: Additional DataLoader arguments

    Returns:
        Training DataLoader with reproducible behavior
    """
    # Use a generator with fixed seed for reproducible shuffling
    generator = torch.Generator()
    generator.manual_seed(51)

    return torch.utils.data.DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=shuffle,
        generator=generator if shuffle else None,
        **kwargs
    )



All random seeds set to 51 for reproducibility


# Utility Functions

In [None]:
def format_time(seconds):
    m = int(seconds // 60)
    s = int(seconds % 60)
    return f"{m:02d}m {s:02d}s"

In [None]:
def hex_to_RGB(hex_str):
    """ #FFFFFF -> [255,255,255]"""
    #Pass 16 to the integer function for change of base
    return [int(hex_str[i:i+2], 16) for i in range(1,6,2)]


def get_color_gradient(c1, c2, c3, n1, n2):
    """
    Given two hex colors, returns a color gradient
    with n colors.
    """
    c1_rgb = np.array(hex_to_RGB(c1))/255
    c2_rgb = np.array(hex_to_RGB(c2))/255
    c3_rgb = np.array(hex_to_RGB(c3))/255
    mix_pcts_c1_c2 = [x/(n1-1) for x in range(n1)]
    mix_pcts_c2_c3 = [x/(n2-1) for x in range(n2)]
    rgb_c1_c2 = [((1-mix)*c1_rgb + (mix*c2_rgb)) for mix in mix_pcts_c1_c2]
    rgb_c2_c3 = [((1-mix)*c2_rgb + (mix*c3_rgb)) for mix in mix_pcts_c2_c3]
    rgb_colors = rgb_c1_c2 + rgb_c2_c3
    return ["#" + "".join([format(int(round(val*255)), "02x") for val in item]) for item in rgb_colors]


cmap = colors.ListedColormap(get_color_gradient("#000000", "#76b900", "#f1ffd9", 64, 128))


In [None]:
def get_outputs(model, batch, inputs_idx):
    inputs = batch[inputs_idx].to(device)
    target = batch[-1].to(device)
    outputs = model(inputs)
    return outputs, target

In [None]:
def get_torch_xyza(lidar_depth, azimuth, zenith):
    x = lidar_depth * torch.sin(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    y = lidar_depth * torch.cos(-azimuth[:, None]) * torch.cos(-zenith[None, :])
    z = lidar_depth * torch.sin(-zenith[None, :])
    a = torch.where(lidar_depth < 50.0, torch.ones_like(lidar_depth), torch.zeros_like(lidar_depth))
    xyza = torch.stack((x, y, z, a))
    return xyza

In [None]:
def format_positions(positions):
    return ['{0: .8f}'.format(x) for x in positions]

In [None]:
def train_model(model, optimizer, input_fn, loss_fn, epochs, train_dataloader, val_dataloader, model_save_path, target_idx=-1, log_to_wandb=False, model_name=None):
    train_losses = []
    valid_losses = []
    epoch_times = []

    best_val_loss = float('inf')
    best_model = None

    # for GPU memory tracking
    max_gpu_mem_mb = 0.0
    use_cuda = torch.cuda.is_available()

    if use_cuda:
        torch.cuda.reset_peak_memory_stats()

    for epoch in tqdm(range(epochs)):
        start_time = time.time()                  # to track the train time per model
        print(f"Epoch and start time: {epoch} und {start_time}")
        model.train()
        train_loss = 0
        for step, batch in enumerate(train_dataloader):

            rgb, lidar_xyza, position = batch
            rgb = rgb.to(device)
            lidar_xyza = lidar_xyza.to(device)
            position = position.to(device)

            optimizer.zero_grad()
            target = batch[target_idx].to(device)
            outputs = model(*input_fn(batch))

            loss = loss_fn(outputs, target)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss = train_loss / (step + 1)
        train_losses.append(train_loss)
        print_loss(epoch, train_loss, outputs, target, is_train=True)

        # ----- validation -----
        model.eval()
        valid_loss = 0
        with torch.no_grad():
          for step, batch in enumerate(val_dataloader):
              target = batch[target_idx].to(device)
              outputs = model(*input_fn(batch))
              valid_loss += loss_fn(outputs, target).item()
        valid_loss = valid_loss / (step + 1)
        valid_losses.append(valid_loss)
        print_loss(epoch, valid_loss, outputs, target, is_train=False)

        if valid_loss < best_val_loss:
          best_val_loss = valid_loss
          best_model = model
          # Save the best model
          torch.save(best_model.state_dict(), model_save_path)
          print('Found and saved better weights for the model')

        # timing
        epoch_time = time.time() - start_time
        epoch_time_formatted = format_time(epoch_time)
        epoch_times.append(epoch_time_formatted)

        # GPU memory
        if use_cuda:
            gpu_mem_mb = torch.cuda.max_memory_allocated() / (1024 ** 2)
            max_gpu_mem_mb = max(max_gpu_mem_mb, gpu_mem_mb)

        # wandb logging
        if log_to_wandb:
            wandb.log(
                {
                    "model": model.__class__.__name__,
                    "epoch": epoch + 1,
                    "train_loss": train_loss,
                    "valid_loss": valid_loss,
                    "lr": optimizer.param_groups[0]["lr"],
                    "epoch_time": epoch_time_formatted,
                    "max_gpu_mem_mb_epoch": gpu_mem_mb if use_cuda else 0.0,
                }
            )

    return train_losses, valid_losses, epoch_times, max_gpu_mem_mb

In [None]:
def print_loss(epoch, loss, outputs, target, is_train=True, is_debug=False):
    loss_type = "train loss:" if is_train else "valid loss:"
    print("epoch", str(epoch), loss_type, str(loss))
    if is_debug:
        print("example pred:", format_positions(outputs[0].tolist()))
        print("example real:", format_positions(target[0].tolist()))

In [None]:
def plot_losses(losses, title="Training & Validation Loss Comparison", figsize=(10,6)):
    plt.figure(figsize=figsize)

    for model_name, log in losses.items():
        train = log["train_losses"]
        valid = log["valid_losses"]

        # plot train + valid with different line styles
        plt.plot(train, label=f"{model_name} - train", linewidth=2)
        plt.plot(valid, label=f"{model_name} - valid", linestyle="--", linewidth=2)

    plt.title(title, fontsize=16)
    plt.xlabel("Epochs", fontsize=14)
    plt.ylabel("Loss", fontsize=14)
    plt.grid(True, alpha=0.3)
    plt.legend()
    plt.tight_layout()
    plt.show()

# Load and prepare Data

In [None]:
class ReplicatorDataset(Dataset):
    def __init__(self, root_dir, start_idx, stop_idx, transform):
        self.root_dir = Path(root_dir)

        # indices this dataset will cover
        self.indices = list(range(start_idx, stop_idx))

        # positions: (N, 3) or (N, 4) depending on file
        all_positions = pd.read_csv(
            "https://github.com/andandandand/practical-computer-vision/blob/main/artifacts/positions.csv?raw=1"
        ).values
        self.positions = all_positions[start_idx:stop_idx]

        # azimuth / zenith loaded once
        azimuth_path = self.root_dir / "azimuth.npy"
        zenith_path = self.root_dir / "zenith.npy"

        if not azimuth_path.exists():
            raise FileNotFoundError(f"azimuth.npy not found at {azimuth_path}")
        if not zenith_path.exists():
            raise FileNotFoundError(f"zenith.npy not found at {zenith_path}")

        azimuth = np.load(azimuth_path)
        zenith = np.load(zenith_path)
        # keep them as CPU tensors; move to device in training if needed
        self.azimuth = torch.from_numpy(azimuth)      # shape (H,)
        self.zenith = torch.from_numpy(zenith)        # shape (W,)

        # dirs
        self.rgb_dir = self.root_dir / "rgb"
        self.lidar_dir = self.root_dir / "lidar"

        self.transform = transform

    def __len__(self):
        return len(self.indices)

    def __getitem__(self, idx):
        # map dataset index -> real frame index
        frame_idx = self.indices[idx]
        file_number = f"{frame_idx:04d}"

        # --- RGB ---
        rgb_path = self.rgb_dir / f"{file_number}.png"
        rgb_img = Image.open(rgb_path)
        rgb_tensor = self.transform(rgb_img)   # still on CPU

        # --- LiDAR depth ---
        lidar_path = self.lidar_dir / f"{file_number}.npy"
        lidar_depth = np.load(lidar_path)      # (H, W)
        lidar_depth = torch.from_numpy(lidar_depth).to(torch.float32)  # CPU

        # --- XYZA ---
        lidar_xyza = get_torch_xyza(lidar_depth, self.azimuth, self.zenith)  # (4, H, W)

        # --- position ---
        position_np = self.positions[idx]     # numpy row
        position = torch.from_numpy(position_np).to(torch.float32)  # CPU

        return rgb_tensor, lidar_xyza, position

In [None]:
class AssessmentXYZADataset(Dataset):
    def __init__(self, root_dir, start_idx=0, end_idx=None,
                 transform_rgb=None, transform_lidar=None, shuffle=True):

        self.root_dir = Path(root_dir)
        self.transform_rgb = transform_rgb
        self.transform_lidar = transform_lidar

        self.classes = ["cubes", "spheres"]
        self.label_map = {"cubes": 0, "spheres": 1}

        samples = []

        print(f"Scanning dataset in {root_dir}...")
        for cls in self.classes:
            cls_dir = self.root_dir / cls
            rgb_dir = cls_dir / "rgb"
            lidar_dir = cls_dir / "lidar_xyza"

            rgb_files = sorted(rgb_dir.glob("*.png"))
            print(f"{cls}: {len(rgb_files)} RGB files found. Matching XYZA...")

            for rgb_path in tqdm(rgb_files, desc=f"{cls} matching", leave=False):
                stem = rgb_path.stem
                lidar_path = lidar_dir / f"{stem}.npy"
                if lidar_path.exists():
                    samples.append({
                        "rgb": rgb_path,
                        "lidar_xyza": lidar_path,
                        "label": self.label_map[cls],
                    })

        if shuffle:
            rng = random.Random(SEED)
            rng.shuffle(samples)

        if end_idx is None:
            end_idx = len(samples)
        self.samples = samples[start_idx:end_idx]

        print(f"Preloading LiDAR XYZA tensors into RAM...")
        self.lidar_tensors = []
        for item in tqdm(self.samples, desc="Loading XYZA", leave=False):
            lidar_np = np.load(item["lidar_xyza"])        # (4, H, W)
            lidar_t  = torch.from_numpy(lidar_np).float() # CPU tensor
            self.lidar_tensors.append(lidar_t)

        print(
            f"Dataset ready: {len(self.samples)} samples loaded.\n"
            f"Slice [{start_idx}:{end_idx}]"
        )

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        item  = self.samples[idx]
        lidar = self.lidar_tensors[idx]   # already a tensor

        # RGB
        rgb = Image.open(item["rgb"])
        if self.transform_rgb:
            rgb = self.transform_rgb(rgb)

        if self.transform_lidar:
            lidar = self.transform_lidar(lidar)

        label = torch.tensor(item["label"], dtype=torch.long)
        return rgb, lidar, label


In [None]:
## old
def compute_mean_std_chatty(dataset):
    loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)

    mean = 0.
    std = 0.
    total = 0

    for images, _, _ in tqdm(loader, desc="Computing mean/std"):
        images = images.float()       # B, C, H, W
        batch_size = images.size(0)

        # compute mean over batch (channels only!)
        mean += images.mean(dim=[0, 2, 3]) * batch_size

        # compute std over batch
        std += images.std(dim=[0, 2, 3]) * batch_size

        total += batch_size

    mean /= total
    std /= total

    return mean, std


In [None]:
## Link: https://www.kozodoi.me/blog/computing-mean-std-in-image-dataset
def compute_mean_std(dataset):
    loader = DataLoader(dataset, batch_size=64, shuffle=False, num_workers=2)

    psum = None
    psum_sq = None
    total_pixels = 0

    for batch in tqdm(loader):
        imgs = batch[0]                       # (B, C, H, W), C can be 3 or 4
        # If first batch: initialize accumulators with correct channel size
        if psum is None:
            C = imgs.size(1)
            psum = torch.zeros(C, device=imgs.device)
            psum_sq = torch.zeros(C, device=imgs.device)

        # accumulate sums
        B, C, H, W = imgs.shape
        psum += imgs.sum(dim=[0, 2, 3])
        psum_sq += (imgs ** 2).sum(dim=[0, 2, 3])
        total_pixels += B * H * W

    mean = psum / total_pixels
    var = psum_sq / total_pixels - mean ** 2
    std = torch.sqrt(var)

    # move to CPU for transforms.Normalize
    return mean.cpu(), std.cpu()

In [None]:
stats_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),  # [0,1], 4 channels
])

In [None]:
#root = DATA_PATH / "replicator_data_cubes"

#stats_dataset = ReplicatorDataset(
#    root_dir=root,
#    start_idx=0,
#    stop_idx=N,          # or e.g. 1000 to subsample
#    transform=stats_transforms,
#)

In [None]:
# root = STORAGE_PATH / "data"

# stats_dataset = AssessmentXYZADataset(
#    root_dir=root,
#    start_idx=0,
#    end_idx=None,          # or e.g. 1000 to subsample
#    transform_rgb=stats_transforms,
# )

AssessmentXYZADataset: 12500 samples loaded from /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/data (slice 0:12500)


In [None]:
#mean, std = compute_mean_std(stats_dataset)
# print(mean, std)

In [None]:
#from torch.utils.data import Subset

#NUM_SAMPLES_FOR_STATS = 2000  # e.g. 2k out of 12k

#N = len(stats_dataset)
#num = min(NUM_SAMPLES_FOR_STATS, N)
#indices = np.random.choice(N, size=num, replace=False)

#subset_for_stats = Subset(stats_dataset, indices)


#mean, std = compute_mean_std_chatty(subset_for_stats)
#print(mean, std)

Computing mean/std: 100%|██████████| 32/32 [03:43<00:00,  6.99s/it]

tensor([0.0051, 0.0052, 0.0051, 1.0000]) tensor([5.8023e-02, 5.8933e-02, 5.8108e-02, 2.4509e-07])





In [None]:
img_transforms = transforms.Compose([
    transforms.ToImage(),   # Scales data into [0,1]    ## TODO: transforms.v2?
    transforms.Resize(IMG_SIZE),
    transforms.ToDtype(torch.float32, scale=True),
    #transforms.Normalize(([0.0138, 0.0137, 0.0132, 1.0000]), ([9.4118e-02, 9.3919e-02, 9.1684e-02, 4.9019e-07])),    ## replicator dataset
    transforms.Normalize(([0.0051, 0.0052, 0.0051, 1.0000]), ([5.8023e-02, 5.8933e-02, 5.8108e-02, 2.4509e-07]))     ## assessment dataset
])

In [None]:
def get_dataloaders(root_dir):
    train_data = AssessmentXYZADataset(root_dir, 0, N-VALID_BATCHES*BATCH_SIZE, img_transforms)
    train_dataloader = create_deterministic_training_dataloader(train_data, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
    valid_data = AssessmentXYZADataset(root_dir, N-VALID_BATCHES*BATCH_SIZE, N, img_transforms)
    val_dataloader = DataLoader(valid_data, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)
    return train_data, train_dataloader, valid_data, val_dataloader


In [None]:
train_data, train_dataloader, valid_data, val_dataloader = get_dataloaders(str(STORAGE_PATH / "data"))
#train_data, train_dataloader, valid_data, val_dataloader = get_dataloaders("/content/data")

for i, sample in enumerate(train_data):
    print(i, *(x.shape for x in sample))
    break

Scanning dataset in /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/data...
cubes: 2501 RGB files found. Matching XYZA...




spheres: 9999 RGB files found. Matching XYZA...




Preloading LiDAR XYZA tensors into RAM...




Dataset ready: 12179 samples loaded.
Slice [0:12179]
Scanning dataset in /content/drive/MyDrive/Colab Notebooks/Applied Computer Vision/Applied-Computer-Vision-Projects/Multimodal_Learning_02/data...
cubes: 2501 RGB files found. Matching XYZA...




spheres: 9999 RGB files found. Matching XYZA...




Preloading LiDAR XYZA tensors into RAM...


                                                               

Dataset ready: 320 samples loaded.
Slice [12179:12499]
0 torch.Size([4, 64, 64]) torch.Size([4, 64, 64]) torch.Size([])




# Create Models

Take the EmbedderMaxPool architecture from the workshop and turn it into an encoder that outputs an embedding instead of 9 positions.

## Early Fusion Model

**Concept:** Fuse modalities before any deep processing — usually by concatenating channels or inputs.

```
input = concat(RGB, XYZ)  → shape (8, H, W)
-> shared CNN processes everything together
```



**Advantages:**

* **Captures Early Cross-Modal Interactions:** Learns joint low-level correlations directly from raw signals.
* **Simple & Lightweight**: Easiest fusion method to implement; minimal architectural overhead.
* **Effective with Perfect Alignment:** Works well when modalities are tightly synchronized and spatially aligned.

**Limitations:**

* **Noise Sensitivity:** One noisy or corrupted modality directly contaminates the shared feature space.
* **Strict Alignment Requirement:** Modalities must have matching spatial resolution, alignment, and synchronization.
* **Feature Space Mismatch:** Raw modalities differ in scale, units, and distribution; one modality can dominate without careful normalization.
* **High Input Dimensionality:** Channel concatenation increases the input size and can require more data and compute to train effectively.
* **Limited Flexibility:** Assumes combining low-level signals is beneficial; may underperform when modalities carry different types of information.

In [None]:
output_dim = 2

class EmbedderMaxPool(nn.Module):
    """
    Embedder where all spatial downsampling is done
    via MaxPool2d.
    """
    def __init__(self, in_ch, feature_dim=200):
        kernel_size = 3
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 50, kernel_size, padding=1)
        self.conv2 = nn.Conv2d(50, 100, kernel_size, padding=1)
        self.conv3 = nn.Conv2d(100, feature_dim, kernel_size, padding=1)
        self.pool = nn.MaxPool2d(2)
        self.flatten_dim = 200 * 8 * 8


    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch

        return x

In [None]:
class FullyConnectedHead(nn.Module):
    """
    The fully connection layer(s) takes flattened features and predicts XX.
    """
    def __init__(self, input_dim, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, 1000)
        self.fc2 = nn.Linear(1000, 100)
        self.fc3 = nn.Linear(100, output_dim)

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)

        return x

In [None]:
class EarlyFusionModel(nn.Module):
    def __init__(self, in_ch=8, output_dim=2):
        super().__init__()

        # 1. Embedder
        self.embedder = EmbedderMaxPool(in_ch)

        # 2. FullyConnected head
        self.fullyConnected = FullyConnectedHead(
            input_dim=self.embedder.flatten_dim,
            output_dim=output_dim
        )

    def forward(self, x):
        # x shape: (B, 8, 64, 64)
        features = self.embedder(x)     # → (B, 12800)
        preds = self.fullyConnected(features)  # → (B, 9)
        return preds

## Intermediate Fusion Models

**Concept:** Each modality has its own encoder / feature extractor, and fusion happens after some layers but before classification.

```
RGB → RGB_conv → RGB_features (C, H, W)
LiDAR → LiDAR_conv → LiDAR_features (C, H, W)

Fusion → joint_features → FC → output
```



**Advantages:**

* **Specialized Processing:** Each modality gets its own encoder, tailored to its characteristics.
* **Learned Representations:** Fusion occurs on higher-level, more discriminative features rather than raw data.
* **Flexible Design:** The fusion point can be chosen at different network depths, allowing fine-grained architectural control.
* **Easily Extendable:** New modalities can be added by including additional modality-specific branches.


**Limitations:**

* **Architectural Complexity:** Requires designing separate modality-specific encoders and choosing an appropriate fusion point.
* **Higher Computational Cost:** More expensive than early fusion due to duplicated feature extractors.
* **Fusion Design Sensitivity:** Performance depends on the chosen fusion mechanism (concat, addition, multiplicative, bilinear, attention), which often requires experimentation.
* **Depth Selection Challenge:** Deciding how much unimodal processing to perform before fusion can be non-trivial and task-dependent.

Implemented 4 variants:

*   Concatenation
*   Addition
* Hadamard Product (element-wise multiplication)
* Matrix-Multiplication



| Fusion Method | Advantages | Limitations |
|---------------|------------|-------------|
| **Concatenation** | - Very expressive and flexible<br>- Lets the network learn arbitrary cross-modal interactions<br>- Robust and widely used baseline | - Doubles channel count → more parameters & memory<br>- Computationally heavier<br>- Fusion is unguided; model must discover interactions itself |
| **Addition** | - Lightweight (no increase in channels)<br>- Fast and parameter-efficient<br>- Enforces similar feature spaces between modalities | - Assumes features are aligned and comparable<br>- One noisy modality corrupts the other<br>- Sensitive to scale differences between modalities |
| **Multiplicative (Hadamard Product)** | - Gating effect: highlights features important in *both* modalities<br>- More expressive than addition, cheaper than concat<br>- Natural for attention-like fusion | - Suppresses features when one modality has low magnitude<br>- Requires careful normalization<br>- Can amplify noise if both activations are high |
| **Matrix Multiplication (Bilinear-like)** | - Captures rich pairwise correlations between modalities<br>- Most expressive among all four<br>- Enables true 2nd-order interaction learning | - Very heavy in compute & memory<br>- Requires flattening or dimensionality reduction<br>- Easily overfits; harder to train and tune |


In [None]:
class ConcatIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        output_dim = 2
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # Calculate combined dimension
        # (200 * 8 * 8) + (200 * 8 * 8)
        combined_dim = self.rgb_encoder.flatten_dim + self.xyz_encoder.flatten_dim

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=combined_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # 1. Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # 2. Fuse (Concatenate) at the feature level
        x_fused = torch.cat((x_rgb, x_xyz), dim=1)                      # (B, 2*D)

        # 3. Predict
        output = self.head(x_fused)

        return output

In [None]:
class AddIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        output_dim = 2
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # For addition, shapes must match
        fused_dim = self.rgb_encoder.flatten_dim                        # same size after addition

        # Shared Regression Head
        self.head = FullyConnectedHead(input_dim=fused_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # 1. Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # 2. Additive fusion in feature space
        x_fused = x_rgb + x_xyz                                         # (B, D)

        # 3. Predict
        output = self.head(x_fused)                                     # (B, output_dim)

        return output

In [None]:
class MatmulIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        output_dim = 2
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # For multiplication, shapes must match
        embedding_dim = self.rgb_encoder.flatten_dim
        fused_dim = embedding_dim * embedding_dim                       # D * D after matmul

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=fused_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # 1. Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # 2. Matrix multiplication: (B, D, 1) @ (B, 1, D)
        x_fused = torch.matmul(x_rgb.unsqueeze(2), x_xyz.unsqueeze(1))  # (B, D, D)
        x_fused = x_fused.flatten(start_dim=1)                          # (B, D*D)

        # 3. Predict
        output = self.head(x_fused)                                     # (B, output_dim)

        return output

In [None]:
class HadamardIntermediateNet(nn.Module):
    def __init__(self, rgb_ch, xyz_ch):
        output_dim = 2
        super().__init__()

        # Independent Encoders
        # RGB learns textures/colors
        self.rgb_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128
        # LiDAR learns geometry/depth
        self.xyz_encoder = EmbedderMaxPool(in_ch=4, feature_dim=200)    ## TODO: warum original feature_dim=128

        # For elementwise multiplication, shapes must match
        fused_dim = self.rgb_encoder.flatten_dim                        # same size after addition

        # Shared FullyConnected Head
        self.head = FullyConnectedHead(input_dim=fused_dim, output_dim=output_dim)

    def forward(self, x_rgb, x_xyz):
        # 1. Extract features independently
        x_rgb = self.rgb_encoder(x_rgb)                                 # (B, D)
        x_xyz = self.xyz_encoder(x_xyz)                                 # (B, D)

        # 2. Multiplicative / gating-like fusion
        x_fused = x_rgb * x_xyz                                         # shape: (B, D)

        # 3. Predict
        output = self.head(x_fused)                                     # (B, output_dim)

        return output

## Late Fusion Model

**Concept:** Each modality is processed completely separately, and only the final predictions or high-level embeddings are fused.

```
RGB → RGB-Embedder → logits_rgb
LiDAR → LiDAR-Embedder → logits_lidar

Fusion → final decision
```

**Advantages:**

* **Robust to Missing Modalities:** The system can still operate if one modality is noisy, unreliable, or absent.
* **Best for Heterogeneous Modalities:** Works well when modalities differ greatly.
* **Modular & Simple:** Unimodal models can be trained, debugged, and replaced independently.
* **Leverages Existing Models:** Allows the reuse of strong off-the-shelf unimodal experts without architectural changes.


**Limitations:**

* **Missed Interactions:** No joint feature learning — modalities never influence each other during representation learning.
* **Limited Expressiveness:** Simple fusion rules (e.g., averaging, weighted sum) cannot capture complex cross-modal relationships.
* **Information Loss:** By the time unimodal predictors output logits/embeddings, rich spatial and semantic details may already be discarded, limiting the power of fusion.

In [None]:
rgb_net = EmbedderMaxPool(4).to(device)
xyz_net = EmbedderMaxPool(4).to(device)

## TODO: passiert das woanders nicht?
networks = [rgb_net, xyz_net]

class LateNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.rgb = rgb_net
        self.xyz = xyz_net
        #self.fc1 = nn.Linear(num_positions * len(networks), num_positions * 10)
        #self.fc2 = nn.Linear(num_positions * 10, num_positions)

        # each embedder outputs flatten_dim (e.g. 12800)
        fusion_dim = self.rgb.flatten_dim * 2  # rgb + xyz

        # single FullyConnected that sees the fused features
        self.fullyConnected = FullyConnectedHead(
            input_dim=fusion_dim,
            output_dim=output_dim,
        )

    def forward(self, x_rgb, x_xyz):
        x_rgb = self.rgb(x_rgb)     # (B, 12800)
        x_xyz = self.xyz(x_xyz)     # (B, 12800)

        # this concatenates the features from the two branches
        x_fused = torch.cat((x_rgb, x_xyz), dim=1)    # (B, 25600)

        preds = self.fullyConnected(x_fused)           # (B, 9)
        return preds

        #x = F.relu(self.fc1(x))
        #x = self.fc2(x)
        # return x

# Model Training

In [None]:
def get_early_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    inputs_mm_early = torch.cat((inputs_rgb, inputs_xyz), 1)
    return (inputs_mm_early,)

In [None]:
def get_inputs(batch):
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    return (inputs_rgb, inputs_xyz)

In [None]:
def compute_class_weights(train_labels=None):

  # Handle class imbalance
  # Count how many samples per class
  n_cubes = 2500
  n_spheres = 9999
  class_counts = torch.tensor([n_cubes, n_spheres], dtype=torch.float32)

  # Inverse-frequency class weights (upsample rare classes)
  class_weights = class_counts.sum() / (class_counts + 1e-6)
  class_weights = class_weights / class_weights.mean()  # normalize

  print("class_weights:", class_weights)

  return class_weights

In [None]:
class_weights = compute_class_weights()

class_weights: tensor([1.6000, 0.4000])


In [None]:
set_seeds(SEED)

EPOCHS = 20
LR = 0.0001

#loss_func = nn.MSELoss()
# loss_func = nn.CrossEntropyLoss()
loss_func = nn.CrossEntropyLoss(weight=class_weights.to(device))

metrics = {}   # store losses for each model

models_to_train = {
    "early_fusion": EarlyFusionModel(in_ch=8, output_dim=2).to(device),
    "intermediate_fusion_concat": ConcatIntermediateNet(4, 4).to(device),
    #"intermediate_fusion_matmul": MatmulIntermediateNet(4, 4).to(device),
    #"intermediate_fusion_hadamard": HadamardIntermediateNet(4, 4).to(device),
    "intermediate_fusion_add": AddIntermediateNet(4, 4).to(device),
    "late_fusion": LateNet().to(device),
}

checkpoint_dir = STORAGE_PATH / "checkpoints"
checkpoint_dir.mkdir(exist_ok=True)

for name, model in models_to_train.items():
  model_save_path = checkpoint_dir / f"{name}.pth"

  # metrics for comparison table
  num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

  opt = Adam(model.parameters(), lr=LR)

  # initialize wandb
  init_wandb(
      model=model,
      fusion_name=name,
      num_params=num_params,
      opt_name = opt.__class__.__name__)

  if name == "early_fusion":
    input_fn = get_early_inputs
  else:
    input_fn = get_inputs

  train_losses, valid_losses, epoch_times, max_gpu_mem_mb = train_model(
    model=model,
    optimizer=opt,
    input_fn=input_fn,
    epochs=EPOCHS,
    loss_fn=loss_func,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    model_save_path=model_save_path,
    target_idx=-1,   # last element in batch is target
    log_to_wandb=True,
    model_name=name
  )

  metrics[name] = {
      "train_losses": train_losses,
      "valid_losses": valid_losses,
      "epoch_times": epoch_times,
      "best_valid_loss": min(valid_losses),
      "max_gpu_mem_mb": max_gpu_mem_mb,
      "num_params": num_params,
  }

  # End wandb run
  wandb.finish()

All random seeds set to 51 for reproducibility


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch and start time: 0 und 1764836897.8672874
epoch 0 train loss: 0.5585812245152498


  5%|▌         | 1/20 [01:05<20:52, 65.90s/it]

epoch 0 valid loss: 0.3816437155008316
Found and saved better weights for the model
Epoch and start time: 1 und 1764836963.7771957
epoch 1 train loss: 0.17547983887948487


 10%|█         | 2/20 [01:57<17:14, 57.49s/it]

epoch 1 valid loss: 0.029946841485798358
Found and saved better weights for the model
Epoch and start time: 2 und 1764837015.3744483
epoch 2 train loss: 0.10920584960411744


 15%|█▌        | 3/20 [02:48<15:27, 54.55s/it]

epoch 2 valid loss: 0.013084906525909901
Found and saved better weights for the model
Epoch and start time: 3 und 1764837066.4209025
epoch 3 train loss: 0.011336241636768376


 20%|██        | 4/20 [03:39<14:08, 53.05s/it]

epoch 3 valid loss: 0.0017347613931633532
Found and saved better weights for the model
Epoch and start time: 4 und 1764837117.1689832
epoch 4 train loss: 0.005309679763520666


 25%|██▌       | 5/20 [04:30<13:05, 52.36s/it]

epoch 4 valid loss: 0.004932388651650399
Epoch and start time: 5 und 1764837168.31144
epoch 5 train loss: 0.0026075686581335552


 30%|███       | 6/20 [05:20<12:00, 51.46s/it]

epoch 5 valid loss: 0.0006988686356635299
Found and saved better weights for the model
Epoch and start time: 6 und 1764837218.021929
epoch 6 train loss: 0.00025436518726325486


 35%|███▌      | 7/20 [06:10<11:05, 51.21s/it]

epoch 6 valid loss: 0.0002015374544498627
Found and saved better weights for the model
Epoch and start time: 7 und 1764837268.7210448
epoch 7 train loss: 0.0001364297884145727


 40%|████      | 8/20 [07:02<10:14, 51.19s/it]

epoch 7 valid loss: 0.00016747532272347597
Found and saved better weights for the model
Epoch and start time: 8 und 1764837319.8715165
epoch 8 train loss: 8.988414320140835e-05


 45%|████▌     | 9/20 [07:53<09:22, 51.16s/it]

epoch 8 valid loss: 0.00015562504977424397
Found and saved better weights for the model
Epoch and start time: 9 und 1764837370.9727955
epoch 9 train loss: 6.436552277687281e-05


 50%|█████     | 10/20 [08:43<08:29, 50.91s/it]

epoch 9 valid loss: 0.00011407112365304783
Found and saved better weights for the model
Epoch and start time: 10 und 1764837421.3113136
epoch 10 train loss: 4.312989108795687e-05


 55%|█████▌    | 11/20 [09:34<07:38, 50.92s/it]

epoch 10 valid loss: 7.791419648128795e-05
Found and saved better weights for the model
Epoch and start time: 11 und 1764837472.2557786
epoch 11 train loss: 2.9460039861293664e-05


 60%|██████    | 12/20 [10:25<06:47, 50.90s/it]

epoch 11 valid loss: 5.549872456640514e-05
Found and saved better weights for the model
Epoch and start time: 12 und 1764837523.103322
epoch 12 train loss: 2.0776022850106145e-05


 65%|██████▌   | 13/20 [11:16<05:56, 50.94s/it]

epoch 12 valid loss: 6.000027776735806e-05
Epoch and start time: 13 und 1764837574.1413097
epoch 13 train loss: 1.60574520063356e-05


 70%|███████   | 14/20 [12:06<05:04, 50.68s/it]

epoch 13 valid loss: 4.276095007185177e-05
Found and saved better weights for the model
Epoch and start time: 14 und 1764837624.2296546
epoch 14 train loss: 1.1659553273166e-05


 75%|███████▌  | 15/20 [12:56<04:13, 50.66s/it]

epoch 14 valid loss: 3.306014577901806e-05
Found and saved better weights for the model
Epoch and start time: 15 und 1764837674.8316007
epoch 15 train loss: 8.856823297408387e-06


 80%|████████  | 16/20 [13:48<03:23, 50.79s/it]

epoch 15 valid loss: 2.6042740370257888e-05
Found and saved better weights for the model
Epoch and start time: 16 und 1764837725.9242024
epoch 16 train loss: 6.665781138269123e-06


 85%|████████▌ | 17/20 [14:38<02:32, 50.75s/it]

epoch 16 valid loss: 2.5249311767083783e-05
Found and saved better weights for the model
Epoch and start time: 17 und 1764837776.5988429
epoch 17 train loss: 5.158981888797598e-06


 90%|█████████ | 18/20 [15:29<01:41, 50.68s/it]

epoch 17 valid loss: 2.026262887682151e-05
Found and saved better weights for the model
Epoch and start time: 18 und 1764837827.0930145
epoch 18 train loss: 3.994675644117993e-06


 95%|█████████▌| 19/20 [16:20<00:50, 50.74s/it]

epoch 18 valid loss: 1.3670982637847828e-05
Found and saved better weights for the model
Epoch and start time: 19 und 1764837877.9832208
epoch 19 train loss: 2.8623905338888202e-06
epoch 19 valid loss: 1.1013141453020125e-05
Found and saved better weights for the model


100%|██████████| 20/20 [17:11<00:00, 51.57s/it]


0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▃▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20
epoch_time,00m 51s
lr,0.0001
max_gpu_mem_mb_epoch,980.10352
model,EarlyFusionModel
train_loss,0.0
valid_loss,1e-05


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch and start time: 0 und 1764837931.5466022
epoch 0 train loss: 0.4023977580039125
epoch 0 valid loss: 0.1252416606992483


  5%|▌         | 1/20 [01:05<20:39, 65.23s/it]

Found and saved better weights for the model
Epoch and start time: 1 und 1764837996.7800279
epoch 1 train loss: 0.07579853330450868
epoch 1 valid loss: 0.013657275587320328


 10%|█         | 2/20 [02:06<18:50, 62.79s/it]

Found and saved better weights for the model
Epoch and start time: 2 und 1764838057.862651
epoch 2 train loss: 0.017279307984254526


 15%|█▌        | 3/20 [03:05<17:21, 61.28s/it]

epoch 2 valid loss: 0.02397695678519085
Epoch and start time: 3 und 1764838117.3398602
epoch 3 train loss: 0.01198422223385837
epoch 3 valid loss: 0.00238522933650529


 20%|██        | 4/20 [04:03<16:00, 60.03s/it]

Found and saved better weights for the model
Epoch and start time: 4 und 1764838175.4555688
epoch 4 train loss: 0.002504486173491291
epoch 4 valid loss: 0.0003450334865192417


 25%|██▌       | 5/20 [05:05<15:07, 60.51s/it]

Found and saved better weights for the model
Epoch and start time: 5 und 1764838236.8193915
epoch 5 train loss: 0.00026974697584831265
epoch 5 valid loss: 0.00016920034340728307


 30%|███       | 6/20 [06:05<14:07, 60.55s/it]

Found and saved better weights for the model
Epoch and start time: 6 und 1764838297.4533145
epoch 6 train loss: 8.889130799855736e-05


 35%|███▌      | 7/20 [07:05<13:04, 60.38s/it]

epoch 6 valid loss: 0.0001773521753420937
Epoch and start time: 7 und 1764838357.487715
epoch 7 train loss: 5.088366988005825e-05


 40%|████      | 8/20 [08:03<11:53, 59.45s/it]

epoch 7 valid loss: 0.00020980342314942392
Epoch and start time: 8 und 1764838414.9238555
epoch 8 train loss: 2.9393170569098977e-05
epoch 8 valid loss: 0.00012579327200228362


 45%|████▌     | 9/20 [09:02<10:51, 59.25s/it]

Found and saved better weights for the model
Epoch and start time: 9 und 1764838473.7418413
epoch 9 train loss: 1.887609219774109e-05
epoch 9 valid loss: 7.460560629510838e-05


 50%|█████     | 10/20 [10:02<09:56, 59.69s/it]

Found and saved better weights for the model
Epoch and start time: 10 und 1764838534.4294279
epoch 10 train loss: 1.2694496766354596e-05
epoch 10 valid loss: 7.117384014918571e-05


 55%|█████▌    | 11/20 [11:03<08:59, 59.94s/it]

Found and saved better weights for the model
Epoch and start time: 11 und 1764838594.9411561
epoch 11 train loss: 8.015743252012965e-06
epoch 11 valid loss: 4.045499504456984e-05


 60%|██████    | 12/20 [12:03<07:59, 59.93s/it]

Found and saved better weights for the model
Epoch and start time: 12 und 1764838654.845683
epoch 12 train loss: 5.472998673246372e-06


 65%|██████▌   | 13/20 [13:02<06:58, 59.73s/it]

epoch 12 valid loss: 6.825777649197562e-05
Epoch and start time: 13 und 1764838714.1006541
epoch 13 train loss: 4.00395536091181e-06
epoch 13 valid loss: 2.7172749059900526e-05


 70%|███████   | 14/20 [14:01<05:56, 59.43s/it]

Found and saved better weights for the model
Epoch and start time: 14 und 1764838772.846917
epoch 14 train loss: 2.803049287702919e-06
epoch 14 valid loss: 2.043707117138638e-05


 75%|███████▌  | 15/20 [15:02<04:59, 59.90s/it]

Found and saved better weights for the model
Epoch and start time: 15 und 1764838833.838495
epoch 15 train loss: 2.1052451830752113e-06
epoch 15 valid loss: 1.2946841614791538e-05


 80%|████████  | 16/20 [16:03<04:00, 60.23s/it]

Found and saved better weights for the model
Epoch and start time: 16 und 1764838894.8195002
epoch 16 train loss: 1.5366288710747405e-06


 85%|████████▌ | 17/20 [17:02<03:00, 60.01s/it]

epoch 16 valid loss: 1.6610488451362926e-05
Epoch and start time: 17 und 1764838954.3305879
epoch 17 train loss: 1.1685127034842753e-06
epoch 17 valid loss: 1.0083784022540954e-05


 90%|█████████ | 18/20 [18:00<01:58, 59.34s/it]

Found and saved better weights for the model
Epoch and start time: 18 und 1764839012.1224854
epoch 18 train loss: 8.788274613308041e-07
epoch 18 valid loss: 5.952613072679469e-06


 95%|█████████▌| 19/20 [19:00<00:59, 59.66s/it]

Found and saved better weights for the model
Epoch and start time: 19 und 1764839072.5159395
epoch 19 train loss: 7.376770365172696e-07


100%|██████████| 20/20 [20:01<00:00, 60.06s/it]

epoch 19 valid loss: 8.724446775154604e-06





0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20
epoch_time,01m 00s
lr,0.0001
max_gpu_mem_mb_epoch,1255.12012
model,ConcatIntermediateNe...
train_loss,0.0
valid_loss,1e-05


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch and start time: 0 und 1764839134.9598112
epoch 0 train loss: 0.3945861776653481


  5%|▌         | 1/20 [00:59<18:58, 59.92s/it]

epoch 0 valid loss: 0.1073454849421978
Found and saved better weights for the model
Epoch and start time: 1 und 1764839194.882309
epoch 1 train loss: 0.067811636033615


 10%|█         | 2/20 [01:58<17:42, 59.03s/it]

epoch 1 valid loss: 0.010296216764254495
Found and saved better weights for the model
Epoch and start time: 2 und 1764839253.288337
epoch 2 train loss: 0.016674193818528835


 15%|█▌        | 3/20 [02:55<16:32, 58.37s/it]

epoch 2 valid loss: 0.003607427859969903
Found and saved better weights for the model
Epoch and start time: 3 und 1764839310.882789
epoch 3 train loss: 0.003087576736417855


 20%|██        | 4/20 [03:53<15:29, 58.09s/it]

epoch 3 valid loss: 0.007759220633306541
Epoch and start time: 4 und 1764839368.5449765
epoch 4 train loss: 0.010377619485697515


 25%|██▌       | 5/20 [04:50<14:23, 57.56s/it]

epoch 4 valid loss: 0.004239433896873379
Epoch and start time: 5 und 1764839425.1504705
epoch 5 train loss: 0.0006102407669291815


 30%|███       | 6/20 [05:47<13:23, 57.39s/it]

epoch 5 valid loss: 0.001142928433569068
Found and saved better weights for the model
Epoch and start time: 6 und 1764839482.2180076
epoch 6 train loss: 0.00016439077584388877


 35%|███▌      | 7/20 [06:45<12:29, 57.63s/it]

epoch 6 valid loss: 0.0003237387881767972
Found and saved better weights for the model
Epoch and start time: 7 und 1764839540.3313904
epoch 7 train loss: 6.755714484092168e-05


 40%|████      | 8/20 [07:42<11:30, 57.58s/it]

epoch 7 valid loss: 0.00047088315404835156
Epoch and start time: 8 und 1764839597.8051314
epoch 8 train loss: 4.5596835867163586e-05


 45%|████▌     | 9/20 [08:39<10:29, 57.24s/it]

epoch 8 valid loss: 0.0004298862039547657
Epoch and start time: 9 und 1764839654.2858417
epoch 9 train loss: 3.031922695469509e-05


 50%|█████     | 10/20 [09:35<09:30, 57.06s/it]

epoch 9 valid loss: 0.0003919552261066883
Epoch and start time: 10 und 1764839710.9549687
epoch 10 train loss: 2.0758007718693774e-05


 55%|█████▌    | 11/20 [10:32<08:32, 56.99s/it]

epoch 10 valid loss: 0.00011178693620479407
Found and saved better weights for the model
Epoch and start time: 11 und 1764839767.7718635
epoch 11 train loss: 1.4725692732968292e-05


 60%|██████    | 12/20 [11:31<07:39, 57.42s/it]

epoch 11 valid loss: 0.00014457758069994498
Epoch and start time: 12 und 1764839826.1832483
epoch 12 train loss: 1.025899104882237e-05


 65%|██████▌   | 13/20 [12:27<06:40, 57.15s/it]

epoch 12 valid loss: 0.00021138426368878528
Epoch and start time: 13 und 1764839882.7021198
epoch 13 train loss: 7.549162105115558e-06


 70%|███████   | 14/20 [13:23<05:41, 56.85s/it]

epoch 13 valid loss: 7.166516097640851e-05
Found and saved better weights for the model
Epoch and start time: 14 und 1764839938.8681107
epoch 14 train loss: 5.485746092354959e-06


 75%|███████▌  | 15/20 [14:21<04:44, 56.99s/it]

epoch 14 valid loss: 0.0001462668061290806
Epoch and start time: 15 und 1764839996.1769495
epoch 15 train loss: 4.46502109632263e-06


 80%|████████  | 16/20 [15:17<03:47, 56.76s/it]

epoch 15 valid loss: 8.099806065384741e-05
Epoch and start time: 16 und 1764840052.4016142
epoch 16 train loss: 3.0254041515739705e-06


 85%|████████▌ | 17/20 [16:13<02:49, 56.64s/it]

epoch 16 valid loss: 0.00011162021269148781
Epoch and start time: 17 und 1764840108.7588782
epoch 17 train loss: 2.454910776747859e-06


 90%|█████████ | 18/20 [17:09<01:52, 56.45s/it]

epoch 17 valid loss: 4.238222719434859e-05
Found and saved better weights for the model
Epoch and start time: 18 und 1764840164.783308
epoch 18 train loss: 1.6569131744157885e-06


 95%|█████████▌| 19/20 [18:06<00:56, 56.65s/it]

epoch 18 valid loss: 0.00011012024325329772
Epoch and start time: 19 und 1764840221.9027023
epoch 19 train loss: 1.3168773440737431e-06


100%|██████████| 20/20 [19:03<00:00, 57.16s/it]

epoch 19 valid loss: 3.992538243868804e-05
Found and saved better weights for the model





0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20
epoch_time,00m 56s
lr,0.0001
max_gpu_mem_mb_epoch,1206.69727
model,AddIntermediateNet
train_loss,0.0
valid_loss,4e-05


  0%|          | 0/20 [00:00<?, ?it/s]

Epoch and start time: 0 und 1764840280.719046
epoch 0 train loss: 0.34750779500280166
epoch 0 valid loss: 0.08781693726778031


  5%|▌         | 1/20 [01:02<19:47, 62.48s/it]

Found and saved better weights for the model
Epoch and start time: 1 und 1764840343.2001631
epoch 1 train loss: 0.06240729872676495
epoch 1 valid loss: 0.010004714038223029


 10%|█         | 2/20 [02:02<18:21, 61.21s/it]

Found and saved better weights for the model
Epoch and start time: 2 und 1764840403.5290892
epoch 2 train loss: 0.01651705772678006
epoch 2 valid loss: 0.009514522179961205


 15%|█▌        | 3/20 [03:03<17:15, 60.94s/it]

Found and saved better weights for the model
Epoch and start time: 3 und 1764840464.1360083
epoch 3 train loss: 0.005161507825543691
epoch 3 valid loss: 0.0008390896204218734


 20%|██        | 4/20 [04:03<16:10, 60.66s/it]

Found and saved better weights for the model
Epoch and start time: 4 und 1764840524.3676503
epoch 4 train loss: 0.0007658871954318832


 25%|██▌       | 5/20 [05:03<15:08, 60.54s/it]

epoch 4 valid loss: 0.008796438775061689
Epoch and start time: 5 und 1764840584.693313
epoch 5 train loss: 0.02224542109478247


 30%|███       | 6/20 [06:01<13:54, 59.62s/it]

epoch 5 valid loss: 0.0019300449966976885
Epoch and start time: 6 und 1764840642.517503
epoch 6 train loss: 0.162288186346067


 35%|███▌      | 7/20 [06:59<12:45, 58.86s/it]

epoch 6 valid loss: 0.011959890642901882
Epoch and start time: 7 und 1764840699.8326986
epoch 7 train loss: 0.0029517495508308025
epoch 7 valid loss: 0.0005355776949727443


 40%|████      | 8/20 [07:56<11:42, 58.55s/it]

Found and saved better weights for the model
Epoch and start time: 8 und 1764840757.6956353
epoch 8 train loss: 0.0006668009016050889
epoch 8 valid loss: 0.00044527680711325957


 45%|████▌     | 9/20 [08:57<10:51, 59.26s/it]

Found and saved better weights for the model
Epoch and start time: 9 und 1764840818.5233555
epoch 9 train loss: 0.00036855063166345455
epoch 9 valid loss: 0.0001858173973232624


 50%|█████     | 10/20 [09:58<09:57, 59.78s/it]

Found and saved better weights for the model
Epoch and start time: 10 und 1764840879.4668183
epoch 10 train loss: 7.962843866320518e-05
epoch 10 valid loss: 0.00010822587760230817


 55%|█████▌    | 11/20 [10:59<09:01, 60.13s/it]

Found and saved better weights for the model
Epoch and start time: 11 und 1764840940.3795629
epoch 11 train loss: 5.0442298843216506e-05
epoch 11 valid loss: 6.732805120464036e-05


 60%|██████    | 12/20 [12:00<08:02, 60.36s/it]

Found and saved better weights for the model
Epoch and start time: 12 und 1764841001.2596622
epoch 12 train loss: 3.674605502630201e-05
epoch 12 valid loss: 6.226522581300741e-05


 65%|██████▌   | 13/20 [13:01<07:03, 60.56s/it]

Found and saved better weights for the model
Epoch and start time: 13 und 1764841062.296493
epoch 13 train loss: 2.4315850784037855e-05
epoch 13 valid loss: 3.993015109244879e-05


 70%|███████   | 14/20 [14:02<06:03, 60.57s/it]

Found and saved better weights for the model
Epoch and start time: 14 und 1764841122.8890805
epoch 14 train loss: 1.9241488836030385e-05
epoch 14 valid loss: 3.245523613486512e-05


 75%|███████▌  | 15/20 [15:03<05:03, 60.66s/it]

Found and saved better weights for the model
Epoch and start time: 15 und 1764841183.7673843
epoch 15 train loss: 1.3876729350745096e-05
epoch 15 valid loss: 2.3825772166219396e-05


 80%|████████  | 16/20 [16:03<04:02, 60.54s/it]

Found and saved better weights for the model
Epoch and start time: 16 und 1764841244.0329447
epoch 16 train loss: 9.52297122397358e-06
epoch 16 valid loss: 1.6848700525429194e-05


 85%|████████▌ | 17/20 [17:03<03:01, 60.45s/it]

Found and saved better weights for the model
Epoch and start time: 17 und 1764841304.2612195
epoch 17 train loss: 7.498201940816635e-06


 90%|█████████ | 18/20 [18:04<02:01, 60.57s/it]

epoch 17 valid loss: 1.6859146688830152e-05
Epoch and start time: 18 und 1764841365.1260242
epoch 18 train loss: 5.660604108274245e-06
epoch 18 valid loss: 9.776575144826438e-06


 95%|█████████▌| 19/20 [19:02<00:59, 59.78s/it]

Found and saved better weights for the model
Epoch and start time: 19 und 1764841423.0684865
epoch 19 train loss: 3.790752075803395e-06
epoch 19 valid loss: 5.986662091572725e-06


100%|██████████| 20/20 [20:02<00:00, 60.14s/it]

Found and saved better weights for the model





0,1
epoch,▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
lr,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
max_gpu_mem_mb_epoch,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
train_loss,█▂▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁
valid_loss,█▂▂▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁

0,1
epoch,20
epoch_time,01m 00s
lr,0.0001
max_gpu_mem_mb_epoch,1406.19629
model,LateNet
train_loss,0.0
valid_loss,1e-05


# Evaluation

In [None]:
def plot_losses(loss_dict, title="Validation Loss per Model", ylabel="Loss", xlabel="Epoch"):
    """
    loss_dict: dict of { "model_name": list_of_losses }
               Every list must have the same length.

    """

    plt.figure(figsize=(8,5))

    # Auto-generate x-axis based on first model
    any_key = next(iter(loss_dict))
    epochs = range(len(loss_dict[any_key]))

    for model_name, losses in loss_dict.items():
        plt.plot(epochs, losses, label=model_name)

    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.title(title)
    plt.legend()
    plt.grid(True, linestyle="--", alpha=0.3)
    plt.show()

In [None]:
loss_dict = {
    "EarlyNet": metrics["early_fusion"]["valid_losses"],
    "CatNet": metrics["intermediate_fusion_concat"]["valid_losses"],
    "MatmulNet": metrics["intermediate_fusion_matmul"]["valid_losses"],
    "AddNet": metrics["intermediate_fusion_add"]["valid_losses"],
    "HadamardNet": metrics["intermediate_fusion_hadamard"]["valid_losses"],
    "LateNet": metrics["late_fusion"]["valid_losses"]
}

plot_losses(loss_dict, title="Validation Loss Comparison")

KeyError: 'intermediate_fusion_matmul'

In [None]:
# compute avg_epoch_time
avg_epoch_time = sum(epoch_times) / len(epoch_times)

In [None]:
import numpy as np
import pandas as pd

# Optional: nice display names for each key in `metrics`
name_map = {
    "early_fusion": "Early Fusion",
    "late_fusion": "Late Fusion",
    "intermediate_fusion_concat": "Intermediate (Concat)",
    "intermediate_fusion_matmul": "Intermediate (Multiplicative)",
    "intermediate_fusion_hadamard": "Intermediate (Hadamard)",
    "intermediate_fusion_add": "Intermediate (Add)",
}

rows = []

for key, m in metrics.items():
    avg_train_loss = float(np.mean(m["train_losses"]))
    avg_valid_loss = float(np.mean(m["valid_losses"]))
    avg_epoch_time = float(np.mean(m["epoch_times"]))

    rows.append({
        "Fusion Strategy": name_map.get(key, key),
        "Avg Valid Loss": avg_valid_loss,
        "Best Valid Loss": float(m["best_valid_loss"]),
        "Num of params": int(m["num_params"]),
        "Avg time per epoch (min:s)": avg_epoch_time,
        "GPU Memory (MB, max)": float(m["max_gpu_mem_mb"]),
    })

df_comparison = pd.DataFrame(rows)
df_comparison


In [None]:
# logs the comparison table to wandb
wandb.init(
    project="cilp-extended-assessment",   # your project name
    name="fusion_comparison_all",
    job_type="analysis",
)

fusion_comparison_table = wandb.Table(dataframe=df_comparison)
wandb.log({"fusion_comparison": fusion_comparison_table})

wandb.finish()

**When to use**

**Early Fusion:**
* Aligned, closely related low-level modalities and comparable features
* Simple setup; avoid if sensors are noisy

**Intermediate Fusion:**
* Modalities with different structure that benefit from separate early processing in order to learn modality-specific features   
* best overall balance of performance and flexibility

**Late Fusion:**
* Strong, independent unimodal predictors, to combine their strengths
* ideal for heterogeneous or missing modalities
* robust fallback when one modality fails