In [None]:
import os
import re
import random

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
DATASET_DIR_PTH = ".\\dataset\\partnet\\chair_voxel_data"

In [None]:
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(data, key=alphanum_key)

In [None]:
BATCH_SIZE = 1

In [None]:
class DataGenerator:
    def __init__(self, dataset_dir_pth, size=None, batch_size=4, voxel_map_shape=(128, 128, 128)):
        self.dataset_dir_pth = dataset_dir_pth
        self.data_names = np.array(sorted_alphanumeric(os.listdir(self.dataset_dir_pth)), dtype=str)[:size]
        self.num_samples = len(self.data_names)
        self.batch_size = batch_size
        self.curr_index = 0
        self.indexes = np.arange(int(self.num_samples / self.batch_size) * self.batch_size)
        self.voxel_map_shape = voxel_map_shape
        self.voxel_data_sparse = self._load_voxel_data()
        
        np.random.shuffle(self.indexes)
    
    def _load_voxel_data(self):
        voxel_data_sparse = []
        for data_name in tqdm(self.data_names, desc="Loading Voxel Data"):
            data_pth = os.path.join(self.dataset_dir_pth, data_name)
            voxel_data_sparse.append(np.load(data_pth))
        return voxel_data_sparse

    def __len__(self):
        return int(self.num_samples / self.batch_size)

    def __iter__(self):
        return self

    def _data_to_dense(self, objs_voxel_data):
        dense_objs_voxel_data = []
        
        for obj_voxel_coordinates in objs_voxel_data:
            voxel_map = np.zeros(self.voxel_map_shape, dtype=np.float32)
            voxel_map[tuple(obj_voxel_coordinates.T)] = 1.0
            dense_objs_voxel_data.append(voxel_map)
            
        return dense_objs_voxel_data

    # def _rotate_obj_voxel_data(self, coordinates, rotation_angle_range):
    #     rotation_angle = random.randint(-rotation_angle_range, rotation_angle_range)
        
    #     theta = np.pi * (rotation_angle / 180)
        
    #     rotation_matrix = np.array([[np.cos(theta), 0, np.sin(theta)],
    #                                 [0, 1, 0],
    #                                 [-np.sin(theta), 0, np.cos(theta)]])
        
    #     center = np.array(self.voxel_map_shape) / 2
    #     translated_coords = coordinates - center
        
    #     rotated_coords = np.dot(translated_coords, rotation_matrix.T)
    #     rotated_coords += center
        
    #     rotated_coords = np.round(rotated_coords).astype(int)
    #     valid_indices = np.all((rotated_coords >= 0) & (rotated_coords < np.array(self.voxel_map_shape)), axis=1)
        
    #     rotated_coords = rotated_coords[valid_indices]
        
    #     return rotated_coords
        
    # def _random_transform(self, objs_voxel_data, rotation_angle_range=45, probability=1):
    #     transformed_objs_voxel_data = []

    #     for obj_voxel_coordinates in objs_voxel_data:
            
    #         if random.random() < probability:
    #             obj_voxel_coordinates = self._rotate_obj_voxel_data(obj_voxel_coordinates, rotation_angle_range)

    #         transformed_objs_voxel_data.append(obj_voxel_coordinates)

    #     return transformed_objs_voxel_data

    def _load_batched_sparse_data(self, indexes):
        voxel_data_sparse = []
        
        if self.voxel_data_sparse is not None:
            for i in indexes:
                voxel_data_sparse.append(self.voxel_data_sparse[i])
                
            return voxel_data_sparse
        else:
            indexed_data_names = self.data_names[indexes]

            for data_name in indexed_data_names:
                data_pth = os.path.join(self.dataset_dir_pth, data_name)

                voxel_data_sparse.append(np.load(data_pth))  # attention ! bottleneck here

            return voxel_data_sparse

    def __next__(self):
        if self.curr_index >= len(self.indexes):
            np.random.shuffle(self.indexes)
            self.curr_index = 0
            raise StopIteration
            
        batched_indexes = self.indexes[self.curr_index: self.curr_index + self.batch_size]
        
        batched_sparse_data = self._load_batched_sparse_data(batched_indexes)
                
        batched_dense_data = self._data_to_dense(batched_sparse_data)

        self.curr_index += self.batch_size
        
        return batched_indexes, batched_dense_data

