In [16]:
%run "C:\Users\DU\aman_fastmri\activation.ipynb"

In [17]:
%run "C:\Users\DU\aman_fastmri\layer.ipynb"

In [18]:
%run "C:\Users\DU\aman_fastmri\fft.ipynb"

In [19]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

class TAM(Layer):
    """
    Texture Attention Module (TAM)
    Paper-faithful implementation for TEID-Net.

    Complex representation:
        Channels are concatenated:
        - first C/2 : real
        - last  C/2 : imaginary
    """

    def __init__(self, channels):
        super().__init__()
        assert channels % 4 == 0, "Channels must be divisible by 4"
        self.channels = channels
        self.half_channels = channels // 2

        # Two sequential CConv3x3 + CLReLU
        # NOTE: filters = channels (complex channels)
        self.conv1 = complex_Conv2D(
            filters=channels,
            kernel_size=3,
            padding="same"
        )
        self.conv2 = complex_Conv2D(
            filters=channels,
            kernel_size=3,
            padding="same"
        )
    def adaptive_avg_pool(self, x, target):
        """
        Adaptive average pooling using resize.
        Low-frequency branch is treated as a non-learnable statistic.
        """
        h, w = tf.shape(x)[1], tf.shape(x)[2]
    
        # Low-pass filtering (NO gradients here)
        x_low = tf.image.resize(x, (target, target), method="area")
        x_low = tf.stop_gradient(x_low)  # ðŸ”’ critical fix
    
        # Restore spatial resolution
        x_low = tf.image.resize(x_low, (h, w), method="bilinear")
        return x_low

    def call(self, x):
        """
        x: (B, H, W, C)
        """
        # ----------------------------------
        # Split into real / imaginary
        # ----------------------------------
        real = x[..., :self.half_channels]     # (B,H,W,C/2)
        imag = x[..., self.half_channels:]     # (B,H,W,C/2)

        # ----------------------------------
        # Split channels into 4 groups
        # ----------------------------------
        real_groups = tf.split(real, 4, axis=-1)
        imag_groups = tf.split(imag, 4, axis=-1)

        pool_sizes = [1, 2, 4, 8]
        tex_real_groups = []
        tex_imag_groups = []

        # ----------------------------------
        # Multi-scale texture extraction
        # ----------------------------------
        for r, i, p in zip(real_groups, imag_groups, pool_sizes):
            r_low = self.adaptive_avg_pool(r, p)
            i_low = self.adaptive_avg_pool(i, p)

            tex_real_groups.append(r - r_low)
            tex_imag_groups.append(i - i_low)

        # ----------------------------------
        # Concatenate texture features
        # ----------------------------------
        tex_real = tf.concat(tex_real_groups, axis=-1)  # (B,H,W,C/2)
        tex_imag = tf.concat(tex_imag_groups, axis=-1)  # (B,H,W,C/2)

        # ----------------------------------
        # Texture filtering (CCB part)
        # ----------------------------------
        tex_real, tex_imag = self.conv1(tex_real, tex_imag)
        tex_real, tex_imag = CLeaky_ReLU(tex_real, tex_imag)

        tex_real, tex_imag = self.conv2(tex_real, tex_imag)
        tex_real, tex_imag = CLeaky_ReLU(tex_real, tex_imag)

        # ----------------------------------
        # Texture attention map
        # ----------------------------------
        attn_real = tf.sigmoid(tex_real)
        attn_imag = tf.sigmoid(tex_imag)

        # ----------------------------------
        # Residual texture attention
        # F_out = F_in + F_in âŠ— M_t
        # ----------------------------------
        out_real = real + real * attn_real
        out_imag = imag + imag * attn_imag

        # ----------------------------------
        # Re-concatenate real & imag
        # ----------------------------------
        return tf.concat([out_real, out_imag], axis=-1)


In [20]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

