# Architecture Comparison

This notebook tests all implemented backbone architectures for multiplex spatial proteomics cell phenotyping.

The central research question is: **is channel separability a better inductive bias than early fusion?**

Each cell below instantiates one model, runs a forward pass with dummy data, and reports parameter count and output shape.

**Test setup:** `n_channels=41` (cHL panel), `input_size=24` (cutter_size), `batch_size=4`

In [None]:
import sys
sys.path.insert(0, '..')

import torch
import torch.nn as nn

from src.models import WideModel, SharedStemModel, WideModelAttention
from src.models_attention import MCIANet
from src.models_early_fusion import EarlyFusionModel, MidFusionModel, ProjectionFusionModel, ResNetBaseline

# Shared test inputs
N_CHANNELS  = 41
INPUT_SIZE  = 24
BATCH_SIZE  = 4
STEM_WIDTH  = 32   # features_per_marker

x = torch.randn(BATCH_SIZE, N_CHANNELS, INPUT_SIZE, INPUT_SIZE)

def summarise(name, model, x):
    model.eval()
    with torch.no_grad():
        out = model(x)
    out_tensor = out[0]
    n_params = sum(p.numel() for p in model.parameters())
    print(f"{'Model':<30} {'Output shape':<25} {'Params':>10}")
    print(f"{name:<30} {str(tuple(out_tensor.shape)):<25} {n_params:>10,}")
    return out_tensor

print('Setup complete. Input shape:', x.shape)

---
## Channel-Separable Architectures
These models maintain strict per-channel independence throughout the convolutional stem and processing layers. Cross-channel mixing only occurs at the end (late fusion / attention).

### 1. WideModel (CIM — Channel-Independent Model)

**Key design:** Depthwise grouped convolutions (`groups=in_channels`) throughout stem and all ConvBlocks. Each marker channel has its own independent set of convolutional filters. Channels are **never mixed** during spatial feature extraction — the model treats each marker as a completely independent spatial signal.

Optional `late_fusion=True` adds a single 1×1 conv at the end to allow cross-channel mixing after global pooling.

**Inductive bias:** Spatial patterns within each marker channel are independently meaningful. Cell identity emerges from aggregating per-marker features, not from pixel-level cross-marker interactions.

**This is the proposed architecture / positive hypothesis.**

In [None]:
model_cim = WideModel(
    in_channels=N_CHANNELS,
    stem_width=STEM_WIDTH,
    block_width=2,
    layer_config=[1, 1],
    late_fusion=False,
    drop_prob=0.05,
)
summarise('WideModel (CIM)', model_cim, x)

### 2. SharedStemModel (SSM — Shared Stem Model)

**Key design:** A single shared convolutional stem (weights shared across all channels) is applied independently to each marker channel. Unlike WideModel which has separate weights per channel, here all channels share the same filter — the stem learns a universal single-channel spatial feature extractor.

**Inductive bias:** All marker channels share the same spatial statistics (e.g., punctate vs diffuse staining). Marker identity is encoded only via the channel position after concatenation.

**Compared to WideModel:** SSM has dramatically fewer parameters in the stem (1 filter set vs C filter sets). This is more parameter-efficient but loses per-marker specialisation.

In [None]:
model_ssm = SharedStemModel(
    in_channels=N_CHANNELS,
    stem_width=STEM_WIDTH,
    block_width=2,
    n_layers=2,
    late_fusion=False,
)
summarise('SharedStemModel (SSM)', model_ssm, x)

### 3. WideModelAttention (CIMATT — CIM with Attention)

**Key design:** Channel-separable stem and ConvBlocks (identical to WideModel), followed by **multi-head self-attention across channel tokens**. After spatial processing, each channel contributes one token (spatially pooled). Attention computes cross-channel relationships and gates the spatial feature maps before final pooling.

**Inductive bias:** Same as WideModel for spatial feature extraction, but adds an explicit mechanism to learn which marker co-expression patterns matter (e.g., CD4 and CD25 together → TReg signal).

**Compared to WideModel:** The attention layer is the only place where channel interactions occur. This is a principled, interpretable form of late fusion.

In [None]:
model_cimatt = WideModelAttention(
    in_channels=N_CHANNELS,
    stem_width=STEM_WIDTH,
    block_width=2,
    layer_config=[1, 1],
    drop_prob=0.05,
    n_heads=4,
)
summarise('WideModelAttention (CIMATT)', model_cimatt, x)

### 4. MCIANet (ATT — Multi-Channel Image Analysis Network)

**Key design:** Per-channel **shared** stem (stride-2, same weights across channels) + **SpatialFocusMap** (learnable Gaussian-initialised spatial attention that suppresses background outside the cell) + two stages of joint convolutional processing interspersed with **cross-channel attention** at 6×6 and 3×3 spatial resolutions.

**Key novelties over CIMATT:**
- Shared stem (parameter-efficient, assumes all channels have similar spatial statistics)
- SpatialFocusMap: explicitly learns to focus on the cell body and ignore neighbouring cells
- Attention is applied at multiple spatial scales (early and late)
- Joint convolutional stages after the stem allow cross-channel spatial reasoning

**Inductive bias:** Spatial filtering of the cell body is the first priority; cross-channel co-expression reasoning happens iteratively at decreasing resolutions.

