## Imports

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import h5py
import numpy as np
import scipy as sp
import skimage as ski
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
from ipywidgets import interact
from matplotlib import patches
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
from matplotlib.colors import LinearSegmentedColormap, hsv_to_rgb

from fusion import fuse
from baseline import baseline
from utils import (
    RESOURCES_PATH,
    spectral_volume_to_color,
    preprocess_raw_spc,
    linear_to_srgb,
    wavelength_to_srgb
)

XY_DIM = 128
Z_START = 0
Z_END = -6

BEADS_PATH = RESOURCES_PATH / "beads"
SPC_PATH = BEADS_PATH / "480_3beads_triangle_505_500_575_SPC_calib_cut_binned_tlxy.npz"


## SPC preprocessing

In [None]:
N_MEASUREMENTS = 1024
MAX_TIMES = 2048
N_BINS = 32

# Preprocess SPC
preprocess_raw_spc(
    raw_spc_path=BEADS_PATH / "480_3beads_triangle_505_500_575_SPC_raw.mat",
    reconstruction_save_path=SPC_PATH,
    forward_matrix_path=BEADS_PATH / "FLIM_Scrambled-Hadamard_1024.mat",
    efficiency_calib_path=BEADS_PATH / "Efficiency_L16_575.mat",
    offset_calib_path=BEADS_PATH / "L16_temporal_offsets_20220520.mat",
    temporal_axis_path=BEADS_PATH / "t.npy",
    n_measurements=N_MEASUREMENTS,
    max_times=MAX_TIMES,
    n_bins=N_BINS,
    n_jobs=8,
    dtype=np.float64,
);

## Data loading

In [None]:
LAMBDA_PATH = RESOURCES_PATH / "beads" / "575_Lambda_L16.mat"
CMOS_PATH = RESOURCES_PATH / "beads" / "3beads_triangle_w4_rec_Hil2D_FOVcorrected.mat"

# CMOS loading
with h5py.File(CMOS_PATH, "r") as f:
    cmos = np.array(f["I"])
    cmos = np.transpose(cmos, (1, 2, 0))
    cmos = ski.transform.resize(cmos, (XY_DIM, XY_DIM, cmos.shape[2]))
    cmos = np.transpose(cmos, (2, 1, 0))
    cmos = cmos[Z_START:Z_END]

# SPC loading
spc = np.load(SPC_PATH)["spc_recon"].swapaxes(-2, -1)
# FIXME: Replace with correct [0,0] pixel.
spc[:, :, 0, 0] = spc[:, :, 1, 0]

# Time axis loading
t = np.load(SPC_PATH)["t_cut_binned"]
t = t - t.min()
dt = t[1] - t[0]

# Wavelength axis loading
lam = np.squeeze(sp.io.loadmat(LAMBDA_PATH)["lambda"])

# Normalization for plotting
cmos_max = cmos / cmos.max()
spc_max = spc / spc.max()

In [None]:
initial_spectrums = np.sum(spc, axis=0)
min_spectrums, max_spectrums = np.min(initial_spectrums), np.max(initial_spectrums)

initial_times = np.sum(spc, axis=1)
min_times, max_times = np.min(initial_times), np.max(initial_times)

resolution_diff_factor = int(cmos.shape[-1] / spc.shape[-1])


def spc_spectrum_time_in_a_point(cmos_z=5, spc_i=16, spc_j=17):
    _, ax = plt.subplots(2, 3, figsize=(9, 6))
    ax[0, 0].imshow(cmos_max[cmos_z], cmap="gray", vmin=0, vmax=1)
    ax[0, 0].scatter([spc_j * resolution_diff_factor], [spc_i * resolution_diff_factor], c="w")
    ax[0, 0].set_title(f"CMOS in Z={cmos_z}")

    ax[1, 0].imshow(spectral_volume_to_color(lam, spc.sum(axis=0)[:, np.newaxis])[0])
    ax[1, 0].scatter([spc_j], [spc_i], c="w")
    ax[1, 0].set_title(f"Colored with spectrum")

    ax[0, 1].plot(lam, spc.sum(axis=(0, 2, 3)))
    ax[0, 1].set_title(f"Global spectrum - {len(lam)} channels")
    ax[0, 1].set_xlabel("Wavelength [nm]")
    ax[0, 1].grid()

    ax[0, 2].plot(t, spc.sum(axis=(1, 2, 3)))
    ax[0, 2].set_title(f"Global time decay - {len(t)} temporal points")
    ax[0, 2].set_xlabel("Time [ns]")
    ax[0, 2].grid()

    ax[1, 1].plot(lam, initial_spectrums[:, spc_i, spc_j])
    ax[1, 1].set_ylim(min_spectrums, max_spectrums)
    ax[1, 1].set_title(f"Spectrum in ({spc_i},{spc_j})")
    ax[1, 1].set_xlabel("Wavelength [nm]")
    ax[1, 1].grid()

    ax[1, 2].plot(t, initial_times[:, spc_i, spc_j])
    ax[1, 2].set_ylim(min_times, max_times)
    ax[1, 2].set_title(f"Time in ({spc_i},{spc_j})")
    ax[1, 2].set_xlabel("Time [ns]")
    ax[1, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    spc_spectrum_time_in_a_point,
    cmos_z=(0, cmos.shape[0] - 1, 1),
    spc_i=(0, spc.shape[-2] - 1, 1),
    spc_j=(0, spc.shape[-1] - 1, 1),
);