class TEM(Layer):
    """
    Texture Enhancement Module (TEM)
    Section 2.2.1, Fig. 1(d) of TEID-Net.

    Input / Output:
        (B, H, W, C)
    Complex representation:
        first C/2 channels  -> real
        last  C/2 channels  -> imaginary
    """

    def __init__(self, channels):
        super().__init__()
        self.channels = channels

        # Texture Attention Module
        self.tam = TAM(channels)

        # Complex Convolutional Block (CCB)
        self.conv1 = complex_Conv2D(
            filters=channels,
            kernel_size=3,
            padding="same"
        )
        self.conv2 = complex_Conv2D(
            filters=channels,
            kernel_size=3,
            padding="same"
        )
        self.conv3 = complex_Conv2D(
            filters=channels,
            kernel_size=3,
            padding="same"
        )

    def call(self, x):
        """
        x: (B, H, W, C)
        """
        # -----------------------------
        # Texture attention
        # -----------------------------
        x_t = self.tam(x)

        # -----------------------------
        # Split real / imag
        # -----------------------------
        C = self.channels // 2
        real = x_t[..., :C]
        imag = x_t[..., C:]

        # -----------------------------
        # Complex Convolutional Block
        # -----------------------------
        real, imag = self.conv1(real, imag)
        real, imag = CLeaky_ReLU(real, imag)

        real, imag = self.conv2(real, imag)
        real, imag = CLeaky_ReLU(real, imag)

        real, imag = self.conv3(real, imag)
        real, imag = CLeaky_ReLU(real, imag)

        # -----------------------------
        # Re-concatenate
        # -----------------------------
        return tf.concat([real, imag], axis=-1)


In [21]:
import tensorflow as tf
from tensorflow.keras.layers import Layer

class TENet(Layer):
    """
    Texture Enhancement Network (TE-Net)
    Paper-faithful implementation using:
    - fft2c_tf / ifft2c_tf
    - channel-wise complex representation
    """

    def __init__(self):
        super().__init__()

        # Feature lifting: 1 complex â†’ 32 complex
        self.conv_in = complex_Conv2D(
            filters=32,
            kernel_size=1,
            padding="same"
        )

        # Three residual TEM blocks
        self.tem1 = TEM(channels=32)
        self.tem2 = TEM(channels=32)
        self.tem3 = TEM(channels=32)

        # Feature fusion
        self.fuse_conv = complex_Conv2D(
            filters=2,
            kernel_size=1,
            padding="same"
        )

        self.out_conv = complex_Conv2D(
            filters=32,
            kernel_size=1,
            padding="same"
        )

    def call(self, kspace):
        """
        kspace: (B, H, W, 2)  # real, imag
        returns: (B, H, W, 64)  # 32 complex channels
        """

        # ----------------------------------
        # 1) k-space â†’ image domain
        # ----------------------------------
        x = ifft2c_tf(kspace)          # (B,H,W,2)

        # ----------------------------------
        # 2) Initial CConv1Ã—1 (1 â†’ 32 complex)
        # ----------------------------------
        real = x[..., 0:1]
        imag = x[..., 1:2]

        real, imag = self.conv_in(real, imag)
        real, imag = CLeaky_ReLU(real, imag)

        x = tf.concat([real, imag], axis=-1)  # (B,H,W,64)

        # ----------------------------------
        # 3) Residual TEMs
        # ----------------------------------
        t1 = self.tem1(x)
        x1 = x + t1

        t2 = self.tem2(x1)
        x2 = x1 + t2

        t3 = self.tem3(x2)
        x3 = x2 + t3

        # ----------------------------------
        # 4) Concatenate TEM outputs
        # ----------------------------------
        x_cat = tf.concat([t1, t2, t3], axis=-1)  # (B,H,W,192)
        #print("x_cat",x_cat.shape)

        # ----------------------------------
        # 5) Feature fusion
        # ----------------------------------
        C = x_cat.shape[-1] // 2
        #print("C",C)
        real = x_cat[..., :C]
        imag = x_cat[..., C:]

        real, imag = self.fuse_conv(real, imag)
        real, imag = CLeaky_ReLU(real, imag)

        # real, imag = self.out_conv(real, imag)
        # real, imag = CLeaky_ReLU(real, imag)

        return tf.concat([real, imag], axis=-1)


In [22]:
def ifft1d_H_tf(x):
    """
    1-D IFFT along H (axis=-3)
    x : (B, H, W, 2)
    """
    # move H to last spatial axis
    x_perm = tf.transpose(x, [0, 2, 1, 3])   # (B, W, H, 2)
    x_ifft = ifft2c_tf(x_perm)               # uses your verified impl
    return tf.transpose(x_ifft, [0, 2, 1, 3])
def ifft1d_W_tf(x):
    """
    1-D IFFT along W (axis=-2)
    x : (B, H, W, 2)
    """
    return ifft2c_tf(x)


