In [None]:
import os
import re
import random

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import imageio.v2 as imageio

from tqdm import tqdm

In [None]:
def sorted_alphanumeric(data):
    convert = lambda text: int(text) if text.isdigit() else text.lower()
    alphanum_key = lambda key: [ convert(c) for c in re.split('([0-9]+)', key) ] 
    return sorted(data, key=alphanum_key)

In [None]:
def data_to_dense(voxel_data, voxel_map_shape=(128, 128, 128), value=1.0):
    dense_objs_voxel_data = []
    
    for obj_voxel_coordinates in voxel_data:
        voxel_map = np.zeros(voxel_map_shape, dtype=np.float32)
        voxel_map[tuple(obj_voxel_coordinates.T)] = value
        dense_objs_voxel_data.append(voxel_map)
        
    return dense_objs_voxel_data

In [None]:
def rotate_obj_along_axis(voxel_coordinates, rotation_angle, axis, voxel_map_shape=(128, 128, 128)):
    theta = np.radians(rotation_angle)
    
    if axis == 'x':
        rot_matrix = np.array([[1, 0, 0],
                                [0, np.cos(theta), -np.sin(theta)],
                                [0, np.sin(theta), np.cos(theta)]])
    elif axis == 'y':
        rot_matrix = np.array([[np.cos(theta), 0, np.sin(theta)],
                                [0, 1, 0],
                                [-np.sin(theta), 0, np.cos(theta)]])
    elif axis == 'z':
        rot_matrix = np.array([[np.cos(theta), -np.sin(theta), 0],
                                [np.sin(theta), np.cos(theta), 0],
                                [0, 0, 1]])
    else:
        raise ValueError("Invalid axis. Must be 'x', 'y', or 'z'.")
        
    center = np.array(voxel_map_shape) / 2
    centered_coordinates = voxel_coordinates - center
        
    rotated_coords = np.dot(centered_coordinates, rot_matrix.T)
    rotated_coords += center
    
    rotated_coords = np.round(rotated_coords).astype(int)
    valid_indices = np.all((rotated_coords >= 0) & (rotated_coords < np.array(voxel_map_shape)), axis=1)
    
    rotated_coords = rotated_coords[valid_indices]

    return rotated_coords

In [None]:
@tf.function
def tf_rotate_obj_along_axis(voxel_coordinates, rotation_angle, axis, voxel_map_shape=(128, 128, 128)):
    theta = rotation_angle * (np.pi / 180)

    if axis == 'x':
        rot_matrix = tf.stack([
            tf.constant([1, 0, 0], dtype=tf.float32),
            tf.stack([0, tf.cos(theta), -tf.sin(theta)], axis=0),
            tf.stack([0, tf.sin(theta), tf.cos(theta)], axis=0)
        ], axis=0)
    elif axis == 'y':
        rot_matrix = tf.stack([
            tf.stack([tf.cos(theta), 0, tf.sin(theta)], axis=0),
            tf.constant([0, 1, 0], dtype=tf.float32),
            tf.stack([-tf.sin(theta), 0, tf.cos(theta)], axis=0)
        ], axis=0)
    elif axis == 'z':
        rot_matrix = tf.stack([
            tf.stack([tf.cos(theta), -tf.sin(theta), 0], axis=0),
            tf.stack([tf.sin(theta), tf.cos(theta), 0], axis=0),
            tf.constant([0, 0, 1], dtype=tf.float32)
        ], axis=0)
    else:
        raise ValueError("Invalid axis. Must be 'x', 'y', or 'z'.")
    
    center = tf.expand_dims(tf.cast(tf.constant(voxel_map_shape, dtype=tf.float32) / 2, tf.float32), 0)
    center = tf.tile(center, [tf.shape(voxel_coordinates)[0], 1])
        
    centered_coordinates = voxel_coordinates - center

    rotated_coords = tf.matmul(centered_coordinates, rot_matrix, transpose_b=True)
    rotated_coords += center

    rotated_coords = tf.round(rotated_coords)

    valid_indices = tf.reduce_all((rotated_coords >= 0) & (rotated_coords < tf.constant(voxel_map_shape, dtype=tf.float32)), axis=1)
    
    rotated_coords = tf.boolean_mask(rotated_coords, valid_indices)

    return rotated_coords

