In [None]:
import os
import re
import random

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tqdm import tqdm

In [None]:
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(data, key=alphanum_key)

In [None]:
def data_to_dense(voxel_data, voxel_map_shape=(128, 128, 128)):
    dense_objs_voxel_data = []
    
    for obj_voxel_coordinates in voxel_data:
        voxel_map = np.zeros(voxel_map_shape, dtype=np.float32)
        voxel_map[tuple(obj_voxel_coordinates.T)] = 1.0
        dense_objs_voxel_data.append(voxel_map)
        
    return dense_objs_voxel_data

In [None]:
def rotate_obj_voxel_data(voxel_coordinates, rotation_angle, voxel_map_shape=(128, 128, 128)):    
    theta = np.pi * (rotation_angle / 180)
    
    rotation_matrix = np.array([[np.cos(theta), 0, np.sin(theta)],
                                [0, 1, 0],
                                [-np.sin(theta), 0, np.cos(theta)]])
    
    center = np.array(voxel_map_shape) / 2
    translated_coords = voxel_coordinates - center
    
    rotated_coords = np.dot(translated_coords, rotation_matrix.T)
    rotated_coords += center
    
    rotated_coords = np.round(rotated_coords).astype(int)
    valid_indices = np.all((rotated_coords >= 0) & (rotated_coords < np.array(voxel_map_shape)), axis=1)
    
    rotated_coords = rotated_coords[valid_indices]
    
    return rotated_coords

In [None]:
class DataGenerator:
    def __init__(self, dataset_dir_pth, each_chair_parts_count_pth, objs_count=None, voxel_map_shape=(128, 128, 128)):
        self.dataset_dir_pth = dataset_dir_pth
        
        self.each_chair_parts_count = np.load(each_chair_parts_count_pth)[:objs_count]
        self.num_objts = objs_count
        
        self.data_names = np.array(sorted_alphanumeric(os.listdir(self.dataset_dir_pth)), dtype=str)[:self._get_total_parts_size()]
        self.num_parts = len(self.data_names)
        
        self.curr_index = 0
        self.indexes = np.arange(self.num_parts)
        
        self.voxel_map_shape = voxel_map_shape
        self.voxel_data_sparse = self._load_voxel_data()
    
    def _load_voxel_data(self):
        voxel_data_sparse = []
        for data_name in tqdm(self.data_names, desc="Loading Voxel Data"):
            data_pth = os.path.join(self.dataset_dir_pth, data_name)
            voxel_data_sparse.append(np.load(data_pth))
        return voxel_data_sparse
    
    def _get_total_parts_size(self):
        count = 0
        
        if self.num_objts == None:
            return None
        
        for i in range(self.num_objts):
            count += self.each_chair_parts_count[i]
        
        return count

    def __iter__(self):
        return self

    def _load_batched_sparse_data(self, indexes):
        voxel_data_sparse = []
        
        if self.voxel_data_sparse is not None:
            for i in indexes:
                voxel_data_sparse.append(self.voxel_data_sparse[i])
                
            return voxel_data_sparse
        else:
            indexed_data_names = self.data_names[indexes]

            for data_name in indexed_data_names:
                data_pth = os.path.join(self.dataset_dir_pth, data_name)

                voxel_data_sparse.append(np.load(data_pth))  # attention ! bottleneck here

            return voxel_data_sparse

    def __next__(self):
        if self.curr_index >= len(self.indexes):
            self.curr_index = 0
            raise StopIteration
        
        index = self.curr_index
        
        sparse_data = self._load_batched_sparse_data([index])
        
        self.curr_index += 1
        
        return index, sparse_data
    
    def reset_index(self):
        self.curr_index = 0