In [23]:
class IDM(tf.keras.layers.Layer):
    def __init__(self, channels=64):
        super().__init__()

        # 1Ã—1 expansion
        self.expand = complex_Conv2D(channels, kernel_size=1, padding="same")

        # 1-D global convolutions (via transpose)
        self.conv1 = complex_Conv2D(320*2, kernel_size=1, padding="same")
        self.conv2 = complex_Conv2D(320*2, kernel_size=1, padding="same")

        # 3Ã—3 complex convolutions after residual fusion
        self.conv3 = complex_Conv2D(32, kernel_size=3, padding="same")
        self.conv4 = complex_Conv2D(32, kernel_size=3, padding="same")

        # Final channel fusion
        self.fuse = complex_Conv2D(2, kernel_size=1, padding="same")

    def call(self, x):
        """
        x : (B, H, W, 2)   # 1 complex channel (real, imag)
        
        """
        #print("Input after IDIFT height",x.shape)

        # --------------------------------------------------
        # Split real / imag (MANDATORY)
        # --------------------------------------------------
        input_real = x[..., 0:1]
        input_imag = x[..., 1:2]
        #print("input_real",input_real.shape,"input_imag",input_imag.shape)

        # --------------------------------------------------
        # (1) Channel expansion (z1)
        # --------------------------------------------------
        real1_conv1, imag1_conv1 = self.expand(input_real, input_imag)
        real1_conv1, imag1_conv1 = CLeaky_ReLU(real1_conv1, imag1_conv1)
        #print("Output after 1st conv1*1 real1_conv1",real1_conv1.shape,"imag1_conv1",imag1_conv1.shape)
        
        # z1 : (B, H, W, C)

        # --------------------------------------------------
        # (2) Transpose â†’ 1-D global convolution
        # --------------------------------------------------
        real_t = tf.transpose(real1_conv1, [0, 1, 3, 2])   # (B,H,C,W)
        imag_t = tf.transpose(imag1_conv1, [0, 1, 3, 2])
        #print("After ist transpose real_t",real_t.shape,"imag_t",imag_t.shape)
        
        # --------------------------------------------------
        # (3) Two complex convs + CLReLU (z2)
        # --------------------------------------------------
        real2, imag2 = self.conv1(real_t, imag_t)
        real2, imag2 = CLeaky_ReLU(real2, imag2)
        # print("real2",real2.shape)
        # print("imag2",imag2.shape)

        real2, imag2 = self.conv2(real2, imag2)
        real2_conv2, imag2_conv2 = CLeaky_ReLU(real2, imag2)
        #print("After 2 intermediate 1*1 conv ::real2",real2_conv2.shape,"imag2",imag2_conv2.shape)
        

        # --------------------------------------------------
        # (4) Transpose back
        # --------------------------------------------------
        real_trans2 = tf.transpose(real2_conv2, [0, 1, 3, 2])    # (B,H,W,C)
        imag2_trans2  = tf.transpose(imag2_conv2, [0, 1, 3, 2])
        #print("After 2nd transpose::Real ",real_trans2.shape,"imag",imag2_trans2.shape)
        
        

        # --------------------------------------------------
        # (5) Residual fusion (z1 + z2)
        # --------------------------------------------------
        real_res_fusion = real1_conv1 + real_trans2
        imag_res_fusion = imag1_conv1 + imag2_trans2
        #print("After residual fusion::Real ",real_res_fusion.shape,"imag",imag_res_fusion.shape)
        # --------------------------------------------------
        # (6) Two 3Ã—3 complex convolutions + CLReLU
        # --------------------------------------------------
        real, imag = self.conv3(real_res_fusion, imag_res_fusion)
        real, imag = CLeaky_ReLU(real, imag)

        real, imag = self.conv4(real, imag)
        real_conv3, imag_conv3 = CLeaky_ReLU(real, imag)
        #print("After 2 3*3 conv ::Real ",real_conv3.shape,"imag",imag_conv3.shape)

        # --------------------------------------------------
        # (7) Final 1Ã—1 fusion â†’ 1 complex channel
        # --------------------------------------------------
        real, imag = self.fuse(real_conv3, imag_conv3)

        # --------------------------------------------------
        # Output as (real, imag)
        # --------------------------------------------------
        return tf.concat([real, imag], axis=-1)


