In [2]:
import tensorflow as tf
from tensorflow.keras import layers, regularizers, Sequential


class CSPN_Plus_Kernel_Generate(tf.keras.Model):
    def __init__(self, kernel_size):
        super(CSPN_Plus_Kernel_Generate, self).__init__()
        self.kernel_size = kernel_size
        tran_kernel_initial = self.tran_kernel_gene(kernel_size)
        self.tran_kernel = tf.Variable(tran_kernel_initial, trainable=False)
        self.guide_tensor_gene = Sequential(
            [
                layers.Conv2D(
                    kernel_size * kernel_size - 1,
                    (3, 3),
                    strides=(1, 1),
                    padding="same",
                    kernel_initializer=tf.keras.initializers.GlorotUniform(),
                ),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.2),
            ]
        )

    @staticmethod
    def wight_tensor_tran_gene(wight_tensor, tran_kernel):
        return tf.nn.conv2d(
            wight_tensor, tran_kernel, strides=[1, 1, 1, 1], padding="SAME"
        )

    @staticmethod
    def tran_kernel_gene(kernel_size):
        tran_kernel = tf.zeros(
            [
                kernel_size,
                kernel_size,
                kernel_size * kernel_size,
                kernel_size * kernel_size,
            ]
        )
        kernel_range_list = [i for i in range(kernel_size - 1, -1, -1)]  # [2, 1, 0]
        ls = []
        for i in range(kernel_size):
            ls.extend(kernel_range_list)
            index = [
                [val for val in kernel_range_list for j in range(kernel_size)],
                ls,
                [j for j in range(kernel_size * kernel_size)],
                [j for j in range(kernel_size * kernel_size)],
            ]

        indices = tf.constant(list(zip(*index)))
        updates = tf.ones(len(indices))
        tran_kernel = tf.tensor_scatter_nd_update(tran_kernel, indices, updates)

        return tran_kernel

    @staticmethod
    def wight_tensor_gene(guide_tensor):
        abs_guide = tf.abs(guide_tensor)
        sum_kernel = tf.ones([1, 1, guide_tensor.shape[-1], 1])
        guide_sum = tf.nn.conv2d(
            abs_guide, sum_kernel, strides=[1, 1, 1, 1], padding="VALID"
        )

        guide = tf.math.divide(guide_tensor, guide_sum)

        guide_mid = 1 - tf.reduce_sum(guide, axis=-1, keepdims=True)
        half1, half2 = tf.split(guide, 2, axis=-1)

        wigth_tensor = tf.concat([half1, guide_mid, half2], axis=-1)
        return wigth_tensor

    def call(self, feature):
        guide_tensor = self.guide_tensor_gene(feature)
        wight_tensor = self.wight_tensor_gene(guide_tensor)

        tran_kernel = self.tran_kernel_gene(self.kernel_size)

        wight_tensor_tran = self.wight_tensor_tran_gene(wight_tensor, tran_kernel)

        return wight_tensor_tran


model = CSPN_Plus_Kernel_Generate(5)
input = tf.keras.Input(shape=(2048, 2048, 64))
output = model(input)
print(output.shape)




(None, 2048, 2048, 25)


