<a href="https://colab.research.google.com/github/AntonioVispi/Recursive_Segmentation/blob/main/Inference_Script.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -U -q segmentation-models
!pip install tensorflow==2.9.3
!pip install h5py==2.10.0
!pip install plotly==5.3.1

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
import numpy as np
import os

import nibabel as nib
import tensorflow as tf

import matplotlib.pyplot as plt
from matplotlib import pyplot as plt
from tqdm import tqdm

import skimage
from skimage.io import imread, imshow, imsave
from skimage.transform import resize

from tensorflow import keras
from keras.callbacks import ModelCheckpoint
from keras.callbacks import CSVLogger
from keras.callbacks import EarlyStopping
from keras.utils.np_utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.models import load_model
from keras import metrics
from keras.callbacks import ReduceLROnPlateau

from segmentation_models import Unet
import math
from math import floor

import random
from random import seed
from random import random

In [None]:
def visualizer(segm, IMG_HEIGHT, IMG_WIDTH):     #Function useful for displaying segmentations

   segm = resize(segm,(IMG_HEIGHT,IMG_WIDTH,4), mode='constant', preserve_range=True)

   back = segm[:,:,0]
   kid = segm[:,:,1]
   tum = segm[:,:,2]
   cys = segm[:,:,3]

   back = (back == 1)
   kid = (kid == 1)
   tum = (tum == 1)
   cys = (cys == 1)

   all_segments = np.zeros((IMG_HEIGHT, IMG_WIDTH, 3))

   all_segments[back] = (1,0,0)
   all_segments[kid] = (0,1,0)
   all_segments[tum] = (0,0,1)
   all_segments[cys] = (1,1,0)

   return all_segments


def Test(dataset_path, model_path, output_path, view_dataset, save_prediction):

    IMG_HEIGHT= 512
    IMG_WIDTH= 512
    IMG_CHANNELS = 3
    NUM_CLASSES = 4

    model = load_model(model_path,custom_objects=None, compile=False)

    j=0

    directory = os.listdir(dataset_path)
    directory.sort()

    for n in range (0,len(directory)):

        case_path = dataset_path +'/'+ directory[n]
        volume = nib.load(os.path.join(case_path))
        mask_affine = volume.affine

        print(volume.shape)
        n_slice,height,width = volume.shape
        X = np.zeros([n_slice, IMG_HEIGHT, IMG_WIDTH,IMG_CHANNELS], dtype=np.uint8)
        Y_automatic=np.zeros([n_slice, IMG_HEIGHT, IMG_WIDTH], dtype=np.uint8)

        volume = volume.slicer[0 : n_slice]


        Y_manual_stock=np.zeros([n_slice, IMG_HEIGHT, IMG_WIDTH, NUM_CLASSES], dtype=np.float32)


        Vol_supporto= volume.get_fdata().astype(np.int16)


        for k in range(0,n_slice):
          image=resize(Vol_supporto[k,:,:], (IMG_HEIGHT,IMG_WIDTH,1), mode='constant', preserve_range=True)
          X[k,:,:,:]=image


          support = np.reshape(X[k,:,:,:],(1,IMG_HEIGHT,IMG_WIDTH,IMG_CHANNELS))
          softmax = model.predict(support)
          softmax = np.reshape(softmax,(IMG_HEIGHT,IMG_WIDTH,NUM_CLASSES))

          softmax=np.argmax(softmax,axis=2).astype(np.uint8)

          Y_automatic[k,:,:]=softmax

        if save_prediction == 'on':

          os.makedirs(output_path, exist_ok = True)
          Y_automatic_nii = nib.Nifti1Image(Y_automatic, affine=mask_affine)
          nib.save(Y_automatic_nii, output_path + '/' + directory[n])
          print('Subject saved successfully')

    if view_dataset == 'on':
      for n in range(0,X.shape[0]):
        fig = plt.figure(figsize=(10, 10))
        ax1 = fig.add_subplot(121)
        ax1.imshow(X[n,:,:,:]), ax1.set_title('TAC Image')

        ax2= fig.add_subplot(122)
        support = to_categorical(Y_automatic[n,:,:], num_classes=NUM_CLASSES, dtype='float32')
        ax2.imshow(visualizer(support,IMG_HEIGHT,IMG_WIDTH)), ax2.set_title('Automatic Segmentation')




In [None]:
dataset_path = '/content/drive/your_dataset_path'
model_path='/content/drive/output_training_path/Epoch_15-Val_Loss0.00022.h5'
output_path = '/content/drive/output_training_path'

```
Test(dataset_path, model_path, output_path, view_dataset, save_prediction)

```
*   If you want to see the predictions: view_dataset = 'on'
*   If you want to save the predictions in output_path: save_prediction = 'on'

In [None]:
Test(dataset_path, model_path, output_path, 'off', 'on')