In [24]:
class IDNet(tf.keras.Model):
    def __init__(self, channels=64):
        super().__init__()
        self.idm = IDM(channels)

    def call(self, x_u):
        """
        x_u : (B,H,W,2) undersampled k-space
        """

        # 1-D IFT along H (frequency-encoding direction)
        #print("x_u",x_u.shape)
        z_Hu = ifft1d_H_tf(x_u)
        #print("z_Hu",z_Hu.shape)

        # Intermediate-domain recovery
        z_H = self.idm(z_Hu)
        ##print("z_H",z_H.shape)
        # 1-D IFT along W (phase-encoding direction)
        y_hat = ifft1d_W_tf(z_H)
        #print("y_hat",y_hat.shape)

        return y_hat


In [25]:
class FusionModule(tf.keras.layers.Layer):
    def __init__(self, channels=32):
        super().__init__()

        assert channels % 4 == 0, "Channels must be divisible by 4"
        self.channels = channels
        self.group_channels = channels // 4

        # Channel alignment
        self.align_te = complex_Conv2D(channels, kernel_size=1, padding="same")
        self.align_id = complex_Conv2D(channels, kernel_size=1, padding="same")

        # Dilated convolutions for 4 scales
        self.dilated_convs = []
        for rate in [1, 2, 3, 4]:
            self.dilated_convs.append(
                complex_Conv2D(
                    self.group_channels,
                    kernel_size=3,
                    padding="same",
                    dilation_rate=rate
                )
            )

        # Final fusion
        self.fuse = complex_Conv2D(2, kernel_size=1, padding="same")

    def call(self, x_te, x_id):
        """
        x_te, x_id : (B, H, W, 2C)  packed complex
        returns    : (B, H, W, 2C)
        """

        # --------------------------------------------------
        # Split real / imag ONCE
        # --------------------------------------------------
        te_r = x_te[..., 0:1]
        te_i = x_te[..., 1:2]
        
        id_r = x_id[..., 0:1]
        id_i = x_id[..., 1:2]

        # te_r = x_te[..., :self.channels]
        # te_i = x_te[..., self.channels:]


        # --------------------------------------------------
        # 1Ã—1 channel alignment
        # --------------------------------------------------
        te_r, te_i = self.align_te(te_r, te_i)
        id_r, id_i = self.align_id(id_r, id_i)

        # --------------------------------------------------
        # Split into 4 channel groups
        # --------------------------------------------------
        te_r_groups = tf.split(te_r, 4, axis=-1)
        te_i_groups = tf.split(te_i, 4, axis=-1)
        id_r_groups = tf.split(id_r, 4, axis=-1)
        id_i_groups = tf.split(id_i, 4, axis=-1)

        fused_r = []
        fused_i = []

        # --------------------------------------------------
        # Cross-fusion + multi-scale dilated conv
        # --------------------------------------------------
        for k in range(4):
            r = tf.concat([te_r_groups[k], id_r_groups[k]], axis=-1)
            i = tf.concat([te_i_groups[k], id_i_groups[k]], axis=-1)

            r, i = self.dilated_convs[k](r, i)
            r, i = CLeaky_ReLU(r, i)

            fused_r.append(r)
            fused_i.append(i)

        # --------------------------------------------------
        # Concatenate all scales
        # --------------------------------------------------
        r = tf.concat(fused_r, axis=-1)
        i = tf.concat(fused_i, axis=-1)

        # --------------------------------------------------
        # Final 1Ã—1 fusion (NO activation here)
        # --------------------------------------------------
        r, i = self.fuse(r, i)

        return tf.concat([r, i], axis=-1)


In [26]:
class DataConsistency(tf.keras.layers.Layer):
    """
    Soft data consistency layer for single-coil MRI
    with 1D Cartesian sampling mask.
    """
    def __init__(self, init_lambda=0.1):
        super().__init__()

        # Learnable balance parameter (Î» > 0)
        self.lambda_dc = tf.Variable(
            initial_value=init_lambda,
            trainable=True,
            dtype=tf.float32,
            name="lambda_dc"
        )

    def call(self, x_img, k_us, mask):
        """
        x_img : (B, H, W, 2)  image-domain prediction
        k_us  : (B, H, W, 2)  undersampled k-space
        mask  : (1, 1, W, 1)  1D Cartesian sampling mask
        """

        # Image â†’ k-space
        k_pred = fft2c_tf(x_img)   # (B,H,W,2)

        # Ensure type consistency
        mask = tf.cast(mask, k_pred.dtype)
        lam  = tf.cast(self.lambda_dc, k_pred.dtype)

        # Soft data consistency
        k_dc = (
            (1.0 - mask) * k_pred +
            mask * (k_us + lam * k_pred) / (1.0 + lam)
        )


        return k_dc


