In [21]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [22]:
!pip install SimpleITK

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [23]:
import os
import tensorflow as tf
from tensorflow.keras import layers, metrics, losses
from tensorflow.keras.models import Model 
from tensorflow.keras.optimizers import Adam
import matplotlib.pyplot as plt
import numpy as np
import SimpleITK as sitk

In [24]:
os.chdir('/content/drive/MyDrive/nnfl')

In [25]:
BATCH_SIZE = 2

In [26]:
class CustomLayer(Model):
    def __init__(self,input_shape, filters, n_layers, is_down = True):
        super(CustomLayer, self).__init__(name='CustomLayer')
        self.clayers = []
        filter1, filter2 = filters 
        
        self.input_layer = layers.InputLayer(input_shape=(*input_shape, filter1))

        for _ in range(n_layers):
            self.clayers.append(layers.Conv3D(filter1, 5, strides=1, padding='same'))

        if is_down:
            self.out = layers.Conv3D(filter2, 2, strides=2, padding='valid')
        else:
            self.out = layers.Conv3DTranspose(filter2, 2, strides=2, padding='valid')
        
        self.prelu = layers.PReLU() 
        
    def call(self, input_tensor1, input_tensor2=None, training=False):

        if input_tensor2 is not None:
            input_tensor = layers.Concatenate(axis=4)([input_tensor1, input_tensor2])
        else:
            input_tensor = input_tensor1

        input_tensor = self.input_layer(input_tensor)
        x = input_tensor
        
        for layer in self.clayers:
            x = layer(x)
        x += input_tensor1
        out = self.out(x)
        
        return x, self.prelu(out)


In [41]:
class VNet(Model):
    def __init__(self, input_shape, batch_size):
        super(VNet, self).__init__(name='VNet')
        self.batch_size = batch_size
        self.input_layer = layers.InputLayer(input_shape=(input_shape), batch_size=self.batch_size)
        self.layer1 = CustomLayer((128,128,64), (1,16), 1)
        self.layer2 = CustomLayer((64,64,32), (16,32), 2)
        self.layer3 = CustomLayer((32,32,16), (32,64), 3)
        self.layer4 = CustomLayer((16,16,8), (64,128), 3)
        self.layer5 = CustomLayer((8,8,4), (128,256), 3, False)
        self.layer6 = CustomLayer((16,16,8), (256,128), 3, False)
        self.layer7 = CustomLayer((32,32,16), (128,64), 3, False)
        self.layer8 = CustomLayer((64,64,32), (64,32), 2, False)
        self.layer9 = layers.Conv3D(32, 5, strides=1, padding='same')
        self.layer10 = layers.Conv3D(1, 1, padding='same')



    def call(self, input_tensor, training = False):
        input_tensor = self.input_layer(input_tensor)
        o1, l1 = self.layer1(input_tensor)
        o2, l2 = self.layer2(l1)
        o3, l3 = self.layer3(l2)
        o4, l4 = self.layer4(l3)
        _, l5 = self.layer5(l4)
        _, l6 = self.layer6(l5, o4)
        _, l7 = self.layer7(l6, o3)
        _, l8 = self.layer8(l7, o2)
        l8_ = layers.Concatenate(axis=4)([l8, o1])
        l9 = self.layer9(l8_)
        l9 += l8
        l10 = self.layer10(l9)
        return tf.nn.softmax(l10)

class VNet_without_skip(Model):
    def __init__(self, input_shape, batch_size):
        super(VNet_without_skip, self).__init__(name='VNet')
        self.batch_size = batch_size
        self.input_layer = layers.InputLayer(input_shape=(input_shape), batch_size=self.batch_size)
        self.layer1 = CustomLayer((128,128,64), (1,16), 1)
        self.layer2 = CustomLayer((64,64,32), (16,32), 2)
        self.layer3 = CustomLayer((32,32,16), (32,64), 3)
        self.layer4 = CustomLayer((16,16,8), (64,128), 3)
        self.layer5 = CustomLayer((8,8,4), (128,256), 3, False)
        self.layer6 = CustomLayer((16,16,8), (256,128), 3, False)
        self.layer7 = CustomLayer((32,32,16), (128,64), 3, False)
        self.layer8 = CustomLayer((64,64,32), (64,32), 2, False)
        self.layer9 = layers.Conv3D(32, 5, strides=1, padding='same')
        self.layer10 = layers.Conv3D(1, 1, padding='same')



    def call(self, input_tensor, training = False):
        input_tensor = self.input_layer(input_tensor)
        o1, l1 = self.layer1(input_tensor)
        o2, l2 = self.layer2(l1)
        o3, l3 = self.layer3(l2)
        o4, l4 = self.layer4(l3)
        _, l5 = self.layer5(l4)
        _, l6 = self.layer6(l5)
        _, l7 = self.layer7(l6)
        _, l8 = self.layer8(l7)
        l9 = self.layer9(l8)
        l9 += l8
        l10 = self.layer10(l9)
        return tf.nn.softmax(l10)

