## Imports

In [None]:
import h5py
import numpy as np
import scipy as sp
import skimage as ski
import matplotlib.pyplot as plt
from scipy.optimize import nnls
from ipywidgets import interact

from fusion import optimize, optimize_with_continuous_time
from baseline import baseline
from utils import (
    RESOURCES_PATH,
    spectral_volume_to_srgb,
    get_discrete_time_decay,
    load_raw_spc,
    calibrate_spc,
    cut_spc,
    bin_spc,
    reconstruct_spc,
)

LAMBDA_PATH = RESOURCES_PATH / "beads" / "575_Lambda_L16.mat"
CMOS_PATH = RESOURCES_PATH / "beads" / "3beads_triangle_w4_rec_Hil2D_FOVcorrected.mat"
SPC_PATH = RESOURCES_PATH / "beads" / "480_3beads_triangle_505_500_575_SPC_calib_cut_binned_tlxy.npz"

PREPROCESS = False
XY_DIM = 128
Z_START = 0
Z_END = -1

## SPC preprocessing

In [None]:
if PREPROCESS:
    DATA_PATH = RESOURCES_PATH / "beads"
    RAW_SPC_PATH = DATA_PATH / "480_3beads_triangle_505_500_575_SPC_raw.mat"
    RECONSTRUCTION_SAVE_PATH = DATA_PATH / "480_3beads_triangle_505_500_575_SPC_calib_cut_binned_tlxy.npz"
    SH_PATH = DATA_PATH / "FLIM_Scrambled-Hadamard_1024.mat"
    EFFICIENCY_CALIB_PATH = DATA_PATH / "Efficiency_L16_575.mat"
    OFFSET_CALIB_PATH = DATA_PATH / "L16_temporal_offsets_20220520.mat"
    TEMPORAL_AXIS_PATH = DATA_PATH / "t.npy"

    spc = load_raw_spc(RAW_SPC_PATH)  # (n_times, n_spectra, n_measurements)
    forward_matrix = sp.io.loadmat(SH_PATH)["M"].astype(np.float64)[::2]  # (n_measurements, pattern_size)
    t = np.load(TEMPORAL_AXIS_PATH).flatten().astype(np.float64)  # (n_times,)

    spc_calib = calibrate_spc(spc, EFFICIENCY_CALIB_PATH, OFFSET_CALIB_PATH)
    spc_calib_cut, t_cut = cut_spc(spc_calib, t, max_times=2048)
    spc_calib_cut_binned, t_cut_binned, dt_cut_binned = bin_spc(spc_calib_cut, t_cut, n_bins=32)
    spc_recon = reconstruct_spc(
        spc_calib_cut_binned,
        forward_matrix,
        algo=nnls,
        n_jobs=8,
    )  # (n_times, n_spectra, img_dim, img_dim)
    np.savez_compressed(
        RECONSTRUCTION_SAVE_PATH,
        spc_recon=spc_recon,
        t_cut_binned=t_cut_binned,
        dt_cut_binned=dt_cut_binned,
    )

## Data loading

In [None]:
with h5py.File(CMOS_PATH, "r") as f:
    cmos = np.array(f["I"])
    cmos = np.transpose(cmos, (1, 2, 0))
    cmos = ski.transform.resize(cmos, (XY_DIM, XY_DIM, cmos.shape[2]))
    cmos = np.transpose(cmos, (2, 1, 0))
    cmos = cmos[Z_START:Z_END]

# spc = sp.io.loadmat(SPC_PATH)["im"]
# t = np.squeeze(sp.io.loadmat(SPC_PATH)["t"])
spc = np.load(SPC_PATH)["spc_recon"].swapaxes(-2, -1)
t = np.load(SPC_PATH)["t_cut_binned"]

# FIXME: Replace with correct [0,0] pixel.
spc[:, :, 0, 0] = spc[:, :, 1, 0]
# gt_spc[:, :, 0, 0] = gt_spc[:, :, 1, 0]

t = t - t.min()
dt = t[1] - t[0]
lam = np.squeeze(sp.io.loadmat(LAMBDA_PATH)["lambda"])

