## Inferencing Notebook

In [1]:
# import core libaries
import numpy as np
import tkinter as tk
from tkinter import filedialog
root = tk.Tk()
root.withdraw()

import glob
import os
import sys
import tifffile

SCRIPT_DIR = os.path.dirname(os.path.abspath(__vsc_ipynb_file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from src.processing.processing_functions import *

# get working directory
path = os.getcwd()
sys.path.append(path)

# import machine learning libraries
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from monai.inferers.inferer import SlidingWindowInferer, SliceInferer
from monai.networks.nets import BasicUNet, UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)

In [2]:
# initialize cuda if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

## Using MONAI Sliding Window Inferencing

In [3]:
# experiment = "+s_+d_-f"
model_soma_dendrite = "2D_Soma+Dendrite.pth"
yaml_filename = "20230216_225012_ResUNet_experiment.yml"
parent_folder = "C:\\Users\\Fungj\\Google Drive\\Masters Project\\Segmentation-Model\\3D-Neuron-Segment\\results\\ResUNET\\20230216\\1_512_512_+s_+d_-f" # windows
# parent_folder = "/Users/jasonfung/Google Drive/Masters Project/Segmentation-Model/3D-Neuron-Segment/results/ResUNET/20230113/+s_+d_-f" # mac
yaml_file = f"{parent_folder}/{yaml_filename}"

In [4]:
import yaml
from yaml.loader import SafeLoader
with open(yaml_file) as f:
    config = yaml.load(f, Loader=SafeLoader)
    print(config)

{'DATASET': {'AUGMENTATION': {'augment': True, 'gamma_lower': -0.5, 'gamma_upper': 0.6, 'mean_noise': 0.05, 'std_noise': 0.025, 'x_deg': 0, 'y_deg': 0, 'z_deg': 25}, 'artifacts': [7], 'axial_steps': 1, 'batch_size': 32, 'ex_autofluorescence': False, 'ex_melanocytes': True, 'exp': '+s_+d_-f', 'folds': 4, 'lateral_steps': 512, 'remove_artifacts': False, 'x_patch': 512, 'y_patch': 512, 'z_patch': 1}, 'MODEL': {'channel_layers': [64, 128, 256, 512, 1024], 'dropout': 0.15, 'input_dim': 1, 'l2': 0.00421, 'learning_rate': 7.54e-05, 'model_arch': 'UNET', 'norm': 'batch', 'num_res_units': 2, 'spatial_dim': 2, 'strides': [2, 2, 2, 2]}, 'RESULTS': {'log_file_path': '/home/jsfung/projects/def-haas/jsfung/results/ResUNet/20230216/1_512_512_+s_+d_-f/225012_log_fold_3', 'model_states_path': '/home/jsfung/projects/def-haas/jsfung/results/ResUNet/20230216/1_512_512_+s_+d_-f/+s_+d_-f_2D_ResUNet_3_141.pth'}, 'TRAINING': {'shuffle': True}, 'date': 'now', 'end_cycle': 20, 'loss': 'dice_ce', 'mask_path': '/

In [5]:
from torch.utils.data import DataLoader

In [6]:
# pick test image
raw_path = filedialog.askopenfilename()
raw_img = glob.glob(raw_path)

mask_path = filedialog.askopenfilename()
mask_img = glob.glob(mask_path)
# raw_img = ['E:\\Image_Folder\\SLAP2_Images\\neuron_threshold.tif']
# raw_img = ['E:\\Image_Folder\\Raw\\000_B_181107_A_N1B2_4a61736f.tif']
# mask_img = ['E:\\Image_Folder\\Mask\\000_B_181107_A_N1B2_4a61736f.tif']
# experiment = "+s_+d_-f"
segmentation_exp = "+s_+d_-f"
ex_autofluor = False # True/False
ex_melanocytes = True # True/False
dim_order = (0,4,1,2,3) # define the image and mask dimension order
output_chnl = 4

patch_transform = transforms.Compose([MinMaxScalerVectorized()])
label_transform = transforms.Compose([process_masks(exp = segmentation_exp,
                                                    ex_autofluor=ex_autofluor,
                                                    ex_melanocytes=ex_melanocytes,
                                                     )])


processed_set = WholeVolumeDataset(raw_directory = raw_img,
                                   mask_directory= mask_img,
                                   num_classes = output_chnl,
                                   raw_transform = patch_transform,
                                   label_transform = label_transform,
                                   mask_order = dim_order,
                                   device = device,
                                   )
processed_dataloader = DataLoader(processed_set, batch_size=1, shuffle= False)

raw, mask = next(iter(processed_dataloader))
mask = torch.squeeze(mask,dim=0)

reading from list
Reading Mask from list


In [7]:
pred_results = []
for fold_idx in range(4):
    model_state = f"{parent_folder}/+s_+d_-f_2D_ResUNet_{fold_idx}.pth"

    lateral_steps = config['DATASET']['lateral_steps']
    axial_steps = config['DATASET']['axial_steps']
    if config['MODEL']['spatial_dim'] == 3:
        patch_size = (axial_steps, lateral_steps, lateral_steps)
    else:
        patch_size = (lateral_steps, lateral_steps)
    batch_size = config['DATASET']['batch_size']
    input_chnl = 1
    output_chnl = 4
    norm_type = config['MODEL']['norm']
    dropout = 0.1

    model = UNet(spatial_dims=config['MODEL']['spatial_dim'], 
                in_channels = input_chnl,
                out_channels = output_chnl,
                channels = config['MODEL']['channel_layers'],
                strides=config['MODEL']['strides'],
                num_res_units=config['MODEL']['num_res_units'],
                norm = norm_type,
                dropout = dropout)

    checkpoint = torch.load(model_state, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    model = model.to(device)

    # inferer = SlidingWindowInferer(roi_size=patch_size, sw_batch_size=batch_size, progress=True)
    inferer = SliceInferer(roi_size=patch_size, sw_batch_size=batch_size, spatial_dim = 0, progress=True)
    # predict using shifted windows

    with torch.no_grad():
        pred = inferer(inputs = raw, network=model)
    pred_results.append(pred)

100%|██████████| 6/6 [03:51<00:00, 38.51s/it]
100%|██████████| 6/6 [03:51<00:00, 38.65s/it]
100%|██████████| 6/6 [03:48<00:00, 38.11s/it]
100%|██████████| 6/6 [03:57<00:00, 39.58s/it]


In [8]:
print(model.parameters)

<bound method Module.parameters of UNet(
  (model): Sequential(
    (0): ResidualUnit(
      (conv): Sequential(
        (unit0): Convolution(
          (conv): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
          (adn): ADN(
            (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (D): Dropout(p=0.1, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
        (unit1): Convolution(
          (conv): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (adn): ADN(
            (N): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (D): Dropout(p=0.1, inplace=False)
            (A): PReLU(num_parameters=1)
          )
        )
      )
      (residual): Conv2d(1, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (1): SkipConnection(
      (submodule): Sequential(
        (0): ResidualUnit(
          (conv): Sequ

In [42]:
pred_from_categorical = np.squeeze(to_numpy(torch.argmax(pred, 1)),axis=0)
gt_from_categorical = np.squeeze(to_numpy(torch.argmax(mask,1)),axis=0)

In [27]:
tifffile.imwrite("test_inference.tif", pred_from_categorical.astype(int))

## Calculate IOU with Respect to Only Ground Truth

In [9]:
from skimage import morphology
# preprocess the predictions and its categories
dim_order = (0,4,1,2,3)
pred_cat_list = []
for fold_idx in range(4):
    pred_cat = to_categorical_torch(torch.argmax(pred_results[fold_idx], 1), 4)
    pred_cat = to_numpy(torch.permute(pred_cat,dim_order))
    pred_cat_list.append(pred_cat)
    # img_proc_pred_cat = pred_cat.clone()
    
    # img_proc_pred_cat = to_numpy(torch.permute(img_proc_pred_cat,dim_order))
# mask = to_numpy(mask)

# mask[mask==3] = 2 # change from filopodia to dendrites

In [163]:
dendrite_min_size = 50
soma_min_size = 50

for i in range(1,3):
    print(i)
    if i == 1:
        img_proc_pred_cat[0,i,...] = morphology.remove_small_objects(img_proc_pred_cat[0,i,...], min_size=soma_min_size, connectivity=1)
    if i == 2:
        img_proc_pred_cat[0,i,...] = morphology.remove_small_objects(img_proc_pred_cat[0,i,...], min_size=dendrite_min_size, connectivity=1)

1
2


In [126]:
scores = []
for i in range(3):
    intersection = np.count_nonzero(pred_cat[0,i,...]*mask[0,i,...])
    pseudo_iou = intersection/np.count_nonzero(mask[0,i,...])
    scores.append(pseudo_iou)

In [127]:
scores

[0.9995756181268031, 0.6719045165722136, 0.7671754819105499]

## Try Using Morphological Gradient

In [170]:
from skimage import morphology

In [193]:
kernel = morphology.ball(radius=1)
dilated_image = morphology.binary_dilation(img_proc_pred_cat[0,2,...].astype(np.uint16), kernel)
erosion_image = morphology.binary_erosion(img_proc_pred_cat[0,2,...].astype(np.uint16), kernel)

In [194]:
# calculate the dilated and eroded images
dilated_image = dilated_image.astype(np.uint16)
erosion_image = erosion_image.astype(np.uint16)

img_proc_pred_cat_dilated = img_proc_pred_cat.copy()
img_proc_pred_cat_dilated[0,2,...] = dilated_image
img_proc_pred_cat_erosion = img_proc_pred_cat.copy()
img_proc_pred_cat_erosion[0,2,...] = erosion_image

In [197]:
# calculate the gradient of the image

img_proc_pred_cat_gradient = img_proc_pred_cat.copy()
img_proc_pred_cat_gradient[0,2,...] = dilated_image - erosion_image

In [199]:
# calculate the closing of the image: goal is the connect the disconnected areas

img_proc_pred_cat_closing = img_proc_pred_cat.copy()
img_proc_pred_cat_closing[0,2,...] = morphology.binary_closing(img_proc_pred_cat[0,2,...].astype(np.uint16),kernel)



## Calculate Dice Score

In [165]:
print("Pure Dice Scores: ", pure_dice_scores)
print("Processed Dice Scores after getting rid of small objects", processed_dice_scores)

Pure Dice Scores:  tensor([[0.7349, 0.7851, 0.0339]])
Processed Dice Scores after getting rid of small objects tensor([[0.7422, 0.7880, 0.0339]])


## View Images on Napari

In [17]:
import napari
viewer = napari.Viewer()

v0.5.0. It is considered an "implementation detail" of the napari
application, not part of the napari viewer model. If your use case
requires access to qt_viewer, please open an issue to discuss.
  self.tools_menu = ToolsMenu(self, self.qt_viewer.viewer)


In [26]:
fold_idx = 2
from monai.metrics import compute_meandice

scores = []
for i in range(3):
    intersection = np.count_nonzero(pred_cat_list[fold_idx][0,i,...]*to_numpy(mask[0,i,...]))
    pseudo_iou = intersection/np.count_nonzero(mask[0,i,...])
    scores.append(pseudo_iou)

print("Intersection WRT Ground Truth: ", scores[1:])
pure_dice_scores = compute_meandice(to_torch(pred_cat_list[fold_idx]), to_torch(mask), include_background=False)
print("Pure Dice Scores: ", pure_dice_scores[0][0:2])

pred_img = to_numpy(torch.argmax(pred_results[fold_idx], 1))
# tmp_pred_img = np.zeros_like(pred_img)
# tmp_pred_img[pred_img==2] = 1
viewer.add_labels(pred_img, blending = "additive", name = f"Pure Prediction Fold {fold_idx}")
# viewer.add_labels(pred_img, name = f"Pure Prediction Fold {fold_idx}")

Intersection WRT Ground Truth:  [0.7155246436544737, 0.8013483999620169]




Pure Dice Scores:  tensor([0.7408, 0.7926])


<Labels layer 'Pure Prediction Fold 2' at 0x1c1ed6e7e50>

In [48]:
pure_dice_scores[0][0:2]

tensor([0.6301, 0.7262])

In [192]:
processed_img = np.argmax(img_proc_pred_cat,1)
viewer.add_labels(processed_img, name = "Processed Prediction")

<Labels layer 'Processed Prediction' at 0x1e45e67a6d0>

In [195]:
dilated_img = np.argmax(img_proc_pred_cat_dilated,1)
viewer.add_labels(dilated_img, name = "Dilated Prediction")

<Labels layer 'Dilated Prediction [1]' at 0x1e45e67afa0>

In [196]:
erosion_img = np.argmax(img_proc_pred_cat_erosion,1)
viewer.add_labels(erosion_img, name = "Erosion Prediction")

<Labels layer 'Erosion Prediction [1]' at 0x1e45adcc160>

In [198]:
gradient_img = np.argmax(img_proc_pred_cat_gradient,1)
viewer.add_labels(gradient_img, name = "Gradient Prediction")

<Labels layer 'Gradient Prediction' at 0x1e45aa5be80>

In [200]:
closing_img = np.argmax(img_proc_pred_cat_closing,1)
viewer.add_labels(gradient_img, name = "Closing Prediction")

<Labels layer 'Closing Prediction' at 0x1e45f914eb0>

In [13]:
gt_img = np.argmax(mask,1)
gt_img[gt_img==3] = 0
# tmp_gt_img = np.zeros_like(gt_img)
# tmp_gt_img[gt_img==2] = 1
# viewer.add_image(tmp_gt_img, colormap="magenta", blending = "additive", name = "Ground Truth")
viewer.add_labels(gt_img, name = "Ground Truth")

<Labels layer 'Ground Truth [1]' at 0x1b3a2e207f0>