In [3]:
import tensorflow as tf
from tensorflow.keras import layers


In [4]:
%run "C:\Users\DU\aman_fastmri\fft.ipynb"

In [5]:
class KUBlock(tf.keras.layers.Layer):
    """
    K-space Under-sampling (KU) block.
    """
    def __init__(self, norm="ortho", **kwargs):
        super().__init__(trainable=False, **kwargs)
        self.norm = norm

    def call(self, inputs):
        """
        inputs: tuple (S_gt, mask)
        """
        S_gt, mask = inputs

        kspace = fft2c_tf(S_gt, norm=self.norm)

        mask = tf.cast(mask, kspace.dtype)

        if len(mask.shape) == 2:
            mask = mask[tf.newaxis, ..., tf.newaxis]
        elif len(mask.shape) == 3:
            mask = mask[..., tf.newaxis]

        if mask.shape[-1] == 1:
            mask = tf.concat([mask, mask], axis=-1)

        kspace_under = kspace * mask
        S_zf = ifft2c_tf(kspace_under, norm=self.norm)

        return S_zf


In [6]:
import tensorflow as tf
from tensorflow.keras import layers

class SCABlock(tf.keras.layers.Layer):
    """
    Spatial and Channel-wise Attention (SCA) block
    Implements Eq. (7)â€“(11) and Figure 3 from the paper.
    """

    def __init__(self, channels, name="SCA_Block"):
        super().__init__(name=name)
        self.channels = channels

        # Spatial attention: linear projection WS, BS
        self.spatial_conv = layers.Conv2D(
            filters=1,
            kernel_size=1,
            strides=1,
            padding="same",
            use_bias=True
        )

        # Channel-wise attention: linear projection WC, BC
        self.channel_fc = layers.Dense(
            units=channels,
            use_bias=True
        )

    def call(self, F):
        """
        Args:
            F: Input feature map (B, H, W, C)

        Returns:
            A: Attentive feature map (B, H, W, C)
        """

        # ==================================================
        # Spatial Attention Î¨_S
        # ==================================================

        # Channel-wise mean: (B, H, W, 1)
        F_mean = tf.reduce_mean(F, axis=-1, keepdims=True)

        # Linear transform + sigmoid: spatial attention map
        M_s = tf.nn.sigmoid(self.spatial_conv(F_mean))

        # Apply spatial attention + skip connection
        Theta_SA = F * M_s + F

        # ==================================================
        # Channel-wise Attention Î¨_C
        # ==================================================

        # Global average pooling: (B, C)
        F_gap = tf.reduce_mean(Theta_SA, axis=[1, 2])

        # Linear transform + sigmoid: channel attention map
        M_c = tf.nn.sigmoid(self.channel_fc(F_gap))

        # Reshape to (B, 1, 1, C)
        M_c = tf.reshape(M_c, [-1, 1, 1, self.channels])

        # Apply channel-wise attention + skip connection
        Theta_SCA = Theta_SA * M_c + Theta_SA

        return Theta_SCA


In [7]:
import tensorflow as tf
from tensorflow.keras import layers

class SCUnit(tf.keras.layers.Layer):
    """
    SC Unit (Short-skip Residual + SCA)
    As defined in Figure 2 and Generator description.
    """

    def __init__(self, channels, activation="relu", name="SC_Unit"):
        super().__init__(name=name)
        self.channels = channels

        # Conv0: C -> C
        self.conv0 = layers.Conv2D(
            filters=channels,
            kernel_size=3,
            strides=1,
            padding="same",
            activation=activation
        )

        # Conv1: C -> C/2
        self.conv1 = layers.Conv2D(
            filters=channels // 2,
            kernel_size=3,
            strides=1,
            padding="same",
            activation=activation
        )

        # Conv2: C/2 -> C
        self.conv2 = layers.Conv2D(
            filters=channels,
            kernel_size=3,
            strides=1,
            padding="same",
            activation=activation
        )

        # SCA block
        self.sca = SCABlock(channels)

    def call(self, x):
        """
        Args:
            x: Input feature map (B, H, W, C)

        Returns:
            y: Output feature map (B, H, W, C)
        """

        residual = x  # local short skip

        x = self.conv0(x)
        x = self.conv1(x)
        x = self.conv2(x)

        x = self.sca(x)

        # Local residual connection
        return x + residual