In [None]:
@tf.function
def tf_project_to_silhouette(voxels, img_wh=128):
    voxels = tf.reshape(voxels, voxels.shape[:-1])

    batch_size = tf.shape(voxels)[0]
    projection_planes = tf.TensorArray(tf.int32, size=batch_size, dynamic_size=False, infer_shape=False)

    for i in tf.range(batch_size):
        angles = tf.random.uniform((3,), 0, 360)
        axis = ['x', 'y', 'z']

        voxel_coords = tf.where(voxels[i] >= 0.5)
        voxel_coords = tf.cast(voxel_coords, tf.float32)

        rotated_coords = voxel_coords

        for j in range(len(angles)):
            rotated_coords = tf_rotate_obj_along_axis(rotated_coords, angles[j], axis[j])

        x = rotated_coords[:, 0]
        y = rotated_coords[:, 1]

        valid_x = tf.logical_and(x < img_wh, x >= 0)
        valid_y = tf.logical_and(y < img_wh, y >= 0)
        valid_coords = tf.logical_and(valid_x, valid_y)

        x_valid = tf.boolean_mask(x, valid_coords)
        y_valid = tf.boolean_mask(y, valid_coords)

        indices = tf.cast(tf.stack((y_valid, x_valid), axis=1), tf.int32)
        updates = tf.ones_like(x_valid, dtype=tf.int32)
        projection_plane = tf.tensor_scatter_nd_update(tf.zeros((img_wh, img_wh), dtype=tf.int32), indices, updates)
        projection_planes = projection_planes.write(i, projection_plane)

    return projection_planes.stack()

In [None]:
class DataGenerator:
    def __init__(self, dataset_dir_pth, each_chair_parts_count_pth, objs_count=None, voxel_map_shape=(128, 128, 128), batch_size=4):
        self.dataset_dir_pth = dataset_dir_pth
        
        self.each_chair_parts_count = np.load(each_chair_parts_count_pth)[:objs_count]
        self.num_objts = objs_count
        
        self.data_names = np.array(sorted_alphanumeric(os.listdir(self.dataset_dir_pth)), dtype=str)[:self._get_total_parts_size()]
        self.num_parts = len(self.data_names)
        
        self.curr_index = 0
        self.indexes = np.arange(self.num_parts)
        
        self.voxel_map_shape = voxel_map_shape
        
        self.batch_szie = batch_size
        
        self.voxel_data_sparse = self._load_voxel_data()
    
    def _get_total_parts_size(self):
        count = 0
        
        if self.num_objts == None:
            return None
        
        for i in range(self.num_objts):
            count += self.each_chair_parts_count[i]
        
        return count
    
    def _load_voxel_data(self):
        voxel_data_sparse = []
        for data_name in tqdm(self.data_names, desc="Loading Voxel Data"):
            data_pth = os.path.join(self.dataset_dir_pth, data_name)
            voxel_data_sparse.append(np.load(data_pth))
        return voxel_data_sparse

    def __iter__(self):
        return self

    def __next__(self):
        if self.curr_index >= len(self.indexes):
            self.curr_index = 0
            raise StopIteration
        
        indexes = []
        
        for i in range(self.curr_index, self.curr_index + self.batch_szie):
            if i >= len(self.indexes):
                break
            indexes.append(i)
        
        sparse_data = self._load_batched_sparse_data(indexes)
        
        self.curr_index += self.batch_szie
        
        return indexes, sparse_data
    
    def _load_batched_sparse_data(self, indexes):
        voxel_data_sparse = []
        
        for i in indexes:
            voxel_data_sparse.append(self.voxel_data_sparse[i])
            
        return voxel_data_sparse
    
    def reset_index(self):
        self.curr_index = 0

