In [1]:
import tensorflow as tf
from tensorflow.keras import layers


In [19]:
class AdaptiveGlobalFilter(tf.keras.layers.Layer):
    def __init__(self, ratio=10):
        super().__init__()
        self.ratio = ratio
        # self.filter  # [C, H, W, 2]


    def build(self, input_shape):
        # input_shape: [B, H, W, C, 2]
        _, H, W, C, _ = input_shape
        self.H, self.W, self.C = H, W, C

        # initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
        self.filter = self.add_weight(
            shape=(C, H, W, 2),
            # initializer=initializer,
            trainable=True,
            name='complex_filter'
        )

        self.mask_low = self._create_static_mask()
        self.mask_high = 1.0 - self.mask_low

    def _create_static_mask(self):
        mask = tf.zeros((self.H, self.W), dtype=tf.float32)
        crow, ccol = self.H // 2, self.W // 2
        r = self.ratio

        indices = tf.stack(tf.meshgrid(
            tf.range(crow - r, crow + r),
            tf.range(ccol - r, ccol + r),
            indexing='ij'
        ), axis=-1)
        indices = tf.reshape(indices, [-1, 2])
        updates = tf.ones((tf.shape(indices)[0],), dtype=tf.float32)

        return tf.Variable(tf.tensor_scatter_nd_update(mask, indices, updates), trainable=False)

    def call(self, x):
        tf.debugging.check_numerics(x, "Input to AGF has NaNs/Infs")
        # x: [B, H, W, C, 2]
        x = tf.transpose(x, [0, 3, 1, 2, 4])  # [B, C, H, W, 2]
        x_freq = fft2c_tf(x)                 # [B, C, H, W, 2]
        tf.debugging.check_numerics(x_freq, "FFT2 output has NaNs/Infs")


        # Convert learnable filter to complex
        real_part = self.filter[..., 0]
        imag_part = self.filter[..., 1]
        weight = tf.complex(real_part, imag_part)  # [C, H, W]
        weight = tf.expand_dims(weight, axis=0)    # [1, C, H, W]

        # Cast masks
        mask_low_c = tf.cast(tf.reshape(self.mask_low, [1, 1, self.H, self.W]), tf.complex64)
        mask_high_c = tf.cast(tf.reshape(self.mask_high, [1, 1, self.H, self.W]), tf.complex64)

        # Convert x_freq to complex
        x_freq_c = tf.complex(x_freq[..., 0], x_freq[..., 1])  # [B, C, H, W]

        x_low = x_freq_c * mask_low_c
        x_high = x_freq_c * mask_high_c

        x_low_filtered = x_low * weight
        x_combined = x_low_filtered + x_high  # [B, C, H, W]

        x_combined = tf.stack([tf.math.real(x_combined), tf.math.imag(x_combined)], axis=-1)  # [B, C, H, W, 2]
        x_combined = tf.transpose(x_combined, [0, 2, 3, 1, 4])  # [B, H, W, C, 2]

        # Apply IFFT
        x_out = ifft2c_tf(x_combined)  # [B, H, W, C, 2]
        tf.debugging.check_numerics(x_out, "IFFT2 output has NaNs/Infs")
        return x_out

# # # Example dummy input: [batch, height, width, channels]
# dummy_input = tf.random.normal((2, 512, 512, 64,2))

# # Instantiate layer
# agf = AdaptiveGlobalFilter(ratio=10)

# # Forward pass
# output = agf(dummy_input)

# print("Input shape: ", dummy_input.shape)
# print("Output shape:", output.shape)Under non-Cartesian or aggressive masks:

#Frequency separation may be suboptimal

In [31]:
class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = tf.keras.layers.Conv2D(1, kernel_size=7, padding='same', use_bias=False)
        self.sigmoid = tf.keras.activations.sigmoid
        self.last_attn_map = None

    def call(self, x, return_attn=False):
        real = x[..., 0]
        imag = x[..., 1]
        mag = tf.sqrt(real**2 + imag**2 + 1e-6)
        avg_out = tf.reduce_mean(mag, axis=-1, keepdims=True)
        max_out = tf.reduce_max(mag, axis=-1, keepdims=True)
        concat = tf.concat([avg_out, max_out], axis=-1)
        attn_map = self.sigmoid(self.conv(concat))
        attn_map = tf.expand_dims(attn_map, axis=-1)

        self.last_attn_map = attn_map  # Save for visualization

        return x * attn_map


