# Processing organoid images

## Import relevant python packages

In [None]:
import os
import errno
import numpy as np
import tensorflow as tf
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
for gpu in gpus:
    tf.config.experimental.set_memory_growth(gpu, True)
import deepcell
import imagecodecs

## Load application

In [None]:
from deepcell.model_zoo.panopticnet import PanopticNet
from deepcell.applications import Application
from deepcell_toolbox.processing import histogram_normalization
from deepcell_toolbox.deep_watershed import deep_watershed as watershed_postprocessing

import functools

def preprocess(*args):
    return histogram_normalization(*args, kernel_size=[32,32])

def postprocess(*args):
    return watershed_postprocessing(*args,
                                    detection_threshold=0.25,
                                    distance_threshold=0.1,
                                    min_distance=2.5)

model = PanopticNet('resnet50',
                   input_shape=(256,256,1),
                   norm_method=None,
                   num_semantic_heads=3,
                   num_semantic_classes=[1,1,2],
                   location=True,
                   include_top=True,
                   interpolation='bilinear',
                   lite=True)
model.load_weights('/scratch/users/jschaepe/pfordyce/Microwells/organoid_resnet50.h5')

class OrganoidSegmenter(Application):
    def __init__(self, model=None):
        if model is None:
            raise ValueError('Provide a model')

        super(OrganoidSegmenter, self).__init__(
            model,
            model_image_shape=model.input_shape[1:],
            model_mpp=0.5,
            preprocessing_fn=preprocess,
            postprocessing_fn=postprocess
        )
    
    def predict(self,
                image,
                batch_size=4,
                image_mpp=None,
                preprocess_kwargs={},
                postprocess_kwargs={}):

        return self._predict_segmentation(
            image,
            batch_size=batch_size,
            image_mpp=image_mpp,
            preprocess_kwargs=preprocess_kwargs,
            postprocess_kwargs=postprocess_kwargs)

OS = OrganoidSegmenter(model)

## Data processing functions

In [None]:
from skimage.exposure import equalize_adapthist, rescale_intensity
import pandas as pd
import skimage
import fnmatch
import imageio
import datetime
from pathlib import Path
from matplotlib import pyplot as plt
from skimage import io
import time
import h5py

# returns well info for a given experiment
def get_well_info(metadata_file, experiment_id):
    metadata = pd.read_csv(metadata_file)
    metadata = metadata[metadata['Experiment']==experiment_id]
    well_dict = {'A1':1, 'A2':2, 'A3':3, 'A4':4, 'B1':5, 'B2':6, 'B3':7, 'B4':8, 'C1':9, 'C2':10, 'C3':11, 'C4':12}
    metadata['row_number'] = metadata.apply(lambda row : well_dict[row['Well']], axis=1)
    metadata = metadata.drop(['Experiment'], axis=1)
    return metadata.to_numpy()

# sets up an h5 file to store predictions from deepcell
def setup_h5file(experiment_folder, experiment_id, experiment_date, well_info):
    print('setting up h5file...')
    title = experiment_id + '_predicted_images.h5'
    h5file = h5py.File(title, 'w')
    for well_id, mutant, well_number in well_info:
        timepoints = read_timepoints(experiment_folder + well_id + '/', well_id)
        h5group = h5file.create_group(well_id)
        # load one image to see what size it is
        input_folder = experiment_folder + well_id + '/'
        mCherry_imagePath = input_folder + 'mCherry/timepoint_{0}-{1}-'.format(
            timepoints[0], well_number) + experiment_date + '.tif'
        imgs, n = load_images(mCherry_imagePath)
        for i in range(n):
            h5dset = h5group.create_dataset(name=str(i), dtype = np.uint8, shape=(len(timepoints), 4560, 4560), 
                                            chunks = True, compression = 'gzip', scaleoffset = True, shuffle = True)
    return h5file

# adds newly predicted image to h5file for storage
def update_h5file(h5file, label_imgs_arr, well_id, j, n):
    for i in range(n):
        h5file[well_id][str(i)][j,:,:] = label_imgs_arr[i,:,:]
    return h5file

