In [None]:
# Set pathways and initialize model
import os
import re
import time
import torch

import numpy as np
import pandas as pd

from scipy.spatial import KDTree
from cellpose import models, core, io

# set pathways to test files
# Paths to annotated test data
base_dir = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/images/'
output_root = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/cellpose_train'

# create output directory if it doesn't already exist
if not os.path.exists(output_root):
    os.mkdir(output_root)

# have cellpose log run
io.logger_setup()

# initialize model
use_GPU = core.use_gpu()
model = models.CellposeModel(gpu = use_GPU, model_type = 'nuclei')

In [None]:
# Load custom functions
import warnings
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

from bs4 import BeautifulSoup
from collections import defaultdict
from pycm import ConfusionMatrix


# function to turn .xml files into x, y, z coordinates
def get_locations(path):
    # read xml file
    with open(path, 'r') as f:
        data = f.read()
    # get xml data in iterable format
    markers = BeautifulSoup(data, features='xml').find_all('Marker')
    cell_loci = []
    # iterate and output marker locations
    for marker in markers:
        coords = re.findall('[0-9]+', marker.text)
        coords = [int(x) for x in coords]
        cell_loci.append(coords)
    
    return cell_loci

# prediction and annotation overlap and build dataframe for validation metrics
# max distance a prediction can be off and still count as cell. very conservative at 1 px
def annot_pred_overlap(blocks, max_dist, type_2D, train = False, get_keys = []):

    metric_df = pd.DataFrame(columns = ['x', 'y', 'z', 'block', 'annotated', 'predicted', 'mask_id'])
    
    if train:
        type_key = ['train_detect_', 'train_reject_']
    else:
        type_key = ['pred_detect_', 'pred_detect_'] 
    
    if len(get_keys) == 0:
        key_list = list(blocks.keys())
    else:
        key_list = get_keys
    
    for key in key_list:
    
        # get annotated data (not always a reject file)
        annot_detect = np.asarray(blocks[key]['detect'])
            
        try:
            annot_reject = np.asarray(blocks[key]['reject'])
            type_reject = np.ones((len(annot_reject), 1)) * 2
        
            annot_points = np.concatenate((annot_detect, annot_reject), axis = 0)
            annot_types = np.concatenate((np.ones((len(annot_detect), 1)),
                                          np.ones((len(annot_reject), 1)) * 2), axis = 0)
        except:
            annot_points = annot_detect.copy()
            annot_types = np.ones((len(annot_detect), 1))

        # get predicted data
        pred_detect = np.asarray(blocks[key][type_key[0] + type_2D])
        pred_reject = np.asarray(blocks[key][type_key[1] + type_2D])
        
        detect_mask = np.asarray(blocks[key][type_key[0] + type_2D + '_mask'])
        reject_mask = np.asarray(blocks[key][type_key[1] + type_2D + '_mask'])
    
        if len(pred_reject) > 0:
            pred_points = np.concatenate((pred_detect, pred_reject), axis = 0)
            pred_types = np.concatenate((np.ones((len(pred_detect), 1)),
                                         np.ones((len(pred_reject), 1)) * 2), axis = 0)
            
            masks = np.concatenate((detect_mask,
                                   reject_mask), axis = 0)
        else:
            pred_points = pred_detect.copy()
            pred_types = np.ones((len(pred_detect), 1))
            
            masks = detect_mask.copy()
        
        # create annotation and prediction tree for comparison
        annot_tree = KDTree(annot_points)
        pred_tree = KDTree(pred_points)
    
        # returns: for each element in annot_tree[i], indexes[i] is a list of indecies within distance r from pred_tree
        indexes = annot_tree.query_ball_tree(pred_tree, r = max_dist)
        pred_id = np.zeros((len(annot_types), 1))
        mask_id = np.zeros((len(annot_types), 1))
        pred_extra = np.zeros((len(pred_types), 1))
    
        # get the index and type for all annotated cells and id predicted cells that were not annotated
        for c, idx in enumerate(indexes):
            if len(idx) > 0:
                pred_id[c] = pred_types[idx[0]]
                mask_id[c] = masks[idx[0]]
                pred_extra[idx[0]] = 1
    
        pred_id[pred_id == 0] = 3
        data_array = np.concatenate((annot_points,
                                    np.ones((len(annot_points), 1)) * int(key),
                                    annot_types,
                                    pred_id,
                                    mask_id), axis = 1)
    
        # get location and type of predicted cells that where not annotated
        pred_extra_loc, _ = np.where(pred_extra == 0)
        if len(pred_extra_loc) > 0:
            curr_points = pred_points[pred_extra_loc[:], :]
            curr_types = pred_types[pred_extra_loc[:], :]
            pred_array = np.concatenate((curr_points,
                                        np.ones((len(curr_points), 1)) * int(key),
                                        np.ones((len(curr_points), 1)) * 3,
                                        curr_types,
                                        np.zeros((len(curr_points), 1))), axis = 1)
        
            # add to annot array
            data_array = np.vstack((data_array, pred_array))

        # create dataframes                                         
        curr_df = pd.DataFrame(data_array, columns = ['x', 'y', 'z', 'block', 'annotated', 'predicted', 'mask_id'])
        metric_df = pd.concat((metric_df, curr_df))
    return metric_df

