# Fusion Architecture Comparison

In this notebook I will compare late fusion with three variants of intermediate fusion by training four models in total and comparing their perfomance. 

In [None]:
import sys

# Colab-only setup
if "google.colab" in sys.modules:
    print("Running in Google Colab. Setting up repo")

    !git clone https://github.com/MatthiasCr/Computer-Vision-Assignment-2.git
    %cd Computer-Vision-Assignment-2
    !pip install -r requirements.txt

In [None]:
# insert wandb token
!wandb login

In [None]:
import fiftyone as fo
from fiftyone.utils.huggingface import load_from_hub
import numpy as np
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.optim.lr_scheduler import CosineAnnealingLR
import torchvision.transforms.v2 as transforms
from torch.utils.data import DataLoader
import wandb

from pathlib import Path

project_root = Path("..").resolve()
sys.path.append(str(project_root))
from src import datasets
from src import training
from src import visualization
from src import models

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## Load Data 

As explained in the last notebook I use a subset of the data that i uploaded on huggingface.

In [None]:
# hyperparameters will be the same for all experiments to make them comparable
IMG_SIZE = 64
BATCH_SIZE = 32
EPOCHS = 20

# Initial and target value for learning rate scheduler
START_LR = 1e-3
END_LR = 1e-6

In [None]:
# load fiftyone dataset from huggingface
dataset = load_from_hub(
    "MatthiasCr/multimodal-shapes-subset", 
    name="multimodal-shapes-subset",
    # fewer workers and greater batch size to hopefully avoid getting rate limited
    num_workers=2,
    batch_size=1000,
    overwrite=True,
)

Now I convert this fiftyone dataset to torch datasets using the already existing tags for the train / val split. For this I have defined a class `MultimodalDataset`. I also create dataloaders for train and valid, as well as a separate dataloader (`log_loader`) that I will use for predictions on the valid dataset. For all dataloaders with shuffle = True I specify a generator with fixed seed to make the shuffling deterministic for reproducible results.

In [None]:
img_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToImage(),
    transforms.ToDtype(torch.float32, scale=True),
])

train_dataset = datasets.MultimodalDataset(dataset, "train", img_transforms)
val_dataset = datasets.MultimodalDataset(dataset, "val", img_transforms)

# use generator with fixed seed for reproducible shuffling
generator = torch.Generator()
generator.manual_seed(51)

train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True, generator=generator)
val_dataloader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, drop_last=True)

# loader to conduct sample predictions
log_loader = DataLoader(val_dataset, batch_size=5, shuffle=True, num_workers=0, generator=generator)

# number of train batches, needed for learning rate scheduling
steps_per_epoch = len(train_dataloader)

## Experiments

The `MultimodalDataset` returns the rbg image data, lidar xyza data, and the class (0 = cube, 1 = sphere) index for each item. 

I have defined a very generalized model training function (`training.train_model()`) that I will use for all training loops across all notebooks. This training function needs a function on how to apply the model on a batch (forward pass). So lets define this function for this notebook:

In [None]:
# function that tells the training process how to apply the model on a batch of the dataset
def apply_model(model, batch):
    target = batch[2].to(device)
    inputs_rgb = batch[0].to(device)
    inputs_xyz = batch[1].to(device)
    outputs = model(inputs_rgb, inputs_xyz)
    return outputs, target

Now i define a function to do a complete experiment cycle for a given model. Every experiment will be logged on Wandb so every experiment starts with an initialization of a new run. Here we pass relevant metrics and hyperparameters as confic, such as the number of parameters, optimizer type, and learning rate scheduler. Then we do the training loop. Inside this training loop we log training, validation loss, and validation accurarcy every epoch. We also checkpoint the training by saving the weights with best validation loss. After the training we load the best checkpoint and do some predictions on 5 validation samples. For each we log the RGB image, the lidar xyza data projected to an image, the label, the prediction, and the prediction probability to wandb.

In [None]:
def log_experiment(model, best_model, fusion_type, device, output_name):
    num_params = sum(p.numel() for p in model.parameters())
    optim = Adam(model.parameters(), lr=START_LR)
    scheduler = CosineAnnealingLR(optim, T_max=EPOCHS * steps_per_epoch, eta_min=END_LR)
    loss_func = nn.BCEWithLogitsLoss()

    # init wandb run and log config hyperparameters
    run = training.initWandbRun(
        fusion_type, EPOCHS, BATCH_SIZE, num_params, "Adam", "Cosine Annealing", START_LR, END_LR
    )

    # train and log loss
    train_loss, val_loss = training.train_model(
        model, optim, apply_model, loss_func, EPOCHS, train_dataloader, val_dataloader, device, run, scheduler=scheduler, output_name=output_name
    )

    # load best model
    model_save_path = f"../checkpoints/{output_name}.pt"
    best_model.load_state_dict(torch.load(model_save_path, map_location=device))
    best_model = best_model.to(device)

    # predict on 1 batch of 5 samples. Log predictions to wandb
    training.log_predictions(best_model, log_loader, device, run, num_batches=1)
    
    run.finish()
    return train_loss, val_loss

