In [1]:
import tensorflow as tf
from tensorflow.keras import layers


In [2]:
class AdaptiveGlobalFilter(tf.keras.layers.Layer):
    def __init__(self, ratio=10):
        super().__init__()
        self.ratio = ratio

    def build(self, input_shape):
        # input_shape: [B, H, W, C, 2]
        _, H, W, C, _ = input_shape
        self.H, self.W, self.C = H, W, C

        # initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
        self.filter = self.add_weight(
            shape=(C, H, W, 2),
            # initializer=initializer,
            trainable=True,
            name='complex_filter'
        )

        self.mask_low = self._create_static_mask()
        self.mask_high = 1.0 - self.mask_low

    def _create_static_mask(self):
        mask = tf.zeros((self.H, self.W), dtype=tf.float32)
        crow, ccol = self.H // 2, self.W // 2
        r = self.ratio

        indices = tf.stack(tf.meshgrid(
            tf.range(crow - r, crow + r),
            tf.range(ccol - r, ccol + r),
            indexing='ij'
        ), axis=-1)
        indices = tf.reshape(indices, [-1, 2])
        updates = tf.ones((tf.shape(indices)[0],), dtype=tf.float32)

        return tf.Variable(tf.tensor_scatter_nd_update(mask, indices, updates), trainable=False)

    def call(self, x):
        tf.debugging.check_numerics(x, "Input to AGF has NaNs/Infs")
        # x: [B, H, W, C, 2]
        x = tf.transpose(x, [0, 3, 1, 2, 4])  # [B, C, H, W, 2]
        x_freq = fft2c_tf(x)                 # [B, C, H, W, 2]
        tf.debugging.check_numerics(x_freq, "FFT2 output has NaNs/Infs")


        # Convert learnable filter to complex
        real_part = self.filter[..., 0]
        imag_part = self.filter[..., 1]
        weight = tf.complex(real_part, imag_part)  # [C, H, W]
        weight = tf.expand_dims(weight, axis=0)    # [1, C, H, W]

        # Cast masks
        mask_low_c = tf.cast(tf.reshape(self.mask_low, [1, 1, self.H, self.W]), tf.complex64)
        mask_high_c = tf.cast(tf.reshape(self.mask_high, [1, 1, self.H, self.W]), tf.complex64)

        # Convert x_freq to complex
        x_freq_c = tf.complex(x_freq[..., 0], x_freq[..., 1])  # [B, C, H, W]

        x_low = x_freq_c * mask_low_c
        x_high = x_freq_c * mask_high_c

        x_low_filtered = x_low * weight
        x_combined = x_low_filtered + x_high  # [B, C, H, W]

        x_combined = tf.stack([tf.math.real(x_combined), tf.math.imag(x_combined)], axis=-1)  # [B, C, H, W, 2]
        x_combined = tf.transpose(x_combined, [0, 2, 3, 1, 4])  # [B, H, W, C, 2]

        # Apply IFFT
        x_out = ifft2c_tf(x_combined)  # [B, H, W, C, 2]
        tf.debugging.check_numerics(x_out, "IFFT2 output has NaNs/Infs")
        return x_out

# # # Example dummy input: [batch, height, width, channels]
# dummy_input = tf.random.normal((2, 512, 512, 64,2))

# # Instantiate layer
# agf = AdaptiveGlobalFilter(ratio=10)

# # Forward pass
# output = agf(dummy_input)

# print("Input shape: ", dummy_input.shape)
# print("Output shape:", output.shape)

In [3]:

# class SpatialAttention(tf.keras.layers.Layer):
#     def __init__(self):
#         super(SpatialAttention, self).__init__()
#         self.conv = tf.keras.layers.Conv2D(
#             filters=1,
#             kernel_size=7,
#             padding='same',
#             use_bias=False
#         )
#         self.sigmoid = tf.keras.activations.sigmoid

#     def call(self, x):
#         # x: [B, H, W, C, 2]
#         real = x[..., 0]
#         imag = x[..., 1]
#         mag = tf.sqrt(real**2 + imag**2 + 1e-6)    # add eps for stability
#         avg_out = tf.reduce_mean(mag, axis=-1, keepdims=True)  # [B,H,W,1]
#         max_out = tf.reduce_max(mag, axis=-1, keepdims=True)   # [B,H,W,1]
#         concat = tf.concat([avg_out, max_out], axis=-1)        # [B,H,W,2]
#         attn_map = self.sigmoid(self.conv(concat))             # [B,H,W,1]
#         attn_map = tf.expand_dims(attn_map, axis=-1)           # [B,H,W,1,1]
#         return x * attn_map                                    # broadcast to [B,H,W,C,2]


