In [None]:
# Useful during development when helper functions may be updated
%load_ext autoreload
%autoreload 2


In [None]:
import sys
if not '../' in sys.path:
    sys.path.append('../')

from utils.utils import show_images
import torch
import torchvision
from torchvision.transforms import InterpolationMode
from aicsimageio import AICSImage

from utils.tiling import sliding_window_inference
from utils.utils import save_image_with_label_overlay, _move_channel_axis, _choose_device, show_images
from utils.utils import labels_to_features
import json
import os




In [None]:
# Load the model from torchscript
instanseg = torch.jit.load("../torchscript_models/instanseg_1735176.pt")
model_pixel_size = instanseg.pixel_size

device= _choose_device()

instanseg.to(device)

img = AICSImage("../examples/LuCa1.tif")
input_data = img.get_image_data("CYX")

#Set this to the pixel size in microns of the input image

if img.physical_pixel_sizes.X is not None:
    pixel_size = img.physical_pixel_sizes.X
    print("Pixel size was found in the metadata, pixel size is set to: ", pixel_size)
else:
    pixel_size = 0.5
    print("Pixel size was not found in the metadata, please set the pixel size of the input image in microns manually")

In [None]:
from utils.augmentations import Augmentations
Augmenter=Augmentations()
Augmenter.integer_channels_labels=[2,2] #You don't have to worry about this line

input_tensor,_ =Augmenter.to_tensor(input_data,normalize=False) #this converts the input data to a tensor and does percentile normalization (no clipping)

input_tensor,_ = Augmenter.normalize(input_tensor, percentile=0.)

original_shape = input_tensor.shape[1:]


print(input_tensor.shape)

In [None]:
#This rescales the input to the requested pixel size (0.5 microns).
#For display purposes only
input_tensor_to_rgb,_ = Augmenter.colourize(input_tensor,c_nuclei=6, random_seed = 1) # c_nuclei is the index of the nucleus channel in the input (nucleus channel is in blue by default in the display image)


input_crop,_ = Augmenter.torch_rescale(input_tensor,labels=None,current_pixel_size=pixel_size,requested_pixel_size=model_pixel_size,crop = True, random_seed=1)
input_tensor,_ = Augmenter.torch_rescale(input_tensor,labels=None,current_pixel_size=pixel_size,requested_pixel_size=model_pixel_size,crop = False, random_seed=1)
input_crop_to_rgb,_ = Augmenter.colourize(input_crop,c_nuclei=6, random_seed = 1) 



In [None]:
# To run instanseg in one pass, you can simply call the model with the input tensor. (No tiling)

#Make sure the input tensor is of shape 1, C, H, W.

labeled_output = instanseg(input_crop.to(device)[None]) #The labeled_output shape should be 1,1,H,W (nucleus or whole cell) or 1,2,H,W (nucleus and whole cell)

output_dimension = labeled_output.shape[1]

print(labeled_output.shape)

In [None]:
show_images(input_tensor, dpi = 50)

In [None]:
if output_dimension ==1: #Nucleus or cell mask]
    show_images(input_crop_to_rgb,labeled_output[0,0],labels = [1],colorbar=False,titles= ["Input (RGB Display)","Predicted mask"],dpi = 100)
elif output_dimension ==2: #Nucleus and cell mask
    show_images(input_crop_to_rgb,labeled_output[0,0],labeled_output[0,1],labels = [1,2],colorbar=True,titles= ["Input (RGB Display)","Predicted Nucleus mask","Predicted Cell Mask"],dpi = 100)

In [None]:
#This section actually calls the model and does the tiling in the background, it may take a few seconds to run

tiled_labels = sliding_window_inference(input_tensor,instanseg, window_size = (512,512),overlap_size = 124/512,sw_device = device,device = 'cpu', output_channels = output_dimension)


#Due to rounding errors, we have to normalize one more time, this should only change the size of the arrays by one pixel or so.
input_tensor = torchvision.transforms.Resize(original_shape,interpolation = InterpolationMode.BILINEAR)(input_tensor)
tiled_labels = torchvision.transforms.Resize(original_shape,interpolation = InterpolationMode.NEAREST)(tiled_labels)


tiled_labels = tiled_labels
print(tiled_labels.shape)

In [None]:
#%matplotlib qt
from utils.biological_utils import resolve_cell_and_nucleus_boundaries

tiled_labels = resolve_cell_and_nucleus_boundaries(tiled_labels.to('cuda')).to('cpu')

