# MCA Experiment Visualisation

End-to-end analysis of a training run from the `.npz` result files.

| Section | What it shows |
|---|---|
| **1. Load Results** | Feature shapes, class distribution |
| **2. Classification Metrics** | Accuracy table + confusion matrix |
| **3. UMAP Embedding** | 2-D projection coloured by label / prediction / confidence |
| **4. Marker Activations** | Per-marker feature means overlaid on UMAP |
| **5. Per-class Marker Profile** | Heatmap of which markers drive each cell-type |

> **Model attention maps** are covered in `attention_maps.ipynb`.

## Configuration
Edit these variables before running the notebook.

In [None]:
# ── Edit these ────────────────────────────────────────────────────────────────
DATASET_NAME = 'CODEX_cHL_CIM_MASK_VP_LONG'
RUNS_DIR     = '/home/simon_g/isilon_images_mnt/10_MetaSystems/MetaSystemsData/_simon/src/MCA/z_RUNS'

N_TRAIN      = 10_000   # train samples to subsample for UMAP
N_VAL        =  5_000   # val   samples to subsample for UMAP
UMAP_DIMS    = 2        # 2 or 3
FEAT_PER_CH  = 32       # feature channels per marker (must match model config)
RANDOM_SEED  = 42

## Imports

In [None]:
%load_ext autoreload
%autoreload 2

import json
from pathlib import Path

import einops
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import umap
from IPython.display import display
from sklearn.metrics import confusion_matrix

RUN_DIR = Path(RUNS_DIR) / DATASET_NAME
print(f'Run directory: {RUN_DIR}')

---
## 1. Load Results

The `.npz` files are written by `val_hook.py` after training.
Each file contains features, string labels, top-1 / top-2 predictions, per-class logits, and sample IDs.

In [None]:
train_file = np.load(RUN_DIR / 'train_results.npz')
val_file   = np.load(RUN_DIR / 'val_results.npz')

train_features   = train_file['features']
val_features     = val_file['features']
train_labels_str = train_file['labels_str']
val_labels_str   = val_file['labels_str']
train_preds_str  = train_file['top1_pred_str']
val_preds_str    = val_file['top1_pred_str']
train_logits     = train_file['logits']   # (N, n_classes) – classifier probabilities
val_logits       = val_file['logits']

classes   = list(val_file['classes'])
n_classes = len(classes)

print(f'Dataset : {DATASET_NAME}')
print(f'Classes : {classes}')
print()
print(f'Train  → {len(train_features):>7,} cells  |  feature dim: {train_features.shape[1]}')
print(f'Val    → {len(val_features):>7,} cells  |  feature dim: {val_features.shape[1]}')

### Class distribution

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 4), sharey=False)

for ax, labels, split in [
    (axes[0], train_labels_str, 'Train'),
    (axes[1], val_labels_str,   'Val'),
]:
    unique, counts = np.unique(labels, return_counts=True)
    order = np.argsort(counts)[::-1]
    bars = ax.bar(unique[order], counts[order], color='steelblue', edgecolor='white', linewidth=0.5)

    # Label each bar with its count
    for bar, cnt in zip(bars, counts[order]):
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            bar.get_height() + counts.max() * 0.01,
            f'{cnt:,}', ha='center', va='bottom', fontsize=8
        )

    ax.set_title(f'{split} class distribution  (n={len(labels):,})', fontsize=11)
    ax.set_xlabel('Cell type')
    ax.set_ylabel('Count')
    ax.tick_params(axis='x', rotation=40)
    ax.spines[['top', 'right']].set_visible(False)

plt.tight_layout()
plt.show()

---
## 2. Classification Metrics

Metrics are computed by a logistic regression probe trained on the frozen features (see `val_hook.py`).

In [None]:
metrics_path = RUN_DIR / 'metrics.json'

if metrics_path.exists():
    with open(metrics_path) as f:
        metrics = json.load(f)

    rows = []
    for split in ['train', 'val']:
        m = metrics[split]
        rows.append({
            'Split':             split.capitalize(),
            'Top-1 Acc':         f"{m['top1_accuracy']:.3f}",
            'Top-2 Acc':         f"{m['top2_accuracy']:.3f}",
            'Bal. Acc (top-1)':  f"{m['top1_balanced_accuracy']:.3f}",
            'Bal. Acc (top-2)':  f"{m['top2_balanced_accuracy']:.3f}",
            'F1 (weighted)':     f"{m['f1']:.3f}",
            'N samples':         f"{m['n_samples']:,}",
        })

    display(pd.DataFrame(rows).set_index('Split').style.set_caption(
        f'Logistic-regression probe — {n_classes} classes'
    ))
