In [None]:
import io
import os
from datetime import datetime
import re

import dash
from dash import dcc, html, Input, Output, State
from dash.exceptions import PreventUpdate
import datashader as ds
import dask.dataframe as dd
import ipywidgets as widgets
from matplotlib.path import Path as MplPath
import numpy as np
import pandas as pd
import plotly.express as px
from IPython.display import clear_output, display
from scipy.ndimage import gaussian_filter, median_filter
import tifffile

## Image reconstruction

This notebook extracts laser-position and TOF data from individual CSV files produced by Nu Instruments TOFMS systems to reconstruct multiplexed images.  
Upon entering a directory path below, the data are extracted from the CSV files stored a datashader canvas.  
  
Once loaded the data can be visualised as an interactive image for inspection and exported as TIFF images.   
If required, regions of interest (ROI) can by added and exported to a CSV.  
  
To begin, enter your file path and export directory (optional) in the input fields of the next cell.  
  
**NOTE**  
- The code in this notebook assumes metadata labels and structure based on the conventions used in a specific combination of laser and TOFMS vendors and corresponding software.
- Pixel data **must** be square. Overlap (if used) must be the same in X- and Y-dimension.
- The given file directory **must not** contain CSV files other than data files.

Specifically we have successfully used this notebook to reconstruct images from CSV files produced by the following software:  
  
| **Software** | **Version** | **Manufacturer**                       |
|--------------|-------------|----------------------------------------|
| NuQuant      | 1.2.8739.1  | Nu Instruments Ltd., Wrexham, UK       |
| Chromium     | 3.2         | Teledyne Photon Machines, Bozeman, USA |
| ActiveView2  | 1.5.1.30    | Elemental Scientific Lasers (ESL), Bozeman, USA                 |


In [None]:
# directory widgets
directory_path = widgets.Text(
    placeholder="Your file path…",
    description="File path:",
    style={'description_width': '100px'},
    layout={'width': '40%'},
)
export_dir = widgets.Text(
    placeholder="Your export directory…",
    description="Export directory:",
    style={'description_width': '100px'},
    layout={'width': '40%'},
)

file_status = widgets.Output()

def _check_files(change):
    with file_status:
        file_status.clear_output()
        path = directory_path.value.strip()
        if not os.path.isdir(path):
            print(f"{path} is not a valid directory.")
            return

        # count how many .csv files are in the directory
        all_files = os.listdir(path)
        csv_files = [f for f in all_files if f.lower().endswith(".csv")]
        count = len(csv_files)

        print(f"Found {count} CSV file(s) in “{path}”")
        

# observer
directory_path.observe(_check_files, names='value')

display(directory_path, export_dir, file_status)

**NOTE**  
The upcomming cell assumes the metadata fields for X- and Y-positions in the CSV files are called: ```"x [um]", "y [um]"```  
Change these (lines 7 - 8) according to the specifications of your metadata.  
Additionally, it is assumed that the CSV files contain a metadata block/header that is 11 rows long. Change the ```skip_metadata``` variable on line 1 of the cell below to accomodate for differences.

In [None]:
skip_metadata = 11

# Getting CSV columns
path = os.path.join(directory_path.value, os.listdir(directory_path.value)[0])
header_df = pd.read_csv(path, skiprows=skip_metadata, nrows=0)
all_channels = list(header_df.columns)

# Spatial axes & isotope channels
X, Y = "x [um]", "y [um]"
channels = [c for c in all_channels if c not in (X, Y)]

In [None]:
isotopes = []

# isotope widgets
isotope_mode = widgets.RadioButtons(
    options=["Manual selection", "Upload CSV file"],
    description="Choose isotope selection method:",
    style={"description_width": "initial"},
)

isotope_list_upload = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload isotope list",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="auto"),
)

isotope_select = widgets.SelectMultiple(
    options=channels,
    description="Select isotopes:",
    style={"description_width": "initial"},
    rows=20,
)

