In [None]:
import os
import skimage
from skimage import io
import matplotlib.pyplot as plt
import numpy as np

In [None]:
im_dir = "../data/images/manual/"
gt_dir = "../data/masks/manual/"

method_name = 'maskrcnn'
# method_name = 'deepcell'

# pred_dir = "../data/masks/deepcell_manual/"
pred_dir = "../data/masks/{}_manual/".format(method_name)

In [None]:
# Where to save the figures
PROJECT_ROOT_DIR = "."
IMAGES_PATH = os.path.join(PROJECT_ROOT_DIR, "images", method_name)
os.makedirs(IMAGES_PATH, exist_ok=True)

def save_fig(fig_id, tight_layout=True, fig_extension="png", resolution=300):
    path = os.path.join(IMAGES_PATH, fig_id + "." + fig_extension)
    print("Saving figure", fig_id)
    if tight_layout:
        plt.tight_layout()
    plt.savefig(path, format=fig_extension, facecolor='white', dpi=resolution)

# def save_jpg(fig_id, tight_layout=True, resolution=300):
#     path = os.path.join(IMAGES_PATH, fig_id + ".jpg")
#     print("Saving figure", fig_id)
#     if tight_layout:
#         plt.tight_layout()
    
#     plt.savefig(path, format="jpg", dpi=resolution, quality=100)

In [None]:
gt_filenames = os.listdir(gt_dir)
im_filenames = os.listdir(im_dir)
pred_filenames = os.listdir(pred_dir)

gt_filenames_full = [os.path.join(gt_dir,s) for s in gt_filenames]
im_filenames_full = [os.path.join(im_dir,s) for s in im_filenames]
pred_filenames_full = [os.path.join(pred_dir,s) for s in gt_filenames]


## analysis data and results

In [None]:
idx = 2
gt_image = io.imread(gt_filenames_full[idx]) > 0
pred_image = np.squeeze(io.imread(pred_filenames_full[idx])) > 0

# print("image shape: {}".format(gt_image.shape))

# plt.figure(figsize=(8,16))
# plt.subplot(1,2,1)
# plt.imshow(gt_image[len(gt_image)//2])
# plt.title("ground truth")
# plt.axis("off")

# plt.subplot(1,2,2)
# plt.imshow(pred_image[len(pred_image)//2])
# plt.title("prediction")
# plt.axis("off")


## metrics definition

In [None]:
def jaccard_index(y_true, y_pred):
    """
    y_true and y_pred are both boolean arrays.
    this jaccard index is valid only if only one object is in the image. 
    """
    intersection = np.sum(np.logical_and(y_true, y_pred).astype(int))
    union = np.sum(np.logical_or(y_true, y_pred).astype(int))
    if union > 0:
        return intersection/union
    else:
        print('bizarre!')
        return 0

In [None]:
def dice_coef(y_true, y_pred):
    intersection = np.sum(np.logical_and(y_true, y_pred).astype(int))
    return (2*intersection)/(np.sum(y_true.astype(int))+np.sum(y_pred.astype(int)))

## apply the metrics

In [None]:
jaccard_index(gt_image, pred_image)

In [None]:
dice_coef(gt_image, pred_image)

## average over the whole dataset

In [None]:
dice = []
jaccard = []

for idx in range(len(gt_filenames_full)):
    gt_image = io.imread(gt_filenames_full[idx]) > 0
    pred_image = io.imread(pred_filenames_full[idx]) > 0
    
    dice += [dice_coef(gt_image, pred_image)]
    jaccard += [jaccard_index(gt_image, pred_image)]

    print(dice[-1])
    print(jaccard[-1])

In [None]:
np.mean(dice)

In [None]:
np.mean(jaccard)

## Stats

In [None]:
plt.hist(dice,bins=20,rwidth=0.5)
plt.xlabel('dice index')
plt.ylabel('number of images')
plt.grid('on')
save_fig('dice')
plt.show()


In [None]:
plt.hist(jaccard)

## display predicted volume

In [None]:
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import random
from skimage.measure import regionprops
#from scipy import ndimage