In [3]:
class CSPN_Plus_Calcu(tf.keras.Model):
    def __init__(self, kernel_size, dilation=1):
        super(CSPN_Plus_Calcu, self).__init__()
        self.kernel_size = kernel_size
        self.dilation = dilation
        self.padding = int(((kernel_size - 1) * dilation) / 2)

    def depth_iter_reshape(self, depht_iter):
        sizes = [1, self.kernel_size, self.kernel_size, 1]
        rates = [1, self.dilation, self.dilation, 1]

        depht_iter = tf.pad(
            depht_iter,
            [
                [0, 0],
                [self.padding, self.padding],
                [self.padding, self.padding],
                [0, 0],
            ],
        )

        depht_iter_patches = tf.image.extract_patches(
            depht_iter, sizes=sizes, strides=[1, 1, 1, 1], rates=rates, padding="VALID"
        )

        output_shape = [
            -1,
            self.kernel_size * self.kernel_size * depht_iter.shape[-1],
            depht_iter_patches.shape[1] * depht_iter_patches.shape[2],
        ]
        depht_iter_patches = tf.reshape(depht_iter_patches, output_shape)

        return depht_iter_patches

    @staticmethod
    def depth_raw_and_kernel_reshape(depth_raw):
        depth_transposed = tf.transpose(depth_raw, [0, 3, 1, 2])
        depth_reshaped = tf.reshape(
            depth_transposed,
            [
                tf.shape(depth_raw)[0],
                tf.shape(depth_raw)[3],
                tf.shape(depth_raw)[1] * tf.shape(depth_raw)[2],
            ],
        )
        return depth_reshaped

    def replace_middle_row(self, depht_iter_matrix, depth_raw_matrix):
        mid_index = int((self.kernel_size * self.kernel_size - 1) / 2)
        top_half = depht_iter_matrix[:, :mid_index, :]
        bottom_half = depht_iter_matrix[:, mid_index + 1 :, :]
        depht_iter_matrix = tf.concat([top_half, depth_raw_matrix, bottom_half], axis=1)
        return depht_iter_matrix

    def call(
        self, wight_tensor_tran, depth_iter, depth_raw
    ):  # wight_tensor_tran[2, 10, 10, 9], depht_iter = depht_iter = [2, 10, 10, 1]
        depht_iter_matrix = self.depth_iter_reshape(depth_iter)  # [2, 9, 100]
        kernel_matrix = self.depth_raw_and_kernel_reshape(
            wight_tensor_tran
        )  # [2, 9, 100]
        depth_raw_matrix = self.depth_raw_and_kernel_reshape(depth_raw)  # [2, 1, 100]

        depht_iter_matrix = self.replace_middle_row(depht_iter_matrix, depth_raw_matrix)

        depth_out = tf.einsum("ijk,ijk->ik", depht_iter_matrix, kernel_matrix)

        depth_out_reshaped = tf.reshape(
            depth_out,
            [tf.shape(depth_raw)[0], tf.shape(depth_raw)[1], tf.shape(depth_raw)[2], 1],
        )
        return depth_out_reshaped


dummy_1 = tf.keras.Input(shape=(10, 10, 9))
dummy_2 = tf.keras.Input(shape=(10, 10, 1))
model = CSPN_Plus_Calcu(3, 1)
output = model(dummy_1, dummy_2, dummy_2)
print(output.shape)

(None, 10, 10, 1)