else:
    print(f'metrics.json not found at {metrics_path}\nCompute metrics inline:')
    from sklearn.metrics import accuracy_score, balanced_accuracy_score, f1_score
    for labels, preds, split in [
        (train_labels_str, train_preds_str, 'Train'),
        (val_labels_str,   val_preds_str,   'Val'),
    ]:
        acc  = accuracy_score(labels, preds)
        bacc = balanced_accuracy_score(labels, preds)
        f1   = f1_score(labels, preds, average='weighted')
        print(f'  {split}: Acc={acc:.3f}  Bal.Acc={bacc:.3f}  F1={f1:.3f}')

### Confusion matrices

Row-normalised (true label on y-axis, predicted on x-axis).  
The diagonal is the per-class recall.

In [None]:
cell_size = max(0.7, 6 / n_classes)
fig, axes = plt.subplots(1, 2, figsize=(cell_size * n_classes * 2 + 2, cell_size * n_classes + 1))

for ax, labels, preds, title in [
    (axes[0], train_labels_str, train_preds_str, 'Train'),
    (axes[1], val_labels_str,   val_preds_str,   'Val'),
]:
    cm = confusion_matrix(labels, preds, labels=classes, normalize='true')
    im = ax.imshow(cm, cmap='Blues', vmin=0, vmax=1)

    for i in range(n_classes):
        for j in range(n_classes):
            ax.text(j, i, f'{cm[i, j]:.2f}',
                    ha='center', va='center', fontsize=8,
                    color='white' if cm[i, j] > 0.55 else 'black')

    ax.set_xticks(range(n_classes))
    ax.set_yticks(range(n_classes))
    ax.set_xticklabels(classes, rotation=45, ha='right', fontsize=9)
    ax.set_yticklabels(classes, fontsize=9)
    ax.set_xlabel('Predicted')
    ax.set_ylabel('True')
    ax.set_title(f'{title} confusion matrix (row-normalised)', fontsize=11)
    plt.colorbar(im, ax=ax, shrink=0.8)

plt.tight_layout()
plt.show()

---
## 3. UMAP Embedding

UMAP projects the high-dimensional features into 2-D for visual inspection.
We subsample from the full split to keep the embedding tractable.

In [None]:
rng = np.random.default_rng(RANDOM_SEED)

def _subsample(feat, labels, preds, logits, n):
    idx = rng.permutation(len(feat))[:n]
    return feat[idx], labels[idx], preds[idx], logits[idx]

tr_feat, tr_lab, tr_pred, tr_logit = _subsample(
    train_features, train_labels_str, train_preds_str, train_logits, N_TRAIN
)
va_feat, va_lab, va_pred, va_logit = _subsample(
    val_features, val_labels_str, val_preds_str, val_logits, N_VAL
)

all_feat  = np.vstack([tr_feat, va_feat])
all_lab   = np.concatenate([tr_lab,   va_lab])
all_pred  = np.concatenate([tr_pred,  va_pred])
all_split = np.array(['Train'] * len(tr_feat) + ['Val'] * len(va_feat))
all_conf  = np.concatenate([tr_logit.max(1), va_logit.max(1)])  # top-class probability
all_corr  = np.where(all_lab == all_pred, 'Correct', 'Incorrect')

print(f'Subsampled: {len(tr_feat):,} train + {len(va_feat):,} val = {len(all_feat):,} total')
print(f'Subsample accuracy: {(all_lab == all_pred).mean():.3f}')

In [None]:
print('Fitting UMAP…')
reducer = umap.UMAP(
    n_neighbors=15,
    min_dist=0.05,
    n_components=UMAP_DIMS,
    metric='euclidean',
    random_state=RANDOM_SEED,
    n_jobs=4,
    verbose=True,
)
embedding = reducer.fit_transform(all_feat)
print(f'Done. Embedding shape: {embedding.shape}')

In [None]:
# Build a single DataFrame used by all downstream plots
df = pd.DataFrame({
    'x':          embedding[:, 0],
    'y':          embedding[:, 1],
    'label':      all_lab,
    'predicted':  all_pred,
    'correct':    all_corr,
    'confidence': all_conf,
    'split':      all_split,
})
if UMAP_DIMS == 3:
    df['z'] = embedding[:, 2]

# Consistent colour palette (same colour = same class, across all plots)
label_colors = {
    lab: px.colors.qualitative.Alphabet[i % len(px.colors.qualitative.Alphabet)]
    for i, lab in enumerate(sorted(df['label'].unique()))
}

SCATTER_KW = dict(  # shared keyword arguments for scatter plots
    x='x', y='y',
    symbol='split', symbol_map={'Train': 'circle', 'Val': 'cross'},
    hover_data=['label', 'predicted', 'correct', 'confidence', 'split'],
    width=1100, height=700,
    opacity=0.65,
)