### Late Fusion

Lets run the experiment for the different fusion models starting with late-fusion.
The `LateFusionNet` has an own `LateEmbedder` for both rgb and lidar. The late embedder create full and flattened embeddings with 100 dimensions for each modality. These embeddings are then concatenated and passed into a linear head.

In [None]:
late_model = models.LateFusionNet().to(device)
late_model_best = models.LateFusionNet().to(device)
late_train_loss, late_val_loss = log_experiment(late_model, late_model_best, "late", device, output_name="late")

visualization.plot_loss(EPOCHS,
    {
        "Late Train Loss": late_train_loss,
        "Late Val Loss": late_val_loss
    }
)

### Intermediate Fusion

For intermediate fusion I have defined a model `IntermediateFusionNet` that has one `IntermediateEmbedder` per modality. These embedders do not produce full embeddings but return intermediate feature maps already after the second convolution. These feature maps are then combined in a way that depends on the type (concatenation, element-wise addition, element-wise multiplication (hadamard product)). After the fusion there is another shared convolutional layer and a linear classification head.

In [None]:
cat_model = models.IntermediateFusionNet(fusion_type="cat").to(device)
cat_model_best = models.IntermediateFusionNet(fusion_type="cat").to(device)
cat_train_loss, cat_val_loss = log_experiment(cat_model, cat_model_best, "intermediate (concatenation)", device, output_name="cat")

add_model = models.IntermediateFusionNet(fusion_type="add").to(device)
add_model_best = models.IntermediateFusionNet(fusion_type="add").to(device)
add_train_loss, add_val_loss = log_experiment(add_model, add_model_best, "intermediate (addition)", device, output_name="add")

had_model = models.IntermediateFusionNet(fusion_type="had").to(device)
had_model_best = models.IntermediateFusionNet(fusion_type="had").to(device)
had_train_loss, had_val_loss = log_experiment(had_model, had_model_best, "intermediate (hadamard)", device, output_name="had")

## Analysis

We can now compare how these models performed on the validation dataset:

In [None]:
visualization.plot_loss(EPOCHS,
    {
        "Concat Valid Loss": cat_val_loss,
        "Addition Valid Loss": add_val_loss,
        "Hadamard Valid Loss": had_val_loss,
        "Late Valid Loss": late_val_loss
    }
)

We can also visualize and analyze the experiments in Wandb. The following screenshots show some visualization in the Wandb dashbord. Most importantly we can visualize the loss and accuracy curves and determine which model performed best. We can also inspect the sample predictions to get a very brief feeling on how confident the model predictions are. 

<img src="../results/wandb-t3-graphs.png">
<img src="../results/wandb-t3-table.png">

<div style="display:grid; grid-template-columns: 1fr 1fr; gap:30px;">
  <img src="../results/wandb-t3-valid-loss.png">
  <img src="../results/wandb-t3-predictions.png">
</div>


The following table summarizes the performance of the 4 models. Validation loss and accuracy are measured from the **best checkpoint**. Train time includes the entire train loop for all 20 epochs and includes the validation phases. GPU memory is measured by wandb by default. 

| Metric | Late Fusion | Intermediate (Cat) | Intermediate (Add) | Intermediate (Had) |
| --- | --- | --- | --- | --- |
| Valid Loss | 0.039 | 0.032 | 0.119 | 0.070 |
| Valid Accuracy | 0.987 | 0.992 | 0.971 | 0.984 |
| Parameter Count | 1,415,051 | 914,201 | 824,201 | 824,201 |
| Train Time (s) | 22.17 | 25.71 | 24.38 | 24.76 |
| GPU Memory (MB) | 601 | 773 | 775 | 777 |

The best performing model was **Intermediate Fusion with Concatenation**. Its best checkpoint has a validation loss of 0.032 and a validation accuracy of 0.992. But also pretty close comes late fusion and intermediate with hadamard. Intermediate with addition has lowest accuracy and highest loss. 

It is interesting that the late fusion model has around 50% more parameters than the intermediate concat model, but the performance of the intermediate concat model is pretty competitive and even slightly better. Also the intermediate hadamard model performs quite well given it has even less parameters. This shows that intermediate fusion is a very fitting architecture for this problem. The idea of intermediate fusion is to combine information from both modalities somewhere during the creation of one joint representation, rather than creating two fully separate embeddings. This way intermediate fusion allows for more joint, cross-modal feature interactions before making a decision. Concatenation is specific preserves the full information from both modalities and letting the model learn itself how to combine them inside the next convolutional layer. This gives the model a little bit more flexibility rather than forcing to combine the modalities with addition or multipliction.

A possible hypothesis for the comparatively bad perfomance of the intermediate addition model is that addition indiscriminately blends features from both modalities, averaging their activations and thereby weakening **modality-specific** signals. If, for example, one modality is noisy or less informative its features can directly corrupt the other modality's representation. In contrast, the hadamard product features are amplified when **both modalities** produce strong activations and suppressed where one of the modalites is rather weak. This encourages agreement between the modalities and reduces the influence of noisy or irrelevant features, which can lead to more meaningful joint representations and better performance simple addition.