In [None]:
import numpy as np
import pandas as pd
import xtiff

from os import PathLike
from pathlib import Path
from typing import List, Union
from zipfile import ZipFile

from steinbock import io
from steinbock.preprocessing import imc
from steinbock.segmentation import deepcell
from steinbock.measurement import intensities, regionprops, neighbors

# IMC preprocessing pipeline

**steinbock:**  
Documentation: https://bodenmillergroup.github.io/steinbock

## Settings

Example data can be downloaded using the `download_examples.ipynb` script.

### Input and output directories

In [None]:
base_dir = Path("..")

# Paths to zipped acquisition files
raw_dir = base_dir / "raw"

# Output directories
img_dir = base_dir / "img"
masks_dir = base_dir / "masks"
segstack_dir = base_dir / "segstacks"
intensities_dir = base_dir / "intensities"
regionprops_dir = base_dir / "regionprops"
neighbors_dir = base_dir / "neighbors"

## Extract images from `.mcd` files

Documentation: https://bodenmillergroup.github.io/steinbock/latest/cli/preprocessing/#image-conversion

### Prepare the panel
#### Panel file and column names

In [None]:
# Path to panel file
panel_file = raw_dir / "panel.csv"

# Panel columns
panel_channel_col = "metal"
panel_name_col = "name"
panel_keep_col = "full"
panel_cellseg_col = "deepcell"

#### Import the panel

In [None]:
imc_panel = pd.read_csv(panel_file)
imc_panel.head()

### Unzip

Unzip function

In [None]:
def extract_zips(
    path: Union[str, PathLike], suffix: str, dest: Union[str, PathLike]
) -> List[Path]:
    extracted_files = []
    for zip_file_path in Path(path).rglob("[!.]*.zip"):
        with ZipFile(zip_file_path) as zip_file:
            zip_infos = sorted(zip_file.infolist(), key=lambda x: x.filename)
            for zip_info in zip_infos:
                if not zip_info.is_dir() and zip_info.filename.endswith(suffix):
                    extracted_file = zip_file.extract(zip_info, path=dest)
                    extracted_files.append(Path(extracted_file))
    return extracted_files

In [None]:
# Extract .mcd files
extract_zips(path=raw_dir, suffix=".mcd", dest=raw_dir)

In [None]:
# Extract .txt files
extract_zips(path=raw_dir, suffix=".txt", dest=raw_dir)

### Convert to tiff
#### Settings

**To extract xml from ome tiff, see https://github.com/BodenmillerGroup/xtiff/blob/main/xtiff/ome.py**

In [None]:
tiff_type = "ome.tiff"

# Value for hot pixel filtering
hpf = 50

# Channels
channel_names = imc_panel[imc_panel["keep"]==1]['channel']

# List mcd and txt files
mcd_files = imc.list_mcd_files(raw_dir)
txt_files = imc.list_txt_files(raw_dir)

# Misc
img_dir.mkdir(exist_ok=True)

In [None]:
# Metadata
def extract_metadata(
    img_file,
    mcd_file,
    img,
    matched_txt,
    recovered
):
    recovery_file_name = None
    if matched_txt is not None:
        recovery_file_name = matched_txt.name
    
    image_info_row = {
        "image": img_file.name,
        "width_px": img.shape[2],
        "height_px": img.shape[1],
        "num_channels": img.shape[0],
        "source_file": mcd_file.name,
        "recovery_file": recovery_file_name,
        "recovered": recovered,
    }
    
    if acquisition is not None:
        image_info_row.update({
            "acquisition_id": acquisition.id,
            "acquisition_description": acquisition.description,
            "acquisition_start_x_um": (acquisition.roi_points_um[0][0]),
            "acquisition_start_y_um": (acquisition.roi_points_um[0][1]),
            "acquisition_end_x_um": (acquisition.roi_points_um[2][0]),
            "acquisition_end_y_um": (acquisition.roi_points_um[2][1]),
            "acquisition_width_um": acquisition.width_um,
            "acquisition_height_um": acquisition.height_um
        })

    image_info_data = pd.DataFrame.from_dict([image_info_row])
    return(image_info_data)

#### Convert

also save image metadata

In [None]:
# Convert to ome.tiff
if tiff_type == "ome.tiff":
    image_info_data = pd.DataFrame()
    
    for mcd_file, acquisition, img, matched_txt, recovered in imc.try_preprocess_images_from_disk(
        mcd_files = mcd_files,
        txt_files = txt_files,
        hpf = hpf,
        channel_names = channel_names
    ):
        
        img_file = Path(img_dir) / f"{mcd_file.stem}_{acquisition.description}.{tiff_type}"
        ome = xtiff.to_tiff(
            img = img,
            file = img_file,
#             channel_names=imc_panel["name"],
#             channel_fluors=imc_panel["channel"]
            channel_names=channel_names
        )

        image_info = extract_metadata(img_file, mcd_file, img, matched_txt, recovered)
        image_info_data = pd.concat([image_info_data, image_info])

    image_info_data.to_csv(base_dir / "images.csv", index=False)

