## PyTorch Prediction

This notebook uses the trained semantic segmentation network for prediction

In [12]:
from skimage.io import imread as skimread
from pathlib import Path
import torch
import numpy as np
from tnia.deeplearning.dl_helper import quantile_normalization
from torchvision import transforms
from torchvision.transforms import v2

## Setup Paths

In [16]:

#tnia_images_path = Path("D:/images/tnia-python-images")
tnia_images_path = Path(r'/home/bnorthan/images/tnia-python-images')
parent_path=Path(tnia_images_path / r'imagesc/2024_08_08_2photon_vessel')

images_path = parent_path 
patch_path = parent_path / 'patches'
labels_path = parent_path / 'labels'

test_name = r'image1.jpg'

testim = skimread(images_path  / test_name) 

print(testim.shape)
axes = 'YX'

(764, 762, 3)


## Load model

In [17]:
model = torch.load(patch_path / 'model1')
model.to('cuda')

BasicUNet(
  (conv_0): TwoConv(
    (conv_0): Convolution(
      (conv): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (adn): ADN(
        (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (D): Dropout(p=0.25, inplace=False)
        (A): ReLU()
      )
    )
    (conv_1): Convolution(
      (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (adn): ADN(
        (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (D): Dropout(p=0.25, inplace=False)
        (A): ReLU()
      )
    )
  )
  (down_1): Down(
    (max_pooling): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (convs): TwoConv(
      (conv_0): Convolution(
        (conv): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (adn): ADN(
          (N): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (D)

In [15]:
device = torch.device("cuda")

## Convert image to tensor...

and predict output segmentation

In [18]:
testim_ = testim.copy().astype(np.float32)

if axes == 'YXC':
    for i in range(1):
        testim_[:,:,i] = quantile_normalization(
            testim[:,:,i],
            quantile_low=0.01,
            quantile_high=0.998,
            clip=True).astype(np.float32)
else:
    testim_ = quantile_normalization(
        testim_,
        quantile_low=0.01,
        quantile_high=0.998,
        clip=True).astype(np.float32)

tensor_transform = transforms.Compose([
    v2.ToTensor(),
])
x = tensor_transform(testim_)
x = x.unsqueeze(0).to(device)
#x = torch.from_numpy(testim_).to(device)

print(x.shape)
y = model(x)


torch.Size([1, 3, 764, 762])




## Find bounding boxes of labels

Load the bounding boxes for any labels drawn on this image.  This is useful to see self-prediction (prediction on areas that were labeld) vs validation prediction (prediction on areas of image that were not labeled)

In [20]:
# search for ROI
import json

labels_image_path = labels_path / 'input0'

json_names = list(Path(labels_image_path).glob('*.json'))
base_name = test_name.split('.')[0]
json_names_ = [x for x in json_names if base_name in x.name]

test_ = test_name.split('.')[0]

rois=[]

for json_name in json_names_:
    # open json
    with open(json_name, 'r') as f:
        json_ = json.load(f)
        print(json_)
        
        y1= json_['bbox'][0]
        x1= json_['bbox'][1]
        y2= json_['bbox'][2]
        x2= json_['bbox'][3]
        rois.append([[x1, y1], [x2, y2]])
        


{'base_name': 'image1_0', 'bbox': [192, 487, 738, 748]}


## View in napari

View image, prediction and bounding box in napari

In [21]:
import napari
viewer = napari.Viewer()
viewer.add_image(testim, name='testim')
viewer.add_image(y.cpu().detach()[0, 0].numpy(), name='prediction')
binary = y.cpu().detach()[0, 0].numpy() > 0.5
binary = binary.astype(np.uint8)
binary = binary*2
viewer.add_labels(binary, name='prediction binary')
boxes_layer = viewer.add_shapes(
            name="Label box",
            face_color="transparent",
            edge_color="green",
            edge_width=2,
        )

boxes_layer.add(rois)