## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import skimage as ski
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from scipy.optimize import curve_fit
from ipywidgets import interact
from matplotlib import patches
from matplotlib.colors import LinearSegmentedColormap, hsv_to_rgb

from datafusion.fusion import FusionAdam, FusionCG
from datafusion.baseline import baseline
from datafusion.utils import mono_exponential_decay_numpy as decay
from datafusion.utils import (
    RESOURCES_PATH,
    FIGURES_PATH,
    spectral_volume_to_color,
    time_volume_to_lifetime,
    linear_to_srgb,
    wavelength_to_srgb,
    sam,
)

TAU = r"$\tau$"
XY_DIM = 128
Z_START = 0
Z_END = -6

DATA_PATH = RESOURCES_PATH / "beads.npz"

## Data loading

In [None]:
data = np.load(DATA_PATH)

# CMOS
cmos = data["cmos"]
# Fixes alignment between CMOS and SPC for this slide
for zi in range(len(cmos)):
    cmos[zi] = np.roll(cmos[zi], shift=(-15, -7), axis=(0, 1))
cmos = np.transpose(cmos, (1, 2, 0))
cmos = ski.transform.resize(cmos, (XY_DIM, XY_DIM, cmos.shape[2]))
cmos = np.transpose(cmos, (2, 1, 0))
cmos = cmos[Z_START:Z_END]

# SPC loading
spc = data["spc"]

# Time axis loading
t = data["time_axis"]
dt = t[1] - t[0]

# Wavelength axis loading
lam = data["spectral_axis"]

metadata = data["metadata"]
print(metadata)

In [None]:
initial_spectrums = np.sum(spc, axis=0)
min_spectrums, max_spectrums = np.min(initial_spectrums), np.max(initial_spectrums)

initial_times = np.sum(spc, axis=1)
min_times, max_times = np.min(initial_times), np.max(initial_times)

resolution_diff_factor = int(cmos.shape[-1] / spc.shape[-1])


def spc_spectrum_time_in_a_point(cmos_z=5, spc_i=16, spc_j=17):
    _, ax = plt.subplots(2, 3, figsize=(9, 6))
    ax[0, 0].imshow(cmos[cmos_z] / cmos[cmos_z].max(), cmap="gray", vmin=0, vmax=1)
    ax[0, 0].scatter([spc_j * resolution_diff_factor], [spc_i * resolution_diff_factor], c="w")
    ax[0, 0].set_title(f"CMOS in Z={cmos_z}")

    ax[1, 0].imshow(spectral_volume_to_color(lam, spc.sum(axis=0)[:, np.newaxis])[0])
    ax[1, 0].scatter([spc_j], [spc_i], c="w")
    ax[1, 0].set_title(f"Colored with spectrum")

    ax[0, 1].plot(lam, spc.sum(axis=(0, 2, 3)))
    ax[0, 1].set_title(f"Global spectrum - {len(lam)} channels")
    ax[0, 1].set_xlabel("Wavelength [nm]")
    ax[0, 1].grid()

    ax[0, 2].plot(t, spc.sum(axis=(1, 2, 3)))
    ax[0, 2].set_title(f"Global time decay - {len(t)} temporal points")
    ax[0, 2].set_xlabel("Time [ns]")
    ax[0, 2].grid()

    ax[1, 1].plot(lam, initial_spectrums[:, spc_i, spc_j])
    ax[1, 1].set_ylim(min_spectrums, max_spectrums)
    ax[1, 1].set_title(f"Spectrum in ({spc_i},{spc_j})")
    ax[1, 1].set_xlabel("Wavelength [nm]")
    ax[1, 1].grid()

    ax[1, 2].plot(t, initial_times[:, spc_i, spc_j])
    ax[1, 2].set_ylim(min_times, max_times)
    ax[1, 2].set_title(f"Time in ({spc_i},{spc_j})")
    ax[1, 2].set_xlabel("Time [ns]")
    ax[1, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    spc_spectrum_time_in_a_point,
    cmos_z=(0, cmos.shape[0] - 1, 1),
    spc_i=(0, spc.shape[-2] - 1, 1),
    spc_j=(0, spc.shape[-1] - 1, 1),
);

