In [1]:
import tensorflow as tf
from tensorflow.keras import layers


class DCRBlock(layers.Layer):
    """
    Densely Connected Residual (DCR) Block
    Paper-faithful implementation (Fig. 3)

    Input  : (B, H, W, C)
    Output : (B, H, W, C)
    """

    def __init__(self, in_channels, growth_rate, kernel_size=3, name=None):
        super().__init__(name=name)

        self.in_channels = in_channels
        self.growth_rate = growth_rate

        # Conv 1: x0 -> x1
        self.conv1 = layers.Conv2D(
            filters=growth_rate,
            kernel_size=kernel_size,
            padding="same",
            activation="relu"
        )

        # Conv 2: [x0, x1] -> x2
        self.conv2 = layers.Conv2D(
            filters=growth_rate,
            kernel_size=kernel_size,
            padding="same",
            activation="relu"
        )

        # Conv 3: [x0, x1, x2] -> x3 (project back to C)
        self.conv3 = layers.Conv2D(
            filters=in_channels,
            kernel_size=kernel_size,
            padding="same",
            activation="relu"
        )

    def call(self, x):
        """
        x: (B, H, W, C)
        """
        x0 = x

        # First conv
        x1 = self.conv1(x0)

        # Dense concat: [x0, x1]
        x2_input = tf.concat([x0, x1], axis=-1)
        x2 = self.conv2(x2_input)

        # Dense concat: [x0, x1, x2]
        x3_input = tf.concat([x0, x1, x2], axis=-1)
        x3 = self.conv3(x3_input)

        # Local residual connection
        out = x0 + x3

        return out


In [2]:
import tensorflow as tf
from tensorflow.keras import layers, models


def build_dcr_cnn(
    input_shape=(320, 320, 1),
    num_dcr_blocks=10,
    num_features=64,
    growth_rate=32
):
    """
    Densely Connected Residual CNN (DCR-CNN)
    Paper-faithful implementation (Fig. 1, Sec. 2.2.1)

    Args:
        input_shape   : (H, W, 1)
        num_dcr_blocks: number of DCR blocks (3, 8, or 10)
        num_features  : number of feature maps (64)
        growth_rate   : DCR growth rate (k)

    Returns:
        tf.keras.Model
    """

    # --------------------------------------------------
    # Input (ZF magnitude image)
    # --------------------------------------------------
    inp = layers.Input(shape=input_shape, name="zf_input")
    x0 = inp  # for global residual

    # --------------------------------------------------
    # Initial convolution (1 → 64)
    # --------------------------------------------------
    x = layers.Conv2D(
        filters=num_features,
        kernel_size=3,
        padding="same",
        activation="relu",
        name="conv_initial"
    )(inp)

    # --------------------------------------------------
    # DCR blocks (64 → 64)
    # --------------------------------------------------
    for i in range(num_dcr_blocks):
        x = DCRBlock(
            in_channels=num_features,
            growth_rate=growth_rate,
            name=f"dcr_block_{i+1}"
        )(x)

    # --------------------------------------------------
    # Final convolution (64 → 1)
    # --------------------------------------------------
    x = layers.Conv2D(
        filters=1,
        kernel_size=3,
        padding="same",
        activation=None,
        name="conv_final"
    )(x)

    # --------------------------------------------------
    # Global residual learning
    # --------------------------------------------------
    out = layers.Add(name="global_residual")([x0, x])

    # --------------------------------------------------
    # Model
    # --------------------------------------------------
    model = models.Model(inputs=inp, outputs=out, name=f"DCR_CNN_{num_dcr_blocks}")

    return model


In [3]:
# from tensorflow.keras.optimizers.experimental import AdamW
# def nmse_metric(y_true, y_pred):
#     return tf.reduce_sum(tf.square(y_true - y_pred)) / tf.reduce_sum(tf.square(y_true))
# def psnr_metric(y_true, y_pred):
#     return tf.image.psnr(y_true, y_pred, max_val=1.0)

# optimizer_cnn = AdamW(
#     learning_rate=1e-4,
#     weight_decay=1e-7,
#     beta_1=0.9,
#     beta_2=0.999
# )
# model = build_dcr_cnn(
#     input_shape=(320, 320, 1),
#     num_dcr_blocks=10,   # or 3 / 8
#     num_features=64,
#     growth_rate=32
# )
# model.compile(
#     optimizer=optimizer_cnn,
#     loss="mse",
#     metrics=[psnr_metric, nmse_metric]
# )
# model.summary()

Model: "DCR_CNN_10"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 zf_input (InputLayer)          [(None, 320, 320, 1  0           []                               
                                )]                                                                
                                                                                                  
 conv_initial (Conv2D)          (None, 320, 320, 64  640         ['zf_input[0][0]']               
                                )                                                                 
                                                                                                  
 dcr_block_1 (DCRBlock)         (None, 320, 320, 64  119936      ['conv_initial[0][0]']           
                                )                                                        

In [4]:
from tensorflow.keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, EarlyStopping

callbacks = [
    ModelCheckpoint(
        filepath="dcr_cnn_best.h5",
        monitor="val_loss",
        save_best_only=True,
        save_weights_only=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor="val_loss",
        factor=0.5,
        patience=5,
        min_lr=1e-6,
        verbose=1
    ),
    EarlyStopping(
        monitor="val_loss",
        patience=10,
        restore_best_weights=True,
        verbose=1
    )
]
