In [None]:
import tensorflow as tf
from tensorflow.keras import layers, metrics, losses
from tensorflow.keras.models import Model 
import numpy as np

In [None]:
class CustomLayer(Model):
    def __init__(self,input_shape, filters, n_layers, is_down = True):
        super(CustomLayer, self).__init__(name='CustomLayer')
        self.clayers = []
        filter1, filter2 = filters 
        
        self.input_layer = layers.InputLayer(input_shape=(*input_shape, filter1))

        for _ in range(n_layers):
            self.clayers.append(layers.Conv3D(filter1, 5, strides=1, padding='same'))

        if is_down:
            self.out = layers.Conv3D(filter2, 2, strides=2, padding='valid')
        else:
            self.out = layers.Conv3DTranspose(filter2, 2, strides=2, padding='valid')
        
        self.prelu = layers.PReLU() 
        
    def call(self, input_tensor1, input_tensor2=None, training=False):
        if input_tensor2 is not None:
            input_tensor = layers.Concatenate(axis=4)([input_tensor1, input_tensor2])
        else:
            input_tensor = input_tensor1
        input_tensor = self.input_layer(input_tensor)
        x = input_tensor
        for layer in self.clayers:
            x = layer(x)
        x += input_tensor1
        out = self.out(x)
        return x, self.prelu(out)


In [None]:
class VNet(Model):
    def __init__(self, input_shape, batch_size):
        super(VNet, self).__init__(name='VNet')
        self.batch_size = batch_size
        self.input_layer = layers.InputLayer(input_shape=(input_shape), batch_size=self.batch_size)
        self.layer1 = CustomLayer((128,128,64), (1,16), 1)
        self.layer2 = CustomLayer((64,64,32), (16,32), 2)
        self.layer3 = CustomLayer((32,32,16), (32,64), 3)
        self.layer4 = CustomLayer((16,16,8), (64,128), 3)
        self.layer5 = CustomLayer((8,8,4), (128,256), 3, False)
        self.layer6 = CustomLayer((16,16,8), (256,128), 3, False)
        self.layer7 = CustomLayer((32,32,16), (128,64), 3, False)
        self.layer8 = CustomLayer((64,64,32), (64,32), 2, False)
        self.layer9 = layers.Conv3D(32, 5, strides=1, padding='same')
        self.layer10 = layers.Conv3D(1, 1, padding='same')



    def call(self, input_tensor, training = False):
        input_tensor = self.input_layer(input_tensor)
        o1, l1 = self.layer1(input_tensor)
        o2, l2 = self.layer2(l1)
        o3, l3 = self.layer3(l2)
        o4, l4 = self.layer4(l3)
        _, l5 = self.layer5(l4)
        _, l6 = self.layer6(l5, o4)
        _, l7 = self.layer7(l6, o3)
        _, l8 = self.layer8(l7, o2)
        l8_ = layers.Concatenate(axis=4)([l8, o1])
        l9 = self.layer9(l8_)
        l9 += l8
        l10 = self.layer10(l9)
        return tf.nn.softmax(l10)

In [None]:
x = np.random.rand(1,128,128,64,1)
x = tf.constant(x)
mod = VNet(input_shape=(128,128,64,1,), batch_size=1)
mod.compile()

p = mod.predict(x)