In [37]:
import tensorflow as tf
import tensorflow.keras.layers as kl

def PSP_Pooling(x, filter):
        # the initial input is split in channel (feature) space in 4 equal partitions
        output = [x]
        dims = x.shape[-1]
        split = tf.split(x, 4, axis=-1)
        for elem,rate in zip(split,[1,2,4,8]):
            elem = kl.MaxPool2D(pool_size = rate, strides = rate)(elem)
            elem = kl.UpSampling2D(size = rate)(elem)
            elem = kl.Conv2D(filter//4, 1, padding = "same")(elem)
            output.append(elem)
        concat = tf.concat(output, axis = -1)
        concat = kl.Conv2D(filter , 1, padding = "same")(concat)
        return concat

In [38]:
x = tf.random.uniform((32,8,8,1024))
y = PSP_Pooling(x, 1024)
y.shape

TensorShape([32, 8, 8, 1024])

In [39]:
x = tf.random.uniform((32,8,8,1024))
y = tf.random.uniform((32,8,8,1024))

print(tf.concat([x,y], axis = -1).shape)

(32, 8, 8, 2048)


In [50]:
def ResBlock_a(x, filters, dilation_list, kernel_size=3, padding="same", stride=1
    ):
        output = [x]
        for d in dilation_list:
            x_ = kl.BatchNormalization(axis=-1)(x)
            x_ = kl.Activation("relu")(x_)
            x_ = kl.Conv2D(
                filters,
                kernel_size=kernel_size,
                padding=padding,
                strides=stride,
                dilation_rate=d,
            )(x_)
            x_ = kl.BatchNormalization(axis=-1)(x_)
            x_ = kl.Activation("relu")(x_)
            x_ = kl.Conv2D(
                filters,
                kernel_size=kernel_size,
                padding=padding,
                strides=stride,
                dilation_rate=d,
            )(x_)
            output.append(x_)
        output = tf.stack(output, axis=0)
        return tf.math.reduce_sum(output, axis = 0 )

x = tf.random.uniform((32,256,256,32))
y = ResBlock_a(x,32,[1,3,15,31])
y.shape

TensorShape([32, 256, 256, 32])

In [54]:

def down_part(x):

        x = kl.Conv2D(32, 1, padding="same")(x)

        u1 = ResBlock_a(x, 32, [1, 3, 15, 31])

        u2 = kl.Conv2D(64, 1, padding="same", strides=2)(u1)
        u2 = ResBlock_a(u2, 64, [1, 3, 15, 31])

        u3 = kl.Conv2D(128, 1, padding="same", strides=2)(u2)
        u3 = ResBlock_a(u3, 128, [1, 3, 15])

        u4 = kl.Conv2D(256, 1,padding="same", strides=2)(u3)
        u4 = ResBlock_a(u4, 256, [1, 3, 15])

        u5 = kl.Conv2D(512, 1, padding="same", strides=2)(u4)
        u5 = ResBlock_a(u5, 512, [1])

        u6 = kl.Conv2D(1024, 1, padding="same", strides=2)(u5)
        u6 = ResBlock_a(u6, 1024, [1])

        return x, u1, u2, u3, u4, u5, u6

x = tf.random.uniform((32,256,256,3))

l = down_part(x)
for elem in l :
        print(elem.shape)

(32, 256, 256, 32)
(32, 256, 256, 32)
(32, 128, 128, 64)
(32, 64, 64, 128)
(32, 32, 32, 256)
(32, 16, 16, 512)
(32, 8, 8, 1024)