isotope_out = widgets.Output()
isotope_list_out = widgets.Output()


# callbacks
def on_isotope_change(change):
    global isotopes
    with isotope_list_out:
        clear_output(wait=True)
        isotopes = list(change["new"])
        print("Selected isotopes:")
        for iso in isotopes:
            print(iso)


def on_upload_change(change):
    uploaded = isotope_list_upload.value[0]
    df = pd.read_csv(io.BytesIO(uploaded.content), header=None)
    global isotopes
    isotopes = df.iloc[:,0].astype(str).tolist()
    with isotope_list_out:
        clear_output(wait=True)
        print("Uploaded isotopes:")
        for iso in isotopes:
            print(iso)


def on_mode_change(change):
    isotope_out.clear_output()
    isotope_list_out.clear_output()
    with isotope_out:
        if change.new == "Manual selection":
            print("Please select isotopes from the list")
            isotope_select.observe(on_isotope_change, names="value")
            display(widgets.HBox([isotope_select, isotope_list_out]))
        else:
            print("Upload a CSV with a single column containing isotope labels.")
            display(isotope_list_upload, isotope_list_out)


# observers
isotope_mode.observe(on_mode_change, names="value")
isotope_list_upload.observe(on_upload_change, names="value")

display(isotope_mode)
display(isotope_out)

# initial view
on_mode_change(type("X", (), {"new": isotope_mode.value}))

**Note**: Uncomment the cell below to save the manual selection of isotopes as a CSV for recurrent usage.

In [None]:
# if len(isotope_select.value) > 0:
#     export_path = os.path.join(export_dir.value, "isotope_selection.csv")
#     isotope_csv = pd.DataFrame(isotope_select.value).to_csv(
#         export_path,
#         index=False,
#         header=False,
#     )
#     print(f"Isotope selection saved to {export_path}")
# else:
#     print("No isotopes selected. Please select at least one isotope.")

In [None]:
canvas_columns = [X, Y] + isotopes

ddf = dd.read_csv(
    os.path.join(directory_path.value, "*_*.csv"),
    skiprows=skip_metadata,  # skip metadata block
    usecols=canvas_columns,
    assume_missing=True,
)

image_valid = ddf[(ddf[X] > 0) & (ddf[Y] > 0)]

unique_x = np.sort(image_valid[X].unique().compute())
unique_y = np.sort(image_valid[Y].unique().compute())

dx = np.median(np.diff(unique_x))
dy = np.median(np.diff(unique_y))

x_range = (unique_x[0] - dx / 2, unique_x[-1] + dx / 2)
y_range = (unique_y[0] - dy / 2, unique_y[-1] + dy / 2)

plot_width = len(unique_x)
plot_height = len(unique_y)

canvas = ds.Canvas(
    plot_width=plot_width, plot_height=plot_height, x_range=x_range, y_range=y_range
)


def compute_agg(channel):
    agg = canvas.points(
        image_valid,
        x=X,
        y=Y,
        agg=ds.mean(channel),
    )
    return np.nan_to_num(agg.values, nan=0).astype(np.float32)

### Dash App - Image Visualisation

The following cells are used to visualise the loaded mass channels in an interactive dash app.  
Regions of interest (ROIs) can be drawn onto the image and subsequently exported. Currently supported shapes for ROIs are: circle, rectangle, open freeform (if closed manually) and closed freeform   
  
**NOTE**: To visualise a mass channel, select a single channel from the dropdown menu. To export multiple channels, select multiple channels at once. 

In [None]:
channel_options = [{"label": c, "value": c} for c in image_valid.columns[2:]]
channel_options.append({"label": "Select all (export)", "value": "all"})

app = dash.Dash(__name__)

# Style for labels
label_style = {"fontWeight": "bold", "marginRight": "5px"}

# Grid container for controls
grid_style = {
    "display": "grid",
    "gridTemplateColumns": "repeat(auto-fit, minmax(150px, 1fr))",
    "gridGap": "20px",
    "alignItems": "center",
    "marginBottom": "20px",
}

