## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import h5py
import numpy as np
import scipy as sp
import skimage as ski
import matplotlib.pyplot as plt
from ipywidgets import interact

from fusion import fuse
from baseline import baseline
from utils import (
    RESOURCES_PATH,
    spectral_volume_to_color,
    preprocess_raw_spc,
)

XY_DIM = 128
Z_START = 0
Z_END = -6

BEADS_PATH = RESOURCES_PATH / "beads"
SPC_PATH = BEADS_PATH / "480_3beads_triangle_505_500_575_SPC_calib_cut_binned_tlxy.npz"


## SPC preprocessing

In [None]:
N_MEASUREMENTS = 1024
MAX_TIMES = 2048
N_BINS = 32

# Preprocess SPC
preprocess_raw_spc(
    raw_spc_path=BEADS_PATH / "480_3beads_triangle_505_500_575_SPC_raw.mat",
    reconstruction_save_path=SPC_PATH,
    forward_matrix_path=BEADS_PATH / "FLIM_Scrambled-Hadamard_1024.mat",
    efficiency_calib_path=BEADS_PATH / "Efficiency_L16_575.mat",
    offset_calib_path=BEADS_PATH / "L16_temporal_offsets_20220520.mat",
    temporal_axis_path=BEADS_PATH / "t.npy",
    n_measurements=N_MEASUREMENTS,
    max_times=MAX_TIMES,
    n_bins=N_BINS,
    n_jobs=8,
    dtype=np.float64,
);

## Data loading

In [None]:
LAMBDA_PATH = RESOURCES_PATH / "beads" / "575_Lambda_L16.mat"
CMOS_PATH = RESOURCES_PATH / "beads" / "3beads_triangle_w4_rec_Hil2D_FOVcorrected.mat"

# CMOS loading
with h5py.File(CMOS_PATH, "r") as f:
    cmos = np.array(f["I"])
    cmos = np.transpose(cmos, (1, 2, 0))
    cmos = ski.transform.resize(cmos, (XY_DIM, XY_DIM, cmos.shape[2]))
    cmos = np.transpose(cmos, (2, 1, 0))
    cmos = cmos[Z_START:Z_END]

# SPC loading
spc = np.load(SPC_PATH)["spc_recon"].swapaxes(-2, -1)
# FIXME: Replace with correct [0,0] pixel.
spc[:, :, 0, 0] = spc[:, :, 1, 0]

# Time axis loading
t = np.load(SPC_PATH)["t_cut_binned"]
t = t - t.min()
dt = t[1] - t[0]

# Wavelength axis loading
lam = np.squeeze(sp.io.loadmat(LAMBDA_PATH)["lambda"])

# Normalization for plotting
cmos_max = cmos / cmos.max()
spc_max = spc / spc.max()

In [None]:
initial_spectrums = np.sum(spc, axis=0)
min_spectrums, max_spectrums = np.min(initial_spectrums), np.max(initial_spectrums)

initial_times = np.sum(spc, axis=1)
min_times, max_times = np.min(initial_times), np.max(initial_times)

resolution_diff_factor = int(cmos.shape[-1] / spc.shape[-1])


def spc_spectrum_time_in_a_point(cmos_z=5, spc_i=16, spc_j=17):
    _, ax = plt.subplots(2, 3, figsize=(9, 6))
    ax[0, 0].imshow(cmos_max[cmos_z], cmap="gray", vmin=0, vmax=1)
    ax[0, 0].scatter([spc_j * resolution_diff_factor], [spc_i * resolution_diff_factor], c="w")
    ax[0, 0].set_title(f"CMOS in Z={cmos_z}")

    ax[1, 0].imshow(spectral_volume_to_color(lam, spc.sum(axis=0)[:, np.newaxis])[0])
    ax[1, 0].scatter([spc_j], [spc_i], c="w")
    ax[1, 0].set_title(f"Colored with spectrum")

    ax[0, 1].plot(lam, spc.sum(axis=(0, 2, 3)))
    ax[0, 1].set_title(f"Global spectrum - {len(lam)} channels")
    ax[0, 1].set_xlabel("Wavelength [nm]")
    ax[0, 1].grid()

    ax[0, 2].plot(t, spc.sum(axis=(1, 2, 3)))
    ax[0, 2].set_title(f"Global time decay - {len(t)} temporal points")
    ax[0, 2].set_xlabel("Time [ns]")
    ax[0, 2].grid()

    ax[1, 1].plot(lam, initial_spectrums[:, spc_i, spc_j])
    ax[1, 1].set_ylim(min_spectrums, max_spectrums)
    ax[1, 1].set_title(f"Spectrum in ({spc_i},{spc_j})")
    ax[1, 1].set_xlabel("Wavelength [nm]")
    ax[1, 1].grid()

    ax[1, 2].plot(t, initial_times[:, spc_i, spc_j])
    ax[1, 2].set_ylim(min_times, max_times)
    ax[1, 2].set_title(f"Time in ({spc_i},{spc_j})")
    ax[1, 2].set_xlabel("Time [ns]")
    ax[1, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    spc_spectrum_time_in_a_point,
    cmos_z=(0, cmos.shape[0] - 1, 1),
    spc_i=(0, spc.shape[-2] - 1, 1),
    spc_j=(0, spc.shape[-1] - 1, 1),
);