# loads and proprocesses images for deepcell to predict on
def load_images(filepath):
    imgs = io.imread(filepath)
    n = imgs.shape[0]
    assert(imgs.shape[1:] == (2280,2280)), 'incorrect image size: ' + imgs.shape
    loaded_imgs = []
    for i in range(imgs.shape[0]):
        img = imgs[i,:,:]
        img = np.float32(img)
        img = skimage.transform.rescale(img, 2)
        img = np.expand_dims(img, axis=0)
        img = np.expand_dims(img, axis=-1)
        loaded_imgs.append(img)
    return loaded_imgs, n

# reads csv with timepoint information and returns a list of strings of timepoint names
def read_timepoints(input_folder, well_id):
    my_file = open(input_folder + well_id + '_timepoints.txt', "r")
    content = my_file.read()
    content = content.translate({ord(c): None for c in '\' '})
    content = content.translate({ord(c): '_' for c in '-'})
    timepoints = content.split(",")
    return timepoints

# updates the xrange and yrange to reflect the fact that the predicted image is 2x the size of the original
def get_slices(xrange, yrange):
    window_size = 115
    xstart = int(2*window_size*np.floor((xrange[0]/window_size))+1)
    xstop = xstart + window_size*2-1
    ystart = int(2*window_size*np.floor((yrange[0]/window_size))+1)
    ystop = ystart + window_size*2-1
    return xstart, xstop, ystart, ystop

# Adds a microwell_id for each row in the form of x_y_stack
def microwell_id_label(row):
    x = row['x']
    y = row['y']
    stack = row['stack_indexer']
    label = str(x) + '_' + str(y) + '_' + str(stack)
    return label

# Processes one well's data over all timepoints, processes all microwells and calculates and saves statistics
def process_well_timecourse(input_folder, experiment_date, well_id, well_number, mutant, h5file):
    # track time
    start = time.time()
    prevt = start
    timepoints = read_timepoints(input_folder, well_id)
    for j, timepoint in enumerate(timepoints):
        # print statements to see progress
        print('well_id: ', well_id, ', loop: ', j, ', timepoint: ', timepoint)
        print(input_folder + 'mCherry/timepoint_{0}-{1}-'.format(timepoint,well_number) + experiment_date + '.csv')
        mCherry_reimport = pd.read_csv(
            input_folder + 'mCherry/timepoint_{0}-{1}-'.format(timepoint,well_number) + experiment_date + '.csv')
        mCherry_imagePath = input_folder + 'mCherry/timepoint_{0}-{1}-'.format(
            timepoint, well_number) + experiment_date + '.tif'
        print('loading image...')
        imgs, n = load_images(mCherry_imagePath)
        print('predicting...')
        label_imgs = [OS.predict(img)[0,...,0] for img in imgs]
        label_imgs_arr = np.asarray(label_imgs)

        # save predicted arrays to hd5 file
        h5file = update_h5file(h5file, label_imgs_arr, well_id, j, n)

        #initialize empty storage variables
        l = len(mCherry_reimport)
        cell_count = np.zeros(l)
        cell_areas = []
        total_area = np.zeros(l)
        centroid_x = np.zeros(l)
        centroid_y = np.zeros(l)
        x_slice = []
        y_slice = []

        # loop through each well
        for i in range(len(mCherry_reimport)):
            chamberInfo = mCherry_reimport.iloc[i]
            x_slice.append(eval(chamberInfo.summaryImg_xslice))
            y_slice.append(eval(chamberInfo.summaryImg_yslice))
            # need to update slices since predicted image is 2x the size of the original
            xstart, xstop, ystart, ystop = get_slices(x_slice[i], y_slice[i])
            mCherry_well = label_imgs_arr[chamberInfo.stack_indexer,xstart:xstop,ystart:ystop] 

            # calculate statistics
            cells = np.unique(mCherry_well[np.where(mCherry_well != 0)])
            cell_count[i] = len(cells)
            areas = [len(mCherry_well[np.where(mCherry_well == cell)]) for cell in cells]
            cell_areas.append(areas)
            total_area[i] = np.sum(areas)
            centroid_x[i] = np.average(np.asarray(np.where(mCherry_well != 0)[1]))
            centroid_y[i] = np.average(np.asarray(np.where(mCherry_well != 0)[0]))

        # combine data into dataframe format
        data = {'timepoint':np.full(l,timepoint), 'x':mCherry_reimport.x.to_numpy(), 'y':mCherry_reimport.y.to_numpy(), 
                'stack_indexer':mCherry_reimport.stack_indexer.to_numpy(), 'cell_count':cell_count, 
                'cell_areas':cell_areas, 'total_area':total_area, 'centroid_x':centroid_x, 'centroid_y':centroid_y, 
                'hash_str':mCherry_reimport.hash_str.to_numpy(), 'experiment_id':np.full(l, experiment_id), 
                'well_number':np.full(l,well_number), 'well_id':np.full(l, well_id), 'mutant':np.full(l, mutant), 
                'x_slice':x_slice, 'y_slice':y_slice}

        # update dataframe
        if j ==0:
            df = pd.DataFrame(data)
        else:
            df2 = pd.DataFrame(data)
            df = df.append(df2, ignore_index = True)

        # print and keep track of time to monitor each loop
        currt = time.time()
        print('loop time: ', currt - prevt)
        prevt = currt
    # add microwell id to every row
    df.insert(0, 'microwell_id', df.apply(lambda row: microwell_id_label(row), axis=1).to_numpy(), False)
    df.to_csv(experiment_id + '_' + well_id + '_processed_summary.csv')
    return h5file