# Identify prediction and annotation overlap and build dataframe for validation metrics
# Key: 1: signal that is detected and considered to be a cell
#      2: signal that is detected but found to be noise and rejected
#      3: singal that is annotated but not predicted to be signal
def get_performance(data, method, trained, get_keys = []):
    #since using different algorithm more flexibility in distance may be needed
    perf_metrics = []

    for i in range(10):
        dist = i + 1
        flow_df = annot_pred_overlap(data, dist, method, trained, get_keys)
        cm_flow = ConfusionMatrix(flow_df['annotated'].values, flow_df['predicted'].values, classes = [1.0, 2.0, 3.0])
        perf_metrics.extend([[cm_flow.class_stat['PPV'][1.0], dist],
                            [cm_flow.class_stat['TPR'][1.0], dist],
                            [cm_flow.class_stat['F1'][1.0], dist]])

    # ignores numpy.where() warning
    warnings.simplefilter(action = 'ignore', category = FutureWarning)

    performance = np.asarray(perf_metrics)
    performance = np.where(performance == 'None', 0, performance)
    return performance

# plots the performance of the flow model as the distance from annotated centroid is varied
def plot_performance(performance):
    colors = ['r', 'g', 'b']
    mets = ['Percision', 'Recall', 'F1 Score']

    opt_flow = np.argmax(performance[2::3, 0])

    fig, ax = plt.subplots(figsize = (4, 4))
    fig.suptitle('Distance between annotated and predicted Cell')

    for i in range(3):
        sns.lineplot(x = performance[i::3, 1], y = performance[i::3, 0], color = colors[i], label = mets[i], ax = ax, zorder = i+2)
        ax.set_xlabel('Max distance (px)')
        ax.set_ylabel('Metric score')
        ax.set_title(method + ' detection')
        ax.set_ylim(0.2, 1)
    
        ax.axvline(x = performance[2::3, 1][opt_flow], color = 'k', linestyle = '--', zorder = 1)

        
    sns.despine()
    plt.tight_layout(rect=[0, 0, 1, 0.95])

    print('Top F1 Scores for flow detection: {0}'.format(performance[2::3, 0][opt_flow]))

In [None]:
# Get annotated data and image file paths
blocks = defaultdict(dict)

# loop through base directory and collect all data from "block_#" subfolders
for root, dirs, files in os.walk(base_dir):
    if 'block' in root: 
        for file in files:
            block_num = re.findall('[0-9]', file)[0]
            if 'signal' in file:
                blocks[block_num]['signal'] = os.path.join(root, file)
            elif 'background' in file:
                blocks[block_num]['background'] = os.path.join(root, file)
            elif 'detect' in file:
                blocks[block_num]['detect'] = get_locations(os.path.join(root, file))
            else:
                blocks[block_num]['reject'] = get_locations(os.path.join(root, file))