In [None]:
model_att = MCIANet(
    in_channels=N_CHANNELS,
    input_size=INPUT_SIZE,
    stem_dim=STEM_WIDTH,
    n_heads=4,
    stem_blocks=2,
    stage1_blocks=2,
    stage2_blocks=2,
    expansion=2,
    drop_prob=0.05,
    sigma_fraction=0.35,
    spatial_init_mode='ones',
)
summarise('MCIANet (ATT)', model_att, x)

---
## Early Fusion Baselines
These models mix all input channels together in the first (or zeroth) layer. They serve as the null hypothesis: if channel separability provides no benefit, these should match or exceed the channel-separable models above.

### 5. EarlyFusionModel

**Key design:** Identical architecture to WideModel (CIM) in every respect — same stem_width, block_width, layer_config, number of parameters — **except** all Conv2d operations use `groups=1` (standard convolutions) instead of `groups=in_channels` (depthwise). All 41 channels are mixed together in the very first 3×3 convolution.

**This is the cleanest possible ablation:** a single hyperparameter change (`groups`) separates this model from WideModel. Any performance difference is directly attributable to the depthwise vs. standard convolution choice.

**Null hypothesis:** If EarlyFusionModel ≥ WideModel → channel separability provides no benefit for this task.

In [None]:
model_ef = EarlyFusionModel(
    in_channels=N_CHANNELS,
    stem_width=STEM_WIDTH,
    block_width=2,
    layer_config=[1, 1],
    late_fusion=False,
    drop_prob=0.05,
)
summarise('EarlyFusionModel', model_ef, x)

### 6. MidFusionModel

**Key design:** The stem is **channel-separable** (depthwise conv, identical to WideModel), but all subsequent ConvBlocks use **standard convolutions** that freely mix channels. Fusion happens after the first layer of per-channel spatial feature extraction.

**Research question:** Is one layer of per-channel spatial processing sufficient to capture the per-marker spatial statistics, after which standard cross-channel processing is fine? Or does maintaining separability throughout the network matter?

**Expected outcome:** Should fall between WideModel and EarlyFusionModel. If MidFusion ≈ WideModel, the benefit of separability comes mainly from the stem. If MidFusion ≈ EarlyFusion, the benefit requires separability in the processing layers too.

In [None]:
model_mf = MidFusionModel(
    in_channels=N_CHANNELS,
    stem_width=STEM_WIDTH,
    block_width=2,
    layer_config=[1, 1],
    drop_prob=0.05,
)
summarise('MidFusionModel', model_mf, x)

### 7. ProjectionFusionModel

**Key design:** A **1×1 convolution** (kernel_size=1, no spatial context) collapses all input channels into the joint feature space **before any spatial processing**. All subsequent layers operate on already-fused features.

**This is the most aggressive early fusion strategy:** unlike EarlyFusionModel (which uses a 3×3 stem that at least has local spatial context when mixing channels), the 1×1 projection has zero spatial receptive field. The model never has access to per-channel spatial structure.

**Interpretation:** If ProjectionFusion ≈ EarlyFusion, spatial context at the fusion point doesn't matter — the damage is done simply by mixing channels. If ProjectionFusion << EarlyFusion, even a 3×3 mixed-channel conv is harmful, and the order of spatial↔channel mixing matters.

In [None]:
model_pf = ProjectionFusionModel(
    in_channels=N_CHANNELS,
    stem_width=STEM_WIDTH,
    block_width=2,
    layer_config=[1, 1],
    drop_prob=0.05,
)
summarise('ProjectionFusionModel', model_pf, x)

### 8. ResNetBaseline

**Key design:** A lightweight ResNet-style architecture with standard early-fusion convolutions, scaled for small patches (no strided conv at input, 2 residual stages only). Takes all `in_channels` as a standard multi-channel input with no assumptions about channel structure.

**Role:** External reference point outside the WideModel family. A well-understood architecture that makes no assumptions about channel separability whatsoever and treats the multiplex image identically to how a standard CNN treats an RGB image.

**Note:** Output dimension is `base_width * 4 = 256`, independent of `n_channels`. The neck `in_channels` must be set to 256 in the experiment config.

In [None]:
model_rn = ResNetBaseline(
    in_channels=N_CHANNELS,
    base_width=64,
    drop_prob=0.05,
)
summarise('ResNetBaseline', model_rn, x)

---
## Summary Comparison

In [None]:
models = [
    ('WideModel (CIM)',          model_cim,    'Channel-separable throughout',              'Proposed'),
    ('SharedStemModel (SSM)',    model_ssm,    'Shared stem, channel-separable',            'Proposed variant'),
    ('WideModelAttention (CIMATT)', model_cimatt, 'Channel-separable + cross-channel attn',    'Proposed + attention'),
    ('MCIANet (ATT)',            model_att,    'Shared stem + SpatialFocus + multi-scale attn', 'Proposed + focus'),
    ('EarlyFusionModel',        model_ef,     'Standard convs (groups=1) throughout',      'Ablation'),
    ('MidFusionModel',          model_mf,     'Depthwise stem, then standard convs',       'Ablation'),
    ('ProjectionFusionModel',   model_pf,     '1x1 pixel fusion first, then standard convs','Ablation'),
    ('ResNetBaseline',          model_rn,     'Standard ResNet, no channel assumptions',   'External baseline'),
]

print(f"{'Model':<30} {'Params':>10}  {'Fusion strategy':<45} {'Role'}")
print('-' * 110)
for name, model, strategy, role in models:
    n_params = sum(p.numel() for p in model.parameters())
    model.eval()
    with torch.no_grad():
        out = model(x)[0]
    print(f"{name:<30} {n_params:>10,}  {strategy:<45} {role}")