Input  shape: (16, 128, 128, 16, 2)
Output shape: (16, 128, 128, 16, 2)


In [4]:
class FSA(tf.keras.layers.Layer):
    def __init__(self, ratio=10):
        super().__init__()
        self.agf = AdaptiveGlobalFilter(ratio=ratio)
        self.sa = SpatialAttention()


    def call(self, x, return_attn=False):
        out1 = self.agf(x)
        if return_attn:
            out2, attn_map = self.sa(x, return_attn=True)
            out = out1 + out2
            return out, attn_map
        else:
            out2 = self.sa(x)
            out = out1 + out2
            return out



In [5]:
class SkipConnection(tf.keras.layers.Layer):
    def __init__(self, ratio, **kwargs):
        super().__init__(**kwargs)
        self.fsa = FSA(ratio=ratio)

    def call(self, real, imag):
        """
        real, imag: Tensors of shape [B, H, W, C]
        """
        # # Stack real and imag along the last axis -> [B, H, W, C, 2]
        x = tf.stack([real, imag], axis=-1)
        tf.debugging.check_numerics(x, "üö® FSA input has NaNs/Infs")

        # tf.print("x shape:", tf.shape(x))

        # # Apply FSA block
        y = self.fsa(x)  # [B, H, W, C, 2]
        # tf.debugging.check_numerics(y, "üö® FSA output has NaNs/Infs")
        

        # # Split output back to real and imaginary parts
        return y[..., 0], y[..., 1]

        # return real,imag

In [6]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D,\
                                    Multiply, ZeroPadding2D, Cropping2D


In [None]:
%run "C:\\Users\\DU\\aman_fastmri\\fft.ipynb"
    
%run "C:\\Users\\DU\\aman_fastmri\\activation.ipynb"
   

In [None]:
%run "C:\\Users\\DU\\aman_fastmri\\layer.ipynb"

# %run layer.ipynb
# %run activation.ipynb
# 

In [8]:
import tensorflow as tf
from tensorflow.keras import Model

class ComplexEncoderTF(Model):
    def __init__(self, input_shape=(320, 320, 2)):
        super(ComplexEncoderTF, self).__init__()

        # Block 1
        self.block1_conv1 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.block1_conv2 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.pool1 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 2
        self.block2_conv1 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.block2_conv2 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.pool2 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 3
        self.block3_conv1 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.block3_conv2 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.pool3 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 4
        self.block4_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block4_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.pool4 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 5 (no pooling)
        self.block5_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block5_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')

        self.build((None, *input_shape))

    def call(self, inputs):
        real = tf.expand_dims(inputs[..., 0], axis=-1)
        imag = tf.expand_dims(inputs[..., 1], axis=-1)

        # Block 1
        r, i = self.block1_conv1(real, imag)
        r, i = CReLU(r, i)
        r, i = self.block1_conv2(r, i)
        r, i = CReLU(r, i)
        feat1 = (r, i)
        r, i = self.pool1(r, i)

        # Block 2
        r, i = self.block2_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block2_conv2(r, i)
        r, i = CReLU(r, i)
        feat2 = (r, i)
        r, i = self.pool2(r, i)

        # Block 3
        r, i = self.block3_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block3_conv2(r, i)
        r, i = CReLU(r, i)
        feat3 = (r, i)
        r, i = self.pool3(r, i)

        # Block 4
        r, i = self.block4_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block4_conv2(r, i)
        r, i = CReLU(r, i)
        feat4 = (r, i)
        r, i = self.pool4(r, i)

        # Block 5 (no pooling)
        r, i = self.block5_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block5_conv2(r, i)
        r, i = CReLU(r, i)
        feat5 = (r, i)

        return feat1, feat2, feat3, feat4, feat5