## Fusion through optimization

In [None]:
weights = {
    "spatial": 2.0,
    "lambda_time": 1.0,
    "global": 0.0,
}

x, spc_out, cmos_out = fuse(
    spc,
    cmos,
    weights=weights,
    lr=1e-3,
    iterations=150,
    l2_regularization=0,
    device="cpu",
    init_type="random",
    mask_initializations=True,
    mask_gradients=True,
    non_neg=True,
    total_energy=1000,
    return_numpy=True,
)

In [None]:
print(x.sum(), cmos_out.sum(), spc_out.sum())

## Fusion through baseline

In [None]:
x_baseline = baseline(cmos, spc, device="cpu", return_numpy=True)

## Visualize results

In [None]:
slices_rgb = spectral_volume_to_color(lam, np.sum(x, axis=0))
slices_rgb_baseline = spectral_volume_to_color(lam, np.sum(x_baseline, axis=0))

spectrums = np.sum(x, axis=(0, 3, 4)).T
min_mean_spectrum = np.min(spectrums)
max_mean_spectrum = np.max(spectrums)

times = np.sum(x, axis=(1, 3, 4)).T
min_mean_times = np.min(times)
max_mean_times = np.max(times)


def plot_across_z(z=5, i=0, j=0):
    _, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax[0, 0].imshow(slices_rgb_baseline[z])
    ax[0, 0].scatter([j], [i], c="w")
    ax[0, 0].set_title(f"Baseline z={z}")

    ax[1, 1].plot(lam, spectrums[z])
    ax[1, 1].set_ylim(min_mean_spectrum, max_mean_spectrum)
    ax[1, 1].set_title(f"Global Spectrum in z={z}")
    ax[1, 1].grid()

    ax[1, 2].plot(t, times[z])
    ax[1, 2].set_ylim(min_mean_times, max_mean_times)
    ax[1, 2].set_title(f"Global Time in z={z}")
    ax[1, 2].grid()

    reconstructed_spectrums = np.sum(x[:, :, z, :, :], axis=0)
    baseline_spectrums = np.sum(x_baseline[:, :, z, :, :], axis=0)
    lxy_spc = np.sum(spc, axis=0)

    reconstructed_times = np.sum(x[:, :, z, :, :], axis=1)
    baseline_times = np.sum(x_baseline[:, :, z, :, :], axis=1)
    txy_spc = np.sum(spc, axis=1)

    ax[1, 0].imshow(slices_rgb[z])
    ax[1, 0].scatter([j], [i], c="w")
    ax[1, 0].set_title(f"Reconstruction Spectral Colored z={z}")

    ax[0, 1].plot(lam, reconstructed_spectrums[:, i, j] * 15 * 16, label="Datafusion")
    ax[0, 1].plot(lam, baseline_spectrums[:, i, j], label="Baseline")
    ax[0, 1].plot(lam, lxy_spc[:, i // 8, j // 8], label="SPC")
    ax[0, 1].set_title(f"Spectrum in ({i},{j})")
    ax[0, 1].legend(loc="upper right")
    ax[0, 1].grid()

    ax[0, 2].plot(t - t.min(), reconstructed_times[:, i, j] * 15 * 16, label="Datafusion")
    ax[0, 2].plot(t - t.min(), baseline_times[:, i, j], label="Baseline")
    ax[0, 2].plot(t - t.min(), txy_spc[:, i // 8, j // 8], label="SPC")
    ax[0, 2].set_title(f"Time in ({i},{j})")
    ax[0, 2].legend(loc="upper right")
    ax[0, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    plot_across_z,
    z=(0, x.shape[2] - 1, 1),
    i=(0, x.shape[-2] - 1, 1),
    j=(0, x.shape[-1] - 1, 1),
);

In [None]:
def decay(t, A, tau, c):
    return A * np.exp(-t / tau) + c


def time_image_to_lifetime(t, img, max_value):
    a_out = np.zeros((img.shape[1], img.shape[2]), dtype=np.float64)
    tau_out = np.zeros((img.shape[1], img.shape[2]), dtype=np.float64)
    c_out = np.zeros((img.shape[1], img.shape[2]), dtype=np.float64)
    intensity_image = img.sum(axis=0)
    intensity_image /= intensity_image.max()

    for i in range(img.shape[1]):
        for j in range(img.shape[2]):
            if intensity_image[i, j] < 0.0:
                a_out[i, j] = 0
                tau_out[i, j] = 0
                c_out[i, j] = 0

            else:
                params, covariance = curve_fit(
                    decay,
                    t,
                    img[:, i, j],
                    bounds=([0.0, 1e-6, 0.0], [max_value, 6.0, 0.1]),
                    p0=(max_value / 2.0, 2.0, 0.000001),
                    maxfev=5000,
                )
                a_out[i, j] = params[0]
                tau_out[i, j] = params[1]
                c_out[i, j] = params[2]

    return a_out, tau_out, c_out


In [None]:
def plot_zoom_results(
        x, t, lam, z_index,
        font_size_letters=16,
        save_name=None,
):
    # Defining spectral colorbar
    wavelengths_ticks = np.arange(510, 650, 1)
    spectral_colors = wavelength_to_srgb(wavelengths_ticks, "basic").T
    spectral_colors = linear_to_srgb(spectral_colors)
    spectral_colors /= spectral_colors.max(axis=1)[..., np.newaxis]
    spectral_cmap = LinearSegmentedColormap.from_list("spectrum", spectral_colors, N=len(wavelengths_ticks))

    # Defining fused scalebar
    fused_micro_width = 150  # in micrometers
    pixel_size = fused_micro_width / x.shape[-1]
    scalebar_length_micrometers = 30
    scalebar_length_pixels = scalebar_length_micrometers / pixel_size

    fused = x[:, :, z_index:z_index + 1]
    tmp = fused.sum(axis=1)[:, 0]
    tmp /= tmp.sum()
    fused_a, fused_tau, fused_c = time_image_to_lifetime(t, tmp, tmp.max())

    fused_tau_min = fused_tau.min()
    fused_tau_max = fused_tau.max()
    fused_h = (260 / 360) * (
            1 - (fused_tau - fused_tau_min) / (fused_tau_max - fused_tau_min))
    fused_s = np.ones_like(fused_tau)
    fused_v = fused_a / fused_a.max()
    fused_lifetime_image = hsv_to_rgb(np.stack([fused_h, fused_s, fused_v], axis=-1))
    fused_spectral_image = spectral_volume_to_color(lam, fused.sum(axis=0))[0]

    fig = plt.figure(figsize=(9, 6))
    gs = fig.add_gridspec(2, 3, width_ratios=[1, 1, 1], height_ratios=[1, 1])

    # Show spectral image
    ax00 = fig.add_subplot(gs[0, 0])
    fused00 = ax00.imshow(fused_spectral_image, cmap=spectral_cmap)
    ax00.set_title("Fused Spectrum")
    ax00.text(
        0.05, 0.05, "(a)",
        transform=ax00.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="w"
    )
    scalebar = AnchoredSizeBar(
        ax00.transData,
        scalebar_length_pixels,  # Length of scalebar in pixels
        f'{scalebar_length_micrometers} µm',  # Label for the scalebar
        'upper left',
        pad=0.3,
        color='white',
        frameon=False,
        size_vertical=2,
        fontproperties={"size": 12, "weight": "bold"},
    )
    ax00.add_artist(scalebar)
    ax00.axis("off")

    cbar = fig.colorbar(fused00, ax=ax00, fraction=0.046, pad=0.03, orientation="horizontal", label="Wavelength [nm]")
    cbar.set_ticks(np.linspace(0, 1, len(wavelengths_ticks[::30])))
    cbar.set_ticklabels([f"{w:.0f}" for w in wavelengths_ticks[::30]])

    # Show lifetime image
    hsv_colormap = np.vstack([np.linspace(0, fused_h.max(), 100)[::-1], np.ones(100), np.ones(100)]).T
    fused_lifetime_cmap = LinearSegmentedColormap.from_list("fused_lifetime", hsv_to_rgb(hsv_colormap), N=100)
    ax10 = fig.add_subplot(gs[1, 0])
    fused10 = ax10.imshow(fused_lifetime_image, cmap=fused_lifetime_cmap)
    ax10.set_title("Fused Lifetime")
    ax10.text(
        0.05, 0.05, "(d)",
        transform=ax10.transAxes,
        fontsize=font_size_letters,
        fontweight='bold',
        va='bottom',
        c="w"
    )
    scalebar = AnchoredSizeBar(
        ax10.transData,
        scalebar_length_pixels,  # Length of scalebar in pixels
        f'{scalebar_length_micrometers} µm',  # Label for the scalebar
        'upper left',
        pad=0.3,
        color='white',
        frameon=False,
        size_vertical=2,
        fontproperties={"size": 12, "weight": "bold"},
    )
    ax10.add_artist(scalebar)
    ax10.axis("off")

    cbar = fig.colorbar(fused10, ax=ax10, fraction=0.046, pad=0.03, orientation="horizontal", label="Lifetime [ns]")
    cbar.set_ticks(np.linspace(0, 1, 6))
    cbar.set_ticklabels([f"{t:.1f}" for t in np.linspace(fused_tau_min, fused_tau_max, 6)])

    # Select two ROIs
    ax00.add_patch(patches.Rectangle((75, 45), 12, 12, linewidth=2, edgecolor="green", facecolor="none"))
    ax00.add_patch(patches.Rectangle((61, 57), 16, 16, linewidth=1.5, edgecolor="gold", facecolor="none"))
    
    ax10.add_patch(patches.Rectangle((75, 45), 12, 12, linewidth=2, edgecolor="green", facecolor="none"))
    ax10.add_patch(patches.Rectangle((61, 57), 16, 16, linewidth=1.5, edgecolor="gold", facecolor="none"))
    
    
    # Show wavelength comparison of green ROI 
    ax01 = fig.add_subplot(gs[0, 1])
    fused_roi_spectrum = x[:, :, :, 45:58, 75:88].sum(axis=(0, 2, 3, 4))
    fused_roi_spectrum /= fused_roi_spectrum.max()
    spc_roi_spectrum = spc[:, :, 45//4:58//4, 75//4:88//4].sum(axis=(0, 2, 3))
    spc_roi_spectrum /= spc_roi_spectrum.max()
    ax01.plot(lam, spc_roi_spectrum, label="SPC", c="black", linestyle="--")
    ax01.plot(lam, fused_roi_spectrum, label="Datafusion", c="green", linestyle="-")
    ax01.set_title("Green Bead Spectrum")
    ax01.set_xlabel("Wavelength [nm]")
    ax01.set_ylabel("Intensity [a.u.]")
    ax01.grid()
    ax01.legend()
    
    
    # Show lifetime comparison of yellow ROI
    ax02 = fig.add_subplot(gs[0, 2])
    fused_roi_spectrum = x[:, :, :, 57:73, 61:77].sum(axis=(0, 2, 3, 4))
    fused_roi_spectrum /= fused_roi_spectrum.max()
    spc_roi_spectrum = spc[:, :, 57//4:73//4, 61//4:77//4].sum(axis=(0, 2, 3))
    spc_roi_spectrum /= spc_roi_spectrum.max()
    ax02.plot(lam, spc_roi_spectrum, label="SPC", c="black", linestyle="--")
    ax02.plot(lam, fused_roi_spectrum, label="Datafusion", c="gold", linestyle="-")
    ax02.set_title("Yellow Bead Spectrum")
    ax02.set_xlabel("Wavelength [nm]")
    ax02.set_ylabel("Intensity [a.u.]")
    ax02.grid()
    ax02.legend()
    
    
    # Show lifetime comparison of green ROI
    ax11 = fig.add_subplot(gs[1, 1])
    fused_roi_time = x[:, :, :, 45:58, 75:88].sum(axis=(1, 2, 3, 4)) 
    fused_roi_time /= fused_roi_time.max()
    spc_roi_time = spc[:, :, 45//4:58//4, 75//4:88//4].sum(axis=(1, 2, 3))
    spc_roi_time /= spc_roi_time.max()
    ax11.plot(t, spc_roi_time, label="SPC", c="black", linestyle="--")
    ax11.plot(t, fused_roi_time, label="Datafusion", c="green", linestyle="-")
    ax11.set_title("Green Bead Time Decay")
    ax11.set_xlabel("Time [ns]")
    ax11.set_ylabel("Intensity [a.u.]")
    ax11.grid()
    ax11.legend()
    
    
    # Show lifetime comparison of yellow ROI
    ax12 = fig.add_subplot(gs[1, 2])
    fused_roi_time = x[:, :, :, 57:73, 61:77].sum(axis=(1, 2, 3, 4))
    fused_roi_time /= fused_roi_time.max()
    spc_roi_time = spc[:, :, 57//4:73//4, 61//4:77//4].sum(axis=(1, 2, 3))
    spc_roi_time /= spc_roi_time.max()
    ax12.plot(t, spc_roi_time, label="SPC", c="black", linestyle="--")
    ax12.plot(t, fused_roi_time, label="Datafusion", c="gold", linestyle="-")
    ax12.set_title("Yellow Bead Time Decay")
    ax12.set_xlabel("Time [ns]")
    ax12.set_ylabel("Intensity [a.u.]")
    ax12.grid()
    ax12.legend()
    
    

    plt.tight_layout()
    plt.show()


plot_zoom_results(x, t, lam, 5)