In [None]:
# load ground truth and prediction
def generate_3d(idx):
    # generate 3d images with the predictions of maskrcnn, deepcell, 
    # the ground truth and the z-projection of the original image
    
    default_color = ["#138FD080"]

    y_true = io.imread(gt_filenames_full[idx])
    raw_im = io.imread(im_filenames_full[idx])


    def load_pred_method(method_name, idx):
        # loads pred from a given method
        pred_dir = "../data/masks/{}_manual/".format(method_name)
        pred_filenames = os.listdir(pred_dir)
        pred_filenames_full = [os.path.join(pred_dir,s) for s in gt_filenames]
        y_pred = io.imread(pred_filenames_full[idx])
        return y_pred
    y_pred_1 = load_pred_method('maskrcnn', idx)
    y_pred_2 = load_pred_method('deepcell', idx)


    fig = plt.figure(figsize=(15, 15))
    index = 0
    x_min = 0
    x_max = y_pred_1.shape[1]
    y_min = 0
    y_max = y_pred_1.shape[2]
    z_min = 0
    z_max = y_pred_1.shape[0]
    gt_mask = y_true[ z_min:z_max, x_min:x_max, y_min:y_max]

    # Predicted Cell volumes: maskrcnn
    masks = y_pred_1[ z_min:z_max, x_min:x_max, y_min:y_max]

    ax = fig.add_subplot(221, projection='3d')

    ax.set_xlim3d(x_min, x_max)
    ax.set_ylim3d(x_min, y_max)
    ax.set_zlim3d(z_min, z_max)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')

    plot_masks = np.rollaxis(np.squeeze(masks), 0, 3)

    colors = np.empty(plot_masks.shape, dtype='<U9')
    color_dict = dict()

    for label in np.unique(plot_masks):
        if label != 0:
            color = default_color
    #         color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
        else:
            color = None#'#7A88CCC0'
        color_dict[label] = color
        colors = np.where(plot_masks==label, color, colors)
    ax.voxels(plot_masks, facecolors=colors)
    ax.set_title('Predicted nucleus volume: Mask R-CNN')


    # Predicted Cell volumes: deepcell
    masks = y_pred_2[ z_min:z_max, x_min:x_max, y_min:y_max]

    ax = fig.add_subplot(222, projection='3d')

    ax.set_xlim3d(x_min, x_max)
    ax.set_ylim3d(x_min, y_max)
    ax.set_zlim3d(z_min, z_max)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')

    plot_masks = np.rollaxis(np.squeeze(masks), 0, 3)

    colors = np.empty(plot_masks.shape, dtype='<U9')
    color_dict = dict()

    for label in np.unique(plot_masks):
        if label != 0:
            color = default_color
    #         color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
        else:
            color = None#'#7A88CCC0'
        color_dict[label] = color
        colors = np.where(plot_masks==label, color, colors)
    ax.voxels(plot_masks, facecolors=colors)
    ax.set_title('Predicted nucleus volume: DeepCell')



    # Ground truth Cell Volumes
    ax = fig.add_subplot(224, projection='3d')
    ax.set_xlim3d(0,x_max)
    ax.set_ylim3d(0,y_max)
    ax.set_zlim3d(z_min, z_max)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')

    plot_masks = np.rollaxis(np.squeeze(gt_mask), 0, 3)
    colors = np.empty(plot_masks.shape, dtype='<U9')
    color_dict = dict()

    for label in np.unique(plot_masks):
        if label != 0:
            color = default_color
    #         color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
        else:
            color = None#'#7A88CCC0'
        color_dict[label] = color
        colors = np.where(plot_masks==label, color, colors)

    ax.voxels(plot_masks, facecolors=colors)
    ax.set_title('Manually segmented nucleus volume')

    # Z-projection of the raw image of the nucleus
    ax = fig.add_subplot(223)
    ax.set_title("Z-projection of the original image of the nucleus")
    ax.imshow(np.sum(raw_im, axis=0),cmap='gray')
    ax.set_xlabel('Y axis')
    ax.set_ylabel('X axis')

    save_fig('3d_{}'.format(gt_filenames[idx]))
    plt.show()

for i in range(1):
    generate_3d(i)

In [None]:

from skimage.filters import threshold_otsu
from skimage.measure import label, regionprops
import skimage.measure
from skimage.transform import resize
from mpl_toolkits.axes_grid1.anchored_artists import AnchoredSizeBar
import matplotlib.font_manager as fm


