# 4-Beam Cross-Talk Demixing

Problem: When one beam scans, other detectors pick up leaked signal (cross-talk). Each detector sees a mix of its own beam plus bleed from the other 3.

Solution: Subtract the estimated cross-talk from each detector's signal using calibration data.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import mbo_utilities as mbo
from pathlib import Path

In [None]:
# configuration
beam_labels = ["NW", "SW", "SE", "NE"]
fs_to_label = {"FS1": "NW", "FS2": "SW", "FS3": "SE", "FS4": "NE"}
label_to_fs = {v: k for k, v in fs_to_label.items()}

base_recording = r"D:\cj\2025-11-21\raw"
base_calib = r"D:\cj\2025-11-21\calibration"  # adjust path as needed

cal_mode = "mip"  # "mean" or "mip"

## Load Data

- VD0/VD1: two vDAQ channels, each with 2 planes
- Reorder planes to match beam labels (NW, SW, SE, NE)

In [None]:
def load_mean_recording(base_path):
    """load recordings and compute temporal mean per plane."""
    vd0 = imread(os.path.join(base_path, "vdaq0"), mode="r")[:].astype(np.float32)
    vd1 = imread(os.path.join(base_path, "vdaq1"), mode="r")[:].astype(np.float32)
    
    vd0_mean = vd0.mean(axis=0)
    vd1_mean = vd1.mean(axis=0)
    
    raw = [vd0_mean[0], vd0_mean[1], vd1_mean[0], vd1_mean[1]]
    perm = [3, 2, 1, 0]
    return dict(zip(beam_labels, [raw[k] for k in perm]))


def load_full_recording(base_path):
    """load full timeseries, returns (4, T, Y, X)."""
    vd0 = imread(os.path.join(base_path, "vdaq0"), mode="r")[:].astype(np.float32)
    vd1 = imread(os.path.join(base_path, "vdaq1"), mode="r")[:].astype(np.float32)
    
    raw = [vd0[:, 0], vd0[:, 1], vd1[:, 0], vd1[:, 1]]
    perm = [3, 2, 1, 0]
    return np.stack([raw[k] for k in perm], axis=0)


def load_calibration(base_path, scan_label, mode="mip"):
    """load calibration scan for a single beam."""
    fs_label = label_to_fs[scan_label]
    
    z0 = imread(os.path.join(base_path, f"{fs_label}_VD0"), mode="r")[:].astype(np.float32)
    z1 = imread(os.path.join(base_path, f"{fs_label}_VD1"), mode="r")[:].astype(np.float32)
    
    if mode == "mean":
        vd0, vd1 = z0.mean(axis=0), z1.mean(axis=0)
    else:  # mip
        vd0, vd1 = z0.max(axis=0), z1.max(axis=0)
    
    raw = [vd0[0], vd0[1], vd1[0], vd1[1]]
    perm = [2, 3, 1, 0]
    return [raw[k] for k in perm]

In [None]:
# load recording
R = load_full_recording(base_recording)
print(f"Recording shape [det, T, Y, X]: {R.shape}")

mean_rec = load_mean_recording(base_recording)

In [None]:
# load calibration and build mixing matrix
cal_all = {label: load_calibration(base_calib, label, cal_mode) for label in beam_labels}

ny, nx = mean_rec["NW"].shape
M_init = np.zeros((4, 4, ny, nx), dtype=np.float32)

for i, scan in enumerate(beam_labels):
    for j, det in enumerate(beam_labels):
        M_init[i, j] = cal_all[scan][j]

print(f"Mixing matrix shape [src, det, Y, X]: {M_init.shape}")

## Mixing Matrix

```
M_init[src, det, y, x]
```
- Row = source beam being scanned alone during calibration
- Column = detector response
- Diagonal = direct signal
- Off-diagonal = cross-talk

In [None]:
fig, axes = plt.subplots(4, 4, figsize=(10, 10))
fig.suptitle(f"Mixing Matrix ({cal_mode})")

for i, src in enumerate(beam_labels):
    for j, det in enumerate(beam_labels):
        ax = axes[i, j]
        ax.imshow(M_init[i, j], cmap="gray")
        ax.set_xticks([])
        ax.set_yticks([])
        if i == 0:
            ax.set_title(det)
        if j == 0:
            ax.set_ylabel(src, rotation=0, labelpad=20)

plt.tight_layout()
plt.show()

## Demixing Algorithm

For each detector, subtract estimated cross-talk from other sources:

```
S[det] = R[det] - sum(conf * w[src,det] * R[src])  for src != det
```

**Components:**