In [13]:
class ComplexUnetUpTF(tf.keras.layers.Layer):
    def __init__(self, out_channels):
        super(ComplexUnetUpTF, self).__init__()
        self.upsample = ComplexUpSampling2D(size=(2, 2), interpolation='bilinear')
        self.out_channels = out_channels
        self.conv1 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        self.conv2 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        
        # self.dense_block = DensBlock(out_channels)

    def call(self, real_input1, imag_input1, real_input2, imag_input2):
        # Upsample
        up_real, up_imag = self.upsample(real_input2, imag_input2)

        # Concatenate with skip connection
        concat_real, concat_imag = concatenate_with(real_input1, imag_input1, up_real, up_imag)

        # Proper flow through both conv layers
        out_real, out_imag = self.conv1(concat_real, concat_imag)
        out_real, out_imag = CReLU(out_real, out_imag)

        out_real, out_imag = self.conv2(out_real, out_imag)
        out_real, out_imag = CReLU(out_real, out_imag)

        return out_real, out_imag


In [18]:
import tensorflow as tf
from tensorflow.keras import Model

class SF_UNet_TF(Model):
    def __init__(self, input_shape=(320, 320, 2), num_classes=2):
        super().__init__()
        self.encoder = ComplexEncoderTF(input_shape=input_shape)


        self.skip4 = SkipConnection(ratio=10)
        self.skip3 = SkipConnection(ratio=10)
        self.skip2 = SkipConnection(ratio=10)
        self.skip1 = SkipConnection(ratio=10)

        self.unet_up4 = ComplexUnetUpTF(512)
        self.unet_up3 = ComplexUnetUpTF(256)
        self.unet_up2 = ComplexUnetUpTF(128)
        self.unet_up1 = ComplexUnetUpTF(64)

        self.final_conv = complex_Conv2D(filters=num_classes, kernel_size=1, activation='linear')

    def call(self, inputs):
        # real = tf.expand_dims(inputs[..., 0], axis=-1)
        # imag = tf.expand_dims(inputs[..., 1], axis=-1)

        feat1, feat2, feat3, feat4, feat5 = self.encoder(inputs)
        # print("feat4 real shape:", feat4[0].shape)
        # print("feat4 imag shape:", feat4[1].shape)
        # print("feat5 real shape:", feat5[0].shape)
        # print("feat5 imag shape:", feat5[1].shape)

        # Skip + up stage 4
        # tf.print("feat4 real min:", tf.reduce_min(feat4[0]), "max:", tf.reduce_max(feat4[0]), "mean:", tf.reduce_mean(feat4[0]))
        # tf.print("feat4 imag min:", tf.reduce_min(feat4[1]), "max:", tf.reduce_max(feat4[1]), "mean:", tf.reduce_mean(feat4[1]))
        # tf.debugging.check_numerics(feat4[0], "üö® feat4[0] has NaNs/Infs")
        # tf.debugging.check_numerics(feat4[1], "üö® feat4[1] has NaNs/Infs")

        skip4_out_r, skip4_out_i = self.skip4(feat4[0],feat4[1])
        # tf.print("Skip4 output shape (real):", tf.shape(skip4_out_r))
        # tf.print("Skip4 output shape (imag):", tf.shape(skip4_out_i))

        decoder4 = self.unet_up4(skip4_out_r, skip4_out_i, feat5[0], feat5[1])
        # tf.print("UnetUp4 output shape (real):", tf.shape(decoder4[0]))
        # tf.print("UnetUp4 output shape (imag):", tf.shape(decoder4[1]))

        # Skip + up stage 3
        skip3_out_r, skip3_out_i = self.skip3(feat3[0],feat3[1])
        # tf.print("Skip3 output shape (real):", tf.shape(skip3_out_r))
        # tf.print("Skip3 output shape (imag):", tf.shape(skip3_out_i))

        decoder3 = self.unet_up3(skip3_out_r, skip3_out_i, decoder4[0], decoder4[1])
        # tf.print("UnetUp3 output shape (real):", tf.shape(decoder3[0]))
        # tf.print("UnetUp3 output shape (imag):", tf.shape(decoder3[1]))

        # Skip + up stage 2
        skip2_out_r, skip2_out_i = self.skip2(feat2[0],feat2[1])
        # tf.print("Skip2 output shape (real):", tf.shape(skip2_out_r))
        # tf.print("Skip2 output shape (imag):", tf.shape(skip2_out_i))

        decoder2 = self.unet_up2(skip2_out_r, skip2_out_i, decoder3[0], decoder3[1])
        # tf.print("UnetUp2 output shape (real):", tf.shape(decoder2[0]))
        # tf.print("UnetUp2 output shape (imag):", tf.shape(decoder2[1]))

        # Skip + up stage 1
        skip1_out_r, skip1_out_i = self.skip1(feat1[0],feat1[1])
        # tf.print("Skip1 output shape (real):", tf.shape(skip1_out_r))
        # tf.print("Skip1 output shape (imag):", tf.shape(skip1_out_i))

        decoder1 = self.unet_up1(skip1_out_r, skip1_out_i, decoder2[0], decoder2[1])
        # tf.print("UnetUp1 output shape (real):", tf.shape(decoder1[0]))
        # tf.print("UnetUp1 output shape (imag):", tf.shape(decoder1[1]))

        # Final 1√ó1 conv
        real_out, imag_out = self.final_conv(decoder1[0], decoder1[1])
        # tf.print("Final conv output shape (real):", tf.shape(real_out))
        # tf.print("Final conv output shape (imag):", tf.shape(imag_out))

        # Concatenate back to a 2-channel or 4-channel tensor
        output = concatenate([real_out, imag_out], axis=-1)
        # tf.print("Final concatenated output shape:", tf.shape(output))

        return output


