In [1]:
import tensorflow as tf
from tensorflow.keras import layers


In [2]:
from tensorflow.keras.layers import Input, Conv2D, Lambda, Add, LeakyReLU,  \
                                    MaxPooling2D, concatenate, UpSampling2D,\
                                    Multiply, ZeroPadding2D, Cropping2D


In [3]:
# %run ./layer.ipynb
# %run ./activation.ipynb

# %run ./fft.ipynb

%run ./Modules/layer.ipynb
%run ./Modules/activation.ipynb

%run ./Modules/fft.ipynb

In [4]:
import tensorflow as tf
from tensorflow.keras import Model

class ComplexEncoderTF(Model):
    def __init__(self, input_shape=(320, 320, 2)):
        super(ComplexEncoderTF, self).__init__()

        # Block 1
        self.block1_conv1 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.block1_conv2 = complex_Conv2D(64, kernel_size=3, padding='same')
        self.pool1 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 2
        self.block2_conv1 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.block2_conv2 = complex_Conv2D(128, kernel_size=3, padding='same')
        self.pool2 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 3
        self.block3_conv1 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.block3_conv2 = complex_Conv2D(256, kernel_size=3, padding='same')
        self.pool3 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 4
        self.block4_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block4_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.pool4 = ComplexMaxPool2D_mag(pool_size=(2, 2))

        # Block 5 (no pooling)
        self.block5_conv1 = complex_Conv2D(512, kernel_size=3, padding='same')
        self.block5_conv2 = complex_Conv2D(512, kernel_size=3, padding='same')

        self.build((None, *input_shape))
        
    def call(self, inputs):
        real = tf.expand_dims(inputs[..., 0], axis=-1)
        imag = tf.expand_dims(inputs[..., 1], axis=-1)

        # Block 1
        r, i = self.block1_conv1(real, imag)
        r, i = CReLU(r, i)
        r, i = self.block1_conv2(r, i)
        r, i = CReLU(r, i)
        feat1 = (r, i)
        r, i = self.pool1(r, i)

        # Block 2
        r, i = self.block2_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block2_conv2(r, i)
        r, i = CReLU(r, i)
        feat2 = (r, i)
        r, i = self.pool2(r, i)

        # Block 3
        r, i = self.block3_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block3_conv2(r, i)
        r, i = CReLU(r, i)
        feat3 = (r, i)
        r, i = self.pool3(r, i)

        # Block 4
        r, i = self.block4_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block4_conv2(r, i)
        r, i = CReLU(r, i)
        feat4 = (r, i)
        r, i = self.pool4(r, i)

        # Block 5 (no pooling)
        r, i = self.block5_conv1(r, i)
        r, i = CReLU(r, i)
        r, i = self.block5_conv2(r, i)
        r, i = CReLU(r, i)
        feat5 = (r, i)

        return feat1, feat2, feat3, feat4, feat5


In [5]:
class ComplexUnetUpTF(tf.keras.layers.Layer):
    def __init__(self, out_channels):
        super(ComplexUnetUpTF, self).__init__()
        self.upsample = ComplexUpSampling2D(size=(2, 2), interpolation='bilinear')
        self.out_channels = out_channels
        self.conv1 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        self.conv2 = complex_Conv2D(out_channels, kernel_size=3, padding='same')
        
        # self.dense_block = DensBlock(out_channels)

    def call(self, real_input1, imag_input1, real_input2, imag_input2):
        # Upsample
        up_real, up_imag = self.upsample(real_input2, imag_input2)

        # Concatenate with skip connection
        concat_real, concat_imag = concatenate_with(real_input1, imag_input1, up_real, up_imag)

        # Proper flow through both conv layers
        out_real, out_imag = self.conv1(concat_real, concat_imag)
        out_real, out_imag = CReLU(out_real, out_imag)
        out_real, out_imag = self.conv2(out_real, out_imag)
        out_real, out_imag = CReLU(out_real, out_imag)

        return out_real, out_imag


In [6]:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import concatenate  # Ensure this is imported

class SF_UNet_TF(Model):
    def __init__(self, input_shape=(320, 320, 2), num_classes=2):
        super().__init__()
        self.encoder = ComplexEncoderTF(input_shape=input_shape)

        
        self.unet_up4 = ComplexUnetUpTF(512)
        self.unet_up3 = ComplexUnetUpTF(256)
        self.unet_up2 = ComplexUnetUpTF(128)
        self.unet_up1 = ComplexUnetUpTF(64)

        self.final_conv = complex_Conv2D(filters=num_classes, kernel_size=1, activation='linear')


    def call(self, inputs):
        # Encoder
        feat1, feat2, feat3, feat4, feat5 = self.encoder(inputs)
    
        # Decoder stage 4
        
        decoder4 = self.unet_up4(feat4[0], feat4[1], feat5[0], feat5[1])
    
        # Decoder stage 3
        
        decoder3 = self.unet_up3(feat3[0], feat3[1], decoder4[0], decoder4[1])
    
        # Decoder stage 2
        
        decoder2 = self.unet_up2(feat2[0], feat2[1], decoder3[0], decoder3[1])
    
        # Decoder stage 1
        
        decoder1 = self.unet_up1(feat1[0], feat1[1], decoder2[0], decoder2[1])
    
        # Final 1x1 convolution
        real_out, imag_out = self.final_conv(decoder1[0], decoder1[1])
    
        # Concatenate along last dimension: (B, H, W, C) -> (B, H, W, 2C)
        output = concatenate([real_out, imag_out], axis=-1)
    
        return output


In [7]:

from tensorflow.keras.layers import Input
from tensorflow.keras.models import Model


def build_dual_output_model(kshape=(3, 3), H=320, W=320, channels=2):
    input_img = Input(shape=(H, W, channels))

    # Image-domain SF_UNet_TF only
    unet = SF_UNet_TF()
    output_img = unet(input_img)

    # Assign a name to the output for use in loss dict
    output_img = Lambda(lambda x: x)(output_img)

    model = Model(
        inputs=input_img,
        outputs=output_img
    )

    return model
model = build_dual_output_model(H=320, W=320)
#model.summary()



Model: "model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_1 (InputLayer)        [(None, 320, 320, 2)]     0         
                                                                 
 sf_u_net_tf (SF_UNet_TF)    (None, 320, 320, 2)       9792898   
                                                                 
 lambda (Lambda)             (None, 320, 320, 2)       0         
                                                                 
Total params: 9,792,898
Trainable params: 9,792,898
Non-trainable params: 0
_________________________________________________________________
