<a href="https://colab.research.google.com/github/DerekGloudemans/segmentation-medical-images/blob/master/Generate_Final_Results.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [6]:
#%%capture 
!pip install -q --upgrade ipython==5.5.0
!pip install -q --upgrade ipykernel==4.6.0
!pip3 install torchvision
!pip3 install opencv-python

import ipywidgets
import traitlets
# imports

# this seems to be a popular thing to do so I've done it here
#from __future__ import print_function, division


# torch and specific torch packages for convenience
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils import data
from torch import multiprocessing
from google.colab.patches import cv2_imshow

# for convenient data loading, image representation and dataset management
from torchvision import models, transforms
import torchvision.transforms.functional as FT
from PIL import Image, ImageFile, ImageStat
ImageFile.LOAD_TRUNCATED_IMAGES = True
from scipy.ndimage import affine_transform
import cv2

# always good to have
import time
import os
import numpy as np    
import _pickle as pickle
import random
import copy
import matplotlib.pyplot as plt
import math

import nibabel as nib



[?25l[K     |███▏                            | 10kB 36.4MB/s eta 0:00:01[K     |██████▎                         | 20kB 2.1MB/s eta 0:00:01[K     |█████████▍                      | 30kB 3.0MB/s eta 0:00:01[K     |████████████▋                   | 40kB 2.0MB/s eta 0:00:01[K     |███████████████▊                | 51kB 2.5MB/s eta 0:00:01[K     |██████████████████▉             | 61kB 3.0MB/s eta 0:00:01[K     |██████████████████████          | 71kB 3.4MB/s eta 0:00:01[K     |█████████████████████████▏      | 81kB 3.9MB/s eta 0:00:01[K     |████████████████████████████▎   | 92kB 4.3MB/s eta 0:00:01[K     |███████████████████████████████▍| 102kB 3.3MB/s eta 0:00:01[K     |████████████████████████████████| 112kB 3.3MB/s 


In [0]:
def track_segment(data,pred,model):
  "Note that pred must be integer, 0 or 1"

  num_slices = data.shape[2]
  result = torch.zeros(data.shape)

  for idx in range(num_slices):
    # resize to 256 x 256
    slice = data[:,:,idx]
    
    # get slice of 
    if idx == 0:
      prev_slice = torch.zeros(slice.shape)
    else:
      prev_slice = pred[:,:,idx-1]

    original_shape = slice.shape

    slice =  Image.fromarray(slice.data.numpy()).copy()
    slice = FT.to_grayscale(slice)
    slice = FT.to_tensor(slice)
    slice = slice.unsqueeze(0)
    prev_slice = prev_slice.unsqueeze(0).unsqueeze(0)

    slice = F.interpolate(slice,size = [256,256],mode = 'bilinear')
    prev_slice = F.interpolate(prev_slice, size = [256,256],mode = 'nearest')

    x = torch.cat((prev_slice,slice),dim = 1)
    x = x.to(device).float()

    out_slice = model(x)
    out_slice = F.interpolate(out_slice,original_shape,mode = 'nearest')
    result[:,:,idx] = out_slice.data.cpu()

    if False:
      plt.figure(figsize = (5,15))
      plt.subplot(132)
      plt.imshow(prev_slice[0][0],cmap = "gray")
      plt.clim(0,1)
      plt.subplot(131)
      plt.imshow(slice[0][0],cmap = "gray")
      plt.clim(0,1)
      plt.subplot(133)
      plt.imshow(result[:,:,idx],cmap = "gray")
      plt.clim(0,1)
      plt.show()

  return result

In [0]:
class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=1,
        depth=3,
        wf=4,
        padding=True,
        batch_norm=False,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

        innum = 1000
        midnum = 100
        outnum = 4*n_classes
        self.reg = nn.Sequential(
            nn.BatchNorm1d(innum),
            nn.Linear(innum,midnum),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(midnum,outnum),
            nn.Sigmoid()
        )

        #for param in self.parameters():
        #  param.requires_grad = True

    def forward(self, x,BBOX = False):
        blocks = []
        
        # encoder
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        # do bbox regression here
        if BBOX:
          x_reg = x.view(-1)
          x_reg = self.reg(x_reg)
          bboxes = x_reg.view(4,-1)

        # decoder
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        #CHANGE THIS LINE FOR MULTIPLE OUTPUT CHANNELS
        # apply per_class last layer and per-class Softmax 
        #x = nn.Softmax2d(self.last(x)) 
        x = torch.sigmoid(self.last(x))
        
        if BBOX:
          return x, bboxes
        else:
          return x


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out


class UNet2(nn.Module):
    def __init__(
        self,
        in_channels=2,
        n_classes=1,
        depth=3,
        wf=4,
        padding=True,
        batch_norm=False,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet2, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

        innum = 1000
        midnum = 100
        outnum = 4*n_classes
        self.reg = nn.Sequential(
            nn.BatchNorm1d(innum),
            nn.Linear(innum,midnum),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(midnum,outnum),
            nn.Sigmoid()
        )

        #for param in self.parameters():
        #  param.requires_grad = True

    def forward(self, x,BBOX = False):
        blocks = []
        
        # encoder
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        # do bbox regression here
        if BBOX:
          x_reg = x.view(-1)
          x_reg = self.reg(x_reg)
          bboxes = x_reg.view(4,-1)

        # decoder
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        #CHANGE THIS LINE FOR MULTIPLE OUTPUT CHANNELS
        # apply per_class last layer and per-class Softmax 
        #x = nn.Softmax2d(self.last(x)) 
        x = torch.sigmoid(self.last(x))
        
        if BBOX:
          return x, bboxes
        else:
          return x


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

class Nifti_Dataset(data.Dataset):
  def __init__(self,mode = "view",dim = 2,class_id = 1):
    """
    Save the last 0.15 proportion of files after sorting for use as validation set.
    Loads each slice of the input images as a separate image
    mode - view,train or val, defined in same dataset to maintain data separation
      view - performs transforms but does not normalize images
      train - normalizes data and performs transforms
      val - normalizes data, no augmenting transforms
    dim - specifies dimension along which to slice image
    """

    self.mode = mode
    self.dim = dim
    self.class_id = class_id
  
    data_dir = "/content/drive/My Drive/Colab Notebooks/Segmentation/RawData/Training/img"
    label_dir = "/content/drive/My Drive/Colab Notebooks/Segmentation/RawData/Training/label"

    # get all data and label file names
    self.data_files = []
    for file in os.listdir(data_dir):
      self.data_files.append(os.path.join(data_dir,file))
    self.data_files.sort()

    self.label_files = []
    for file in os.listdir(label_dir):
      self.label_files.append(os.path.join(label_dir,file))
    self.label_files.sort()

    # for each data_file
    self.train_data = []
    self.val_data = []

    for i in range(len(self.data_files)):
      data = nib.load(self.data_files[i])
      data = np.array(data.get_fdata())

      label = nib.load(self.label_files[i])
      label = np.array(label.get_fdata()).astype(float)

      identifier = self.data_files[i].split("_")[0]
      for slice in range(0,data.shape[dim]):

        # get slices
        if dim == 0:
          data_slice = data[slice,:,:]
          label_slice = label[slice,:,:]
        elif dim == 1:
          data_slice = data[:,slice,:]
          label_slice = label[:,slice,:]
        elif dim == 2:
          data_slice = data[:,:,slice]
          label_slice = label[:,:,slice]

        mean,std = np.mean(data_slice),np.std(data_slice)
        # define item dict to store info
        item = {
            "identifier":identifier,
            "slice":slice,
            "data":data_slice,
            "label":label_slice,
            "mean":mean,
            "std":std
            }

        # check to make sure this example actually has organs in it
        test = np.bincount(label_slice.astype(int).reshape(-1))
        if len(test) == 1: 
          continue

        # make sure 50% of examples have organ of interest in them
        if class_id not in np.unique(label_slice.astype(int).reshape(-1)):
            if np.random.rand() > 0.5:
              continue

        # assign to either training or validation data
        if i < len(self.data_files) * 0.85:
          self.train_data.append(item)
        else:
          self.val_data.append(item)

      #break # to shorten loading time

    # define some transforms for training dataset
    self.train_transforms = transforms.Compose([
          transforms.ColorJitter(brightness = 0.2,contrast = 0.2,saturation = 0.1),
          transforms.ToTensor(),
          transforms.RandomErasing(p=0.0015, scale=(0.4, 0.6), ratio=(0.3, 3.3), value=0, inplace=False), # big
          transforms.RandomErasing(p=0.003, scale=(0.1, 0.3), ratio=(0.3, 3.3), value=0, inplace=False), # medium
          transforms.RandomErasing(p=0.004, scale=(0.05, 0.15), ratio=(0.3, 3.3), value=0, inplace=False),# small
          transforms.RandomErasing(p=0.0035, scale=(0.05, 0.2), ratio=(0.3, 3.3), value=0, inplace=False) # small
          

        ])

  def __getitem__(self,index):
      #get relevant dictionary from self memory

      if self.mode in ['train','view']:
        item = self.train_data[index]
      else:
        item = self.val_data[index]

      x = Image.fromarray(item['data']).copy()
      y = Image.fromarray(item['label']).copy()

      # to grayscale
      x  = FT.to_grayscale(x)
      y = FT.to_grayscale(y)

      if self.mode in ['train','view']:
        # randomly flip and rotate both
        FLIP = 0 #np.random.rand()
        if FLIP > 0.5:
          x = FT.hflip(x)
          y = FT.hflip(y)

        ROTATE = 0 #np.random.rand()*60 - 30
        x  = x.rotate(ROTATE)
        y = y.rotate(ROTATE,Image.NEAREST)

      # resize to 224 on shorter dimension
      x = FT.resize(x, 256)
      y = FT.resize(y,256,Image.NEAREST)

      if self.mode in ['train','view']:
        # randomly jitter color of data and randomly erase data
        x = self.train_transforms(x)
      # to tensor
      try:
        x = FT.to_tensor(x)
      except:
        pass
      y = FT.to_tensor(y)
            
      # normalize and repeat along color dimension if in train or val mode
      if self.mode in ['train','val']:
        #x = FT.normalize(x,[item['mean']],[item['std']])
        #x = x.repeat(3,1,1)
        pass

      return x,y

  def __len__(self):
    if self.mode in ["train","view"]:
      return len(self.train_data)
    else:
      return len(self.val_data)

  def show(self,index):
    data,label = self[index]
    plt.figure()
    plt.subplot(121)
    data = data.detach()
    plt.imshow(data[0],cmap = "gray")

    plt.subplot(122)
    plt.imshow(label[0],cmap = "gray")
    plt.show()
    # convert each tensor to numpy array

  def show_slices(self,idx = 0,dim = 0,organ_id = None):
    """
    A nice utility function for plotting all of the slices along a given dimension
    idx - indexes all NIfTI images in dataset
    dim - indexes dimension of image
    organ_id - if not None, all other organs removed from label
    """
    data = nib.load(self.data_files[idx])
    label = nib.load(self.label_files[idx])

    data = data.get_fdata()
    data = np.array(data)
    label = label.get_fdata()
    label = np.array(label)

    for slice in range(0,data.shape[dim]):
      if dim == 0:
            data_slice = data[slice,:,:]
            label_slice = label[slice,:,:]
      elif dim == 1:
        data_slice = data[:,slice,:]
        label_slice = label[:,slice,:]
      elif dim == 2:
        data_slice = data[:,:,slice]
        label_slice = label[:,:,slice]

      if organ_id is not None:
        # if a specific label is to be looked at, 0 all others
        label_slice = 1.0 - np.ceil(np.abs(label_slice.astype(float)-organ_id)/15.0)

      print(np.unique(label_slice))
      plt.figure()
      plt.subplot(121)
      plt.imshow(data_slice,cmap = "gray")

      plt.subplot(122)
      plt.imshow(label_slice,cmap = "gray")
      plt.show()

  def len_3d(self):
    return len(self.data_files)

  def get_3d_array(self,idx):
      """
      Loads a 3D image as a tensor as well as its label, mode, and file name
      """
      assert idx < len(self.data_files) , "3D image index out of range, there are {} 3D images".format(len(self.data_files))

      # load data and label as tensors
      data = nib.load(self.data_files[idx])
      label = nib.load(self.label_files[idx])

      data = data.get_fdata()
      data = torch.from_numpy(np.array(data))
      label = label.get_fdata()
      label = torch.from_numpy(np.array(label)).int()

      # note whether image is training or validation set
      if idx < 0.85 * len(self.data_files):
        mode = "train"
      else:
        mode = "val"
      
      return data,label,mode,self.data_files[idx]

def load_model(checkpoint_file,model,optimizer):
  """
  Reloads a checkpoint, loading the model and optimizer state_dicts and 
  setting the start epoch
  """
  checkpoint = torch.load(checkpoint_file)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  all_losses = checkpoint['losses']
  all_accs = checkpoint['accs']

  return model,optimizer,epoch,all_losses,all_accs


def dt_refine(orig,tree):
  """
  Applies decision tree ensembling to original outputs
  orig = classes x l x w x h
  """
  output = torch.zeros((orig.shape[1],orig.shape[2],orig.shape[3]))

  for i in range(0,orig.shape[1]):
    for j in range(0,orig.shape[2]):
      
        inp = np.zeros([orig.shape[3],17])
        inp[:,:14] = orig[:,i,j,:].data.numpy().transpose()
        inp[:,14] = i
        inp[:,15] = j
        for k in range(0,orig.shape[3]):
          inp[k,16] = k

        output[i,j,:] = torch.from_numpy(tree.predict(inp))
  
  return output


def segment(model,data,device,axis = 2,outfile = None):
  """
  Takes in 3D tensor, slices and segments using UNET, and returns result
  Note that result will be a [0,1] tensor corresponding to a certain class (whatever model class was used)
  if outfile is not none, saves results in  that file 
  """

  num_slices = data.shape[axis]
  result = torch.zeros(data.shape)

  for idx in range(num_slices):
    # resize to 256 x 256
    if axis == 0:
      slice = data[idx,:,:]
    elif axis == 1:
      slice = data[:,idx,:]
    elif axis == 2:
      slice = data[:,:,idx]

    original_shape = slice.shape

    slice =  Image.fromarray(slice.data.numpy()).copy()
    slice = FT.to_grayscale(slice)
    slice = FT.to_tensor(slice)
    slice = slice.unsqueeze(0)

    slice = F.interpolate(slice,size = [256,256],mode = 'bilinear')

    x = slice.to(device).float()

    out_slice = model(x)
    out_slice = F.interpolate(out_slice,original_shape,mode = 'bilinear')

    if axis == 0:
      result[idx,:,:] = out_slice.data.cpu()
    elif axis == 1:
      result[:,idx,:] = out_slice.data.cpu()
    elif axis == 2:
      result[:,:,idx] = out_slice.data.cpu()

  if outfile:
    torch.save(result, outfile)

  return result

#def segment(model,checkpoint_dict,data,device,outfile = None):
#  """
#  Takes a 3D tensor and segments it with a series of models, predicting maximum class for each
#  """
#  result = torch.zeros((data.shape[0],data.shape[1],data.shape[2]))
#  result = predict_by_slices(model,data,device,outfile = None)
#  return result


def dice_3D(output,target,threshold= 0.5,eps=1e-07):
    ones = torch.ones(target.shape)
    zeros = torch.zeros(target.shape)
    true_pos = torch.where(target == 1,ones,zeros)
    pred_pos = torch.where(output > threshold,ones,zeros)

    numerator = 2.0 * torch.mul(true_pos,pred_pos) + eps
    denominator = true_pos + pred_pos + eps

    dice = (numerator.sum())/denominator.sum()
    return dice


In [0]:
def thresh_plot(data,step = 0.02):
  plt.figure(figsize = (10,15))
  from matplotlib import style
  style.use("fivethirtyeight")

  x = np.arange(0,1,step)
  legend = []
  for key in data:
    series = data[key]
    plt.plot(x,series)
    legend.append(key)
  
  plt.xlabel("Discrimination Threshold")
  plt.ylabel("Dice Accuracy")
  plt.xlim([0,1])
  plt.ylim([0,1])
  plt.title("Discrimination Threshold Accuracy for Various Pooling Sizes")
  plt.legend(legend)
  plt.show()

In [0]:
def reachability_cluster_mask(input):
  """
  Uses reachability to cluster a 3D tensor of 0,1
  """
  data = input.clone()
  cluster_val = 2
  cluster_sizes = {}

  ones = torch.ones(input.shape)
  zeros = torch.zeros(input.shape)
  original_length = len(torch.where(data == 1, ones, zeros).nonzero())
  queue = []
  contin = True
  while True:
    if cluster_val > 1000:
      break
    cluster_sizes[cluster_val] = 0
    mask = torch.where(data==1,ones,zeros)
    indices = mask.nonzero()

    # if any values in tensor are still 1 continue (unclustered)
    if len(indices) > 1:
      first = indices[0]
      queue.append(first)
    else:
      contin = False
      break

    #print("Assigning cluster {}".format(cluster_val))

    while len(queue) > 0:
      cur = queue.pop(0)
      val = data[cur[0],cur[1],cur[2]]
      if val == 1:
        cluster_sizes[cluster_val] += 1
        data[cur[0],cur[1],cur[2]] = cluster_val # assign to cluster
        # add neighbors to queue
        add_indices = [[cur[0],cur[1],cur[2]+1],
                        [cur[0],cur[1],cur[2]-1],
                        [cur[0],cur[1]+1,cur[2]],
                        [cur[0],cur[1]-1,cur[2]],
                        [cur[0]-1,cur[1],cur[2]],
                        [cur[0]+1,cur[1],cur[2]]]
        if cur[0] < data.shape[0]-1 and cur[1] < data.shape[1]-1 and cur[2] < data.shape[2] - 1:
          for item in add_indices:
            queue.append(item)

    cluster_val += 1
    
  biggest_idx = 2
  biggest_size = 0
  for key in cluster_sizes:
    if cluster_sizes[key] > biggest_size:
      biggest_idx = key
      biggest_size = cluster_sizes[key]
  # get biggest cluster
  mask1 = torch.where((data == biggest_idx), ones, 1.0 - ones)
  mask2 = torch.where((data == 0), ones, 1.0 - ones)
  mask = mask1+mask2
  return mask

In [0]:
def predict(volume,model,device):
  checkpoint = torch.load(checkpoint_dict["dim0"])
  model.load_state_dict(checkpoint['model_state_dict'])
  baseline0 = segment(model,volume,device,axis = 0)

  checkpoint = torch.load(checkpoint_dict["dim1"])
  model.load_state_dict(checkpoint['model_state_dict'])
  baseline1 = segment(model,volume,device,axis = 1)

  checkpoint = torch.load(checkpoint_dict["dim2"])
  model.load_state_dict(checkpoint['model_state_dict'])
  baseline2 = segment(model,volume,device,axis = 2)

  # combine results and avgpool with kernel size 3, and threshold at 0.4
  combo = 0.4*baseline0 + 0.2* baseline1 + 0.4* baseline2
  p=3
  avg_combo = F.avg_pool3d(combo.unsqueeze(0),p,stride = 1,padding = int((p-1)/2)).squeeze(0)
  ones = torch.ones(avg_combo.shape)
  avg_combo = torch.where(avg_combo > 0.4,ones, 1.0-ones)

  # remove clusters from combination
  mask = reachability_cluster_mask(avg_combo)
  declustered = torch.mul(mask,combo)

  # apply final average pooling with threshold ___
  final = F.avg_pool3d(combo.unsqueeze(0),p,stride = 1,padding = int((p-1)/2)).squeeze(0)
  final_thresholded = torch.where(final > 0.4,ones, 1.0-ones)

  return final_thresholded


# Load everything up

In [0]:
checkpoint_dict = {
    "dim0":"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/organ1_dim0_e22.pt",
    "dim1":"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/organ1_dim1_e15.pt",
    "dim2":"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/organ1_dim2_e120.pt"
}

In [13]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.cuda.empty_cache()   

model = UNet()
print ("Model loaded.")
model = model.to(device)
model.eval()

checkpoint = torch.load(checkpoint_dict['dim1'])
model.load_state_dict(checkpoint['model_state_dict'])

try:
  dataset
except:
  dataset = Nifti_Dataset(mode = "train",dim = 2)

Model loaded.


In [19]:
train_accs = []
val_accs = []

for idx in range(dataset.len_3d()):
  volume,label,mode,name = dataset.get_3d_array(idx)

  output = predict(volume,model,device)
  dice = dice_3D(output,label)  
  if mode == "train":
    train_accs.append(dice)
  elif mode == "val":
    val_accs.append(dice)
  
  print("{} volume {} dice score: {}".format(mode,idx,dice))

train_acc = sum(train_accs)/len(train_accs)
val_acc = sum(val_accs)/len(val_accs)

  "See the documentation of nn.Upsample for details.".format(mode))


train volume 0 dice score: 0.9473950862884521
train volume 1 dice score: 0.92239910364151
train volume 2 dice score: 0.8817955255508423
train volume 3 dice score: 0.9443104267120361
train volume 4 dice score: 0.9367591738700867
train volume 5 dice score: 0.8528770208358765
train volume 6 dice score: 0.8491819500923157
train volume 7 dice score: 0.9404333233833313
train volume 8 dice score: 0.9493824243545532
train volume 9 dice score: 0.921882152557373
train volume 10 dice score: 0.9110795855522156
train volume 11 dice score: 0.9425414204597473
train volume 12 dice score: 0.9118952751159668
train volume 13 dice score: 0.950798511505127
train volume 14 dice score: 0.8930566906929016
train volume 15 dice score: 0.9424028396606445
train volume 16 dice score: 0.9366868734359741
train volume 17 dice score: 0.9535695314407349
train volume 18 dice score: 0.9387070536613464
train volume 19 dice score: 0.9370836019515991
train volume 20 dice score: 0.9355854988098145
train volume 21 dice score:

In [27]:
train = np.array(train_accs)
print(np.mean(train))
print(np.std(train))

0.9197182
0.049323957


In [32]:
val = np.array(val_accs)
print(np.mean(val))
print(np.std(val))


0.90080816
0.05478016


In [31]:
# generate test volumes
test_directory = "/content/drive/My Drive/Colab Notebooks/Segmentation/RawData/Testing/img"

for file in os.listdir(test_directory):
  full_file = os.path.join(test_directory,file)

  data = nib.load(full_file)
  outname = "predictions_" + file
  data = data.get_fdata()
  volume = torch.from_numpy(np.array(data))
  
  output = predict(volume,model,device).data.numpy()
  out_file = nib.Nifti1Image(output,affine=np.eye(4))
  nib.save(out_file, os.path.join(test_directory,outname))
  print("Wrote output file {}".format(outname))


  "See the documentation of nn.Upsample for details.".format(mode))


Wrote output file predictions_img0061.nii.gz
Wrote output file predictions_img0062.nii.gz
Wrote output file predictions_img0063.nii.gz
Wrote output file predictions_img0064.nii.gz
Wrote output file predictions_img0065.nii.gz
Wrote output file predictions_img0066.nii.gz
Wrote output file predictions_img0067.nii.gz
Wrote output file predictions_img0068.nii.gz
Wrote output file predictions_img0069.nii.gz
Wrote output file predictions_img0070.nii.gz
Wrote output file predictions_img0071.nii.gz
Wrote output file predictions_img0072.nii.gz
Wrote output file predictions_img0073.nii.gz
Wrote output file predictions_img0074.nii.gz
Wrote output file predictions_img0075.nii.gz
Wrote output file predictions_img0076.nii.gz
Wrote output file predictions_img0077.nii.gz
Wrote output file predictions_img0078.nii.gz
Wrote output file predictions_img0079.nii.gz
Wrote output file predictions_img0080.nii.gz
