<a href="https://colab.research.google.com/github/DerekGloudemans/segmentation-medical-images/blob/master/Select_Final_Model.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
# Mount drive
from google.colab import drive
drive.mount('/content/drive',force_remount=True)

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3aietf%3awg%3aoauth%3a2.0%3aoob&response_type=code&scope=email%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdocs.test%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive%20https%3a%2f%2fwww.googleapis.com%2fauth%2fdrive.photos.readonly%20https%3a%2f%2fwww.googleapis.com%2fauth%2fpeopleapi.readonly

Enter your authorization code:
··········
Mounted at /content/drive


In [0]:
#%%capture 
!pip install -q --upgrade ipython==5.5.0
!pip install -q --upgrade ipykernel==4.6.0
!pip3 install torchvision
!pip3 install opencv-python

import ipywidgets
import traitlets
# imports

# this seems to be a popular thing to do so I've done it here
#from __future__ import print_function, division


# torch and specific torch packages for convenience
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torch.utils import data
from torch import multiprocessing
from google.colab.patches import cv2_imshow

# for convenient data loading, image representation and dataset management
from torchvision import models, transforms
import torchvision.transforms.functional as FT
from PIL import Image, ImageFile, ImageStat
ImageFile.LOAD_TRUNCATED_IMAGES = True
from scipy.ndimage import affine_transform
import cv2

# always good to have
import time
import os
import numpy as np    
import _pickle as pickle
import random
import copy
import matplotlib.pyplot as plt
import math

import nibabel as nib



