# Figures



## Imports

In [None]:
import os
import pathlib

import pandas as pd
import numpy as np
import cv2 as cv
import imageio
import matplotlib.pyplot as plt
import matplotlib.patches as mpatches
import seaborn as sns

import skimage
from skimage import io
from skimage import color
from skimage import filters
from skimage.exposure import equalize_adapthist
from skimage.exposure import adjust_gamma

from PIL import Image

## Specify which predictions to use and create corresponding figure folder

**Note:** Navigate to `/repository_folder/data/data_preprocessed/mask_predicted` and choose a folder name from there.

In [None]:
pred_dir_pick = "U-net-from-scratch_astroS20E50_cortS60E10_shsy5yS24E25" # Folder from which predictions are taken

curr_dir = os.getcwd()
parent_dir = pathlib.Path(curr_dir).parents[1]

pred_dir = f"{parent_dir}/data/data_preprocessed/mask_predicted/{pred_dir_pick}/"

figures_dir = f"{parent_dir}/notebooks/U-net/figures"
fig_dir = f"{parent_dir}/notebooks/U-net/figures/{pred_dir_pick}"

try:
    os.mkdir(figures_dir)
except:
    print("Figure directory already exists")
    
try:
    os.mkdir(fig_dir)
except:
    print("Figure directory already exists")

In [None]:
# Path to the original images
img_path = f"{parent_dir}/data/data_original/train/"

# Path to the original masks
msk_path = f"{parent_dir}/data/data_preprocessed/masks_used_for_model/"

# Path to the predicted masks
prd_path = f"{parent_dir}/data/data_preprocessed/mask_predicted/{pred_dir_pick}/"

## Load the IoU data

In [None]:
df_iou = pd.read_csv(f"{pred_dir}IoU.csv")
df_iou.head()

## Figure: IoU histogram

In [None]:
for dataset in ["train", "eval", "test"]:

    df2 = df_iou[df_iou["dataset"]==dataset]
    uniq_cell_type = df2["cell_type"].unique()
    uniq_cell_type.sort()

    fig, axes = plt.subplots(1, 3)
    fig.set_figheight(4)
    fig.set_figwidth(12)
    fig.suptitle(f"IoU distributions per cell type ({dataset} data only)", fontsize=20)

    for i, cell_type in enumerate(uniq_cell_type):
        df3 = df2[df2["cell_type"]==cell_type]

        g = sns.histplot(ax = axes[i], data = df3, x = "IoU", binwidth = 0.05)
        
        # Mean IoU
        mean_iou = np.mean(df3["IoU"])
        
        if dataset != "test":
            axes[i].set(xlim = (-0.01, 1), ylim = (0, 60))
            axes[i].vlines(mean_iou, 0, 55, colors='k', linestyles='dashed') 
        else:
            axes[i].set(xlim = (-0.01, 1), ylim = (0, 25))
            axes[i].vlines(mean_iou, 0, 22, colors='k', linestyles='dashed') 

        trans = axes[i].get_xaxis_transform()
        axes[i].text(mean_iou+0.02,0.875,str(round(mean_iou,2)), transform=trans, fontsize=14)

        axes[i].set_title(cell_type, fontsize=16)
        axes[i].set_xlabel('IoU', fontsize=14)
        axes[i].set_ylabel('N images', fontsize=14)

        sns.despine(top = True, right = True)

        plt.tight_layout()
        plt.savefig(f"{fig_dir}/mean_IoU_{dataset}.png", dpi=300, transparent=True)

## Define functions for image transformation

In [None]:
def make_IoU_img(msk, prd):
    prd = np.expand_dims(prd, axis = 2)
    msk = np.expand_dims(msk, axis = 2)
    
    r = ((prd!=0) | (msk!=0)).astype(dtype = np.uint8) * 255
    g = ((prd!=0) & (msk!=0)).astype(dtype = np.uint8) * 255
    b = np.zeros(r.shape)

    img_iou = np.concatenate((r, g, b), axis = 2)
    
    return img_iou