In [None]:
# load ground truth and prediction
def generate_3d_data_paper_cta(idx):
    # generate 3d images with the predictions of maskrcnn, deepcell, 
    # the ground truth and the z-projection of the original image
    
    default_color = ["#138FD080"]

    # gt image and raw image
    y_true = io.imread(gt_filenames_full[idx])
    print(gt_filenames_full[idx])
    raw_im = io.imread(im_filenames_full[idx])
    print(im_filenames_full[idx])


    def load_pred_method(method_name, idx):
        # loads pred from a given method
        pred_dir = "../data/masks/{}_manual/".format(method_name)
        pred_filenames = os.listdir(pred_dir)
        pred_filenames_full = [os.path.join(pred_dir,s) for s in gt_filenames]
        y_pred = io.imread(pred_filenames_full[idx])
        print(pred_filenames_full[idx])
        return y_pred
    
    
    y_pred_1 = load_pred_method('maskrcnn', idx)


    
    # Figure definition
    fig = plt.figure(figsize=(22, 8), facecolor='white')

    # crop 3d figures
    ## retrieve the bounding box around the nucleus
    thresh = threshold_otsu(raw_im)
    binary = raw_im > thresh
    connec = skimage.measure.label(binary)
    props = regionprops(connec)
    areas = [prop.area for prop in props]
    max_area = np.argmax(areas)
    
    def crop_bbox(img, bbox, margin=20):
        """
        bbox: (ax,ay,az,bx,by,bz)
        """
        # adds margin and assert it is in the right range
        bbox_a = np.array(bbox)[:3]-margin
        bbox_b = np.array(bbox)[3:]+margin
        bbox_ = np.append(bbox_a, bbox_b)
        im_shape = img.shape

        for i in range(2):
            for j in range(3):
                if bbox_[i*3+j] >= img.shape[j]:
                    bbox_[i*3+j] = img.shape[j] - 1
                elif bbox_[i*3+j] < 0:
                    bbox_[i*3+j] = 0
        ax,ay,az,bx,by,bz = np.array(bbox_)
        return img[ax:bx,ay:by,az:bz], np.array(bbox_)

    raw_im, y_true, y_pred_1 = [crop_bbox(im, props[max_area].bbox, margin=10)[0] for im in [raw_im, y_true, y_pred_1]]
    
    index = 0
    x_min = 0
    x_max = y_pred_1.shape[1]
    y_min = 0
    y_max = y_pred_1.shape[2]
    z_min = 0
    z_max = y_pred_1.shape[0]
    gt_mask = y_true[ z_min:z_max, x_min:x_max, y_min:y_max]
    
    
    
    
    # Predicted Cell volumes: maskrcnn
    masks = y_pred_1[ z_min:z_max, x_min:x_max, y_min:y_max]

    ax = fig.add_subplot(133, projection='3d')

    ax.set_xlim3d(x_min, x_max)
    ax.set_ylim3d(x_min, y_max)
    ax.set_zlim3d(z_min, z_max)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')

    plot_masks = np.rollaxis(np.squeeze(masks), 0, 3)

    colors = np.empty(plot_masks.shape, dtype='<U9')
    color_dict = dict()

    for label in np.unique(plot_masks):
        if label != 0:
            color = default_color
    #         color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
        else:
            color = None#'#7A88CCC0'
        color_dict[label] = color
        colors = np.where(plot_masks==label, color, colors)
    ax.voxels(plot_masks, facecolors=colors)
    ax.set_title('Mask R-CNN prediction')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label, ax.zaxis.label]):
        item.set_fontsize(25)
    for item in (ax.get_xticklabels() + ax.get_yticklabels() + ax.get_zticklabels()):
        item.set_fontsize(12)


    # Ground truth Cell Volumes
    ax = fig.add_subplot(132, projection='3d')
    ax.set_xlim3d(0,x_max)
    ax.set_ylim3d(0,y_max)
    ax.set_zlim3d(z_min, z_max)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Y axis')
    ax.set_zlabel('Z axis')

    plot_masks = np.rollaxis(np.squeeze(gt_mask), 0, 3)
    colors = np.empty(plot_masks.shape, dtype='<U9')
    color_dict = dict()

    for label in np.unique(plot_masks):
        if label != 0:
            color = default_color
    #         color = ["#"+''.join([random.choice('0123456789ABCDEF') for j in range(8)])]
        else:
            color = None#'#7A88CCC0'
        color_dict[label] = color
        colors = np.where(plot_masks==label, color, colors)

    ax.voxels(plot_masks, facecolors=colors)
    ax.set_title('Semi-automatic segmentation with ilastik')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label, ax.zaxis.label]):
        item.set_fontsize(25)
    for item in (ax.get_xticklabels() + ax.get_yticklabels() + ax.get_zticklabels()):
        item.set_fontsize(12)
    
    
    # Z-projection of the raw image of the nucleus
    ax = fig.add_subplot(131)
    ax.set_title("Y-projection of the raw image")
    proj = np.sum(raw_im, axis=2)
    ax.imshow(proj,cmap='gray',origin='lower', aspect=1)
    ax.set_xlabel('X axis')
    ax.set_ylabel('Z axis')
    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label]):
        item.set_fontsize(25)
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(15)
    # scale bar
    scalebar = AnchoredSizeBar(ax.transData,
                           25, '3 μm', 'lower right', 
                           pad=0.1,
                           color='white',
                           frameon=False,
                           size_vertical=1,
                           fontproperties=fm.FontProperties(size=25))

    ax.add_artist(scalebar)

    save_fig('3d_{}'.format(gt_filenames[idx]),tight_layout=False,fig_extension="svg", resolution=30)
#     save_jpg('3d_{}'.format(gt_filenames[idx]),tight_layout=False, resolution=300)
#     plt.show()

for i in range(4,5):
    generate_3d_data_paper_cta(i)

In [None]:
idx = 0
for idx in range(4,5):
    im = io.imread(os.path.join(IMAGES_PATH, '3d_{}'.format(gt_filenames[idx]) + "." + 'png'))
    im = im[:,:,:3]
    io.imsave('images/3d/'+'3d_{}'.format(gt_filenames[idx]) + ".jpg",im,quality=100)