<a href="https://colab.research.google.com/github/ShankarPoudel441/U-Net_implementations/blob/main/U_Net_implementations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf
from tensorflow import keras

In [2]:
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model


def conv_block(inputs,num_filters):
  x = Conv2D(num_filters, 3, padding = "same")(inputs)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)

  x=Conv2D(num_filters, 3, padding="same")(x)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)

  return x

def encoder_block(inputs,num_filters):
  x = conv_block(inputs, num_filters)
  p = MaxPool2D((2,2))(x)
  return x, p
  
def decoder_block(inputs, skip_features, num_filters):
  x = Conv2DTranspose(num_filters, (2,2), strides=2, padding="same")(inputs)
  x = Concatenate()([x, skip_features])
  x = conv_block(x, num_filters)
  return x

def build_unet(input_shape):
  inputs = Input(input_shape)

  """Encoder"""

  s1,p1 = encoder_block(inputs,64)
  s2,p2 = encoder_block(p1,128)
  s3,p3 = encoder_block(p2,256)
  s4,p4 = encoder_block(p3,512)

  """Bridge"""

  b1= conv_block(p4, 1024)

  """Decoder"""
  d1 = decoder_block(b1,s4,512)
  d2 = decoder_block(d1,s3,256)
  d3 = decoder_block(d2,s2,128)
  d4 = decoder_block(d3,s1,64)


  """Output"""
  outputs = Conv2D(1, (1,1), padding="same", activation="sigmoid")(d4)


  model = Model(inputs, outputs, name="U-Net")
  return model

In [3]:
# input_shape=(512,512,3)
# model=build_unet(input_shape)
# model.summary()

In [14]:
from json import decoder
#Modified UNET


from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, MaxPool2D, Conv2DTranspose, Concatenate, Input
from tensorflow.keras.models import Model
from tensorflow.keras.applications import VGG16


def conv_block(inputs, num_filters):
  x=Conv2D(num_filters, 3, padding = "same") (inputs)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)

  x=Conv2D(num_filters, 3, padding = "same") (x)
  x=BatchNormalization()(x)
  x=Activation("relu")(x)

  return x


def decoder_block(inputs,skip_features,num_filters):
  x = Conv2DTranspose(num_filters, (2,2), strides=2, padding="same")(inputs)
  print("x and skip features shape",x.shape,skip_features.shape)
  x = Concatenate()([x, skip_features])
  x = conv_block(x,num_filters)
  return x

def build_VGG_16_unet(input_shape):
  inputs = Input(shape=input_shape)

  vgg16 = VGG16(include_top=False, weights="imagenet", input_tensor=inputs)
  # vgg16.summary()


  """  Encoder """

  s1=vgg16.get_layer("block1_conv1").output
  s2=vgg16.get_layer("block2_conv2").output
  s3=vgg16.get_layer("block3_conv3").output
  s4=vgg16.get_layer("block4_conv3").output

  """Bridge"""
  b1=vgg16.get_layer("block5_conv3").output


  """Decoder"""
  d1=decoder_block(b1,s4,512)
  d2=decoder_block(d1,s3,256)
  d3=decoder_block(d2,s2,128)
  d4=decoder_block(d3,s1,64)


  """Output Biniary Segmentation"""
  outputs = Conv2D(1,1,padding="same", activation = "sigmoid")(d4)

  # """Output Multiclass Segmentation"""  
  # outputs = Conv2D(5,1,padding="same", activation = "softmax")(d4)

  model= Model(inputs, outputs, name="VGG16_U-Net")
  return model

In [15]:
input_shape = (512, 512, 3)
model=build_VGG_16_unet(input_shape)
model.summary()

x and skip features shape (None, 64, 64, 512) (None, 64, 64, 512)
x and skip features shape (None, 128, 128, 256) (None, 128, 128, 256)
x and skip features shape (None, 256, 256, 128) (None, 256, 256, 128)
x and skip features shape (None, 512, 512, 64) (None, 512, 512, 64)
Model: "VGG16_U-Net"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_6 (InputLayer)           [(None, 512, 512, 3  0           []                               
                                )]                                                                
                                                                                                  
 block1_conv1 (Conv2D)          (None, 512, 512, 64  1792        ['input_6[0][0]']                
                                )                                                                 
            