# Runs through each well within an experiment and processes it over the entire timecourse
def analyze_experiment(well_info, experiment_folder, experiment_id, experiment_date, h5file):
    print('processing experiment...')
    # loop through each well in the experiment a process over entire timecourse
    for well_id, mutant, well_number in well_info:  
        input_folder = experiment_folder + well_id + '/'
        h5file = process_well_timecourse(input_folder, experiment_date, well_id, well_number, mutant, h5file)
    return h5file


## Process organoid microwell experiment
To process a new experiment, you only need to update `experiment_id`, `experiment_folder`, `experiment_date` and `metadata_file`. Each loop over one well at one timepoint should take ~100s to process, so your total expected time will be `(# wells)*(# timepoints per well)*(# stacks per well)*11`. If this cell doesn't run all the way through, you may need to run the `h5file.close()` command in the cell below in order to run it again. You can disregard the initial warnings on the first loop.

This script assumes the following file structure: <br>
> data <br>
> > experiment_id <br>
> > > well_id <br>
> > > > well_id_timepoints.txt <br>
> > > > mCherry <br>
> > > > > timepoint_experiment_id_timepoint-well_number-experiment_date.csv <br>
> > > > > timepoint_experiment_id_timepoint-well_number-experiment_date.tif <br>


In [None]:
# if the above cell does not finish running, you may need to run the line below before trying again
h5file.close()

In [None]:
# These are the only three lines that need to be changed to process different experiments
experiment_id = 'experiment_id' # fill this in with your experiment folder name
#path to folder where you saved timelapse images from timecourse processing 
experiment_folder = '/home/users/jschaepe/scratch/pfordyce/data/' + experiment_id + '/'
# date that is at the end of each tif or csv in that experiment
experiment_date = '20201104'
#path to folder where you store experiment and well metadata
metadata_file = '/home/users/jschaepe/scratch/pfordyce/Microwells/microwell_well_info.csv'

well_info = get_well_info(metadata_file, experiment_id)
# create h5file for storing predicted images
h5file = setup_h5file(experiment_folder, experiment_id, experiment_date, well_info)
h5file = analyze_experiment(well_info, experiment_folder, experiment_id, experiment_date, h5file)
h5file.close()

