In [2]:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import Input, Conv2D, MaxPooling2D, Concatenate, Conv2DTranspose

In [52]:
class DoubleConv(keras.layers.Layer):
    def __init__(self, num_filters):
        super(DoubleConv, self).__init__()
        self.dl1 = Conv2D(num_filters, kernel_size=(3,3), strides=(1,1), padding='valid', activation='relu')
        self.dl2 = Conv2D(num_filters, kernel_size=(3,3), strides=(1,1), padding='valid', activation='relu')
        
    def call(self, x):
        x = self.dl1(x)
        x = self.dl2(x)
        return x

In [53]:
def crop_tensor(given, target):
    delta = (given.shape[2] - target.shape[2]) // 2
    return given[:, delta:given.shape[2]-delta, delta:given.shape[2]-delta,:]
    

In [54]:
# Check if crop_tensor function is working
given = tf.ones((1, 64, 64, 512))
target = tf.ones((1, 56, 56, 512))
y = crop_tensor(given, target)
print(y.shape)

(1, 56, 56, 512)


In [55]:
class DoubleUpConv(keras.layers.Layer):
    def __init__(self, num_filters):
        super(DoubleUpConv, self).__init__()
        self.uc = Conv2DTranspose(num_filters, kernel_size=(2,2), strides=(2,2), padding='valid', activation='relu')
        self.dl1 = Conv2D(num_filters, kernel_size=(3,3), strides=(1,1), padding='valid', activation='relu')
        self.dl2 = Conv2D(num_filters, kernel_size=(3,3), strides=(1,1), padding='valid', activation='relu')
        
    def call(self, x, prev):
        x1 = self.uc(x)
        x2 = crop_tensor(prev, x1)
#         print(f"x1 shape: {x1.shape}")
#         print(f"x2 shape: {x2.shape}")
        x = tf.concat([x2, x1], axis=3)
#         print(f"In up_conv{x.shape}")
        x = self.dl1(x)
        x = self.dl2(x)
        return x

In [56]:
class UNET(keras.Model):
    def __init__(self):
        super(UNET, self).__init__()
        self.dc64 = DoubleConv(64)
        self.dc128 = DoubleConv(128)
        self.dc256 = DoubleConv(256)
        self.dc512 = DoubleConv(512)
        self.dc1024 = DoubleConv(1024)
        self.maxpool = MaxPooling2D(pool_size=(2, 2), strides=None, padding='valid')
        
        self.uc1 = DoubleUpConv(512)
        self.uc2 = DoubleUpConv(256)
        self.uc3 = DoubleUpConv(128)
        self.uc4 = DoubleUpConv(64)
        
        self.conv_1x1 = Conv2D(2, kernel_size=(1,1), strides=(1,1), padding='valid', activation='relu')
        
    def call(self, x):
        x1 = self.dc64(x)
        x2 = self.maxpool(x1)
        x2 = self.dc128(x2)
        x3 = self.maxpool(x2)
        x3 = self.dc256(x3)
        x4 = self.maxpool(x3)
        x4 = self.dc512(x4)
        x5 = self.maxpool(x4)
        x5 = self.dc1024(x5)
        print(x5.shape)
        
        x6 = self.uc1(x5, x4)
        x7 = self.uc2(x6, x3)
        x8 = self.uc3(x7, x2)
        x9 = self.uc4(x8, x1)
        print(x9.shape)
        
        x10 = self.conv_1x1(x9)
        print(x10.shape)
        return x10      

In [None]:
# Hyperparameters
BATCH_SIZE = 1

In [57]:
model = UNET()
x = tf.ones(shape=(BATCH_SIZE, 572, 572, 1))
y = model(x)
#model.build((1, 572, 572, 1))

(1, 28, 28, 1024)
(1, 388, 388, 64)
(1, 388, 388, 2)


In [31]:
# input_shape = (572, 572, 1)
# inputs = Input(shape=input_shape)
# x1 = Conv2D(64, (3, 3), padding='valid', activation='relu')(inputs)
# x2 = Conv2D(64, (3, 3), padding='valid', activation='relu')(x1)
# x3 = MaxPooling2D((2, 2), padding='valid')(x2)
# x4 = Conv2D(128, (3, 3), padding='valid', activation='relu')(x3)
# x5 = Conv2D(128, (3, 3), padding='valid', activation='relu')(x4)
# x6 = MaxPooling2D((2, 2), padding='valid')(x5)
# x7 = Conv2D(256, (3, 3), padding='valid', activation='relu')(x6)
# x8 = Conv2D(256, (3, 3), padding='valid', activation='relu')(x7)
# x9 = MaxPooling2D((2, 2), padding='valid')(x8)
# x10 = Conv2D(512, (3, 3), padding='valid', activation='relu')(x9)
# x11 = Conv2D(512, (3, 3), padding='valid', activation='relu')(x10)
# x12 = MaxPooling2D((2, 2), padding='valid')(x11)
# x13 = Conv2D(1024, (3, 3), padding='valid', activation='relu')(x12)
# x14 = Conv2D(1024, (3, 3), padding='valid', activation='relu')(x13)
# x15 = Conv2DTranspose(1024, (2, 2), strides=(2,2), padding='valid', activation='relu')(x14)
# x16 = Conv2D(512, (3, 3), padding='valid', activation='relu')(x15)
# x17 = Conv2D(512, (3, 3), padding='valid', activation='relu')(x16)
# x18 = Conv2DTranspose(1024, (2, 2), strides=(2,2), padding='valid', activation='relu')(x17)
# print(x17.shape)

(None, 52, 52, 512)
