In [12]:
import os
import importlib
import numpy as np
import matplotlib.pyplot as plt
from utils.vedo import plot_slicer_cloud, plot_volume_cloud, plot_two_volumes

import utils.nifti
importlib.reload(utils.nifti)
from utils.nifti import estimate_volume

# neural imaging
import nibabel as nib

# tensorflow
import tensorflow as tf
from evaluation.evaluation import *

# Make numpy printouts easier to read.
np.set_printoptions(precision=3, suppress=True)
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))

Num GPUs Available:  1


casi che quasi tutti non riescono a predirre bene
- TBI_PTE_fm_20_08_3w_N4, c'è un vuoto nero che non prende come cervello
- TBI_gv_20_42_N4 e TBI_gv_20_55_N4, uno è riflesso al contrario e l'altro è tutto deforme. Entrambi sono più scuri

In [32]:
from vedo import Volume, show, settings
import vedo
vedo.settings.default_backend= 'vtkplotter'
from vedo import *
from vedo.applications import Slicer3DPlotter

def plot_vol(nii_img, nii_mask, spacing=[1, 1, 1]):
    '''
    Plot a 3D image and a point cloud of the lesion mask in the same plot.
    '''

    plt = Plotter()

    v2 = Volume(nii_img)

    # Set voxel dimension in mm
    v2.spacing(spacing)
    v2.cmap('bone')

    plt += v2

    # Define class labels based on the mask values
    class_labels = np.unique(nii_mask)
    class_len = len(class_labels)
    print('Class labels:', class_labels)
    print('Number of classes:', class_len)
    
    # make a color map that goes from red to blue for each class
    color_map = np.zeros((class_len, 3))
    color_map[:, 0] = np.linspace(1, 0, class_len)
    color_map[:, 2] = np.linspace(0, 1, class_len)
    color_map[:, 1] = np.linspace(0, 0, class_len)
    
    

    # Create an empty dictionary to store point clouds
    point_clouds = {}

    i=0
    for label in class_labels:
        # Get voxel coordinates
        voxel_coords = np.array(np.where(nii_mask == label)).T  * v2.spacing()
        pts = Points(voxel_coords, r=3, c=color_map[i], alpha=0.5)
        # Multiply by voxel dimension to get coordinates in mm

        # Store the point cloud in the dictionary
        point_clouds[label] = pts
        i+=1


    for label in class_labels:
        plt += point_clouds[label]

    plt.add_slider(
        lambda w, e: [point_clouds[label].alpha(w.value) for label in class_labels],
        xmin=0,
        xmax=1,
        value=0.5,
        pos="bottom-right-vertical",
        title="Opacity",
    )

    plt.add_slider(
        lambda w, e: v2.alpha([0, w.value]),
        xmin=0,
        xmax=1,
        value=0.1,
        pos="bottom-left-vertical",
        title="Opacity",
    )

    # set initial alpha of volume 2 to 0.1
    v2.alpha([0, 0.01])

    # Make the plot rotate automatically
    plt += Text2D("Press q to exit", pos=(0.8, 0.05), s=0.8)

    return plt.show(viewup='z').close()

In [21]:
patch_size = (76,76,76)
overlap = (8,8,8)


# Call the preprocessing function, must be the same of preprocessing

# declare random input volume of size 128x128x128
input_volume = np.random.rand(150,150,111)
input_volume = np.random.rand(150,150,111)
# calculate the number of patches in each dimension
num_patches = np.ceil((np.array(input_volume.shape) - np.array(patch_size)) / (np.array(patch_size) - np.array(overlap))+ 1).astype(int)
print('Num patches: ', num_patches)

predicted_patches = np.zeros_like(input_volume)

count = 1
# for each patch make a prediction
for i in range(num_patches[0]):
    for j in range(num_patches[1]):
        for k in range(num_patches[2]):
            start_i = i* (patch_size[0] - overlap[0])
            start_j = j* (patch_size[1] - overlap[1])
            start_k = k* (patch_size[2] - overlap[2])

            # extract the current patch from the input volume
            patch = input_volume[start_i:start_i + patch_size[0],
                                    start_j:start_j + patch_size[1],
                                    start_k:start_k + patch_size[2]]
            
            # perform inference on the patch
            predicted_patch = count
            predicted_patches[start_i:start_i + patch_size[0],
                                start_j:start_j + patch_size[1],
                                start_k:start_k + patch_size[2]] = predicted_patch
            count += 1

# cap the values to 1
#predicted_patches[predicted_patches > 1] = 1

# plot the input volume and the predicted volume
plot_vol(input_volume, predicted_patches)

Num patches:  [3 3 2]
Class labels: [ 1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13. 14. 15. 16. 17. 18.]
Number of classes: 18


In [34]:
def sliding_window_inference(data, patch_size, stride, threshold=0.5):
    depth, height, width = data.shape
    predictions = np.zeros((depth, height, width))
    count_map = np.zeros((depth, height, width))

    count = 1
    # Calculate the number of steps required to cover each dimension
    steps_d = max(1, (depth - patch_size[0]) // stride + 1)
    steps_h = max(1, (height - patch_size[1]) // stride + 1)
    steps_w = max(1, (width - patch_size[2]) // stride + 1)

    for d in range(0, steps_d * stride, stride):
        for h in range(0, steps_h * stride, stride):
            for w in range(0, steps_w * stride, stride):
                patch = data[d:min(d+patch_size[0], depth), h:min(h+patch_size[1], height), w:min(w+patch_size[2], width)]
                prediction = count
                predictions[d:min(d+patch_size[0], depth), h:min(h+patch_size[1], height), w:min(w+patch_size[2], width)] += 1
                count_map[d:min(d+patch_size[0], depth), h:min(h+patch_size[1], height), w:min(w+patch_size[2], width)] += 1



    # Avoid division by zero
    # count_map[count_map == 0] = 1

    # Average predictions
    # predictions /= count_map

    # Apply threshold for binary segmentation
    # binary_predictions = (predictions >= threshold).astype(np.uint8)
    plot_vol(data, predictions)
sliding_window_inference(input_volume, patch_size=(76,76,76), stride=8)

Class labels: [  0.   1.   2.   3.   4.   5.   6.   7.   8.   9.  10.  12.  14.  15.
  16.  18.  20.  21.  24.  25.  27.  28.  30.  32.  35.  36.  40.  42.
  45.  48.  49.  50.  54.  56.  60.  63.  64.  70.  72.  75.  80.  81.
  84.  90.  96.  98. 100. 105. 108. 112. 120. 125. 126. 128. 135. 140.
 144. 147. 150. 160. 162. 168. 175. 180. 189. 192. 196. 200. 210. 216.
 224. 225. 240. 243. 245. 250. 252. 256. 270. 280. 288. 300. 315. 320.
 324. 350. 360. 400. 405. 450. 500.]
Number of classes: 91
