<a href="https://colab.research.google.com/github/abcdjdj/cs-766-project/blob/main/double_u_net.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
%cd '/content/drive/MyDrive/Colab Notebooks/'

/content/drive/MyDrive/Colab Notebooks


Utility Functions

In [None]:
import torch
import cv2
import glob

'''
Reads the image specified by 'path' and returns it
param : path - path of image file
return : image as a numpy array
'''
def read_img(path):
    image = cv2.imread(path, cv2.IMREAD_COLOR)
    return image

'''
Converts numpy img to tensor
param : img - numpy arr containing image data
return : t - torch tensor of shape [1, 3, H, W]
'''
def img_to_tensor(img):
    t = torch.from_numpy(img)
    t = t.view(-1, 3, t.shape[0], t.shape[1])
    return t

'''
Converts tensor back to numpy img
param : t - torch tensor of shape [1, 3, H, W]
return : img - numpy arr containing image data
'''
def tensor_to_img(t):
    t = t.view(t.shape[2], t.shape[3], 3)
    return t.numpy()

Double U-Net Architecture

Wrap Up inside nn.Module

In [36]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchsummary import summary
from torchvision import models
from imageio import imread as imread
import matplotlib.pyplot as plt

class SqueezeAndExcite(nn.Module):
  def __init__(self, x, ratio = 8):
    super(SqueezeAndExcite, self).__init__()

    channel_axis = 1
    filters = x.shape[channel_axis]
    # Architecture
    self.avgpool2d = nn.AvgPool2d(kernel_size = (x.shape[2], x.shape[3]))
    self.sequential = nn.Sequential(nn.Linear(filters, filters//ratio, bias = False), nn.ReLU(), nn.Linear(filters//ratio, filters, bias = False), nn.Sigmoid())

  def forward(self, x):
    init = x
    channel_axis = 1
    filters = init.shape[channel_axis]
    x = self.avgpool2d(x)
    x = x.view(init.shape[0] , filters)
    x = self.sequential(x)
    x = x.view(init.shape[0], filters, 1, 1)

    return torch.mul(init, x)

class ConvBlock(nn.Module):
  def __init__(self, x, filters):
      super().__init__()

      self.layer1_conv2d = nn.Conv2d(in_channels = x.shape[1], out_channels = filters, kernel_size = 3, padding='same')
      x = self.layer1_conv2d(x)
      self.layer1_batchnorm2d = nn.BatchNorm2d(num_features = x.shape[1])
      x = self.layer1_batchnorm2d(x)
      self.layer1_relu = nn.ReLU()
      x = self.layer1_relu(x)

      self.layer2_conv2d = nn.Conv2d(in_channels = x.shape[1], out_channels = filters, kernel_size = 3, padding='same')
      x = self.layer2_conv2d(x)
      self.layer2_batchnorm2d = nn.BatchNorm2d(num_features = x.shape[1])
      x = self.layer2_batchnorm2d(x)
      self.layer2_relu = nn.ReLU()
      x = self.layer2_relu(x)

      self.squeeze_and_excite = SqueezeAndExcite(x)

  def forward(self, x):
      x = self.layer1_conv2d(x)
      x = self.layer1_batchnorm2d(x)
      x = self.layer1_relu(x)

      x = self.layer2_conv2d(x)
      x = self.layer2_batchnorm2d(x)
      x = self.layer2_relu(x)

      x = self.squeeze_and_excite.forward(x)
      return x

class ASPP(nn.Module):
    def __init__(self, x, filter_count):
      super().__init__()

      self.layer1_avgpool2d = nn.AvgPool2d(kernel_size = (x.shape[2], x.shape[3]))
      se = self.layer1_avgpool2d(x)
      self.layer1_conv2d = nn.Conv2d(in_channels = se.shape[1], out_channels = filter_count, kernel_size = 1, padding='same')
      se = self.layer1_conv2d(se)
      self.layer1_batchnorm2d = nn.BatchNorm2d(num_features = se.shape[1])
      se = self.layer1_batchnorm2d(se)
      self.layer1_relu = nn.ReLU()
      se = self.layer1_relu(se)
      self.layer1_upsampling = nn.UpsamplingBilinear2d(size=(x.shape[2], x.shape[3]))
      se = self.layer1_upsampling(se)

      self.layer2_conv2d = nn.Conv2d(dilation=1, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)
      y1 = self.layer2_conv2d(x)
      self.layer2_batchnorm2d = nn.BatchNorm2d(num_features = y1.shape[1])
      y1 = self.layer2_batchnorm2d(y1)
      self.layer2_relu = nn.ReLU()
      y1 = self.layer2_relu(y1)

      self.layer3_conv2d = nn.Conv2d(dilation=6, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)
      y2 = self.layer3_conv2d(x)
      self.layer3_batchnorm2d = nn.BatchNorm2d(num_features = y2.shape[1])
      y2 = self.layer3_batchnorm2d(y2)
      self.layer3_relu = nn.ReLU()
      y2 = self.layer3_relu(y2)

      self.layer4_conv2d = nn.Conv2d(dilation=12, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)
      y3 = self.layer4_conv2d(x)
      self.layer4_batchnorm2d = nn.BatchNorm2d(num_features = y3.shape[1])
      y3 = self.layer4_batchnorm2d(y3)
      self.layer4_relu = nn.ReLU()
      y3 = self.layer4_relu(y3)

      self.layer5_conv2d = nn.Conv2d(dilation=18, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)
      y4 = self.layer5_conv2d(x)
      self.layer5_batchnorm2d = nn.BatchNorm2d(num_features = y4.shape[1])
      y4 = self.layer5_batchnorm2d(y4)
      self.layer5_relu = nn.ReLU()
      y4 = self.layer5_relu(y4)

      y = torch.cat([se, y1, y2, y3, y4], dim=1)
      self.layer6_conv2d = nn.Conv2d(dilation=1, in_channels = y.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)
      y = self.layer6_conv2d(y)
      self.layer6_batchnorm2d = nn.BatchNorm2d(num_features = y.shape[1])
      y = self.layer6_batchnorm2d(y)
      self.layer6_relu = nn.ReLU()
      y = self.layer6_relu(y)

    def forward(self, x, filter_count):
      se = self.layer1_avgpool2d(x)
      se = self.layer1_conv2d(se)
      se = self.layer1_batchnorm2d(se)
      se = self.layer1_relu(se)
      se = self.layer1_upsampling(se)
      #print(se.shape)

      y1 = self.layer2_conv2d(x)
      y1 = self.layer2_batchnorm2d(y1)
      y1 = self.layer2_relu(y1)
      #print(y1.shape)

      y2 = self.layer3_conv2d(x)
      y2 = self.layer3_batchnorm2d(y2)
      y2 = self.layer3_relu(y2)
      #print(y2.shape)

      y3 = self.layer4_conv2d(x)
      y3 = self.layer4_batchnorm2d(y3)
      y3 = self.layer4_relu(y3)
      #print(y3.shape)

      y4 = self.layer5_conv2d(x)
      y4 = self.layer5_batchnorm2d(y4)
      y4 = self.layer5_relu(y4)
      #print(y4.shape)

      y = torch.cat([se, y1, y2, y3, y4], dim=1)
      y = self.layer6_conv2d(y)
      y = self.layer6_batchnorm2d(y)
      y = self.layer6_relu(y)
      #print(y.shape)
      return y

class Encoder1(nn.Module):
    def __init__(self):
      super().__init__()
      self.model = models.vgg19()
    
    def forward(self, inputs):
      #skip connections from pre-trained VGG-19
      names = ["ReLU-4", "ReLU-9", "ReLU-18", "ReLU-27", "ReLU-36"]

      indices = [3, 8, 17, 26, 35]

      skip_connections = []

      def encoder1_receive_outputs(layer, _, output):
          skip_connections.append(output)

      for name, layer in self.model.named_children():
          for idx in indices:
              layer[idx].register_forward_hook(encoder1_receive_outputs)
          break

      self.model(inputs)

      return skip_connections[-1], skip_connections[0:-1]


class DoubleUNet(nn.Module):
  def __init__(self):
    super(DoubleUNet, self).__init__()
    #self.blocks = nn.ModuleList()

    # Encoder 1
    #self.encoder1_vgg19 = models.vgg19()
    #self.conv_block = ConvBlock(torch.ones(1, 3, 256, 256), filters = 8)
    #self.squeeze = SqueezeAndExcite(torch.ones(1, 10, 256, 256))
    # To get picked up - type(self.xxx) == nn.Module
    #self.ASPP_model = ASPP(torch.ones(2,512,16,16), 64)
    #self.encoder1 = Encoder1()
    

  def squeeze_and_excite(self, inputs, ratio = 8):
    init = inputs  #(b, 32, 128, 128)
    channel_axis = 1
    filters = init.shape[channel_axis]
    se = nn.AvgPool2d(kernel_size = (init.shape[2], init.shape[3]))(init) # (b, 32) -> (b,4)
    se = se.view(init.shape[0] , filters)
    se = nn.Sequential(nn.Linear(filters, filters//ratio, bias = False), nn.ReLU(), nn.Linear(filters//ratio, filters, bias = False), nn.Sigmoid())(se) # (b,32)
    se = se.view(init.shape[0],filters,1,1) #(b, 32, 1, 1)

    return torch.mul(init,se) #(b,32,128,128)
  
  """
  Function: ASPP to get high resolution feature maps
  Inputs: feature maps, output channels desired 
  Outputs: High Res feature maps
  """
  def ASPP(self, x, filter_count):
      se = nn.AvgPool2d(kernel_size = (x.shape[2], x.shape[3]))(x)
      se = nn.Conv2d(in_channels = se.shape[1], out_channels = filter_count, kernel_size = 1, padding='same')(se)
      se = nn.BatchNorm2d(num_features = se.shape[1])(se)
      se = nn.ReLU()(se)
      se = nn.UpsamplingBilinear2d(size=(x.shape[2], x.shape[3]))(se)
      #print(se.shape)

      y1 = nn.Conv2d(dilation=1, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)(x)
      y1 = nn.BatchNorm2d(num_features = y1.shape[1])(y1)
      y1 = nn.ReLU()(y1)
      #print(y1.shape)

      y2 = nn.Conv2d(dilation=6, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)(x)
      y2 = nn.BatchNorm2d(num_features = y2.shape[1])(y2)
      y2 = nn.ReLU()(y2)
      #print(y2.shape)

      y3 = nn.Conv2d(dilation=12, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)(x)
      y3 = nn.BatchNorm2d(num_features = y3.shape[1])(y3)
      y3 = nn.ReLU()(y3)
      #print(y3.shape)

      y4 = nn.Conv2d(dilation=18, in_channels = x.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)(x)
      y4 = nn.BatchNorm2d(num_features = y4.shape[1])(y4)
      y4 = nn.ReLU()(y4)
      #print(y4.shape)

      y = torch.cat([se, y1, y2, y3, y4], dim=1)
      y = nn.Conv2d(dilation=1, in_channels = y.shape[1], out_channels = filter_count, kernel_size = 1, padding='same', bias=False)(y)
      y = nn.BatchNorm2d(num_features = y.shape[1])(y)
      y = nn.ReLU()(y)
      #print(y.shape)
      return y


  """
  function: This is Encoder 1
  params: Medical Image Input
  return: Output of Encoder1, 4 Skip Conns for Decoder 1
  """
  def encoder1(self, inputs):
      model = self.encoder1_vgg19
      #print(summary(model,(3,256,256)))

      #skip connections from pre-trained VGG-19
      names = ["ReLU-4", "ReLU-9", "ReLU-18", "ReLU-27", "ReLU-36"]

      indices = [3, 8, 17, 26, 35]

      skip_connections = []

      def encoder1_receive_outputs(layer, _, output):
          skip_connections.append(output)

      for name, layer in model.named_children():
          for idx in indices:
              layer[idx].register_forward_hook(encoder1_receive_outputs)
          break

      model(inputs)

      return skip_connections[-1], skip_connections[0:-1]

  """
  Function: 2 Blocks of Convolution + BN + ReLU
  Input: Input Activation Map, Desired output channels
  Output: Convolved Activation Maps
  """
  def conv_block(self, x, filters):
      x = nn.Conv2d(in_channels = x.shape[1], out_channels = filters, kernel_size = 3, padding='same')(x)
      x = nn.BatchNorm2d(num_features = x.shape[1])(x)
      x = nn.ReLU()(x)

      x = nn.Conv2d(in_channels = x.shape[1], out_channels = filters, kernel_size = 3, padding='same')(x)
      x = nn.BatchNorm2d(num_features = x.shape[1])(x)
      x = nn.ReLU()(x)

      x = self.squeeze_and_excite(x)

      return x

  """
  Function: Decoder 1
  Params: ASPP Output, Skip Connections from Encoder1
  Output: To be passed through output_block to get mask
  """
  def decoder1(self, inputs, skip_connections):
      num_filters = [256, 128, 64, 32]

      skip_connections.reverse()

      x = inputs

      for i,f in enumerate(num_filters):
          x = nn.UpsamplingBilinear2d(size=(2*x.shape[2], 2*x.shape[3]))(x)
          x = torch.cat([x, skip_connections[i]], dim=1)
          x = self.conv_block(x, f)

      return x

  """
  Function: To get mask from decoder1 output
  Input: Decoder1 output
  Output: Mask for Network1
  """
  def output_block(self, inputs):
      x = nn.Conv2d(in_channels = inputs.shape[1], out_channels = 1, kernel_size = 1, padding = "same")(inputs)
      x = nn.Sigmoid()(x)
      return x


  def encoder2(self, inputs):
      num_filters = [32, 64, 128, 256]
      skip_connections = []
      x = inputs

      for f in num_filters:
          x = self.conv_block(x, f)
          skip_connections.append(x)
          x = nn.MaxPool2d(kernel_size = (2,2))(x)

      return x, skip_connections

  def decoder2(self, inputs, skip_1, skip_2):
      num_filters = [256, 128, 64, 32]

      skip_2.reverse()

      x = inputs

      for i,f in enumerate(num_filters):
          x = nn.UpsamplingBilinear2d(size=(2*x.shape[2], 2*x.shape[3]))(x)
          #print(f"X -> {x.shape}")
          #print(f"Skip1 -> {torch.Tensor(skip_1[i]).shape}")
          #print(f"Skip2 -> {torch.Tensor(skip_2[i]).shape}")
          x = torch.cat([x, skip_1[i], skip_2[i]], dim=1)
          x = self.conv_block(x, f)

      return x
  
  def forward(self, inputs):
    encoder1_op, encoder1_skip_conns = self.encoder1(inputs)
    #print(f"Encoder 1 o/p shape {encoder1_op.shape}")
    aspp_op = self.ASPP(encoder1_op, 64)
    #print(f"ASPP o/p shape {aspp_op.shape}")
    decoder1_op = self.decoder1(aspp_op, encoder1_skip_conns)
    #print(f"Decoder 1 o/p shape {decoder1_op.shape}")
    mask = self.output_block(decoder1_op)
    #print(f"Mask shape {mask.shape}")
    network1_op = inputs * mask
    #print(f"Network 1 o/p shape {network1_op.shape}")
    encoder2_op,encoder2_skip_conns = self.encoder2(network1_op)
    #print(f"Encoder2 o/p shape {encoder2_op.shape}")
    aspp2_op = self.ASPP(encoder2_op, 64)
    #print(f"ASPP2 o/p shape {aspp2_op.shape}")
    decoder2_op = self.decoder2(aspp2_op, encoder1_skip_conns, encoder2_skip_conns)
    #print(f"Decoder2 o/p shape {decoder2_op.shape}")
    network2_op = self.output_block(decoder2_op)
    #print(f"Network 2 o/p shape {network2_op.shape}")
    final_output = torch.cat([mask, network2_op], dim = 1)
    #print(f"Final o/p shape {self.final_output.shape}")
    return final_output

Hyperparameters

In [None]:
learning_rate = 1e-5
num_epochs = 300
batch_size = 7
num_batches = num_epochs//batch_size

Data Pre-Processing

In [None]:
img_list = sorted(glob.glob("out/image/*"))
mask_list = sorted(glob.glob("out/mask/*"))

In [None]:
img_list = [img_to_tensor(read_img(ele)) for ele in img_list]
mask_list = [img_to_tensor(read_img(ele)) for ele in mask_list]

img_data = list(zip(img_list,mask_list))

data_len = len(img_list)

In [None]:
#Splitting into 80-10-10

train_set, val_set, test_set = torch.utils.data.random_split(img_data, [round(0.8*data_len), round(0.1*data_len), data_len - round(0.8*data_len) - round(0.1*data_len)])

In [None]:
#Divide Train Data Into List of Batches for Training Loop
train_loader_x = []
train_loader_y = []

for idx in range(0, len(train_set), batch_size):
  x_list, y_list = list(zip(*(list(train_set)[idx:idx + batch_size])))
  train_loader_x.append(x_list)
  train_loader_y.append(y_list)

Define Optimizer, Loss Function

In [37]:
double_u_net = DoubleUNet()
for parameter in double_u_net.parameters():
  print(f"Parameter = {parameter}")

optimizer = optim.NAdam(double_u_net.parameters(), lr = 0.001)

loss = nn.BCELoss()

Parameter = Parameter containing:
tensor([[[[-4.1079e-02, -9.4198e-03, -3.2224e-02],
          [ 1.4317e-02, -1.6763e-02, -1.0992e-01],
          [ 8.9170e-02,  8.3375e-02,  6.5552e-02]],

         [[-2.4049e-02,  3.9336e-02,  6.4972e-02],
          [ 2.8402e-02, -5.6396e-02,  7.6450e-02],
          [ 1.6318e-01,  5.1375e-03,  7.0324e-02]],

         [[ 9.7755e-02, -1.6002e-04,  1.4582e-02],
          [ 3.8255e-02, -5.4637e-02,  5.9850e-02],
          [ 8.8532e-02, -3.9805e-02,  9.3292e-02]]],


        [[[ 1.0265e-02,  2.3144e-02, -2.4944e-02],
          [ 5.8861e-03, -5.1325e-02,  2.3130e-02],
          [-1.2266e-02,  3.3044e-02,  1.5702e-01]],

         [[-1.1765e-01, -3.2747e-02,  6.2388e-02],
          [ 4.1264e-02,  2.9642e-02,  5.2829e-02],
          [ 1.9636e-02, -3.7679e-02, -1.5139e-01]],

         [[ 3.4216e-02, -7.2346e-02,  7.8419e-03],
          [-6.6289e-02, -4.2798e-02, -1.0571e-01],
          [ 8.4092e-02,  7.4742e-02,  6.4863e-02]]],


        [[[ 1.6296e-02,  4.8895e

Training Loop

In [None]:
for epochs in range(num_epochs):
  for idx in range(num_batches):
    input = torch.cat(train_loader_x[idx]) #Shape (batch_size, 3, 288, 384)
    #final_output = build_model(input.float()) shape (batch_size, 2 , 288, 384)
    final_output = double_u_net.forward(input.float())
    print(final_output.shape)
    break
  break

AttributeError: ignored