In [None]:
EACH_CHAIR_PARTS_COUNT_PTH = ".\\dataset\\each_chair_parts_count.npy"
DATASET_DIR_PTH = ".\\dataset\\chair_voxel_data"

LOAD_OBJS_COUNT = 1
VOXEL_MAP_SHAPE = (128, 128, 128)

BATCH_SIZE = 1

In [None]:
data_generator = DataGenerator(dataset_dir_pth=DATASET_DIR_PTH,
                               each_chair_parts_count_pth=EACH_CHAIR_PARTS_COUNT_PTH,
                               objs_count=LOAD_OBJS_COUNT,
                               voxel_map_shape=VOXEL_MAP_SHAPE,
                               batch_size=BATCH_SIZE)

In [None]:
class GeometryNetwork:
    def __init__(self, hparam):
        self.latent_code_dim = hparam['model_latent_code_dim']

        self.fc_channels = hparam['model_fc_channels']

        self.conv_size = hparam['model_conv_size']

        self.num_latent_codes_parts = hparam['model_num_latent_codes_parts']
        self.num_latent_codes_objts = hparam['model_num_latent_codes_objts']

        self.learning_rate_network = hparam['model_learning_rate_network']
        self.learning_rate_codes = hparam['model_learning_rate_codes']

        self.model_voxel_map_shape = hparam['model_voxel_map_shape']

        self.checkpoint_dir = hparam['model_checkpoint_dir']
        
        self.ramdom_projection_num = hparam['modelramdom_projection_num']

        self._init_model()
        self._init_optimizer()
        self._init_losser()
        self._init_checkpoint()

    def _init_model(self):
        self.part_generator = self._get_generator()
        self.objt_generator = self._get_generator()

        init_latent_code_parts = tf.random.normal((self.num_latent_codes_parts, self.latent_code_dim))
        self.latent_code_vars_parts = tf.Variable(init_latent_code_parts, trainable=True)

        init_latent_code_objts = tf.random.normal((self.num_latent_codes_objts, self.latent_code_dim))
        self.latent_code_vars_objts = tf.Variable(init_latent_code_objts, trainable=True)

        self.part_generator_trainable_variables = self.part_generator.trainable_variables
        self.objt_generator_trainable_variables = self.objt_generator.trainable_variables

    def _get_generator(self):

        with tf.name_scope('Network/'):

            latent_code = tf.keras.layers.Input(shape=(self.latent_code_dim,))

            with tf.name_scope('FC_layers'):

                fc0 = tf.keras.layers.Dense(self.fc_channels, activation='relu')(latent_code)

                fc1 = tf.keras.layers.Dense(self.fc_channels, activation='relu')(fc0)

                fc2 = tf.keras.layers.Dense(self.fc_channels, activation='relu')(fc1)

                fc2_as_volume = tf.keras.layers.Reshape((1, 1, 1, self.fc_channels))(fc2)

            with tf.name_scope('GLO_VoxelDecoder'):

                decoder_1 = self._conv_t_block_3d(fc2_as_volume, num_filters=32, size=self.conv_size, strides=2)

                decoder_2 = self._conv_t_block_3d(decoder_1, num_filters=32, size=self.conv_size, strides=2)

                decoder_3 = self._conv_t_block_3d(decoder_2, num_filters=32, size=self.conv_size, strides=2)

                decoder_4 = self._conv_t_block_3d(decoder_3, num_filters=16, size=self.conv_size, strides=2)

                decoder_5 = self._conv_t_block_3d(decoder_4, num_filters=8, size=self.conv_size, strides=2)

                decoder_6 = self._conv_t_block_3d(decoder_5, num_filters=4, size=self.conv_size, strides=2)

                volume_out = self._conv_t_block_3d(decoder_6, num_filters=1, size=self.conv_size, strides=2, output_mode=True)

        return tf.keras.Model(inputs=[latent_code], outputs=[volume_out])

    def _conv_t_block_3d(self, tensor, num_filters, size, strides, alpha_lrelu=0.2, output_mode=False):
        conv_3D_transpose = tf.keras.layers.Conv3DTranspose(
            filters=num_filters,
            kernel_size=size,
            strides=strides,
            padding='same',
            kernel_initializer=tf.keras.initializers.glorot_normal(),
            use_bias=False
        )

        tensor = conv_3D_transpose(tensor)

        if output_mode:
            return tensor

        tensor = tf.keras.layers.BatchNormalization()(tensor)

        tensor = tf.keras.layers.LeakyReLU(alpha=alpha_lrelu)(tensor)

        return tensor

    def _init_optimizer(self):
        self.optimizer_part_generator = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_network)
        self.optimizer_objt_generator = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_network)
        self.optimizer_latent_for_parts = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_codes)
        self.optimizer_latent_for_objts = tf.keras.optimizers.Adam(learning_rate=self.learning_rate_codes)

    def _init_losser(self):
        self.losser_bce = tf.keras.losses.BinaryCrossentropy()

    def _init_checkpoint(self):
        self.checkpoint = tf.train.Checkpoint(
            part_generator=self.part_generator,
            objt_generator=self.objt_generator,
            latent_code_vars_parts=self.latent_code_vars_parts,
            latent_code_vars_objts=self.latent_code_vars_objts,
            optimizer_part_generator=self.optimizer_part_generator,
            optimizer_objt_generator=self.optimizer_objt_generator,
            optimizer_latent_for_parts=self.optimizer_latent_for_parts,
            optimizer_latent_for_objts=self.optimizer_latent_for_objts
        )

        self.manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
                                                  directory=self.checkpoint_dir,
                                                  max_to_keep=1)

        self._load_checkpoint()

    def _load_checkpoint(self):
        latest_checkpoint = self.manager.latest_checkpoint

        if latest_checkpoint is not None:
            print('Checkpoint {} restored'.format(latest_checkpoint))
        else:
            print('No checkpoint was restored.')

        self.checkpoint.restore(latest_checkpoint).expect_partial()

    @tf.function
    def train_step_parts(self, latent_code_vars, true_voxels_part):
        with tf.GradientTape() as tape:
            pred_logits_voxels = self.part_generator(latent_code_vars)

            pred_voxels_part = tf.sigmoid(pred_logits_voxels)
            
            loss = self.losser_bce(true_voxels_part, pred_voxels_part)

        network_vars = self.part_generator_trainable_variables
        gradients = tape.gradient(loss, network_vars + [latent_code_vars])

        self.optimizer_part_generator.apply_gradients(zip(gradients[:len(network_vars)], network_vars))
        self.optimizer_latent_for_parts.apply_gradients(zip(gradients[len(network_vars):], [latent_code_vars]))

        return loss
    
    @tf.function
    def train_step_objts(self, latent_code_vars, true_voxels_objt):
        with tf.GradientTape() as tape:
            pred_logits_voxels = self.objt_generator(latent_code_vars)
            
            pred_voxels_objt = tf.sigmoid(pred_logits_voxels)
            
            bce_loss = self.losser_bce(true_voxels_objt, pred_voxels_objt)
            
            silhouette_losses = []
                        
            for _ in range(self.ramdom_projection_num):
                true_silhouette = tf_project_to_silhouette(true_voxels_objt)
                pred_silhouette = tf_project_to_silhouette(pred_voxels_objt)
                silhouette_losses.append(tf.math.reduce_euclidean_norm((true_silhouette, pred_silhouette)))
            
            silhouette_loss = tf.reduce_mean(tf.convert_to_tensor(silhouette_losses, dtype=tf.float32))
            
            loss = bce_loss + silhouette_loss

        network_vars = self.objt_generator_trainable_variables
        gradients = tape.gradient(loss, network_vars + [latent_code_vars])
        
        self.optimizer_objt_generator.apply_gradients(zip(gradients[:len(network_vars)], network_vars))
        self.optimizer_latent_for_objts.apply_gradients(zip(gradients[len(network_vars):], [latent_code_vars]))

        return loss

    @tf.function
    def train_step_assamble(self, latent_code_vars, true_voxels_objt, model='part'):
        with tf.GradientTape() as tape:
            pred_logits_voxels = self.part_generator(latent_code_vars)
            
            pred_voxels = tf.sigmoid(pred_logits_voxels)
            
            pred_voxels_objt = tf.math.reduce_sum(pred_voxels, axis=0)

            loss = self.losser_bce(true_voxels_objt, pred_voxels_objt)

        gradients = tape.gradient(loss, [latent_code_vars])

        self.optimizer_latent_for_parts.apply_gradients(zip(gradients, [latent_code_vars]))

        return pred_voxels, loss

    def update_latent_code_vars_parts(self, latent_code_vars):
        self.latent_code_vars_parts.assign(latent_code_vars)
    
    def update_latent_code_vars_objts(self, latent_code_vars):
        self.latent_code_vars_objts.assign(latent_code_vars)

    def save_models(self):
        self.manager.save()

