# Channel Importance Analysis

Two complementary views on whether early fusion models ignore input channels:

1. **Weight norm analysis** — directly inspect the first-layer conv weights; requires only the `.pth` checkpoint
2. **Channel ablation** — zero out one marker at a time and measure embedding shift; requires the model + h5 data
3. **Embedding effective rank** — how many dimensions are actually used in the representation space

In [None]:
import sys
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import h5py
from pathlib import Path

sys.path.insert(0, str(Path('..').resolve()))
from src.models import WideModel
from src.models_early_fusion import ResNetBaseline

In [None]:
# ── Configuration ─────────────────────────────────────────────────────────────
BASE = Path('/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/src/MCA/z_RUNS')

RUNS = {
    'CIM (channel-sep.)' : BASE / 'CODEX_cHL_CIM_VICReg',
    'ResNet (early fus.)': BASE / 'CODEX_cHL_ResNet_VICReg',
}

MARKERS_TXT = Path('/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/data/MCI_data/h5_files/CODEX_cHL/used_markers.txt')
H5_FILE     = Path('/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/data/MCI_data/h5_files/CODEX_cHL/CODEX_cHL.h5')
VAL_IDX_TXT = Path('/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/data/MCI_data/h5_files/CODEX_cHL/val.txt')

# CIM backbone config (must match the run's backbone config)
CIM_CFG    = dict(in_channels=41, stem_width=32, block_width=4, layer_config=[1, 1])
RESNET_CFG = dict(in_channels=41, base_width=64)

ABLATION_N_CELLS = 512   # how many val cells to use for ablation (keep small for speed)
PATCH_SIZE       = 32
DEVICE           = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {DEVICE}')

## Helper — load backbone from checkpoint

In [None]:
def load_backbone(run_dir: Path, backbone: torch.nn.Module) -> torch.nn.Module:
    """Load backbone weights from an mmengine checkpoint in run_dir."""
    ptr = run_dir / 'last_checkpoint'
    ckpt_path = Path(ptr.read_text().strip())
    ckpt = torch.load(ckpt_path, map_location='cpu')
    state = ckpt['state_dict']
    bb_state = {k[len('backbone.'):]: v for k, v in state.items() if k.startswith('backbone.')}
    backbone.load_state_dict(bb_state, strict=True)
    backbone.eval()
    return backbone

marker_names = np.loadtxt(MARKERS_TXT, dtype=str, delimiter=',')
print(f'{len(marker_names)} markers: {marker_names}')

---
## 1 · Weight norm analysis

For each model the first conv layer determines which input channels are attended to.  
- **ResNet stem**: `Conv2d(41 → base_width, 3×3, groups=1)` — weight shape `[base_width, 41, 3, 3]`  
  → per-channel norm = `w.norm(dim=(0,2,3))`  
- **CIM stem**: `Conv2d(41 → 41×stem_width, 3×3, groups=41)` — weight shape `[41×stem_width, 1, 3, 3]`  
  → each group of `stem_width` output rows belongs to one input channel  
  → per-channel norm = `w.view(41, stem_width, 1, 3, 3).norm(dim=(1,2,3,4))`

In [None]:
cim    = load_backbone(RUNS['CIM (channel-sep.)'],  WideModel(**CIM_CFG))
resnet = load_backbone(RUNS['ResNet (early fus.)'], ResNetBaseline(**RESNET_CFG))

# ── CIM: depthwise weight [41*stem_width, 1, 3, 3]
# group i occupies rows [i*stem_width : (i+1)*stem_width]
# → reshape to [C, stem_width*1*3*3] then norm over dim 1
w_cim = cim.stem[0].weight   # [C*D, 1, 3, 3]
C, D  = CIM_CFG['in_channels'], CIM_CFG['stem_width']
norms_cim = w_cim.reshape(C, -1).norm(dim=1).detach().numpy()

# ── ResNet: standard conv [base_width, 41, 3, 3]
# → transpose to [41, base_width, 3, 3] then flatten and norm
w_res     = resnet.stem[0].weight   # [base_width, C, 3, 3]
norms_res = w_res.permute(1, 0, 2, 3).reshape(C, -1).norm(dim=1).detach().numpy()

