In [None]:
import glob
import io
import os

import ipywidgets as widgets
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from IPython.display import display
from PIL import Image
import imageio.v3 as iio
from skimage.color import label2rgb
from skimage.measure import regionprops
from skimage.segmentation import clear_border

%matplotlib widget

## Data extraction
This notebook provides functionallity to extract region based (single cell) information based on a segmentation mask and raw TIFF images.  

In [None]:
cell_mask_output = widgets.Output()
cell_labels = None  # cell labels

cell_mask_upload = widgets.FileUpload(
    accept='.tif,.tiff,.png',
    multiple=False,
    description='Upload Segementation Mask',
    layout=widgets.Layout(width='auto'),
)

# Callback
def on_cell_mask_upload(change):
    global cell_labels
    with cell_mask_output:
        cell_mask_output.clear_output()
        if cell_mask_upload.value:
            uploaded_file = cell_mask_upload.value[0]
            content = uploaded_file['content']
            cell_labels = iio.imread(io.BytesIO(content)).astype(np.int32)
            print(f"Uploaded: {uploaded_file['name']}")

# Attach and display
cell_mask_upload.observe(on_cell_mask_upload, names='value')
display(cell_mask_upload, cell_mask_output)

### Remove border objects (cells)

In [None]:
processing_ouput = widgets.Output()
cell_labels_cleaned = None

with processing_ouput:
    processing_ouput.clear_output()
    if cell_labels is None:
        print("Segmentation mask not uploaded yet. Please upload a Cell mask in Cell 1.")
    else:
        print("Processing Cell Mask by removing border-touching cells...\n")
        # Copy & clean
        mask_original = cell_labels.copy()
        mask_clean = clear_border(mask_original)
        cell_labels_cleaned = mask_clean

        # Statistics
        num_original = len(regionprops(mask_original))
        num_cleaned = len(regionprops(mask_clean))
        num_removed = num_original - num_cleaned
        print(f"Original number of cells: {num_original}")
        print(f"Number of cells after removal: {num_cleaned}")
        print(f"Number of cells removed: {num_removed}")

        # Convert to RGB
        mask_original_rgb = label2rgb(
            mask_original, bg_label=0, colors=plt.cm.tab20.colors, alpha=1.0
        )
        mask_clean_rgb = label2rgb(
            mask_clean, bg_label=0, colors=plt.cm.tab20.colors, alpha=1.0
        )

        # plot
        fig, axes = plt.subplots(1, 2, figsize=(14, 7))
        for ax, img, title in zip(
            axes,
            [mask_original_rgb, mask_clean_rgb],
            ["Original Cell Mask", "Cell Mask after Removing Border Cells"],
        ):
            ax.imshow(img)
            ax.set_title(title, fontsize=16)
            ax.axis("off")
        plt.tight_layout()
        plt.show()

# Display
display(processing_ouput)

### Filename (image) --> feature name mapping

Choose whether to use a python dictionary or an uploaded CSV file to map initial image file names to meaningful/informative feature names for the resulting single-cell tabular data.  
The format of the CSV should look like the following example:  
 
<table>
  <tr><td>image_channel_1</td><td>DNA</td></tr>
  <tr><td>image_channel_2</td><td>membrane</td></tr>
  <tr><td>image_channel_3</td><td>collagen</td></tr>
</table>

In [None]:
# Storage for manual or CSV‐loaded mapping
filename_mapping = {}

# Widgets
mapping_mode = widgets.RadioButtons(
    options=["Manual dictionary entry", "Upload CSV file"],
    description="Choose mapping method:",
    style={"description_width": "initial"},
)
mapping_output = widgets.Output()
csv_upload = widgets.FileUpload(
    accept=".csv",
    multiple=False,
    description="Upload mapping CSV",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="auto"),
)


# Show manual instructions or CSV widget
def on_mode_change(change):
    mapping_output.clear_output()
    with mapping_output:
        if change.new == "Manual dictionary entry":
            print(
                """
                Please uncomment and edit the `filename_mapping` dict below.
                Format: 'original_name': 'new_name'

                Example:
                filename_mapping = {
                    'filename_image_1': 'channel_1',
                    'filename_image_2': 'channel_2',
                }
                """.lstrip()
            )
        else:
            display(csv_upload)
            print(
                "Upload a CSV with two columns (no header):\n"
                "  • Column 1: original filenames\n"
                "  • Column 2: new filenames"
            )


# Process uploaded CSV into dictionary
def on_csv_upload(change):
    if not csv_upload.value:
        return
    with mapping_output:
        mapping_output.clear_output()
        try:
            uploaded = csv_upload.value[0]
            df = pd.read_csv(io.BytesIO(uploaded.content), header=None)
            if df.shape[1] < 2:
                print("Error: CSV needs at least 2 columns.")
                return
            # Build mapping
            global filename_mapping
            filename_mapping = dict(zip(df.iloc[:, 0], df.iloc[:, 1]))
            print(f"Loaded {len(filename_mapping)} mappings from CSV.")
            print("\nPreview:")
            for orig, new in filename_mapping.items():
                print(f"  '{orig}' → '{new}'")
        except Exception as e:
            print(f"Error processing CSV: {e}")