In [None]:
model_hparam = {
    'model_latent_code_dim': 256,
    'model_fc_channels': 512,
    'model_conv_size': 4,
    'model_num_latent_codes_parts': data_generator.num_parts,
    'model_num_latent_codes_objts': data_generator.num_objts,
    'model_learning_rate_network': 5e-4,
    'model_learning_rate_codes': 1e-3,
    'model_voxel_map_shape': VOXEL_MAP_SHAPE,
    'model_checkpoint_dir': './ckpt',
    'modelramdom_projection_num': 5
}

geom_network = GeometryNetwork(model_hparam)

In [None]:
TRAINING_EPOCH_FOR_PARTS = 1500

pbar = tqdm(range(1, TRAINING_EPOCH_FOR_PARTS+1), desc="Training on Parts")

data_generator.reset_index()

for epoch in pbar:
    total_loss = []
    
    latent_code_vars_parts = geom_network.latent_code_vars_parts.numpy()
    
    for voxel_index, true_voxels in data_generator:
        true_voxels = data_to_dense(true_voxels)
                                
        latent_code_vars = tf.Variable(latent_code_vars_parts[voxel_index], trainable=True)
        
        true_voxels = np.expand_dims(true_voxels, axis=-1)
                
        loss = geom_network.train_step_parts(latent_code_vars, true_voxels).numpy()
        
        latent_code_vars_parts[voxel_index] = latent_code_vars.numpy()
        
        total_loss.append(loss)
                
    pbar.set_postfix({"Avg Loss": '{:.9f}'.format(sum(total_loss) / len(total_loss))})
    
    geom_network.update_latent_code_vars_parts(latent_code_vars_parts)
    