## Fusion through optimization

In [None]:
weights = {
    "spatial": 0.2,
    "lambda_time": 0.8,
    # "global": 0.0,
}

fuse_with_adam = FusionAdam(
    spc, cmos, weights=weights, init_type="baseline", mask_noise=True,
    tol=1e-6, total_energy=1, device="cpu", seed=42,
)
x, spc_out, cmos_out = fuse_with_adam(iterations=100, lr=1e-8, return_numpy=True, non_neg=False)

In [None]:
fuse_with_cg = FusionCG(
    spc, cmos, weights=weights, init_type="baseline", tol=1e-6,
    mask_noise=True, total_energy=1, device="cpu", seed=42,
)
x_cg, spc_out_cg, cmos_out_cg = fuse_with_cg(iterations=40, eps=1e-8, return_numpy=True)

In [None]:
loss_adam_space, loss_adam_lambda_time = fuse_with_adam.loss()
loss_cg_space, loss_cg_lambda_time = fuse_with_cg.loss()
print(f"F(x_adam): Space: {loss_adam_space:.2E}, LambdaTime: {loss_adam_lambda_time:.2E}")
print(f"F(x_cg): Space: {loss_cg_space:.2E}, LambdaTime: {loss_cg_lambda_time:.2E}")

In [None]:
print(x.sum(), cmos_out.sum(), spc_out.sum())
print(x_cg.sum(), cmos_out_cg.sum(), spc_out_cg.sum())

## Fusion through baseline

In [None]:
x_baseline = baseline(cmos, spc, device="cpu", return_numpy=True)

## Visualize results

In [None]:
slices_rgb = spectral_volume_to_color(lam, np.sum(x, axis=0))
slices_rgb_spc = spectral_volume_to_color(lam, spc_out[:, :, np.newaxis, :, :].sum(axis=0))[0]

spectrums = np.sum(x, axis=(0, 3, 4)).T
times = np.sum(x, axis=(1, 3, 4)).T


