## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import h5py
import numpy as np
import scipy as sp
import skimage as ski
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from ipywidgets import interact
from matplotlib import patches
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from matplotlib.colors import LinearSegmentedColormap, hsv_to_rgb

from fusion import FusionAdam, FusionCG
from baseline import baseline
from utils import mono_exponential_decay_numpy as decay
from utils import (
    RESOURCES_PATH,
    FIGURES_PATH,
    spectral_volume_to_color,
    time_volume_to_lifetime,
    preprocess_raw_spc,
    linear_to_srgb,
    wavelength_to_srgb,
    sam,
    rmse,
)

XY_DIM = 128
Z_START = 0
Z_END = -6

TAU = r"$\tau$"

BEADS_PATH = RESOURCES_PATH / "beads"
SPC_PATH = BEADS_PATH / "480_3beads_triangle_505_500_575_SPC_calib_cut_binned_tlxy.npz"

PREPROCESS = False

## SPC preprocessing

In [None]:
if PREPROCESS:
    from utils import TVAL3
    from scipy.linalg import lstsq

    tval = TVAL3(img_shape=(32, 32))

    N_MEASUREMENTS = 1024
    MAX_TIMES = 2048
    N_BINS = 64

    # Preprocess SPC
    preprocess_raw_spc(
        raw_spc_path=BEADS_PATH / "480_3beads_triangle_505_500_575_SPC_raw.mat",
        reconstruction_save_path=SPC_PATH,
        forward_matrix_path=BEADS_PATH / "FLIM_Scrambled-Hadamard_1024.mat",
        efficiency_calib_path=BEADS_PATH / "Efficiency_L16_575.mat",
        offset_calib_path=BEADS_PATH / "L16_temporal_offsets_20220520.mat",
        temporal_axis_path=BEADS_PATH / "t.npy",
        n_measurements=N_MEASUREMENTS,
        max_times=MAX_TIMES,
        algo=lstsq,
        compression=1,
        n_bins=N_BINS,
        n_jobs=8,  # TVAL3 uses the MATLAB engine, please use 1 job to avoid problems.
        dtype=np.float32,
    )

## Data loading

In [None]:
LAMBDA_PATH = RESOURCES_PATH / "beads" / "575_Lambda_L16.mat"
CMOS_PATH = RESOURCES_PATH / "beads" / "3beads_triangle_w4_rec_Hil2D_FOVcorrected.mat"

# CMOS loading
with h5py.File(CMOS_PATH, "r") as f:
    cmos = np.array(f["I"])
    # Fixes alignment between CMOS and SPC
    for zi in range(len(cmos)):
        cmos[zi] = np.roll(cmos[zi], shift=(-15, -7), axis=(0, 1))
    cmos = np.transpose(cmos, (1, 2, 0))
    cmos = ski.transform.resize(cmos, (XY_DIM, XY_DIM, cmos.shape[2]))
    cmos = np.transpose(cmos, (2, 1, 0))
    cmos = cmos[Z_START:Z_END]

# SPC loading
# spc = np.load(SPC_PATH)["spc_recon"].swapaxes(-2, -1)

SPC_PATH = BEADS_PATH / "TVAL3" / "480_3beads_triangle_505_500_575_SPC_raw_proc_tlxy_bin200ps_cut_CR50_b5m8.mat"
spc = sp.io.loadmat(SPC_PATH)["im"]
# FIXME: Replace with correct [0,0] pixel.
spc[:, :, 0, 0] = spc[:, :, 1, 0]

# Time axis loading
# t = np.load(SPC_PATH)["t_cut_binned"]
t = np.squeeze(sp.io.loadmat(SPC_PATH)["t"])
t = t - t.min()
dt = t[1] - t[0]

# Wavelength axis loading
lam = np.squeeze(sp.io.loadmat(LAMBDA_PATH)["lambda"])

# Crop where there are beads
# cmos = cmos[:, 42:82, 62:102]
# spc = spc[:, :, 10:20, 15:25]

# Normalization for plotting
cmos_max = cmos / cmos.max()
spc_max = spc / spc.max()

In [None]:
initial_spectrums = np.sum(spc, axis=0)
min_spectrums, max_spectrums = np.min(initial_spectrums), np.max(initial_spectrums)

initial_times = np.sum(spc, axis=1)
min_times, max_times = np.min(initial_times), np.max(initial_times)