In [27]:
class TEIDBlock(tf.keras.layers.Layer):
    """
    One TEID block:
    k-space -> TE/ID internal processing -> fusion -> DC -> k-space
    """
    def __init__(self, channels):
        super().__init__()

        self.te_net = TENet()
        self.id_net = IDNet(channels=channels)
        self.fm     = FusionModule(channels=channels)
        self.dc     = DataConsistency()

    def call(self, k_in, k_us, mask):
        """
        k_in : (B, H, W, 2)  k-space from previous cascade
        k_us : (B, H, W, 2)  zero-filled undersampled k-space (K^0)
        mask : (1, 1, W, 1)
        """

        # TE-Net and ID-Net both take k-space as input
        x_te = self.te_net(k_in)   # image-domain output
        x_id = self.id_net(k_in)   # image-domain output

        # Fusion in image domain
        x_fused = self.fm(x_te, x_id)

        # Data consistency (returns k-space)
        k_out = self.dc(x_fused, k_us, mask)

        return k_out


In [31]:
class CascadedTEIDNet(tf.keras.Model):
    def __init__(self, num_cascades=5, channels=64):
        super().__init__()

        self.blocks = [
            TEIDBlock(channels=channels)
            for _ in range(num_cascades)
        ]

    def call(self, k_us, mask):
        """
        k_us : (B, H, W, 2) undersampled k-space
        mask : (1, 1, W, 1)
        """
        #print("mask",mask.shape)

        k = k_us  # K^(0)

        for block in self.blocks:
            k = block(k, k_us, mask)

        # Final reconstruction (only once)
        x_rec = ifft2c_tf(k)

        return x_rec


In [34]:
from tensorflow.keras.layers import Input, Lambda
from tensorflow.keras.models import Model

def build_cascaded_teid_model(
    H=320,
    W=320,
    num_cascades=5,
    channels=64
):
    """
    Builds a Keras Model wrapper for CascadedTEIDNet
    so it plugs directly into the existing data pipeline.
    """

    # Inputs
    k_us = Input(shape=(H, W, 2), name="kspace_undersampled")
    mask = Input(shape=(1, W, 1), name="sampling_mask")

    # Core model
    teid_net = CascadedTEIDNet(
        num_cascades=num_cascades,
        channels=channels
    )

    # Forward pass
    x_rec = teid_net(k_us, mask)

    # Optional identity layer (for loss dict / naming consistency)
    x_rec = Lambda(lambda x: x, name="reconstructed_image")(x_rec)

    # Final wrapped model
    model = Model(
        inputs=[k_us, mask],
        outputs=x_rec,
        name="CascadedTEIDNet_Model"
    )

    return model


In [35]:
model = build_cascaded_teid_model(
    H=320,
    W=320,
    num_cascades=5,
    channels=64
)

model.summary()


mask (None, 1, 320, 1)
x_u (None, 320, 320, 2)
z_Hu (None, 320, 320, 2)
Input after IDIFT height (None, 320, 320, 2)
input_real (None, 320, 320, 1) input_imag (None, 320, 320, 1)
Output after 1st conv1*1 real1_conv1 (None, 320, 320, 32) imag1_conv1 (None, 320, 320, 32)
After ist transpose real_t (None, 320, 32, 320) imag_t (None, 320, 32, 320)
After 2 intermediate 1*1 conv ::real2 (None, 320, 32, 320) imag2 (None, 320, 32, 320)
After 2nd transpose::Real  (None, 320, 320, 32) imag (None, 320, 320, 32)
After residual fusion::Real  (None, 320, 320, 32) imag (None, 320, 320, 32)
After 2 3*3 conv ::Real  (None, 320, 320, 16) imag (None, 320, 320, 16)
z_H (None, 320, 320, 2)
y_hat (None, 320, 320, 2)
x_u (None, 320, 320, 2)
z_Hu (None, 320, 320, 2)
Input after IDIFT height (None, 320, 320, 2)
input_real (None, 320, 320, 1) input_imag (None, 320, 320, 1)
Output after 1st conv1*1 real1_conv1 (None, 320, 320, 32) imag1_conv1 (None, 320, 320, 32)
After ist transpose real_t (None, 320, 32, 320) i