In [1]:
import tensorflow as tf
from tensorflow.keras import layers


In [2]:
class AdaptiveGlobalFilter(tf.keras.layers.Layer):
    def __init__(self, ratio=10):
        super().__init__()
        self.ratio = ratio

    def build(self, input_shape):
        # input_shape: [B, H, W, C, 2]
        _, H, W, C, _ = input_shape
        self.H, self.W, self.C = H, W, C

        # initializer = tf.keras.initializers.RandomNormal(stddev=0.02)
        self.filter = self.add_weight(
            shape=(C, H, W, 2),
            # initializer=initializer,
            trainable=True,
            name='complex_filter'
        )

        self.mask_low = self._create_static_mask()
        self.mask_high = 1.0 - self.mask_low

    def _create_static_mask(self):
        mask = tf.zeros((self.H, self.W), dtype=tf.float32)
        crow, ccol = self.H // 2, self.W // 2
        r = self.ratio

        indices = tf.stack(tf.meshgrid(
            tf.range(crow - r, crow + r),
            tf.range(ccol - r, ccol + r),
            indexing='ij'
        ), axis=-1)
        indices = tf.reshape(indices, [-1, 2])
        updates = tf.ones((tf.shape(indices)[0],), dtype=tf.float32)

        return tf.Variable(tf.tensor_scatter_nd_update(mask, indices, updates), trainable=False)

    def call(self, x):
        tf.debugging.check_numerics(x, "Input to AGF has NaNs/Infs")
        # x: [B, H, W, C, 2]
        x = tf.transpose(x, [0, 3, 1, 2, 4])  # [B, C, H, W, 2]
        x_freq = fft2c_tf(x)                 # [B, C, H, W, 2]
        tf.debugging.check_numerics(x_freq, "FFT2 output has NaNs/Infs")


        # Convert learnable filter to complex
        real_part = self.filter[..., 0]
        imag_part = self.filter[..., 1]
        weight = tf.complex(real_part, imag_part)  # [C, H, W]
        weight = tf.expand_dims(weight, axis=0)    # [1, C, H, W]

        # Cast masks
        mask_low_c = tf.cast(tf.reshape(self.mask_low, [1, 1, self.H, self.W]), tf.complex64)
        mask_high_c = tf.cast(tf.reshape(self.mask_high, [1, 1, self.H, self.W]), tf.complex64)

        # Convert x_freq to complex
        x_freq_c = tf.complex(x_freq[..., 0], x_freq[..., 1])  # [B, C, H, W]

        x_low = x_freq_c * mask_low_c
        x_high = x_freq_c * mask_high_c

        x_high_filtered = x_high * weight
        x_combined = x_low + x_high_filtered
        # x_low_filtered = x_low * weight
        # x_combined = x_low_filtered + x_high  # [B, C, H, W]

        x_combined = tf.stack([tf.math.real(x_combined), tf.math.imag(x_combined)], axis=-1)  # [B, C, H, W, 2]
        x_combined = tf.transpose(x_combined, [0, 2, 3, 1, 4])  # [B, H, W, C, 2]

        # Apply IFFT
        x_out = ifft2c_tf(x_combined)  # [B, H, W, C, 2]
        tf.debugging.check_numerics(x_out, "IFFT2 output has NaNs/Infs")
        return x_out

# # # Example dummy input: [batch, height, width, channels]
# dummy_input = tf.random.normal((2, 512, 512, 64,2))

# # Instantiate layer
# agf = AdaptiveGlobalFilter(ratio=10)

# # Forward pass
# output = agf(dummy_input)

# print("Input shape: ", dummy_input.shape)
# print("Output shape:", output.shape)

In [3]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D,\
                                    Multiply, ZeroPadding2D, Cropping2D


In [4]:
# %run ./layer.ipynb
# %run ./activation.ipynb

# %run ./fft.ipynb

%run ./Modules/layer.ipynb
%run ./Modules/activation.ipynb

%run ./Modules/fft.ipynb

In [5]:
import tensorflow as tf
from tensorflow.keras import Model