In [17]:
class CSPN_Plus(tf.keras.Model):
    def __init__(self, dilation=1, mask=False, regularize_rate=0):
        super(CSPN_Plus, self).__init__()
        self.dilation = dilation
        self.mask = mask

        self.kernel_gene_3 = CSPN_Plus_Kernel_Generate(3)
        self.kernel_gene_5 = CSPN_Plus_Kernel_Generate(5)
        self.kernel_gene_7 = CSPN_Plus_Kernel_Generate(7)

        self.cspn_cal_3 = CSPN_Plus_Calcu(3, self.dilation)
        self.cspn_cal_5 = CSPN_Plus_Calcu(5, self.dilation)
        self.cspn_cal_7 = CSPN_Plus_Calcu(7, self.dilation)

        self.kernel_gate_gene = Sequential(
            [
                layers.Conv2D(
                    3,
                    (3, 3),
                    strides=(1, 1),
                    padding="same",
                    kernel_initializer=tf.keras.initializers.GlorotUniform(),
                    kernel_regularizer=regularizers.l2(regularize_rate),
                ),
                layers.BatchNormalization(),
                layers.Softmax(axis=3),
            ]
        )

        self.iterate_gate_gene = Sequential(
            [
                layers.Conv2D(
                    12,
                    (3, 3),
                    strides=(1, 1),
                    padding="same",
                    kernel_initializer=tf.keras.initializers.GlorotUniform(),
                    kernel_regularizer=regularizers.l2(regularize_rate),
                ),
                layers.BatchNormalization(),
                layers.LeakyReLU(alpha=0.2),
            ]
        )

        if mask == True:
            self.kernel_mask_gene = Sequential(
                [
                    layers.Conv2D(
                        3,
                        (3, 3),
                        strides=(1, 1),
                        padding="same",
                        kernel_initializer=tf.keras.initializers.GlorotUniform(),
                        kernel_regularizer=regularizers.l2(regularize_rate),
                    ),
                    layers.BatchNormalization(),
                    layers.Activation("sigmoid"),
                ]
            )

    def call(self, Feature, depth_raw):
        kernel_3 = self.kernel_gene_3(Feature)
        kernel_5 = self.kernel_gene_5(Feature)
        kernel_7 = self.kernel_gene_7(Feature)

        kernel_gate = self.kernel_gate_gene(Feature)

        kernel_gate_list = []
        for i in range(3):
            kernel_gate_list.append(kernel_gate[:, :, :, i : i + 1])

        iterate_gate = self.iterate_gate_gene(Feature)
        kernel_3_iterate_gate = tf.nn.softmax(iterate_gate[:, :, :, 0:4], axis=3)
        kernel_5_iterate_gate = tf.nn.softmax(iterate_gate[:, :, :, 4:8], axis=3)
        kernel_7_iterate_gate = tf.nn.softmax(iterate_gate[:, :, :, 8:12], axis=3)

        kernel_3_iterate_gate_list = []
        for i in range(4):
            kernel_3_iterate_gate_list.append(kernel_3_iterate_gate[:, :, :, i : i + 1])

        kernel_5_iterate_gate_list = []
        for i in range(4):
            kernel_5_iterate_gate_list.append(kernel_5_iterate_gate[:, :, :, i : i + 1])

        kernel_7_iterate_gate_list = []
        for i in range(4):
            kernel_7_iterate_gate_list.append(kernel_7_iterate_gate[:, :, :, i : i + 1])

        if self.mask == True:
            valid_mask = tf.where(
                depth_raw > 0, tf.ones_like(depth_raw), tf.zeros_like(depth_raw)
            )
            kernel_mask = self.kernel_mask_gene(Feature)
            kernel_mask = tf.multiply(kernel_mask, valid_mask)
            kernel_3_mask = kernel_mask[:, :, :, 0:1]
            kernel_5_mask = kernel_mask[:, :, :, 1:2]
            kernel_7_mask = kernel_mask[:, :, :, 2:3]

        depth_iter_3 = depth_raw
        depth_iter_5 = depth_raw
        depth_iter_7 = depth_raw

        list_depth_3 = []
        list_depth_5 = []
        list_depth_7 = []
        i = 0
        for i in range(12):
            depth_iter_3 = self.cspn_cal_3(kernel_3, depth_iter_3, depth_raw)
            depth_iter_5 = self.cspn_cal_5(kernel_5, depth_iter_5, depth_raw)
            depth_iter_7 = self.cspn_cal_7(kernel_7, depth_iter_7, depth_raw)
            if self.mask == True:
                depth_iter_3 = tf.multiply(kernel_3_mask, depth_raw) + tf.multiply(
                    (1 - kernel_3_mask), depth_iter_3
                )
                depth_iter_5 = tf.multiply(kernel_5_mask, depth_raw) + tf.multiply(
                    (1 - kernel_5_mask), depth_iter_5
                )
                depth_iter_7 = tf.multiply(kernel_7_mask, depth_raw) + tf.multiply(
                    (1 - kernel_7_mask), depth_iter_7
                )
            if (i == 2) or (i == 5) or (i == 8) or (i == 11):
                list_depth_3.append(depth_iter_3)
                list_depth_5.append(depth_iter_5)
                list_depth_7.append(depth_iter_7)

        refined_depth = 0
        for i in range(4):
            refined_depth += tf.multiply(
                tf.multiply(list_depth_3[i], kernel_gate_list[0]),
                kernel_3_iterate_gate_list[i],
            )
            refined_depth += tf.multiply(
                tf.multiply(list_depth_5[i], kernel_gate_list[1]),
                kernel_5_iterate_gate_list[i],
            )
            refined_depth += tf.multiply(
                tf.multiply(list_depth_7[i], kernel_gate_list[2]),
                kernel_7_iterate_gate_list[i],
            )

        return refined_depth


model = CSPN_Plus(dilation=2, mask=True)

dummy_1 = tf.keras.Input(shape=(4089, 4089, 64))
dummy_2 = tf.keras.Input(shape=(4089, 4089, 1))

output = model(dummy_1, dummy_2)

print(output.shape)

model.summary()

(None, 4089, 4089, 1)
Model: "cspn__plus_13"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 cspn__plus__kernel__genera  multiple                  5377      
 te_40 (CSPN_Plus_Kernel_Ge                                      
 nerate)                                                         
                                                                 
 cspn__plus__kernel__genera  multiple                  29569     
 te_41 (CSPN_Plus_Kernel_Ge                                      
 nerate)                                                         
                                                                 
 cspn__plus__kernel__genera  multiple                  145537    
 te_42 (CSPN_Plus_Kernel_Ge                                      
 nerate)                                                         
                                                                 
 cspn__plus__calcu_40 (CSPN  mu