geom_network.objt_generator.set_weights(geom_network.part_generator.get_weights())

geom_network.save_models()

In [None]:
# test_target_part = random.randint(0, data_generator.num_parts-1)

test_target_part = 0

true_voxels_coords = np.load(os.path.join(DATASET_DIR_PTH, str(test_target_part)+'.npy'))

true_voxel_rotated_coords = rotate_obj_along_axis(true_voxels_coords, rotation_angle=180, axis='y')
true_voxel_rotated = data_to_dense([true_voxel_rotated_coords])[0]

true_voxel_rotated_plot = np.moveaxis(true_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(true_voxel_rotated_plot)

plt.show()

latent_codes = tf.expand_dims(geom_network.latent_code_vars_parts[test_target_part], axis=0)

pred_logits_voxels = geom_network.part_generator(latent_codes)
pred_voxels = tf.sigmoid(pred_logits_voxels)
pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.3), tf.float32).numpy()
pred_voxels = pred_voxels[0, :, :, :].reshape((pred_voxels.shape[1:4]))
pred_voxel_coords = np.where(pred_voxels == True)

x, y, z = pred_voxel_coords
pred_voxel_coords = np.expand_dims(x, axis=0).T
pred_voxel_coords = np.concatenate((pred_voxel_coords, np.expand_dims(y, axis=0).T), axis=1)
pred_voxel_coords = np.concatenate((pred_voxel_coords, np.expand_dims(z, axis=0).T), axis=1)