def _style(fig):
    fig.update_traces(marker=dict(size=4, line=dict(width=0)))
    fig.update_layout(xaxis_title='UMAP-1', yaxis_title='UMAP-2')
    return fig

### 3a. Coloured by true cell type

In [None]:
fig = px.scatter(
    df, color='label', color_discrete_map=label_colors,
    title='UMAP — true cell type',
    **SCATTER_KW
)
_style(fig).update_layout(legend_title='Cell type')
fig.show()

### 3b. Correct vs. incorrect predictions

Green points are correctly classified; red points are errors.  
Clusters of red points suggest confusion between specific cell types.

In [None]:
# Put Correct on top so Incorrect (red) is not hidden
df_sorted = pd.concat([
    df[df['correct'] == 'Correct'],
    df[df['correct'] == 'Incorrect'],
])

fig = px.scatter(
    df_sorted, color='correct',
    color_discrete_map={'Correct': '#2ecc71', 'Incorrect': '#e74c3c'},
    category_orders={'correct': ['Correct', 'Incorrect']},
    title='UMAP — prediction correctness',
    **{k: v for k, v in SCATTER_KW.items() if k != 'opacity'},
    opacity=0.55,
)
_style(fig).update_layout(legend_title='Prediction')
fig.show()

### 3c. Coloured by classifier confidence

Confidence = max predicted probability from the logistic regression probe.  
Low-confidence regions (blue/yellow) indicate class boundaries or ambiguous cells.

In [None]:
fig = px.scatter(
    df.sort_values('confidence'),   # low confidence rendered last → visible on top
    color='confidence',
    color_continuous_scale='RdYlGn',
    range_color=[0, 1],
    title='UMAP — classifier confidence (max class probability)',
    **{k: v for k, v in SCATTER_KW.items() if k not in ('symbol', 'symbol_map')},
)
_style(fig)
fig.show()

### 3d. Error analysis — incorrect predictions only

Shows *where* in the embedding errors occur and *what* the model predicted instead.

In [None]:
df_err = df[df['correct'] == 'Incorrect'].copy()
n_err  = len(df_err)
n_tot  = len(df)
print(f'Errors: {n_err:,} / {n_tot:,}  ({100 * n_err / n_tot:.1f} %)')

fig = px.scatter(
    df_err, color='predicted', color_discrete_map=label_colors,
    title=f'Errors only (n={n_err:,}) — coloured by predicted class',
    **SCATTER_KW,
)
_style(fig).update_layout(legend_title='Predicted as')
fig.show()

---
## 4. Marker Activations in UMAP Space

The feature vector is structured as `(C × FEAT_PER_CH,)` where `C` is the number of markers.  
We take the mean over the `FEAT_PER_CH` channels per marker to get a scalar activation per cell per marker,
then overlay this on the UMAP.

**Requires** the dataset config to resolve marker names.

In [None]:
from mmengine import Config
from mmengine.registry import DATASETS

cfg_path = RUN_DIR / f'{DATASET_NAME}.py'
cfg      = Config.fromfile(str(cfg_path))
dataset  = DATASETS.build(cfg['train_dataset'])

markers = list(dataset.marker2idx.keys())
print(f'Markers ({len(markers)}): {markers}')

expected_feat_dim = len(markers) * FEAT_PER_CH
actual_feat_dim   = all_feat.shape[1]
if expected_feat_dim != actual_feat_dim:
    print(f'\nWARNING: expected {expected_feat_dim} = {len(markers)} × {FEAT_PER_CH}'
          f' but features have dim {actual_feat_dim}.')
    print('Adjust FEAT_PER_CH in the Configuration cell.')

In [None]:
# (N, C*F) → mean over F → (N, C)
activations = einops.rearrange(
    all_feat, 'N (C F) -> N C F', C=len(markers), F=FEAT_PER_CH
).mean(axis=-1)

for marker, idx in dataset.marker2idx.items():
    df[marker] = activations[:, idx]

print(f'Added {len(markers)} marker activation columns to the DataFrame.')
display(df[markers].describe().round(3))

### Marker activation grid

Each sub-plot shows the UMAP coloured by one marker's activation.  
Colour is clipped to the 2–98th percentile range to suppress outliers.

In [None]:
N_COLS = 4
N_ROWS = int(np.ceil(len(markers) / N_COLS))

fig = make_subplots(
    rows=N_ROWS, cols=N_COLS,
    subplot_titles=markers,
    shared_xaxes=True, shared_yaxes=True,
    vertical_spacing=0.04, horizontal_spacing=0.02,
)

