In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, LayerNormalization, LeakyReLU, Concatenate
from tensorflow.keras.models import Model

def unet(H=320, W=320, channels=2, kshape=(3,3)):
    inputs = Input(shape=(H, W, channels))
    #mask = Input(shape=(1, W, 1))
    
    # Encoder
    conv1 = Conv2D(32, kshape, padding='same')(inputs)
    conv1 = LayerNormalization(scale=False, center=False)(conv1)
    conv1 = LeakyReLU(alpha=0.2)(conv1)
    conv1 = Conv2D(32, kshape, padding='same')(conv1)
    conv1 = LayerNormalization(scale=False, center=False)(conv1)
    conv1 = LeakyReLU(alpha=0.2)(conv1)
    pool1 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    #print("conv1",conv1.shape)
    #print("pool1",pool1.shape)
    
    
    conv2 = Conv2D(64, kshape, padding='same')(pool1)
    conv2 = LayerNormalization(scale=False, center=False)(conv2)
    conv2 = LeakyReLU(alpha=0.2)(conv2)
    conv2 = Conv2D(64, kshape, padding='same')(conv2)
    conv2 = LayerNormalization(scale=False, center=False)(conv2)
    conv2 = LeakyReLU(alpha=0.2)(conv2)
    pool2 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    #print("conv2",conv2.shape)
    #print("pool2",pool2.shape)

    conv3 = Conv2D(128, kshape, padding='same')(pool2)
    conv3 = LayerNormalization(scale=False, center=False)(conv3)
    conv3 = LeakyReLU(alpha=0.2)(conv3)
    conv3 = Conv2D(128, kshape, padding='same')(conv3)
    conv3 = LayerNormalization(scale=False, center=False)(conv3)
    conv3 = LeakyReLU(alpha=0.2)(conv3)
    pool3 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv3)
    #print("conv3",conv3.shape)
    #print("pool3",pool3.shape)

    conv4 = Conv2D(256, kshape, padding='same')(pool3)
    conv4 = LayerNormalization(scale=False, center=False)(conv4)
    conv4 = LeakyReLU(alpha=0.2)(conv4)
    conv4 = Conv2D(256, kshape, padding='same')(conv4)
    conv4 = LayerNormalization(scale=False, center=False)(conv4)
    conv4 = LeakyReLU(alpha=0.2)(conv4)
    pool4 = tf.keras.layers.MaxPooling2D(pool_size=(2, 2))(conv4)
    #print("conv4",conv4.shape)
    #print("pool4",pool4.shape)

    
    # Bottleneck
    conv5 = Conv2D(512, kshape, padding='same')(pool4)
    conv5 = LayerNormalization(scale=False, center=False)(conv5)
    conv5 = LeakyReLU(alpha=0.2)(conv5)
    conv5 = Conv2D(512, kshape, padding='same')(conv5)
    conv5 = LayerNormalization(scale=False, center=False)(conv5)
    conv5 = LeakyReLU(alpha=0.2)(conv5)
    #print("conv5",conv5.shape)
    
    # Decoder (using Conv2DTranspose instead of UpSampling2D)
    up1 = Conv2DTranspose(256, kshape, strides=(2, 2), padding='same')(conv5)
    up1 = Concatenate()([up1, conv4])
    conv6 = Conv2D(256, kshape, padding='same')(up1)
    conv6 = LayerNormalization(scale=False, center=False)(conv6)
    conv6 = LeakyReLU(alpha=0.2)(conv6)
    conv6 = Conv2D(256, kshape, padding='same')(conv6)
    conv6 = LayerNormalization(scale=False, center=False)(conv6)
    conv6 = LeakyReLU(alpha=0.2)(conv6)
    #print("conv6",conv6.shape)
    up2 = Conv2DTranspose(128, kshape, strides=(2, 2), padding='same')(conv6)
    up2 = Concatenate()([up2, conv3])
    conv7 = Conv2D(128, kshape, padding='same')(up2)
    conv7 = LayerNormalization(scale=False, center=False)(conv7)
    conv7 = LeakyReLU(alpha=0.2)(conv7)
    conv7 = Conv2D(128, kshape, padding='same')(conv7)
    conv7 = LayerNormalization(scale=False, center=False)(conv7)
    conv7 = LeakyReLU(alpha=0.2)(conv7)
    #print("conv7",conv7.shape)
    up3 = Conv2DTranspose(64, kshape, strides=(2, 2), padding='same')(conv7)
    up3 = Concatenate()([up3, conv2])
    conv8 = Conv2D(64, kshape, padding='same')(up3)
    conv8 = LayerNormalization(scale=False, center=False)(conv8)
    conv8 = LeakyReLU(alpha=0.2)(conv8)
    conv8 = Conv2D(64, kshape, padding='same')(conv8)
    conv8 = LayerNormalization(scale=False, center=False)(conv8)
    conv8 = LeakyReLU(alpha=0.2)(conv8)
    #print("conv8",conv8.shape)
    up4 = Conv2DTranspose(32, kshape, strides=(2, 2), padding='same')(conv8)
    up4 = Concatenate()([up4, conv1])
    conv9 = Conv2D(32, kshape, padding='same')(up4)
    conv9 = LayerNormalization(scale=False, center=False)(conv9)
    conv9 = LeakyReLU(alpha=0.2)(conv9)
    conv9 = Conv2D(32, kshape, padding='same')(conv9)
    conv9 = LayerNormalization(scale=False, center=False)(conv9)
    conv9 = LeakyReLU(alpha=0.2)(conv9)
    #print("conv9",conv9.shape)
    outputs = Conv2D(2, (1, 1), activation='linear')(conv9)
    #print("outputs",outputs.shape)
    model = Model(inputs=inputs, outputs=outputs)
    return model