def make_mask(img, transparent_background = False):
        g = np.expand_dims(img!=0, axis = 2).astype(dtype = np.uint8) * 255
        r = np.zeros(g.shape)
        b = np.zeros(g.shape)
        
        img = np.concatenate((r, g, b), axis = 2)

        # Use Pillow to make every black pixel fully transparent
        if transparent_background:
            img = Image.fromarray(img.astype(np.uint8))
            img = img.convert("RGBA")

            datas = img.getdata()
            newData = []
            for item in datas:
                if item[0] == 0 and item[1] == 0 and item[2] == 0:
                    newData.append((255, 255, 255, 0))
                else:
                    newData.append(item)

            img.putdata(newData)

        return img


def enhance_image(img):
    img = equalize_adapthist(img, clip_limit=0.025)
    img = adjust_gamma(img, gamma=0.6,gain=1)
    
    return img


def create_images(img_id, img_path, msk_path, prd_path):
    # Store the following in a dictionary for visualization:
    # - enhanced image (more contrast and brightness)
    # - original/predicted mask with black background and bright cell segments
    # - original/predicted mask with transparent background for image-mask overlay
    # - IoU visualization
    
    # Read image, original mask and prediction
    img = imageio.imread(f"{img_path}{img_id}.png")
    msk = imageio.imread(f"{msk_path}{img_id}_mask.png")
    prd = imageio.imread(f"{prd_path}masks/{img_id}_pred.png")

    # Enhance the image's contrast and brightness and store in dictionary
    img_dict = {"image": enhance_image(img)}
    # img_dict = {"image": img} # show the unenhanced image instead
    
    # Convert prediction and mask colors and store without/with transparent background
    img_dict["msk_orig"] = make_mask(msk)
    img_dict["msk_pred"] = make_mask(prd)
    img_dict["msk_orig_trans"] = make_mask(msk, transparent_background = True)
    img_dict["msk_pred_trans"] = make_mask(prd, transparent_background = True)
    
    # Create and store IoU
    img_dict["IoU"] = make_IoU_img(msk, prd)
    
    return img_dict

## Figure: Enhanced image with predicted mask overlay

In [None]:
# To be visualized image IDs
img_ids = ["0140b3c8f445", "6b165d790e33", "db8bc8f09776"]

In [None]:
for img_id in img_ids:
    
    # Read IoU and cell type
    IoU       = df_iou.loc[df_iou["id"]==img_id, "IoU"].values[0]
    cell_type = df_iou.loc[df_iou["id"]==img_id, "cell_type"].values[0]
    
    # Read and convert images and masks
    img_dict = create_images(img_id, img_path, msk_path, prd_path)
    
    # Plot original image with mask overlay
    fig = plt.figure(figsize = (10, 10))
    plt.imshow(img_dict["image"], cmap = "gray")
    plt.imshow(img_dict["msk_pred_trans"], alpha = 0.3)
    plt.axis('off')
    plt.savefig(f"{fig_dir}/pred_overlay_{cell_type}_{img_id}.png", dpi=500)

## Figure: IoU example

In [None]:
for img_id in img_ids:
    IoU       = df_iou.loc[df_iou["id"]==img_id, "IoU"].values[0]
    cell_type = df_iou.loc[df_iou["id"]==img_id, "cell_type"].values[0]
    
    # Read and convert images and masks
    img_dict = create_images(img_id, img_path, msk_path, prd_path)
    
    fig, ax = plt.subplots(1, 3)
    fig.set_figheight(3.3)
    fig.set_figwidth(10)
    fig.suptitle("Cell type: {}, intersection over union: {}".format(cell_type, round(IoU,2)),
                 fontsize = 16)
    
    ax[0].imshow(img_dict["msk_pred"])
    ax[0].set_title("Prediction")
    
    ax[1].imshow(img_dict["msk_orig"])
    ax[1].set_title("Original")
    
    ax[2].imshow(img_dict["IoU"])
    ax[2].set_title("Intersection over Union")
    
    for i in range(0,3):
        ax[i].set_xticks([])
        ax[i].set_yticks([])
        ax[i].set_xticks([])
        ax[i].set_yticks([])
    
    plt.tight_layout()
    plt.savefig(f"{fig_dir}/example_pred2_{cell_type}_{img_id}.png", dpi=300)