## Plotting specific well or microwell prediction

In [None]:
def plot_well(experiment_id, timepoint, well_id, stack_index, experiment_folder):
    # load in necesarry information
    df = pd.read_csv(experiment_id + '_' + well_id + '_processed_summary.csv')
    input_folder = experiment_folder + well_id + '/'
    timepoints = read_timepoints(input_folder, well_id)
    timepoint_index = np.argmax(np.asarray(timepoints) == timepoint)
    input_folder = experiment_folder + well_id + '/'
    h5file_name = experiment_id + '_predicted_images.h5'
    h5file = h5py.File(h5file_name, 'r')
    well = df[df['well_id']==well_id]
    img = h5file[well_id][str(stack_index)][timepoint_index, :, :]
    
    # plot microwell
    plt.figure(figsize=(30,30))
    plt.imshow(img)
    plt.title(experiment_id + ', ' + well_id)
    plt.xlabel('pixels of predicted image')
    plt.ylabel('pixels of predicted image')
    plt.show()
    h5file.close()
    return

# change these parameters to plot different microwells
experiment_id = 'experiment_id' # fill this in with your experiment folder name
experiment_folder = '/home/users/jschaepe/scratch/pfordyce/data/' + experiment_id + '/'
well_id = 'A1'
stack_index = 0
first_timepoint = '20200306_185209'
last_timepoint = '20200314_114229'
plot_well(experiment_id, first_timepoint, well_id, stack_index, experiment_folder)
plot_well(experiment_id, last_timepoint, well_id, stack_index, experiment_folder)



In [None]:
# plots a specific microwell
def plot_microwell(experiment_id, timepoint, well_id, microwell_id, experiment_folder):
    # load in necesarry information
    df = pd.read_csv(experiment_id + '_' + well_id + '_processed_summary.csv')
    input_folder = experiment_folder + well_id + '/'
    timepoints = read_timepoints(input_folder, well_id)
    timepoint_index = np.argmax(np.asarray(timepoints) == timepoint)
    input_folder = experiment_folder + well_id + '/'
    h5file_name = experiment_id + '_predicted_images.h5'
    h5file = h5py.File(h5file_name, 'r')
    microwell = df[df['microwell_id']==microwell_id]
    xstart, xstop, ystart, ystop = get_slices(list(eval(list(microwell['x_slice'])[0])), 
                                              list(eval(list(microwell['y_slice'])[0])))
    img = h5file[well_id][microwell_id[-1]][timepoint_index, xstart:xstop, ystart:ystop]
    print(len(np.where(img != 0)[0]))
    
    # plot microwell
    plt.figure()
    plt.imshow(img)
    plt.title(experiment_id + ', ' + well_id + ', ' + microwell_id)
    plt.xlabel('pixels of predicted image')
    plt.ylabel('pixels of predicted image')
    plt.show()
    h5file.close()

# change these parameters to plot different microwells
experiment_id = 'experiment_id' # fill this in with your experiment folder name
experiment_folder = '/home/users/jschaepe/scratch/pfordyce/data/' + experiment_id + '/'
well_id = 'A1'
# in the form of 'x_y_stackindex'
microwell_id = '5_11_0'
timepoint = '20200314_114229'
plot_microwell(experiment_id, timepoint, well_id, microwell_id, experiment_folder)

