# Test `merfish3d-analysis` on simulated MERFISH data
The goal of this notebook is to show the performance of [`merfish3d-analysis`](https://github.com/QI2lab/merfish3d-analysis) on simulated MERFISH data. The output metric is the [F1-score](https://en.wikipedia.org/wiki/F-score) that determines how well `merfish3d-analysis` recovers the ground truth location and identity of the RNA molecules used to generate the simulation. We will use a single FOV with uniformly distributed RNA molecules. **Note:** `merfish3d-analysis` requires a GPU runtime and will not run without one.


<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/QI2lab/merfish3d-analysis/blob/fix-install-script-for-numpy2-issue/examples/notebooks/Simulated_random_molecules.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/QI2lab/merfish3d-analysis/blob/fix-install-script-for-numpy2-issue/examples/notebooks/Simulated_random_molecules.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Install `merfish3d-analysis`
This is a modified version of the library installation that allow it to run using Google Colab servers. It is missing visualization tools and the ability to automatically stitched tiled data.  
  
**Note:** This installation can take 5-10 minutes, because all of the CUDA libraries have to be validated.

In [None]:
%%capture
!git clone https://github.com/qi2lab/merfish3d-analysis/
%cd merfish3d-analysis
!git checkout fix-install-script-for-numpy2-issue
!pip install -e .
!setup-colab

## Download simulation data
Roughly 1200 individual RNA molecules that are randomly distributed in space within a 41.6 𝜇m by 41.6 𝜇m by 15.0 𝜇m volume (x,y,z). The individual RNA molecules are imaged using a 16-bit Hamming Weight 4 Distance 4 codebook and a simulated microscope with realistic parameters and noise. The imaging simulation is performed in 8 rounds each with 3 channels, containg 2 MERFISH bits per round and a fidicual marker. Three simulations of the same RNA molecules are performed for three different axial steps sizes (0.315 𝜇m, 1.0 𝜇m, 1.5 𝜇m) to explore the impact of sufficent axial sampling when imaging.

In [None]:
%%capture
import zipfile
import os

%cd /content/
!gdown 1rRcV72oknYV2rL-etJzF1eRdWZJnKeuS

# Destination path for the unzipped content
unzip_destination = '/content/synthetic_data'

# Create the destination directory if it doesn't exist
os.makedirs(unzip_destination, exist_ok=True)

# Unzip the file
try:
    with zipfile.ZipFile("/content/synthetic_data.zip", 'r') as zip_ref:
        zip_ref.extractall(unzip_destination)
    print(f"File unzipped successfully to {unzip_destination}")
except zipfile.BadZipFile:
    print("Error: The downloaded file is not a valid zip file.")
except FileNotFoundError:
    print("Error: The file /content/synthetic_data.zip was not found.")
except Exception as e:
    print(f"An error occurred during unzipping: {e}")

## Test merfish3d-analysis on randomly distributed RNA with 𝚫z=0.315 𝞵m.
Because an axial spacing of 𝚫z=0.315 𝞵m is Shannon-Nyquist sampled for the objective (NA=1.35), here we decode in 3D. 
  
The steps are:  
1. Convert simulation data format to our (qi2lab) experimental format.
2. Convert qi2lab format to `merfish3d-analysis` datastore.
3. 3D deconvolution and 3D prediction of "spot-like" features in every bit.
4. 3D decoding to find RNA molecules and filter to limit blank codewords as necessary.
5. Calculate F1-score using ground truth RNA molecule locations.

In [None]:
!sim-convert "/content/synthetic_data/example_16bit_flat/0.315"
!sim-datastore "/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition"
!sim-preprocess "/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition"
!sim-decode "/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition"
!sim-f1score "/content/synthetic_data/example_16bit_flat/0.315"

## Display MERFISH data, ground truth RNA molecules, and decoded RNA molecules for 𝚫z=0.315 𝞵m.
  
For display purposes, we do not differentiate between RNA molecule identity here and just plot all molecules. For further exploration, we suggest using [`fishSCALE`](https://github.com/linnarsson-lab/FISHscale).

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from ipywidgets import IntSlider, Checkbox, HBox, VBox, interactive_output, FloatText, Button, Layout, HTML
from IPython.display import display
from merfish3danalysis.qi2labDataStore import qi2labDataStore

# ------------------------------
# 0) Load imagery and metadata
# ------------------------------
datastore = qi2labDataStore("/content/synthetic_data/example_16bit_flat/0.315/sim_acquisition")

bit_images_list = []
for bit_idx in datastore.merfish_bits:
    vol = datastore.load_local_registered_image(tile=0, bit=bit_idx, return_future=False)
    bit_images_list.append(vol)

bit_images = np.asarray(bit_images_list, dtype=np.uint16)
voxel_size_zyx_um = np.asarray(datastore.voxel_size_zyx_um, dtype=np.float32)
tile_offset_zyx_um = np.asarray(datastore.load_local_stage_position_zyx_um(tile=0, round=0, return_future=False), dtype=np.float32)

spots_csv_path = "/content/synthetic_data/example_16bit_flat/0.315/GT_spots.csv"

# ------------------------------
# 1) Utilities
# ------------------------------
def compute_vmin_vmax(a, lo=1.0, hi=99.5):
    if a is None:
        return 0.0, 1.0

    finite = np.isfinite(a)
    if not finite.any():
        return 0.0, 1.0

    vals = a[finite]
    p_lo, p_hi = np.percentile(vals, (lo, hi))
    if p_lo == p_hi:
        vmin = float(np.min(vals))
        vmax = float(np.max(vals))
        if vmax <= vmin:
            vmax = vmin + 1.0
        return vmin, vmax

    return float(p_lo), float(p_hi)

def precompute_projections(img_czyx):
    # img_czyx shape: (C, Z, Y, X)
    C, Z, Y, X = img_czyx.shape

    proj_xy = []
    proj_xz = []
    proj_yz = []

    for c in range(C):
        vol = img_czyx[c]
        proj_xy.append(np.max(vol, axis=0))  # (Y, X)
        proj_xz.append(np.max(vol, axis=1))  # (Z, X)
        proj_yz.append(np.max(vol, axis=2))  # (Z, Y)

    return proj_xy, proj_xz, proj_yz

def load_spots_with_offset(csv_path, voxel_size_zyx_um, image_zyx_shape):
    if csv_path is None:
        return None

    if not os.path.exists(csv_path):
        return None

    df = pd.read_csv(csv_path)

    lower = {c.lower(): c for c in df.columns}
    required = {"z", "y", "x"}
    if not required.issubset(lower.keys()):
        raise ValueError("CSV must contain columns named z, y, x (microns).")

    df = df.rename(columns={
        lower["z"]: "z",
        lower["x"]: "y",
        lower["y"]: "x"
    })

    # offset = [0, Y/2 * dy, X/2 * dx] in microns
    Z = int(image_zyx_shape[0])
    Y = int(image_zyx_shape[1])
    X = int(image_zyx_shape[2])

    dz = float(voxel_size_zyx_um[0])
    dy = float(voxel_size_zyx_um[1])
    dx = float(voxel_size_zyx_um[2])

    off_z = 0.0
    off_y = 0.5 * Y * dy
    off_x = 0.5 * X * dx

    df["z"] = df["z"] + off_z
    df["y"] = df["y"] + off_y
    df["x"] = df["x"] + off_x
    # ------------------------------------------------------------------

    return df[["z", "y", "x"]].copy()

def world_extents_2d(Z, Y, X, dz, dy, dx, tz, ty, tx):
    # Returns extents in world microns for each projection.
    extent_xy = [tx, tx + X * dx, ty, ty + Y * dy]
    extent_xz = [tx, tx + X * dx, tz, tz + Z * dz]
    extent_yz = [ty, ty + Y * dy, tz, tz + Z * dz]
    return extent_xy, extent_xz, extent_yz

# ------------------------------
# 2) Interactive viewer
# ------------------------------
def show_czyx_with_spots(img_czyx, spots_df, voxel_size_zyx_um, translate_zyx_um):
    C, Z, Y, X = img_czyx.shape

    dz = float(voxel_size_zyx_um[0])
    dy = float(voxel_size_zyx_um[1])
    dx = float(voxel_size_zyx_um[2])

    tz = float(translate_zyx_um[0])
    ty = float(translate_zyx_um[1])
    tx = float(translate_zyx_um[2])

    proj_xy, proj_xz, proj_yz = precompute_projections(img_czyx)

    extent_xy, extent_xz, extent_yz = world_extents_2d(Z, Y, X, dz, dy, dx, tz, ty, tx)

    sz = None
    sy = None
    sx = None
    if spots_df is not None and len(spots_df) > 0:
        xmin = extent_xy[0]
        xmax = extent_xy[1]
        ymin = extent_xy[2]
        ymax = extent_xy[3]
        zmin = extent_xz[2]
        zmax = extent_xz[3]

        in_bounds = (
            (spots_df["z"] >= zmin) &
            (spots_df["z"] <  zmax) &
            (spots_df["y"] >= ymin) &
            (spots_df["y"] <  ymax) &
            (spots_df["x"] >= xmin) &
            (spots_df["x"] <  xmax)
        )
        pts = spots_df.loc[in_bounds].to_numpy(dtype=float)
        sz = pts[:, 0]
        sy = pts[:, 1]
        sx = pts[:, 2]

    fig, axes = plt.subplots(1, 3, figsize=(14, 4), constrained_layout=True)
    ax_xy, ax_xz, ax_yz = axes

    c0 = 0
    im_xy = ax_xy.imshow(proj_xy[c0], origin="lower", extent=extent_xy, interpolation="nearest")
    im_xz = ax_xz.imshow(proj_xz[c0], origin="lower", extent=extent_xz, interpolation="nearest")
    im_yz = ax_yz.imshow(proj_yz[c0], origin="lower", extent=extent_yz, interpolation="nearest")

    ax_xy.set_title("XY (max over Z)")
    ax_xz.set_title("XZ (max over Y)")
    ax_yz.set_title("YZ (max over X)")

    ax_xy.set_xlabel("X (µm)")
    ax_xy.set_ylabel("Y (µm)")
    ax_xz.set_xlabel("X (µm)")
    ax_xz.set_ylabel("Z (µm)")
    ax_yz.set_xlabel("Y (µm)")
    ax_yz.set_ylabel("Z (µm)")

    ax_xy.set_aspect("equal", adjustable="box")
    ax_xz.set_aspect("equal", adjustable="box")
    ax_yz.set_aspect("equal", adjustable="box")

    def apply_contrast(ci):
        vmin_xy, vmax_xy = compute_vmin_vmax(proj_xy[ci])
        im_xy.set_clim(vmin_xy, vmax_xy)

        vmin_xz, vmax_xz = compute_vmin_vmax(proj_xz[ci])
        im_xz.set_clim(vmin_xz, vmax_xz)

        vmin_yz, vmax_yz = compute_vmin_vmax(proj_yz[ci])
        im_yz.set_clim(vmin_yz, vmax_yz)

    apply_contrast(c0)

    scat_xy = None
    scat_xz = None
    scat_yz = None
    if sx is not None:
        scat_xy = ax_xy.scatter(sx, sy, s=10, marker="o", linewidths=0, alpha=0.6)
        scat_xz = ax_xz.scatter(sx, sz, s=10, marker="o", linewidths=0, alpha=0.6)
        scat_yz = ax_yz.scatter(sy, sz, s=10, marker="o", linewidths=0, alpha=0.6)

    slider = IntSlider(description="Channel", min=0, max=C - 1, step=1, value=0, continuous_update=False)
    toggle_spots = Checkbox(description="Show spots", value=(sx is not None))

    dz_box = FloatText(description="dz (µm/px)", value=dz, layout=Layout(width="140px"))
    dy_box = FloatText(description="dy (µm/px)", value=dy, layout=Layout(width="140px"))
    dx_box = FloatText(description="dx (µm/px)", value=dx, layout=Layout(width="140px"))

    tz_box = FloatText(description="tz (µm)", value=tz, layout=Layout(width="120px"))
    ty_box = FloatText(description="ty (µm)", value=ty, layout=Layout(width="120px"))
    tx_box = FloatText(description="tx (µm)", value=tx, layout=Layout(width="120px"))

    apply_btn = Button(description="Apply scale/translate", layout=Layout(width="200px"))
    tip = HTML("<b>World mapping:</b> world = translate + scale * index (napari-style).")

    def on_apply_clicked(_btn):
        new_dz = float(dz_box.value)
        new_dy = float(dy_box.value)
        new_dx = float(dx_box.value)

        new_tz = float(tz_box.value)
        new_ty = float(ty_box.value)
        new_tx = float(tx_box.value)

        ex_xy, ex_xz, ex_yz = world_extents_2d(Z, Y, X, new_dz, new_dy, new_dx, new_tz, new_ty, new_tx)
        im_xy.set_extent(ex_xy)
        im_xz.set_extent(ex_xz)
        im_yz.set_extent(ex_yz)

        ax_xy.set_aspect("equal", adjustable="box")
        ax_xz.set_aspect("equal", adjustable="box")
        ax_yz.set_aspect("equal", adjustable="box")

        fig.canvas.draw_idle()

    def update(ci, show_spots):
        im_xy.set_data(proj_xy[ci])
        im_xz.set_data(proj_xz[ci])
        im_yz.set_data(proj_yz[ci])

        apply_contrast(ci)

        if scat_xy is not None:
            scat_xy.set_visible(show_spots)
        if scat_xz is not None:
            scat_xz.set_visible(show_spots)
        if scat_yz is not None:
            scat_yz.set_visible(show_spots)

        fig.canvas.draw_idle()

    apply_btn.on_click(on_apply_clicked)

    out = interactive_output(update, {"ci": slider, "show_spots": toggle_spots})

    controls_row1 = HBox([slider, toggle_spots])
    controls_row2 = HBox([dx_box, dy_box, dz_box, tx_box, ty_box, tz_box, apply_btn])
    display(VBox([tip, controls_row1, controls_row2, out]))

    return fig

spots_df = load_spots_with_offset(
    csv_path=spots_csv_path,
    voxel_size_zyx_um=voxel_size_zyx_um,
    image_zyx_shape=bit_images[0,:].shape
)

fig = show_czyx_with_spots(
    img_czyx=bit_images,
    spots_df=spots_df,
    voxel_size_zyx_um=voxel_size_zyx_um,
    translate_zyx_um=tile_offset_zyx_um
)

plt.show()

## Test merfish3d-analysis on randomly distributed RNA with 𝚫z=1.0 𝞵m.
Because an axial spacing of 𝚫z=1.0 𝞵m is larger than Shannon-Nyquist sampling for the objective (NA=1.35), here we decode in plane-by-plane and then collapse spots in adajacent z planes.
  
The steps are:  
1. Convert simulation data format to our (qi2lab) experimental format.
2. Convert qi2lab format to `merfish3d-analysis` datastore.
3. 2D deconvolution and 2D prediction of "spot-like" features plane-by-plane in every bit.
4. 2D decoding to find RNA molecules plane-by-plane, then collapse indentical molecules in adajacent z-planes to one decoded moelcule, and filter to limit blank codewords as necessary.
5. Calculate F1-score using ground truth RNA molecule locations.


In [None]:
!sim-convert "/content/synthetic_data/example_16bit_flat/1.0"
!sim-datastore "/content/synthetic_data/example_16bit_flat/1.0/sim_acquisition"
!sim-preprocess "/content/synthetic_data/example_16bit_flat/1.0/sim_acquisition"
!sim-decode "/content/synthetic_data/example_16bit_flat/1.0/sim_acquisition"
!sim-f1score "/content/synthetic_data/example_16bit_flat/1.0"

## Test merfish3d-analysis on randomly distributed RNA with 𝚫z=1.5 𝞵m.
Because an axial spacing of 𝚫z=1.5 𝞵m is larger than Shannon-Nyquist sampling for the objective (NA=1.35), here we decode in plane-by-plane and then collapse spots in adajacent z planes.
  
The steps are:  
1. Convert simulation data format to our (qi2lab) experimental format.
2. Convert qi2lab format to `merfish3d-analysis` datastore.
3. 2D deconvolution and 2D prediction of "spot-like" features plane-by-plane in every bit.
4. 2D decoding to find RNA molecules plane-by-plane, then collapse indentical molecules in adajacent z-planes to one decoded moelcule, and filter to limit blank codewords as necessary.
5. Calculate F1-score using ground truth RNA molecule locations.

In [None]:
!sim-convert "/content/synthetic_data/example_16bit_flat/1.5"
!sim-datastore "/content/synthetic_data/example_16bit_flat/1.5/sim_acquisition"
!sim-preprocess "/content/synthetic_data/example_16bit_flat/1.5/sim_acquisition"
!sim-decode "/content/synthetic_data/example_16bit_flat/1.5/sim_acquisition"
!sim-f1score "/content/synthetic_data/example_16bit_flat/1.5"