In [8]:
class EncoderBlock(tf.keras.layers.Layer):
    def __init__(self, channels, name="EncoderBlock"):
        super().__init__(name=name)
        self.sc1 = SCUnit(channels)
        self.sc2 = SCUnit(channels)

        self.down = tf.keras.layers.Conv2D(
            filters=channels * 2,
            kernel_size=3,
            strides=2,
            padding="same"
        )

    def call(self, x):
        x = self.sc1(x)
        x = self.sc2(x)
        skip = x
        x = self.down(x)
        return x, skip


In [9]:
class DecoderBlock(tf.keras.layers.Layer):
    def __init__(self, channels, name="DecoderBlock"):
        super().__init__(name=name)

        self.channels = channels

        self.up = tf.keras.layers.Conv2DTranspose(
            filters=channels,
            kernel_size=3,
            strides=2,
            padding="same"
        )

        # ðŸ”‘ projection AFTER concatenation
        self.proj = tf.keras.layers.Conv2D(
            filters=channels,
            kernel_size=1,
            padding="same"
        )

        self.sc1 = SCUnit(channels)
        self.sc2 = SCUnit(channels)
        self.sca = SCABlock(channels)

    def call(self, x, skip):
        x = self.up(x)
        x = tf.concat([x, skip], axis=-1)   # (B,H,W,2C)
        x = self.proj(x)                    # (B,H,W,C)

        x = self.sc1(x)
        x = self.sc2(x)
        x = self.sca(x)
        return x


In [10]:
class SCUNet(tf.keras.Model):
    def __init__(self, in_channels=2, base_channels=64, name="SCUNet"):
        super().__init__(name=name)

        self.in_channels = in_channels
        self.base_channels = base_channels

        # ======================
        # Initial projection
        # ======================
        self.in_conv = tf.keras.layers.Conv2D(
            base_channels, kernel_size=3, padding="same"
        )

        # ======================
        # Encoder
        # ======================
        self.enc1 = EncoderBlock(base_channels)        # 64  -> 128
        self.enc2 = EncoderBlock(base_channels * 2)    # 128 -> 256
        self.enc3 = EncoderBlock(base_channels * 4)    # 256 -> 512
        self.enc4 = EncoderBlock(base_channels * 8)    # 512 -> 1024

        # ======================
        # Bottleneck
        # ======================
        self.bottleneck = SCUnit(base_channels * 16)   # 1024

        # ======================
        # Decoder
        # ======================
        self.dec4 = DecoderBlock(base_channels * 8)    # 1024 -> 512
        self.dec3 = DecoderBlock(base_channels * 4)    # 512  -> 256
        self.dec2 = DecoderBlock(base_channels * 2)    # 256  -> 128
        self.dec1 = DecoderBlock(base_channels)        # 128  -> 64

        # ======================
        # Output projection
        # ======================
        self.out_conv = tf.keras.layers.Conv2D(
            in_channels, kernel_size=3, padding="same"
        )

    def call(self, y, training=False):
        """
        Args:
            y: (B, H, W, in_channels)
        Returns:
            (B, H, W, in_channels)
        """

        # Optional safety check (can comment out later)
        # tf.debugging.assert_equal(tf.shape(y)[-1], self.in_channels)

        # Encoder
        x0 = self.in_conv(y)
        x1, s1 = self.enc1(x0)
        x2, s2 = self.enc2(x1)
        x3, s3 = self.enc3(x2)
        x4, s4 = self.enc4(x3)

        # Bottleneck
        x = self.bottleneck(x4)

        # Decoder
        x = self.dec4(x, s4)
        x = self.dec3(x, s3)
        x = self.dec2(x, s2)
        x = self.dec1(x, s1)

        # Output
        out = self.out_conv(x)

        # Global residual (long skip)
        return out + y


In [11]:
import tensorflow as tf