In [None]:
# # OPTIONAL: Convert to .tiff
# if tiff_type == "tiff":
#     image_info_data = pd.DataFrame()
    
#     for mcd_file, acquisition, img, matched_txt, recovered in imc.try_preprocess_images_from_disk(
#         mcd_files = mcd_files,
#         txt_files = txt_files,
#         hpf = hpf,
#         channel_names = channel_names
#     ):
#         img_file = Path(img_dir) / f"{mcd_file.stem}_{acquisition.description}.{tiff_type}"
#         io.write_image(img, img_file)

#         image_info = extract_metadata(img_file, mcd_file, img, matched_txt, recovered)
#         image_info_data = pd.concat([image_info_data, image_info])

#     image_info_data.to_csv(base_dir / "images.csv", index=False)

## Cell segmentation

### Prepare segmentation stacks

In [None]:
keep = imc_panel["keep"]==1
channel_groups = imc_panel[keep]["deepcell"].values
channelwise_zscore = True
aggr_func = np.sum
segstack_dir.mkdir(exist_ok=True)

# Segmentation type
segmentation_type = "whole-cell"

In [None]:
for img_path in sorted(Path(img_dir).glob("*ome.tiff")):
    img = io.read_image(img_path)
    img = img[keep,:,:]
    
    if channelwise_zscore:
        channel_means = np.nanmean(img, axis=(1, 2))
        channel_stds = np.nanstd(img, axis=(1, 2))
        img -= channel_means[:, np.newaxis, np.newaxis]
        img[channel_stds > 0] /= channel_stds[
            channel_stds > 0, np.newaxis, np.newaxis
        ]
                
    if channel_groups is not None:
        img = np.stack(
            [
                aggr_func(img[channel_groups == channel_group], axis=0)
                for channel_group in np.unique(channel_groups)
                if not np.isnan(channel_group)
            ]
        )
    img_file = Path(segstack_dir) / f"{img_path.name.replace(('.' + tiff_type), '_deepcell.tiff')}"
    io.write_image(img, img_file)

### Segment cells

In [None]:
segstacks = sorted(Path(segstack_dir).glob("*"  + "deepcell.tiff"))
masks_dir.mkdir(exist_ok=True)

In [None]:
for img_path, mask in deepcell.try_segment_objects(
    img_files = segstacks,
    application = deepcell.Application.MESMER,
    pixel_size_um = 1.0,
    segmentation_type = "whole-cell"
):
    mask_file = Path(masks_dir) / f"{img_path.stem}_{'mask'}_{segmentation_type}.tiff"
    io.write_mask(mask, mask_file)        

## Measure cells

#### Create output folders

In [None]:
intensities_dir.mkdir(exist_ok=True)
regionprops_dir.mkdir(exist_ok=True)
neighbors_dir.mkdir(exist_ok=True)

### Measure cell intensities per channel

In [None]:
for img_path, mask_path, intens in intensities.try_measure_intensities_from_disk(
    img_files = io.list_image_files(img_dir),
    mask_files = io.list_image_files(masks_dir),
    channel_names = imc_panel["channel"],
    intensity_aggregation = intensities.IntensityAggregation.MEAN
):
    intensities_file = Path(intensities_dir) / f"{mask_path.name.replace('.tiff', '.csv')}"
    pd.DataFrame.to_csv(intens, intensities_file)

### Measure cell spatial properties

#### List properties to measure

In [None]:
skimage_regionprops = [
        "area",
        "centroid",
        "major_axis_length",
        "minor_axis_length",
        "eccentricity",
    ]

#### Measure region props

In [None]:
for img_path, mask_path, region_props in regionprops.try_measure_regionprops_from_disk(
    img_files = io.list_image_files(img_dir),
    mask_files = io.list_image_files(masks_dir),
    skimage_regionprops = skimage_regionprops
):
    regionprops_file = Path(regionprops_dir) / f"{mask_path.name.replace('.tiff', '.csv')}"
    pd.DataFrame.to_csv(region_props, regionprops_file)

### Measure cell neighbors

#### Settings

Choose dmax (max distance between centroids) and/or kmax (k-nearest neighbors)
Neighborhood types:
+ NeighborhoodType.CENTROID_DISTANCE,
+ NeighborhoodType.EUCLIDEAN_BORDER_DISTANCE,
+ NeighborhoodType.EUCLIDEAN_PIXEL_EXPANSION,

In [None]:
neighborhood_type = neighbors.NeighborhoodType.CENTROID_DISTANCE
dmax = 15
kmax = 5

#### Measure cell neighbors

In [None]:
for mask_path, neighb in neighbors.try_measure_neighbors_from_disk(
    mask_files = io.list_image_files(masks_dir),
    neighborhood_type = neighborhood_type,
    metric = "euclidean",
    dmax = dmax,
    kmax = kmax
):
    neighb_file = Path(neighbors_dir) / f"{mask_path.name.replace('.tiff', '.csv')}"
    pd.DataFrame.to_csv(neighb, neighb_file)

In [None]:
!conda list