In [1]:
import tensorflow as tf
from tensorflow.keras import layers, Model
import numpy as np

# ============================================================
# 1) Weight Normalized Convolution (WNConv2D)
# ============================================================
class WNConv2D(layers.Layer):
    def __init__(self, filters, kernel_size, strides=1, padding="same", use_bias=True, name=None):
        super().__init__(name=name)
        self.filters = filters
        self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else kernel_size
        self.strides = (strides, strides)
        self.padding = padding.upper()
        self.use_bias = use_bias

    def build(self, input_shape):
        in_channels = int(input_shape[-1])
        kH, kW = self.kernel_size

        # v: unnormalized weights
        self.v = self.add_weight(
            name="v",
            shape=(kH, kW, in_channels, self.filters),
            initializer="glorot_uniform",
            trainable=True
        )

        # g: per-channel scaling
        v_flat = tf.reshape(self.v, [-1, self.filters])
        g_init = tf.norm(v_flat, axis=0)
        self.g = self.add_weight(
            name="g",
            shape=(self.filters,),
            initializer=tf.keras.initializers.Constant(g_init.numpy()),
            trainable=True
        )

        if self.use_bias:
            self.bias = self.add_weight(
                name="bias",
                shape=(self.filters,),
                initializer="zeros",
                trainable=True
            )
        else:
            self.bias = None

    def call(self, x):
        v_flat = tf.reshape(self.v, [-1, self.filters])
        v_norm = tf.norm(v_flat, axis=0) + 1e-8
        w = self.v * (self.g / v_norm)

        x = tf.nn.conv2d(x, w, strides=(1, *self.strides, 1), padding=self.padding)
        if self.bias is not None:
            x = tf.nn.bias_add(x, self.bias)
        return x