# Normalise both to [0, 1] for easy comparison
norms_cim_n = norms_cim / norms_cim.max()
norms_res_n = norms_res / norms_res.max()

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(16, 7), sharex=True)
x = np.arange(len(marker_names))

for ax, norms, label, color in zip(
    axes,
    [norms_cim_n, norms_res_n],
    ['CIM  (channel-sep.)', 'ResNet (early fus.)'],
    ['steelblue', 'tomato'],
):
    # sort by CIM norms so both plots share the same x-order
    order = np.argsort(norms_cim_n)[::-1] if 'CIM' in label else np.argsort(norms_cim_n)[::-1]
    bars = ax.bar(x, norms[order], color=color, alpha=0.85, edgecolor='white', linewidth=0.4)
    ax.axhline(norms[order].mean(), color='black', lw=1.2, ls='--', label=f'mean = {norms[order].mean():.3f}')
    ax.set_ylabel('Normalised stem\nweight norm', fontsize=10)
    ax.set_title(label, fontsize=11, fontweight='bold')
    ax.set_ylim(0, 1.08)
    ax.legend(fontsize=9)
    ax.set_xticks(x)
    ax.set_xticklabels(marker_names[order], rotation=75, ha='right', fontsize=8)

fig.suptitle('First-layer weight norm per input channel', fontsize=13, y=1.01)
plt.tight_layout()
plt.savefig('../notebooks/channel_weight_norms.pdf', bbox_inches='tight')
plt.show()

# Summary statistics
for name, norms in [('CIM', norms_cim_n), ('ResNet', norms_res_n)]:
    low = (norms < 0.1).sum()
    print(f'{name}: {low}/{len(norms)} channels with norm < 0.1 (effectively ignored)')

---
## 2 · Channel ablation

Zero out one channel at a time across a batch of val cells.  
Sensitivity = `1 - mean cosine_similarity(baseline_embedding, ablated_embedding)`  
A channel the model ignores → sensitivity ≈ 0.

In [None]:
# ── Load a small batch of raw patches from the h5 file ───────────────────────
val_idx = np.loadtxt(VAL_IDX_TXT, dtype=int)
rng     = np.random.default_rng(42)
chosen  = rng.choice(val_idx, size=ABLATION_N_CELLS, replace=False)
chosen.sort()

half = PATCH_SIZE // 2

with h5py.File(H5_FILE, 'r') as h5:
    DIM1      = h5['coords']['DIM1'][chosen]
    DIM2      = h5['coords']['DIM2'][chosen]
    sample_id = h5['coords']['sample_id'][chosen].astype(str)

    patches = []
    for d1, d2, sid in zip(DIM1, DIM2, sample_id):
        img  = h5['data'][sid]['image']
        s1, e1 = max(0, d1-half), min(img.shape[0], d1+half)
        s2, e2 = max(0, d2-half), min(img.shape[1], d2+half)
        p = img[s1:e1, s2:e2, :]   # [H, W, C]
        # pad if needed
        p = np.pad(p,
                   ((d1-half - min(0, d1-half), max(0, d1+half - img.shape[0])),
                    (d2-half - min(0, d2-half), max(0, d2+half - img.shape[1])),
                    (0, 0)), mode='constant')
        patches.append(p)

# [N, H, W, C] → [N, C, H, W]
X = torch.from_numpy(np.stack(patches)).permute(0, 3, 1, 2).float()
print(f'Loaded {X.shape[0]} patches, shape {tuple(X.shape)}')

In [None]:
import torch.nn.functional as F

def channel_sensitivity(backbone, X, device):
    """Returns (C,) array of sensitivity scores."""
    backbone = backbone.to(device)
    X        = X.to(device)

    with torch.no_grad():
        baseline = backbone(X)[0].squeeze(-1).squeeze(-1)  # [N, D]
        baseline = F.normalize(baseline, dim=-1)

        sensitivities = []
        for c in range(X.shape[1]):
            X_abl = X.clone()
            X_abl[:, c] = 0.0
            ablated = backbone(X_abl)[0].squeeze(-1).squeeze(-1)
            ablated = F.normalize(ablated, dim=-1)
            sim = (baseline * ablated).sum(dim=-1).mean().item()
            sensitivities.append(1.0 - sim)

    backbone.cpu()
    return np.array(sensitivities)