| Term | Formula | Purpose |
|------|---------|--------|
| `noise_floor` | `percentile(diagonal, 20)` | background estimate |
| `conf[det]` | `diag / (diag + noise_floor)` | pixel-wise confidence (0-1) |
| `w[src,det]` | `M[src,det] / M[det,det]` | cross-talk ratio |

**Not linear algebra** - just pixel-wise weighted subtraction.

In [None]:
def demix_recording(R, M_init, noise_percentile=20):
    """confidence-weighted cross-talk subtraction."""
    D, T, Y, X = R.shape
    
    # noise floor from diagonal
    noise_floor = np.array([
        np.percentile(M_init[i, i], noise_percentile)
        for i in range(D)
    ], dtype=np.float32)
    
    # confidence weights
    conf = np.zeros((D, Y, X), dtype=np.float32)
    for det in range(D):
        diag = M_init[det, det]
        conf[det] = diag / (diag + noise_floor[det])
    
    # cross-talk weights
    w = np.zeros_like(M_init)
    for src in range(D):
        for det in range(D):
            if src != det:
                w[src, det] = M_init[src, det] / (M_init[det, det] + 1e-12)
    
    # subtract cross-talk
    S = np.zeros_like(R)
    for det in range(D):
        S[det] = R[det].copy()
        for src in range(D):
            if src != det:
                S[det] -= conf[det][None, :, :] * w[src, det][None, :, :] * R[src]
        S[det] = np.clip(S[det], 0, None)
    
    return S

In [None]:
S_est = demix_recording(R, M_init)
print(f"Demixed shape: {S_est.shape}")

## Results: Raw vs Demixed

In [None]:
fig, axes = plt.subplots(4, 3, figsize=(12, 16))

for i, label in enumerate(beam_labels):
    raw_mip = R[i].max(axis=0)
    demix_mip = S_est[i].max(axis=0)
    diff = raw_mip - demix_mip
    
    axes[i, 0].imshow(raw_mip, cmap="gray")
    axes[i, 0].set_title(f"{label} Raw MIP")
    axes[i, 0].axis("off")
    
    axes[i, 1].imshow(demix_mip, cmap="gray")
    axes[i, 1].set_title(f"{label} Demixed MIP")
    axes[i, 1].axis("off")
    
    axes[i, 2].imshow(diff, cmap="RdBu_r", vmin=-diff.max(), vmax=diff.max())
    axes[i, 2].set_title(f"{label} Difference")
    axes[i, 2].axis("off")

plt.tight_layout()
plt.show()

## Diagnostics

If cross-talk persists, check:

1. **Alignment** - calibration vs recording spatial match
2. **Contaminated source** - `R[src]` itself has cross-talk
3. **Confidence under-subtraction** - dim regions subtract less
4. **Wrong noise percentile** - try 10 or 30 instead of 20

In [None]:
# check calibration-recording alignment
fig, axes = plt.subplots(2, 2, figsize=(10, 10))

for i, (label, ax) in enumerate(zip(beam_labels, axes.flat)):
    diff = M_init[i, i] - mean_rec[label]
    ax.imshow(diff, cmap="RdBu_r")
    ax.set_title(f"{label}: calib - recording")
    ax.axis("off")

plt.suptitle("Alignment Check (should be ~uniform if aligned)")
plt.tight_layout()
plt.show()

In [None]:
# try iterative demixing
S_iter = R.copy()
for iteration in range(3):
    S_iter = demix_recording(S_iter, M_init)
    print(f"Iteration {iteration + 1} complete")

In [None]:
# compare single-pass vs iterative
fig, axes = plt.subplots(4, 2, figsize=(10, 16))

for i, label in enumerate(beam_labels):
    axes[i, 0].imshow(S_est[i].max(axis=0), cmap="gray")
    axes[i, 0].set_title(f"{label} Single-pass")
    axes[i, 0].axis("off")
    
    axes[i, 1].imshow(S_iter[i].max(axis=0), cmap="gray")
    axes[i, 1].set_title(f"{label} 3 Iterations")
    axes[i, 1].axis("off")

plt.tight_layout()
plt.show()

## Save Results

In [None]:
# # uncomment to save
# output_root = r"D:\cj\2025-11-21\demixed"
# os.makedirs(output_root, exist_ok=True)
# 
# for idx, label in enumerate(beam_labels):
#     outpath = os.path.join(output_root, label)
#     imwrite(
#         lazy_array=imread(S_est[idx]),
#         outpath=outpath,
#         ext=".zarr",
#         overwrite=True,
#     )
#     print(f"Saved {label} to {outpath}")