In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D, Conv2DTranspose, MaxPooling2D, concatenate, Dropout
from tensorflow_addons.layers.normalizations import GroupNormalization as CustomGroupNorm



TensorFlow Addons (TFA) has ended development and introduction of new features.
TFA has entered a minimal maintenance and release mode until a planned end of life in May 2024.
Please modify downstream libraries to take dependencies from other repositories in our TensorFlow community (e.g. Keras, Keras-CV, and Keras-NLP). 

For more information see: https://github.com/tensorflow/addons/issues/2807 



In [2]:
class UNetMidBlock2DCrossAttn(tf.keras.Model):
    def __init__(self, **kwargs):
        super(UNetMidBlock2DCrossAttn, self).__init__(**kwargs)

        self.conv1 = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')
        self.conv2 = Conv2D(filters=128, kernel_size=(3, 3), padding='same', activation='relu')
        self.dropout = Dropout(0.5)

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.conv2(x)
        x = self.dropout(x)
        return x

In [3]:
def Code2Model(height, width, start_neurons):
    input_layer = Input(shape=(height, width, 3))  # Keep the input shape as (32, 32, 3) for CIFAR-10
    
    conv1 = Conv2D(filters=start_neurons * 1, kernel_size=(3, 3), padding="same", activation='relu')(input_layer)
    conv1 = Conv2D(filters=start_neurons * 1, kernel_size=(3, 3), padding="same", activation='relu')(conv1)
    pool1 = MaxPooling2D((2, 2))(conv1)
    pool1 = Dropout(0.25)(pool1)

    conv2 = Conv2D(filters=start_neurons * 2, kernel_size=(3, 3), padding="same", activation='relu')(pool1)
    conv2 = Conv2D(filters=start_neurons * 2, kernel_size=(3, 3), padding="same", activation='relu')(conv2)
    pool2 = MaxPooling2D((2, 2))(conv2)
    pool2 = Dropout(0.5)(pool2)

    conv3 = Conv2D(filters=start_neurons * 4, kernel_size=(3, 3), padding="same", activation='relu')(pool2)
    conv3 = Conv2D(filters=start_neurons * 4, kernel_size=(3, 3), padding="same", activation='relu')(conv3)
    pool3 = MaxPooling2D((2, 2))(conv3)
    pool3 = Dropout(0.5)(pool3)

    conv4 = Conv2D(filters=start_neurons * 8, kernel_size=(3, 3), padding="same", activation='relu')(pool3)
    conv4 = Conv2D(filters=start_neurons * 8, kernel_size=(3, 3), padding="same", activation='relu')(conv4)
    pool4 = MaxPooling2D((2, 2))(conv4)
    pool4 = Dropout(0.5)(pool4)

    convm = Conv2D(filters=start_neurons * 16, kernel_size=(3, 3), padding="same", activation='relu')(pool4)
    convm = Conv2D(filters=start_neurons * 16, kernel_size=(3, 3), padding="same", activation='relu')(convm)
    convm = UNetMidBlock2DCrossAttn()(convm)

    deconv4 = Conv2DTranspose(start_neurons * 8, (3, 3), strides=(2, 2), padding="same")(convm)
    uconv4 = concatenate([deconv4, conv4])
    uconv4 = Dropout(0.5)(uconv4)
    uconv4 = Conv2D(filters=start_neurons * 8, kernel_size=(3, 3), padding="same", activation='relu')(uconv4)
    uconv4 = Conv2D(filters=start_neurons * 8, kernel_size=(3, 3), padding="same", activation='relu')(uconv4)

    deconv3 = Conv2DTranspose(start_neurons * 4, (3, 3), strides=(2, 2), padding="same")(uconv4)
    uconv3 = concatenate([deconv3, conv3])
    uconv3 = Dropout(0.5)(uconv3)
    uconv3 = Conv2D(filters=start_neurons * 4, kernel_size=(3, 3), padding="same", activation='relu')(uconv3)
    uconv3 = Conv2D(filters=start_neurons * 4, kernel_size=(3, 3), padding="same", activation='relu')(uconv3)

    deconv2 = Conv2DTranspose(start_neurons * 2, (3, 3), strides=(2, 2), padding="same")(uconv3)
    uconv2 = concatenate([deconv2, conv2])
    uconv2 = Dropout(0.5)(uconv2)
    uconv2 = Conv2D(filters=start_neurons * 2, kernel_size=(3, 3), padding="same", activation='relu')(uconv2)
    uconv2 = Conv2D(filters=start_neurons * 2, kernel_size=(3, 3), padding="same", activation='relu')(uconv2)

    deconv1 = Conv2DTranspose(start_neurons * 1, (3, 3), strides=(2, 2), padding="same")(uconv2)
    uconv1 = concatenate([deconv1, conv1])
    uconv1 = Dropout(0.5)(uconv1)
    uconv1 = Conv2D(filters=start_neurons * 1, kernel_size=(3, 3), padding="same", activation='relu')(uconv1)
    uconv1 = Conv2D(filters=start_neurons * 1, kernel_size=(3, 3), padding="same", activation='relu')(uconv1)

    group_norm = CustomGroupNorm(groups=16, axis=-1)
    uconv1 = group_norm(uconv1)

    uconv1 = Dropout(0.5)(uconv1)
    uconv1 = Conv2D(4, (1, 1), padding="same", activation='relu')(uconv1)
    output_layer = Conv2D(1, (1, 1), padding="same", activation='sigmoid')(uconv1)

    return tf.keras.Model(inputs=input_layer, outputs=output_layer)  # Return a Keras Model
