In [None]:
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.optimizers import Adam
import albumentations as A
import numpy as np

# Albumentations Zoom Blur augmentation function
def zoom_blur_aug(image):
    aug = A.OneOf([
        A.ZoomBlur(blur_limit=(3, 7), p=0.5),
        A.NoOp()
    ], p=1)
    image = image.astype(np.uint8)
    image = aug(image=image)['image']
    return image.astype(np.float32)

# Data augmentation setup
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.1,
    height_shift_range=0.1,
    shear_range=0.15,
    zoom_range=0.15,
    horizontal_flip=True,
    vertical_flip=True,
    fill_mode='nearest',
    preprocessing_function=zoom_blur_aug
)

val_datagen = ImageDataGenerator(rescale=1./255)

# Residual Block
def residual_block(x, filters, downsample=False):
    shortcut = x
    strides = (2, 2) if downsample else (1, 1)

    x = Conv2D(filters, (3, 3), strides=strides, padding="same")(x)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    x = Conv2D(filters, (3, 3), padding="same")(x)
    x = BatchNormalization()(x)

    if downsample:
        shortcut = Conv2D(filters, (1, 1), strides=(2, 2), padding="same")(shortcut)
        shortcut = BatchNormalization()(shortcut)

    x = Add()([shortcut, x])
    x = Activation("relu")(x)
    return x

# Atrous Convolution Block
def atrous_block(x, filters):
    conv1 = Conv2D(filters, (3, 3), dilation_rate=1, padding="same", activation='relu')(x)
    conv2 = Conv2D(filters, (3, 3), dilation_rate=3, padding="same", activation='relu')(x)
    conv3 = Conv2D(filters, (3, 3), dilation_rate=5, padding="same", activation='relu')(x)
    concat = Concatenate()([conv1, conv2, conv3])
    output = Conv2D(filters, (1, 1), activation='relu')(concat)
    return output

# Multi-Kernel Pooling Block
def mkp_block(x):
    pool1 = MaxPooling2D(pool_size=(2, 2), strides=1, padding='same')(x)
    pool2 = MaxPooling2D(pool_size=(3, 3), strides=1, padding='same')(x)
    pool3 = MaxPooling2D(pool_size=(5, 5), strides=1, padding='same')(x)
    pool4 = MaxPooling2D(pool_size=(6, 6), strides=1, padding='same')(x)
    concat = Concatenate()([pool1, pool2, pool3, pool4])
    output = Conv2D(256, (1, 1), activation='relu')(concat)
    return output

# Decoder Block
def decoder_block(x, skip, filters):
    x = Conv2DTranspose(filters, (3, 3), strides=(2, 2), padding="same")(x)
    x = Add()([x, skip])
    x = BatchNormalization()(x)
    x = Activation("relu")(x)
    return x

# Complete Model
def build_proposed_model(input_shape=(448, 448, 1)):
    inputs = Input(input_shape)
    x = Conv2D(64, (7, 7), strides=(2, 2), padding="same")(inputs)
    x = BatchNormalization()(x)
    x = Activation("relu")(x)

    e1 = residual_block(x, 64, downsample=True)
    e1 = residual_block(e1, 64)
    e2 = residual_block(e1, 128, downsample=True)
    e2 = residual_block(e2, 128)
    e3 = residual_block(e2, 256, downsample=True)
    e3 = residual_block(e3, 256)

    ac = atrous_block(e3, 256)
    mkp = mkp_block(ac)

    d1 = decoder_block(mkp, e3, 256)
    d2 = decoder_block(d1, e2, 128)
    d3 = decoder_block(d2, e1, 64)
    d4 = Conv2DTranspose(64, (3, 3), strides=(2, 2), padding="same")(d3)

    output = Conv2D(1, (1, 1), activation='sigmoid')(d4)

    model = Model(inputs, output)
    return model

def dice_coef(y_true, y_pred):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    return (2. * intersection) / (tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f))

def jaccard_index(y_true, y_pred):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    intersection = tf.keras.backend.sum(y_true_f * y_pred_f)
    union = tf.keras.backend.sum(y_true_f) + tf.keras.backend.sum(y_pred_f) - intersection
    return intersection / union

def sensitivity(y_true, y_pred):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    tp = tf.keras.backend.sum(y_true_f * y_pred_f)
    fn = tf.keras.backend.sum(y_true_f * (1 - y_pred_f))
    return tp / (tp + fn)

def specificity(y_true, y_pred):
    y_true_f = tf.keras.backend.flatten(y_true)
    y_pred_f = tf.keras.backend.flatten(y_pred)
    tn = tf.keras.backend.sum((1 - y_true_f) * (1 - y_pred_f))
    fp = tf.keras.backend.sum((1 - y_true_f) * y_pred_f)
    return tn / (tn + fp)

model = build_proposed_model()
model.compile(optimizer=Adam(learning_rate=1e-4),
              loss='binary_crossentropy',
              metrics=['accuracy', dice_coef, jaccard_index, sensitivity, specificity])

model.summary()
