## Inferencing Notebook

In [1]:
# import core libaries
import numpy as np
import tkinter as tk
from tkinter import filedialog
import yaml
from yaml.loader import SafeLoader
root = tk.Tk()
root.withdraw()

import glob
import os
import sys
import tifffile

SCRIPT_DIR = os.path.dirname(os.path.abspath(__vsc_ipynb_file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from src.processing.processing_functions import *

# get working directory
path = os.getcwd()
sys.path.append(path)

# import machine learning libraries
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from monai.inferers.inferer import SlidingWindowInferer, SliceInferer
from monai.networks.nets import BasicUNet, UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)

In [2]:
# initialize cuda if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [None]:
device

In [None]:
# model = "+s+d+f_ResUNet.onnx"
model_soma_dendrite = "Soma+Dendrite.onnx"

In [None]:
# processing raw image
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
batch_size = 64
# split_size = 0.9
dim_order = (0,4,1,2,3) # define the image and mask dimension order

raw_path = filedialog.askopenfilename()
raw_img = glob.glob(raw_path)
orig_shape = tifffile.imread(raw_img).shape

# Use patch transform to normalize and transform ndarray(z,y,x) -> tensor(
patch_transform = transforms.Compose([MinMaxScalerVectorized(),
                                      patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = False)])


processed_test_img = MyImageDataset(raw_list = raw_img,
                                    mask_list = None,
                                    transform = patch_transform,
                                    device = device,
                                    mask_order = dim_order,
                                    num_classes = None,
                                    train=False)

## Using Custom Inferencing

In [None]:

reconstructed_img = inference(processed_test_img, 
                              model, 
                              batch_size, 
                              patch_size, 
                              orig_shape,
                              )

np.unique(reconstructed_img)

if len(np.unique(reconstructed_img))-1 == 2:
    reconstructed_img[reconstructed_img==1] = 2

In [None]:
type(reconstructed_img)

In [None]:
tifffile.imwrite(f'{raw_path}_+s+d+f.tif', reconstructed_img.astype(int))

## Using MONAI Sliding Window Inferencing

In [3]:
# initialize cuda if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [36]:
experiment = "+s_+d_-f"
model_soma_dendrite = "+s_+d_-f_ResUNet_3_121.pth"
model_state = f"C:\\Users\\Fungj\\Google Drive\\Masters Project\\Segmentation-Model\\3D-Neuron-Segment\\results\ResUNET\\20230202\\16_128_128_+s_+d_-f\\{model_soma_dendrite}"
yaml_file = f"C:\\Users\\Fungj\\Google Drive\\Masters Project\\Segmentation-Model\\3D-Neuron-Segment\\results\\ResUNET\\20230202\\16_128_128_+s_+d_-f\\20230202_143720_3D_ResUNet_experiment.yml"

In [37]:
with open(yaml_file) as f:
    config = yaml.load(f, Loader=SafeLoader)
    print(config)

