In [None]:
import numpy as np
import pyvista as pv
import matplotlib.pyplot as plt
from matplotlib import cm
from ipywidgets import interact
from scipy.linalg import hadamard
from scipy.optimize import curve_fit
from matplotlib.colors import hsv_to_rgb, LinearSegmentedColormap

from datafusion.fusion import FusionCG
from datafusion.baseline import baseline
from datafusion.utils import mono_exponential_decay_numpy as decay
from datafusion.utils import (
    RESOURCES_PATH,
    spectral_volume_to_color,
    time_volume_to_lifetime,
    linear_to_srgb,
    wavelength_to_srgb,
    load_data,
)

XY_DIM = 128
Z_START = 2
Z_END = -4

CELLS_PATH = RESOURCES_PATH / "acquisitions" / "cells" / "cells_0.25cr.npz"
TAU = r"$\tau$"

## Data loading


In [None]:
spc, cmos, t, lam = load_data(CELLS_PATH, max_xy_size=XY_DIM)
cmos = cmos[Z_START:Z_END]
dt = t[1] - t[0]

## Data exploration: CMOS Volume

In [None]:
pv.set_jupyter_backend("trame")
pl = pv.Plotter()
print("Volume shape: ", cmos.shape)
vol = (cmos - cmos.min()) / (cmos.max() - cmos.min())
_ = pl.add_volume(vol.swapaxes(0, -1), cmap="gray_r", show_scalar_bar=False)
pl.add_scalar_bar("Intensity [a.u.]", vertical=False, title_font_size=20, label_font_size=16)
pl.show()

## Data exploration: CMOS + SPC


In [None]:
initial_spectrums = np.sum(spc, axis=0)
min_spectrums, max_spectrums = np.min(initial_spectrums), np.max(initial_spectrums)

initial_times = np.sum(spc, axis=1)
min_times, max_times = np.min(initial_times), np.max(initial_times)

resolution_diff_factor = int(cmos.shape[-1] / spc.shape[-1])
normalized_cmos = cmos / cmos.max()


def spc_spectrum_time_in_a_point(cmos_z=10, spc_i=19, spc_j=17):
    _, ax = plt.subplots(2, 3, figsize=(9, 6))
    ax[0, 0].imshow(normalized_cmos[cmos_z], cmap="gray", vmin=0, vmax=1)
    ax[0, 0].scatter([spc_j * resolution_diff_factor], [spc_i * resolution_diff_factor], c="w")
    ax[0, 0].set_title(f"CMOS in z={cmos_z}")

    ax[1, 0].imshow(spectral_volume_to_color(lam, spc.sum(axis=0)[:, np.newaxis])[0])
    ax[1, 0].scatter([spc_j], [spc_i], c="w")
    ax[1, 0].set_title(f"SPC Spectrum Colored")

    ax[0, 1].plot(lam, spc.sum(axis=(0, 2, 3)))
    ax[0, 1].set_title(f"Global spectrum - {len(lam)} channels")
    ax[0, 1].set_xlabel("Wavelength [nm]")
    ax[0, 1].grid()

    ax[0, 2].plot(t, spc.sum(axis=(1, 2, 3)))
    ax[0, 2].set_title(f"Global time decay - {len(t)} temporal points")
    ax[0, 2].set_xlabel("Time [ns]")
    ax[0, 2].grid()

    ax[1, 1].plot(lam, initial_spectrums[:, spc_i, spc_j])
    ax[1, 1].set_ylim(min_spectrums, max_spectrums)
    ax[1, 1].set_title(f"Spectrum in ({spc_i},{spc_j})")
    ax[1, 1].set_xlabel("Wavelength [nm]")
    ax[1, 1].grid()

    ax[1, 2].plot(t, initial_times[:, spc_i, spc_j])
    ax[1, 2].set_ylim(min_times, max_times)
    ax[1, 2].set_title(f"Time in ({spc_i},{spc_j})")
    ax[1, 2].set_xlabel("Time [ns]")
    ax[1, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    spc_spectrum_time_in_a_point,
    cmos_z=(0, cmos.shape[0] - 1, 1),
    spc_i=(0, spc.shape[-2] - 1, 1),
    spc_j=(0, spc.shape[-1] - 1, 1),
);

In [None]:
weights = {
    "spatial": 0.5,
    "spectro_temporal": 0.5,
}

fuse_with_cg = FusionCG(
    spc, cmos,
    weights=weights,
    init_type="baseline",
    tol=1e-6,
    mask_noise=False,
    total_energy=1,
    device="mps",
    seed=42,
)