class SpatialAttention(tf.keras.layers.Layer):
    def __init__(self):
        super(SpatialAttention, self).__init__()
        self.conv = tf.keras.layers.Conv2D(
            filters=1,
            kernel_size=7,
            padding='same',
            use_bias=False
        )
        self.sigmoid = tf.keras.activations.sigmoid
        self.last_attn_map = None  # Optional: for saving the attention map

    def call(self, x, return_attn=False):
        # x: [B, H, W, C, 2]
        real = x[..., 0]
        imag = x[..., 1]
        mag = tf.sqrt(real**2 + imag**2 + 1e-6)    # Magnitude for attention

        avg_out = tf.reduce_mean(mag, axis=-1, keepdims=True)  # [B, H, W, 1]
        max_out = tf.reduce_max(mag, axis=-1, keepdims=True)   # [B, H, W, 1]
        concat = tf.concat([avg_out, max_out], axis=-1)        # [B, H, W, 2]

        attn_map = self.sigmoid(self.conv(concat))             # [B, H, W, 1]
        attn_map = tf.expand_dims(attn_map, axis=-1)           # [B, H, W, 1, 1]

        self.last_attn_map = attn_map  # Optionally store for later

        output = x * attn_map  # Broadcasting: [B, H, W, C, 2] * [B, H, W, 1, 1]

        if return_attn:
            return output, attn_map
        else:
            return output


In [4]:
import tensorflow as tf
from tensorflow.keras.layers import Layer, Conv1D, Conv2DTranspose
import math

class MPCA(Layer):
    def __init__(self, input_channel1=128, input_channel2=256, gamma=2, bias=1):
        super(MPCA, self).__init__()
        self.input_channel1 = input_channel1
        self.input_channel2 = input_channel2

        # Adaptive average pooling equivalent: global avg pool spatial dims
        # We'll do manually with reduce_mean over height & width

        kernel_size1 = int(abs((math.log(input_channel1, 2) + bias) / gamma))
        kernel_size1 = kernel_size1 if kernel_size1 % 2 else kernel_size1 + 1

        kernel_size2 = int(abs((math.log(input_channel2, 2) + bias) / gamma))
        kernel_size2 = kernel_size2 if kernel_size2 % 2 else kernel_size2 + 1

        kernel_size3 = int(abs((math.log(input_channel1 + input_channel2, 2) + bias) / gamma))
        kernel_size3 = kernel_size3 if kernel_size3 % 2 else kernel_size3 + 1

        # Conv1D layers - note input shape: (batch, steps, channels)
        self.conv1 = Conv1D(filters=1, kernel_size=kernel_size1, padding='same', use_bias=False)
        self.conv2 = Conv1D(filters=1, kernel_size=kernel_size2, padding='same', use_bias=False)
        self.conv3 = Conv1D(filters=1, kernel_size=kernel_size3, padding='same', use_bias=False)

        self.sigmoid = tf.keras.activations.sigmoid

        self.up = Conv2DTranspose(filters=input_channel1,
                                  kernel_size=3,
                                  strides=2,
                                  padding='same',
                                  output_padding=1)
    def call(self, x1, x2):
        # x1, x2: [B, H, W, C]
    
        # Global avg pool spatial dims
        a1 = tf.reduce_mean(x1, axis=[1, 2], keepdims=True)  # [B,1,1,C1]
        a2 = tf.reduce_mean(x2, axis=[1, 2], keepdims=True)  # [B,1,1,C2]
        # print("a1 image shape:", a1.shape)
        # print("a2 image shape:", a2.shape)


        # Squeeze height and width dims
        a1 = tf.squeeze(a1, axis=[1, 2])  # [B, C1]
        a2 = tf.squeeze(a2, axis=[1, 2])  # [B, C2]
        # print("a1 after squueze image shape:", a1.shape)
        # print("a2 after squueze image shape:", a2.shape)


    
        # Add a channel dim at the end for Conv1D: [B, length, channels=1]
        a1 = tf.expand_dims(a1, axis=-1)  # [B, C1, 1]
        a2 = tf.expand_dims(a2, axis=-1)  # [B, C2, 1]

        # print("a1 after expand image shape:", a1.shape)
        # print("a2 after expand image shape:", a2.shape)

        # Conv1D along length (channel dim from original tensor)
        a1 = self.conv1(a1)  # [B, C1, 1]
        a2 = self.conv2(a2)  # [B, C2, 1]
        # print("a1 after conv1d image shape:", a1.shape)
        # print("a2 after conv1d image shape:", a2.shape)
    
        # Concatenate along length dim (axis=1)
        mid = tf.concat([a1, a2], axis=1)  # [B, C1 + C2, 1]
        # print("mid after concat a1 and a2 image shape:", mid.shape)
    
    
        # Conv3 along length axis with channels=1
        mid = self.conv3(mid)  # [B, C1 + C2, 1]
        # print("mid after conv a1 and a2 image shape:", mid.shape)
    
    
        # Apply sigmoid activation
        mid = self.sigmoid(mid)  # [B, C1 + C2, 1]
        
    
        # Remove channel dim, shape: [B, C1 + C2]
        mid = tf.squeeze(mid, axis=-1)
        # print("mid after squeeze a1 and a2 image shape:", mid.shape)
    
    
        # Split back into attn1 and attn2
        attn1, attn2 = tf.split(mid, [self.input_channel1, self.input_channel2], axis=1)  # both [B, C]
        # print("attn1 image shape:", attn1.shape)
        # print("attn2 image shape:", attn2.shape)
    
        # Reshape for broadcasting to [B,1,1,C]
        attn1 = tf.reshape(attn1, [-1, 1, 1, self.input_channel1])
        attn2 = tf.reshape(attn2, [-1, 1, 1, self.input_channel2])
        # print("attn1 after resshape image shape:", attn1.shape)
        # print("attn2 after reshape image shape:", attn2.shape)
    
        
        # Apply attention weights
        x1_out = x1 * attn1
        x2_out = x2 * attn2
        # print("x1_out image shape:", x1_out.shape)
        # print("x2_out image shape:", x2_out.shape)
    
        
    
        # Upsample x2_out to match x1_out spatial dims
        x2_out = self.up(x2_out)
        # print("x2_out image upsanple shape:", x2_out.shape)
    
        
    
        # Sum outputs
        result = x1_out + x2_out
        # print("result MPCA shape:", result.shape)
    
        return result