for i, marker in enumerate(markers):
    row, col = divmod(i, N_COLS)
    vals = df[marker].values
    vmin, vmax = np.percentile(vals, [2, 98])

    fig.add_trace(
        go.Scattergl(
            x=df['x'], y=df['y'],
            mode='markers',
            marker=dict(
                size=2,
                color=vals,
                colorscale='Viridis',
                cmin=vmin, cmax=vmax,
                showscale=(i == 0),
                colorbar=dict(title='activation', thickness=10, len=0.3, y=0.85) if i == 0 else {},
            ),
            showlegend=False,
            hovertemplate=f'<b>{marker}</b>: %{{marker.color:.3f}}<br>label: %{{text}}<extra></extra>',
            text=df['label'],
        ),
        row=row + 1, col=col + 1,
    )

fig.update_xaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_yaxes(showticklabels=False, showgrid=False, zeroline=False)
fig.update_layout(
    title='Marker activations on UMAP',
    height=260 * N_ROWS,
    width=1100,
)
fig.show()

### Single-marker interactive view

Change `MARKER` to inspect any one marker in detail.

In [None]:
MARKER = markers[0]  # ← change this

vals = df[MARKER].values
vmin, vmax = np.percentile(vals, [2, 98])

fig = px.scatter(
    df, x='x', y='y',
    color=MARKER,
    color_continuous_scale='Viridis',
    range_color=[vmin, vmax],
    hover_data=['label', 'predicted', MARKER],
    title=f'Marker activation: {MARKER}',
    width=900, height=650, opacity=0.7,
)
fig.update_traces(marker=dict(size=4, line=dict(width=0)))
fig.update_layout(xaxis_title='UMAP-1', yaxis_title='UMAP-2')
fig.show()

---
## 5. Per-class Marker Profile

Shows which markers are most active (or uniquely active) for each cell type.

- **Raw**: mean activation per (class, marker) pair.
- **Z-scored**: each marker's activations are standardised across classes,
  so a high value means that class has unusually high activation for that marker.

In [None]:
from scipy.stats import zscore

class_means = (
    df.groupby('label')[markers]
    .mean()
    .T   # markers as rows, classes as columns
)
class_z = class_means.apply(zscore, axis=1)  # z-score across classes per marker

cell_h = max(0.35, 6 / len(markers))
cell_w = max(0.7,  6 / n_classes)

fig, axes = plt.subplots(
    1, 2,
    figsize=(cell_w * n_classes * 2 + 3, cell_h * len(markers) + 1),
)

for ax, data, title, cmap, center in [
    (axes[0], class_means, 'Mean activation (raw)',                  'YlOrRd', None),
    (axes[1], class_z,     'Z-scored activation (across classes)',   'RdBu_r', 0.0),
]:
    vabs = np.abs(data.values).max() if center is not None else None
    vmin = -vabs if center is not None else data.values.min()
    vmax =  vabs if center is not None else data.values.max()

    im = ax.imshow(data.values, aspect='auto', cmap=cmap, vmin=vmin, vmax=vmax)
    plt.colorbar(im, ax=ax, shrink=0.6)

    ax.set_xticks(range(len(data.columns)))
    ax.set_xticklabels(data.columns, rotation=45, ha='right', fontsize=9)
    ax.set_yticks(range(len(data.index)))
    ax.set_yticklabels(data.index, fontsize=9)
    ax.set_xlabel('Cell type')
    ax.set_ylabel('Marker')
    ax.set_title(title, fontsize=11)

plt.tight_layout()
plt.show()

### Top discriminative markers per class

Ranked by z-scored activation — the markers that are most uniquely high for each class.

In [None]:
N_TOP = 5

rows = []
for cls in sorted(class_z.columns):
    top = class_z[cls].nlargest(N_TOP)
    rows.append({'Class': cls, **{f'#{i+1}': f'{m} ({v:.2f})' for i, (m, v) in enumerate(top.items())}})

display(pd.DataFrame(rows).set_index('Class').style.set_caption(
    f'Top {N_TOP} markers per class (z-scored activation)'
))

### Per-marker expression boxplot by class

Change `MARKER` to inspect any marker's distribution across all classes.

In [None]:
MARKER = markers[0]  # ← change this

fig = px.box(
    df, x='label', y=MARKER,
    color='label', color_discrete_map=label_colors,
    points='outliers',
    title=f'Activation distribution: {MARKER}',
    labels={'label': 'Cell type', MARKER: 'Activation'},
    width=900, height=500,
)
fig.update_layout(showlegend=False, xaxis_tickangle=-35)
fig.show()

---
> **Next steps**  
> - For spatial attention masks and channel cross-attention maps, see **`attention_maps.ipynb`**.  
> - For raw dataset exploration (patch browsing, marker coverage), see **`dataset_exploration.ipynb`**.