In [15]:
import tensorflow as tf
import keras
from keras.layers import Layer, Input, Dense, Conv3D, MaxPooling3D, UpSampling3D, Dropout, Flatten,InputLayer , Reshape, concatenate, Concatenate, Activation
from keras.models import Model
from keras import backend as K
from keras.callbacks import EarlyStopping
from keras.layers.normalization import BatchNormalization
import numpy as np
        
input_voxel = Input(shape=(32, 32, 16, 1))
scales = Input(shape=(1, 3))

# Convolution
x = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(input_voxel)
x = MaxPooling3D((2, 2, 2), padding='same')(x)
x = Conv3D(64, (3, 3 ,3), activation='relu', padding='same')(x)
x = MaxPooling3D((2, 2, 2), padding='same')(x)
x = Conv3D(64, (3, 3, 3), activation='relu', padding='same')(x)

x = Flatten()(x)
x = Reshape((1, 16384))(x)
x = Concatenate()([x, scales])


x = Dense(512)(x)
x = BatchNormalization()(x)
x = Dropout(0.5)(x)

# Descriptor
description = Dense(64)(x)

# Deconvolution
x = Dense(8192)(description)
x = Reshape((8, 8, 4, 32))(x)
x = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(x)
x = UpSampling3D((2, 2, 2))(x)
x = Conv3D(32, (3, 3, 3), activation='relu', padding='same')(x)
x = UpSampling3D((2, 2, 2))(x)
reconstructed = Conv3D(1, (3, 3, 3), activation='sigmoid', padding='same', name='reconstruction_output')(x)

#Classificator
y = BatchNormalization()(description)
y = Dropout(0.5)(y)
y = Dense(4)(y)
classified = Activation('softmax', name='classification_output')(y)

def reconstruction_loss(voxels, reconstructed):
    FN_TO_FP_WEIGHT = 0.9
    loss_r = - tf.math.reduce_mean(FN_TO_FP_WEIGHT * voxels * keras.backend.log(reconstructed + 1e-10) + (1 - FN_TO_FP_WEIGHT) * \
                            (1 - voxels) * keras.backend.log(1 - reconstructed + 1e-10))
    return loss_r

def classification_loss(classes, classified):
    loss_c = -tf.math.reduce_mean(keras.losses.binary_crossentropy(classes, classified))
    return loss_c

losses = {
	"reconstruction_output": reconstruction_loss,
	"classification_output": classification_loss
}
loss_weights = {"reconstruction_output": 200, "classification_output": 1}

autoencoder = Model(inputs=[input_voxel, scales], outputs=[reconstructed, classified])
autoencoder.compile(optimizer='adadelta', loss=losses, loss_weights=loss_weights)

history = autoencoder.fit(x=[np.zeros((1, 32, 32, 16, 1)), np.zeros((1, 1, 3))], 
                          y=[np.zeros((1, 32, 32, 16, 1)), np.zeros((1, 1, 4))], 
                          epochs=50, batch_size=1)
        

Epoch 1/50
Epoch 2/50
Epoch 3/50
Epoch 4/50
Epoch 5/50
Epoch 6/50
Epoch 7/50
Epoch 8/50
Epoch 9/50
Epoch 10/50
Epoch 11/50
Epoch 12/50
Epoch 13/50
Epoch 14/50
Epoch 15/50
Epoch 16/50
Epoch 17/50
Epoch 18/50
Epoch 19/50
Epoch 20/50
Epoch 21/50
Epoch 22/50
Epoch 23/50
Epoch 24/50
Epoch 25/50
Epoch 26/50
Epoch 27/50
Epoch 28/50
Epoch 29/50
Epoch 30/50
Epoch 31/50
Epoch 32/50
Epoch 33/50
Epoch 34/50
Epoch 35/50
Epoch 36/50
Epoch 37/50
Epoch 38/50
Epoch 39/50
Epoch 40/50
Epoch 41/50
Epoch 42/50
Epoch 43/50
Epoch 44/50
Epoch 45/50
Epoch 46/50
Epoch 47/50
Epoch 48/50
Epoch 49/50
Epoch 50/50


In [18]:
print(history.params)
print(history.history['loss'])

{'batch_size': 1, 'epochs': 50, 'steps': None, 'samples': 1, 'verbose': 1, 'do_validation': False, 'metrics': ['loss', 'reconstruction_output_loss', 'classification_output_loss']}
[13.575098037719727, 13.56118392944336, 13.546806335449219, 13.53233528137207, 13.517911911010742, 13.503253936767578, 13.488819122314453, 13.47421646118164, 13.459611892700195, 13.44505500793457, 13.43010139465332, 13.415407180786133, 13.400838851928711, 13.385860443115234, 13.371288299560547, 13.356586456298828, 13.341602325439453, 13.326906204223633, 13.312309265136719, 13.29725456237793, 13.282621383666992, 13.267688751220703, 13.252969741821289, 13.238248825073242, 13.223312377929688, 13.208681106567383, 13.193626403808594, 13.179018020629883, 13.164321899414062, 13.149251937866211, 13.134624481201172, 13.119680404663086, 13.104965209960938, 13.090274810791016, 13.075395584106445, 13.06070327758789, 13.04599380493164, 13.031042098999023, 13.016422271728516, 13.001737594604492, 12.98678207397461, 12.97216