In [1]:
import tensorflow as tf
from tensorflow.keras import layers,models
class DCRBlockUNet(layers.Layer):
    def __init__(self, in_channels, growth_rate, name=None):
        super().__init__(name=name)

        # Dense conv layers
        self.conv1 = layers.Conv2D(growth_rate, 3, padding="same")
        self.act1  = layers.PReLU(shared_axes=[1, 2])

        self.conv2 = layers.Conv2D(growth_rate, 3, padding="same")
        self.act2  = layers.PReLU(shared_axes=[1, 2])

        self.conv3 = layers.Conv2D(in_channels, 3, padding="same")
        self.act3  = layers.PReLU(shared_axes=[1, 2])

        # Projection to align residual channels
        self.proj = layers.Conv2D(in_channels, 1, padding="same")

    def call(self, x):
        x0 = x

        x1 = self.act1(self.conv1(x0))
        x2 = self.act2(self.conv2(tf.concat([x0, x1], axis=-1)))
        x3 = self.act3(self.conv3(tf.concat([x0, x1, x2], axis=-1)))

        x3 = self.proj(x3)   # ensure same channels
        return x0 + x3



In [2]:
def build_dcr_unet(input_shape=(320, 320, 2)):
    inputs = layers.Input(shape=input_shape)
    x0 = inputs  # global residual

    # -------------------------------------------------
    # Initial 1×1 conv → 32 channels
    # -------------------------------------------------
    x = layers.Conv2D(32, 1, padding="same")(inputs)
    x = layers.PReLU(shared_axes=[1, 2])(x)

    # ================= ENCODER ======================
    # Level 1 (320×320, 32ch)
    e1 = DCRBlockUNet(32, 16)(x)
    e1 = DCRBlockUNet(32, 16)(e1)
    p1 = layers.MaxPooling2D(2)(e1)
    p1 = layers.Conv2D(64, 3, padding="same")(p1)

    # Level 2 (160×160, 64ch)
    e2 = DCRBlockUNet(64, 32)(p1)
    e2 = DCRBlockUNet(64, 32)(e2)
    p2 = layers.MaxPooling2D(2)(e2)
    p2 = layers.Conv2D(128, 3, padding="same")(p2)

    # Level 3 (80×80, 128ch)
    e3 = DCRBlockUNet(128, 64)(p2)
    e3 = DCRBlockUNet(128, 64)(e3)
    p3 = layers.MaxPooling2D(2)(e3)
    p3 = layers.Conv2D(256, 3, padding="same")(p3)

    # ================= BOTTLENECK ===================
    b = DCRBlockUNet(256, 128)(p3)
    b = DCRBlockUNet(256, 128)(b)

    # ================= DECODER ======================
    # Level 3 (80×80)
    u3 = layers.UpSampling2D(2)(b)
    u3 = layers.Conv2D(128, 3, padding="same")(u3)
    u3 = layers.Concatenate()([u3, e3])     # 256 ch
    d3 = DCRBlockUNet(256, 128)(u3)
    d3 = DCRBlockUNet(256, 128)(d3)

    # Level 2 (160×160)
    u2 = layers.UpSampling2D(2)(d3)
    u2 = layers.Conv2D(64, 3, padding="same")(u2)
    u2 = layers.Concatenate()([u2, e2])     # 128 ch
    d2 = DCRBlockUNet(128, 64)(u2)
    d2 = DCRBlockUNet(128, 64)(d2)

    # Level 1 (320×320)
    u1 = layers.UpSampling2D(2)(d2)
    u1 = layers.Conv2D(32, 3, padding="same")(u1)
    u1 = layers.Concatenate()([u1, e1])     # 64 ch
    d1 = DCRBlockUNet(64, 32)(u1)
    d1 = DCRBlockUNet(64, 32)(d1)

    # -------------------------------------------------
    # Final reconstruction + global residual
    # -------------------------------------------------
    out = layers.Conv2D(2, 1, padding="same")(d1)
    out = layers.Add()([out, x0])

    return models.Model(inputs, out, name="DCR_UNet")


In [3]:
def nmse_metric(y_true, y_pred):
    return tf.reduce_sum(tf.square(y_true - y_pred)) / tf.reduce_sum(tf.square(y_true))
def psnr_metric(y_true, y_pred):
    return tf.image.psnr(y_true, y_pred, max_val=1.0)


In [4]:
optimizer = tf.keras.optimizers.RMSprop(learning_rate=1e-3)

model = build_dcr_unet(input_shape=(320, 320, 2))
model.summary()

model.compile(
    optimizer=optimizer,
    loss="mse",
    metrics=[psnr_metric, nmse_metric]
)


Model: "DCR_UNet"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 320, 320, 2  0           []                               
                                )]                                                                
                                                                                                  
 conv2d (Conv2D)                (None, 320, 320, 32  96          ['input_1[0][0]']                
                                )                                                                 
                                                                                                  
 p_re_lu (PReLU)                (None, 320, 320, 32  32          ['conv2d[0][0]']                 
                                )                                                          