pred_voxel_rotated_coords = rotate_obj_along_axis(pred_voxel_coords, rotation_angle=180, axis='y')
pred_voxel_rotated = data_to_dense([pred_voxel_rotated_coords])[0]

pred_voxel_rotated_plot = np.moveaxis(pred_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(pred_voxel_rotated_plot)

plt.show()

In [None]:
TRAINING_EPOCH_FOR_OBJTS = 3000

pbar = tqdm(range(1, TRAINING_EPOCH_FOR_OBJTS+1), desc="Training on Objects")

latent_code_vars_parts = geom_network.latent_code_vars_parts.numpy()
latent_code_vars_objts = []

base_index = 0

for obj_index, count in enumerate(data_generator.each_chair_parts_count):
    reduced_latent_code_vars = np.mean(latent_code_vars_parts[base_index:base_index+count], axis=0)
    latent_code_vars_objts.append(reduced_latent_code_vars)
    base_index += count

geom_network.update_latent_code_vars_objts(tf.Variable(latent_code_vars_objts, trainable=True))

data_generator.reset_index()

for epoch in pbar:
    total_loss = []
    
    latent_code_vars_objts = geom_network.latent_code_vars_objts.numpy()
    
    range_count = 0
    
    for obj_index, count in enumerate(data_generator.each_chair_parts_count):        
        latent_code_vars = tf.Variable(tf.expand_dims(latent_code_vars_objts[obj_index], axis=0), trainable=True)
        
        true_voxels = np.zeros(shape=(VOXEL_MAP_SHAPE), dtype=np.float32)
        
        range_count += count
        for voxel_index, sparse_data in data_generator:
            for x, y, z in sparse_data[0]:
                true_voxels[x, y, z] = 1
                
            if voxel_index == range_count-1:
                if voxel_index == data_generator.num_parts-1:
                    data_generator.reset_index()
                break
        
        true_voxels = np.expand_dims(true_voxels, axis=(0, -1))
                
        loss = geom_network.train_step_objts(latent_code_vars, true_voxels)
                        
        total_loss.append(loss)
        
    pbar.set_postfix({"Avg Loss": '{:.9f}'.format(sum(total_loss) / len(total_loss))})
    
    geom_network.update_latent_code_vars_objts(latent_code_vars_objts)

In [None]:
test_target_obj = 0

base_index = 0

for i in range(test_target_obj):
    base_index += data_generator.each_chair_parts_count[i]

true_voxels = np.zeros(shape=(VOXEL_MAP_SHAPE), dtype=np.int32)

for i in range(base_index, base_index+data_generator.each_chair_parts_count[test_target_obj]):
    true_voxels_coords = np.load(os.path.join(DATASET_DIR_PTH, str(i)+'.npy'))
    
    for x, y, z in true_voxels_coords:
        true_voxels[x, y, z] = 1

true_voxels = tf.convert_to_tensor(np.expand_dims(true_voxels, axis=(0, -1)), dtype=tf.float32)

silhouette = tf.reverse(tf_project_to_silhouette(true_voxels)[0], axis=[0])

plt.imshow(silhouette, cmap='gray')
plt.show()

In [None]:
# test_target_obj = random.randint(0, data_generator.num_objts-1)

test_target_obj = 0

base_index = 0

for i in range(test_target_obj):
    base_index += data_generator.each_chair_parts_count[i]

true_voxels = np.zeros(shape=(VOXEL_MAP_SHAPE), dtype=np.int32)

for i in range(base_index, base_index+data_generator.each_chair_parts_count[test_target_obj]):
    true_voxels_coords = np.load(os.path.join(DATASET_DIR_PTH, str(i)+'.npy'))
    
    for x, y, z in true_voxels_coords:
        true_voxels[x, y, z] = 1

true_voxels_coords = np.where(true_voxels == 1)

x, y, z = true_voxels_coords
true_voxels_coords = np.expand_dims(x, axis=0).T
true_voxels_coords = np.concatenate((true_voxels_coords, np.expand_dims(y, axis=0).T), axis=1)
true_voxels_coords = np.concatenate((true_voxels_coords, np.expand_dims(z, axis=0).T), axis=1)

true_voxel_rotated_coords = rotate_obj_along_axis(true_voxels_coords, rotation_angle=180, axis='y')
true_voxel_rotated = data_to_dense([true_voxel_rotated_coords])[0]

true_voxel_rotated_plot = np.moveaxis(true_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(true_voxel_rotated_plot)

plt.show()

# plt.savefig('./{}_true_voxel'.format(str(test_target_obj)) + '.png', dpi=300)

latent_codes = tf.expand_dims(geom_network.latent_code_vars_objts[test_target_obj], axis=0)

pred_logits_voxels = geom_network.objt_generator(latent_codes)
pred_voxels = tf.sigmoid(pred_logits_voxels)
pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.3), tf.float32).numpy()
pred_voxels = pred_voxels[0, :, :, :].reshape((pred_voxels.shape[1:4]))
pred_voxels_coords = np.where(pred_voxels == True)