In [None]:
_, ax = plt.subplots(1, 2, figsize=(7, 3))
ax[0].plot(t, spc.sum(axis=(1, 2, 3)))
ax[0].set_title(f"Global time decay - {len(t)} temporal points")
ax[0].set_xlabel("Time [ns]")
ax[0].grid()

ax[1].plot(lam, spc.sum(axis=(0, 2, 3)))
ax[1].set_title(f"Global spectrum - {len(lam)} channels")
ax[1].set_xlabel("Wavelength [nm]")
ax[1].grid()

plt.tight_layout()
plt.show()

In [None]:
cmos_energy = cmos.sum()
spc_energy = spc.sum()

cmos_max = cmos / cmos.max()
spc_max = spc / spc.max()

cmos /= cmos_energy
spc /= spc_energy

In [None]:
initial_spectrums = np.sum(spc, axis=0)
min_spectrums, max_spectrums = np.min(initial_spectrums), np.max(initial_spectrums)

initial_times = np.sum(spc, axis=1)
min_times, max_times = np.min(initial_times), np.max(initial_times)

resolution_diff_factor = int(cmos.shape[-1] / spc.shape[-1])


def spc_spectrum_time_in_a_point(cmos_z=5, spc_i=19, spc_j=17):
    _, ax = plt.subplots(2, 3, figsize=(9, 6))
    ax[0, 0].imshow(cmos_max[cmos_z], cmap="gray", vmin=0, vmax=1)
    ax[0, 0].scatter([spc_j * resolution_diff_factor], [spc_i * resolution_diff_factor], c="w")
    ax[0, 0].set_title(f"CMOS in Z={cmos_z}")

    ax[1, 0].imshow(spectral_volume_to_srgb(lam, spc.sum(axis=0)[:, np.newaxis])[0])
    ax[1, 0].scatter([spc_j], [spc_i], c="w")
    ax[1, 0].set_title(f"Colored with spectrum")

    ax[0, 1].plot(lam, spc.sum(axis=(0, 2, 3)))
    ax[0, 1].set_title(f"Global spectrum - {len(lam)} channels")
    ax[0, 1].set_xlabel("Wavelength [nm]")
    ax[0, 1].grid()

    ax[0, 2].plot(t, spc.sum(axis=(1, 2, 3)))
    ax[0, 2].set_title(f"Global time decay - {len(t)} temporal points")
    ax[0, 2].set_xlabel("Time [ns]")
    ax[0, 2].grid()

    ax[1, 1].plot(lam, initial_spectrums[:, spc_i, spc_j])
    ax[1, 1].set_ylim(min_spectrums, max_spectrums)
    ax[1, 1].set_title(f"Spectrum in ({spc_i},{spc_j})")
    ax[1, 1].set_xlabel("Wavelength [nm]")
    ax[1, 1].grid()

    ax[1, 2].plot(t, initial_times[:, spc_i, spc_j])
    ax[1, 2].set_ylim(min_times, max_times)
    ax[1, 2].set_title(f"Time in ({spc_i},{spc_j})")
    ax[1, 2].set_xlabel("Time [ns]")
    ax[1, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    spc_spectrum_time_in_a_point,
    cmos_z=(0, cmos.shape[0] - 1, 1),
    spc_i=(0, spc.shape[-2] - 1, 1),
    spc_j=(0, spc.shape[-1] - 1, 1),
);

## Fusion through optimization

In [None]:
print(f"Machine precision: {np.finfo(np.float64).eps}")
print(f"SPC dtype: {spc.dtype}", f"CMOS dtype: {cmos.dtype}")

weights = {
    "spatial": 10,
    "lambda_time": 10,
    "global": 0.001,
}

x = optimize(
    spc,
    cmos,
    weights=weights,
    lr=1e-5,
    iterations=100,
    device="cpu",
    init_type="random",
    mask_initializations=True,
    mask_gradients=True,
    non_neg=True,
    return_numpy=True,
)

In [None]:
print(x.sum(), cmos.sum(), spc.sum())

## Fusion through optimization with continuous time

In [None]:
weights = {
    "spectral": 1.0,
    "time": 2.0,
    "spatial": 0.1,
    "spectral_time": 0.0,
}