# print info on blocks        
for key in blocks.keys():
    try:
        rejected = len(blocks[key]['reject'])
    except:
        rejected = 0
    print('Pulled annotation data for block {0}: {1} cells and {2} noncells.'.format(str(key), len(blocks[key]['detect']), rejected))

In [None]:
# There are 3 ways to run 3D Cellpose. The first 2 are exrapolating 2D to 3D
# The Third way is using new Cellpose3D repo in other script
import skimage.io
from skimage.measure import regionprops

# Set parameters
# grayscale=0, R=1, G=2, B=3
# channels = [cytoplasm, nucleus]
# channels = [0,0] # IF YOU HAVE GRAYSCALE
channels = [0,0]
prob_thresh = 0.9
stitch_thresh = 0.5
diameter = 9 # unit = px
min_size = 60
detect_times = defaultdict(int)
method = 'flow'

start_detect = time.time()
for key in blocks.keys():
    for ch in ['signal', 'background']:
        fname = os.path.basename(blocks[key][ch])[:-4]
        img = skimage.io.imread(blocks[key][ch])
        
        print(fname)
            
        if method == 'flow':
            masks, flows, styles = model.eval(img, channels = channels, diameter = diameter, cellprob_threshold = prob_thresh, 
                                                         do_3D = True, min_size = min_size)
        elif method == 'stitch':
            masks, flows, styles = model.eval(img, channels = channels, diameter = diameter, cellprob_threshold = prob_thresh, 
                                                         do_3D = False, stitch_threshold = stitch_thresh)

        # get centroids from masks. see: https://github.com/MouseLand/cellpose/issues/337
        centroids = []
        mask_num = []
        candidates = regionprops(masks)
        for c in range(len(candidates)):
            location = [int(x) for x in candidates[c]['centroid']]
            centroids.append(location[::-1])
            mask_num.append(masks[location[0], location[1], location[2]])
        
        if ch == 'signal':
            blocks[key]['pred_detect_' + method] = centroids
            blocks[key]['pred_detect_' + method + '_mask'] = mask_num
        else:
            blocks[key]['pred_reject_' + method] = centroids
            blocks[key]['pred_reject_' + method + '_mask'] = mask_num
            
        # save masks and flows
        output_path = os.path.join(output_root, method)

        # create output directory if it doesn't already exist
        if not os.path.exists(output_path):
            os.mkdir(output_path)
            
        # save masks as tif 
        io.save_masks(img, masks, flows, blocks[key][ch], png = False, tif = True, channels = channels, 
                      save_flows = True, save_outlines = True, savedir = output_path)
            
        # save all data for plotting
        io.masks_flows_to_seg(img, masks, flows, diams, os.path.join(output_path, fname), channels=None)

detect_times[method] = time.time() - start_detect

In [None]:
# get time to run each method
for method in ['flow']:
    print('Using ' + method + ' method for 2D to 3D')
    for key in blocks.keys():
        print('Classified {0} cells and {1} noncells for annocation block {2}.'.format(len(blocks[key]['pred_detect_' + method]), 
                                                                                       len(blocks[key]['pred_reject_' + method]), 
                                                                                       str(key)))

    print('Detection and Classification via ' + method + ' took {0} seconds.\n'.format(detect_times[method]))

In [None]:
# get metrics on model performance and plot results
method = 'flow'
trained = False

performance = get_performance(blocks, method, trained)
plot_performance(performance)

In [None]:
## Train Model

In [None]:
#Train model using only cells with centroids within N number of px away from annotation
dist = 1