# observers
mapping_mode.observe(on_mode_change, names="value")
csv_upload.observe(on_csv_upload, names="value")

display(mapping_mode)
display(mapping_output)

# Trigger initial view
on_mode_change(type("X", (), {"new": mapping_mode.value}))

In [None]:
# Uncomment and edit the dictionary below with the correct filename mappings

# filename_mapping = {
    # '[193Ir]+': 'nucleus',
    # 'filename_image_2': 'channel_2',
    # Add more mappings as needed
# }

### Image Directory

Enter the image directory containing the images channels you wish to extract single-cell data from.  
**NOTE:** Any, .tif/.tiff or .ome.tif/.ome.tiff files contained in the directory will be listed and processed.

In [None]:
# Loaded images
image_channels = {}

directory_path = widgets.Text(
    description="Directory path:",
    placeholder="Enter path to TIFF files",
    style={"description_width": "initial"},
    layout=widgets.Layout(width="80%"),
)

load_output = widgets.Output()


# Callback
def on_directory_change(change):
    global image_channels
    path = change.new.strip()
    image_channels = {}
    with load_output:
        load_output.clear_output()
        if not path:
            print("Please enter a directory path.")
            return
        if not os.path.isdir(path):
            print(f"'{path}' is not a valid directory.")
            return

        # Collect matching TIFF variants
        patterns = ["*.ome.tif", "*.ome.tiff", "*.tif", "*.tiff"]
        image_files = []
        for pat in patterns:
            image_files.extend(glob.glob(os.path.join(path, pat)))
        if not image_files:
            print(f"No TIFF/OME-TIFF files found in '{path}'.")
            return

        loaded = []
        for fp in image_files:
            base = os.path.basename(fp)
            name, _ = os.path.splitext(base)
            if name.lower().endswith(".ome"):
                name = os.path.splitext(name)[0]
            new = filename_mapping.get(name, name)
            try:
                img = Image.open(fp)
                image_channels[new] = np.array(img)
                loaded.append((name, new))
            except Exception as e:
                print(f"Error loading '{base}': {e}")

        print(f"\nLoaded {len(loaded)} image(s):")
        for orig, new in loaded:
            print(f"  {orig} → {new}")

        # report failed mappings
        missing = set(filename_mapping) - {o for o, _ in loaded}
        if missing:
            print("\nMappings not found on disk:")
            for m in missing:
                print(f"  {m}")


# Observer and display
directory_path.observe(on_directory_change, names="value")
display(directory_path, load_output)

# Initial load
on_directory_change(type("X", (), {"new": directory_path.value}))

### Extract and save single-cell data

Integrated single-cell intensities for all selected image channels + cell region properties are extracted and saved to ```your/path/single_cell_data.csv```.  
Current cell region information contains **cell_label**, **Center_x**, **Center_y**, **Area** and **Diameter**. To add more properties edit the ```row``` dictionary on line 15.  
For information on what region properties are available one can consult the scikit image documentation: [skimage.measure.regionprops](https://scikit-image.org/docs/0.25.x/api/skimage.measure.html#skimage.measure.regionprops)

In [None]:
output_extraction = widgets.Output()
cell_data_df = None

with output_extraction:
    output_extraction.clear_output()
    if cell_labels_cleaned is None:
        print("Cell mask not available. Please upload and process the cell mask first.")
    elif not image_channels:
        print("Marker images not available. Please upload marker images.")
    else:
        data = []
        for prop in regionprops(cell_labels_cleaned):
            label = prop.label
            mask = cell_labels_cleaned == label
            row = {
                "cell_label": label,
                "Center_Y": prop.centroid[0],
                "Center_X": prop.centroid[1],
                "Area": prop.area,
                "Diameter": prop.equivalent_diameter_area,
            }
            # sum signal per marker
            for name, img in image_channels.items():
                if img.shape != cell_labels_cleaned.shape:
                    print(f"Warning: dimension mismatch for '{name}', inserting NaN.")
                    signal = float("nan")
                else:
                    signal = img[mask].sum()
                row[f"{name}_signal"] = signal
            data.append(row)

        # Build/save DataFrame
        cell_data_df = pd.DataFrame(data)
        file_path = os.path.join(directory_path.value, "single_cell_data.csv")
        cell_data_df.to_csv(file_path, index=False)
        print(f"Data extracted and saved to: {file_path}")

# Display
display(output_extraction)
if cell_data_df is not None:
    display(cell_data_df)