def plot_across_z(z=5, i=60, j=70):
    _, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax[0, 0].imshow(slices_rgb_spc)
    ax[0, 0].scatter([j // 4], [i // 4], c="w")
    ax[0, 0].set_title(f"SPC")

    ax[1, 1].plot(lam, spectrums[z])
    ax[1, 1].set_ylim(spectrums.min(), spectrums.max())
    ax[1, 1].set_title(f"Global Spectrum in z={z}")
    ax[1, 1].grid()

    ax[1, 2].plot(t, times[z])
    ax[1, 2].set_ylim(times.min(), times.max())
    ax[1, 2].set_title(f"Global Time in z={z}")
    ax[1, 2].grid()

    reconstructed_spectrums = np.sum(x[:, :, z, :, :], axis=0)
    lxy_spc = np.sum(spc_out, axis=0)

    reconstructed_times = np.sum(x[:, :, z, :, :], axis=1)
    txy_spc = np.sum(spc_out, axis=1)

    ax[1, 0].imshow(slices_rgb[z])
    ax[1, 0].scatter([j], [i], c="w")
    ax[1, 0].set_title(f"Reconstruction Spectral Colored z={z}")

    ax[0, 1].plot(lam, reconstructed_spectrums[:, i, j] * 10 * 16, label="DF")
    ax[0, 1].plot(lam, lxy_spc[:, i // 4, j // 4], label="SPC")
    ax[0, 1].set_title(f"Spectrum in ({i},{j})")
    ax[0, 1].legend(loc="upper right")
    ax[0, 1].grid()

    params, covariance = curve_fit(
        decay,
        t,
        reconstructed_times[:, i, j] / reconstructed_times[:, i, j].max(),
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.00001),
        maxfev=5000,
    )
    a, tau, c = params

    ax[0, 2].plot(t, txy_spc[:, i // 4, j // 4] / txy_spc[:, i // 4, j // 4].max(), label="SPC", c="C1")
    ax[0, 2].scatter(t, reconstructed_times[:, i, j] / reconstructed_times[:, i, j].max(), label="DF", c="C0",
                     marker="+")
    ax[0, 2].plot(t, decay(t, a, tau, c), label=f"DF Fit {TAU}: {tau:.2f} ns", c="C0")
    ax[0, 2].set_title(f"Time in ({i},{j})")
    ax[0, 2].legend(loc="upper right")
    ax[0, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    plot_across_z,
    z=(0, x.shape[2] - 1, 1),
    i=(0, x.shape[-2] - 1, 1),
    j=(0, x.shape[-1] - 1, 1),
);

In [None]:
from datafusion.plot_helpers import add_letter, add_scalebar, add_z_text, exp_fit

def plot_beads_results(
        x, t, lam, z_index, spc,
        font_size_letters=16,
        save_name=None,
        tau_clip=(1, 4),
        fov_size=150,
):
    # SPECTRAL COLORMAP
    wavelengths_ticks = np.arange(510, 650, 1)
    spectral_colors = wavelength_to_srgb(wavelengths_ticks, "basic").T
    spectral_colors = linear_to_srgb(spectral_colors)
    spectral_colors /= spectral_colors.max(axis=1)[..., np.newaxis]
    spectral_cmap = LinearSegmentedColormap.from_list("spectrum", spectral_colors, N=len(wavelengths_ticks))

    # LIFETIME COLORMAPS
    hsv_colormap = np.vstack([np.linspace(0, 26 / 36, 100)[::-1], np.ones(100), np.ones(100)]).T
    lifetime_cmap = LinearSegmentedColormap.from_list("fused_lifetime", hsv_to_rgb(hsv_colormap), N=100)

    fused = x[:, :, z_index:z_index + 1]
    fused_spectral_image = spectral_volume_to_color(lam, fused.sum(axis=0))[0]
    fused_lifetime_image, tau_min, tau_max = time_volume_to_lifetime(t, fused.sum(axis=1), tau_clip=tau_clip, max_tau=6)
    fused_lifetime_image = fused_lifetime_image[0]

    spc_spectral_image = spectral_volume_to_color(lam, spc.sum(axis=0)[:, np.newaxis])[0]
    spc_lifetime_image, spc_min_t, spc_max_t = time_volume_to_lifetime(t, spc.sum(axis=1)[:, np.newaxis],
                                                                       tau_clip=tau_clip, max_tau=6)
    spc_lifetime_image = spc_lifetime_image[0]

    print(tau_min, tau_max)
    print(spc_min_t, spc_max_t)

    fig = plt.figure(figsize=(11, 6))
    gs = fig.add_gridspec(2, 4, width_ratios=[1, 1, 1, 1.2], height_ratios=[1, 1])

    # SPC INTENSITY IMAGE
    ax = fig.add_subplot(gs[0, 0])
    add_letter(ax, "a", font_size=font_size_letters, pos=(0.05, 0.05))
    spc_intensity = spc.sum(axis=(0, 1))
    spc_intensity /= spc_intensity.max()
    ax_img = ax.imshow(spc_intensity, cmap="gray")
    cbar = fig.colorbar(ax_img, ax=ax, fraction=0.046, pad=0.03, orientation="horizontal", label="Intensity [a.u.]")
    cbar.ax.xaxis.set_major_locator(ticker.MultipleLocator(0.25))
    add_scalebar(ax, fov_size / spc.shape[-1], length_micrometers=30, size_vertical=1)
    ax.set_title("SPC Intensity Image")
    ax.axis("off")

    # CMOS INTENSITY IMAGE
    ax = fig.add_subplot(gs[1, 0])
    add_letter(ax, "e", font_size=font_size_letters, pos=(0.05, 0.05))
    cmos_intensity = cmos[z_index]
    cmos_intensity /= cmos_intensity.max()
    ax_img = ax.imshow(cmos_intensity, cmap="gray")
    cbar = fig.colorbar(ax_img, ax=ax, fraction=0.046, pad=0.03, orientation="horizontal", label="Intensity [a.u.]")
    cbar.ax.xaxis.set_major_locator(ticker.MultipleLocator(0.25))
    add_scalebar(ax, fov_size / cmos.shape[-1], length_micrometers=30, size_vertical=4)
    add_z_text(ax, z_index)
    ax.set_title("CMOS Intensity Image")
    ax.axis("off")

    # SPC SPECTRAL
    ax = fig.add_subplot(gs[0, 1])
    add_letter(ax, "b", font_size=font_size_letters, pos=(0.05, 0.05))
    ax_img = ax.imshow(spc_spectral_image, cmap=spectral_cmap)
    cbar = fig.colorbar(ax_img, ax=ax, fraction=0.046, pad=0.03, orientation="horizontal", label="Wavelength [nm]")
    cbar.set_ticks(np.linspace(0, 1, len(wavelengths_ticks[::30])))
    cbar.set_ticklabels([f"{w:.0f}" for w in wavelengths_ticks[::30]])
    add_scalebar(ax, fov_size / spc.shape[-1], length_micrometers=30, size_vertical=1)
    ax.set_title("SPC Multispectral Image")
    ax.axis("off")

    # FUSED SPECTRAL
    ax = fig.add_subplot(gs[0, 2])
    add_letter(ax, "c", font_size=font_size_letters, pos=(0.05, 0.05))
    ax_img = ax.imshow(fused_spectral_image, cmap=spectral_cmap)
    cbar = fig.colorbar(ax_img, ax=ax, fraction=0.046, pad=0.03, orientation="horizontal", label="Wavelength [nm]")
    cbar.set_ticks(np.linspace(0, 1, len(wavelengths_ticks[::30])))
    cbar.set_ticklabels([f"{w:.0f}" for w in wavelengths_ticks[::30]])
    ax.add_patch(patches.Rectangle((74, 44), 12, 12, linewidth=2, edgecolor="green", facecolor="none"))
    ax.add_patch(patches.Rectangle((60, 56), 16, 16, linewidth=1.5, edgecolor="gold", facecolor="none"))
    add_z_text(ax, z_index)
    add_scalebar(ax, fov_size / x.shape[-1], length_micrometers=30, size_vertical=4)
    ax.set_title("DF Multispectral Image")
    ax.axis("off")

    # SPC LIFETIME
    ax = fig.add_subplot(gs[1, 1])
    add_letter(ax, "f", font_size=font_size_letters, color="w")
    ax_img = ax.imshow(spc_lifetime_image, cmap=lifetime_cmap)
    cbar_spc = fig.colorbar(ax_img, ax=ax, fraction=0.046, pad=0.03, orientation="horizontal", label="Lifetime [ns]")
    cbar_spc.set_ticks(np.linspace(0, 1, 6))
    cbar_spc.set_ticklabels([f"{t:.1f}" for t in np.linspace(tau_clip[0], tau_clip[1], 6)])
    add_scalebar(ax, fov_size / spc.shape[-1], length_micrometers=30, size_vertical=1)
    ax.set_title("SPC FLIM Image")
    ax.axis("off")

    # FUSED LIFETIME
    ax = fig.add_subplot(gs[1, 2])
    add_letter(ax, "g", font_size=font_size_letters, color="w")
    fused10 = ax.imshow(fused_lifetime_image, cmap=lifetime_cmap)
    cbar = fig.colorbar(fused10, ax=ax, fraction=0.046, pad=0.03, orientation="horizontal", label="Lifetime [ns]")
    cbar.set_ticks(np.linspace(0, 1, 6))
    cbar.set_ticklabels([f"{t:.1f}" for t in np.linspace(tau_clip[0], tau_clip[1], 6)])
    ax.add_patch(patches.Rectangle((74, 44), 12, 12, linewidth=2, edgecolor="green", facecolor="none"))
    ax.add_patch(patches.Rectangle((60, 56), 16, 16, linewidth=1.5, edgecolor="gold", facecolor="none"))
    add_scalebar(ax, fov_size / x.shape[-1], length_micrometers=30, size_vertical=4)
    add_z_text(ax, z_index)
    ax.set_title("DF FLIM Image")
    ax.axis("off")

    # SPECTRAL COMPARISON ROI
    ax = fig.add_subplot(gs[0, 3])
    add_letter(ax, "d", font_size=font_size_letters, color="black")
    fused_roi_spectrum = x[:, :, :, 45:58, 75:88].sum(axis=(0, 2, 3, 4))
    fused_roi_spectrum /= fused_roi_spectrum.sum()
    spc_roi_spectrum = spc[:, :, 45 // 4:58 // 4, 75 // 4:88 // 4].sum(axis=(0, 2, 3))
    spc_roi_spectrum /= spc_roi_spectrum.sum()
    ax.plot(lam, fused_roi_spectrum, label="DF", c="green", linestyle="-")
    ax.plot(lam, spc_roi_spectrum, c="black", linestyle="--")
    print(f"SAM Green: {sam(fused_roi_spectrum, spc_roi_spectrum):.2F} rad")

    fused_roi_spectrum = x[:, :, :, 57:73, 61:77].sum(axis=(0, 2, 3, 4))
    fused_roi_spectrum /= fused_roi_spectrum.sum()
    spc_roi_spectrum = spc[:, :, 57 // 4:73 // 4, 61 // 4:77 // 4].sum(axis=(0, 2, 3))
    spc_roi_spectrum /= spc_roi_spectrum.sum()
    ax.plot(lam, fused_roi_spectrum, label="DF", c="gold", linestyle="-")
    ax.plot(lam, spc_roi_spectrum, label="SPC", c="black", linestyle="--")
    print(f"SAM Yellow: {sam(fused_roi_spectrum, spc_roi_spectrum):.2F} rad")

    ax.set_title("Beads Spectra")
    ax.set_xlabel("Wavelength [nm]")
    ax.set_ylabel("Intensity [a.u.]")
    ax.set_ylim(-0.01, 0.17)
    ax.set_xlabel("Wavelength [nm]")
    ax.xaxis.set_major_locator(ticker.MultipleLocator(30))
    ax.grid()
    ax.legend()

    # TIME COMPARISON ROI
    # Green lifetime ROI
    ax = fig.add_subplot(gs[1, 3])
    add_letter(ax, "h", font_size=font_size_letters, color="black")
    fused_roi_time = x[:, :, :, 45:58, 75:88].sum(axis=(1, 2, 3, 4))
    fused_roi_time /= fused_roi_time.max()
    params = exp_fit(t, fused_roi_time)
    ax.plot(
        t, fused_roi_time,
        label=f" DF | {TAU}={params[1]:.2F} ns",
        c="green",
        linestyle="-"
    )
    spc_roi_time = spc[:, :, 45 // 4:58 // 4, 75 // 4:88 // 4].sum(axis=(1, 2, 3))
    spc_roi_time /= spc_roi_time.max()
    params = exp_fit(t, spc_roi_time)
    ax.plot(
        t, spc_roi_time,
        label=f"SPC | {TAU}={params[1]:.2F} ns",
        c="black",
        linestyle="--",
    )

    # Yellow lifetime ROI
    fused_roi_time = x[:, :, :, 57:73, 61:77].sum(axis=(1, 2, 3, 4))
    fused_roi_time /= fused_roi_time.max()
    params = exp_fit(t, fused_roi_time)
    ax.plot(
        t, fused_roi_time,
        label=f" DF | {TAU}={params[1]:.2F} ns",
        c="gold",
        linestyle="-",
    )
    spc_roi_time = spc[:, :, 57 // 4:73 // 4, 61 // 4:77 // 4].sum(axis=(1, 2, 3))
    spc_roi_time /= spc_roi_time.max()
    params = exp_fit(t, spc_roi_time)
    ax.plot(
        t, spc_roi_time,
        label=f"SPC | {TAU}={params[1]:.2F} ns",
        c="black",
        linestyle=":",
    )

    ax.grid()
    ax.legend()
    ax.set_ylabel("Intensity [a.u.]")
    ax.set_xlabel("Time [ns]")
    ax.set_title("Beads Time Decay")

    plt.tight_layout()

    if save_name:
        plt.savefig(FIGURES_PATH / save_name, dpi=300)

    plt.show()


plot_beads_results(x, t, lam, 5, spc=spc, save_name="results_beads.pdf", tau_clip=(0.0, 3.8))