# Inference

In [1]:
!pip install deepcell

Collecting deepcell
  Downloading DeepCell-0.12.9.tar.gz (147 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m147.3/147.3 kB[0m [31m7.6 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l- \ | / - \ | done
[?25h  Getting requirements to build wheel ... [?25l- done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- done
Collecting tensorflow~=2.8.0 (from deepcell)
  Obtaining dependency information for tensorflow~=2.8.0 from https://files.pythonhosted.org/packages/e9/c6/b2e8c2e6c8537775f30eac2ee93b7e375a9bac9abec87b624415cc2d8cab/tensorflow-2.8.4-cp310-cp310-manylinux2010_x86_64.whl.metadata
  Downloading tensorflow-2.8.4-cp310-cp310-manylinux2010_x86_64.whl.metadata (2.9 kB)
Collecting tensorflow-addons~=0.16.1 (from deepcell)
  Obtaining dependency information for tensorflow-addons~=0.16.1 from https://files.pythonhosted.org/packages/91/7a/371dc8fc995ecfc6680cbbefb9467a2fdba45e5905beeb5f2fd2f533fe06/ten

In [2]:
from scipy.spatial.qhull import QhullError
from scipy import spatial
spatial.QhullError = QhullError

import os
import sys

import numpy as np
import tensorflow as tf

from deepcell.applications import NuclearSegmentation, CellTracking
from deepcell_tracking.trk_io import load_trks, save_trk

import copy
import imageio
import matplotlib as mpl
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt



In [3]:
source_data = '/kaggle/input/dynamicnuclearnet-tracking-v1-0/test.trks'

data_dir = '/kaggle/working/benchmarking/DeepCell/data'
gt_seg_dir = os.path.join(data_dir, 'SEG_GT')
pred_seg_dir = os.path.join(data_dir, 'SEG_PRED')

image_dir = '/kaggle/working/benchmarking/images'
os.makedirs(image_dir)

for d in [data_dir, gt_seg_dir, pred_seg_dir]:
    if not os.path.exists(d):
        os.makedirs(d)
        
model_urls = {
    'NuclearSegmentation': 'https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearSegmentation-75.tar.gz',
    'NuclearTrackingNE': 'https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingNE-75.tar.gz',
    'NuclearTrackingInf': 'https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingInf-75.tar.gz'
}

In [4]:
%%writefile /kaggle/working/benchmarking/utils.py
import os

import numpy as np
from tifffile import imwrite

from deepcell_tracking.isbi_utils import trk_to_isbi
from deepcell_tracking.utils import contig_tracks


def find_zero_padding(X):
    """Remove zero padding to avoid adverse effects on model performance"""
    # Calculate position of padding based on first frame
    # Assume that padding is in blocks on the edges of image
    good_rows = np.where(X[0].any(axis=0))[0]
    good_cols = np.where(X[0].any(axis=1))[0]

    slc = (
        slice(None),
        slice(good_cols[0], good_cols[-1] + 1),
        slice(good_rows[0], good_rows[-1] + 1),
        slice(None)
    )

    return slc


def save_ctc_raw(exp_dir, batch, X):
    raw_dir = os.path.join(exp_dir, '{:03}'.format(batch))
    
    if not os.path.exists(raw_dir):
        os.makedirs(raw_dir)
        
    # Save each frame as a tiff file
    for i in range(X.shape[0]):
        imwrite(os.path.join(raw_dir, 't{:03}.tif'.format(i)), X[i])
        

def save_ctc_gt(exp_dir, batch, y, lineage):
    gt_dir = os.path.join(exp_dir, '{:03}_GT'.format(batch))
    seg_dir = os.path.join(gt_dir, 'SEG')
    tra_dir = os.path.join(gt_dir, 'TRA')
    
    for d in [gt_dir, seg_dir, tra_dir]:
        if not os.path.exists(d):
            os.makedirs(d)
            
    # Save lineage to isbi txt
    df = trk_to_isbi(lineage)
    df.to_csv(os.path.join(tra_dir, 'man_track.txt'), sep=' ', header=False, index=False)
    
    # Save each frame as a tiff file
    for i in range(y.shape[0]):
        imwrite(os.path.join(seg_dir, 'man_seg{:03}.tif'.format(i)), y[i].astype('uint16'))
        imwrite(os.path.join(tra_dir, 'man_track{:03}.tif'.format(i)), y[i].astype('uint16'))
        
            
def save_ctc_res(exp_dir, batch, y, lineage=None, seg=False):
    if seg:
        name = '{:03}_SEG_RES'.format(batch)
    else:
        name = '{:03}_RES'.format(batch)
    res_dir = os.path.join(exp_dir, name)
    
    if not os.path.exists(res_dir):
        os.makedirs(res_dir)
        
    # Save lineage to isbi txt
    if lineage:
        df = trk_to_isbi(lineage)
        df.to_csv(os.path.join(res_dir, 'res_track.txt'), sep=' ', header=False, index=False)
    
    # Save each frame as a tiff file
    for i in range(y.shape[0]):
        imwrite(os.path.join(res_dir, 'mask{:03}.tif'.format(i)), y[i].astype('uint16'))
        
        
def convert_to_contiguous(y, lineage):
    done_labels = []
    while set(done_labels) != set(lineage.keys()):
        leftover_labels = [l for l in lineage.keys() if l not in done_labels]
        for label in leftover_labels:
            lineage, y = contig_tracks(label, lineage, y)
            done_labels.append(label)
    
    return y, lineage

Writing /kaggle/working/benchmarking/utils.py


In [5]:
sys.path.append('/kaggle/working/benchmarking')
import utils

In [6]:
def get_iou(boxA, boxB):
# determine the (x, y)-coordinates of the intersection rectangle
    xA = max(boxA[0], boxB[0])
    yA = max(boxA[1], boxB[1])
    xB = min(boxA[2], boxB[2])
    yB = min(boxA[3], boxB[3])

    interArea = max(0, xB - xA + 1) * max(0, yB - yA + 1)

    boxAArea = (boxA[2] - boxA[0] + 1) * (boxA[3] - boxA[1] + 1)
    boxBArea = (boxB[2] - boxB[0] + 1) * (boxB[3] - boxB[1] + 1)

    iou = interArea / float(boxAArea + boxBArea - interArea)
    # return the intersection over union value
    return iou


def mask_iou(curr_box, curr_label, curr_image, prev_box, prev_label, prev_image):
    new_prev_image = np.zeros(prev_image.shape)
    new_curr_image = np.zeros(curr_image.shape)
    
    curr_pos = np.where(curr_image == curr_label)
    prev_pos = np.where(prev_image == prev_label)
#     print(curr_pos)
#     print(prev_pos)
    
    for i in range(len(curr_pos[0])):
        new_curr_image[curr_pos[0][i]][curr_pos[1][i]][curr_pos[2][i]] = 1
    for i in range(len(prev_pos[0])):
        new_prev_image[prev_pos[0][i]][prev_pos[1][i]][prev_pos[2][i]] = 1
    
    intersection = (new_curr_image * new_prev_image).sum()
    if intersection == 0:
        return 0.0
    union = np.logical_or(new_curr_image, new_prev_image).astype(int).sum()
    return intersection / union



# get bounding boxes for all cells in one image
def get_all_box(image):
    label_box_dict = {}
    unique_labels = np.unique(image)
    
    for label in unique_labels:
        positions = np.where(image == label)
        if image[positions[0][0]][positions[1][0]][positions[2][0]] != 0:
            box = [min(positions[0]), min(positions[1]), max(positions[0]), max(positions[1])]
            label_box_dict[int(label)] = box
    
    return label_box_dict



def get_keys_with_same_value(relation_dict):
    rev_multidict = {}
    relation_dict_copy = relation_dict.copy()
    for key, value in relation_dict_copy.items():
        rev_multidict.setdefault(value, set()).add(key)
    
    remove_keys = []
    for key, value in rev_multidict.items():
        if len(value) == 1:
            remove_keys.append(key)
        else:
            for v in value:
                relation_dict_copy.pop(v)
            
    for key in remove_keys:
        rev_multidict.pop(key)
        
    return rev_multidict, relation_dict_copy



def map_new_label_to_image(ori_image, tar_image, prev_label, curr_label):
    ori_pos = np.where(ori_image == prev_label)
    for i in range(len(ori_pos[0])):
        tar_image[ori_pos[0][i]][ori_pos[1][i]][ori_pos[2][i]] = curr_label
    return tar_image

In [7]:
# Initialize a dictionary by the first frame
# each key represent the id of a cell
# the format for each frame is (frame_id, bbox, intensity)

def init_dict(X, y):
    
    trk_dict = {'tracks': {}, 'X': X.copy(), 'y': y.copy(), 'y_tracked': y.copy()}
    prev_box_dict = get_all_box(y[len(y) - 1])
    y_tracked = np.zeros(y[0].shape)
    
    label_rev = 1
    
    for label, box in prev_box_dict.items():
        trk_dict['tracks'][label_rev] = {
                                        'label': label_rev,
                                        'frames': [len(y) - 1],
                                        'daughters': [],
                                        'capped': False,
                                        'frame_div': None,
                                        'parent': None,
                                    }
        y_tracked = map_new_label_to_image(trk_dict['y_tracked'][len(y) - 1], y_tracked, label, label_rev)
        label_rev += 1
        
    prev_box_dict = get_all_box(y_tracked)
    trk_dict['y_tracked'][len(y) - 1] = y_tracked
    
    return trk_dict, prev_box_dict
        
    

def finalize_dict(init_trk_dict, prev_box_dict, X, y):
    
    trk_dict = init_trk_dict
    next_label = int(max(prev_box_dict.keys())) + 1

    
    # Iterate over all frames
    for frame_id in range(len(y)-2, -1, -1):
        prev_frame_id = frame_id + 1
        y_tracked = np.zeros(y[0].shape)
        curr_box_dict = get_all_box(trk_dict['y_tracked'][frame_id])
        relation_dict = {}
    
        curr_labels = curr_box_dict.keys()
        present_dict = {}
        for label in curr_labels:
            present_dict[label] = False
            
        # go through all detected cells in the previous frame
        for prev_label in prev_box_dict:
            prev_bbox = prev_box_dict[prev_label]
            
            best_iou = 0
            best_id = 0
            best_bbox = []
#                 best_mask = None
            
            # go through all detected cells in the current frame
            for curr_label in curr_box_dict:
                curr_bbox = curr_box_dict[curr_label]
                box_iou = get_iou(prev_bbox, curr_bbox)

                # see if one of the box in the current frame intersects with one in the previous frame
                if box_iou > 0:
                    iou = mask_iou(curr_bbox, curr_label, trk_dict['y_tracked'][frame_id], prev_bbox, prev_label, trk_dict['y_tracked'][prev_frame_id])

                    # choose the one that has the largest intersection
                    if iou > best_iou:
                        best_iou = iou
                        best_label = curr_label
                        best_bbox = curr_bbox
                
                # if we find a cell in the current frame that does intersect one in the previous frame,
                # we assume that they are the same cell
                if best_iou > 0:
                    relation_dict[prev_label] = best_label
        
        for prev_label, best_label in relation_dict.items():
            present_dict[best_label] = True
            
        # process cells without divisions
        rev_dict_more_values, rev_dict_one_value = get_keys_with_same_value(relation_dict)

        for prev_label, curr_label in rev_dict_one_value.items():
            trk_dict['tracks'][prev_label]['frames'].insert(0, frame_id)
            y_tracked = map_new_label_to_image(trk_dict['y_tracked'][frame_id], y_tracked, curr_label, prev_label)
        
        # process cells with divisions
#         if len(rev_dict_more_values) >0 :
#             print('merge occurs at label:', next_label)
        for curr_label, prev_labels in rev_dict_more_values.items():
            children_list = list(prev_labels)
            trk_dict['tracks'][next_label] = {
                                        'label': next_label,
                                        'frames': [frame_id],
                                        'daughters': [int(x) for x in children_list],
                                        'capped': False,
                                        'frame_div': frame_id,
                                        'parent': None,
                                    }
            y_tracked = map_new_label_to_image(trk_dict['y_tracked'][frame_id], y_tracked, curr_label, next_label)
            
            for prev_label in prev_labels:
                trk_dict['tracks'][prev_label]['parent'] = next_label
                
            next_label += 1
            
            
        # process unseen cells
        for curr_label, flag in present_dict.items():
            if not flag:
                y_tracked = map_new_label_to_image(trk_dict['y_tracked'][frame_id], y_tracked, curr_label, next_label)
                trk_dict['tracks'][next_label] = {
                                    'label': next_label,
                                    'frames': [frame_id],
                                    'daughters': [],
                                    'capped': False,
                                    'frame_div': None,
                                    'parent': None,
                                }
                next_label += 1
        
        trk_dict['y_tracked'][frame_id] = y_tracked
        prev_box_dict = get_all_box(y_tracked)

    
    return trk_dict



def backtracking(X, y):
    init_trk_dict, prev_box_dict = init_dict(X, y)
    track_gt = finalize_dict(init_trk_dict, prev_box_dict, X, y)
    return track_gt

In [8]:
def plot(x, y, ymax):
    yy = copy.deepcopy(y)
    yy = yy.astype(np.float64)
    yy = np.ma.masked_equal(yy, 0)
    yy /= ymax / 30
    
    fig, ax = plt.subplots(1, 2, figsize=(12, 6))
    ax[0].imshow(x, cmap='Greys_r')
    ax[0].axis('off')
    ax[0].set_title('Raw')
    ax[1].imshow(x, cmap='Greys_r')
    ax[1].imshow(yy)
    ax[1].set_title('Tracked')
    ax[1].axis('off')

    fig.canvas.draw()  # draw the canvas, cache the renderer
    image = np.frombuffer(fig.canvas.tostring_rgb(), dtype='uint8')
    image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,))
    plt.close(fig)

    return image



def generate_images(x, y_tracked, filename):
    ymax = np.max(y_tracked)
#     cmap = shuffle_colors(ymax, 'tab20')
    imageio.mimsave(
        filename,
        [plot(x[i,...,0], y_tracked[i,...,0], ymax)
         for i in range(y_tracked.shape[0])]
    )

In [9]:
# Load test data
data = load_trks(source_data)

In [10]:
# Download and load each model
models = {}
for m, url in model_urls.items():
    archive_path = tf.keras.utils.get_file(f'{m}.tgz', url, extract=True, cache_subdir='models')
    model_path = os.path.splitext(archive_path)[0]
    model = tf.keras.models.load_model(model_path)
    models[m] = model

Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearSegmentation-75.tar.gz
Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingNE-75.tar.gz
Downloading data from https://deepcell-data.s3-us-west-1.amazonaws.com/saved-models/NuclearTrackingInf-75.tar.gz


In [11]:
# Load segmentation and tracking applications
app_seg = NuclearSegmentation(models['NuclearSegmentation'])
app_trk = CellTracking(models['NuclearTrackingInf'], models['NuclearTrackingNE'])

In [12]:
for batch_no in range(len(data['lineages'])):
    print('batch number:', batch_no)
    
    # Pull out relevant data for this batch
    X = data['X'][batch_no]
    y = data['y'][batch_no]
    lineage = data['lineages'][batch_no]
    
    # Correct discontiguous tracks, which are not allowed by CTC
    y, lineage = utils.convert_to_contiguous(y, lineage)
    
    # Determine position of zero padding for removal
    slc = utils.find_zero_padding(X)
    X = X[slc]
    y = y[slc]

    # Determine which frames are zero padding
    frames = np.sum(y, axis=(1,2)) # True if image not blank
    good_frames = np.where(frames)[0]
    X = X[:len(good_frames)]
    y = y[:len(good_frames)]
    
    print(y.shape)
    
    # Generate tracks on GT segmentations
#     track_gt = app_trk.track(X, y)
    track_gt = backtracking(X, y)
    track_gt['y_tracked'], track_gt['tracks'] = utils.convert_to_contiguous(track_gt['y_tracked'], track_gt['tracks'])
    utils.save_ctc_res(gt_seg_dir, batch_no + 1, track_gt['y_tracked'][..., 0], track_gt['tracks'])
    utils.save_ctc_gt(gt_seg_dir, batch_no + 1, y[..., 0], lineage)
    
    generate_images(X, track_gt['y_tracked'], './benchmarking/images/tracks-gt-' + str(batch_no) + '.tiff')
    
    # Generate tracks on predicted segmentations
    y_pred = app_seg.predict(X)
#     track_pred = app_trk.track(X, y_pred)
    track_pred = backtracking(X, y_pred)
    track_pred['y_tracked'], track_pred['tracks'] = utils.convert_to_contiguous(track_pred['y_tracked'], track_pred['tracks'])
    utils.save_ctc_res(pred_seg_dir, batch_no + 1, track_pred['y_tracked'][..., 0], track_pred['tracks'])
    utils.save_ctc_gt(pred_seg_dir, batch_no + 1, y[..., 0], lineage)
    
    generate_images(X, track_pred['y_tracked'], './benchmarking/images/tracks-pred-' + str(batch_no) + '.tiff')

batch number: 0
(42, 540, 540, 1)
batch number: 1
(42, 540, 540, 1)
batch number: 2
(42, 540, 540, 1)
batch number: 3
(42, 540, 540, 1)
batch number: 4
(42, 540, 540, 1)
batch number: 5
(50, 584, 584, 1)
batch number: 6
(50, 584, 584, 1)
batch number: 7
(71, 512, 512, 1)
batch number: 8
(71, 568, 600, 1)
batch number: 9
(65, 540, 540, 1)
batch number: 10
(45, 540, 540, 1)
batch number: 11
(55, 540, 540, 1)


# Evaluation

In [13]:
!wget http://public.celltrackingchallenge.net/software/EvaluationSoftware.zip
!unzip EvaluationSoftware.zip -d CTC_Evaluation_Software
!chmod u=rwx,g=rwx,o=rwx -R CTC_Evaluation_Software

--2024-03-02 06:20:25--  http://public.celltrackingchallenge.net/software/EvaluationSoftware.zip
Resolving public.celltrackingchallenge.net (public.celltrackingchallenge.net)... 147.251.52.183
Connecting to public.celltrackingchallenge.net (public.celltrackingchallenge.net)|147.251.52.183|:80... connected.
HTTP request sent, awaiting response... 301 Moved Permanently
Location: https://public.celltrackingchallenge.net/software/EvaluationSoftware.zip [following]
--2024-03-02 06:20:25--  https://public.celltrackingchallenge.net/software/EvaluationSoftware.zip
Connecting to public.celltrackingchallenge.net (public.celltrackingchallenge.net)|147.251.52.183|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 16279967 (16M) [application/zip]
Saving to: 'EvaluationSoftware.zip'


2024-03-02 06:20:25 (70.1 MB/s) - 'EvaluationSoftware.zip' saved [16279967/16279967]

Archive:  EvaluationSoftware.zip
  inflating: CTC_Evaluation_Software/Evaluation software.pdf 

In [14]:
import glob
import os
import re
import subprocess

import numpy as np
import pandas as pd

from deepcell_tracking.metrics import TrackingMetrics

In [15]:
data_dir = '/kaggle/working/benchmarking/DeepCell/data'
gt_seg_dir = os.path.join(data_dir, 'SEG_GT')
pred_seg_dir = os.path.join(data_dir, 'SEG_PRED')

pattern = re.compile('\d{3}_GT')
data_ids = [f.split('_')[0] for f in os.listdir(gt_seg_dir) if pattern.fullmatch(f)]

node_match_threshold = 0.6

ctc_software = '/kaggle/working/CTC_Evaluation_Software'
operating_system = 'Linux' # or 'Mac' or 'Win'
num_digits = '3'

In [16]:
benchmarks = []

for results_dir, s in zip([gt_seg_dir, pred_seg_dir], ['GT', 'Deepcell']):
    for data_id in data_ids:
        results = {
            'model': f'CellBacTrack - {s}',
            'data_id': os.path.splitext(data_id)[0]
        }
        gt_dir = os.path.join(results_dir, f'{data_id}_GT/TRA')
        res_dir = os.path.join(results_dir, f'{data_id}_RES')
        
        # Deepcell benchmarking
        m = TrackingMetrics.from_isbi_dirs(gt_dir, res_dir, threshold=node_match_threshold)
        results.update(m.stats)
        
        # CTC metrics
        for metric, path in [('DET', 'DETMeasure'), ('SEG', 'SEGMeasure'), ('TRA', 'TRAMeasure')]:
            p = subprocess.run([os.path.join(ctc_software, operating_system, path), results_dir, data_id, num_digits],
                               stdout=subprocess.PIPE)
            outstring = p.stdout
            
            try:
                val = float(outstring.decode('utf-8').split()[-1])
                results[metric] = val
            except:
                print('Benchmarking failure', path, results_dir, data_id)
                print(outstring.decode('utf-8'))
        
        benchmarks.append(results)

df = pd.DataFrame(benchmarks)
df.to_csv('benchmarks.csv')

missed node 1_29 division completely
missed node 21_34 division completely
missed node 26_25 division completely
missed node 17_24 division completely
missed node 29_17 division completely
missed node 77_16 division completely
missed node 83_18 division completely
missed node 104_34 division completely
missed node 121_42 division completely
missed node 57_10 division completely
missed node 60_0 division completely
18_16 out degree = 2, daughters mismatch, gt and res degree equal.
missed node 42_38 division completely
missed node 1_29 division completely
missed node 3_23 division completely
corrected division 3_23 as a frameshift division not an error
missed node 21_34 division completely
26_25 out degree = 2, daughters mismatch.
missed node 14_10 division completely
missed node 37_35 division completely
missed node 41_26 division completely
missed node 50_42 division completely
missed node 54_28 division completely
missed node 73_5 division completely
corrected division 54_28 as a fram

In [17]:
df

Unnamed: 0,model,data_id,correct_division,mismatch_division,false_positive_division,false_negative_division,total_divisions,aa_tp,aa_total,te_tp,te_total,DET,SEG,TRA
0,CellBacTrack - GT,5,2,0,0,0,2,690,690,712,712,1.0,1.0,1.0
1,CellBacTrack - GT,2,1,0,0,0,1,1075,1075,1109,1109,1.0,1.0,1.0
2,CellBacTrack - GT,3,7,0,0,1,8,2049,2055,2116,2122,1.0,1.0,0.999836
3,CellBacTrack - GT,7,1,0,0,0,1,199,199,206,206,1.0,1.0,1.0
4,CellBacTrack - GT,1,1,0,0,2,3,997,997,1030,1030,1.0,1.0,0.999703
5,CellBacTrack - GT,11,14,0,3,1,15,3782,3833,3924,3975,1.0,1.0,0.999638
6,CellBacTrack - GT,12,13,0,47,3,16,9245,9796,9546,10097,1.0,1.0,0.998436
7,CellBacTrack - GT,6,0,0,5,0,0,336,383,350,397,1.0,1.0,0.995819
8,CellBacTrack - GT,9,55,0,40,2,57,14265,14806,14633,15174,1.0,1.0,0.999414
9,CellBacTrack - GT,8,18,0,0,2,20,4945,4945,5056,5056,1.0,1.0,0.999966
