In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
 
# setting path
sys.path.append('..')

In [None]:
import skimage
import matplotlib.pyplot as plt
import torch
import torchio as tio
import argparse
import yaml
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F

from training.segmentation_module import BinarySegmentation
from utils.inference_utils import output_name, load_single_slice
from utils.config_utils import read_transforms

In [None]:
import pylab
pylab.rcParams['figure.figsize'] = (15.0, 12.0) 

In [None]:
model_id = 'yurwvjn0'

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
model_path = "../trunk-segmentation/{}/checkpoints/last.ckpt".format(model_id)
data_path = "../data/scan.tif"
axis = 0
idx = 804

In [None]:
model = BinarySegmentation.load_from_checkpoint(model_path).to(device)

In [None]:
config = model.config

In [None]:
general_transforms = read_transforms(config['transforms']['general'])

In [None]:
input_image = load_single_slice(data_path, idx, axis, as_numpy=True)

In [None]:
skimage.io.imshow(input_image)

In [None]:
transformed_image = general_transforms(image=input_image)['image']

In [None]:
transformed_image.shape

In [None]:
def crop(image, mask=None, padding=10):
    meansY, meansX = np.mean(image, axis=1), np.mean(image, axis=0)    
    selected_indices = np.asarray(meansX > np.round(np.min(meansX))).nonzero()[0]
    minX = selected_indices[0] - padding
    maxX = selected_indices[-1] + padding
    selected_indices = np.asarray(meansY > np.round(np.min(meansY))).nonzero()[0]
    minY = selected_indices[0] - padding
    maxY = selected_indices[-1] + padding
    if mask is not None:
        return image[minY:maxY, minX:maxX], mask[minY:maxY, minX:maxX]
    return image[minY:maxY, minX:maxX]

In [None]:
skimage.io.imshow(crop(transformed_image))

In [None]:
image_tensor = torch.from_numpy(transformed_image).to(device)

In [None]:
scalar_input_image =tio.ScalarImage(tensor=image_tensor.unsqueeze(0).unsqueeze(-1))

In [None]:
grid_sampler = tio.inference.GridSampler(
        tio.Subject(one_image=scalar_input_image),
        (128, 128, 1), #(config['patch_size'], config['patch_size'], 1),  # TODO: when passing a volume we might want to select an axis
        (32, 32, 0) #(config['stride'], config['stride'], 0),
    )
patch_loader = torch.utils.data.DataLoader(grid_sampler, batch_size=config['batch_size'])
aggregator = tio.inference.GridAggregator(grid_sampler, overlap_mode='average')

In [None]:
model.eval()
with torch.no_grad():
    for patches_batch in tqdm(patch_loader, desc='Predicting'):
        input_tensor = patches_batch['one_image'][tio.DATA].type(torch.FloatTensor).to(device)
        locations = patches_batch[tio.LOCATION]
        logits = model(input_tensor.squeeze(-1)) # .round()
        aggregator.add_batch(logits.unsqueeze(1).unsqueeze(-1), locations)

In [None]:
output_tensor = aggregator.get_output_tensor()#.round()

In [None]:
def visualize_mask(img, mask, color=(255, 0, 0), alpha=0.5):
  # credits: https://stackoverflow.com/questions/9193603/applying-a-coloured-overlay-to-an-image-in-either-pil-or-imagemagik

  img_color = skimage.color.gray2rgb(img)

  col_mask = np.ones((mask.shape[0], mask.shape[1], 3)) * (255, 0 ,0)

  img_hsv = skimage.color.rgb2hsv(img_color)
  color_mask_hsv = skimage.color.rgb2hsv(col_mask)

  # Replace the hue and saturation of the original image
  # with that of the color mask
  img_hsv[..., 0] = color_mask_hsv[..., 0]
  img_hsv[..., 1] = color_mask_hsv[..., 1] * mask

  return skimage.color.hsv2rgb(img_hsv)

In [None]:
mask = output_tensor.squeeze(0).squeeze(-1).cpu().numpy()

In [None]:
skimage.io.imshow(visualize_mask(transformed_image, mask))

In [None]:
pylab.rcParams['figure.figsize'] = (15.0, 12.0) 
img_cropped, mask_cropped = crop(transformed_image, mask)
skimage.io.imshow(visualize_mask(img_cropped, mask_cropped))