## Fusion through optimization

In [None]:
weights = {
    "spatial": 1.0,
    "lambda_time": 1.0,
    "global": 0.0,
}

x, spc_out, cmos_out = fuse(
    spc,
    cmos,
    weights=weights,
    lr=1e-3,
    iterations=100,
    l2_regularization=0,
    device="cpu",
    init_type="random",
    mask_initializations=True,
    mask_gradients=True,
    non_neg=True,
    total_energy=10000,
    return_numpy=True,
)

In [None]:
print(x.sum(), cmos_out.sum(), spc_out.sum())

## Fusion through baseline

In [None]:
x_baseline = baseline(cmos, spc, device="cpu", return_numpy=True)

## Visualize results

In [None]:
slices_rgb = spectral_volume_to_color(lam, np.mean(x, axis=0))
slices_rgb_baseline = spectral_volume_to_color(lam, np.mean(x_baseline, axis=0))

means_spectrums = np.sum(x, axis=(0, 3, 4)).T
min_mean_spectrum = np.min(means_spectrums)
max_mean_spectrum = np.max(means_spectrums)

means_times = np.sum(x, axis=(1, 3, 4)).T
min_mean_times = np.min(means_times)
max_mean_times = np.max(means_times)

zxy = np.mean(x, axis=(0, 1))
zxy /= zxy.max()


def plot_across_z(z=5, i=0, j=0):
    _, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax[0, 0].imshow(slices_rgb_baseline[z])
    ax[0, 0].scatter([j], [i], c="w")
    ax[0, 0].set_title(f"Baseline z={z}")

    ax[1, 1].plot(lam, means_spectrums[z])
    ax[1, 1].set_ylim(min_mean_spectrum, max_mean_spectrum)
    ax[1, 1].set_title(f"Global Spectrum in z={z}")
    ax[1, 1].grid()

    ax[1, 2].plot(t, means_times[z])
    ax[1, 2].set_ylim(min_mean_times, max_mean_times)
    ax[1, 2].set_title(f"Global Time in z={z}")
    ax[1, 2].grid()

    reconstructed_spectrums = np.sum(x[:, :, z, :, :], axis=0)
    baseline_spectrums = np.sum(x_baseline[:, :, z, :, :], axis=0)
    lxy_spc = np.sum(spc, axis=0)

    reconstructed_times = np.sum(x[:, :, z, :, :], axis=1)
    baseline_times = np.sum(x_baseline[:, :, z, :, :], axis=1)
    txy_spc = np.sum(spc, axis=1)

    ax[1, 0].imshow(slices_rgb[z])
    ax[1, 0].scatter([j], [i], c="w")
    ax[1, 0].set_title(f"Reconstruction Spectral Colored z={z}")

    ax[0, 1].plot(lam, reconstructed_spectrums[:, i, j] * 15 * 16, label="Datafusion")
    ax[0, 1].plot(lam, baseline_spectrums[:, i, j], label="Baseline")
    ax[0, 1].plot(lam, lxy_spc[:, i // 8, j // 8], label="SPC")
    ax[0, 1].set_title(f"Spectrum in ({i},{j})")
    ax[0, 1].legend(loc="upper right")
    ax[0, 1].grid()

    ax[0, 2].plot(t - t.min(), reconstructed_times[:, i, j] * 15 * 16, label="Datafusion")
    ax[0, 2].plot(t - t.min(), baseline_times[:, i, j], label="Baseline")
    ax[0, 2].plot(t - t.min(), txy_spc[:, i // 8, j // 8], label="SPC")
    ax[0, 2].set_title(f"Time in ({i},{j})")
    ax[0, 2].legend(loc="upper right")
    ax[0, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    plot_across_z,
    z=(0, len(zxy) - 1, 1),
    i=(0, x.shape[-2] - 1, 1),
    j=(0, x.shape[-1] - 1, 1),
);

In [None]:
# TODO: Figure for the paper with the results on the beads.