## This notebook takes the output of deepcell, processes it, segments cells, and outputs the extracted channel information

In [None]:
import numpy as np
import os
import skimage.io as io
import matplotlib.pyplot as plt
import sys

sys.path.append("../")

from segmentation.utils import data_utils, segmentation_utils
import xarray as xr

## This script is currently configured as a template to run with the provided example data. If running your own data, make a copy of this notebook first before modifying it.                         Go to file-> make a copy to create a copy of this notebook

In [None]:
# set up filepaths

# this folder has the output of the deepcell network
base_dir = '../data/example_dataset/'

# these don't need to be changed
base_name = "deepcell_output"
deepcell_dir = os.path.join(base_dir, base_name)

## The output of the deepcell network is first smoothed to avoid oversplitting

In [None]:
# define parameters for data extraction
# the amount of smoothing that will be applied to watershed image
pixel_smooth = [4, 6, 8, 10]

In [None]:
# smooth the data using the thresholds specified above
pixel_xr_data = xr.load_dataarray(os.path.join(deepcell_dir, base_name + "_pixel.xr"))
pixel_xr_data.name = base_name + "_pixel"
data_utils.save_deepcell_tifs(pixel_xr_data, save_path=deepcell_dir,  transform='pixel', pixel_smooth=pixel_smooth)

In [None]:
# load the processed deepcell output
pixel_xr = xr.load_dataarray(os.path.join(deepcell_dir, '{}_pixel_processed.xr'.format(base_name)))
input_xr = xr.load_dataarray(os.path.join(base_dir, "input_data/deepcell_input.xr"))

### We can then plot specific points, and look at the smoothing, to assess which is performing the best

In [None]:
# select point and smooth value to visualize
point = "Point8"
smooth = "pixel_interior_smoothed_8"

In [None]:
plt.figure(figsize = (13, 13))
plt.imshow(pixel_xr.loc[point, :, :, smooth])

### There are a few key tunable parameters for performing watershed

In [None]:
# this is the value that the deepcell probability mask is thresholded at to differentiate cell vs background.
# lower values will include more pixels, higher values less
background_threshold = 0.35

# this is the pixel smooth that was selected from above
pixel_smooth = 6

# these are the channels that will plotted to assess segmentation accuracy; change to match your dataset
# channels within their own bracket will be plotted by themselves. Multiple channels will get plotted together
overlay_channels = [["HH3"], ["Membrane"], ["HH3", "Membrane"]]

# if you're doing whole-cell prediction, set this to None. If you're doing nuclear prediction, set a single value
nuclear_expansion=None

### We then segment the data

In [None]:
segmentation_dir = base_dir + "/segmentation_output_threshold_{}_smooth_{}_expansion_{}/".format(background_threshold, pixel_smooth, nuclear_expansion)
if not os.path.isdir(segmentation_dir):
    os.makedirs(segmentation_dir)

In [None]:
# watershed over the processed deepcell output
segmentation_utils.watershed_transform(pixel_xr=pixel_xr, channel_xr=input_xr, 
                                       background_threshold=background_threshold,
                                       pixel_smooth="pixel_interior_smoothed_{}".format(pixel_smooth),
                                       overlay_channels=overlay_channels, output_dir=segmentation_dir, 
                                       rescale_factor=1.5, nuclear_expansion=nuclear_expansion)

In [None]:
# load segmentation generated by watershed
segmentation_labels = xr.load_dataarray(os.path.join(segmentation_dir,
                                                     '{}_pixel_processed_segmentation_labels.xr'.format(base_name)))


### We can then visualize the segmented mask generated by the watershed

In [None]:
point = "Point8"

In [None]:
plt.figure(figsize = (13, 13))
plt.imshow(segmentation_labels.loc[point, :, :, "segmentation_label"])

### We can also visualize the segmented mask overlaid on the imaging data

In [None]:
plt.figure(figsize = (13, 13))
plt.imshow(plt.imread(os.path.join(segmentation_dir, "Point8_{}_overlay.tiff".format(overlay_channels[0]))))

### Once you're happy with the segmentation parameters, we extract the data

In [None]:
# if loading your own dataset, make sure all imaging data is in the same folder, with each FOV given it's own folder
# All FOVs must have the same channels

# If the TIFs are in a subfolder, specify the name here
tif_folder = 'TIFs'

# load channel data
points_folder = os.path.join(base_dir, "Input_Data")
points = os.listdir(points_folder)
points = [point for point in points if os.path.isdir(os.path.join(points_folder, point))]

single_cell_dir = base_dir + "single_cell_output_threshold_{}_smooth_{}_expansion_{}".format(background_threshold, pixel_smooth, nuclear_expansion)

if not os.path.exists(single_cell_dir):
    os.makedirs(single_cell_dir)

In [None]:
# if loading more data than can fit into memory at once, we loop through in smaller increments
batch_size = 5
cohort_len = len(points)
num_batch = int(np.floor(cohort_len / batch_size))

for i in range(num_batch):
    current_points = points[i * batch_size:(i + 1) * batch_size]
    image_data = data_utils.load_tifs_from_points_dir(point_dir=points_folder, tif_folder=tif_folder, 
                                                      points=current_points)
    current_labels = segmentation_labels.loc[current_points, :, :, :]
    
    # segment the imaging data
    segmentation_utils.extract_single_cell_data(segmentation_labels=current_labels, image_data=image_data,
                                              save_dir=single_cell_dir)

# if batch did not divide evenly into total, process remainder
if cohort_len % batch_size != 0:
    current_points = points[num_batch * batch_size:]
    image_data = data_utils.load_tifs_from_points_dir(point_dir=points_folder, tif_folder=tif_folder, 
                                                      points=current_points)
    current_labels = segmentation_labels.loc[current_points, :, :, :]
    
    # segment the imaging data
    segmentation_utils.extract_single_cell_data(segmentation_labels=current_labels, image_data=image_data,
                                              save_dir=single_cell_dir)

In [None]:
# combine CSV files together
csv_files = os.listdir(single_cell_dir)
csv_files = [x for x in csv_files if 'transformed' in x]

segmentation_utils.concatenate_csv(base_dir=single_cell_dir, csv_files=csv_files)