In [1]:
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, Conv2DTranspose, Dropout, BatchNormalization, MaxPooling2D
from tensorflow.keras import mixed_precision

mixed_precision.set_global_policy('mixed_float16')
def resBlock(input_layer, filter_nbr, dropout_rate, kernel_size=(3, 3), stride=1, layer_name="rb", training=True,
             pooling=True, repetition=1):
    shortcut = Conv2D(filter_nbr, (1, 1), strides=stride, activation=tf.nn.leaky_relu, name=layer_name + '_s_0')(input_layer)

    resA = Conv2D(filter_nbr, kernel_size, padding='same', activation=tf.nn.leaky_relu, name=layer_name + '_conv1')(input_layer)
    resA = BatchNormalization()(resA)
    resA = Conv2D(filter_nbr, kernel_size, padding='same', activation=tf.nn.leaky_relu, name=layer_name + '_conv2')(resA)
    resA = BatchNormalization()(resA)

    resA = tf.add(resA, shortcut)

    if pooling:
        resB = Dropout(dropout_rate, name="dropout")(resA)
        resB = MaxPooling2D((2, 2), padding='same')(resB)
        return resB, resA
    else:
        resB = Dropout(dropout_rate, name="dropout")(resA)
        return resB

def upBlock(input_layer, skip_layer, filter_nbr, dropout_rate, kernel_size=(3, 3), layer_name="dec", training=True):
    upA = Conv2DTranspose(filter_nbr, kernel_size, strides=2, padding='same', activation=tf.nn.leaky_relu, name=layer_name + "_up_tconv")(input_layer)
    upA = Dropout(dropout_rate, name="dropout")(upA)

    upB = tf.add(upA, skip_layer, name="add")
    upB = Dropout(dropout_rate, name="dropout_add")(upB)

    upE = Conv2D(filter_nbr, kernel_size, padding='same', activation=tf.nn.leaky_relu, name=layer_name + "_conv1")(upB)
    upE = BatchNormalization()(upE)
    upE = Conv2D(filter_nbr, kernel_size, padding='same', activation=tf.nn.leaky_relu, name=layer_name + "_conv2")(upE)
    upE = BatchNormalization()(upE)
    upE = Conv2D(filter_nbr, kernel_size, padding='same', activation=tf.nn.leaky_relu, name=layer_name + "_conv3")(upE)
    upE = BatchNormalization()(upE)
    upE = Dropout(dropout_rate, name="dropout_conv")(upE)

    return upE

def create_SalsaNet(input_img, num_classes=3, dropout_rate=0.5, is_training=False, kernel_number=32):
    print ("--------------- SalsaNet model --------------------")
    print("input", input_img.shape)

    down0c, down0b = resBlock(input_img, filter_nbr=kernel_number, dropout_rate=dropout_rate, kernel_size=3, stride=1, layer_name="res0", training=is_training, repetition=1)
    down1c, down1b = resBlock(down0c, filter_nbr=2 * kernel_number, dropout_rate=dropout_rate, kernel_size=3, stride=1, layer_name="res1", training=is_training, repetition=1)
    down2c, down2b = resBlock(down1c, filter_nbr=4 * kernel_number, dropout_rate=dropout_rate, kernel_size=3, stride=1, layer_name="res2", training=is_training, repetition=1)
    down3c, down3b = resBlock(down2c, filter_nbr=8 * kernel_number, dropout_rate=dropout_rate, kernel_size=3, stride=1, layer_name="res3", training=is_training, repetition=1)
    down4b = resBlock(down3c, filter_nbr=8 * kernel_number, dropout_rate=dropout_rate, kernel_size=3, stride=1, layer_name="res4", training=is_training, pooling=False, repetition=1)

    up3e = upBlock(down4b, down3b,  filter_nbr=8 * kernel_number, dropout_rate=dropout_rate, kernel_size=(3, 3), layer_name="up3", training=is_training)
    up2e = upBlock(up3e, down2b,  filter_nbr=4 * kernel_number, dropout_rate=dropout_rate, kernel_size=(3, 3), layer_name="up2", training=is_training)
    up1e = upBlock(up2e, down1b,  filter_nbr=2 * kernel_number, dropout_rate=dropout_rate, kernel_size=(3, 3), layer_name="up1", training=is_training)
    up0e = upBlock(up1e, down0b,  filter_nbr= kernel_number, dropout_rate=dropout_rate, kernel_size=(3, 3), layer_name="up0", training=is_training)

    logits = Conv2D(num_classes, (1, 1), activation=None, name='logits')(up0e)
    print("logits", logits.shape)

    return logits
    print ("--------------------------------------------------")
print('no error')

INFO:tensorflow:Mixed precision compatibility check (mixed_float16): OK
Your GPU will likely run quickly with dtype policy mixed_float16 as it has compute capability of at least 7.0. Your GPU: NVIDIA GeForce RTX 3050 Laptop GPU, compute capability 8.6
no error