In [5]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D,\
                                    Multiply, ZeroPadding2D, Cropping2D


In [1]:
%run "C:\\Users\\DU\\aman_fastmri\\layer.ipynb"
    
%run "C:\\Users\\DU\\aman_fastmri\\activation.ipynb"
    
%run "C:\\Users\\DU\\aman_fastmri\\fft.ipynb"
#%run ./Modules/fft.ipynb

# %run layer.ipynb
# %run activation.ipynb
# 

In [7]:
import tensorflow as tf
from tensorflow.keras import Model

class ComplexEncoderTF(Model):
    def __init__(self, input_shape=(320, 320, 2)):
        super(ComplexEncoderTF, self).__init__()

        # Block 1
        self.block1_conv1 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.block1_conv2 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.pool1 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 2
        self.block2_conv1 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.block2_conv2 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.pool2 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 3
        self.block3_conv1 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.block3_conv2 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.pool3 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 4
        self.block4_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block4_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.pool4 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 5 (no pooling)
        self.block5_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block5_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')

        self.build((None, *input_shape))

    def call(self, inputs):
        real = tf.expand_dims(inputs[..., 0], axis=-1)
        imag = tf.expand_dims(inputs[..., 1], axis=-1)

        # Block 1
        r, i = self.block1_conv1(real, imag)
        r, i = CReLU(r, i)
        r, i = self.block1_conv2(r, i)
        r, i = CReLU(r, i)
        feat1 = (r, i)
        r, i = self.pool1(r, i)

        # Block 2
        r, i = self.block2_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block2_conv2(r, i)
        r, i = CReLU(r, i)
        feat2 = (r, i)
        r, i = self.pool2(r, i)

        # Block 3
        r, i = self.block3_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block3_conv2(r, i)
        r, i = CReLU(r, i)
        feat3 = (r, i)
        r, i = self.pool3(r, i)

        # Block 4
        r, i = self.block4_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block4_conv2(r, i)
        r, i = CReLU(r, i)
        feat4 = (r, i)
        r, i = self.pool4(r, i)

        # Block 5 (no pooling)
        r, i = self.block5_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block5_conv2(r, i)
        r, i = CReLU(r, i)
        feat5 = (r, i)

        return feat1, feat2, feat3, feat4, feat5