{'DATASET': {'AUGMENTATION': {'augment': True, 'gamma_lower': -0.5, 'gamma_upper': 0.6, 'mean_noise': 0.05, 'std_noise': 0.025, 'x_deg': 0, 'y_deg': 0, 'z_deg': 25}, 'artifacts': [7], 'axial_steps': 16, 'batch_size': 64, 'ex_autofluorescence': False, 'ex_melanocytes': True, 'exp': '+s_+d_-f', 'folds': 4, 'lateral_steps': 128, 'remove_artifacts': False, 'x_patch': 128, 'y_patch': 128, 'z_patch': 16}, 'MODEL': {'channel_layers': [32, 64, 128, 256, 512], 'dropout': 0.15, 'input_dim': 1, 'l2': 0.0, 'learning_rate': 7.54e-05, 'model_arch': 'UNET', 'norm': 'instance', 'num_res_units': 2, 'spatial_dim': 3, 'strides': [2, 2, 2, 2]}, 'RESULTS': {'log_file_path': '/home/jsfung/projects/def-haas/jsfung/results/ResUNet/20230202/16_128_128_+s_+d_-f/143720_log_fold_2', 'model_states_path': '/home/jsfung/projects/def-haas/jsfung/results/ResUNet/20230202/16_128_128_+s_+d_-f/+s_+d_-f_ResUNet_2_149.pth'}, 'TRAINING': {'shuffle': True}, 'date': 'now', 'end_cycle': 20, 'loss': 'dice', 'mask_path': '/home/

In [38]:
lateral_steps = config['DATASET']['lateral_steps']
axial_steps = config['DATASET']['axial_steps']
patch_size = (axial_steps, lateral_steps, lateral_steps)
batch_size = config['DATASET']['batch_size']
input_chnl = 1
output_chnl = 4
norm_type = config['MODEL']['norm']
dropout = 0.1

model = UNet(spatial_dims=3, 
            in_channels = input_chnl,
            out_channels = output_chnl,
            channels = config['MODEL']['channel_layers'],
            strides=config['MODEL']['strides'],
            num_res_units=2,
            norm = norm_type,
            dropout = dropout)

model.load_state_dict(torch.load(model_state, map_location = device))
model = model.to(device)

inferer = SlidingWindowInferer(roi_size=patch_size, sw_batch_size=batch_size,progress=True)

In [39]:
from torch.utils.data import DataLoader

In [31]:
#pick test image
raw_path = filedialog.askopenfilename()
raw_img = glob.glob(raw_path)

# mask_path = filedialog.askopenfilename()
# mask_img = glob.glob(mask_path)

# raw_img = ['E:\\Image_Folder\\Raw\\000_B_181107_A_N1B2_4a61736f.tif']
# mask_img = ['E:\\Image_Folder\\Mask\\000_B_181107_A_N1B2_4a61736f.tif']

segmentation_exp = experiment
ex_autofluor = False # True/False
ex_melanocytes = True # True/False
dim_order = (0,4,1,2,3) # define the image and mask dimension order

patch_transform = transforms.Compose([MinMaxScalerVectorized()])
label_transform = transforms.Compose([process_masks(exp = segmentation_exp,
                                                    ex_autofluor=ex_autofluor,
                                                    ex_melanocytes=ex_melanocytes,
                                                     )])

processed_set = WholeVolumeDataset(raw_directory = raw_img,
                                   num_classes = output_chnl,
                                   raw_transform = patch_transform,
                                   label_transform = label_transform,
                                   mask_order = dim_order,
                                   device = device,
                                   )

# processed_dataloader = DataLoader(processed_set, batch_size=1, shuffle= False)

raw, mask = next(iter(processed_set))


reading from list


In [32]:
raw = torch.unsqueeze(raw, dim = 0)

In [40]:
# predict using shifted windows

with torch.no_grad():
    pred = inferer(inputs = raw, network=model)

100%|██████████| 18/18 [04:23<00:00, 14.65s/it]


In [34]:
pred_from_categorical = to_numpy(torch.argmax(pred, 1))

In [None]:
type(pred_from_categorical)

In [35]:
tifffile.imwrite("SLAP_INFERENCE_TEST_16x128x128.tif", pred_from_categorical.astype(np.uint16))

In [None]:
torch.unique(pred_from_categorical)

In [None]:
# import napari
# viewer = napari.Viewer()
# orig_img = tifffile.imread(raw_img)
# raw_image = viewer.add_image(orig_img, rgb=False)

In [None]:
# label_img = viewer.add_labels(reconstructed_img.astype(int))

## 2D Inferencing using SliceInferer

In [None]:
experiment = "+s_+d_-f"
model_soma_dendrite = "+s_+d_-f_ResUNet_1_77.pth"
model_path = f"C:\\Users\\Fungj\\Google Drive\Masters Project\\Segmentation-Model\\3D-Neuron-Segment\\results\\ResUNET\\20230111\\{experiment}\\{model_soma_dendrite}"

In [None]:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

In [None]:
lateral_steps = 512
patch_size = (lateral_steps, lateral_steps)
batch_size = 1
input_chnl = 1
output_chnl = 4
norm_type = "batch"
dropout = 0.1

model = UNet(spatial_dims=2, 
            in_channels = input_chnl,
            out_channels = output_chnl,
            channels = (32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm = norm_type,
            dropout = dropout)

model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# inferer = SlidingWindowInferer(roi_size=patch_size, sw_batch_size=batch_size)
inferer = SliceInferer(roi_size=patch_size, sw_batch_size=batch_size, spatial_dim = 0)

In [None]:
#pick test image
# raw_path = filedialog.askopenfilename()
# raw_img = glob.glob(raw_path)

# mask_path = filedialog.askopenfilename()
# mask_img = glob.glob(mask_path)
raw_img = ['E:\\Image_Folder\\Raw\\000_ML_20180613_N4_4a61736f.tif']
mask_img = ['E:\\Image_Folder\\Mask\\000_ML_20180613_N4_4a61736f.tif']

segmentation_exp = experiment
ex_autofluor = False # True/False
ex_melanocytes = True # True/False
dim_order = (0,4,1,2,3) # define the image and mask dimension order

patch_transform = transforms.Compose([MinMaxScalerVectorized()])
label_transform = transforms.Compose([process_masks(exp = segmentation_exp,
                                                    ex_autofluor=ex_autofluor,
                                                    ex_melanocytes=ex_melanocytes,
                                                     )])

processed_set = WholeVolumeDataset(raw_directory = raw_img,
                                   mask_directory = mask_img,
                                   num_classes = output_chnl,
                                   raw_transform = patch_transform,
                                   label_transform = label_transform,
                                   mask_order = dim_order,
                                   device = device,
                                   )

# processed_dataloader = DataLoader(processed_set, batch_size=1, shuffle= False)



In [None]:
raw, mask = next(iter(processed_set))

In [None]:
mask.shape

In [None]:
with torch.no_grad():
    pred = inferer(inputs = raw, network=model)

In [None]:
probabilities = torch.softmax(pred,1)

In [None]:
probabilities

In [None]:
pred_from_categorical = to_numpy(torch.argmax(probabilities, 1))

In [None]:
pred_from_categorical.shape

In [None]:
np.unique(pred_from_categorical)

In [None]:
tifffile.imwrite("000_ML_20180613_N4_4a61736f_INFERENCED_fold_1.tif", pred_from_categorical.astype(int))