app.layout = html.Div(
    [
        # Top controls in grid
        html.Div(
            [
                html.Div(
                    [
                        html.Label("Channels:", style=label_style),
                        dcc.Dropdown(
                            id="channel-dropdown",
                            options=channel_options,
                            value=[image_valid.columns[2]],
                            multi=True,
                            closeOnSelect=False,
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("Colorscale:", style=label_style),
                        dcc.Dropdown(
                            id="colorscale-dropdown",
                            options=[
                                {"label": cs, "value": cs}
                                for cs in px.colors.named_colorscales()
                            ],
                            value="viridis",
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("ROI Export:", style=label_style),
                        dcc.Dropdown(
                            id="roi-export-dropdown",
                            options=[
                                {"label": label, "value": val}
                                for label, val in [
                                    ("Raw", "raw"),
                                    ("Processed", "processed"),
                                    ("Both", "both"),
                                ]
                            ],
                            value="both",
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("Intensity Min (%):", style=label_style),
                        dcc.Input(
                            id="intensity-min",
                            type="number",
                            value=0.0,
                            min=0,
                            max=100,
                            step=0.001,
                            style={"width": "20%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("Intensity Max (%):", style=label_style),
                        dcc.Input(
                            id="intensity-max",
                            type="number",
                            value=100.0,
                            min=0,
                            max=100,
                            step=0.001,
                            style={"width": "20%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        dcc.Checklist(
                            id="toggle-histogram",
                            options=[{"label": "Show Histogram", "value": "show"}],
                            value=["show"],
                            labelStyle={"marginLeft": "10px"},
                        ),
                    ]
                ),
            ],
            style=grid_style,
        ),
        # Second row of filters in grid
        html.Div(
            [
                html.Div(
                    [
                        html.Label("Gaussian σ:", style=label_style),
                        dcc.Input(
                            id="gaussian-sigma",
                            type="number",
                            value=None,
                            min=0,
                            step=0.1,
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("Gaussian Radius:", style=label_style),
                        dcc.Input(
                            id="gaussian-radius",
                            type="number",
                            value=None,
                            min=1,
                            step=1,
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("Median size/footprint:", style=label_style),
                        dcc.Input(
                            id="median-size",
                            type="number",
                            value=None,
                            min=0,
                            step=0.1,
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Label("Image Export:", style=label_style),
                        dcc.Dropdown(
                            id="export-items",
                            options=[
                                {"label": label, "value": val}
                                for label, val in [
                                    ("Raw", "raw"),
                                    ("Processed", "processed"),
                                    ("Both", "both"),
                                ]
                            ],
                            value="both",
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Button(
                            "Export TIFF to Folders",
                            id="export-image-btn",
                            n_clicks=0,
                            style={"width": "100%"},
                        ),
                    ]
                ),
                html.Div(
                    [
                        html.Button("Save ROIs", id="save-rois-btn"),
                        html.Div(id="roi-output"),
                    ]
                ),
            ],
            style=grid_style,
        ),
        # image and histogram
        html.Div(
            [
                html.Div(
                    [
                        dcc.Graph(
                            id="graph-picture",
                            config={
                                "modeBarButtonsToAdd": [
                                    "drawopenpath",
                                    "drawclosedpath",
                                    "drawcircle",
                                    "drawrect",
                                    "eraseshape",
                                ]
                            },
                        )
                    ],
                    id="image-container",
                    style={"flex": "0 0 65%", "padding": "0 10px"},
                ),
                html.Div(
                    [dcc.Graph(id="histogram")],
                    id="histogram-container",
                    style={"flex": "0 0 35%", "padding": "0 10px"},
                ),
            ],
            style={
                "display": "flex",
                "justifyContent": "center",
                "maxWidth": "auto",
                "margin": "0 auto",
            },
        ),
        html.Div(id="export-status", style={"margin": "10px 0", "fontWeight": "bold"}),
    ],
    style={"padding": "20px"},
)

In [None]:
def preprocess(arr, imin, imax, sig=None, rad=None, med=None):
    lo, hi = np.percentile(arr, (imin, imax)) # compute thresholds
    clipped = np.clip(arr, lo, hi)
    proc = clipped.copy()
    if sig is not None and sig > 0: # Gaussian filter
        proc = gaussian_filter(proc, sigma=sig, radius=rad)
    if med is not None and med > 1: # Median filter
        proc = median_filter(proc, size=med)
    return clipped, proc, lo, hi


@app.callback(
    [Output("graph-picture", "figure"), Output("histogram", "figure")],
    [
        Input("channel-dropdown", "value"),
        Input("colorscale-dropdown", "value"),
        Input("intensity-min", "value"),
        Input("intensity-max", "value"),
        Input("gaussian-sigma", "value"),
        Input("gaussian-radius", "value"),
        Input("median-size", "value"),
    ],
)
def update_output(channels, cs, imin, imax, sig, rad, med):
    if not channels:
        raise PreventUpdate
    all_ch = image_valid.columns[2:].tolist()
    if "all" in channels:
        channels = all_ch

    arr = compute_agg(channels[0])

    # thresholds and spatial filters
    _, proc, lo, hi = preprocess(arr, imin, imax, sig, rad, med)

    # figures
    fig_img = px.imshow(proc, color_continuous_scale=cs, zmin=lo, zmax=hi)
    fig_img.update_layout(newshape=dict(line=dict(color="red", width=2)))
    fig_img.update_traces(
        hovertemplate="x: %{x}<br>y: %{y}<br>intensity: %{z:.2f}<extra></extra>"
    )

    fig_hist = px.histogram(arr.ravel(), nbins=512)
    fig_hist.update_layout(
        xaxis_title="pixel intensity",
        yaxis_title="number of pixels",
    )
    fig_hist.add_vrect(x0=lo, x1=hi, fillcolor="red", opacity=0.15)
    fig_hist.update_layout(yaxis_type="log", showlegend=False)

    return fig_img.to_dict(), fig_hist.to_dict()


@app.callback(
    [Output("image-container", "style"), Output("histogram-container", "style")],
    [Input("toggle-histogram", "value")],
)
def update_container_styles(toggle_value):
    if "show" in toggle_value:
        image_style = {"width": "65%", "flex": "0 0 65%", "padding": "0 10px"}
        hist_style = {"width": "35%", "flex": "0 0 35%", "padding": "0 10px"}
    else:
        image_style = {"width": "100%", "flex": "1", "padding": "0 10px"}
        hist_style = {"display": "none"}
    return image_style, hist_style


@app.callback(
    Output("export-status", "children"),
    [
        Input("export-image-btn", "n_clicks"),
        State("channel-dropdown", "value"),
        State("export-items", "value"),
        State("intensity-min", "value"),
        State("intensity-max", "value"),
        State("gaussian-sigma", "value"),
        State("gaussian-radius", "value"),
        State("median-size", "value"),
    ],
    prevent_initial_call=True,
)
def export_to_folders(n, chs, mode, imin, imax, sig, rad, med):
    all_channels = image_valid.columns[2:].tolist()
    if "all" in chs:
        chs = all_channels

    raw_dir = os.path.join(export_dir.value, "raw")
    proc_dir = os.path.join(export_dir.value, "processed")
    msg = []

    for ch in chs:
        arr = compute_agg(ch)

        # thresholds and spatial filters
        _, proc, lo, hi = preprocess(arr, imin, imax, sig, rad, med)

        # Raw export
        if mode in ("raw", "both"):
            os.makedirs(raw_dir, exist_ok=True)
            path = os.path.join(raw_dir, f"{ch}.tiff")
            tifffile.imwrite(path, arr.astype(np.float32))
            msg.append(f"Raw:{ch}->{path}")

        # Processed export
        if mode in ("processed", "both"):
            os.makedirs(proc_dir, exist_ok=True)
            path = os.path.join(proc_dir, f"{ch}.tiff")
            tifffile.imwrite(path, proc.astype(np.float32))
            msg.append(f"Proc:{ch}->{path}")

    return "; ".join(msg) if msg else "None"


@app.callback(
    Output("roi-output", "children"),
    [Input("save-rois-btn", "n_clicks")],
    [
        State("graph-picture", "figure"),
        State("roi-export-dropdown", "value"),
        State("intensity-min", "value"),
        State("intensity-max", "value"),
        State("gaussian-sigma", "value"),
        State("gaussian-radius", "value"),
        State("median-size", "value"),
    ],
    prevent_initial_call=True,
)
def save_rois(n, fig, mode, imin, imax, sig, rad, med):
    if not n:
        raise PreventUpdate

    # shapes
    shapes = fig.get("layout", {}).get("shapes", [])
    if not shapes:
        return "No ROI annotations drawn."

    # compute intensity stretch bounds
    raw_sample = compute_agg(image_valid.columns[2])
    # thresholds and spatial filters
    _, _, lo, hi = preprocess(raw_sample, imin, imax, sig, rad, med)

    # ROI data collectors
    roi_records = []
    roi_counter = 1

    raw_images, proc_images = {}, {}
    for ch in image_valid.columns[2:]:
        raw = compute_agg(ch)
        # preprocess
        clipped, proc, _, _ = preprocess(raw, imin, imax, sig, rad, med)
        raw_images[ch] = clipped
        proc_images[ch] = proc

    # iterate shapes
    for shape in shapes:
        coords = []
        typ = shape.get("type", "rect")
        if typ in ("rect", "circle"):  # rectangle/circle shapes
            x0, x1 = float(shape.get("x0", 0)), float(shape.get("x1", 0))
            y0, y1 = float(shape.get("y0", 0)), float(shape.get("y1", 0))
            x_min, x_max = sorted((x0, x1))
            y_min, y_max = sorted((y0, y1))
            if typ == "rect":
                for y in range(int(np.floor(y_min)), int(np.ceil(y_max))):
                    for x in range(int(np.floor(x_min)), int(np.ceil(x_max))):
                        coords.append((x, y))
            else:  # circle
                cx, cy = (x_min + x_max) / 2, (y_min + y_max) / 2
                r = (x_max - x_min) / 2
                for y in range(int(np.floor(y_min)), int(np.ceil(y_max))):
                    for x in range(int(np.floor(x_min)), int(np.ceil(x_max))):
                        if (x - cx) ** 2 + (y - cy) ** 2 <= r**2:
                            coords.append((x, y))
        else:  # polygon shapes
            path_str = shape.get("path", "")
            # regex for 'M' or 'L' commands/coordinates
            tokens = re.findall(r"([ML])([\d\.\-]+),([\d\.\-]+)", path_str)
            pts = [(float(x), float(y)) for (_, x, y) in tokens]

            # build mask
            poly = MplPath(pts)
            xs, ys = zip(*pts)
            x_min, x_max = min(xs), max(xs)
            y_min, y_max = min(ys), max(ys)
            grid_x, grid_y = np.meshgrid(
                np.arange(int(np.floor(x_min)), int(np.ceil(x_max))),
                np.arange(int(np.floor(y_min)), int(np.ceil(y_max))),
            )
            coords = [
                (int(x), int(y))
                for (x, y), inside in zip(
                    zip(grid_x.ravel(), grid_y.ravel()),
                    poly.contains_points(np.vstack((grid_x.ravel(), grid_y.ravel())).T),
                )
                if inside
            ]

        # Clip coords
        valid = [
            (x, y)
            for x, y in coords
            if 0 <= y < raw_sample.shape[0] and 0 <= x < raw_sample.shape[1]
        ]
        # Extract pixel values
        for x, y in valid:
            record = {"ROI": roi_counter, "x": x, "y": y}
            for ch in raw_images:
                record[f"{ch}_raw (counts_pixel)"] = raw_images[ch][y, x]
                record[f"{ch}_proc (counts_pixel)"] = proc_images[ch][y, x]
            roi_records.append(record)
        roi_counter += 1

    if not roi_records:
        return "No ROI data extracted."

    df = pd.DataFrame(roi_records)
    raw_cols = [c for c in df.columns if c.endswith("_raw (counts_pixel)")]
    proc_cols = [c for c in df.columns if c.endswith("_proc (counts_pixel)")]

    ts = datetime.now().strftime("%Y%m%d_%H%M%S")
    out_files = []

    # Pixel‐level CSV
    pix_path = os.path.join(export_dir.value, f"roi_pixels_{ts}.csv")
    cols = ["ROI", "x", "y"]
    if mode in ("raw", "both"):
        cols += raw_cols
    if mode in ("processed", "both"):
        cols += proc_cols
    df[cols].to_csv(pix_path, index=False)
    out_files.append(pix_path)

    # ROI summary CSV
    area = df.groupby("ROI").size().rename("Area (pixels)").reset_index()
    # choose which sums to compute
    sum_targets = []
    if mode in ("raw", "both"):
        sum_targets += raw_cols
    if mode in ("processed", "both"):
        sum_targets += proc_cols

    sum_df = (
        df.groupby("ROI")[sum_targets]
        .sum()
        .reset_index()
        .rename(columns={c: f"{c[:-15]} sum (counts_pixel)" for c in sum_targets})
    )

    mean_df = (
        df.groupby("ROI")[sum_targets]
        .mean()
        .reset_index()
        .rename(columns={c: f"{c[:-15]} mean (counts_pixel)" for c in sum_targets})
    )

    summary_df = area.merge(sum_df, on="ROI").merge(mean_df, on="ROI")

    sum_path = os.path.join(export_dir.value, f"roi_summary_{ts}.csv")
    summary_df.to_csv(sum_path, index=False)
    out_files.append(sum_path)

    return "Saved ROI data to: " + ", ".join(out_files)


if __name__ == "__main__":
    app.run(jupyter_height=900)

### TIFF stacking
This cell creates a TIFF-stack from the exported/selected TIFF files (e.g. for segmentation steps).

In [None]:
# List TIFF files in raw/processed directories
def list_tiffs(root):
    paths = []
    for sub in ("raw", "processed"):
        d = os.path.join(root, sub)
        if os.path.isdir(d):
            for f in os.listdir(d):
                if f.lower().endswith((".tif", ".tiff")):
                    paths.append(os.path.join(sub, f))
    return sorted(paths)


# Merge TIFFs into a stack
def on_merge(btn):
    with merge_output:
        clear_output(wait=True)
        sel = list(tiff_select.value)
        if not sel:
            print("No files selected.")
            return
        arrays = []
        for rel in sel:
            full = os.path.join(export_dir.value, rel)
            try:
                arr = tifffile.imread(full)
            except Exception as e:
                print(f"Failed to read {rel}: {e}")
                return
            arrays.append(arr)
        try:
            stack = np.stack(arrays, axis=0)
        except ValueError:
            print("Error: TIFFs have mismatched shapes.")
            return
        out_path = os.path.join(export_dir.value, stack_name.value)
        tifffile.imwrite(
            out_path,
            stack,
            imagej=True,
            photometric="minisblack",
            metadata={"axes": "CYX"},
        )
        print(f"Saved merged stack ({stack.shape[0]} pages) to:\n{out_path}")


tiff_select = widgets.SelectMultiple(
    options=list_tiffs(export_dir.value),
    description="Select TIFFs:",
    rows=10,
    style={"description_width": "initial"},
)
stack_name = widgets.Text(value="merged_stack.tiff", description="Output name:")
merge_button = widgets.Button(description="Create Stack", icon="layer-group")
merge_output = widgets.Output()


merge_button.on_click(on_merge)
display(widgets.VBox([tiff_select, stack_name, merge_button, merge_output]))