In [8]:
# class FSA(tf.keras.layers.Layer):
#     def __init__(self, ratio=10):
#         super().__init__()
#         self.agf = AdaptiveGlobalFilter(ratio=ratio)
#         self.sa = SpatialAttention()

#     def call(self, x):
        
#         """
#         x: Tensor of shape [B, H, W, C, 2]
#         Returns:
#             Tensor of shape [B, H, W, C, 2]
#         """
    
#         out1 = self.agf(x)  # Frequency attention
        
#         out2 = self.sa(x)   # Spatial attention
#         out= out1 + out2  
        
#         return out  

class FSA(tf.keras.layers.Layer):
    def __init__(self, ratio=10):
        super().__init__()
        self.agf = AdaptiveGlobalFilter(ratio=ratio)
        self.sa = SpatialAttention()


    def call(self, x, return_attn=False):
        out1 = self.agf(x)
        if return_attn:
            out2, attn_map = self.sa(x, return_attn=True)
            out = out1 + out2
            return out, attn_map
        else:
            out2 = self.sa(x)
            out = out1 + out2
            return out




In [9]:
class SkipConnection(tf.keras.layers.Layer):
    def __init__(self, ratio, input_channel1, input_channel2, size):
        super().__init__()
        self.mpca_real = MPCA(input_channel1, input_channel2)
        self.mpca_imag = MPCA(input_channel1, input_channel2)
        self.fsa = FSA(ratio=ratio)


    def call(self, input1, input2):
        real1, imag1 = input1
        real2, imag2 = input2
        out_real = self.mpca_real(real1, real2)
        out_imag = self.mpca_imag(imag1, imag2)
        # print("out_imag image shape:", out_imag.shape)

        # combined = tf.concat([tf.expand_dims(out_real, axis=-1),tf.expand_dims(out_imag, axis=-1)], axis=-1)  # shape: [B, H, W, 2]
        # # Apply FSA
        # # print("combined :", combined.shape)
        # output = self.fsa(combined)

        # out_real = output[..., 0]
        # out_imag = output[..., 1]

        return out_real, out_imag




# batch_size = 1
# height, width = 64, 64
# input_channel1 = 8
# input_channel2 = 16
# size = (64, 64)

# real1 = tf.random.normal((batch_size, height, width, input_channel1))
# imag1 = tf.random.normal((batch_size, height, width, input_channel1))
# real2 = tf.random.normal((batch_size, 32, 32, input_channel2))
# imag2 = tf.random.normal((batch_size, 32, 32, input_channel2))

# # Create and run layer
# skip = SkipConnection(ratio=4, input_channel1=input_channel1, input_channel2=input_channel2, size=size)
# out_real, out_imag = skip((real1, imag1), (real2, imag2))

# print("✅ Output Real Shape:", out_real.shape)
# print("✅ Output Imag Shape:", out_imag.shape)



In [10]:
class ComplexUnetUpTF(tf.keras.layers.Layer):
    def __init__(self, out_channels):
        super(ComplexUnetUpTF, self).__init__()
        self.upsample = ComplexUpSampling2D(size=(2, 2), interpolation='bilinear')
        self.out_channels = out_channels
        self.conv1 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        self.conv2 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        
        # self.dense_block = DensBlock(out_channels)

    def call(self, real_input1, imag_input1, real_input2, imag_input2):
        # Upsample
        up_real, up_imag = self.upsample(real_input2, imag_input2)

        # Concatenate with skip connection
        concat_real, concat_imag = concatenate_with(real_input1, imag_input1, up_real, up_imag)

        # Proper flow through both conv layers
        out_real, out_imag = self.conv1(concat_real, concat_imag)
        out_real, out_imag = self.conv2(out_real, out_imag)
        out_real, out_imag = CReLU(out_real, out_imag)

        return out_real, out_imag


In [17]:
import tensorflow as tf
from tensorflow.keras import Model