flow_df = annot_pred_overlap(blocks, dist, 'flow')
train_ids = flow_df.loc[(flow_df['annotated'] == 1.0) & (flow_df['predicted'] == 1), ['mask_id']].values
train_ids = np.unique(train_ids)

# load image and mask data for block 4
mask_path = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/cellpose_train/flow/block_4_signal_seg.npy'
model_path = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/cellpose_train/model'

if not os.path.exists(model_path):
    os.mkdir(model_path)

# keep only masks that are withing the alotted threshold 
seg_file = np.load(mask_path, allow_pickle = True).item()
mask_array = seg_file['masks']
mask_array = np.where(np.isin(mask_array, train_ids), mask_array, 0)
train_array = skimage.io.imread(blocks['4']['signal'])

#verify that the number of identified masks is the same as in the masked array
print('{0} cells identified within {1} px of annotated cells. Mask array contains {2} unique masks.'.format(len(train_ids), 
                                                                                                            dist, 
                                                                                                            len(np.unique(mask_array))))

# format needs to be list of arrays with shape (nchan x Ly x Lx)
train_imgs = [np.expand_dims(train_array[img, :, :], axis = 0) for img in range(train_array.shape[0])]
train_masks = [np.expand_dims(mask_array[mask, :, :], axis = 0) for mask in range(mask_array.shape[0])]


model.train(train_data = train_imgs, train_labels = train_masks, channels = [0, 0],
            save_path = model_path, model_name = 'block_4_model')

In [None]:
# Rerun with Trained Model: only did block 4
model_path = '/home/elyse/Documents/GitHub/LS_evaluation_tool/test_data/cellpose_train/models/block_4_model'
model_trained = models.CellposeModel(gpu = use_GPU, pretrained_model = model_path)

channels = [0,0]
prob_thresh = 0.0
diameter = 5 # unit = px
min_size = 10
detect_times = defaultdict(int)
method = 'flow'
key_root = ['train_detect_', 'train_reject_'] 

for kr, ch in enumerate(['signal', 'background']):
    fname = os.path.basename(blocks['4'][ch])[:-4] + '_trained'
    img = skimage.io.imread(blocks['4'][ch])

    masks, flows, styles = model_trained.eval(img, channels = channels, diameter = diameter, cellprob_threshold = prob_thresh, 
                                              do_3D = True, min_size = min_size)
    
    # get centroids from masks. see: https://github.com/MouseLand/cellpose/issues/337
    centroids = []
    mask_num = []
    candidates = regionprops(masks)
    for c in range(len(candidates)):
        location = [int(x) for x in candidates[c]['centroid']]
        centroids.append(location[::-1])
        mask_num.append(masks[location[0], location[1], location[2]])
        
    blocks['4'][key_root[kr] + method] = centroids
    blocks['4'][key_root[kr] + method + '_mask'] = mask_num
            
    # save masks and flows
    output_path = os.path.join(output_root, method)

    # create output directory if it doesn't already exist
    if not os.path.exists(output_path):
        os.mkdir(output_path)
            
    # save masks as tif 
    io.save_masks(img, masks, flows, blocks['4'][ch], png = False, tif = True, channels = channels, 
                  save_flows = True, save_outlines = True, savedir = output_path)
            
    # save all data for plotting
    io.masks_flows_to_seg(img, masks, flows, diams, os.path.join(output_path, fname), channels=None)


In [None]:
# get time to run each method
for method in ['flow']:
    print('Using ' + method + ' method for 2D to 3D')
    for key in ['4']:
        print('Classified {0} cells and {1} noncells for annocation block {2}.'.format(len(blocks[key]['train_detect_' + method]), 
                                                                                       len(blocks[key]['train_reject_' + method]), 
                                                                                       str(key)))

    print('Detection and Classification via ' + method + ' took {0} seconds.\n'.format(detect_times[method]))

In [None]:
# get metrics on model performance and plot results
method = 'flow'
trained = True
get_keys = ['4']

performance = get_performance(blocks, method, trained, get_keys)
plot_performance(performance)