class ComplexEncoderTF(Model):
    def __init__(self, input_shape=(320, 320, 2)):
        super(ComplexEncoderTF, self).__init__()

        # Block 1
        self.block1_conv1 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.block1_conv2 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.pool1 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 2
        self.block2_conv1 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.block2_conv2 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.pool2 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 3
        self.block3_conv1 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.block3_conv2 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.pool3 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 4
        self.block4_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block4_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.pool4 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 5 (no pooling)
        self.block5_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block5_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')

        self.build((None, *input_shape))
        
    def call(self, inputs):
        real = tf.expand_dims(inputs[..., 0], axis=-1)
        imag = tf.expand_dims(inputs[..., 1], axis=-1)

        # Block 1
        r, i = self.block1_conv1(real, imag)
        r, i = CReLU(r, i)
        r, i = self.block1_conv2(r, i)
        r, i = CReLU(r, i)
        feat1 = (r, i)
        r, i = self.pool1(r, i)

        # Block 2
        r, i = self.block2_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block2_conv2(r, i)
        r, i = CReLU(r, i)
        feat2 = (r, i)
        r, i = self.pool2(r, i)

        # Block 3
        r, i = self.block3_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block3_conv2(r, i)
        r, i = CReLU(r, i)
        feat3 = (r, i)
        r, i = self.pool3(r, i)

        # Block 4
        r, i = self.block4_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block4_conv2(r, i)
        r, i = CReLU(r, i)
        feat4 = (r, i)
        r, i = self.pool4(r, i)

        # Block 5 (no pooling)
        r, i = self.block5_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block5_conv2(r, i)
        r, i = CReLU(r, i)
        feat5 = (r, i)

        return feat1, feat2, feat3, feat4, feat5


In [6]:
class FSA(tf.keras.layers.Layer):
    def __init__(self, ratio=10):
        super().__init__()
        self.agf = AdaptiveGlobalFilter(ratio=ratio)
        #self.sa = SpatialAttention()

    def call(self, x, return_attn=False):
        
        """
        x: Tensor of shape [B, H, W, C, 2]
        Returns:
            Tensor of shape [B, H, W, C, 2]
        """
        # tf.print("FSA input stats:", tf.reduce_min(x), tf.reduce_max(x), tf.reduce_mean(x))

        out1 = self.agf(x)  # Frequency attention
        #out1 = self.agf(x)
        # if return_attn:
        #     out2, attn_map = self.sa(x, return_attn=True)
        #     out = out2
        #     return out, attn_map
        # else:
        #     out2 = self.sa(x)
        #     out = out2
        #     return out
        return out1



In [7]:
class SkipConnection(tf.keras.layers.Layer):
    def __init__(self, ratio, input_channel1, size):
        super().__init__()
        self.fsa = FSA(ratio=ratio)
     
    def call(self, input1, return_attn=False):
        real1, imag1 = input1
     
        fsa_input = tf.concat([real1[..., None], imag1[..., None]], axis=-1)
     
        if return_attn:
            fsa_out, attn_map = self.fsa(fsa_input, return_attn=True)
        else:
            fsa_out = self.fsa(fsa_input)

      
        return fsa_out[..., 0], fsa_out[..., 1]
# batch_size = 1
# height, width = 64, 64
# input_channel1 = 8
# input_channel2 = 16
# size = (64, 64)

# real1 = tf.random.normal((batch_size, height, width, input_channel1))
# imag1 = tf.random.normal((batch_size, height, width, input_channel1))
# real2 = tf.random.normal((batch_size, 32, 32, input_channel2))
# imag2 = tf.random.normal((batch_size, 32, 32, input_channel2))

# # Create and run layer
# skip = SkipConnection(ratio=4, input_channel1=input_channel1,  size=size)
# out_real, out_imag = skip((real1, imag1), (real2, imag2))

# print("✅ Output Real Shape:", out_real.shape)
# print("✅ Output Imag Shape:", out_imag.shape)


