## This notebook takes the output of deepcell, processes it, segments cells, and outputs the extracted channel information

In [None]:
import os
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import skimage.io as io
import xarray as xr

from segmentation.utils import data_utils, segmentation_utils, plot_utils

## This script is currently configured as a template to run with the provided example data. If running your own data, make a copy of this notebook first before modifying it. If you change this script you'll get merge conflicts when updating to the latest version. Go to file-> make a copy to create a copy of this notebook

In [None]:
# set up file paths
base_dir = "../data/example_dataset/"
input_dir = os.path.join(base_dir, "input_data")
tiff_dir = os.path.join(input_dir, 'single_channel_inputs')
label_dir = os.path.join(base_dir, 'deepcell_output')

# points to look at (None for all)
points = None

# validate file paths (add extra paths to this list)
data_utils.validate_paths([
    base_dir,
    input_dir,
    label_dir,
])

### We compute the paths for the deepcell input TIFFs

In [None]:
points_input = os.listdir(input_dir)
points_input = [p for p in points_input if p.split('.')[-1] in ['tif', 'tiff']]
if points:
    points_input = [p for p in points_input if p.split('_deepcell_input')[0] in points]

### We can then load the segmented mask from deepcell via label-map TIFFs and save as an xarray

In [None]:
segmentation_labels = data_utils.load_imgs_from_dir(data_dir=label_dir,
                                                    imgdim_name='compartments',
                                                    image_name='whole_cell',
                                                    delimiter='_feature_0')

save_name = os.path.join(label_dir, 'segmentation_labels.xr')
if os.path.exists(save_name):
    print("overwriting previously generated processed output file")
    os.remove(save_name)

segmentation_labels.to_netcdf(save_name, format="NETCDF3_64BIT")

### We can also then save the segmented mask overlaid on the imaging data

In [None]:

# get input data for overlay
input_data_xr = data_utils.load_imgs_from_multitiff(input_dir, multitiff_files=points_input)

for fov in input_data_xr.fovs:
    plot_utils.plot_overlay(segmentation_labels.loc[fov, :, :, "whole_cell"].values,
                            input_data_xr.loc[fov, :, :, :].values,
                            path=os.path.join(label_dir, f'{fov.values}_overlay.tif'))

### Afterwards, we can generate expression matrices from the labeling + imaging data

In [None]:
# if loading your own dataset, make sure all imaging data is in the same folder, with each FOV given it's own folder
# All FOVs must have the same channels

# If the TIFs are in a subfolder, specify the name here
img_sub_folder = "TIFs"

if not points:
    # load channel data
    all_points = os.listdir(tiff_dir)
    all_points = [point for point in all_points if os.path.isdir(os.path.join(tiff_dir, point))
                  and point.startswith("Point")]
    points = all_points
points.sort()

single_cell_dir = base_dir + "single_cell_output"

if not os.path.exists(single_cell_dir):
    os.makedirs(single_cell_dir)

In [None]:
# if loading more data than can fit into memory at once, we loop through in smaller increments
batch_size = 5
cohort_len = len(points)
num_batch = int(np.floor(cohort_len / batch_size))
combined_normalized_data = pd.DataFrame()
combined_transformed_data = pd.DataFrame()

for i in range(num_batch):
    current_points = points[i * batch_size:(i + 1) * batch_size]
    image_data = data_utils.load_imgs_from_tree(data_dir=tiff_dir, img_sub_folder=img_sub_folder, 
                                                      fovs=current_points)
    current_labels = segmentation_labels.loc[current_points, :, :, :]
    
    # segment the imaging data
    normalized_data, transformed_data = segmentation_utils.generate_expression_matrix(segmentation_labels=current_labels, image_data=image_data)
    
    combined_normalized_data = combined_normalized_data.append(normalized_data)
    combined_transformed_data = combined_transformed_data.append(transformed_data)
# if batch did not divide evenly into total, process remainder
if cohort_len % batch_size != 0:
    current_points = points[num_batch * batch_size:]
    image_data = data_utils.load_imgs_from_tree(data_dir=tiff_dir, img_sub_folder=img_sub_folder, fovs=current_points)
    current_labels = segmentation_labels.loc[current_points, :, :, :]
    
    # segment the imaging data
    normalized_data, transformed_data = segmentation_utils.generate_expression_matrix(segmentation_labels=current_labels, image_data=image_data)
    
    combined_normalized_data = combined_normalized_data.append(normalized_data)
    combined_transformed_data = combined_transformed_data.append(transformed_data)

In [None]:
# save output as CSV
combined_normalized_data.to_csv(os.path.join(single_cell_dir, 'normalized_data.csv'), index=False)
combined_transformed_data.to_csv(os.path.join(single_cell_dir, 'transformed_data.csv'), index=False)