[?25l[K     |███▏                            | 10kB 29.6MB/s eta 0:00:01[K     |██████▎                         | 20kB 1.7MB/s eta 0:00:01[K     |█████████▍                      | 30kB 2.3MB/s eta 0:00:01[K     |████████████▋                   | 40kB 1.7MB/s eta 0:00:01[K     |███████████████▊                | 51kB 1.9MB/s eta 0:00:01[K     |██████████████████▉             | 61kB 2.2MB/s eta 0:00:01[K     |██████████████████████          | 71kB 2.4MB/s eta 0:00:01[K     |█████████████████████████▏      | 81kB 2.6MB/s eta 0:00:01[K     |████████████████████████████▎   | 92kB 2.9MB/s eta 0:00:01[K     |███████████████████████████████▍| 102kB 2.8MB/s eta 0:00:01[K     |████████████████████████████████| 112kB 2.8MB/s 


In [0]:
class UNet(nn.Module):
    def __init__(
        self,
        in_channels=1,
        n_classes=1,
        depth=3,
        wf=4,
        padding=True,
        batch_norm=False,
        up_mode='upconv',
    ):
        """
        Implementation of
        U-Net: Convolutional Networks for Biomedical Image Segmentation
        (Ronneberger et al., 2015)
        https://arxiv.org/abs/1505.04597
        Using the default arguments will yield the exact version used
        in the original paper
        Args:
            in_channels (int): number of input channels
            n_classes (int): number of output channels
            depth (int): depth of the network
            wf (int): number of filters in the first layer is 2**wf
            padding (bool): if True, apply padding such that the input shape
                            is the same as the output.
                            This may introduce artifacts
            batch_norm (bool): Use BatchNorm after layers with an
                               activation function
            up_mode (str): one of 'upconv' or 'upsample'.
                           'upconv' will use transposed convolutions for
                           learned upsampling.
                           'upsample' will use bilinear upsampling.
        """
        super(UNet, self).__init__()
        assert up_mode in ('upconv', 'upsample')
        self.padding = padding
        self.depth = depth
        prev_channels = in_channels
        self.down_path = nn.ModuleList()
        for i in range(depth):
            self.down_path.append(
                UNetConvBlock(prev_channels, 2 ** (wf + i), padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.up_path = nn.ModuleList()
        for i in reversed(range(depth - 1)):
            self.up_path.append(
                UNetUpBlock(prev_channels, 2 ** (wf + i), up_mode, padding, batch_norm)
            )
            prev_channels = 2 ** (wf + i)

        self.last = nn.Conv2d(prev_channels, n_classes, kernel_size=1)

        innum = 1000
        midnum = 100
        outnum = 4*n_classes
        self.reg = nn.Sequential(
            nn.BatchNorm1d(innum),
            nn.Linear(innum,midnum),
            nn.ReLU(),
            nn.Dropout(),
            nn.Linear(midnum,outnum),
            nn.Sigmoid()
        )

        #for param in self.parameters():
        #  param.requires_grad = True

    def forward(self, x,BBOX = False):
        blocks = []
        
        # encoder
        for i, down in enumerate(self.down_path):
            x = down(x)
            if i != len(self.down_path) - 1:
                blocks.append(x)
                x = F.max_pool2d(x, 2)

        # do bbox regression here
        if BBOX:
          x_reg = x.view(-1)
          x_reg = self.reg(x_reg)
          bboxes = x_reg.view(4,-1)

        # decoder
        for i, up in enumerate(self.up_path):
            x = up(x, blocks[-i - 1])

        #CHANGE THIS LINE FOR MULTIPLE OUTPUT CHANNELS
        # apply per_class last layer and per-class Softmax 
        #x = nn.Softmax2d(self.last(x)) 
        x = torch.sigmoid(self.last(x))
        
        if BBOX:
          return x, bboxes
        else:
          return x


class UNetConvBlock(nn.Module):
    def __init__(self, in_size, out_size, padding, batch_norm):
        super(UNetConvBlock, self).__init__()
        block = []

        block.append(nn.Conv2d(in_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        block.append(nn.Conv2d(out_size, out_size, kernel_size=3, padding=int(padding)))
        block.append(nn.ReLU())
        if batch_norm:
            block.append(nn.BatchNorm2d(out_size))

        self.block = nn.Sequential(*block)

    def forward(self, x):
        out = self.block(x)
        return out


class UNetUpBlock(nn.Module):
    def __init__(self, in_size, out_size, up_mode, padding, batch_norm):
        super(UNetUpBlock, self).__init__()
        if up_mode == 'upconv':
            self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
        elif up_mode == 'upsample':
            self.up = nn.Sequential(
                nn.Upsample(mode='bilinear', scale_factor=2),
                nn.Conv2d(in_size, out_size, kernel_size=1),
            )

        self.conv_block = UNetConvBlock(in_size, out_size, padding, batch_norm)

    def center_crop(self, layer, target_size):
        _, _, layer_height, layer_width = layer.size()
        diff_y = (layer_height - target_size[0]) // 2
        diff_x = (layer_width - target_size[1]) // 2
        return layer[
            :, :, diff_y : (diff_y + target_size[0]), diff_x : (diff_x + target_size[1])
        ]

    def forward(self, x, bridge):
        up = self.up(x)
        crop1 = self.center_crop(bridge, up.shape[2:])
        out = torch.cat([up, crop1], 1)
        out = self.conv_block(out)

        return out

class Nifti_Dataset(data.Dataset):
  def __init__(self,mode = "view",dim = 2,class_id = 1):
    """
    Save the last 0.15 proportion of files after sorting for use as validation set.
    Loads each slice of the input images as a separate image
    mode - view,train or val, defined in same dataset to maintain data separation
      view - performs transforms but does not normalize images
      train - normalizes data and performs transforms
      val - normalizes data, no augmenting transforms
    dim - specifies dimension along which to slice image
    """

    self.mode = mode
    self.dim = dim
    self.class_id = class_id

    data_dir = "/content/drive/My Drive/Colab Notebooks/Segmentation/RawData/Training/img"
    label_dir = "/content/drive/My Drive/Colab Notebooks/Segmentation/RawData/Training/label"

    # get all data and label file names
    self.data_files = []
    for file in os.listdir(data_dir):
      self.data_files.append(os.path.join(data_dir,file))
    self.data_files.sort()

    self.label_files = []
    for file in os.listdir(label_dir):
      self.label_files.append(os.path.join(label_dir,file))
    self.label_files.sort()

    # for each data_file
    self.train_data = []
    self.val_data = []

    for i in range(len(self.data_files)):
      data = nib.load(self.data_files[i])
      data = np.array(data.get_fdata())

      label = nib.load(self.label_files[i])
      label = np.array(label.get_fdata()).astype(float)

      identifier = self.data_files[i].split("_")[0]
      for slice in range(0,data.shape[dim]):

        # get slices
        if dim == 0:
          data_slice = data[slice,:,:]
          label_slice = label[slice,:,:]
        elif dim == 1:
          data_slice = data[:,slice,:]
          label_slice = label[:,slice,:]
        elif dim == 2:
          data_slice = data[:,:,slice]
          label_slice = label[:,:,slice]

        mean,std = np.mean(data_slice),np.std(data_slice)
        # define item dict to store info
        item = {
            "identifier":identifier,
            "slice":slice,
            "data":data_slice,
            "label":label_slice,
            "mean":mean,
            "std":std
            }

        # check to make sure this example actually has organs in it
        test = np.bincount(label_slice.astype(int).reshape(-1))
        if len(test) == 1: 
          continue

        # make sure 50% of examples have organ of interest in them
        if class_id not in np.unique(label_slice.astype(int).reshape(-1)):
            if np.random.rand() > 0.5:
              continue

        # assign to either training or validation data
        if i < len(self.data_files) * 0.85:
          self.train_data.append(item)
        else:
          self.val_data.append(item)

      #break # to shorten loading time

    # define some transforms for training dataset
    self.train_transforms = transforms.Compose([
          transforms.ColorJitter(brightness = 0.2,contrast = 0.2,saturation = 0.1),
          transforms.ToTensor(),
          transforms.RandomErasing(p=0.0015, scale=(0.4, 0.6), ratio=(0.3, 3.3), value=0, inplace=False), # big
          transforms.RandomErasing(p=0.003, scale=(0.1, 0.3), ratio=(0.3, 3.3), value=0, inplace=False), # medium
          transforms.RandomErasing(p=0.004, scale=(0.05, 0.15), ratio=(0.3, 3.3), value=0, inplace=False),# small
          transforms.RandomErasing(p=0.0035, scale=(0.05, 0.2), ratio=(0.3, 3.3), value=0, inplace=False) # small
          

        ])

  def __getitem__(self,index):
      #get relevant dictionary from self memory

      if self.mode in ['train','view']:
        item = self.train_data[index]
      else:
        item = self.val_data[index]

      x = Image.fromarray(item['data']).copy()
      y = Image.fromarray(item['label']).copy()

      # to grayscale
      x  = FT.to_grayscale(x)
      y = FT.to_grayscale(y)

      if self.mode in ['train','view']:
        # randomly flip and rotate both
        FLIP = 0 #np.random.rand()
        if FLIP > 0.5:
          x = FT.hflip(x)
          y = FT.hflip(y)

        ROTATE = 0 #np.random.rand()*60 - 30
        x  = x.rotate(ROTATE)
        y = y.rotate(ROTATE,Image.NEAREST)

      # resize to 224 on shorter dimension
      x = FT.resize(x, 256)
      y = FT.resize(y,256,Image.NEAREST)

      if self.mode in ['train','view']:
        # randomly jitter color of data and randomly erase data
        x = self.train_transforms(x)
      # to tensor
      try:
        x = FT.to_tensor(x)
      except:
        pass
      y = FT.to_tensor(y)
            
      # normalize and repeat along color dimension if in train or val mode
      if self.mode in ['train','val']:
        #x = FT.normalize(x,[item['mean']],[item['std']])
        #x = x.repeat(3,1,1)
        pass

      return x,y

  def __len__(self):
    if self.mode in ["train","view"]:
      return len(self.train_data)
    else:
      return len(self.val_data)

  def show(self,index):
    data,label = self[index]
    plt.figure()
    plt.subplot(121)
    data = data.detach()
    plt.imshow(data[0],cmap = "gray")

    plt.subplot(122)
    plt.imshow(label[0],cmap = "gray")
    plt.show()
    # convert each tensor to numpy array

  def show_slices(self,idx = 0,dim = 0,organ_id = None):
    """
    A nice utility function for plotting all of the slices along a given dimension
    idx - indexes all NIfTI images in dataset
    dim - indexes dimension of image
    organ_id - if not None, all other organs removed from label
    """
    data = nib.load(self.data_files[idx])
    label = nib.load(self.label_files[idx])

    data = data.get_fdata()
    data = np.array(data)
    label = label.get_fdata()
    label = np.array(label)

    for slice in range(0,data.shape[dim]):
      if dim == 0:
            data_slice = data[slice,:,:]
            label_slice = label[slice,:,:]
      elif dim == 1:
        data_slice = data[:,slice,:]
        label_slice = label[:,slice,:]
      elif dim == 2:
        data_slice = data[:,:,slice]
        label_slice = label[:,:,slice]

      if organ_id is not None:
        # if a specific label is to be looked at, 0 all others
        label_slice = 1.0 - np.ceil(np.abs(label_slice.astype(float)-organ_id)/15.0)

      print(np.unique(label_slice))
      plt.figure()
      plt.subplot(121)
      plt.imshow(data_slice,cmap = "gray")

      plt.subplot(122)
      plt.imshow(label_slice,cmap = "gray")
      plt.show()

  def len_3d(self):
    return len(self.data_files)

  def get_3d_array(self,idx):
      """
      Loads a 3D image as a tensor as well as its label, mode, and file name
      """
      assert idx < len(self.data_files) , "3D image index out of range, there are {} 3D images".format(len(self.data_files))

      # load data and label as tensors
      data = nib.load(self.data_files[idx])
      label = nib.load(self.label_files[idx])

      data = data.get_fdata()
      data = torch.from_numpy(np.array(data))
      label = label.get_fdata()
      label = torch.from_numpy(np.array(label)).int()

      # note whether image is training or validation set
      if idx < 0.85 * len(self.data_files):
        mode = "train"
      else:
        mode = "val"
      
      return data,label,mode,self.data_files[idx]

def load_model(checkpoint_file,model,optimizer):
  """
  Reloads a checkpoint, loading the model and optimizer state_dicts and 
  setting the start epoch
  """
  checkpoint = torch.load(checkpoint_file)
  model.load_state_dict(checkpoint['model_state_dict'])
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  epoch = checkpoint['epoch']
  all_losses = checkpoint['losses']
  all_accs = checkpoint['accs']

  return model,optimizer,epoch,all_losses,all_accs


def dt_refine(orig,tree):
  """
  Applies decision tree ensembling to original outputs
  orig = classes x l x w x h
  """
  output = torch.zeros((orig.shape[1],orig.shape[2],orig.shape[3]))

  for i in range(0,orig.shape[1]):
    for j in range(0,orig.shape[2]):
      
        inp = np.zeros([orig.shape[3],17])
        inp[:,:14] = orig[:,i,j,:].data.numpy().transpose()
        inp[:,14] = i
        inp[:,15] = j
        for k in range(0,orig.shape[3]):
          inp[k,16] = k

        output[i,j,:] = torch.from_numpy(tree.predict(inp))
  
  return output


def predict_by_slices(model,data,device,outfile = None):
  """
  Takes in 3D tensor, slices and segments using UNET, and returns result
  Note that result will be a [0,1] tensor corresponding to a certain class (whatever model class was used)
  if outfile is not none, saves results in  that file 
  """

  num_slices = data.shape[2]
  result = torch.zeros(data.shape)

  for idx in range(num_slices):
    # resize to 256 x 256
    slice = data[:,:,idx]
    original_shape = slice.shape

    slice =  Image.fromarray(slice.data.numpy()).copy()
    slice = FT.to_grayscale(slice)
    slice = FT.to_tensor(slice)
    slice = slice.unsqueeze(0)

    slice = F.interpolate(slice,size = [256,256],mode = 'bilinear')

    x = slice.to(device).float()

    out_slice = model(x)
    out_slice = F.interpolate(out_slice,original_shape,mode = 'nearest')
    result[:,:,idx] = out_slice.data.cpu()

  if outfile:
    torch.save(result, outfile)

  return result

def segment(model,checkpoint_dict,data,device,outfile = None):
  """
  Takes a 3D tensor and segments it with a series of models, predicting maximum class for each
  """


  result = torch.zeros((14,data.shape[0],data.shape[1],data.shape[2]))

  for i in range(0,14):
    checkpoint = torch.load(checkpoint_dict[i])
    model.load_state_dict(checkpoint['model_state_dict'])
    result[i,:,:,:] = predict_by_slices(model,data,device,outfile = None)
    #print("Finished class {}".format(i))
  
  #result = result.round()
  #result[0,:,:,:] = result[0,:,:,:].round()*0.3
  #result = torch.argmax(result,dim = 0)

  return result

def Dice_3D(output,target,eps = 1e-07,threshold = 0.5):
  """
  Assumes output is 4D and target is 3D
  """
  per_class_dice = torch.zeros(14)
  per_class_counts = torch.zeros(14)
  ones = torch.ones(target[0].shape)
  zeros = torch.zeros(target[0].shape)

  for idx in range(0,14):
    true_pos = torch.where(target == idx,ones,zeros)
    pred_pos = torch.where(output[idx,:,:,:] > threshold,ones,zeros)

    numerator = 2.0 * torch.mul(true_pos,pred_pos)
    denominator = true_pos + pred_pos

    per_class_dice[idx] = (numerator.sum()+eps)/(denominator.sum()+eps)
    per_class_counts[idx] = true_pos.sum()
  
  # ignore 0s
  total_dice = torch.sum(torch.mul(per_class_dice[1:].float(),per_class_counts[1:]) / torch.numel(target))
  return per_class_dice,total_dice

def Dice_3D_alt(output,target,eps = 1e-07):
  """
  Assumes output is 3D and target is 3D
  """
  per_class_dice = torch.zeros(14)
  per_class_counts = torch.zeros(14)
  ones = torch.ones(target[0].shape)
  zeros = torch.zeros(target[0].shape)

  for idx in range(0,14):
    true_pos = torch.where(target == idx,ones,zeros)
    pred_pos = torch.where(output == idx,ones,zeros)

    numerator = 2.0 * torch.mul(true_pos,pred_pos)
    denominator = true_pos + pred_pos

    per_class_dice[idx] = (numerator.sum()+eps)/(denominator.sum()+eps)
    per_class_counts[idx] = true_pos.sum()
  
  # ignore 0s
  total_dice = torch.sum(torch.mul(per_class_dice[1:].float(),per_class_counts[1:]) / torch.numel(target))
  return per_class_dice,total_dice

def single_class_dice_3D(output,target,threshold,idx = 1,eps=1e-07):
    ones = torch.ones(target.shape)
    zeros = torch.zeros(target.shape)
    true_pos = torch.where(target == idx,ones,zeros)
    pred_pos = torch.where(output > threshold,ones,zeros)

    numerator = 2.0 * torch.mul(true_pos,pred_pos) + eps
    denominator = true_pos + pred_pos + eps

    dice = (numerator.sum())/denominator.sum()
    return dice


In [0]:
checkpoint_dict = {
    0:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ0_e7.pt",
    1:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ1_e120.pt",
    2:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ2_e119.pt",
    3:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ3_e14.pt",
    4:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ4_e2.pt",
    5:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ5_e26.pt",
    6:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ6_e45.pt",
    7:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ7_e2.pt",
    8:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ8_e26.pt",
    9:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ9_e37.pt",
    10:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ10_e80.pt",
    11:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ11_e33.pt",
    12:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ12_e46.pt",
    13:"/content/drive/My Drive/Colab Notebooks/Segmentation/checkpoints/Keepers/Final/best_UNET11_organ13_e34.pt"

}

In [0]:
# CUDA for PyTorch
use_cuda = torch.cuda.is_available()
device = torch.device("cuda:0" if use_cuda else "cpu")
torch.cuda.empty_cache()   

model = UNet()
print ("Model loaded.")
model = model.to(device)
model.eval()

checkpoint = torch.load(checkpoint_dict[6])
model.load_state_dict(checkpoint['model_state_dict'])

try:
  dataset
except:
  dataset = Nifti_Dataset(mode = "train",dim = 2)

Model loaded.


## Create Datasets from all Training Data

In [0]:
if False:
  file_path = "/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/"
  samples_per_file = 100000
  #sample_points = np.zeros([samples_per_file*dataset.len_3d(),18])
  #sample_points_avg5 = np.zeros([samples_per_file*dataset.len_3d(),18])
  sample_points_avg9 = np.zeros([samples_per_file*dataset.len_3d(),18])
  completed = 0

  for idx in range(0,dataset.len_3d()):
    x,y,mode,name = dataset.get_3d_array(idx)
    
    if mode == "train":
      result = segment(model,checkpoint_dict,x,device)
      #result_avg5 = F.avg_pool3d(result.unsqueeze(0),5,stride = 1,padding = 2).squeeze(0)
      result_avg9 = F.avg_pool3d(result.unsqueeze(0),9,stride = 1,padding = 4).squeeze(0)

      #file_name = file_path + "train_file{}.pt".format(idx)
      #torch.save(result,file_name)
      print("Saved results for train file {}".format(idx))


      # randomly sample points from result and append to numpy array
      for count in range(0,samples_per_file):
        i = np.random.randint(0,result.shape[1])
        j = np.random.randint(0,result.shape[2])
        k = np.random.randint(0,result.shape[3])

        for cls in range(0,14):
          #sample_points[completed*samples_per_file + count, cls] =  result[cls,i,j,k]
          #sample_points_avg5[completed*samples_per_file + count, cls] =  result_avg5[cls,i,j,k]
          sample_points_avg9[completed*samples_per_file + count, cls] =  result_avg9[cls,i,j,k]

        #sample_points[completed*samples_per_file + count, 14] = i
        #sample_points[completed*samples_per_file + count, 15] = j
        #sample_points[completed*samples_per_file + count, 16] = k
        #sample_points[completed*samples_per_file + count, 17] = y[i,j,k]

        #sample_points_avg5[completed*samples_per_file + count, 14] = i
        #sample_points_avg5[completed*samples_per_file + count, 15] = j
        #sample_points_avg5[completed*samples_per_file + count, 16] = k
        #sample_points_avg5[completed*samples_per_file + count, 17] = y[i,j,k]

        sample_points_avg9[completed*samples_per_file + count, 14] = i
        sample_points_avg9[completed*samples_per_file + count, 15] = j
        sample_points_avg9[completed*samples_per_file + count, 16] = k
        sample_points_avg9[completed*samples_per_file + count, 17] = y[i,j,k]

      completed += 1
      print("Sampled points for train file {}".format(idx))

      #np.save(file_path + "sample_points{}.npy".format(idx),sample_points)
      #np.save(file_path + "sample_points_avg5_{}.npy".format(idx),sample_points_avg5)
      np.save(file_path + "sample_points_avg9_{}.npy".format(idx),sample_points_avg9)

      print("Saved copy of sample points.")
      
      del result_avg9
      del result
      del y
    
    elif mode == "": 
      pass  
      

## Fit Decision Trees to Each

In [0]:
from sklearn.tree import DecisionTreeClassifier
import _pickle as pickle      

if False:    
  # fit first dataset
  all_data = np.load("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/sample_points.npy")
  tree = DecisionTreeClassifier()
  x = all_data[:,:17]
  y = all_data[:,17]
  tree.fit(x,y)
  with open("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/decision_tree.cpkl",'wb') as f:
    pickle.dump(tree,f)
  del all_data

if False:
  # fit second dataset
  all_data = np.load("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/sample_points_avg5.npy")
  tree = DecisionTreeClassifier()
  x = all_data[:,:17]
  y = all_data[:,17]
  tree.fit(x,y)
  with open("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/decision_tree_avg5.cpkl",'wb') as f:
    pickle.dump(tree,f)
  del all_data

if True:
  # fit third dataset
  all_data = np.load("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/sample_points_avg9.npy")
  tree = DecisionTreeClassifier()
  x = all_data[:,:17]
  y = all_data[:,17]
  tree.fit(x,y)
  with open("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/decision_tree_avg9.cpkl",'wb') as f:
    pickle.dump(tree,f)
  del all_data

## Get results on Training Dataset
Load each training file and predict labels with each of the following 8 strategies. Maintain running scores for each strategy, and report scores at the end.
- Base Model (threshold 0.8)
- Base Model with Average Pooling Kernel size 5 (threshold 0.5, 0.75)
- Base Model with Average Pooling Kernel size 9 (threshold 0.45, 0.75)
- DT 
- DT with Average Pooling Kernel size 5
- DT with Average Pooling Kernel size 9



In [0]:
all_results = {
    "base":[],
    "base_avg5_0.5":[],
    "base_avg5_0.75":[],
    "base_avg9_0.45":[],
    "base_avg9_0.75":[],
    "dt":[],
    "dt_avg5":[],
    "dt_avg9":[]
}

if True:
  for idx in range(0,dataset.len_3d()):
      print("Processing file {}".format(idx))
      x,y,mode,name = dataset.get_3d_array(idx)
      
      if mode == "val":

        
        result = segment(model,checkpoint_dict,x,device)
        all_results["base"].append(Dice_3D(result,y))
        print("Got base result")


        result_avg5 = F.avg_pool3d(result.unsqueeze(0),5,stride = 1,padding = 2).squeeze(0)
        all_results["base_avg5_0.5"].append(Dice_3D(result_avg5,y,threshold = 0.5))
        all_results["base_avg5_0.75"].append(Dice_3D(result_avg5,y,threshold = 0.75))
        print("Got base avgpooled results")
        del result_avg5

        #with open("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/decision_tree_avg5.cpkl",'rb') as f:
        #  tree = pickle.load(f)
        #result_dt_avg5 = dt_refine(result_avg5,tree)
        #all_results["dt_avg5"].append(Dice_3D_alt(result_dt_avg5,y))
        #del result_avg5
        #del result_dt_avg5
        #print("Got dt avgpooled results")

        result_avg9 = F.avg_pool3d(result.unsqueeze(0),9,stride = 1,padding = 4).squeeze(0)
        all_results["base_avg9_0.45"].append(Dice_3D(result_avg9,y,threshold = 0.45))
        all_results["base_avg9_0.75"].append(Dice_3D(result_avg9,y,threshold = 0.75))
        del result
        del result_avg9

          #with open("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/decision_tree_avg9.cpkl",'rb') as f:
          #  tree = pickle.load(f)
          #result_dt_avg9 = dt_refine(result_avg9,tree)
          #all_results["dt_avg9"].append(Dice_3D_alt(result_dt_avg9,y))
          #del result_avg9
          #del result_dt_avg9
    
        
      elif mode == "train":
        pass

Processing file 0
Processing file 1
Processing file 2
Processing file 3
Processing file 4
Processing file 5
Processing file 6
Processing file 7
Processing file 8
Processing file 9
Processing file 10
Processing file 11
Processing file 12
Processing file 13
Processing file 14
Processing file 15
Processing file 16
Processing file 17
Processing file 18
Processing file 19
Processing file 20
Processing file 21
Processing file 22
Processing file 23
Processing file 24
Processing file 25
Processing file 26


  "See the documentation of nn.Upsample for details.".format(mode))


Got base result
Got base avgpooled results
Processing file 27
Got base result
Got base avgpooled results
Processing file 28
Got base result
Got base avgpooled results
Processing file 29
Got base result
Got base avgpooled results


In [0]:
# Parse all_results

for key in all_results:
  result = all_results[key]
  try:
    running_total = torch.zeros(result[0][0].shape)
    count = 0
    for item in result:
      running_total += item[0]
      count+=1
    
    avg_results = running_total / count

    all_results[key] = avg_results
  except:
    pass
  #print(key, all_results[key])



In [0]:
for key in all_results:
  print(key,all_results[key])

base tensor([0.9881, 0.6069, 0.0096, 0.8458, 0.5231, 0.5941, 0.8153, 0.0486, 0.7745,
        0.6009, 0.5919, 0.3361, 0.4871, 0.3413])
base_avg5_0.5 tensor([0.9877, 0.7234, 0.0025, 0.8551, 0.4902, 0.5810, 0.8481, 0.0543, 0.8093,
        0.6863, 0.4647, 0.4102, 0.2722, 0.2865])
base_avg5_0.75 tensor([9.6354e-01, 6.9975e-01, 1.1105e-04, 6.8099e-01, 2.7119e-01, 3.8267e-01,
        8.0533e-01, 6.8593e-13, 7.2050e-01, 6.2072e-01, 1.8807e-01, 2.9839e-01,
        4.7550e-02, 8.4411e-02])
base_avg9_0.45 tensor([9.8671e-01, 7.6390e-01, 6.3435e-13, 8.2418e-01, 3.5322e-01, 4.8172e-01,
        8.6415e-01, 5.1004e-02, 8.1755e-01, 7.1167e-01, 2.0717e-01, 3.7460e-01,
        9.6815e-03, 1.0621e-01])
base_avg9_0.75 tensor([9.5002e-01, 5.7018e-01, 1.0975e-12, 4.4801e-01, 4.5906e-02, 9.5308e-02,
        7.3506e-01, 6.8593e-13, 5.4302e-01, 4.9191e-01, 3.2134e-02, 1.7405e-01,
        7.5296e-11, 1.3124e-02])
dt []
dt_avg5 []
dt_avg9 []


In [0]:
 with open("/content/drive/My Drive/Colab Notebooks/Segmentation/Save_files/all_results.cpkl",'wb') as f:
    pickle.dump(all_results,f)