In [1]:
import tensorflow as tf

In [2]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [3]:
from model_tf2 import *
import numpy as np

 The versions of TensorFlow you are currently using is 2.4.0-rc1 and is not supported. 
Some things might work, some things might not.
If you were to encounter a bug, do not file an issue.
If you want to make sure you're using a tested and supported configuration, either change the TensorFlow version or the TensorFlow Addons's version. 
You can find the compatibility matrix in TensorFlow Addon's readme:
https://github.com/tensorflow/addons


In [4]:
input_shape = (4, 16, 16, 16)
output_channels = 3
dummy_x = np.random.randn(10, *input_shape)
dummy_y = np.random.randn(10, output_channels, *input_shape[1:])

In [5]:
class conv3d_autoenc_reg(keras.Model):
    def __init__(self, input_shape=(4, 160, 192, 128), output_channels=3, l2_reg_weight = 1e-5, weight_L2=0.1, weight_KL=0.1, 
                 dice_e=1e-8, test_mode = True, n_gpu = 1, GL_weight = 1, VL_weight = 0.1, **kwargs):
        super().__init__(**kwargs)

        self.c, self.H, self.W, self.D = input_shape
        self.n = self.c * self.H * self.W * self.D
        assert len(input_shape) == 4, "Input shape must be a 4-tuple"
        if test_mode is not True: assert (self.c % 4) == 0, "The no. of channels must be divisible by 4"
        assert (self.H % 16) == 0 and (self.W % 16) == 0 and (self.D % 16) == 0, "All the input dimensions must be divisible by 16"
        self.l2_regularizer = l2(l2_reg_weight) if l2_reg_weight is not None else None
        
        self.input_shape_p = input_shape
        self.output_channels = output_channels
        self.l2_reg_weight = l2_reg_weight
        self.weight_L2 = weight_L2
        self.weight_KL = weight_KL
        self.dice_e = dice_e
        self.GL_weight = GL_weight
        self.VL_weight = VL_weight
        
        self.LossVAE = LossVAE(weight_L2, weight_KL, self.n)
        
        ## The Initial Block
        self.Input_x1 = Conv3D(
        filters=32,
        kernel_size=(3, 3, 3),
        strides=1,
        padding='same',
        kernel_regularizer = self.l2_regularizer,
        data_format='channels_first',
        name='Input_x1')
        
        ## Dropout (0.2)
        self.spatial_dropout = SpatialDropout3D(0.2, data_format='channels_first')
        
        ## Green Block x1 (output filters = 32)
        self.x1 = green_block(32, regularizer = self.l2_regularizer, name='x1')
        self.Enc_DownSample_32 = Conv3D(
            filters=32,
            kernel_size=(3, 3, 3),
            strides=2,
            padding='same',
            kernel_regularizer = self.l2_regularizer,
            data_format='channels_first',
            name='Enc_DownSample_32')
        
        ## Green Block x2 (output filters = 64)
        self.Enc_64_1 = green_block(64, regularizer = self.l2_regularizer, name='Enc_64_1')
        self.x2 = green_block(64, regularizer = self.l2_regularizer, name='x2')
        self.Enc_DownSample_64 = Conv3D(
                            filters=64,
                            kernel_size=(3, 3, 3),
                            strides=2,
                            padding='same',
                            kernel_regularizer = self.l2_regularizer,
                            data_format='channels_first',
                            name='Enc_DownSample_64')
        
        ## Green Blocks x2 (output filters = 128)
        self.Enc_128_1 = green_block(128, regularizer = self.l2_regularizer, name='Enc_128_1')
        self.x3 = green_block(128, regularizer = self.l2_regularizer, name='x3')
        self.Enc_DownSample_128 = Conv3D(filters=128, kernel_size=(3, 3, 3), strides=2, padding='same', kernel_regularizer = self.l2_regularizer, 
                                         data_format='channels_first', name='Enc_DownSample_128')
        
        ## Green Blocks x4 (output filters = 256)
        self.Enc_256_1 = green_block(256, regularizer = self.l2_regularizer, name='Enc_256_1')
        self.Enc_256_2 = green_block(256, regularizer = self.l2_regularizer, name='Enc_256_2')
        self.Enc_256_3 = green_block(256, regularizer = self.l2_regularizer, name='Enc_256_3')
        self.x4 = green_block(256, regularizer = self.l2_regularizer, name='x4')
        
        # -------------------------------------------------------------------------
        # Decoder
        # -------------------------------------------------------------------------

        ## GT (Groud Truth) Part
        # -------------------------------------------------------------------------
        
        ### Green Block x1 (output filters=128)
        self.Dec_GT_ReduceDepth_128 = Conv3D(filters=128, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first', name='Dec_GT_ReduceDepth_128')
        self.Dec_GT_UpSample_128 = UpSampling3D(size=2, data_format='channels_first', name='Dec_GT_UpSample_128') 
        self.Input_Dec_GT_128 = Add(name='Input_Dec_GT_128')
        self.Dec_GT_128 = green_block(128, regularizer = self.l2_regularizer, name='Dec_GT_128')
        
        ### Green Block x1 (output filters=64)
        self.Dec_GT_ReduceDepth_64 = Conv3D(filters=64, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first', name='Dec_GT_ReduceDepth_64')
        self.Dec_GT_UpSample_64 = UpSampling3D(size=2, data_format='channels_first', name='Dec_GT_UpSample_64')
        self.Input_Dec_GT_64 = Add(name='Input_Dec_GT_64')
        self.Dec_GT_64 = green_block(64, regularizer = self.l2_regularizer, name='Dec_GT_64')
        
        ### Green Block x1 (output filters=32)
        self.Dec_GT_ReduceDepth_32 = Conv3D(filters=32, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first', 
                                       name='Dec_GT_ReduceDepth_32')
        self.Dec_GT_UpSample_32 = UpSampling3D(size=2, data_format='channels_first', name='Dec_GT_UpSample_32')
        self.Input_Dec_GT_32 = Add(name='Input_Dec_GT_32')
        self.Dec_GT_32 = green_block(32, regularizer = self.l2_regularizer, name='Dec_GT_32')
        
        ### Blue Block x1 (output filters=32)
        self.Input_Dec_GT_Output = Conv3D(filters=32, kernel_size=(3, 3, 3), strides=1, padding='same', kernel_regularizer = self.l2_regularizer, 
                                     data_format='channels_first', name='Input_Dec_GT_Output')
        
        ### Output Block
        self.Dec_GT_Output = Conv3D(filters=self.output_channels, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, 
                                data_format='channels_first', activation='sigmoid', name='Dec_GT_Output')
        
        ## VAE (Variational Auto Encoder) Part
        # -------------------------------------------------------------------------

        ### VD Block (Reducing dimensionality of the data)
        self.Dec_VAE_VD_GN = GroupNormalization(groups=8, axis=1, name='Dec_VAE_VD_GN')
        self.Dec_VAE_VD_relu = Activation('relu', name='Dec_VAE_VD_relu')
        self.Dec_VAE_VD_Conv3D = Conv3D(filters=16, kernel_size=(3, 3, 3), strides=2, padding='same', kernel_regularizer = self.l2_regularizer, 
                                   data_format='channels_first', name='Dec_VAE_VD_Conv3D')
        
        # Not mentioned in the paper, but the author used a Flattening layer here.
        self.Dec_VAE_VD_Flatten = Flatten(name='Dec_VAE_VD_Flatten')
        self.Dec_VAE_VD_Dense = Dense(256, name='Dec_VAE_VD_Dense')

        ### VDraw Block (Sampling)
        self.Dec_VAE_VDraw_Mean = Dense(128, name='Dec_VAE_VDraw_Mean')
        self.Dec_VAE_VDraw_Var = Dense(128, name='Dec_VAE_VDraw_Var')
#         self.Dec_VAE_VDraw_Sampling = Lambda(sampling, name='Dec_VAE_VDraw_Sampling')
        self.Dec_VAE_VDraw_Sampling = sampling()

        ### VU Block (Upsizing back to a depth of 256)
        c1 = 1
        self.VU_Dense1 = Dense((c1) * (self.H//16) * (self.W//16) * (self.D//16))
        self.VU_relu = Activation('relu')
        self.VU_reshape = Reshape(((c1), (self.H//16), (self.W//16), (self.D//16)))
        self.Dec_VAE_ReduceDepth_256 = Conv3D(filters=256, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first',
                                            name='Dec_VAE_ReduceDepth_256')
        self.Dec_VAE_UpSample_256 = UpSampling3D(size=2, data_format='channels_first', name='Dec_VAE_UpSample_256')

        ### Green Block x1 (output filters=128)
        self.Dec_VAE_ReduceDepth_128 = Conv3D(filters=128, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first', 
                                         name='Dec_VAE_ReduceDepth_128')
        self.Dec_VAE_UpSample_128 = UpSampling3D(size=2, data_format='channels_first', name='Dec_VAE_UpSample_128')
        self.Dec_VAE_128 = green_block(128, regularizer = self.l2_regularizer, name='Dec_VAE_128')

        ### Green Block x1 (output filters=64)
        self.Dec_VAE_ReduceDepth_64 = Conv3D(filters=64, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first',
                                        name='Dec_VAE_ReduceDepth_64')
        self.Dec_VAE_UpSample_64 = UpSampling3D(size=2, data_format='channels_first', name='Dec_VAE_UpSample_64')
        self.Dec_VAE_64 = green_block(64, regularizer = self.l2_regularizer, name='Dec_VAE_64')

        ### Green Block x1 (output filters=32)
        self.Dec_VAE_ReduceDepth_32 = Conv3D(filters=32, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first',
                                        name='Dec_VAE_ReduceDepth_32')
        self.Dec_VAE_UpSample_32 = UpSampling3D(size=2, data_format='channels_first', name='Dec_VAE_UpSample_32')
        self.Dec_VAE_32 = green_block(32, regularizer = self.l2_regularizer, name='Dec_VAE_32')

        ### Blue Block x1 (output filters=32)
        self.Input_Dec_VAE_Output = Conv3D(filters=32, kernel_size=(3, 3, 3), strides=1, padding='same', kernel_regularizer = self.l2_regularizer, 
                                      data_format='channels_first', name='Input_Dec_VAE_Output')

        ### Output Block
        self.Dec_VAE_Output = Conv3D(filters=self.c, kernel_size=(1, 1, 1), strides=1, kernel_regularizer = self.l2_regularizer, data_format='channels_first', 
                                     name='Dec_VAE_Output')
        
#     def build(self, batch_input_shape):
#         n_inputs = batch_input_shape[-1]
        
#         ### super build
#         super().build(batch_input_shape)
        
    def call(self, inputs, training=None):
        Z = inputs
        x = self.Input_x1(Z)
        
        ## Dropout (0.2)
        x = self.spatial_dropout(x)

        ## Green Block x1 (output filters = 32)
        x1 = self.x1(x)
        x = self.Enc_DownSample_32(x1)

        ## Green Block x2 (output filters = 64)
        x = self.Enc_64_1(x)
        x2 = self.x2(x)
        x = self.Enc_DownSample_64(x2)

        ## Green Blocks x2 (output filters = 128)
        x = self.Enc_128_1(x)
        x3 = self.x3(x)
        x = self.Enc_DownSample_128(x3)

        ## Green Blocks x4 (output filters = 256)
        x = self.Enc_256_1(x)
        x = self.Enc_256_2(x)
        x = self.Enc_256_3(x)
        x4 = self.x4(x)

        # -------------------------------------------------------------------------
        # Decoder
        # -------------------------------------------------------------------------

        ## GT (Groud Truth) Part
        # -------------------------------------------------------------------------

        ### Green Block x1 (output filters=128)
        x = self.Dec_GT_ReduceDepth_128(x4)
        x = self.Dec_GT_UpSample_128(x)
        x = self.Input_Dec_GT_128([x, x3])
        x = self.Dec_GT_128(x)

        ### Green Block x1 (output filters=64)
        x = self.Dec_GT_ReduceDepth_64(x)
        x = self.Dec_GT_UpSample_64(x)
        x = self.Input_Dec_GT_64([x, x2])
        x = self.Dec_GT_64(x)

        ### Green Block x1 (output filters=32)
        x = self.Dec_GT_ReduceDepth_32(x)
        x = self.Dec_GT_UpSample_32(x)
        x = self.Input_Dec_GT_32([x, x1])
        x = self.Dec_GT_32(x)

        ### Blue Block x1 (output filters=32)
        x = self.Input_Dec_GT_Output(x)

        ### Output Block
        out_GT = self.Dec_GT_Output(x)

        ## VAE (Variational Auto Encoder) Part
        # -------------------------------------------------------------------------

        ### VD Block (Reducing dimensionality of the data)
        x = self.Dec_VAE_VD_GN(x4)
        x = self.Dec_VAE_VD_relu(x)
        x = self.Dec_VAE_VD_Conv3D(x)

        # Not mentioned in the paper, but the author used a Flattening layer here.
        x = self.Dec_VAE_VD_Flatten(x)
        x = self.Dec_VAE_VD_Dense(x)

        ### VDraw Block (Sampling)
        z_mean = self.Dec_VAE_VDraw_Mean(x)
        z_var = self.Dec_VAE_VDraw_Var(x)
        x = self.Dec_VAE_VDraw_Sampling([z_mean, z_var])

        ### VU Block (Upsizing back to a depth of 256)
        x = self.VU_Dense1(x)
        x = self.VU_relu(x)
        x = self.VU_reshape(x)
        x = self.Dec_VAE_ReduceDepth_256(x)
        x = self.Dec_VAE_UpSample_256(x)

        ### Green Block x1 (output filters=128)
        x = self.Dec_VAE_ReduceDepth_128(x)
        x = self.Dec_VAE_UpSample_128(x)
        x = self.Dec_VAE_128(x)

        ### Green Block x1 (output filters=64)
        x = self.Dec_VAE_ReduceDepth_64(x)
        x = self.Dec_VAE_UpSample_64(x)
        x = self.Dec_VAE_64(x)

        ### Green Block x1 (output filters=32)
        x = self.Dec_VAE_ReduceDepth_32(x)
        x = self.Dec_VAE_UpSample_32(x)
        x = self.Dec_VAE_32(x)

        ### Blue Block x1 (output filters=32)
        x = self.Input_Dec_VAE_Output(x)

        ### Output Block
        out_VAE = self.Dec_VAE_Output(x) 
        
#         self.LossVAE([Z, out_VAE, z_mean, z_var])
        
        return out_GT

In [6]:
model = conv3d_autoenc_reg(input_shape, output_channels, l2_reg_weight = None)

In [7]:
opt = Adam(lr=1e-4, clipvalue=0.5)
lg = DiceLoss()
dc = dice_coefficient

model.compile(
    opt,
    [lg],
    metrics=[dc],
    loss_weights = [1.]
)

In [None]:
model.fit(dummy_x, dummy_y, batch_size = 2, epochs = 1, callbacks = [], validation_data = (dummy_x, dummy_y))

1/5 [=====>........................] - ETA: 2:06 - loss: -0.6211 - dice_coefficient: 0.6211

In [None]:
csp.sound_alert()