<a href="https://colab.research.google.com/github/Shiveshrane/Research_paper_implementations/blob/main/UNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import tensorflow as tf
import cv2
import pandas
import numpy as np
import matplotlib.pyplot as plt


In [None]:
def doubleConv(outdim):
  conv=tf.keras.Sequential([
      tf.keras.layers.Conv2D(outdim, 3, activation='relu'),
      tf.keras.layers.Conv2D(outdim, 3, activation='relu')
  ])
  return conv

In [None]:
def crop_img(tensor, target):
  _, h,w,c=tensor.shape
  target_h, target_w=target.shape[1], target.shape[2]
  offset_h=(h-target_h)//2
  offset_w=(w-target_w)//2
  tensor=tensor[:, offset_h:offset_h+target_h, offset_w:offset_w+target_w, :]
  return tensor

In [None]:
class UNet(tf.keras.Model):
  def __init__(self):
    super(UNet, self).__init__()
    self.down1=doubleConv(64)
    self.down2=doubleConv(128)
    self.down3=doubleConv(256)
    self.down4=doubleConv(512)
    self.down5=doubleConv(1024)
    self.maxpool=tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=2)
    self.uptrans_1=tf.keras.layers.Conv2DTranspose(512,2,strides=2,padding='same')
    self.uptrans_2=tf.keras.layers.Conv2DTranspose(256,2,strides=2,padding='same')
    self.uptrans_3=tf.keras.layers.Conv2DTranspose(128,2,strides=2,padding='same')
    self.uptrans_4=tf.keras.layers.Conv2DTranspose(64,2,strides=2,padding='same')

    self.up1=doubleConv(512)
    self.up2=doubleConv(256)
    self.up3=doubleConv(128)
    self.up4=doubleConv(64)
    self.conv=tf.keras.layers.Conv2D(1,1,activation='sigmoid')

  def call(self, inputs):
    #Encoder
    x1=self.down1(inputs) #
    #print(x1.shape)
    p1=self.maxpool(x1)
    x2=self.down2(p1)#
    p2=self.maxpool(x2)
    x3=self.down3(p2)#
    p3=self.maxpool(x3)
    x4=self.down4(p3)#
    #print(f"x4 shape:{x4.shape}")
    p4=self.maxpool(x4)
    x5=self.down5(p4)


    #Decoder
    x6=self.uptrans_1(x5)
    x4_crp=crop_img(x4,x6)
   # print(f"x6.shape={x6_crp.shape}")
    x6=tf.concat([x4_crp, x6],axis=-1)
    x6=self.up1(x6)
    x7=self.uptrans_2(x6)
    x3_crp=crop_img(x3,x7)
    x7=tf.concat([x3_crp, x7],axis=-1)
    x7=self.up2(x7)
    x8=self.uptrans_3(x7)
    x2_crp=crop_img(x2,x8)
    x8=tf.concat([x2_crp, x8],axis=-1)
    x8=self.up3(x8)
    x9=self.uptrans_4(x8)
    x1_crp=crop_img(x1,x9)
    x9=tf.concat([x1_crp, x9],axis=-1)
    x9=self.up4(x9)
    output=self.conv(x9)
    return output





In [None]:
tf.random.uniform((1,256,256,3))

unet=UNet()
unet(tf.random.uniform((1,572,572,3))) # For binary

<tf.Tensor: shape=(1, 388, 388, 1), dtype=float32, numpy=
array([[[[0.49149704],
         [0.49796534],
         [0.50300246],
         ...,
         [0.49622294],
         [0.4953434 ],
         [0.4965124 ]],

        [[0.4939871 ],
         [0.4941092 ],
         [0.48149043],
         ...,
         [0.5010762 ],
         [0.48987046],
         [0.48773307]],

        [[0.5119833 ],
         [0.4970885 ],
         [0.4915133 ],
         ...,
         [0.4985525 ],
         [0.505046  ],
         [0.50246906]],

        ...,

        [[0.50317097],
         [0.49991584],
         [0.4934016 ],
         ...,
         [0.48956007],
         [0.503519  ],
         [0.49633968]],

        [[0.4962588 ],
         [0.50139284],
         [0.4878854 ],
         ...,
         [0.49714774],
         [0.4913343 ],
         [0.49181226]],

        [[0.49240118],
         [0.50439084],
         [0.48667938],
         ...,
         [0.49425632],
         [0.4987619 ],
         [0.50031507]]]], dty