resolution_diff_factor = int(cmos.shape[-1] / spc.shape[-1])


def spc_spectrum_time_in_a_point(cmos_z=5, spc_i=16, spc_j=17):
    _, ax = plt.subplots(2, 3, figsize=(9, 6))
    ax[0, 0].imshow(cmos_max[cmos_z] / cmos_max[cmos_z].max(), cmap="gray", vmin=0, vmax=1)
    ax[0, 0].scatter([spc_j * resolution_diff_factor], [spc_i * resolution_diff_factor], c="w")
    ax[0, 0].set_title(f"CMOS in Z={cmos_z}")

    ax[1, 0].imshow(spectral_volume_to_color(lam, spc.sum(axis=0)[:, np.newaxis])[0])
    ax[1, 0].scatter([spc_j], [spc_i], c="w")
    ax[1, 0].set_title(f"Colored with spectrum")

    ax[0, 1].plot(lam, spc.sum(axis=(0, 2, 3)))
    ax[0, 1].set_title(f"Global spectrum - {len(lam)} channels")
    ax[0, 1].set_xlabel("Wavelength [nm]")
    ax[0, 1].grid()

    ax[0, 2].plot(t, spc.sum(axis=(1, 2, 3)))
    ax[0, 2].set_title(f"Global time decay - {len(t)} temporal points")
    ax[0, 2].set_xlabel("Time [ns]")
    ax[0, 2].grid()

    ax[1, 1].plot(lam, initial_spectrums[:, spc_i, spc_j])
    ax[1, 1].set_ylim(min_spectrums, max_spectrums)
    ax[1, 1].set_title(f"Spectrum in ({spc_i},{spc_j})")
    ax[1, 1].set_xlabel("Wavelength [nm]")
    ax[1, 1].grid()

    ax[1, 2].plot(t, initial_times[:, spc_i, spc_j])
    ax[1, 2].set_ylim(min_times, max_times)
    ax[1, 2].set_title(f"Time in ({spc_i},{spc_j})")
    ax[1, 2].set_xlabel("Time [ns]")
    ax[1, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    spc_spectrum_time_in_a_point,
    cmos_z=(0, cmos.shape[0] - 1, 1),
    spc_i=(0, spc.shape[-2] - 1, 1),
    spc_j=(0, spc.shape[-1] - 1, 1),
);

## Fusion through optimization

In [None]:
weights = {
    "spatial": 0.5,
    "lambda_time": 0.5,
    # "global": 0.0,
}

fuse_with_adam = FusionAdam(
    spc, cmos, weights=weights, init_type="baseline", mask_noise=True,
    tol=1e-6, total_energy=1, device="mps", seed=42,
)
x, spc_out, cmos_out = fuse_with_adam(iterations=100, lr=1e-8, return_numpy=True)

In [None]:
fuse_with_cg = FusionCG(
    spc, cmos, weights=weights, init_type="baseline", tol=1e-6,
    mask_noise=True, total_energy=1, device="cpu", seed=42,
)
x_cg, spc_out_cg, cmos_out_cg = fuse_with_cg(iterations=40, eps=1e-8, return_numpy=True)

In [None]:
loss_adam_space, loss_adam_lambda_time = fuse_with_adam.loss()
loss_cg_space, loss_cg_lambda_time = fuse_with_cg.loss()
print(f"F(x_adam): Space: {loss_adam_space:.2E}, LambdaTime: {loss_adam_lambda_time:.2E}")
print(f"F(x_cg): Space: {loss_cg_space:.2E}, LambdaTime: {loss_cg_lambda_time:.2E}")

In [None]:
print(x.sum(), cmos_out.sum(), spc_out.sum())
print(x_cg.sum(), cmos_out_cg.sum(), spc_out_cg.sum())

## Fusion through baseline

In [None]:
x_baseline = baseline(cmos, spc, device="cpu", return_numpy=True)

## Visualize results

In [None]:
slices_rgb = spectral_volume_to_color(lam, np.sum(x, axis=0))
slices_rgb_spc = spectral_volume_to_color(lam, spc_out[:, :, np.newaxis, :, :].sum(axis=0))[0]

spectrums = np.sum(x, axis=(0, 3, 4)).T
times = np.sum(x, axis=(1, 3, 4)).T