class VNet_with_extra_layer(Model):
    def __init__(self, input_shape, batch_size):
        super(VNet_with_extra_layer, self).__init__(name='VNet')
        self.batch_size = batch_size
        self.input_layer = layers.InputLayer(input_shape=(input_shape), batch_size=self.batch_size)
        self.layer1 = CustomLayer((128,128,64), (1,16), 1)
        self.layer2 = CustomLayer((64,64,32), (16,32), 2)
        self.layer3 = CustomLayer((32,32,16), (32,64), 3)
        self.layer4 = CustomLayer((16,16,8), (64,128), 3)
        self.extra1 = CustomLayer((8,8,4), (128,128), 3)
        self.layer5 = CustomLayer((4,4,2), (128,256), 3, False)
        self.extra2 = CustomLayer((8,8,4), (256,256), 3, False)
        self.layer6 = CustomLayer((16,16,8), (256,128), 3, False)
        self.layer7 = CustomLayer((32,32,16), (128,64), 3, False)
        self.layer8 = CustomLayer((64,64,32), (64,32), 2, False)
        self.layer9 = layers.Conv3D(32, 5, strides=1, padding='same')
        self.layer10 = layers.Conv3D(1, 1, padding='same')



    def call(self, input_tensor, training = False):
        input_tensor = self.input_layer(input_tensor)
        o1, l1 = self.layer1(input_tensor)
        o2, l2 = self.layer2(l1)
        o3, l3 = self.layer3(l2)
        o4, l4 = self.layer4(l3)
        eo1, el1 = self.extra1(l4)
        _, l5 = self.layer5(el1)
        _, el2 = self.extra2(l5, eo1)
        _, l6 = self.layer6(el2, o4)
        _, l7 = self.layer7(l6, o3)
        _, l8 = self.layer8(l7, o2)
        l8_ = layers.Concatenate(axis=4)([l8, o1])
        l9 = self.layer9(l8_)
        l9 += l8
        l10 = self.layer10(l9)
        return tf.nn.softmax(l10)

In [40]:
# x = np.random.rand(1,128,128,64,1)
# x = tf.constant(x)
# mod = VNet_with_extra_layer(input_shape=(128,128,64,1,), batch_size=1)
# mod.compile()

# p = mod.predict(x)

(None, 16, 16, 8, 64) (None, 8, 8, 4, 128)
(None, 8, 8, 4, 128) (None, 4, 4, 2, 128)
(None, 4, 4, 2, 128) (None, 8, 8, 4, 256)
(None, 8, 8, 4, 256) (None, 16, 16, 8, 256)
(None, 16, 16, 8, 64) (None, 8, 8, 4, 128)
(None, 8, 8, 4, 128) (None, 4, 4, 2, 128)
(None, 4, 4, 2, 128) (None, 8, 8, 4, 256)
(None, 8, 8, 4, 256) (None, 16, 16, 8, 256)


# Data Pipeline:

In [29]:
class CustomDataset:

    def __init__(self, dir:str, out_shape = (128,128,64), image_channels = 1, mask_channels = 1, batch_size= 10):
      self.dir = dir
      self.image_channels = image_channels
      self.batch_size = batch_size
      self.mask_channels = mask_channels
      self.params = dict()
      self.params['dstRes'] = np.asarray([1,1,1.5],dtype=float)
      self.params['VolSize'] = np.asarray([*out_shape],dtype=int)
      self.params['normDir'] = False
  

    def getNumpyData(self,dat,method):

      ret = np.zeros([self.params['VolSize'][0], self.params['VolSize'][1], self.params['VolSize'][2]], dtype=np.float32)

      img=dat

      #we rotate the image according to its transformation using the direction and according to the final spacing we want
      factor = np.asarray(img.GetSpacing()) / [self.params['dstRes'][0], self.params['dstRes'][1],
                                                self.params['dstRes'][2]]

      factorSize = np.asarray(img.GetSize() * factor, dtype=float)

      newSize = np.max([factorSize, self.params['VolSize']], axis=0)

      newSize = newSize.astype(dtype=int).tolist()
      
      T=sitk.AffineTransform(3)
      T.SetMatrix(img.GetDirection())

      resampler = sitk.ResampleImageFilter()
      resampler.SetReferenceImage(img)
      resampler.SetOutputSpacing([self.params['dstRes'][0], self.params['dstRes'][1], self.params['dstRes'][2]])
      resampler.SetSize(newSize)
      resampler.SetInterpolator(method)

      if self.params['normDir']:
        resampler.SetTransform(T.GetInverse())

      imgResampled = resampler.Execute(img)


      imgCentroid = np.asarray(newSize, dtype=float) / 2.0

      imgStartPx = (imgCentroid - self.params['VolSize'] / 2.0).astype(dtype=int).tolist()

      regionExtractor = sitk.RegionOfInterestImageFilter()
      regionExtractor.SetSize(list(self.params['VolSize'].astype(dtype=int).tolist()))
      regionExtractor.SetIndex(imgStartPx)

      imgResampledCropped = regionExtractor.Execute(imgResampled)

      ret = np.transpose(sitk.GetArrayFromImage(imgResampledCropped).astype(dtype=float), [2, 1, 0])

      return ret


    # loads data
    def load_data(self):
      images =  [f for f in os.listdir(self.dir) if 'segmentation' not in f and 'raw' not in f] 
      masks = [f for f in os.listdir(self.dir) if 'segmentation' in f and 'raw' not in f]
      images = sorted(images)
      masks = sorted(masks)
      return images , masks
    
    def process_img(self, filepaths):
      images = []
      cwd = os.getcwd()
      os.chdir(self.dir)
      for filepath in filepaths:
        rescalFilt=sitk.RescaleIntensityImageFilter()
        rescalFilt.SetOutputMaximum(1)
        rescalFilt.SetOutputMinimum(0)
        img =rescalFilt.Execute(sitk.Cast(sitk.ReadImage(filepath),sitk.sitkFloat32))
        img = self.getNumpyData(img, sitk.sitkLinear)
        img = img.reshape(*img.shape, 1)
        images.append(img)
      os.chdir(cwd)
      images = np.array(images, dtype= np.float32)
      images =  tf.convert_to_tensor(images)
      return images
    

    def process_mask(self, maskpaths):
      cwd = os.getcwd()
      os.chdir(self.dir)
      masks = []
      for maskpath in maskpaths:
        mask = sitk.Cast(sitk.ReadImage(maskpath)>0.5,sitk.sitkFloat32)
        mask = self.getNumpyData(mask, sitk.sitkLinear)
        mask = (mask > 0.5).astype(dtype=np.float32)
        mask = mask.reshape(*mask.shape, 1)
        masks.append(mask)
      os.chdir(cwd)
      masks = np.array(masks, dtype= np.float32)
      masks =  tf.convert_to_tensor(masks)
      return masks  


    # Call this to get the data
    def get_dataset(self):
      
      x,y = self.load_data()
      print(f"Total images:- {len(x)}, masks:- {len(y)}")
      x,y = self.process_img(x), self.process_mask(y)
      dataset = tf.data.Dataset.from_tensor_slices((x,y))
      dataset = dataset.shuffle(buffer_size=10)
      # dataset = dataset.map(
      #   lambda x, y: tf.numpy_function(
      #       func = self.process_img,
      #       inp = [x],
      #       Tout= tf.Tensor,
      #   ), tf.numpy_function(
      #       func = self.process_mask,
      #       inp = [y],
      #       Tout= tf.Tensor,
      #   )
      # )
      dataset = dataset.batch(batch_size=self.batch_size)
      return  dataset


In [33]:
train = CustomDataset('./dataset/train', batch_size=BATCH_SIZE).get_dataset()
# test = CustomDataset('./dataset/test', batch_size=1).get_dataset()

for x,y in train.take(2):
  print(x.shape,y.shape)

Total images:- 50, masks:- 50
(2, 128, 128, 64, 1) (2, 128, 128, 64, 1)
(2, 128, 128, 64, 1) (2, 128, 128, 64, 1)


# Train Pipeline: 

In [34]:
mod = VNet(input_shape=(128,128,64,1,), batch_size = BATCH_SIZE)
mod.compile(loss='mse', optimizer= Adam(learning_rate=1e-5))
mod.fit(train, epochs= 3, batch_size=BATCH_SIZE)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x7fea4a32fa50>