In [None]:
class GeometryNetwork:
    def __init__(self, hparam):
        self.latent_code_dim = hparam['model_latent_code_dim']
        self.fc_channels = hparam['model_fc_channels']
        self.fc_activation = hparam['model_fc_activation']
        self.norm_3d = hparam['model_norm_3d']
        self.conv_size = hparam['model_conv_size']
        self.num_latent_codes_parts = hparam['model_num_latent_codes_parts']
        self.num_latent_codes_objts = hparam['model_num_latent_codes_objts']
        self.learning_rate_network = hparam['model_learning_rate_network']
        self.learning_rate_codes = hparam['model_learning_rate_codes']
        self.checkpoint_dir = hparam['model_checkpoint_dir']

        self.init_model()
        self.init_optimizer()
        self.init_losser()
        self.init_checkpoint()

    def norm_layer(self, tensor, normalization):
        if normalization and normalization.lower() == 'batchnorm':
            tensor = tf.keras.layers.BatchNormalization()(tensor)
        return tensor

    def conv_t_block_3d(self,
                        tensor,
                        num_filters,
                        size,
                        strides,
                        normalization=None,
                        dropout=False,
                        alpha_lrelu=0.2,
                        relu=True,
                        rate=0.7):
        conv_3D_transpose = tf.keras.layers.Conv3DTranspose(num_filters,
                                                            size,
                                                            strides=strides,
                                                            padding='same',
                                                            kernel_initializer=tf.keras.initializers.glorot_normal(),
                                                            use_bias=False)

        tensor = conv_3D_transpose(tensor)

        tensor = self.norm_layer(tensor, normalization)

        if relu:
            tensor = tf.keras.layers.LeakyReLU(alpha=alpha_lrelu)(tensor)

        if dropout:
            tensor = tf.keras.layers.Dropout(rate)(tensor)

        return tensor

    def get_model(self):

        with tf.name_scope('Network/'):

            latent_code = tf.keras.layers.Input(shape=(self.latent_code_dim,))

            with tf.name_scope('FC_layers'):

                fc0 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(latent_code)

                fc1 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(fc0)

                fc2 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(fc1)

                fc2_as_volume = tf.keras.layers.Reshape((1, 1, 1, self.fc_channels))(fc2)

            with tf.name_scope('GLO_VoxelDecoder'):

                decoder_1 = self.conv_t_block_3d(fc2_as_volume,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_2 = self.conv_t_block_3d(decoder_1,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_3 = self.conv_t_block_3d(decoder_2,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_4 = self.conv_t_block_3d(decoder_3,

                                                 num_filters=16,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_5 = self.conv_t_block_3d(decoder_4,

                                                 num_filters=8,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_6 = self.conv_t_block_3d(decoder_5,

                                                 num_filters=4,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                conv_3D_transpose_out = tf.keras.layers.Conv3DTranspose(1,

                                                                        self.conv_size,

                                                                        strides=2,

                                                                        padding='same',

                                                                        kernel_initializer=tf.keras.initializers.glorot_normal(),

                                                                        use_bias=False)

                volume_out = conv_3D_transpose_out(decoder_6)

        return tf.keras.Model(inputs=[latent_code], outputs=[volume_out])

    def init_model(self):
        self.model = self.get_model()

        self.latest_epoch = tf.Variable(0, trainable=False, dtype=tf.int64)

        init_latent_code_parts = tf.random.normal((self.num_latent_codes_parts, self.latent_code_dim))
        self.latent_code_vars_parts = tf.Variable(init_latent_code_parts, trainable=True)

        init_latent_code_objts = tf.random.normal((self.num_latent_codes_objts, self.latent_code_dim))
        self.latent_code_vars_objts = tf.Variable(init_latent_code_objts, trainable=True)

        self.trainable_variables = self.model.trainable_variables

    def init_optimizer(self):
        self.optimizer_network = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_network)
        self.optimizer_latent_for_parts = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_codes)
        self.optimizer_latent_for_objts = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_codes)

    def init_losser(self):
        self.losser_bce = tf.keras.losses.BinaryCrossentropy()

    def init_checkpoint(self):
        self.checkpoint = tf.train.Checkpoint(
            model=self.model,
            latent_code_vars_parts=self.latent_code_vars_parts,
            latent_code_vars_objts=self.latent_code_vars_objts,
            optimizer_network=self.optimizer_network,
            optimizer_latent_for_parts=self.optimizer_latent_for_parts,
            optimizer_latent_for_objts=self.optimizer_latent_for_objts,
            epoch=self.latest_epoch
        )

        self.manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
                                                  directory=self.checkpoint_dir,
                                                  max_to_keep=1)

        self.load_checkpoint()

    def load_checkpoint(self):
        latest_checkpoint = self.manager.latest_checkpoint

        if latest_checkpoint is not None:
            print('Checkpoint {} restored'.format(latest_checkpoint))

        self.checkpoint.restore(latest_checkpoint).expect_partial()

    @tf.function
    def train_step_parts(self, latent_code_vars, true_voxels):
        with tf.GradientTape() as tape:
            pred_logits_voxels = self.model(latent_code_vars)

            pred_voxels = tf.sigmoid(pred_logits_voxels)

            loss = self.losser_bce(true_voxels, pred_voxels)

        network_vars = self.trainable_variables
        gradients = tape.gradient(loss, network_vars + [latent_code_vars])

        self.optimizer_network.apply_gradients(zip(gradients[:len(network_vars)], network_vars))
        self.optimizer_latent_for_parts.apply_gradients(zip(gradients[len(network_vars):], [latent_code_vars]))

        return loss

    @tf.function
    def train_step_objts(self, latent_code_vars, true_voxels):
        with tf.GradientTape() as tape:
            pred_logits_voxels = self.model(latent_code_vars)

            pred_voxels = tf.sigmoid(pred_logits_voxels)

            loss = self.losser_bce(true_voxels, pred_voxels)

        network_vars = self.trainable_variables
        gradients = tape.gradient(loss, network_vars + [latent_code_vars])

        self.optimizer_network.apply_gradients(zip(gradients[:len(network_vars)], network_vars))
        self.optimizer_latent_for_objts.apply_gradients(zip(gradients[len(network_vars):], [latent_code_vars]))

        return loss

    def update_latent_code_vars_parts(self, latent_code_vars):
        self.latent_code_vars_parts.assign(latent_code_vars)
    
    def update_latent_code_vars_objts(self, latent_code_vars):
        self.latent_code_vars_objts.assign(latent_code_vars)

    def save_models(self, curr_epoch):
        self.latest_epoch.assign(curr_epoch)
        self.manager.save()

In [None]:
DATASET_DIR_PTH = ".\\dataset\\partnet\\chair_voxel_data"

EACH_CHAIR_PARTS_COUNT_PTH = ".\\dataset\\partnet\\each_chair_parts_count.npy"

OBJS_COUNT = 4

VOXEL_MAP_SHAPE = (128, 128, 128)

TRAINING_EPOCH_FOR_PARTS = 1000
TRAINING_EPOCH_FOR_OBJS = TRAINING_EPOCH_FOR_PARTS

In [None]:
data_generator = DataGenerator(dataset_dir_pth=DATASET_DIR_PTH, each_chair_parts_count_pth=EACH_CHAIR_PARTS_COUNT_PTH, objs_count=OBJS_COUNT)

In [None]:
model_hparam = {
    'model_latent_code_dim': 256,   
    'model_fc_channels': 512,
    'model_fc_activation': 'relu',
    'model_norm_3d': 'batchnorm',
    'model_conv_size': 4,
    'model_num_latent_codes_parts': data_generator.num_parts,
    'model_num_latent_codes_objts': data_generator.num_objts,
    'model_learning_rate_network': 5e-4,
    'model_learning_rate_codes': 1e-3,
    'model_checkpoint_dir': './ckpt'
}

geom_network = GeometryNetwork(model_hparam)

In [None]:
latest_epoch = geom_network.latest_epoch.numpy()

pbar = tqdm(range(latest_epoch+1, latest_epoch+TRAINING_EPOCH_FOR_PARTS+1), desc="Training on Parts")

data_generator.reset_index()

for epoch in pbar:
    total_loss = []
    
    latent_code_vars_parts = geom_network.latent_code_vars_parts.numpy()
    
    for voxel_index, true_voxels in data_generator:
        true_voxels = data_to_dense(true_voxels)
                        
        latent_code_vars = tf.Variable([latent_code_vars_parts[voxel_index]], trainable=True)
        
        true_voxels = np.expand_dims(true_voxels, axis=-1)
                
        loss = geom_network.train_step_parts(latent_code_vars, true_voxels).numpy()
        
        latent_code_vars_parts[voxel_index] = latent_code_vars.numpy()
        
        total_loss.append(loss)
                
    pbar.set_postfix({"Avg Loss": '{:.9f}'.format(sum(total_loss) / len(total_loss))})
    
    geom_network.update_latent_code_vars_parts(latent_code_vars_parts)
            
# geom_network.save_models(epoch)

In [None]:
test_target_part = random.randint(0, data_generator.num_parts-1)

test_target_part = 0

true_voxels_coords = np.load(os.path.join(DATASET_DIR_PTH, str(test_target_part)+'.npy'))

true_voxel_rotated_coordinates = rotate_obj_voxel_data(true_voxels_coords, rotation_angle=180)
true_voxel_rotated = data_to_dense([true_voxel_rotated_coordinates])[0]

true_voxel_rotated_plot = np.moveaxis(true_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(true_voxel_rotated_plot)

plt.show()

latent_codes = tf.expand_dims(geom_network.latent_code_vars_parts[test_target_part], axis=0)

pred_logits_voxels = geom_network.model(latent_codes)
pred_voxels = tf.sigmoid(pred_logits_voxels)
pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.3), tf.float32).numpy()
pred_voxels = pred_voxels[0, :, :, :].reshape((pred_voxels.shape[1:4]))
pred_voxel_coordinates = np.where(pred_voxels == True)

x, y, z = pred_voxel_coordinates
pred_voxel_coordinates = np.expand_dims(x, axis=0).T
pred_voxel_coordinates = np.concatenate((pred_voxel_coordinates, np.expand_dims(y, axis=0).T), axis=1)
pred_voxel_coordinates = np.concatenate((pred_voxel_coordinates, np.expand_dims(z, axis=0).T), axis=1)

pred_voxel_rotated_coordinates = rotate_obj_voxel_data(pred_voxel_coordinates, rotation_angle=180)
pred_voxel_rotated = data_to_dense([pred_voxel_rotated_coordinates])[0]

pred_voxel_rotated_plot = np.moveaxis(pred_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(pred_voxel_rotated_plot)

plt.show()

In [None]:
pbar = tqdm(range(1, TRAINING_EPOCH_FOR_OBJS+1), desc="Training on Objects")

base_index = 0

latent_code_vars_parts = geom_network.latent_code_vars_parts.numpy()
latent_code_vars_objts = []

for obj_index, count in enumerate(data_generator.each_chair_parts_count):
    reduced_latent_code_vars = np.mean(latent_code_vars_parts[base_index:base_index+count], axis=0)
    latent_code_vars_objts.append(tf.Variable(reduced_latent_code_vars, trainable=True))
    
    base_index += count
    
geom_network.update_latent_code_vars_objts(latent_code_vars_objts)

data_generator.reset_index()

for epoch in pbar:
    total_loss = []
    
    latent_code_vars_objts = geom_network.latent_code_vars_objts.numpy()
    
    range_count = 0
    
    for obj_index, count in enumerate(data_generator.each_chair_parts_count):
        range_count += count
        
        latent_code_vars = tf.Variable(tf.expand_dims(latent_code_vars_objts[obj_index], axis=0), trainable=True)
                                        
        true_voxels = np.zeros(shape=(VOXEL_MAP_SHAPE), dtype=np.float32)
        
        for voxel_index, sparse_data in data_generator:
            for x, y, z in sparse_data[0]:
                true_voxels[x, y, z] = 1
                
            if voxel_index == range_count-1:
                if voxel_index == data_generator.num_parts-1:
                    data_generator.reset_index()
                break
                                    
        true_voxels = np.expand_dims(true_voxels, axis=(0, -1))
                                        
        loss = geom_network.train_step_objts(latent_code_vars, true_voxels).numpy()
                        
        total_loss.append(loss)
        
    pbar.set_postfix({"Avg Loss": '{:.9f}'.format(sum(total_loss) / len(total_loss))})
    
    geom_network.update_latent_code_vars_objts(latent_code_vars_objts)

In [None]:
test_target_obj = random.randint(0, data_generator.num_objts-1)

test_target_obj = 1

base_index = 0

for i in range(test_target_obj):
    base_index += data_generator.each_chair_parts_count[i]

true_voxels = np.zeros(shape=(VOXEL_MAP_SHAPE), dtype=np.int32)

for i in range(base_index, base_index+data_generator.each_chair_parts_count[test_target_obj]):
    true_voxels_coords = np.load(os.path.join(DATASET_DIR_PTH, str(i)+'.npy'))
    
    for x, y, z in true_voxels_coords:
        true_voxels[x, y, z] = 1

true_voxel_coordinates = np.where(true_voxels == 1)

x, y, z = true_voxel_coordinates
true_voxel_coordinates = np.expand_dims(x, axis=0).T
true_voxel_coordinates = np.concatenate((true_voxel_coordinates, np.expand_dims(y, axis=0).T), axis=1)
true_voxel_coordinates = np.concatenate((true_voxel_coordinates, np.expand_dims(z, axis=0).T), axis=1)

true_voxel_rotated_coordinates = rotate_obj_voxel_data(true_voxel_coordinates, rotation_angle=180)
true_voxel_rotated = data_to_dense([true_voxel_rotated_coordinates])[0]

true_voxel_rotated_plot = np.moveaxis(true_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(true_voxel_rotated_plot)

plt.show()

# plt.savefig('./{}_true_voxel'.format(str(test_target_obj)) + '.png', dpi=300)

latent_codes = tf.expand_dims(geom_network.latent_code_vars_objts[test_target_obj], axis=0)

pred_logits_voxels = geom_network.model(latent_codes)
pred_voxels = tf.sigmoid(pred_logits_voxels)
pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.3), tf.float32).numpy()
pred_voxels = pred_voxels[0, :, :, :].reshape((pred_voxels.shape[1:4]))
pred_voxel_coordinates = np.where(pred_voxels == True)

x, y, z = pred_voxel_coordinates
pred_voxel_coordinates = np.expand_dims(x, axis=0).T
pred_voxel_coordinates = np.concatenate((pred_voxel_coordinates, np.expand_dims(y, axis=0).T), axis=1)
pred_voxel_coordinates = np.concatenate((pred_voxel_coordinates, np.expand_dims(z, axis=0).T), axis=1)

pred_voxel_rotated_coordinates = rotate_obj_voxel_data(pred_voxel_coordinates, rotation_angle=180)
pred_voxel_rotated = data_to_dense([pred_voxel_rotated_coordinates])[0]

pred_voxel_rotated_plot = np.moveaxis(pred_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(pred_voxel_rotated_plot)

plt.show()

# plt.savefig('./{}_pred_voxel'.format(str(test_target_obj)) + '.png', dpi=300)