In [1]:
# import required packages
from deepcell.applications import Application
from deepcell.model_zoo.panopticnet import PanopticNet
import sys
sys.path.append("../src")
from cell_classification.semantic_head import create_semantic_head
import tensorflow as tf
import os
from tifffile import imread, imwrite
import numpy as np
import matplotlib.pyplot as plt
from skimage.segmentation import find_boundaries
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))
import cv2
import pandas as pd
from tqdm.notebook import tqdm
from joblib import Parallel, delayed
import random
import json

## 0: Set root directory and download example dataset
Here we are using the example data located in `/data/example_dataset/input_data`. To modify this notebook to run using your own data, simply change `base_dir` to point to your own sub-directory within the data folder. Set `base_dir`, the path to all of your imaging data (i.e. multiplexed images and segmentation masks). Subdirectory `cell_classification` will contain all of the data generated by this notebook. In the following, we expect this folder structure:
```
|-- base_dir
|   |-- image_data
|   |   |-- fov_1
|   |   |-- fov_2
|   |-- segmentation
|   |   |-- deepcell_output
|   |-- cell_classification
```

In [2]:
# set up the base directory
base_dir = "E:/angelo_lab/data/TONIC/raw"

## 1: set file paths and parameters

### All data, images, files, etc. must be placed in the 'data' directory, and referenced via '../data/path_to_your_data'

If you're interested in directly interfacing with Google Drive, consult the documentation [here](https://ark-analysis.readthedocs.io/en/latest/_rtd/google_docs_usage.html).

In [3]:
# set up file paths
tiff_dir = os.path.join(base_dir, "image_data/samples")
deepcell_output_dir = os.path.join(base_dir, "segmentation_data/deepcell_output")
cell_classification_output_dir = os.path.join(base_dir, "cell_classification")


## 2: Load data and prepare normalization dictionary
The next step is to iterate through all the fovs and calculate the 0.999 marker expression quantile for each marker individually. This is used for normalizing the marker expressions prior to predicting marker positivity/negativity with our model.

In [None]:
# Make output directory
os.makedirs(cell_classification_output_dir, exist_ok=True)

# define the channels to exclude
exclude_channels = ['H3K9ac', 'H3K27me3', "Au", "Fe", "Noodle", "Ca"]

# load data and prepare normalization dict
fov_names = os.listdir(tiff_dir)
# fov_names = ["TONIC_TMA10_R1C1", "TONIC_TMA10_R3C6", "TONIC_TMA10_R5C4"]
fov_paths = [os.path.join(tiff_dir, fov_name) for fov_name in fov_names]

# change this function to match your naming convention
def segmentation_naming_convention(fov_path):
    """Prepares the path to the segmentation data for a given fov
    Args:
        fov_path (str): path to fov
    Returns:
        seg_path (str): paths to segmentation fovs
    """
    fov_name = os.path.basename(fov_path)
    return os.path.join(
        deepcell_output_dir, fov_name + "_feature_0.tif"
    )

def calculate_normalization(channel_path, quantile):
    """Calculates the normalization value for a given channel
    Args:
        channel_path (str): path to channel
        quantile (float): quantile to use for normalization
    Returns:
        normalization_value (float): normalization value
    """
    mplex_img = imread(channel_path)
    normalization_value = np.quantile(mplex_img, quantile)
    chan = os.path.basename(channel_path).split(".")[0]
    return chan, normalization_value

def prepare_normalization_dict(
        fov_paths, quantile=0.999, exclude_channels=[], n_subset=10, n_jobs=8
    ):
    """Prepares the normalization dict for a list of fovs
    Args:
        fov_paths (list): list of paths to fovs
        quantile (float): quantile to use for normalization
        exclude_channels (list): list of channels to exclude
        n_subset (int): number of fovs to use for normalization
    Returns:
        normalization_dict (dict): dict with fov names as keys and normalization values as values
    """
    normalization_dict = {}
    if n_subset is not None:
        random.shuffle(fov_paths)
        fov_paths = fov_paths[:n_subset]
    print("Iterate over fovs...")
    for fov_path in tqdm(fov_paths):
        channels = os.listdir(fov_path)
        channels = [
            channel for channel in channels if channel.split(".")[0] not in exclude_channels
        ]
        channel_paths = [os.path.join(fov_path, channel) for channel in channels]
        if n_jobs > 1:
            normalization_values = Parallel(n_jobs=n_jobs)(
            delayed(calculate_normalization)(channel_path, quantile)
            for channel_path in channel_paths
            )
        else:
            normalization_values = [
                calculate_normalization(channel_path, quantile)
                for channel_path in channel_paths
            ]
        for channel, normalization_value in normalization_values:
            if channel not in normalization_dict:
                normalization_dict[channel] = []
            normalization_dict[channel].append(normalization_value)
    for channel in normalization_dict.keys():
        normalization_dict[channel] = np.mean(normalization_dict[channel])
    return normalization_dict