In [26]:
# from tensorflow.keras.layers import Input, Lambda
# from tensorflow.keras import Model

# def build_dual_output_model(nf=32, kshape=(3, 3), H=320, W=320, channels=2, cascades=5):
#     input_img = Input(shape=(H, W, channels), name='input_img')

#     # Step 1: First image-domain U-Net
#     unet1 = SF_UNet_TF()
#     img1_out = unet1(input_img)


#     # Label output
#     final_out_image = tf.keras.layers.Lambda(lambda x: x, name='target_img')(img1_out)

#     # Only final image output (no k-space)
#     model = Model(
#         inputs=input_img,
#         outputs={'target_img': final_out_image}
#     )

#     return model
from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model


def build_dual_output_model(kshape=(3, 3), H=320, W=320, channels=2):
    input_img = Input(shape=(H, W, channels))

    # Image-domain SF_UNet_TF only
    unet = SF_UNet_TF()
    output_img = unet(input_img)

    # Assign a name to the output for use in loss dict
    output_img = Lambda(lambda x: x)(output_img)

    model = Model(
        inputs=input_img,
        outputs=output_img
    )

    return model
# model = build_dual_output_model(H=320, W=320)
# model.summary()

# # Create a dummy batch: batch=2, H=64, W=64, 2 channels (real+imag)
# dummy = np.random.randn(2, 320, 320, 2).astype(np.float32)

# # Forward-pass: this will trigger our tf.print in SkipConnection.call
# _ = model(dummy)



In [27]:
# import tensorflow as tf
# from tensorflow.keras.losses import MeanSquaredError

# # Build the model
# model = build_dual_output_model(H=320, W=320, channels=2)

# # Create dummy input and target tensors
# x = tf.random.normal([1, 320, 320, 2])  # input: real + imag
# y_true = tf.random.normal([1, 320, 320, 2])  # target: real + imag

# # Use Mean Squared Error as loss
# loss_fn = MeanSquaredError()

# # Track gradients
# with tf.GradientTape() as tape:
#     y_pred = model(x)
#     loss = loss_fn(y_true, y_pred)

# # Get trainable variables and compute gradients
# trainable_vars = model.trainable_variables
# grads = tape.gradient(loss, trainable_vars)

# # Check which gradients are None
# none_grads = [var.name for var, g in zip(trainable_vars, grads) if g is None]

# # Display result
# if none_grads:
#     print("‚ö†Ô∏è Gradients are NOT flowing for these variables:")
#     for name in none_grads:
#         print(f" - {name}")
# else:
#     print("‚úÖ All gradients are flowing correctly!")