In [None]:
data_generator = DataGenerator(dataset_dir_pth=DATASET_DIR_PTH, batch_size=BATCH_SIZE)

In [None]:
model_hparam = {
    'model_latent_code_dim': 256,   
    'model_fc_channels': 512,
    'model_fc_activation': 'relu',
    'model_norm_3d': 'batchnorm',
    'model_conv_size': 4,
    'model_num_latent_codes': len(data_generator) * BATCH_SIZE,
    'model_learning_rate_network': 1e-4,
    'model_learning_rate_codes': 1e-3,
    'model_checkpoint_dir': './ckpt'
}

In [None]:
class GeometryNetwork:
    def __init__(self, hparam):
        self.latent_code_dim = hparam['model_latent_code_dim']
        self.fc_channels = hparam['model_fc_channels']
        self.fc_activation = hparam['model_fc_activation']
        self.norm_3d = hparam['model_norm_3d']
        self.conv_size = hparam['model_conv_size']
        self.num_latent_codes = hparam['model_num_latent_codes']
        self.learning_rate_network = hparam['model_learning_rate_network']
        self.learning_rate_codes = hparam['model_learning_rate_codes']
        self.checkpoint_dir = hparam['model_checkpoint_dir']
        
        self.init_model()
        self.init_optimizer()
        self.init_losser()
        self.init_checkpoint()

    def norm_layer(self, tensor, normalization):
        if normalization and normalization.lower() == 'batchnorm':
            tensor = tf.keras.layers.BatchNormalization()(tensor)
        return tensor

    def conv_t_block_3d(self,
                        tensor,
                        num_filters,
                        size,
                        strides,
                        normalization=None,
                        dropout=False,
                        alpha_lrelu=0.2,
                        relu=True,
                        rate=0.7):
        conv_3D_transpose = tf.keras.layers.Conv3DTranspose(num_filters,
                                                            size,
                                                            strides=strides,
                                                            padding='same',
                                                            kernel_initializer=tf.keras.initializers.glorot_normal(),
                                                            use_bias=False)

        tensor = conv_3D_transpose(tensor)

        tensor = self.norm_layer(tensor, normalization)

        if relu:
            tensor = tf.keras.layers.LeakyReLU(alpha=alpha_lrelu)(tensor)

        if dropout:
            tensor = tf.keras.layers.Dropout(rate)(tensor)

        return tensor

    def get_model(self):

        with tf.name_scope('Network/'):

            latent_code = tf.keras.layers.Input(shape=(self.latent_code_dim,))
                        
            with tf.name_scope('FC_layers'):

                fc0 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(latent_code)

                fc1 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(fc0)

                fc2 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(fc1)

                fc2_as_volume = tf.keras.layers.Reshape((1, 1, 1, self.fc_channels))(fc2)

            with tf.name_scope('GLO_VoxelDecoder'):

                decoder_1 = self.conv_t_block_3d(fc2_as_volume,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_2 = self.conv_t_block_3d(decoder_1,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_3 = self.conv_t_block_3d(decoder_2,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_4 = self.conv_t_block_3d(decoder_3,

                                                 num_filters=16,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_5 = self.conv_t_block_3d(decoder_4,

                                                 num_filters=8,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_6 = self.conv_t_block_3d(decoder_5,

                                                 num_filters=4,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                conv_3D_transpose_out = tf.keras.layers.Conv3DTranspose(1,

                                                                        self.conv_size,

                                                                        strides=2,

                                                                        padding='same',

                                                                        kernel_initializer=tf.keras.initializers.glorot_normal(),

                                                                        use_bias=False)

                volume_out = conv_3D_transpose_out(decoder_6)

        return tf.keras.Model(inputs=[latent_code], outputs=[volume_out])

    def init_model(self):
        self.model = self.get_model()
        self.model_backup = self.get_model()

        self.latest_epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

        init_latent_code = tf.random.normal((self.num_latent_codes, self.latent_code_dim))
        self.latent_code_vars = tf.Variable(init_latent_code, trainable=True)

        self.trainable_variables = self.model.trainable_variables

    def init_optimizer(self):
        self.optimizer_network = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_network)
        self.optimizer_latent = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_codes)
    
    def init_losser(self):
        self.losser_bce = tf.keras.losses.BinaryCrossentropy()

    def init_checkpoint(self):
        self.checkpoint = tf.train.Checkpoint(
            model=self.model,
            latent_code_var=self.latent_code_vars,
            optimizer_network=self.optimizer_network,
            optimizer_latent=self.optimizer_latent,
            epoch=self.latest_epoch)

        self.manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
                                                  directory=self.checkpoint_dir,
                                                  max_to_keep=3)

        self.load_checkpoint()

    def load_checkpoint(self):
        latest_checkpoint = self.manager.latest_checkpoint

        if latest_checkpoint is not None:
            print('Checkpoint {} restored'.format(latest_checkpoint))

        self.checkpoint.restore(latest_checkpoint).expect_partial()

        for a, b in zip(self.model_backup.variables, self.model.variables):
            a.assign(b)
    
    @tf.function
    def train_step(self, latent_code_vars, true_voxels):
        with tf.GradientTape() as tape:
            pred_logits_voxels = self.model(latent_code_vars)
            
            pred_voxels = tf.sigmoid(pred_logits_voxels)
                        
            total_loss = self.losser_bce(true_voxels, pred_voxels)
        
        network_vars = self.trainable_variables
                
        gradients = tape.gradient(total_loss, network_vars + [latent_code_vars])
        
        self.optimizer_network.apply_gradients(zip(gradients[:len(network_vars)], network_vars))
        self.optimizer_latent.apply_gradients(zip(gradients[len(network_vars):], [latent_code_vars]))
                
        return total_loss
    
    def update_latent_code_vars(self, latent_code_vars):
        self.latent_code_vars.assign(latent_code_vars)
    
    def save_models(self, curr_epoch):
        self.latest_epoch.assign(curr_epoch)
        self.manager.save()

    def reset_models(self):
        for a, b in zip(self.model.variables, self.model_backup.variables):
            a.assign(b)

In [None]:
TRAINING_EPOCH = 1

In [None]:
geom_network = GeometryNetwork(model_hparam)

latest_epoch = geom_network.latest_epoch.numpy()

for epoch in range(latest_epoch+1, latest_epoch+TRAINING_EPOCH+1):
    total_loss = []
    
    all_latent_code_vars = geom_network.latent_code_vars.numpy()
        
    pbar = tqdm(data_generator, desc="Training EPOCH {}/{}".format(epoch, latest_epoch+TRAINING_EPOCH))
    
    for voxel_indexes, true_voxels in pbar:
        batch_latent_code_vars = tf.Variable(all_latent_code_vars[voxel_indexes], trainable=True)
        
        loss = geom_network.train_step(batch_latent_code_vars, np.expand_dims(true_voxels, axis=-1)).numpy()
        
        all_latent_code_vars[voxel_indexes] = batch_latent_code_vars.numpy()
        
        total_loss.append(loss)
                
        pbar.set_postfix({"Avg Loss": '{:.5f}'.format(sum(total_loss) / len(total_loss))}) 
    
    geom_network.update_latent_code_vars(all_latent_code_vars)
        
    print("[EPOCH {}] Average Training Loss: {:.5f}".format(epoch, sum(total_loss) / len(total_loss)))
    
    geom_network.save_models(epoch)

In [None]:
test_target = random.randint(0, len(data_generator))

true_voxels_coords = np.load(os.path.join(DATASET_DIR_PTH, str(test_target)+'.npy'))

true_voxels = np.zeros((128, 128, 128), dtype=np.float32)
true_voxels[tuple(true_voxels_coords.T)] = 1.0

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(true_voxels)

plt.show()

latent_codes = tf.expand_dims(geom_network.latent_code_vars[test_target], axis=0)

pred_logits_voxels = geom_network.model(latent_codes)
pred_voxels = tf.sigmoid(pred_logits_voxels)
pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.5), tf.float32)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(pred_voxels[0, :, :, :].numpy().reshape((pred_voxels.shape[1:4])))

plt.show()