In [None]:
from utils.pytorch_utils import torch_sparse_onehot, fast_sparse_dual_iou
onehot1,_ = torch_sparse_onehot(tiled_labels[0,0],flatten=True)
onehot2,_ = torch_sparse_onehot(tiled_labels[0,1],flatten=True)

iou = fast_sparse_dual_iou(onehot1,onehot2)

In [None]:

show_images(tiled_labels)

In [None]:
out_path = "./Luca1_labels.geojson"

if output_dimension == 1:
    features = labels_to_features(tiled_labels[0,0].numpy(),object_type = "detection")

elif output_dimension == 2:
    features = labels_to_features(tiled_labels[0,0].numpy(),object_type = "detection",classification="Nuclei") + labels_to_features(tiled_labels[0,1].numpy(),object_type = "detection",classification = "Cells")
geojson = json.dumps(features)
with open(os.path.join(out_path), "w") as outfile:
    outfile.write(geojson)


In [None]:

if output_dimension ==1: #Nucleus or cell mask]
    show_images(input_tensor_to_rgb,tiled_labels[0,0],labels = [1],colorbar=False,titles= ["Input (RGB Display)","Predicted mask"],dpi = 100)
elif output_dimension ==2: #Nucleus and cell mask
    show_images(input_tensor_to_rgb,tiled_labels[0,0],tiled_labels[0,1],labels = [1,2],colorbar=False,titles= ["Input (RGB Display)","Predicted Nucleus mask","Predicted Cell Mask"],dpi = 100)

In [None]:
import numpy as np

im_for_display = torch.clamp(input_tensor_to_rgb,0,1).cpu().numpy() * 255  #Shape is 3,H,W
im_for_display = _move_channel_axis(im_for_display,to_back = True).astype(np.uint8) #Shape is H,W,3


if output_dimension ==1: #Nucleus or cell mask]
    labels_for_display = tiled_labels[0,0].cpu().numpy() #Shape is 1,H,W
    image_overlay = save_image_with_label_overlay(im_for_display,lab=labels_for_display,return_image=True, label_boundary_mode="thick", label_colors=None,thickness=10,alpha=0.5)
elif output_dimension ==2: #Nucleus and cell mask
    nuclei_labels_for_display = tiled_labels[0,0].cpu().numpy()
    cell_labels_for_display = tiled_labels[0,1].cpu().numpy() #Shape is 1,H,W
    image_overlay = save_image_with_label_overlay(im_for_display,lab=nuclei_labels_for_display,return_image=True, label_boundary_mode="thick", label_colors="red",thickness=10)
    image_overlay = save_image_with_label_overlay(image_overlay,lab=cell_labels_for_display,return_image=True, label_boundary_mode="inner", label_colors="green",thickness=1)





In [None]:
show_images(image_overlay,colorbar=False,dpi = 200)

In [None]:
#Now lets write everything as a single method:

def run_instanseg(model_name: str,image_path, pixel_size: int = None, output_dir = None,  model_pixel_size=0.5, cell_and_nuclei=True, export_geojson: bool = True, number_of_input_channels = None):
    from pathlib import Path
    from utils.augmentations import Augmentations
    import numpy as np
    from aicsimageio import AICSImage
    import warnings
    from skimage import io

    instanseg = torch.jit.load(Path("../examples/torchscript_models/") / model_name)
    device= _choose_device(verbose = False)
    instanseg.to(device)
    img = AICSImage(image_path)

    if pixel_size is None and img.physical_pixel_sizes.X is None:
        raise ValueError("Pixel size was not found in the metadata, please set the pixel size of the input image in microns manually")
    elif  pixel_size is None and img.physical_pixel_sizes.X is not None:
        pixel_size = img.physical_pixel_sizes.X
        assert pixel_size > 0 and pixel_size < 3, "Pixel size is not in microns, please check the metadata"
        print("Pixel size was found in the metadata, pixel size is set to: ", pixel_size)

    channel_number = img.dims.C
    if "S" in img.dims.order and img.dims.S > img.dims.C:
        channel_number = img.dims.S
        input_data = img.get_image_data("SYX")
    else:
        input_data = img.get_image_data("CYX")

    if number_of_input_channels is not None:
        if channel_number != number_of_input_channels:
            warnings.warn("Skipping images which don't fit the number of input channels of the model")
            return None
    
    Augmenter=Augmentations()

    if cell_and_nuclei:
        output_dimension=2
        Augmenter.integer_channels_labels=[2,2]
    else:
        output_dimension=1
        Augmenter.integer_channels_labels=[2] #You don't have to worry about this line

    input_tensor,_ =Augmenter.to_tensor(input_data,normalize=True) #this converts the input data to a tensor and does percentile normalization (no clipping)

   # show_images(input_tensor)
    
    
    channel_number = input_tensor.shape[0]

    input_tensor,_ = Augmenter.torch_rescale(input_tensor,labels=None,current_pixel_size=pixel_size,requested_pixel_size=model_pixel_size,crop = False, random_seed=1)

    #For display purposes only
    if channel_number !=3:
        input_tensor_rendered,_ = Augmenter.colourize(input_tensor,c_nuclei=1, random_seed = 1) # c_nuclei is the index of the nucleus channel in the input (nucleus channel is in blue by default in the display image)
    else:
        input_tensor_rendered = input_tensor

    #This section actually calls the model and does the tiling in the background, it may take a few seconds to run

    tiled_labels = sliding_window_inference(input_tensor,instanseg, window_size = (512,512),overlap_size = 0.1,sw_device = device,device = 'cpu', output_channels = output_dimension)

    #Recover the original pixel size
    input_tensor_rendered,tiled_labels = Augmenter.torch_rescale(input_tensor_rendered,labels=tiled_labels.squeeze(0),current_pixel_size=model_pixel_size,requested_pixel_size=pixel_size,crop =False)
    tiled_labels = tiled_labels.unsqueeze(0)
    

    if output_dir is None:
        output_dir = Path(image_path).parent / ("Results_" + model_name)
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)
        file_name= Path(image_path).stem

        output_path = output_dir / file_name

    if export_geojson:
        if output_dimension == 1:
            features = labels_to_features(tiled_labels[0,0].numpy(),object_type = "detection")

        elif output_dimension == 2:
            features = labels_to_features(tiled_labels[0,0].numpy(),object_type = "detection",classification="Nuclei") + labels_to_features(tiled_labels[0,1].numpy(),object_type = "detection",classification = "Cells")
        geojson = json.dumps(features)
        with open(os.path.join(str(output_path) + "_labels.geojson" ), "w") as outfile:
            outfile.write(geojson)

    im_for_display = torch.clamp(input_tensor_rendered,0,1).cpu().numpy() * 255  #Shape is 3,H,W
    im_for_display = _move_channel_axis(im_for_display,to_back = True).astype(np.uint8) #Shape is H,W,3


    if output_dimension ==1: #Nucleus or cell mask]
        labels_for_display = tiled_labels[0,0].cpu().numpy() #Shape is 1,H,W
        image_overlay = save_image_with_label_overlay(im_for_display,lab=labels_for_display,return_image=True, label_boundary_mode="thick", label_colors=None,thickness=5,alpha=1)
    elif output_dimension ==2: #Nucleus and cell mask
        nuclei_labels_for_display = tiled_labels[0,0].cpu().numpy()
        cell_labels_for_display = tiled_labels[0,1].cpu().numpy() #Shape is 1,H,W
        image_overlay = save_image_with_label_overlay(im_for_display,lab=nuclei_labels_for_display,return_image=True, label_boundary_mode="inner", label_colors="red",thickness=1)
        image_overlay = save_image_with_label_overlay(image_overlay,lab=cell_labels_for_display,return_image=True, label_boundary_mode="inner", label_colors="green",thickness=1)


    io.imsave(str(output_path) + "_rendered_markup.tif",image_overlay)
    io.imsave(str(output_path) + "_rendered.tif",im_for_display)



In [None]:
from utils.utils import export_to_torchscript

export_to_torchscript("1740051")

In [None]:
import os
from pathlib import Path
from tqdm.auto import tqdm
for image in tqdm(Path(("../../Datasets/Unannotated/QBI_CROPS")).iterdir()):
    if "Results" not in str(image.name):
        file = str(image.name)=
        if file != "Thumbs.db":
            run_instanseg("instanseg_1740051.pt",image,pixel_size = None, model_pixel_size=0.5, cell_and_nuclei=True, export_geojson=False, number_of_input_channels = 3)



In [None]:
run_instanseg("instanseg_v0_2_0.pt","/home/thibaut_goldsborough/Downloads/LuCa-7color_[17572,60173]_3x3component_data.tif",pixel_size=0.5, model_pixel_size=0.5, cell_and_nuclei=True, export_geojson=True)