x, y, z = pred_voxels_coords
pred_voxels_coords = np.expand_dims(x, axis=0).T
pred_voxels_coords = np.concatenate((pred_voxels_coords, np.expand_dims(y, axis=0).T), axis=1)
pred_voxels_coords = np.concatenate((pred_voxels_coords, np.expand_dims(z, axis=0).T), axis=1)

pred_voxel_rotated_coords = rotate_obj_along_axis(pred_voxels_coords, rotation_angle=180, axis='y')
pred_voxel_rotated = data_to_dense([pred_voxel_rotated_coords])[0]

pred_voxel_rotated_plot = np.moveaxis(pred_voxel_rotated, 1, -1)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(pred_voxel_rotated_plot)

plt.show()

# plt.savefig('./{}_pred_voxel'.format(str(test_target_obj)) + '.png', dpi=300)

In [None]:
def save_plot(epoch, pred_voxels, palette):
    pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.3), tf.float32).numpy()
    pred_voxels = pred_voxels.reshape((pred_voxels.shape[:4]))

    voxels = np.zeros(shape=VOXEL_MAP_SHAPE, dtype=np.int32)
    colors = np.zeros(shape=(*VOXEL_MAP_SHAPE, 3), dtype=np.float32)
    
    for i, part in enumerate(pred_voxels):
        pred_voxel_coordinates = np.where(part == True)
        
        x, y, z = pred_voxel_coordinates
        pred_voxel_coordinates = np.expand_dims(x, axis=0).T
        pred_voxel_coordinates = np.concatenate((pred_voxel_coordinates, np.expand_dims(y, axis=0).T), axis=1)
        pred_voxel_coordinates = np.concatenate((pred_voxel_coordinates, np.expand_dims(z, axis=0).T), axis=1)
        
        for coordinates in rotate_obj_along_axis(true_voxels_coords, rotation_angle=180, axis='y'):
            x, y, z = coordinates
            voxels[x, y, z] = 1
            colors[x, y, z] = palette[i]

    voxels = np.moveaxis(voxels, 1, -1)
    colors = np.moveaxis(colors, 1, -2)

    ax = plt.figure().add_subplot(projection='3d')
    ax.set_aspect('equal')
    ax.voxels(voxels, facecolors=colors)

    plt.savefig('./gif_imgs/{}_pred_parts'.format(str(epoch)) + '.png', dpi=300)
    
    plt.close()

