## Inferencing Notebook

In [5]:
# import core libaries
import numpy as np
import tkinter as tk
from tkinter import filedialog
root = tk.Tk()
root.withdraw()

import glob
import os
import sys
import tifffile

SCRIPT_DIR = os.path.dirname(os.path.abspath(__vsc_ipynb_file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from processing.processing_functions import *

# get working directory
path = os.getcwd()
sys.path.append(path)

# import machine learning libraries
import torch
from torch.utils.data import DataLoader
from torchvision import transforms, utils
from monai.inferers.inferer import SlidingWindowInferer, SliceInferer
from monai.networks.nets import BasicUNet, UNet
from monai.transforms import (
    Activations,
    AsDiscrete,
    Compose,
)

  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: /Users/jasonfung/miniforge3/envs/ml_env/lib/python3.9/site-packages/torchvision/image.so
  Expected in: /Users/jasonfung/miniforge3/envs/ml_env/lib/python3.9/site-packages/torch/lib/libc10.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
# initialize cuda if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [3]:
device

device(type='cpu')

In [4]:
# model = "+s+d+f_ResUNet.onnx"
model_soma_dendrite = "Soma+Dendrite.onnx"

In [None]:
# processing raw image
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
batch_size = 64
# split_size = 0.9
dim_order = (0,4,1,2,3) # define the image and mask dimension order

raw_path = filedialog.askopenfilename()
raw_img = glob.glob(raw_path)
orig_shape = tifffile.imread(raw_img).shape

# Use patch transform to normalize and transform ndarray(z,y,x) -> tensor(
patch_transform = transforms.Compose([MinMaxScalerVectorized(),
                                      patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = False)])


processed_test_img = MyImageDataset(raw_list = raw_img,
                                    mask_list = None,
                                    transform = patch_transform,
                                    device = device,
                                    mask_order = dim_order,
                                    num_classes = None,
                                    train=False)

## Using Custom Inferencing

In [None]:

reconstructed_img = inference(processed_test_img, 
                              model, 
                              batch_size, 
                              patch_size, 
                              orig_shape,
                              )

np.unique(reconstructed_img)

if len(np.unique(reconstructed_img))-1 == 2:
    reconstructed_img[reconstructed_img==1] = 2

In [None]:
type(reconstructed_img)

In [None]:
tifffile.imwrite(f'{raw_path}_+s+d+f.tif', reconstructed_img.astype(int))

## Using MONAI Sliding Window Inferencing

In [22]:
experiment = "+s_+d_-f"
model_soma_dendrite = "+s_+d_-f_ResUNet_2_57.pth"
model_state = f"C:/Users/Fungj/Documents/github/Label_Seg_Program/models/{experiment}/{model_soma_dendrite}"

In [23]:
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
batch_size = 64
input_chnl = 1
output_chnl = 4
norm_type = "batch"
dropout = 0.1

model = UNet(spatial_dims=3, 
            in_channels = input_chnl,
            out_channels = output_chnl,
            channels = (32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm = norm_type,
            dropout = dropout)

model.load_state_dict(torch.load(model_state, map_location = device))
model = model.to(device)

inferer = SlidingWindowInferer(roi_size=patch_size, sw_batch_size=batch_size)

In [24]:
from torch.utils.data import DataLoader

In [25]:
#pick test image
# raw_path = filedialog.askopenfilename()
# raw_img = glob.glob(raw_path)

# mask_path = filedialog.askopenfilename()
# mask_img = glob.glob(mask_path)

raw_img = ['E:\\Image_Folder\\Raw\\000_B_181107_A_N1B2_4a61736f.tif']
mask_img = ['E:\\Image_Folder\\Mask\\000_B_181107_A_N1B2_4a61736f.tif']

segmentation_exp = experiment
ex_autofluor = False # True/False
ex_melanocytes = True # True/False
dim_order = (0,4,1,2,3) # define the image and mask dimension order

patch_transform = transforms.Compose([MinMaxScalerVectorized()])
label_transform = transforms.Compose([process_masks(exp = segmentation_exp,
                                                    ex_autofluor=ex_autofluor,
                                                    ex_melanocytes=ex_melanocytes,
                                                     )])

processed_set = WholeVolumeDataset(raw_directory = raw_img,
                                   mask_directory = mask_img,
                                   num_classes = output_chnl,
                                   raw_transform = patch_transform,
                                   label_transform = label_transform,
                                   mask_order = dim_order,
                                   device = device,
                                   )

# processed_dataloader = DataLoader(processed_set, batch_size=1, shuffle= False)

raw, mask = next(iter(processed_set))


In [9]:
raw.shape

torch.Size([1, 1, 135, 800, 1280])

In [11]:
raw.max()

tensor(1.)

In [26]:
# predict using shifted windows

with torch.no_grad():
    pred = inferer(inputs = raw, network=model)

In [27]:
pred.shape

torch.Size([1, 4, 35, 512, 512])

In [32]:
pred.min()

tensor(-36.2151)

In [28]:
pred_from_categorical = to_numpy(torch.argmax(pred, 1))

In [29]:
type(pred_from_categorical)

numpy.ndarray

In [30]:
tifffile.imwrite("test_inference.tif", pred_from_categorical.astype(int))

In [None]:
torch.unique(pred_from_categorical)

In [None]:
# import napari
# viewer = napari.Viewer()
# orig_img = tifffile.imread(raw_img)
# raw_image = viewer.add_image(orig_img, rgb=False)

In [None]:
# label_img = viewer.add_labels(reconstructed_img.astype(int))

## 2D Inferencing using SliceInferer

In [17]:
from pathlib import Path

In [27]:
cwd = os.getcwd()
path = Path(cwd)
path = path.parent.absolute()
model_soma_dendrite = "2D_Soma+Dendrite.pth"
model_path = os.path.join(path,'models',model_soma_dendrite)

In [28]:
model_path

'/Users/jasonfung/Documents/Label_Seg_Program/models/2D_Soma+Dendrite.pth'

In [29]:
checkpoint = torch.load(model_path, map_location=torch.device('cpu'))

In [35]:
lateral_steps = 512
patch_size = (lateral_steps, lateral_steps)
batch_size = 1
input_chnl = 1
output_chnl = 4
norm_type = "batch"
dropout = 0.1

model = UNet(spatial_dims=2, 
            in_channels = input_chnl,
            out_channels = output_chnl,
            channels = (32, 64, 128, 256, 512),
            strides=(2, 2, 2, 2),
            num_res_units=2,
            norm = norm_type,
            dropout = dropout)

model.load_state_dict(checkpoint['model_state_dict'])
model = model.to(device)

# inferer = SlidingWindowInferer(roi_size=patch_size, sw_batch_size=batch_size)
inferer = SliceInferer(roi_size=patch_size, sw_batch_size=batch_size, spatial_dim = 0)

In [39]:
#pick test image
# raw_path = filedialog.askopenfilename()
# raw_img = glob.glob(raw_path)

# mask_path = filedialog.askopenfilename()
# mask_img = glob.glob(mask_path)
raw_img = ['E:\\Image_Folder\\Raw\\000_ML_20180613_N4_4a61736f.tif']
mask_img = ['E:\\Image_Folder\\Mask\\000_ML_20180613_N4_4a61736f.tif']

segmentation_exp = experiment
ex_autofluor = False # True/False
ex_melanocytes = True # True/False
dim_order = (0,4,1,2,3) # define the image and mask dimension order

patch_transform = transforms.Compose([MinMaxScalerVectorized()])
label_transform = transforms.Compose([process_masks(exp = segmentation_exp,
                                                    ex_autofluor=ex_autofluor,
                                                    ex_melanocytes=ex_melanocytes,
                                                     )])

processed_set = WholeVolumeDataset(raw_directory = raw_img,
                                   mask_directory = mask_img,
                                   num_classes = output_chnl,
                                   raw_transform = patch_transform,
                                   label_transform = label_transform,
                                   mask_order = dim_order,
                                   device = device,
                                   )

# processed_dataloader = DataLoader(processed_set, batch_size=1, shuffle= False)



In [40]:
raw, mask = next(iter(processed_set))

In [41]:
mask.shape

torch.Size([1, 4, 190, 512, 512])

In [42]:
with torch.no_grad():
    pred = inferer(inputs = raw, network=model)

In [43]:
probabilities = torch.softmax(pred,1)

In [44]:
probabilities

tensor([[[[[9.9388e-01, 9.9838e-01, 9.9981e-01,  ..., 9.9937e-01,
            9.9987e-01, 9.9640e-01],
           [9.9893e-01, 9.9996e-01, 9.9997e-01,  ..., 1.0000e+00,
            1.0000e+00, 9.9995e-01],
           [9.9993e-01, 9.9998e-01, 1.0000e+00,  ..., 1.0000e+00,
            1.0000e+00, 9.9998e-01],
           ...,
           [9.8824e-01, 9.9906e-01, 9.9852e-01,  ..., 9.9997e-01,
            9.9995e-01, 9.9917e-01],
           [9.9774e-01, 9.9978e-01, 9.9997e-01,  ..., 9.9994e-01,
            9.9997e-01, 9.9559e-01],
           [9.8325e-01, 9.9918e-01, 9.9843e-01,  ..., 9.9526e-01,
            9.7970e-01, 8.1134e-01]],

          [[9.9443e-01, 9.9929e-01, 9.9940e-01,  ..., 9.9990e-01,
            9.9995e-01, 9.9592e-01],
           [9.9850e-01, 9.9995e-01, 9.9980e-01,  ..., 1.0000e+00,
            1.0000e+00, 9.9989e-01],
           [9.9932e-01, 9.9983e-01, 9.9994e-01,  ..., 9.9999e-01,
            1.0000e+00, 9.9981e-01],
           ...,
           [1.0000e+00, 1.0000e+00, 1.0

In [45]:
pred_from_categorical = to_numpy(torch.argmax(probabilities, 1))

In [46]:
pred_from_categorical.shape

(1, 190, 512, 512)

In [16]:
np.unique(pred_from_categorical)

array([0, 1, 2, 3], dtype=int64)

In [48]:
tifffile.imwrite("000_ML_20180613_N4_4a61736f_INFERENCED_fold_1.tif", pred_from_categorical.astype(int))