# spc_denoised = np.zeros_like(spc)
# for li in range(spc.shape[0]):
#     for yi in range(spc.shape[3]):
#         for xi in range(spc.shape[2]):
#             spc_denoised[li, :, xi, yi] = wavelet_denoising(spc[li, :, xi, yi], wavelet="db2", threshold=0.1)

x = optimize_with_continuous_time(
    spc,
    cmos,
    weights=weights,
    t=np.arange(0.00001, dt * len(t), dt),
    n_decays=1,
    lr=0.005,
    iterations=1000,
    device="cpu",
    init_type="random",
    mask_initializations=False,
    mask_gradients=False,
    non_neg=True,
    return_numpy=False,
)

# If you use optimize with continuous time, you can use this to get the discrete time decay
x_cont = x.cpu().detach().numpy().copy()

x = get_discrete_time_decay(x.detach(), np.arange(0.00001, dt * len(t), dt))
x = x.cpu().numpy()

## Fusion through baseline

In [None]:
x_baseline = baseline(cmos, spc, device="cpu", return_numpy=True)

## Visualize results

In [None]:
slices_rgb = spectral_volume_to_srgb(lam, np.mean(x, axis=0))
slices_rgb_baseline = spectral_volume_to_srgb(lam, np.mean(x_baseline, axis=0))

means_spectrums = np.sum(x, axis=(0, 3, 4)).T
min_mean_spectrum = np.min(means_spectrums)
max_mean_spectrum = np.max(means_spectrums)

means_times = np.sum(x, axis=(1, 3, 4)).T
min_mean_times = np.min(means_times)
max_mean_times = np.max(means_times)

zxy = np.mean(x, axis=(0, 1))
zxy /= zxy.max()


def plot_across_z(z=5, i=0, j=0):
    _, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax[0, 0].imshow(slices_rgb_baseline[z])
    ax[0, 0].scatter([j], [i], c="w")
    ax[0, 0].set_title(f"Baseline z={z}")

    ax[1, 1].plot(lam, means_spectrums[z])
    ax[1, 1].set_ylim(min_mean_spectrum, max_mean_spectrum)
    ax[1, 1].set_title(f"Global Spectrum in z={z}")
    ax[1, 1].grid()

    ax[1, 2].plot(t, means_times[z])
    ax[1, 2].set_ylim(min_mean_times, max_mean_times)
    ax[1, 2].set_title(f"Global Time in z={z}")
    ax[1, 2].grid()

    reconstructed_spectrums = np.sum(x[:, :, z, :, :], axis=0)
    baseline_spectrums = np.sum(x_baseline[:, :, z, :, :], axis=0)
    lxy_spc = np.sum(spc, axis=0)

    reconstructed_times = np.sum(x[:, :, z, :, :], axis=1)
    baseline_times = np.sum(x_baseline[:, :, z, :, :], axis=1)
    txy_spc = np.sum(spc, axis=1)

    ax[1, 0].imshow(slices_rgb[z])
    ax[1, 0].scatter([j], [i], c="w")
    ax[1, 0].set_title(f"Reconstruction Spectral Colored z={z}")

    ax[0, 1].plot(lam, reconstructed_spectrums[:, i, j] * 15 * 16, label="Datafusion")
    ax[0, 1].plot(lam, baseline_spectrums[:, i, j], label="Baseline")
    ax[0, 1].plot(lam, lxy_spc[:, i // 4, j // 4], label="SPC")
    ax[0, 1].set_title(f"Spectrum in ({i},{j})")
    ax[0, 1].legend(loc="upper right")
    ax[0, 1].grid()

    ax[0, 2].plot(t - t.min(), reconstructed_times[:, i, j] * 15 * 16, label="Datafusion")
    ax[0, 2].plot(t - t.min(), baseline_times[:, i, j], label="Baseline")
    ax[0, 2].plot(t - t.min(), txy_spc[:, i // 4, j // 4], label="SPC")
    ax[0, 2].set_title(f"Time in ({i},{j})")
    ax[0, 2].legend(loc="upper right")
    ax[0, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    plot_across_z,
    z=(0, len(zxy) - 1, 1),
    i=(0, x.shape[-2] - 1, 1),
    j=(0, x.shape[-1] - 1, 1),
);


## Plot results