class Generator(tf.keras.Model):
    """
    RSCA-GAN Generator
    Dual-cascade SCUNet architecture
    """

    def __init__(self, in_channels=2, base_channels=64, name="Generator_G"):
        super().__init__(name=name)

        self.in_channels = in_channels

        self.scunet1 = SCUNet(
            in_channels=in_channels,
            base_channels=base_channels,
            name="SCUNet_1"
        )

        self.scunet2 = SCUNet(
            in_channels=in_channels,
            base_channels=base_channels,
            name="SCUNet_2"
        )

    def call(self, x, training=False):
        """
        Args:
            x: Zero-filled MRI image (B, H, W, 2)

        Returns:
            (B, H, W, 2)
        """

        # Optional safety check (can comment out later)
        # tf.debugging.assert_equal(tf.shape(x)[-1], self.in_channels)

        x = self.scunet1(x, training=training)
        x = self.scunet2(x, training=training)
        return x


In [12]:
import tensorflow as tf
from tensorflow.keras import layers, initializers

class Discriminator(tf.keras.Model):
    """
    RSCA-GAN Discriminator D
    6-layer CNN (PatchGAN style)
    """

    def __init__(self, in_channels=2, base_channels=64, name="Discriminator_D"):
        super().__init__(name=name)

        self.in_channels = in_channels

        init = initializers.HeNormal()

        self.conv1 = layers.Conv2D(
            base_channels, 3, strides=1, padding="same",
            kernel_initializer=tf.keras.initializers.HeNormal()
        )
        self.lrelu = layers.LeakyReLU(alpha=0.2)

        self.conv2 = layers.Conv2D(
            base_channels * 2, 3, strides=2, padding="same",
            kernel_initializer=tf.keras.initializers.HeNormal()
        )
        self.conv3 = layers.Conv2D(
            base_channels * 4, 3, strides=2, padding="same",
            kernel_initializer=tf.keras.initializers.HeNormal()
        )
        self.conv4 = layers.Conv2D(
            base_channels * 8, 3, strides=2, padding="same",
            kernel_initializer=tf.keras.initializers.HeNormal()
        )
        self.conv5 = layers.Conv2D(
            base_channels * 8, 3, strides=2, padding="same",
            kernel_initializer=tf.keras.initializers.HeNormal()
        )

        self.conv6 = layers.Conv2D(
            1, 3, strides=1, padding="same",
            kernel_initializer=tf.keras.initializers.HeNormal()
        )

    def call(self, x, training=False):
        # Optional safety check
        # tf.debugging.assert_equal(tf.shape(x)[-1], self.in_channels)

        x = self.lrelu(self.conv1(x))
        x = self.lrelu(self.conv2(x))
        x = self.lrelu(self.conv3(x))
        x = self.lrelu(self.conv4(x))
        x = self.lrelu(self.conv5(x))
        return self.conv6(x)


In [13]:
import tensorflow as tf

class RSCAGAN(tf.keras.Model):
    """
    RSCA-GAN full model wrapper (KU block skipped).
    Defines ONLY the forward graph.
    """

    def __init__(self, in_channels=2, base_channels=64, name="RSCA_GAN"):
        super().__init__(name=name)

        self.generator = Generator(
            in_channels=in_channels,
            base_channels=base_channels
        )

        self.discriminator = Discriminator(
            in_channels=in_channels,
            base_channels=base_channels
        )

    def call(self, inputs, training=False):
        """
        Args:
            inputs: tuple (SZF, SGT)
                SZF : undersampled / zero-filled image (B, H, W, 2)
                SGT : ground truth image (B, H, W, 2)

        Returns:
            dict of outputs needed for training
        """

        SZF, SGT = inputs

        # Generator (only SZF goes in)
        SRE = self.generator(SZF, training=training)

        # Discriminator (sees both real and fake)
        D_real = self.discriminator(SGT, training=training)
        D_fake = self.discriminator(SRE, training=training)

        return {
            "SZF": SZF,
            "SGT": SGT,
            "SRE": SRE,
            "D_real": D_real,
            "D_fake": D_fake
        }