In [None]:
%%time
x, spc_out, cmos_out = fuse_with_cg(max_iterations=20, eps=1e-8, return_numpy=True)

In [None]:
x_baseline = baseline(cmos_out, spc_out, device="cpu", return_numpy=True)

slices_rgb = spectral_volume_to_color(lam, np.sum(x, axis=0))
slices_rgb_spc = spectral_volume_to_color(lam, spc_out[:, :, np.newaxis, :, :].sum(axis=0))[0]

spectrums = np.sum(x, axis=(0, 3, 4)).T
times = np.sum(x, axis=(1, 3, 4)).T


def plot_across_z(z=10, i=77, j=70):
    _, ax = plt.subplots(2, 3, figsize=(12, 8))
    ax[0, 0].imshow(slices_rgb_spc)
    ax[0, 0].scatter([j // 4], [i // 4], c="w")
    ax[0, 0].set_title(f"SPC")

    ax[1, 1].plot(lam, spectrums[z])
    ax[1, 1].set_ylim(spectrums.min(), spectrums.max())
    ax[1, 1].set_title(f"Global Spectrum in z={z}")
    ax[1, 1].grid()

    ax[1, 2].plot(t, times[z])
    ax[1, 2].set_ylim(times.min(), times.max())
    ax[1, 2].set_title(f"Global Time in z={z}")
    ax[1, 2].grid()

    reconstructed_spectrums = np.sum(x[:, :, z, :, :], axis=0)
    lxy_spc = np.sum(spc_out, axis=0)

    reconstructed_times = np.sum(x[:, :, z, :, :], axis=1)
    txy_spc = np.sum(spc_out, axis=1)

    ax[1, 0].imshow(slices_rgb[z])
    ax[1, 0].scatter([j], [i], c="w")
    ax[1, 0].set_title(f"Reconstruction Spectral Colored z={z}")

    ax[0, 1].plot(lam, reconstructed_spectrums[:, i, j] * 10 * 16, label="DF")
    ax[0, 1].plot(lam, lxy_spc[:, i // 4, j // 4], label="SPC")
    ax[0, 1].set_title(f"Spectrum in ({i},{j})")
    ax[0, 1].legend(loc="upper right")
    ax[0, 1].grid()

    params, covariance = curve_fit(
        decay,
        t,
        reconstructed_times[:, i, j] / reconstructed_times[:, i, j].max(),
        bounds=([0.0, 1e-6, 0.0], [1, 6.0, 0.1]),
        p0=(0.5, 2.0, 0.00001),
        maxfev=5000,
    )
    a, tau, c = params

    ax[0, 2].plot(t, txy_spc[:, i // 4, j // 4] / txy_spc[:, i // 4, j // 4].max(), label="SPC", c="C1")
    ax[0, 2].scatter(t, reconstructed_times[:, i, j] / reconstructed_times[:, i, j].max(), label="DF", c="C0",
                     marker="+")
    ax[0, 2].plot(t, decay(t, a, tau, c), label=f"DF Fit {TAU}: {tau:.2f} ns", c="C0")
    ax[0, 2].set_title(f"Time in ({i},{j})")
    ax[0, 2].legend(loc="upper right")
    ax[0, 2].grid()

    plt.tight_layout()
    plt.show()


interact(
    plot_across_z,
    z=(0, x.shape[2] - 1, 1),
    i=(0, x.shape[-2] - 1, 1),
    j=(0, x.shape[-1] - 1, 1),
);

## Presentation Plots

In [None]:
def get_lifetime_tensor(tensor):
    lifetime_volume, tau_min, tau_max = time_volume_to_lifetime(
        t, tensor, tau_clip=(1, 3), noise_thr=0.1
    )
    return lifetime_volume

In [None]:
def plot_lifetime_image(tensor, tau_min, tau_max, save="flim_image.png"):
    hsv_colormap = np.vstack([
        np.linspace(0, 26 / 36, 100)[::-1],
        np.ones(100), np.ones(100),
    ]).T

    lifetime_cmap = LinearSegmentedColormap.from_list(
        "fused_lifetime", hsv_to_rgb(hsv_colormap), N=100
    )

    fig, ax = plt.subplots(1, 1, figsize=(4, 5))
    ax.matshow(
        tensor,
        cmap=lifetime_cmap,
    )
    cbar = fig.colorbar(
        cm.ScalarMappable(norm=None, cmap=lifetime_cmap),
        ax=ax, fraction=0.046, pad=0.03,
        orientation="horizontal",
        label="Lifetime [ns]",
    )
    cbar.set_ticks(np.linspace(0, 1, 6))
    ticklabs = [f"{time:.1f}" for time in np.linspace(tau_min, tau_max, 6)]
    cbar.set_ticklabels(ticklabs)

    plt.title("FLIM Image", fontsize=18)
    plt.tight_layout()
    if save:
        plt.savefig(save)
    plt.show()


spc_lifetime = spc.sum(axis=1)[:, np.newaxis]
spc_lifetime = get_lifetime_tensor(spc_lifetime)[0]
plot_lifetime_image(spc_lifetime, tau_min=1, tau_max=3)

In [None]:
def plot_multispectral_image(tensor, lambdas, save="spectral_image.png"):
    wavelengths_ticks = np.arange(540, 691, 1)
    spectral_colors = wavelength_to_srgb(wavelengths_ticks, "basic").T
    spectral_colors = linear_to_srgb(spectral_colors)
    spectral_colors /= spectral_colors.max(axis=1)[..., np.newaxis]
    spectral_cmap = LinearSegmentedColormap.from_list(
        "spectrum",
        spectral_colors,
        N=len(wavelengths_ticks),
    )

    tensor = spectral_volume_to_color(lambdas, tensor)[0]

    fig, ax = plt.subplots(1, 1, figsize=(4, 5))
    ax.matshow(tensor, cmap=spectral_cmap)
    cbar = fig.colorbar(
        cm.ScalarMappable(norm=None, cmap=spectral_cmap),
        ax=ax, fraction=0.046, pad=0.02,
        orientation="horizontal",
        label="Wavelength [nm]",
    )
    cbar.set_ticks(np.linspace(0, 1, len(wavelengths_ticks[::30])))
    cbar.set_ticklabels([f"{w:.0f}" for w in wavelengths_ticks[::30]])

    plt.title("Multispectral Image", fontsize=18)
    plt.tight_layout()
    if save:
        plt.savefig(save)
    plt.show()


spc_spectrum = spc.sum(axis=0)[:, np.newaxis]
plot_multispectral_image(spc_spectrum, lam)

In [None]:
# This cell provides a visualization of the compressed sensing reconstruction process.
# It does not reflect the actual reconstruction, it is just for visualization purposes.
img_size = 128


def normalize(image):
    return (image - np.min(image)) / (np.max(image) - np.min(image))


def reconstruct(A, y):
    x_hat, _, _, _ = np.linalg.lstsq(A, y, rcond=None)
    return normalize(x_hat.reshape(img_size, img_size))


def walsh_hadamard(n: int, dtype) -> np.ndarray:
    def sequency_order(row):
        return np.sum(np.diff(row) != 0)

    H = hadamard(n, dtype)
    indices = sorted(range(n), key=lambda i: sequency_order(H[i]))
    return H[indices]


def measure(A, x):
    return A @ x.flatten()


orig_img = cmos[10]
orig_img = normalize(orig_img)

N = orig_img.shape[0] * orig_img.shape[1]
W1 = walsh_hadamard(N, float)
W2 = np.vstack([(W1[i].reshape(img_size, img_size).T).reshape(1, -1) for i in range(N)])
W = np.hstack([W1, W2])
W = W.reshape(-1, N)

for i in range(32):
    show_index = i + 10
    print(f"Processing step {i + 1} of 32...")
    fig, ax = plt.subplots(1, 4, figsize=(13, 3))

    ax[0].imshow(orig_img, cmap='gray')
    ax[0].set_title("Sample")

    ax[1].imshow(W[show_index].reshape(img_size, img_size), cmap='gray')
    ax[1].set_title(f"Pattern {i + 1}")

    A = W[0:(i + 1) * 256]
    y = measure(A, orig_img.flatten())

    ax[2].plot(y[10: show_index + 1])
    ax[2].set_title("Measurements")
    ax[2].set_xlim(0, 32)
    ax[2].set_aspect('auto')
    ax[2].set_yticks([])

    recon = reconstruct(A, y)
    ax[3].imshow(recon, cmap='gray')
    ax[3].set_title("Reconstruction")

    # removes axis
    for ax_index, a in enumerate(ax):
        if ax_index != 2:
            a.axis('off')

    print("Saving figure...")
    plt.savefig(f"cs/recon_{str(i).zfill(2)}.png", bbox_inches='tight', dpi=300)
    plt.close(fig)
    # plt.show()

!ffmpeg -stream_loop 2 -r 5 -i cs/recon_%02d.png -vcodec mpeg4 -y cs/cs.mp4