## Inferencing Notebook

In [1]:
# import core libaries
import numpy as np
import tkinter as tk
from tkinter import filedialog
root = tk.Tk()
root.withdraw()

import glob
import os
import sys
import tifffile

SCRIPT_DIR = os.path.dirname(os.path.abspath(__vsc_ipynb_file__))
sys.path.append(os.path.dirname(SCRIPT_DIR))
from processing.processing_functions import *

# get working directory
path = os.getcwd()
sys.path.append(path)

# import machine learning libraries
import torch
from torchvision import transforms, utils
from monai.inferers.inferer import SlidingWindowInferer

  from .autonotebook import tqdm as notebook_tqdm
  Referenced from: /Users/jasonfung/miniforge3/envs/ml_env/lib/python3.9/site-packages/torchvision/image.so
  Expected in: /Users/jasonfung/miniforge3/envs/ml_env/lib/python3.9/site-packages/torch/lib/libc10.dylib
  warn(f"Failed to load image Python extension: {e}")


In [2]:
# initialize cuda if available
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")

In [3]:
# model = "+s+d+f_ResUNet.onnx"
model_soma_dendrite = "Soma+Dendrite.onnx"

In [None]:
# processing raw image
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
batch_size = 64
# split_size = 0.9
dim_order = (0,4,1,2,3) # define the image and mask dimension order

raw_path = filedialog.askopenfilename()
raw_img = glob.glob(raw_path)
orig_shape = tifffile.imread(raw_img).shape

# Use patch transform to normalize and transform ndarray(z,y,x) -> tensor(
patch_transform = transforms.Compose([MinMaxScalerVectorized(),
                                      patch_imgs(xy_step = lateral_steps, z_step = axial_steps, patch_size = patch_size, is_mask = False)])


processed_test_img = MyImageDataset(raw_list = raw_img,
                                    mask_list = None,
                                    transform = patch_transform,
                                    device = device,
                                    img_order = dim_order,
                                    mask_order = dim_order,
                                    num_classes = None,
                                    train=False)

## Using Custom Inferencing

In [None]:

reconstructed_img = inference(processed_test_img, 
                              model, 
                              batch_size, 
                              patch_size, 
                              orig_shape,
                              )

np.unique(reconstructed_img)

if len(np.unique(reconstructed_img))-1 == 2:
    reconstructed_img[reconstructed_img==1] = 2

In [None]:
type(reconstructed_img)

In [None]:
tifffile.imwrite(f'{raw_path}_+s+d+f.tif', reconstructed_img.astype(int))

## Using MONAI Sliding Window Inferencing

In [4]:
import onnx
onnx_model = onnx.load(f"/Users/jasonfung/Documents/Label_Seg_Program/models/{model_soma_dendrite}")
onnx.checker.check_model(onnx_model)

In [5]:
lateral_steps = 64
axial_steps = 16
patch_size = (axial_steps, lateral_steps, lateral_steps)
batch_size = 64

inferer = SlidingWindowInferer(roi_size=patch_size, sw_batch_size=batch_size)

In [6]:
#pick test image
raw_path = filedialog.askopenfilename()
raw_img = glob.glob(raw_path)
patch_transform = transforms.Compose([MinMaxScalerVectorized()])

processed_test_img = MyImageDataset(raw_list = raw_img,
                                    mask_list = None,
                                    transform = patch_transform,
                                    device = device,
                                    test_model = True,
                                    train=False)


In [12]:
x = next(iter(processed_test_img))

In [13]:
x = torch.unsqueeze(torch.unsqueeze(x,0),0)

In [14]:
x.shape

torch.Size([1, 1, 35, 512, 512])

In [15]:
import onnxruntime

ort_session = onnxruntime.InferenceSession(onnx_model)
inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
output = to_torch(ort_session.run(None,inputs)) # predict the batches

with torch.no_grad():
    pred = inferer(inputs = x, network=onnx_model)

TypeError: 'ModelProto' object is not callable

In [None]:
# import napari
# viewer = napari.Viewer()
# orig_img = tifffile.imread(raw_img)
# raw_image = viewer.add_image(orig_img, rgb=False)

In [None]:
# label_img = viewer.add_labels(reconstructed_img.astype(int))