In [None]:
# plots a specific microwell
def plot_microwell_all_timepoints(experiment_id, well_id, microwell_id, experiment_folder):
    # load in necesarry information
    df = pd.read_csv(experiment_id + '_' + well_id + '_processed_summary.csv')
    input_folder = experiment_folder + well_id + '/'
    timepoints = read_timepoints(input_folder, well_id)
    h5file_name = experiment_id + '_predicted_images.h5'
    h5file = h5py.File(h5file_name, 'r')
    microwell = df[df['microwell_id']==microwell_id]
    xstart, xstop, ystart, ystop = get_slices(list(eval(list(microwell['x_slice'])[0])), 
                                                  list(eval(list(microwell['y_slice'])[0])))
    fig, axs = plt.subplots(int(np.ceil(len(timepoints)/10)), 10, figsize = (50, 50))
    counterx = 0
    countery = 0
    for i, timepoint in enumerate(timepoints):
        timepoint_index = np.argmax(np.asarray(timepoints) == timepoint)
        img = h5file['A1'][microwell_id[-1]][timepoint_index, xstart:xstop, ystart:ystop]
        axs[countery, counterx].imshow(img)
        axs[countery, counterx].set_title(timepoint, fontsize = 16)
        counterx += 1
        if counterx == 10:
            counterx = 0
            countery += 1
        if i == len(timepoints) - 1:
            for j in range(int(np.ceil(len(timepoints)/10))*10 - len(timepoints)):
                fig.delaxes(axs[countery, counterx + j])
    
    # plot microwell
    fig.suptitle(experiment_id + ', ' + well_id + ', ' + microwell_id, fontsize=50, y=0.9)
    plt.show()
    # uncomment this if you want to save the output
#     plt.savefig(experiment_id + '_' + well_id +'_' + microwell_id + '_all_timepoints.png')
    h5file.close()

# change these parameters to plot different microwells
experiment_id = 'experiment_id' # fill this in with your experiment folder name
experiment_folder = '/home/users/jschaepe/scratch/pfordyce/data/' + experiment_id + '/'
well_id = 'A1'
timepoint = '20200314_114229'
# in the form of 'x_y_stackindex'
microwell_id = '5_11_0'

plot_microwell_all_timepoints(experiment_id, well_id, microwell_id, experiment_folder)




# Interactive plots
Holoviews and Bokeh are useful for creating interactive plots

In [None]:
import hvplot.pandas
import holoviews as hv

In [None]:
# Holoviews plot

def plot_total_area_and_centroids(experiment_id, well_id):
    df = pd.read_csv(experiment_id + '_' + well_id + '_processed_summary.csv')
    df_only_cells = df[df.cell_count > 0]
    p1 = df_only_cells.hvplot.hist('total_area', groupby = 'timepoint', min_height=0, width = 400, height = 400)
    p2 = df_only_cells.hvplot.hexbin(gridsize=10, x='centroid_x', y='centroid_y', groupby = 'timepoint', width = 500, height=400)
    layout = p1 + p2

    # save to html file
    hv.save(layout, 'cell_count_and_location_over_timepoints.html')
    return layout

well_id = 'A1'
experiment_id = 'experiment_id' # fill this in with your experiment folder name
layout = plot_total_area_and_centroids(experiment_id, well_id)
layout

## Datashader plots
Datashader is useful for plotting and saving large datasets quickly.

In [None]:
import datashader as ds
import colorcet
import datashader.transfer_functions as tf

In [None]:
def datashader_growth_rate(experiment_id, well_id):
    df = pd.read_csv(experiment_id + '_' + well_id + '_processed_summary.csv')
    timepoints = df['timepoint'].unique()
    toplot=pd.pivot_table(df,index=['timepoint'],columns=df.groupby(['timepoint']).cumcount().add(1),values=['total_area'],aggfunc='sum')
    toplot.columns=toplot.columns.map('{0[0]}{0[1]}'.format) 
    toplot = toplot.reset_index()
    toplot = toplot.drop('timepoint', axis=1)
    toplot.columns = list(range(len(toplot.columns)))
    toplot = toplot.T
    toplot.columns = list(range(len(timepoints)))

    points = len(timepoints) - 1
    time = np.linspace(0, 1, points)
    cvs = ds.Canvas(plot_height=400, plot_width=1000)
    agg = cvs.line(toplot, x=time, y=list(range(points)), agg=ds.count(), axis=1)
    img = tf.shade(agg, how='eq_hist')
    return img

# change these two to plot all growth rates in well
well_id = 'A1'
experiment_id = 'experiment_id' # fill this in with your experiment folder name

img = datashader_growth_rate(experiment_id, well_id)
img