sens_cim    = channel_sensitivity(cim,    X, DEVICE)
sens_resnet = channel_sensitivity(resnet, X, DEVICE)
print('Done.')

In [None]:
fig, axes = plt.subplots(2, 1, figsize=(16, 7), sharex=True)
order = np.argsort(sens_cim)[::-1]  # sort by CIM sensitivity

for ax, sens, label, color in zip(
    axes,
    [sens_cim, sens_resnet],
    ['CIM  (channel-sep.)', 'ResNet (early fus.)'],
    ['steelblue', 'tomato'],
):
    ax.bar(x, sens[order], color=color, alpha=0.85, edgecolor='white', linewidth=0.4)
    ax.axhline(sens[order].mean(), color='black', lw=1.2, ls='--',
               label=f'mean = {sens[order].mean():.4f}')
    ax.set_ylabel('Ablation sensitivity\n(1 − cos sim)', fontsize=10)
    ax.set_title(label, fontsize=11, fontweight='bold')
    ax.legend(fontsize=9)
    ax.set_xticks(x)
    ax.set_xticklabels(marker_names[order], rotation=75, ha='right', fontsize=8)

fig.suptitle('Channel ablation: embedding sensitivity to zeroing each marker', fontsize=13, y=1.01)
plt.tight_layout()
plt.savefig('../notebooks/channel_ablation.pdf', bbox_inches='tight')
plt.show()

# Print the most and least sensitive channels for each model
for name, sens in [('CIM', sens_cim), ('ResNet', sens_resnet)]:
    top5    = np.argsort(sens)[::-1][:5]
    bottom5 = np.argsort(sens)[:5]
    print(f'\n{name} — top-5 most sensitive:    {[(marker_names[i], round(sens[i],4)) for i in top5]}')
    print(f'{name} — top-5 least sensitive:   {[(marker_names[i], round(sens[i],4)) for i in bottom5]}')

---
## 3 · Effective rank of embedding space

Are the ResNet embeddings collapsing into a low-dimensional subspace?  
Uses the saved `val_results.npz` files — no model or data loading needed.

In [None]:
def effective_rank(Z):
    """Roy & Vetterli 2007: exp(entropy of normalised singular value distribution)."""
    Z = Z - Z.mean(axis=0)
    _, S, _ = np.linalg.svd(Z, full_matrices=False)
    p = S**2 / (S**2).sum()
    erank = np.exp(-np.sum(p * np.log(p + 1e-12)))
    return erank, S, p

results = {}
for label, run_dir in RUNS.items():
    npz = np.load(run_dir / 'val_results.npz')
    Z   = npz['features']
    er, S, p = effective_rank(Z)
    results[label] = dict(S=S, p=p, erank=er, dim=Z.shape[1])
    print(f'{label}: effective rank = {er:.1f} / {Z.shape[1]}  ({100*er/Z.shape[1]:.1f}% of dims used)')

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4))

colors = {'CIM (channel-sep.)': 'steelblue', 'ResNet (early fus.)': 'tomato'}

# Left: singular value spectra
ax = axes[0]
for label, res in results.items():
    S_norm = res['S'] / res['S'][0]
    ax.plot(S_norm, label=f"{label}  (erank={res['erank']:.0f}/{res['dim']})",
            color=colors[label], lw=1.5)
ax.set_xlabel('Singular value index')
ax.set_ylabel('Normalised singular value')
ax.set_title('Singular value spectrum')
ax.legend(fontsize=9)
ax.set_yscale('log')

# Right: cumulative variance explained
ax = axes[1]
for label, res in results.items():
    cum_var = np.cumsum(res['S']**2) / (res['S']**2).sum()
    k90 = int((cum_var < 0.90).sum())
    ax.plot(cum_var, color=colors[label], lw=1.5,
            label=f'{label}  (90% var @ dim {k90})')
ax.axhline(0.90, color='grey', ls=':', lw=1)
ax.set_xlabel('Number of dimensions')
ax.set_ylabel('Cumulative variance explained')
ax.set_title('Cumulative variance')
ax.legend(fontsize=9)

plt.tight_layout()
plt.savefig('../notebooks/effective_rank.pdf', bbox_inches='tight')
plt.show()