## U-Net: Convolutional Networks for Biomedical Image Segmentation

Link: https://arxiv.org/abs/1505.04597

![U-Net](https://lmb.informatik.uni-freiburg.de/people/ronneber/u-net/u-net-architecture.png)

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], True)

print(f'''TensorFlow: {tf.__version__}
Physical Devices: {physical_devices}''')

TensorFlow: 2.9.2
Physical Devices: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]


In [2]:
class DoubleConv(layers.Layer):
    def __init__(self, out_filters):
        super(DoubleConv, self).__init__()
        self.conv1 = keras.Sequential([
            layers.Conv2D(filters=out_filters, kernel_size=3, strides=1),
            layers.BatchNormalization(),
            layers.Activation('relu')
        ])
        self.conv2 = keras.Sequential([
            layers.Conv2D(filters=out_filters, kernel_size=3, strides=1),
            layers.BatchNormalization(),
            layers.Activation('relu')
        ])
    def call(self, x, training=False):
        x = self.conv1(x, training=training)
        x = self.conv2(x, training=training)
        return x

In [3]:
def crop_image(x, target):
    t = (x.shape[1] - target[1])//2
    return x[:, t:-t, t:-t, :]

In [4]:
class UNet(keras.Model):
    def __init__(self):
        super(UNet, self).__init__()
        self.down_conv1 = DoubleConv(out_filters=64)
        self.down_conv2 = DoubleConv(out_filters=128)
        self.down_conv3 = DoubleConv(out_filters=256)        
        self.down_conv4 = DoubleConv(out_filters=512)
        self.down_conv5 = DoubleConv(out_filters=1024)
        
        self.up_conv1_trans = layers.Conv2DTranspose(filters=512, kernel_size=2, strides=2, padding='same')
        self.up_conv1 = DoubleConv(out_filters=512)
        self.up_conv2_trans = layers.Conv2DTranspose(filters=256, kernel_size=2, strides=2, padding='same')
        self.up_conv2 = DoubleConv(out_filters=256)
        self.up_conv3_trans = layers.Conv2DTranspose(filters=128, kernel_size=2, strides=2, padding='same')
        self.up_conv3 = DoubleConv(out_filters=128)
        self.up_conv4_trans = layers.Conv2DTranspose(filters=64, kernel_size=2, strides=2, padding='same')
        self.up_conv4 = DoubleConv(out_filters=64)
        
        self.conv = layers.Conv2D(filters=2, kernel_size=3, strides=1, padding='same')
        self.maxpool = layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))
    
    def call(self, x, training=False):
        xs = []
        x = self.down_conv1(x, training=training)
        xs.append(x)
        print(f'Down Conv1: {x.shape}')
        x = self.maxpool(x)
        print(f'Down Conv1 Maxpool: {x.shape}')
        
        x = self.down_conv2(x, training=training)
        xs.append(x)
        print(f'Down Conv2: {x.shape}')
        x = self.maxpool(x)
        print(f'Down Conv2 Maxpool: {x.shape}')
        
        x = self.down_conv3(x, training=training)
        xs.append(x)
        print(f'Down Conv3: {x.shape}')
        x = self.maxpool(x)
        print(f'Down Conv3 Maxpool: {x.shape}')
        
        x = self.down_conv4(x, training=training)
        xs.append(x)
        print(f'Down Conv4: {x.shape}')
        x = self.maxpool(x)
        print(f'Down Conv4 Maxpool: {x.shape}')
        
        x = self.down_conv5(x)
        print(f'Down Conv5: {x.shape}')
        
        x = self.up_conv1_trans(x)
        t = crop_image(xs[-1], x.shape)
        print(f'Up ConvTranspose1: {x.shape}')
        x = self.up_conv1(layers.concatenate([x, t]), training=training)
        print(f'Up Conv1: {x.shape}')
        
        x = self.up_conv2_trans(x)
        t = crop_image(xs[-2], x.shape)
        print(f'Up ConvTranspose2: {x.shape}')
        x = self.up_conv2(layers.concatenate([x, t]), training=training)
        print(f'Up Conv2: {x.shape}')
        
        x = self.up_conv3_trans(x)
        t = crop_image(xs[-3], x.shape)
        print(f'Up ConvTranspose3: {x.shape}')
        x = self.up_conv3(layers.concatenate([x, t]), training=training)
        print(f'Up Conv3: {x.shape}')
        
        x = self.up_conv4_trans(x)
        t = crop_image(xs[0], x.shape)
        print(f'Up ConvTranspose4: {x.shape}')
        x = self.up_conv4(layers.concatenate([x, t]), training=training)
        print(f'Up Conv4: {x.shape}')
        
        x = self.conv(x)
        print(f'Conv: {x.shape}')
        
        return x
        

In [6]:
model = UNet()
x = tf.random.normal((1, 572, 572, 3))
y = model(x)

Down Conv1: (1, 568, 568, 64)
Down Conv1 Maxpool: (1, 284, 284, 64)
Down Conv2: (1, 280, 280, 128)
Down Conv2 Maxpool: (1, 140, 140, 128)
Down Conv3: (1, 136, 136, 256)
Down Conv3 Maxpool: (1, 68, 68, 256)
Down Conv4: (1, 64, 64, 512)
Down Conv4 Maxpool: (1, 32, 32, 512)
Down Conv5: (1, 28, 28, 1024)
Up ConvTranspose1: (1, 56, 56, 512)
Up Conv1: (1, 52, 52, 512)
Up ConvTranspose2: (1, 104, 104, 256)
Up Conv2: (1, 100, 100, 256)
Up ConvTranspose3: (1, 200, 200, 128)
Up Conv3: (1, 196, 196, 128)
Up ConvTranspose4: (1, 392, 392, 64)
Up Conv4: (1, 388, 388, 64)
Conv: (1, 388, 388, 2)