# Prepare or load training data normalization dict
normalization_dict = prepare_normalization_dict(fov_paths, exclude_channels=exclude_channels)
# normalization_dict = json.load(open(os.path.join(cell_classification_output_dir, 'normalization_dict.json')))

# save normalization dict
if not os.path.exists(os.path.join(cell_classification_output_dir, 'normalization_dict.json')):
    with open(os.path.join(cell_classification_output_dir, 'normalization_dict.json'), 'w') as f:
        json.dump(normalization_dict, f)



Iterate over fovs...


  0%|          | 0/10 [00:00<?, ?it/s]

## 3: Load model and initialize deepcell application
The following code initializes the deepcell application and loads the model checkpoint. The checkpoint needs to be downloaded from [here](https://charitede-my.sharepoint.com/:u:/g/personal/josef-lorenz_rumberger_charite_de/Ed5iVEMreE5DqJ_WczdXS9EBFeD75ZmaLdYWXENvUvUbSg?e=r2hxK8) and put under path `checkpoints/checkpoint_125000.h5`.

In [None]:
# load model
checkpoint_path = os.path.normpath("../checkpoints/halfres_512_checkpoint_160000.h5")
backbone = "efficientnetv2bs"
input_shape = [1024,1024,2]

def cell_preprocess(image, **kwargs):
    """Preprocess input data for cell classification model.
    Args:
        image: array to be processed
    Returns:
        np.array: processed image array
    """
    output = np.copy(image)
    if len(image.shape) != 4:
        raise ValueError("Image data must be 4D, got image of shape {}".format(image.shape))
    normalize = kwargs.get('normalize', True)
    marker = kwargs.get('marker')
    normalization_dict = kwargs.get('normalization_dict')
    if normalize:
        if marker in normalization_dict.keys():
            print("Norm_factor found for marker {}".format(marker))
            norm_factor = normalization_dict[marker]
        else:
            print("Norm_factor not found for marker {}".format(marker))
            norm_factor = np.quantile(image[...,0], 0.999)
        image[...,0] /= norm_factor
        image = image.clip(0, 1)
        output = np.copy(image)
    return output

def cell_postprocess(model_output):
    return model_output

def format_output(model_output):
    return model_output[0]

model = PanopticNet(
    backbone=backbone, input_shape=input_shape,
    norm_method="std", num_semantic_classes=[1],
    create_semantic_head=create_semantic_head, location=False,
)
model.load_weights(checkpoint_path)

prep = lambda x: cell_preprocess(x, normalize=True, marker=marker, normalization_dict=normalization_dict)
app = Application(
    model = model,
    model_image_shape = input_shape,
    preprocessing_fn=cell_preprocess,
    postprocessing_fn=cell_postprocess,
    format_model_output_fn = format_output
    )

## 4: Make predictions with the model
Determine if you want to (a) plot the predictions, (b) save the prediction images and (c) use test-time augmentation during inference. The script will iterate through your samples and store predictions and a file named `pred_cell_table.csv` that contains the mean-per-cell predicted marker activity.

In [None]:
# plot and save images
plot_predictions = False
save_predictions = True
test_time_aug = True
half_resolution = True

def prepare_input_data(mplex_img, instance_mask):
    edge = find_boundaries(instance_mask, mode="inner").astype(np.uint8)
    binary_mask = np.logical_and(edge == 0, instance_mask > 0).astype(np.float32)
    input_data = np.stack([mplex_img, binary_mask], axis=-1)[np.newaxis,...] # bhwc
    return input_data

def segment_mean(instance_mask, prediction):
    instance_mask_flat = tf.cast(tf.reshape(instance_mask, -1), tf.int32)  # (h*w)
    pred_flat = tf.cast(tf.reshape(prediction, -1), tf.float32)
    sort_order = tf.argsort(instance_mask_flat)
    instance_mask_flat = tf.gather(instance_mask_flat, sort_order)
    uniques, _ = tf.unique(instance_mask_flat)
    pred_flat = tf.gather(pred_flat, sort_order)
    mean_per_cell = tf.math.segment_mean(pred_flat, instance_mask_flat)
    mean_per_cell = tf.gather(mean_per_cell, uniques)
    return [uniques.numpy()[1:], mean_per_cell.numpy()[1:]] # discard background


def test_time_aug(input_data, channel, model, rotate=True, flip=True):    
    forward_augmentations = []
    backward_augmentations = []
    if rotate:
        for k in [0,1,2,3]:
            forward_augmentations.append(lambda x: tf.image.rot90(x, k=k))
            backward_augmentations.append(lambda x: tf.image.rot90(x, k=-k))
    if flip:
        forward_augmentations += [
            lambda x: tf.image.flip_left_right(x),
            lambda x: tf.image.flip_up_down(x)
        ]
        backward_augmentations += [
            lambda x: tf.image.flip_left_right(x),
            lambda x: tf.image.flip_up_down(x)
        ]
    input_batch = []
    for forw_aug in forward_augmentations:
        input_data_tmp = forw_aug(input_data).numpy() # bhwc
        input_batch.append(np.concatenate(input_data_tmp))
    input_batch = np.stack(input_batch, 0)
    seg_map = app._predict_segmentation(input_batch, preprocess_kwargs={"normalize": True, "marker": channel, "normalization_dict": normalization_dict}, batch_size=64)
    tmp = []
    for backw_aug, seg_map_tmp in zip(backward_augmentations, seg_map):
        seg_map_tmp = backw_aug(seg_map_tmp[np.newaxis,...])
        seg_map_tmp = np.squeeze(seg_map_tmp)
        tmp.append(seg_map_tmp)
    seg_map = np.stack(tmp, -1)
    seg_map = np.mean(seg_map, axis = -1, keepdims = True)
    return seg_map

fov_dict_list = []
for fov_path in fov_paths:
    out_fov_path = os.path.join(os.path.normpath(cell_classification_output_dir), os.path.basename(fov_path))
    fov_dict = {}
    for channel in os.listdir(fov_path):
        channel_path = os.path.join(fov_path, channel)
        if not channel.endswith(".tiff"):
            continue
        if channel[:2] == "._":
            continue
        channel = channel.split(".")[0]
        if channel in exclude_channels:
            continue
        mplex_img = np.squeeze(imread(channel_path))
        instance_path = segmentation_naming_convention(fov_path)
        instance_mask = np.squeeze(imread(instance_path))
        input_data = prepare_input_data(mplex_img, instance_mask)
        if half_resolution:
            scale = 0.5
            input_data = np.squeeze(input_data)
            h,w,_ = input_data.shape
            img = cv2.resize(input_data[...,0], [int(h*scale), int(w*scale)])
            binary_mask = cv2.resize(input_data[...,1], [int(h*scale), int(w*scale)], interpolation=0)
            input_data = np.stack([img, binary_mask], axis=-1)[np.newaxis,...]
        if test_time_aug:
            prediction = test_time_aug(input_data, channel, model)
        else:
            prediction = app._predict_segmentation(input_data, preprocess_kwargs={"normalize": True, "marker": channel, "normalization_dict": normalization_dict}, batch_size=2)
        prediction = np.squeeze(prediction)
        if half_resolution:
            prediction = cv2.resize(prediction, (h, w))
        instance_mask = np.expand_dims(instance_mask, axis=-1)
        labels, mean_per_cell = segment_mean(instance_mask, prediction)
        if "segmentation_label" not in fov_dict.keys():
            fov_dict["fov"] = [os.path.basename(fov_path)]*len(labels)
            fov_dict["segmentation_label"] = labels
        fov_dict[channel+"_pred"] = mean_per_cell
        if plot_predictions:
            fig, ax = plt.subplots(1,3, figsize=(16,16))
            # plot stuff
            ax[0].imshow(np.squeeze(input_data[...,0]), vmin=0, vmax=np.quantile(input_data[...,0], 0.999))
            ax[0].set_title(channel)
            ax[1].imshow(np.squeeze(input_data[...,1]))
            ax[1].set_title("binary")
            ax[2].imshow(np.squeeze(prediction), vmin=0, vmax=1)
            ax[2].set_title(channel+"_pred")
            for a in ax:
                a.set_xticks([])
                a.set_yticks([])
            plt.tight_layout()
            plt.show()
        if save_predictions:
            os.makedirs(out_fov_path, exist_ok=True)
            pred_int = tf.cast(prediction*255.0, tf.uint8).numpy()
            imwrite(os.path.join(out_fov_path, channel+".tiff"), pred_int, photometric="minisblack", compression="zlib")
    fov_dict_list.append(pd.DataFrame(fov_dict))
cell_table = pd.concat(fov_dict_list, ignore_index=True)
cell_table.to_csv(os.path.join(cell_classification_output_dir, "pred_cell_table.csv"), index=False)