def plot_across_z(z=5, i=60, j=70):
    _, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax[0, 0].imshow(slices_rgb_spc)
    ax[0, 0].scatter([j // 4], [i // 4], c="w")
    ax[0, 0].set_title(f"SPC")

    ax[1, 1].plot(lam, spectrums[z])
    ax[1, 1].set_ylim(spectrums.min(), spectrums.max())
    ax[1, 1].set_title(f"Global Spectrum in z={z}")
    ax[1, 1].grid()

    ax[1, 2].plot(t, times[z])
    ax[1, 2].set_ylim(times.min(), times.max())
    ax[1, 2].set_title(f"Global Time in z={z}")
    ax[1, 2].grid()

    reconstructed_spectrums = np.sum(x[:, :, z, :, :], axis=0)
    lxy_spc = np.sum(spc_out, axis=0)

    reconstructed_times = np.sum(x[:, :, z, :, :], axis=1)
    txy_spc = np.sum(spc_out, axis=1)

    ax[1, 0].imshow(slices_rgb[z])
    ax[1, 0].scatter([j], [i], c="w")
    ax[1, 0].set_title(f"Reconstruction Spectral Colored z={z}")

    ax[0, 1].plot(lam, reconstructed_spectrums[:, i, j] * 10 * 16, label="DF")
    ax[0, 1].plot(lam, lxy_spc[:, i // 4, j // 4], label="SPC")
    ax[0, 1].set_title(f"Spectrum in ({i},{j})")
    ax[0, 1].legend(loc="upper right")
    ax[0, 1].grid()

    params, covariance = curve_fit(
        decay,
        t,
        reconstructed_times[:, i, j] / reconstructed_times[:, i, j].max(),
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.00001),
        maxfev=5000,
    )
    a, tau, c = params

    ax[0, 2].plot(t, txy_spc[:, i // 4, j // 4] / txy_spc[:, i // 4, j // 4].max(), label="SPC", c="C1")
    ax[0, 2].scatter(t, reconstructed_times[:, i, j] / reconstructed_times[:, i, j].max(), label="DF", c="C0",
                     marker="+")
    ax[0, 2].plot(t, decay(t, a, tau, c), label=f"DF Fit {TAU}: {tau:.2f} ns", c="C0")
    ax[0, 2].set_title(f"Time in ({i},{j})")
    ax[0, 2].legend(loc="upper right")
    ax[0, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    plot_across_z,
    z=(0, x.shape[2] - 1, 1),
    i=(0, x.shape[-2] - 1, 1),
    j=(0, x.shape[-1] - 1, 1),
);

In [None]:
def plot_results(
        x, t, lam, z_index,
        font_size_letters=16,
        save_name=None,
        tau_clip=(1, 4),
):
    # Defining spectral colorbar
    wavelengths_ticks = np.arange(510, 650, 1)
    spectral_colors = wavelength_to_srgb(wavelengths_ticks, "basic").T
    spectral_colors = linear_to_srgb(spectral_colors)
    spectral_colors /= spectral_colors.max(axis=1)[..., np.newaxis]
    spectral_cmap = LinearSegmentedColormap.from_list("spectrum", spectral_colors, N=len(wavelengths_ticks))

    # Defining fused scalebar
    fused_micro_width = 150  # in micrometers
    pixel_size = fused_micro_width / x.shape[-1]
    scalebar_length_micrometers = 30
    scalebar_length_pixels = scalebar_length_micrometers / pixel_size

    fused = x[:, :, z_index:z_index + 1]
    fused_spectral_image = spectral_volume_to_color(lam, fused.sum(axis=0))[0]
    fused_lifetime_image, tau_min, tau_max = time_volume_to_lifetime(t, fused.sum(axis=1), tau_clip=tau_clip, max_tau=6)
    fused_lifetime_image = fused_lifetime_image[0]

    fig = plt.figure(figsize=(9, 6))
    gs = fig.add_gridspec(2, 3, width_ratios=[1, 1, 1], height_ratios=[1, 1])

    # Show spectral image
    ax00 = fig.add_subplot(gs[0, 0])
    fused00 = ax00.imshow(fused_spectral_image, cmap=spectral_cmap)
    ax00.set_title("Multispectral Image")
    ax00.text(
        0.05, 0.05, "(a)",
        transform=ax00.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="w"
    )
    scalebar = AnchoredSizeBar(
        ax00.transData,
        scalebar_length_pixels,  # Length of scalebar in pixels
        f'{scalebar_length_micrometers} µm',  # Label for the scalebar
        'upper left',
        pad=0.3,
        color='white',
        frameon=False,
        size_vertical=2,
        fontproperties={"size": 12, "weight": "bold"},
    )
    ax00.add_artist(scalebar)
    ax00.text(
        0.95, 0.95, f"z={z_index} µm",
        transform=ax00.transAxes,
        fontsize=12,
        fontweight='bold',
        va='top',
        ha='right',
        c="w"
    )
    ax00.axis("off")

    cbar = fig.colorbar(fused00, ax=ax00, fraction=0.046, pad=0.03, orientation="horizontal", label="Wavelength [nm]")
    cbar.set_ticks(np.linspace(0, 1, len(wavelengths_ticks[::30])))
    cbar.set_ticklabels([f"{w:.0f}" for w in wavelengths_ticks[::30]])

    # Show lifetime image
    hsv_colormap = np.vstack([np.linspace(0, 26 / 36, 100)[::-1], np.ones(100), np.ones(100)]).T
    lifetime_cmap = LinearSegmentedColormap.from_list("fused_lifetime", hsv_to_rgb(hsv_colormap), N=100)

    ax10 = fig.add_subplot(gs[1, 0])
    fused10 = ax10.imshow(fused_lifetime_image, cmap=lifetime_cmap)
    ax10.set_title("FLIM Image")
    ax10.text(
        0.05, 0.05, "(d)",
        transform=ax10.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="w"
    )
    scalebar = AnchoredSizeBar(
        ax10.transData,
        scalebar_length_pixels,  # Length of scalebar in pixels
        f'{scalebar_length_micrometers} µm',  # Label for the scalebar
        'upper left',
        pad=0.3,
        color='white',
        frameon=False,
        size_vertical=2,
        fontproperties={"size": 12, "weight": "bold"},
    )
    ax10.add_artist(scalebar)
    ax10.text(
        0.95, 0.95, f"z={z_index} µm",
        transform=ax10.transAxes,
        fontsize=12,
        fontweight='bold',
        va='top',
        ha='right',
        c="w"
    )
    ax10.axis("off")

    cbar = fig.colorbar(fused10, ax=ax10, fraction=0.046, pad=0.03, orientation="horizontal", label="Lifetime [ns]")
    cbar.set_ticks(np.linspace(0, 1, 6))
    cbar.set_ticklabels([f"{t:.1f}" for t in np.linspace(tau_min, tau_max, 6)])

    # Select two ROIs
    ax00.add_patch(patches.Rectangle((74, 44), 12, 12, linewidth=2, edgecolor="green", facecolor="none"))
    ax00.add_patch(patches.Rectangle((60, 56), 16, 16, linewidth=1.5, edgecolor="gold", facecolor="none"))

    ax10.add_patch(patches.Rectangle((74, 44), 12, 12, linewidth=2, edgecolor="green", facecolor="none"))
    ax10.add_patch(patches.Rectangle((60, 56), 16, 16, linewidth=1.5, edgecolor="gold", facecolor="none"))

    # Show wavelength comparison of green ROI 
    ax01 = fig.add_subplot(gs[0, 1])
    fused_roi_spectrum = x[:, :, :, 45:58, 75:88].sum(axis=(0, 2, 3, 4))
    fused_roi_spectrum /= fused_roi_spectrum.sum()
    spc_roi_spectrum = spc[:, :, 45 // 4:58 // 4, 75 // 4:88 // 4].sum(axis=(0, 2, 3))
    spc_roi_spectrum /= spc_roi_spectrum.sum()
    ax01.plot(lam, fused_roi_spectrum, label="DF", c="green", linestyle="-")
    ax01.plot(lam, spc_roi_spectrum, label="SPC", c="black", linestyle="--")
    ax01.set_title("Green Bead Spectrum")
    ax01.set_xlabel("Wavelength [nm]")
    ax01.set_ylabel("Intensity [a.u.]")
    ax01.grid()
    print(f"SAM Green: {sam(fused_roi_spectrum, spc_roi_spectrum):.2F} rad")
    ax01.legend()
    ax01.text(
        0.05, 0.05, "(b)",
        transform=ax01.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="black"
    )



    # Show wavelength comparison of yellow ROI
    ax02 = fig.add_subplot(gs[0, 2])
    fused_roi_spectrum = x[:, :, :, 57:73, 61:77].sum(axis=(0, 2, 3, 4))
    fused_roi_spectrum /= fused_roi_spectrum.sum()
    spc_roi_spectrum = spc[:, :, 57 // 4:73 // 4, 61 // 4:77 // 4].sum(axis=(0, 2, 3))
    spc_roi_spectrum /= spc_roi_spectrum.sum()
    ax02.plot(lam, fused_roi_spectrum, label="DF", c="gold", linestyle="-")
    ax02.plot(lam, spc_roi_spectrum, label="SPC", c="black", linestyle="--")
    ax02.set_title("Yellow Bead Spectrum")
    ax02.set_xlabel("Wavelength [nm]")
    ax02.set_ylabel("Intensity [a.u.]")
    ax02.grid()
    print(f"SAM Yellow: {sam(fused_roi_spectrum, spc_roi_spectrum):.2F} rad")
    ax02.legend()
    ax02.text(
        0.05, 0.05, "(c)",
        transform=ax02.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="black"
    )

    # Show lifetime comparison of green ROI
    ax11 = fig.add_subplot(gs[1, 1])
    fused_roi_time = x[:, :, :, 45:58, 75:88].sum(axis=(1, 2, 3, 4))
    fused_roi_time /= fused_roi_time.max()
    params, _ = curve_fit(
        decay, t, fused_roi_time,
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.000001),
        maxfev=5000,
    )
    ax11.plot(
        t, fused_roi_time,
        label=f" DF | {TAU}={params[1]:.2F} ns",
        c="green",
        linestyle="-",
    )

    spc_roi_time = spc[:, :, 45 // 4:58 // 4, 75 // 4:88 // 4].sum(axis=(1, 2, 3))
    spc_roi_time /= spc_roi_time.max()
    params, _ = curve_fit(
        decay, t, spc_roi_time,
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.000001),
        maxfev=5000,
    )

    ax11.plot(
        t, spc_roi_time,
        label=f"SPC | {TAU}={params[1]:.2F} ns",
        c="black",
        linestyle="--",
    )
    ax11.set_title("Green Bead Time Decay")
    ax11.set_xlabel("Time [ns]")
    ax11.set_ylabel("Intensity [a.u.]")
    # ax11.set_yscale("log")
    ax11.grid()
    print(f"Green Time RMSE: {rmse(fused_roi_time, spc_roi_time):.2F}")
    ax11.legend()
    ax11.text(
        0.05, 0.05, "(e)",
        transform=ax11.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="black"
    )

    # Show lifetime comparison of yellow ROI
    ax12 = fig.add_subplot(gs[1, 2])
    fused_roi_time = x[:, :, :, 57:73, 61:77].sum(axis=(1, 2, 3, 4))
    fused_roi_time /= fused_roi_time.max()
    params, _ = curve_fit(
        decay, t, fused_roi_time,
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.000001),
        maxfev=5000,
    )
    ax12.plot(
        t, fused_roi_time,
        label=f" DF | {TAU}={params[1]:.2F} ns",
        c="gold",
        linestyle="-",
    )

    spc_roi_time = spc[:, :, 57 // 4:73 // 4, 61 // 4:77 // 4].sum(axis=(1, 2, 3))
    spc_roi_time /= spc_roi_time.max()
    params, _ = curve_fit(
        decay, t, spc_roi_time,
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.000001),
        maxfev=5000,
    )

    ax12.plot(
        t, spc_roi_time,
        label=f"SPC | {TAU}={params[1]:.2F} ns",
        c="black",
        linestyle="--",
    )

    ax12.set_title("Yellow Bead Time Decay")
    ax12.set_xlabel("Time [ns]")
    ax12.set_ylabel("Intensity [a.u.]")
    # ax12.set_yscale("log")
    ax12.grid()
    print(f"Yellow Time RMSE: {rmse(fused_roi_time, spc_roi_time):.2F}")
    ax12.legend()
    ax12.text(
        0.05, 0.05, "(f)",
        transform=ax12.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="black"
    )

    plt.tight_layout()

    if save_name:
        plt.savefig(FIGURES_PATH / save_name, dpi=300)

    plt.show()


plot_results(x, t, lam, 5, save_name="results_beads.pdf", tau_clip=(1.5, 3.5))