In [8]:
class ComplexUnetUpTF(tf.keras.layers.Layer):
    def __init__(self, out_channels):
        super(ComplexUnetUpTF, self).__init__()
        self.upsample = ComplexUpSampling2D(size=(2, 2), interpolation='bilinear')
        self.out_channels = out_channels
        self.conv1 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        self.conv2 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        
        # self.dense_block = DensBlock(out_channels)

    def call(self, real_input1, imag_input1, real_input2, imag_input2):
        # Upsample
        up_real, up_imag = self.upsample(real_input2, imag_input2)

        # Concatenate with skip connection
        concat_real, concat_imag = concatenate_with(real_input1, imag_input1, up_real, up_imag)

        # Proper flow through both conv layers
        out_real, out_imag = self.conv1(concat_real, concat_imag)
        out_real, out_imag = CReLU(out_real, out_imag)
        out_real, out_imag = self.conv2(out_real, out_imag)
        out_real, out_imag = CReLU(out_real, out_imag)

        return out_real, out_imag


In [9]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import concatenate  # Ensure this is imported

class SF_UNet_TF(Model):
    def __init__(self, input_shape=(320, 320, 2), num_classes=2):
        super().__init__()
        self.encoder = ComplexEncoderTF(input_shape=input_shape)

        self.skip4 = SkipConnection(ratio=10, input_channel1=256,  size=40)
        self.skip3 = SkipConnection(ratio=10, input_channel1=128,  size=80)
        self.skip2 = SkipConnection(ratio=10, input_channel1=64,  size=160)
        self.skip1 = SkipConnection(ratio=10, input_channel1=32,  size=320)

        self.unet_up4 = ComplexUnetUpTF(512)
        self.unet_up3 = ComplexUnetUpTF(256)
        self.unet_up2 = ComplexUnetUpTF(128)
        self.unet_up1 = ComplexUnetUpTF(64)

        self.final_conv = complex_Conv2D(filters=num_classes, kernel_size=1, activation='linear')


    def call(self, inputs):
        # Encoder
        feat1, feat2, feat3, feat4, feat5 = self.encoder(inputs)
    
        # Decoder stage 4
        skip4_r, skip4_i = self.skip4(feat4)
        decoder4 = self.unet_up4(skip4_r, skip4_i, feat5[0], feat5[1])
    
        # Decoder stage 3
        skip3_r, skip3_i = self.skip3(feat3)
        decoder3 = self.unet_up3(skip3_r, skip3_i, decoder4[0], decoder4[1])
    
        # Decoder stage 2
        skip2_r, skip2_i = self.skip2(feat2)
        decoder2 = self.unet_up2(skip2_r, skip2_i, decoder3[0], decoder3[1])
    
        # Decoder stage 1
        skip1_r, skip1_i = self.skip1(feat1)
        decoder1 = self.unet_up1(skip1_r, skip1_i, decoder2[0], decoder2[1])
    
        # Final 1x1 convolution
        real_out, imag_out = self.final_conv(decoder1[0], decoder1[1])
    
        # Concatenate along last dimension: (B, H, W, C) -> (B, H, W, 2C)
        output = concatenate([real_out, imag_out], axis=-1)
    
        return output


In [10]:

from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model


def build_dual_output_model(kshape=(3, 3), H=320, W=320, channels=2):
    input_img = Input(shape=(H, W, channels))

    # Image-domain SF_UNet_TF only
    unet = SF_UNet_TF()
    output_img = unet(input_img)

    # Assign a name to the output for use in loss dict
    output_img = Lambda(lambda x: x)(output_img)

    model = Model(
        inputs=input_img,
        outputs=output_img
    )

    return model
model = build_dual_output_model(H=320, W=320)
model.summary()



Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 320, 320, 2)]     0         
                                                                 
 sf_u_net_tf (SF_UNet_TF)    (None, 320, 320, 2)       22216898  
                                                                 
 lambda (Lambda)             (None, 320, 320, 2)       0         
                                                                 
Total params: 22,216,898
Trainable params: 22,080,898
Non-trainable params: 136,000
_________________________________________________________________