class SF_UNet_TF(Model):
    def __init__(self, input_shape=(320, 320, 2), num_classes=2):
        super().__init__()
        self.encoder = ComplexEncoderTF(input_shape=input_shape)

        self.skip4 = SkipConnection(ratio=10, input_channel1=256, input_channel2=256, size=40)
        self.skip3 = SkipConnection(ratio=10, input_channel1=128, input_channel2=256, size=80)
        self.skip2 = SkipConnection(ratio=10, input_channel1=64, input_channel2=128, size=160)
        self.skip1 = SkipConnection(ratio=10, input_channel1=32, input_channel2=64, size=320)

        self.unet_up4 = ComplexUnetUpTF(512)
        self.unet_up3 = ComplexUnetUpTF(256)
        self.unet_up2 = ComplexUnetUpTF(128)
        self.unet_up1 = ComplexUnetUpTF(64)

        self.final_conv = complex_Conv2D(filters=num_classes, kernel_size=1, activation='linear')

    def call(self, inputs):
        # real = tf.expand_dims(inputs[..., 0], axis=-1)
        # imag = tf.expand_dims(inputs[..., 1], axis=-1)

        feat1, feat2, feat3, feat4, feat5 = self.encoder(inputs)
        # print("feat1 real shape:", feat1[0].shape)
        # print("feat2 real shape:", feat2[0].shape)
        # print("feat3 real shape:", feat3[0].shape)
        # print("feat4 real shape:", feat4[0].shape)
        # print("feat5 bottleneck real shape:", feat5[0].shape)

        # print("feat4 real shape:", feat4[0].shape)
        # print("feat4 imag shape:", feat4[1].shape)
        # print("feat5 real shape:", feat5[0].shape)
        # print("feat5 imag shape:", feat5[1].shape)

        # Skip + up stage 4
        skip4_out_r, skip4_out_i = self.skip4(feat4, feat5)
        # tf.print("Skip4 output shape (real):", tf.shape(skip4_out_r))
        # tf.print("Skip4 output shape (imag):", tf.shape(skip4_out_i))

        decoder4 = self.unet_up4(skip4_out_r, skip4_out_i, feat5[0], feat5[1])
        # tf.print("UnetUp4 output shape (real):", tf.shape(decoder4[0]))
        # tf.print("UnetUp4 output shape (imag):", tf.shape(decoder4[1]))

        # Skip + up stage 3
        skip3_out_r, skip3_out_i = self.skip3(feat3, decoder4)
        # tf.print("Skip3 output shape (real):", tf.shape(skip3_out_r))
        # tf.print("Skip3 output shape (imag):", tf.shape(skip3_out_i))

        decoder3 = self.unet_up3(skip3_out_r, skip3_out_i, decoder4[0], decoder4[1])
        # tf.print("UnetUp3 output shape (real):", tf.shape(decoder3[0]))
        # tf.print("UnetUp3 output shape (imag):", tf.shape(decoder3[1]))

        # Skip + up stage 2
        skip2_out_r, skip2_out_i = self.skip2(feat2, decoder3)
        # tf.print("Skip2 output shape (real):", tf.shape(skip2_out_r))
        # tf.print("Skip2 output shape (imag):", tf.shape(skip2_out_i))

        decoder2 = self.unet_up2(skip2_out_r, skip2_out_i, decoder3[0], decoder3[1])
        # tf.print("UnetUp2 output shape (real):", tf.shape(decoder2[0]))
        # tf.print("UnetUp2 output shape (imag):", tf.shape(decoder2[1]))

        # Skip + up stage 1
        skip1_out_r, skip1_out_i = self.skip1(feat1, decoder2)
        # tf.print("Skip1 output shape (real):", tf.shape(skip1_out_r))
        # tf.print("Skip1 output shape (imag):", tf.shape(skip1_out_i))

        decoder1 = self.unet_up1(skip1_out_r, skip1_out_i, decoder2[0], decoder2[1])
        # tf.print("UnetUp1 output shape (real):", tf.shape(decoder1[0]))
        # tf.print("UnetUp1 output shape (imag):", tf.shape(decoder1[1]))

        # Final 1×1 conv
        real_out, imag_out = self.final_conv(decoder1[0], decoder1[1])
        # tf.print("Final conv output shape (real):", tf.shape(real_out))
        # tf.print("Final conv output shape (imag):", tf.shape(imag_out))

        # Concatenate back to a 2-channel or 4-channel tensor
        output = concatenate([real_out, imag_out], axis=-1)
        # tf.print("Final concatenated output shape:", tf.shape(output))

        return output

    # def call(self, inputs):
    #     # Encoder
    #     feat1, feat2, feat3, feat4, feat5 = self.encoder(inputs)
    
    #     # Decoder stage 4
    #     skip4_r, skip4_i = self.skip4(feat4, feat5)
    #     decoder4 = self.unet_up4(skip4_r, skip4_i, feat5[0], feat5[1])
    
    #     # Decoder stage 3
    #     skip3_r, skip3_i = self.skip3(feat3, feat4)
    #     decoder3 = self.unet_up3(skip3_r, skip3_i, decoder4[0], decoder4[1])
    
    #     # Decoder stage 2
    #     skip2_r, skip2_i = self.skip2(feat2, feat3)
    #     decoder2 = self.unet_up2(skip2_r, skip2_i, decoder3[0], decoder3[1])
    
    #     # Decoder stage 1
    #     skip1_r, skip1_i = self.skip1(feat1, feat2)
    #     decoder1 = self.unet_up1(skip1_r, skip1_i, decoder2[0], decoder2[1])
    
    #     # Final 1x1 convolution
    #     real_out, imag_out = self.final_conv(decoder1[0], decoder1[1])
    
    #     # Concatenate along last dimension: (B, H, W, C) -> (B, H, W, 2C)
    #     output = concatenate([real_out, imag_out], axis=-1)
    
    #     return output