# ============================================================
# 2) Channel Attention (CA)
# ============================================================
class ChannelAttention(layers.Layer):
    def __init__(self, channels, reduction=16):
        super().__init__()
        self.channels = channels
        self.reduction = reduction
        self.fc1 = layers.Dense(channels // reduction, activation='relu')
        self.fc2 = layers.Dense(channels, activation='sigmoid')

    def call(self, x):
        s = tf.reduce_mean(x, axis=[1, 2])  # GAP
        s = self.fc1(s)
        s = self.fc2(s)
        s = tf.reshape(s, [-1, 1, 1, self.channels])
        return x * s


# ============================================================
# 3) Spatial Attention (SA)
# ============================================================
class SpatialAttention(layers.Layer):
    def __init__(self):
        super().__init__()
        self.conv = layers.Conv2D(1, kernel_size=7, padding='same', activation='sigmoid')

    def call(self, x):
        avg = tf.reduce_mean(x, axis=-1, keepdims=True)
        maxp = tf.reduce_max(x, axis=-1, keepdims=True)
        t = tf.concat([avg, maxp], axis=-1)
        att = self.conv(t)
        return x * att


# ============================================================
# 4) Wide Activation ConvBlock (WACB)
#     EXACT structure from figure: Conv3×3 → WN → ReLU → Conv3×3 → WN → residual
# ============================================================
class WACB(layers.Layer):
    def __init__(self, channels, name=None):
        super().__init__(name=name)
        self.channels = channels
        self.conv1 = WNConv2D(channels, 3)
        self.conv2 = WNConv2D(channels, 3)

    def call(self, x):
        res = x
        out = self.conv1(x)
        out = tf.nn.relu(out)
        out = self.conv2(out)
        #print("WACB",out.shape)


        return out


# ============================================================
# 5) Wide Activation TransBlock (WATB)
#     EXACT: TransConv2×2 → WN → ReLU → Conv3×3 → WN → residual
# ============================================================
class WATB(layers.Layer):
    def __init__(self, channels, name=None):
        super().__init__(name=name)
        self.channels = channels

        # Transposed convolution from figure (pink block)
        self.trans = layers.Conv2DTranspose(channels, kernel_size=2, strides=2, padding='same')

        self.conv = WNConv2D(channels, 3)

    def call(self, x):
        res = x
        out = self.trans(x)
        out = self.conv(out)
        out = tf.nn.relu(out)
        

        # # match channels if needed
        # if res.shape[-1] != self.channels:
        #     res = WNConv2D(self.channels, 1)(res)

        return out


# ============================================================
# 6) Hybrid Attention Block (HAB)
# ============================================================
class HAB(layers.Layer):
    def __init__(self, channels, name=None):
        super().__init__(name=name)
        self.channels = channels
        self.wacb = WACB(channels)
        self.ca = ChannelAttention(channels)
        self.sa = SpatialAttention()
        self.conv1x1 = WNConv2D(channels, 1)

    def call(self, x):
        h = self.wacb(x)

        h_ca = self.ca(h)
        h_sa = self.sa(h)

        h_cat = tf.concat([h_ca, h_sa], axis=-1)  # concat channels
        h_proj = self.conv1x1(h_cat)

        return x + h_proj


# ============================================================
# 7) SRUN (main module)
# ============================================================
class SRUN(layers.Layer):
    def __init__(self, name="SRUN"):
        super().__init__(name=name)

        # Channel widths EXACTLY from figure
        self.ch0 = 16
        self.ch1 = 32
        self.ch2 = 64
        self.ch3 = 128

        # Head
        self.wacb_head = WACB(self.ch0)
        self.hab_head = HAB(self.ch0)

        # Encoder
        #self.enc1 = WACB(self.ch0)
        self.pool1 = layers.AveragePooling2D(pool_size=2)
        self.enc2 = WACB(self.ch1)

        
        self.pool2 = layers.AveragePooling2D(pool_size=2)
        self.enc3 = WACB(self.ch2)

        
        self.pool3 = layers.AveragePooling2D(pool_size=2)
        self.bottleneck = WACB(self.ch3)

        # Bottleneck
        

        # Decoder (EXACT pink/blue pattern)
        self.dec3_watb = WATB(self.ch2)
        self.dec3_wacb = WACB(self.ch2)

        self.dec2_watb = WATB(self.ch1)
        self.dec2_wacb = WACB(self.ch1)

        self.dec1_watb = WATB(self.ch0)
        self.dec1_wacb = WACB(self.ch0)

        # Tail
        self.hab_tail = HAB(self.ch0)
        self.out_conv = WNConv2D(2, 3)

    def call(self, undersampled_image):
        # -----------------------------
        # Eq (3) & (4): Shallow + Head HAB
        x=undersampled_image
        # -----------------------------
        shallow = self.wacb_head(x)     # → 16 ch
        # print("shallow",shallow.shape)
        head = self.hab_head(shallow)   # → 16 ch
        # print("head",head.shape)

        # -----------------------------
        # Encoder (with skips)
        # -----------------------------
        #e1 = self.enc1(head)        # 16
        p1 = self.pool1(head)
        e2 = self.enc2(p1) # 32
        # print("p1",p1.shape)
        # print("e2",e2.shape)
        

                  
        p2 = self.pool2(e2)
        e3 = self.enc3(p2)          # 64
        # print("p2",p2.shape)
        # print("e3",e3.shape)

        p3 = self.pool3(e3)
        b = self.bottleneck(p3)
        # print("p3",p3.shape)
        # print("b",b.shape)

        # Bottleneck
             # 128

        # -----------------------------
        # Decoder (mirror)
        # -----------------------------
        # Level 3
        d3 = self.dec3_watb(b)
        # print("d3",d3.shape)
        d3 = tf.concat([d3, e3], axis=-1)
        # print("d3",d3.shape)
        d3 = self.dec3_wacb(d3)     # 64
        # print("d3",d3.shape)
        # Level 2
        d2 = self.dec2_watb(d3)
        # print("d2",d2.shape)
        d2 = tf.concat([d2, e2], axis=-1)
        # print("d2",d2.shape)
        d2 = self.dec2_wacb(d2)     # 32
        # print("d2",d2.shape)
        # Level 1
        d1 = self.dec1_watb(d2)
        # print("d1",d1.shape)
        d1 = tf.concat([d1, head], axis=-1)
        # print("d1",d1.shape)
        d1 = self.dec1_wacb(d1)     # 16
        # print("d1",d1.shape)

        # -----------------------------
        # Tail HAB
        # -----------------------------
        out_feat = self.hab_tail(d1)
        # print("out_feat",out_feat.shape)

        # Final projection → 2 channels
        out = self.out_conv(out_feat)
        # print("out",out.shape)

        return shallow, out


# ============================================================
# 8) Test SRUN correctness
# # ============================================================
# if __name__ == "__main__":
#     x = np.random.randn(1, 256, 256, 2).astype(np.float32)
#     srun = SRUN()
#     y = srun(x)
#     #print("Output shape:", y.shape)


In [2]:
import tensorflow as tf
from tensorflow.keras import layers
import numpy as np

class DRCM(layers.Layer):
    """
    Detail Representation Construction Module (DRCM)
    Uses fixed Sobel edge detectors + standard Conv2D(3x3).
    """
    def __init__(self, out_channels=2, name="DRCM"):
        super().__init__(name=name)
        
       # Final convolution (standard)
        self.proj = layers.Conv2D(out_channels, kernel_size=3, padding='same', name=name+"_conv")

        # Sobel filters
        sobel_x = np.array([[ -1., 0., 1.],
                            [ -2., 0., 2.],
                            [ -1., 0., 1.]], dtype=np.float32)
        sobel_y = np.array([[ -1., -2., -1.],
                            [  0.,  0.,  0.],
                            [  1.,  2.,  1.]], dtype=np.float32)

        self.sobel_x = tf.constant(sobel_x.reshape(3,3,1,1), dtype=tf.float32)
        self.sobel_y = tf.constant(sobel_y.reshape(3,3,1,1), dtype=tf.float32)

    def call(self, x):
        real = x[..., 0:1]
        imag = x[..., 1:2]

        # Real magnitude
        Gx_r = tf.nn.conv2d(real, self.sobel_x, strides=[1,1,1,1], padding="SAME")
        Gy_r = tf.nn.conv2d(real, self.sobel_y, strides=[1,1,1,1], padding="SAME")
        G_real = tf.sqrt(Gx_r**2 + Gy_r**2 + 1e-12)

        # Imag magnitude
        Gx_i = tf.nn.conv2d(imag, self.sobel_x, strides=[1,1,1,1], padding="SAME")
        Gy_i = tf.nn.conv2d(imag, self.sobel_y, strides=[1,1,1,1], padding="SAME")
        G_imag = tf.sqrt(Gx_i**2 + Gy_i**2 + 1e-12)

        G_cat = tf.concat([G_real, G_imag], axis=-1)

        # Standard 3×3 conv (per figure + eq.(8))
        D0 = self.proj(G_cat)

        return D0
# if __name__ == "__main__":
#     # dummy zero-filled example (batch 1, 128x128, 2 channels)
#     inp = np.random.randn(1,320,320,2).astype(np.float32)
#     drcm = DRCM(out_channels=2)
#     out = drcm(inp)
#     print("DRCM input shape:", inp.shape)
#     print("DRCM output shape:", out.shape)

In [3]:
import tensorflow as tf
from tensorflow.keras import layers

class LRDB(layers.Layer):
    """
    Lite Residual Dense Block (faithful to figure):
      - 4 conv3x3 layers with channel pattern: 32 -> 16 -> 32 -> 16
      - Dense-style concatenation: each conv input = concat(original_input + previous outputs)
      - Final 1x1 conv (bottleneck) to project concatenated intermediates back to 'out_channels' (default 16)
      - Local residual: output = input + bottleneck_proj
    """
    def __init__(self, in_channels=16, out_channels=16, name="LRDB"):
        super().__init__(name=name)
        # convs use the filters seen in the figure
        self.conv1 = layers.Conv2D(32, 3, padding='same', activation='relu', name=name+"_c1")
        self.conv2 = layers.Conv2D(16, 3, padding='same', activation='relu', name=name+"_c2")
        self.conv3 = layers.Conv2D(32, 3, padding='same', activation='relu', name=name+"_c3")
        self.conv4 = layers.Conv2D(16, 3, padding='same', activation='relu', name=name+"_c4")

        # bottleneck: project concatenated features back to out_channels
        # The concatenation will be of [y1,y2,y3,y4] => channels = 32+16+32+16 = 96
        self.bottleneck = layers.Conv2D(out_channels, 1, padding='same', name=name+"_bottleneck")

    def call(self, x):
        # x shape: (B,H,W,C) where C is typically 16 in the DFRM usage
        y1 = self.conv1(x)                          # -> 32
        #print("y1",y1.shape)
        y1_con=tf.concat([x, y1], axis=-1)
        #print("y1_con",y1_con.shape)
        y2 = self.conv2(y1_con)  # -> 16
        #print("y2",y2.shape)

        y2_con = x + y2
        #print("y2_con",y2_con.shape)
        y3 = self.conv3(y2_con)  # -> 32
        #print("y3",y3.shape)

        y3_con=tf.concat([ y2_con, y3], axis=-1)
        #print("y3_con",y3_con.shape)
        y4 = self.conv4(y3_con)  # -> 16
        #print("y4",y4.shape)
        # concat intermediates and bottleneck-project
        out = x+y2_con+y4  # channels = 96
        #print("out",out.shape)
        #out = self.bottleneck(concat_feats)                  # -> out_channels (e.g., 16)

        # # local residual
        # # If input channels != out_channels, project input accordingly before add
        # if x.shape[-1] != out.shape[-1]:
        #     x_proj = layers.Conv2D(out.shape[-1], 1, padding='same')(x)
        #     return x_proj + out
        # else:
        return out


In [4]:
# if __name__ == "__main__":
#     import numpy as np
#     lrdb = LRDB(in_channels=16, out_channels=16)
#     x = np.random.randn(1, 128, 128, 16).astype('float32')
#     y = lrdb(x)
#     print("LRDB out shape:", y.shape)  # expect (1,128,128,16)


In [26]:
# ============================================================
# DFRM (Detail Feature Refinement Module)
# ============================================================
class DFRM(layers.Layer):
    """
    Inputs:
        D_in  : detail input     (B,H,W,16)
        S_shallow : shallow structure features from SRUN (B,H,W,16)

    Outputs:
        D_out : refined detail output (B,H,W,16)
    """
    def __init__(self, channels=16, name="DFRM"):
        super().__init__(name=name)
        self.channels = channels

        # Eq (9) — Extract shallow detail features
        self.conv_shallow = layers.Conv2D(channels, 3, padding='same')

        # Eq (10) — WACB fused head
        self.wacb_head = WACB(channels)

        # Feature refinement body (LRDB + Conv1×1 + LRDB + Conv1×1 + CA)
        self.lrd1 = LRDB(channels)
        self.ca = ChannelAttention(channels)
        
        self.conv1 = layers.Conv2D(channels, 1, padding='same')
        self.lrd2 = LRDB(channels)
        self.conv2 = layers.Conv2D(channels, 1, padding='same')

        # Channel attention (pink block)
        

        # Tail WACB (Eq 12)
        self.wacb_tail = WACB(channels)

        # Final output conv (Eq 13)
        self.conv_out = layers.Conv2D(channels, 3, padding='same')
        self.dfrm_final = layers.Conv2D(2, 1, padding='same', name=name+"_DFRM_1x1")

    def call(self, D_in, S_shallow):
        # -----------------------------
        # (1) Shallow detail features
        # -----------------------------
        D_shallow = self.conv_shallow(D_in)   # Eq (9)

        # -----------------------------
        # (2) Concatenate with S_shallow → WACB head
        # -----------------------------
        fused = tf.concat([D_shallow, S_shallow], axis=-1)
        D_head = self.wacb_head(fused)        # Eq (10)

        # -----------------------------
        # (3) Feature refinement body
        # -----------------------------
        lrd_1 = self.lrd1(D_head)
        att = self.ca(D_head)
        lrdb1_con=tf.concat([lrd_1, att], axis=-1)
        
        c_1 = self.conv1(lrdb1_con)

        lrd_2 = self.lrd2(c_1)
        lrdb2_con=tf.concat([lrd_2, att], axis=-1)
        h = self.conv2(lrdb2_con)
        D_tail = self.wacb_tail(h) 

        # channel attention applied to D_head (source-shared)
        

        # -----------------------------
        D_sum = D_shallow + D_tail

        # final output convolution (Eq 13)
        D_out_pre = self.conv_out(D_sum)
        D_out=self.dfrm_final(D_out_pre)
        

        return D_out

In [27]:
import tensorflow as tf
from tensorflow.keras import layers

# ------------------------
# Instance Normalization
# ------------------------
class InstanceNorm(layers.Layer):
    def __init__(self, epsilon=1e-5, name="InstanceNorm"):
        super().__init__(name=name)
        self.epsilon = epsilon

    def build(self, input_shape):
        ch = int(input_shape[-1])
        self.gamma = self.add_weight(shape=(ch,),
                                     initializer="ones",
                                     trainable=True,
                                     name="gamma")
        self.beta = self.add_weight(shape=(ch,),
                                    initializer="zeros",
                                    trainable=True,
                                    name="beta")

    def call(self, x):
        mean, var = tf.nn.moments(x, axes=[1,2], keepdims=True)
        x_norm = (x - mean) / tf.sqrt(var + self.epsilon)
        return self.gamma * x_norm + self.beta


# ------------------------
# Memory-safe, adaptive DGFM
# ------------------------
class DGFM(layers.Layer):
    """
    Detail Guided Fusion Module (memory-safe and adaptive).

    This implementation:
      - Infers channel count from inputs (works for 2 or 16 etc).
      - Computes attention on a reduced spatial grid only when needed, to avoid OOM.
      - Uses conv 1x1 for query/key/value applied to appropriate tensors:
          * query, key <- D_out
          * value        <- F_input = D_out + S_out
      - Applies InstanceNorm to the attention-activated output, then adds a global skip.

    Args:
      max_attn_elements: maximum allowed H*W for attention matrix (default 4096).
                         This controls when downsampling is applied. 4096 => ~64x64 attention.
      name: layer name
    """
    def __init__(self, max_attn_elements=4096, name="DGFM"):
        super().__init__(name=name)
        self.max_attn_elements = int(max_attn_elements)
        self.IN = InstanceNorm(name=name + "/IN")

        # conv layers will be created lazily in build() because we need input channel count
        self.conv_query = None
        self.conv_key = None
        self.conv_value = None

    def build(self, input_shape):
        # input_shape is [S_out_shape, D_out_shape] or we will infer in call
        # We don't rely on build input_shape here; layer convs are created in call when channels known.
        super().build(input_shape)

    def _make_conv_layers(self, channels):
        # Create 1x1 convs for query/key/value if not already created
        if self.conv_query is None:
            self.conv_query = layers.Conv2D(channels, kernel_size=1, padding='same', name=self.name + "/conv_q")
            self.conv_key   = layers.Conv2D(channels, kernel_size=1, padding='same', name=self.name + "/conv_k")
            self.conv_value = layers.Conv2D(channels, kernel_size=1, padding='same', name=self.name + "/conv_v")

    def _compute_reduced_size(self, H, W):
        """
        Compute an integer reduced size (new_H, new_W) such that
        new_H * new_W <= self.max_attn_elements, preserving roughly the aspect ratio.
        H, W are tf.Tensor (ints).
        """
        HW = tf.cast(H, tf.int64) * tf.cast(W, tf.int64)
        maxE = tf.cast(self.max_attn_elements, tf.int64)

        # If HW <= maxE, no reduction
        def no_downsample():
            return H, W

        def do_downsample():
            # scale = ceil(sqrt(HW / maxE))
            ratio = tf.cast(HW, tf.float32) / tf.cast(maxE, tf.float32)
            ratio = tf.math.maximum(ratio, 1.0)
            scale = tf.math.ceil(tf.math.sqrt(ratio))
            # new_H = ceil(H / scale), new_W = ceil(W / scale)
            new_H = tf.cast(tf.math.ceil(tf.cast(H, tf.float32) / scale), tf.int32)
            new_W = tf.cast(tf.math.ceil(tf.cast(W, tf.float32) / scale), tf.int32)
            # Ensure minimum 1
            new_H = tf.maximum(new_H, 1)
            new_W = tf.maximum(new_W, 1)
            return new_H, new_W

        new_H, new_W = tf.cond(HW <= maxE, no_downsample, do_downsample)
        return new_H, new_W

    def call(self, S_out, D_out):
        """
        Args:
          S_out: structure output tensor, shape (B, H, W, C)
          D_out: detail output tensor,    shape (B, H, W, C)
        Returns:
          F_out: fused tensor, same shape (B, H, W, C)
        """
        # Validate shapes
        # Ensure same spatial and channel dims
        s_shape = tf.shape(S_out)
        d_shape = tf.shape(D_out)
        # (B,H,W,C) expected
        # If channels differ, raise error
        C_s = s_shape[-1]
        C_d = d_shape[-1]
        tf.debugging.assert_equal(C_s, C_d, message="S_out and D_out must have the same channel count")

        # Fused input (Eq 14)
        F_input = S_out + D_out  # (B, H, W, C)

        B = s_shape[0]
        H = s_shape[1]
        W = s_shape[2]
        C = C_s  # number of channels (inferred)

        # Create convs if needed with 'C' filters (1x1 convs)
        # Use python int if statically known, else use C (tensor), but layers require integer - try to get static
        static_C = S_out.shape[-1]
        channels_for_conv = int(static_C) if static_C is not None else None
        if channels_for_conv is not None:
            self._make_conv_layers(channels_for_conv)
        else:
            # shape unknown at graph build - create convs with C by casting to int via tf.get_static_value not available here
            # fallback: create with filters=0? Not allowed. So create lazily using first call with dynamic channel via build.
            # Here we'll construct using a small trick: create convs with channels= C by converting to int in python if possible
            # Practically, in most TF/Keras setups channels are known; if not, calling .build later will succeed.
            # We ensure convs exist by reusing default size 8 then re-wrapping - but to keep correctness, require static channels.
            raise ValueError("DGFM requires known channel dimension at graph build time. Make sure channel dimension is statically defined.")

        # Decide reduced spatial size to keep attention matrix <= max_attn_elements
        new_H, new_W = self._compute_reduced_size(H, W)  # tf.int32

        # If new_H == H and new_W == W -> compute attention at full resolution (safe)
        full_res = tf.logical_and(tf.equal(new_H, H), tf.equal(new_W, W))

        def _attend_at_full_res():
            D_q = self.conv_query(D_out)
            D_k = self.conv_key(D_out)
            F_v = self.conv_value(F_input)
            return D_q, D_k, F_v, H, W  # return original sizes

        def _attend_at_reduced_res():
            # Resize (bilinear) down
            # use tf.image.resize which expects float32
            D_small = tf.image.resize(D_out, size=(new_H, new_W), method="bilinear", antialias=True)
            F_small = tf.image.resize(F_input, size=(new_H, new_W), method="bilinear", antialias=True)
            D_q = self.conv_query(D_small)
            D_k = self.conv_key(D_small)
            F_v = self.conv_value(F_small)
            return D_q, D_k, F_v, new_H, new_W

        D_q, D_k, F_v, Hs, Ws = tf.cond(full_res, _attend_at_full_res, _attend_at_reduced_res)

        # Flatten spatial dims: (B, Hs*Ws, C)
        HW_small = Hs * Ws
        Q = tf.reshape(D_q, [B, HW_small, C])
        K = tf.reshape(D_k, [B, HW_small, C])
        V = tf.reshape(F_v, [B, HW_small, C])

        # Attention logits: (B, HW_small, HW_small)
        att_logits = tf.matmul(Q, K, transpose_b=True)

        # Softmax along last axis (for each query position)
        att_map = tf.nn.softmax(att_logits, axis=-1)

        # Apply attention: (B, HW_small, C)
        F_active_flat = tf.matmul(att_map, V)

        # Reshape back to spatial (B, Hs, Ws, C)
        F_active_small = tf.reshape(F_active_flat, [B, Hs, Ws, C])

        # If we computed on reduced size, upsample back to original H,W
        F_active = tf.image.resize(F_active_small, size=(H, W), method="bilinear", antialias=True)

        # InstanceNorm then global skip (Eq 18)
        F_active_norm = self.IN(F_active)
        F_out = F_input + F_active_norm

        return F_out


In [30]:
from tensorflow.keras.layers import Input
def cascade_block(F_in, D_in, srun, dfrm, dfgm, dc=None, zf=None, name="Cascade"):
    with tf.name_scope(name):
        # SRUN
        shallow, S = srun(F_in)

        # DFRM
        D_out = dfrm(D_in, shallow)

        # DGFM
        F_out = dfgm(S, D_out)

        # DC (optional)
        if dc is not None:
            F_out = dc(F_out, zf)

        return F_out, D_out

def build_DSMENet_functional(
        N=4, M=1, T=2,
        H=256, W=256, C=2,
        DC_cls=None
):
    inp = Input(shape=(H, W, C), name="input_image")
    zf_in = Input(shape=(H, W, C), name="zf_image") if DC_cls else None

    # -------------------------------------
    # INITIAL DRCM
    # -------------------------------------
    DRCM_init = DRCM(out_channels=2, name="DRCM_initial")
    D = DRCM_init(inp)
    F = inp

    # -------------------------------------
    # FIRST M CASCADES (unique)
    # -------------------------------------
    for i in range(M):
        srun = SRUN(name=f"SRUN_{i+1}")
        dfrm = DFRM(channels=16, name=f"DFRM_{i+1}")
        dfgm = DGFM(name=f"DGFM_{i+1}")
        dc = DC_cls(name=f"DC_{i+1}") if DC_cls else None

        F, D = cascade_block(
            F, D,
            srun, dfrm, dfgm,
            dc=dc, zf=zf_in,
            name=f"Cascade_{i+1}"
        )

    # Save output of first cascade → for ERC
    F_M = F

    # -------------------------------------
    # SHARED CASCADES (M+1...N)
    # -------------------------------------
    shared_srun = SRUN(name="SRUN_shared")
    shared_dfrm = DFRM(channels=16, name="DFRM_shared")
    shared_dfgm = DGFM(name="DGFM_shared")
    shared_dc = DC_cls(name="DC_shared") if DC_cls else None

    for t in range(T):  # ERR
        F_temp, D_temp = F, D
        for j in range(N - M):
            F_temp, D_temp = cascade_block(
                F_temp, D_temp,
                shared_srun, shared_dfrm, shared_dfgm,
                dc=shared_dc, zf=zf_in,
                name=f"SharedCascade_round{t+1}_{j+1}"
            )
        F, D = F_temp, D_temp

    F_final = F

    # -------------------------------------
    # MODEL OUTPUT: MUST RETURN BOTH
    # -------------------------------------
    if DC_cls:
        return Model([inp, zf_in], [F_M, F_final], name="DSMENet_Functional")
    else:
        return Model(inp, [F_M, F_final], name="DSMENet_Functional")


In [31]:
model = build_DSMENet_functional(N=4, M=1, T=2)
model.summary()


Model: "DSMENet_Functional"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_image (InputLayer)       [(None, 256, 256, 2  0           []                               
                                )]                                                                
                                                                                                  
 SRUN_1 (SRUN)                  ((None, 256, 256, 1  542204      ['input_image[0][0]']            
                                6),                                                               
                                 (None, 256, 256, 2                                               
                                ))                                                                
                                                                                 

In [16]:
class DMSENet(tf.keras.Model):
    """
    Wrapper model:
        Input: undersampled image  (B,H,W,2)
        Output:
            S0 = SRUN output (structure-dominated representation)
            D0 = DRCM output (initial detail representation)
    """
    def __init__(self, name="DMSENet"):
        super().__init__(name=name)

        self.srun = SRUN(name="SRUN")
        self.drcm = DRCM(out_channels=2, name="DRCM")
        self.dfrm=DFRM(channels=16)
        self.dfgm=DGFM()

    def call(self, x):
        """
        x: undersampled image (B,H,W,2)
        returns: (structure_output, detail_output)
        """
        shallow,S0 = self.srun(x)   # structure reconstruction
        D0 = self.drcm(x)   # detail representation construction
        D_out =self.dfrm(D0,shallow)
        dfgm_out=self.dfgm(S0,D_out)
        return S0, D0, D_out,dfgm_out



        
# Dummy input (batch=1, 256×256, 2 channels)
inp = tf.random.normal([1, 256, 256, 2])

model = DMSENet()
S0, D0,DF,dfgm_out = model(inp)

print("SRUN output shape:", S0.shape)
print("DRCM output shape:", D0.shape)
print("DFRM output shape:", DF.shape)
print("DFGM output shape:", dfgm_out.shape)


SRUN output shape: (1, 256, 256, 2)
DRCM output shape: (1, 256, 256, 2)
DFRM output shape: (1, 256, 256, 2)
DFGM output shape: (1, 256, 256, 2)
