# Test `merfish3d-analysis` on simulated MERFISH data
The goal of this notebook is to show the performance of [`merfish3d-analysis`](https://github.com/QI2lab/merfish3d-analysis) on simulated MERFISH data. The output metric is the [F1-score](https://en.wikipedia.org/wiki/F-score) that determines how well `merfish3d-analysis` recovers the ground truth location and identity of the RNA molecules used to generate the simulation. We will use a single FOV with uniformly distributed RNA molecules. **Note:** `merfish3d-analysis` requires a GPU runtime and will not run without one.


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/QI2lab/merfish3d-analysis/blob/fix-install-script-for-numpy2-issue/examples/notebooks/Simulated_uniform_molecules.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/QI2lab/merfish3d-analysis/blob/fix-install-script-for-numpy2-issue/examples/notebooks/Simulated_uniform_molecules.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Install `merfish3d-analysis`
This is a modified version of the library installation that allow it to run using Google Colab servers. It is missing visualization tools and the ability to automatically stitched tiled data.  
  
**Note:** This installation can take 5-10 minutes, because all of the CUDA libraries have to be validated.

In [None]:
%%capture
!git clone https://github.com/qi2lab/merfish3d-analysis/
%cd merfish3d-analysis
!git checkout fix-install-script-for-numpy2-issue
!pip install -e .
!setup-colab

## Download simulation data
Roughly 1200 individual RNA molecules that are randomly distributed in space within a 41.6 𝜇m by 41.6 𝜇m by 15.0 𝜇m volume (x,y,z). The individual RNA molecules are imaged using a 16-bit Hamming Weight 4 Distance 4 codebook and a simulated microscope with realistic parameters and noise. The imaging simulation is performed in 8 rounds each with 3 channels, containg 2 MERFISH bits per round and a fidicual marker. Three simulations of the same RNA molecules are performed for three different axial steps sizes (0.315 𝜇m, 1.0 𝜇m, 1.5 𝜇m) to explore the impact of sufficent axial sampling when imaging.

In [None]:
%%capture
import zipfile
import os

%cd /content/
!gdown 1rRcV72oknYV2rL-etJzF1eRdWZJnKeuS

# Destination path for the unzipped content
unzip_destination = '/content/synthetic_data'

# Create the destination directory if it doesn't exist
os.makedirs(unzip_destination, exist_ok=True)

# Unzip the file
try:
    with zipfile.ZipFile("/content/synthetic_data.zip", 'r') as zip_ref:
        zip_ref.extractall(unzip_destination)
    print(f"File unzipped successfully to {unzip_destination}")
except zipfile.BadZipFile:
    print("Error: The downloaded file is not a valid zip file.")
except FileNotFoundError:
    print("Error: The file /content/synthetic_data.zip was not found.")
except Exception as e:
    print(f"An error occurred during unzipping: {e}")

## Test merfish3d-analysis on randomly distributed RNA with 𝚫z=0.315 𝞵m.
Because an axial spacing of 𝚫z=0.315 𝞵m is Shannon-Nyquist sampled for the objective (NA=1.35), here we decode in 3D. 
  
The steps are:  
1. Convert simulation data format to our (qi2lab) experimental format.
2. Convert qi2lab format to `merfish3d-analysis` datastore.
3. 3D deconvolution and 3D prediction of "spot-like" features in every bit.
4. 3D decoding to find RNA molecules and filter to limit blank codewords as necessary.
5. Calculate F1-score using ground truth RNA molecule locations.

In [None]:
!sim-convert "/content/synthetic_data/example_16bit_flat/0.315"
!sim-datastore "/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition"
!sim-preprocess "/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition"
!sim-decode "/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition"
!sim-f1score "/content/synthetic_data/example_16bit_flat/0.315"

## Display MERFISH data, ground truth RNA molecules, and decoded RNA molecules for 𝚫z=0.315 𝞵m.
  
For display purposes, we do not differentiate between RNA molecule identity here and just plot all molecules. For further exploration, we suggest using [`fishSCALE`](https://github.com/linnarsson-lab/FISHscale).

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import IntSlider, Checkbox, HBox, VBox, interactive_output, FloatText, Button, Layout
from IPython.display import display, clear_output
from merfish3danalysis.qi2labDataStore import qi2labDataStore

datastore = qi2labDataStore("/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition")
bit_images = []

for bit_idx in datastore.merfish_bits:
    bit_images.append(datastore.load_local_registered_image(tile=0,bit=bit_idx,return_future=False))
bit_images = np.asarray(bit_images,dtype=np.uint16)

SPOTS_CSV = "/content/synthetic_data/example_16bit_flat/0.315/GT_spots.csv"

# Optional: physical voxel sizes (dz, dy, dx). Leave as (1,1,1) if using pixel coords.
voxel_size = np.asarray(datastore.voxel_size_zyx_um,dtype=np.float32)

def compute_vmin_vmax(a, lo=1.0, hi=99.5):
    # Handle empty / constant arrays gracefully
    if not np.isfinite(a).any():
        return 0.0, 1.0
    p = np.percentile(a[np.isfinite(a)], (lo, hi))
    if p[0] == p[1]:
        return float(np.min(a)), float(np.max(a) if np.max(a) > np.min(a) else np.min(a)+1.0)
    return float(p[0]), float(p[1])

def precompute_projections(img_czyx: np.ndarray):
    # img shape: (C, Z, Y, X)
    C, Z, Y, X = img_czyx.shape
    proj_xy = []  # (Y, X) max over Z
    proj_xz = []  # (Z, X) max over Y
    proj_yz = []  # (Z, Y) max over X
    # Using np.max to keep simple; if your data are huge, consider np.nanmax or chunked compute.
    for c in range(C):
        vol = img_czyx[c]
        proj_xy.append(np.max(vol, axis=0))
        proj_xz.append(np.max(vol, axis=1))
        proj_yz.append(np.max(vol, axis=2))
    return proj_xy, proj_xz, proj_yz


def load_spots(csv_path):
    if csv_path is None or not os.path.exists(csv_path):
        return None
    df = pd.read_csv(csv_path)
    required = {"Z", "Y", "X"}
    if not required.issubset(df.columns.str.lower()):
        # Try handle case-sensitive columns
        cols = {c.lower(): c for c in df.columns}
        if not required.issubset(cols.keys()):
            raise ValueError("CSV must have columns: z, y, x")
        df = df.rename(columns={cols['Z']: 'z', cols['Y']: 'y', cols['X']: 'x'})
    else:
        df = df.rename(columns={c: c.lower() for c in df.columns})
    return df[["z", "x", "y"]].copy()

def show_czyx_with_spots(img_czyx: np.ndarray,
                         spots_df: pd.DataFrame | None = None,
                         voxel_size=(1.0, 1.0, 1.0)):
    C, Z, Y, X = img_czyx.shape
    dz, dy, dx = voxel_size

    # Precompute projections for responsiveness
    proj_xy, proj_xz, proj_yz = precompute_projections(img_czyx)

    # Pre-filter spots to be in-bounds and pre-scale for extents
    if spots_df is not None and len(spots_df):
        # Keep only points within volume bounds
        in_bounds = (
            (spots_df["z"] >= 0) & (spots_df["z"] < Z) &
            (spots_df["y"] >= 0) & (spots_df["y"] < Y) &
            (spots_df["x"] >= 0) & (spots_df["x"] < X)
        )
        spots = spots_df.loc[in_bounds].to_numpy(dtype=float)
        offset = [
            0, 
            1*bit_images[0,:].shape[1]/2*datastore.voxel_size_zyx_um[1],
            1*bit_images[0,:].shape[2]/2*datastore.voxel_size_zyx_um[2]
        ]
        spots = spots + offset
        sx = spots[:, 2]
        sy = spots[:, 1]
        sz = spots[:, 0]
    else:
        spots = None

    # Figure & axes layout
    fig, axes = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)
    ax_xy, ax_xz, ax_yz = axes

    # Extents so scatters align with pixels (origin lower-left)
    extent_xy = [0, X * dx, 0, Y * dy]   # x across, y up
    extent_xz = [0, X * dx, 0, Z * dz]   # x across, z up
    extent_yz = [0, Y * dy, 0, Z * dz]   # y across, z up

    # Initial channel
    c0 = 0
    im_xy = ax_xy.imshow(proj_xy[c0], origin='lower', extent=extent_xy, interpolation='nearest')
    im_xz = ax_xz.imshow(proj_xz[c0], origin='lower', extent=extent_xz, interpolation='nearest')
    im_yz = ax_yz.imshow(proj_yz[c0], origin='lower', extent=extent_yz, interpolation='nearest')

    # Titles & labels
    ax_xy.set_title("XY (max over Z)")
    ax_xz.set_title("XZ (max over Y)")
    ax_yz.set_title("YZ (max over X)")
    ax_xy.set_xlabel("X")
    ax_xy.set_ylabel("Y")
    ax_xz.set_xlabel("X")
    ax_xz.set_ylabel("Z")
    ax_yz.set_xlabel("Y")
    ax_yz.set_ylabel("Z")

    # Robust contrast per view, per channel
    def apply_contrast(ci):
        vmin, vmax = compute_vmin_vmax(proj_xy[ci])
        im_xy.set_clim(vmin, vmax)
        vmin, vmax = compute_vmin_vmax(proj_xz[ci])
        im_xz.set_clim(vmin, vmax)
        vmin, vmax = compute_vmin_vmax(proj_yz[ci])
        im_yz.set_clim(vmin, vmax)

    apply_contrast(c0)

    # Spot overlays
    scat_xy = scat_xz = scat_yz = None
    if spots is not None and len(spots):
        scat_xy = ax_xy.scatter(sx, sy, s=10, marker='o', linewidths=0, alpha=0.6)
        scat_xz = ax_xz.scatter(sx, sz, s=10, marker='o', linewidths=0, alpha=0.6)
        scat_yz = ax_yz.scatter(sy, sz, s=10, marker='o', linewidths=0, alpha=0.6)

    # Widgets
    slider = IntSlider(description="Channel", min=0, max=C-1, step=1, value=0, continuous_update=False)
    toggle_spots = Checkbox(description="Show spots", value=(spots is not None and len(spots) > 0))
    dz_box = FloatText(description="dz", value=dz, layout=Layout(width='120px'))
    dy_box = FloatText(description="dy", value=dy, layout=Layout(width='120px'))
    dx_box = FloatText(description="dx", value=dx, layout=Layout(width='120px'))
    apply_vox_btn = Button(description="Apply voxel size", layout=Layout(width='160px'))

    def on_apply_vox_clicked(_btn):
        # Update extents according to new voxel sizes
        new_dz, new_dy, new_dx = float(dz_box.value), float(dy_box.value), float(dx_box.value)
        im_xy.set_extent([0, X * new_dx, 0, Y * new_dy])
        im_xz.set_extent([0, X * new_dx, 0, Z * new_dz])
        im_yz.set_extent([0, Y * new_dy, 0, Z * new_dz])
        # Update scatter coords (scaled) if present
        if spots is not None and len(spots):
            sx2, sy2, sz2 = spots[:, 2] * new_dx, spots[:, 1] * new_dy, spots[:, 0] * new_dz
            if scat_xy: 
                scat_xy.set_offsets(np.c_[sx2, sy2])
            if scat_xz: 
                scat_xz.set_offsets(np.c_[sx2, sz2])
            if scat_yz: 
                scat_yz.set_offsets(np.c_[sy2, sz2])
        fig.canvas.draw_idle()

    apply_vox_btn.on_click(on_apply_vox_clicked)

    def update(ci, show_spots):
        im_xy.set_data(proj_xy[ci])
        im_xz.set_data(proj_xz[ci])
        im_yz.set_data(proj_yz[ci])
        apply_contrast(ci)
        # toggle visibility of scatters
        if scat_xy: scat_xy.set_visible(show_spots)
        if scat_xz: scat_xz.set_visible(show_spots)
        if scat_yz: scat_yz.set_visible(show_spots)
        fig.canvas.draw_idle()

    out = interactive_output(update, {"ci": slider, "show_spots": toggle_spots})

    controls = HBox([slider, toggle_spots, dx_box, dy_box, dz_box, apply_vox_btn])
    display(VBox([controls, out]))
    return fig

spots_df = load_spots(SPOTS_CSV)

# Show the viewer
fig = show_czyx_with_spots(bit_images, spots_df=spots_df, voxel_size=voxel_size)
plt.show()


## Test merfish3d-analysis on randomly distributed RNA with 𝚫z=1.0 𝞵m.
Because an axial spacing of 𝚫z=1.0 𝞵m is larger than Shannon-Nyquist sampling for the objective (NA=1.35), here we decode in plane-by-plane and then collapse spots in adajacent z planes.
  
The steps are:  
1. Convert simulation data format to our (qi2lab) experimental format.
2. Convert qi2lab format to `merfish3d-analysis` datastore.
3. 2D deconvolution and 2D prediction of "spot-like" features plane-by-plane in every bit.
4. 2D decoding to find RNA molecules plane-by-plane, then collapse indentical molecules in adajacent z-planes to one decoded moelcule, and filter to limit blank codewords as necessary.
5. Calculate F1-score using ground truth RNA molecule locations.


In [None]:
!sim-convert "/content/synthetic_data/example_16bit_flat/1.0"
!sim-datastore "/content/synthetic_data/example_16bit_flat/1.0/sim_acquisition"
!sim-preprocess "/content/synthetic_data/example_16bit_flat/1.0/sim_acquisition"
!sim-decode "/content/synthetic_data/example_16bit_flat/1.0/sim_acquisition"
!sim-f1score "/content/synthetic_data/example_16bit_flat/1.0"

## Test merfish3d-analysis on randomly distributed RNA with 𝚫z=1.5 𝞵m.
Because an axial spacing of 𝚫z=1.5 𝞵m is larger than Shannon-Nyquist sampling for the objective (NA=1.35), here we decode in plane-by-plane and then collapse spots in adajacent z planes.
  
The steps are:  
1. Convert simulation data format to our (qi2lab) experimental format.
2. Convert qi2lab format to `merfish3d-analysis` datastore.
3. 2D deconvolution and 2D prediction of "spot-like" features plane-by-plane in every bit.
4. 2D decoding to find RNA molecules plane-by-plane, then collapse indentical molecules in adajacent z-planes to one decoded moelcule, and filter to limit blank codewords as necessary.
5. Calculate F1-score using ground truth RNA molecule locations.

In [None]:
!sim-convert "/content/synthetic_data/example_16bit_flat/1.5"
!sim-datastore "/content/synthetic_data/example_16bit_flat/1.5/sim_acquisition"
!sim-preprocess "/content/synthetic_data/example_16bit_flat/1.5/sim_acquisition"
!sim-decode "/content/synthetic_data/example_16bit_flat/1.5/sim_acquisition"
!sim-f1score "/content/synthetic_data/example_16bit_flat/1.5"