In [None]:
TRAINING_EPOCH_FOR_OBJTS = 1500

USER_DEFINE_PARTS_NUM = 30

latent_code_vars = tf.Variable(tf.random.normal((USER_DEFINE_PARTS_NUM, geom_network.latent_code_dim)), trainable=True)

pbar = tqdm(range(1, TRAINING_EPOCH_FOR_OBJTS+1), desc="Training on Assambles")

data_generator.reset_index()

palette = [np.array([random.random() for _ in range(3)], dtype=np.float32) for __ in range(USER_DEFINE_PARTS_NUM)]

target_object = 0

for epoch in pbar:
    total_loss = []
    
    start_count = 0
    for i in range(target_object):
        start_count += data_generator.each_chair_parts_count[i]
        
    range_count = start_count + data_generator.each_chair_parts_count[target_object]
                                            
    true_voxels = np.zeros(shape=(VOXEL_MAP_SHAPE), dtype=np.float32)
    
    for voxel_index, sparse_data in data_generator: # finding parts is slow here, not a good idea, but works.
        if voxel_index >= start_count:
            for x, y, z in sparse_data[0]:
                true_voxels[x, y, z] = 1
            
        if voxel_index == range_count-1:
            data_generator.reset_index()
            break
                                            
    true_voxels = tf.convert_to_tensor(np.expand_dims(true_voxels, axis=-1), dtype=tf.float32)
                                            
    pred_voxels, loss = geom_network.train_step_assamble(latent_code_vars, true_voxels)
    
    # if epoch % 10 == 0:
    #     save_plot(epoch, pred_voxels, palette)
                    
    total_loss.append(loss.numpy())
        
    pbar.set_postfix({"Avg Loss": '{:.9f}'.format(sum(total_loss) / len(total_loss))})

In [None]:
latent_codes = latent_code_vars.numpy()

pred_logits_voxels = geom_network.part_generator(latent_codes)
pred_voxels = tf.sigmoid(pred_logits_voxels)
pred_voxels = tf.cast(tf.math.greater_equal(pred_voxels, 0.3), tf.float32).numpy()
pred_voxels = pred_voxels.reshape((pred_voxels.shape[:4]))

voxels = np.zeros(shape=VOXEL_MAP_SHAPE, dtype=np.int32)
colors = np.zeros(shape=(*VOXEL_MAP_SHAPE, 3), dtype=np.float32)

for part in pred_voxels:
    pred_voxel_coords = np.where(part == True)
    
    x, y, z = pred_voxel_coords
    pred_voxel_coords = np.expand_dims(x, axis=0).T
    pred_voxel_coords = np.concatenate((pred_voxel_coords, np.expand_dims(y, axis=0).T), axis=1)
    pred_voxel_coords = np.concatenate((pred_voxel_coords, np.expand_dims(z, axis=0).T), axis=1)
    
    color = np.array([random.random() for _ in range(3)], dtype=np.float32)
    
    for coordinates in rotate_obj_along_axis(pred_voxel_coords, rotation_angle=180, axis='y'):
        x, y, z = coordinates
        voxels[x, y, z] = 1
        colors[x, y, z] = color

voxels = np.moveaxis(voxels, 1, -1)
colors = np.moveaxis(colors, 1, -2)

ax = plt.figure().add_subplot(projection='3d')
ax.set_aspect('equal')
ax.voxels(voxels, facecolors=colors)

plt.show()

# plt.savefig('./{}_pred_parts'.format(str(0)) + '.png', dpi=300)

In [None]:
img_list = sorted_alphanumeric(os.listdir('C:\\Users\\Matt\\Downloads\\Voxgen\\gif_imgs'))

gif_name = './voxgen_part.gif'

frames = []
for image_name in img_list:
    frames.append(imageio.imread(os.path.join('./gif_imgs', image_name)))
imageio.mimsave(gif_name, frames, 'GIF', duration=1, loop=0)