In [18]:

from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model


def build_dual_output_model(kshape=(3, 3), H=320, W=320, channels=2):
    input_img = Input(shape=(H, W, channels))

    # Image-domain SF_UNet_TF only
    unet = SF_UNet_TF()
    output_img = unet(input_img)

    # Assign a name to the output for use in loss dict
    output_img = Lambda(lambda x: x)(output_img)

    model = Model(
        inputs=input_img,
        outputs=output_img
    )

    return model
model = build_dual_output_model(H=320, W=320)
model.summary()

# Create a dummy batch: batch=2, H=64, W=64, 2 channels (real+imag)
dummy = np.random.randn(2, 320, 320, 2).astype(np.float32)

# Forward-pass: this will trigger our tf.print in SkipConnection.call
_ = model(dummy)



feat1 real shape: (None, 320, 320, 32)
feat2 real shape: (None, 160, 160, 64)
feat3 real shape: (None, 80, 80, 128)
feat4 real shape: (None, 40, 40, 256)
feat5 bottleneck real shape: (None, 20, 20, 256)
feat4 real shape: (None, 40, 40, 256)
feat4 imag shape: (None, 40, 40, 256)
feat5 real shape: (None, 20, 20, 256)
feat5 imag shape: (None, 20, 20, 256)
Model: "model_2"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_4 (InputLayer)        [(None, 320, 320, 2)]     0         
                                                                 
 sf_u_net_tf_3 (SF_UNet_TF)  (None, 320, 320, 2)       24172146  
                                                                 
 lambda_2 (Lambda)           (None, 320, 320, 2)       0         
                                                                 
Total params: 24,172,146
Trainable params: 24,036,146
Non-trainable params: 136,000
________________

In [16]:
# import tensorflow as tf
# from tensorflow.keras.losses import MeanSquaredError

# # Build the model
# model = build_dual_output_model(H=320, W=320, channels=2)

# # Create dummy input and target tensors
# x = tf.random.normal([1, 320, 320, 2])  # input: real + imag
# y_true = tf.random.normal([1, 320, 320, 2])  # target: real + imag

# # Use Mean Squared Error as loss
# loss_fn = MeanSquaredError()

# # Track gradients
# with tf.GradientTape() as tape:
#     y_pred = model(x)
#     loss = loss_fn(y_true, y_pred)

# # Get trainable variables and compute gradients
# trainable_vars = model.trainable_variables
# grads = tape.gradient(loss, trainable_vars)

# # Check which gradients are None
# none_grads = [var.name for var, g in zip(trainable_vars, grads) if g is None]

# # Display result
# if none_grads:
#     print("⚠️ Gradients are NOT flowing for these variables:")
#     for name in none_grads:
#         print(f" - {name}")
# else:
#     print("✅ All gradients are flowing correctly!")
