<a href="https://colab.research.google.com/github/aju22/UNet/blob/main/U_NET_from_Scratch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [9]:
import tensorflow as tf
import tensorflow.keras.layers as layers

# Model Architecture

U-Net is an architecture for semantic segmentation. It consists of a contracting path and an expansive path. 

The contracting path follows the typical architecture of a convolutional network. It consists of the repeated application of two 3x3 convolutions (unpadded convolutions), each followed by a rectified linear unit (ReLU) and a 2x2 max pooling operation with stride 2 for downsampling.

Every step in the expansive path consists of an upsampling of the feature map followed by a 2x2 convolution (“up-convolution”) that halves the number of feature channels, a concatenation with the correspondingly cropped feature map from the contracting path, and two 3x3 convolutions, each followed by a ReLU. 

![](https://production-media.paperswithcode.com/methods/Screen_Shot_2020-07-07_at_9.08.00_PM_rpNArED.png)

In [20]:
class DoubleCNN(layers.Layer):
  def __init__(self, in_channels, out_channels):
    super().__init__()

    self.convblock = tf.keras.Sequential([
                                     layers.Conv2D(out_channels, kernel_size = 3, strides = 1, padding = 'same', use_bias = False),
                                     layers.BatchNormalization(),
                                     layers.Activation('relu'),
                                     layers.Conv2D(out_channels, kernel_size = 3, strides = 1, padding = 'same', use_bias = False),
                                     layers.BatchNormalization(),
                                     layers.Activation('relu'),
                                    ])

  def call(self, x):

    x = self.convblock(x)

    return x

In [21]:
class UNet(tf.keras.Model):
  def __init__(self, in_channels = 3, out_channels = 1, features = [64, 12, 256, 512]):
    super().__init__()

    self.ups = []
    self.downs = []
    self.pool = layers.MaxPooling2D(pool_size = 2, strides = 2)


    for feature in features:
     
      self.downs.append(DoubleCNN(in_channels, feature))
      in_channels = feature

    

    for feature in reversed(features):
      
      self.ups.append(layers.Conv2DTranspose(feature, kernel_size = 2, strides = 2))
      self.ups.append(DoubleCNN(feature*2, feature))


    self.bottleneck = DoubleCNN(features[-1], features[-1]*2)
    self.final_out = layers.Conv2D(out_channels, kernel_size = 1, padding = 'same')


  def call(self, x):
    skip_connections = []

    for down in self.downs:
      x = down(x)
      skip_connections.append(x)
      x = self.pool(x)

    x = self.bottleneck(x)
    
    skip_connections = skip_connections[::-1]
    
    for i in range(0, len(self.ups), 2):

      x = self.ups[i](x)
      
      skip_connection = skip_connections[i//2]
      if x.shape != skip_connection.shape:
        x = tf.compat.v1.image.resize_bilinear(x, skip_connection.shape[1:3])

      concat = layers.Concatenate()([skip_connection, x])

      x = self.ups[i+1](concat)

    return self.final_out(x)

In [25]:
model = UNet(out_channels = 3)

In [26]:
x = tf.random.uniform((1, 224, 224, 3))

In [27]:
print(model(x).shape)

(1, 224, 224, 3)
