In [None]:
import math

import tensorflow as tf
import numpy as np

In [None]:
def stratified_1d(near, far, num_samples, name="stratified_1d"):
    with tf.name_scope(name):
        near = tf.convert_to_tensor(near)
        far = tf.convert_to_tensor(far)

        bin_borders = tf.linspace(0.0, 1.0, num_samples + 1, axis=-1)
        bin_below = bin_borders[..., :-1]
        bin_above = bin_borders[..., 1:]
        target_shape = tf.concat([tf.shape(near), [num_samples]], axis=-1)
        random_point_in_bin = tf.random.uniform(target_shape)
        z_values = bin_below + (bin_above - bin_below) * random_point_in_bin
        z_values = (tf.expand_dims(near, -1) * (1. - z_values) +
                    tf.expand_dims(far, -1) * z_values)

        return z_values


def points_from_z_values(ray_org, ray_dir, z_values):
    points3d = (tf.expand_dims(ray_dir, axis=-2) *
                tf.expand_dims(z_values, axis=-1))
    points3d = tf.expand_dims(ray_org, -2) + points3d
    return points3d


def sample_1d(ray_org, ray_dir, near, far, n_samples, name="sample_1d"):
    with tf.name_scope(name):
        ray_org = tf.convert_to_tensor(ray_org)
        ray_dir = tf.convert_to_tensor(ray_dir)

        near = tf.convert_to_tensor(near) * tf.ones((1,))
        far = tf.convert_to_tensor(far) * tf.ones((1,))

        near = near * tf.ones(tf.shape(ray_org)[:-1])
        far = far * tf.ones(tf.shape(ray_org)[:-1])

        random_z_values = stratified_1d(near, far, n_samples)
        points3d = points_from_z_values(ray_org, ray_dir, random_z_values)

        return points3d


def match_intermediate_batch_dimensions(tensor1, tensor2):
    shape1 = tf.shape(tensor1)
    shape2 = tf.shape(tensor2)

    shape_diff = len(shape2) - len(shape1)
    new_shape = tf.concat([[shape1[0]], [1]*shape_diff, [shape1[-1]]], axis=-1)
    target_shape = tf.concat([shape2[:-1], [shape1[-1]]], axis=-1)

    return tf.broadcast_to(tf.reshape(tensor1, new_shape), target_shape)


def trilinear_interpolate(grid_3d, sampling_points, name="trilinear_interpolate"):
    with tf.name_scope(name):
        grid_3d = tf.convert_to_tensor(value=grid_3d)
        sampling_points = tf.convert_to_tensor(value=sampling_points)

        voxel_cube_shape = tf.shape(input=grid_3d)[-4:-1]
        sampling_points.set_shape(sampling_points.shape)
        batch_dims = tf.shape(input=sampling_points)[:-2]
        num_points = tf.shape(input=sampling_points)[-2]

        bottom_left = tf.floor(sampling_points)
        top_right = bottom_left + 1
        bottom_left_index = tf.cast(bottom_left, tf.int32)
        top_right_index = tf.cast(top_right, tf.int32)
        x0_index, y0_index, z0_index = tf.unstack(bottom_left_index, axis=-1)
        x1_index, y1_index, z1_index = tf.unstack(top_right_index, axis=-1)
        index_x = tf.concat([x0_index, x1_index, x0_index, x1_index,
                             x0_index, x1_index, x0_index, x1_index], axis=-1)
        index_y = tf.concat([y0_index, y0_index, y1_index, y1_index,
                             y0_index, y0_index, y1_index, y1_index], axis=-1)
        index_z = tf.concat([z0_index, z0_index, z0_index, z0_index,
                             z1_index, z1_index, z1_index, z1_index], axis=-1)
        indices = tf.stack([index_x, index_y, index_z], axis=-1)
        clip_value = tf.convert_to_tensor(
            value=[voxel_cube_shape - 1], dtype=indices.dtype)
        indices = tf.clip_by_value(indices, 0, clip_value)
        content = tf.gather_nd(
            params=grid_3d, indices=indices, batch_dims=tf.size(input=batch_dims))
        distance_to_bottom_left = sampling_points - bottom_left
        distance_to_top_right = top_right - sampling_points
        x_x0, y_y0, z_z0 = tf.unstack(distance_to_bottom_left, axis=-1)
        x1_x, y1_y, z1_z = tf.unstack(distance_to_top_right, axis=-1)
        weights_x = tf.concat([x1_x, x_x0, x1_x, x_x0,
                               x1_x, x_x0, x1_x, x_x0], axis=-1)
        weights_y = tf.concat([y1_y, y1_y, y_y0, y_y0,
                               y1_y, y1_y, y_y0, y_y0], axis=-1)
        weights_z = tf.concat([z1_z, z1_z, z1_z, z1_z,
                               z_z0, z_z0, z_z0, z_z0], axis=-1)
        weights = tf.expand_dims(weights_x * weights_y * weights_z, axis=-1)

        interpolated_values = weights * content
        return tf.add_n(tf.split(interpolated_values, [num_points] * 8, -2))


@tf.function
def ray_sample_voxel_grid(ray_points, voxels, w2v_alpha, w2v_beta):
    w2v_alpha = match_intermediate_batch_dimensions(w2v_alpha, ray_points)
    w2v_beta = match_intermediate_batch_dimensions(w2v_beta, ray_points)
    rays = w2v_alpha*ray_points + w2v_beta

    batch_size = tf.shape(voxels)[0]
    channels = tf.shape(voxels)[-1]

    target_shape = tf.concat([tf.shape(rays)[:-1], [channels]], axis=-1)

    rays = tf.reshape(rays, [batch_size, -1, 3])
    features_alpha = trilinear_interpolate(voxels, rays)

    return tf.reshape(features_alpha, target_shape)


def compute_density(density_values, distances, name=None):
    with tf.compat.v1.name_scope(name, "ray_density", [density_values, distances]):
        density_values = tf.convert_to_tensor(value=density_values)
        distances = tf.convert_to_tensor(value=distances)
        distances = tf.expand_dims(distances, -1)

        alpha = 1. - tf.exp(-density_values * distances)
        alpha = tf.squeeze(alpha, -1)
        ray_sample_weights = alpha * \
            tf.math.cumprod(1. - alpha + 1e-10, -1, exclusive=True)
        ray_alpha = tf.expand_dims(
            tf.reduce_sum(ray_sample_weights, -1), axis=-1)
        return ray_alpha


def l2_loss(prediction, target, weights=1.0):
    assert prediction.shape == target.shape, "Shape dims should be the same."
    return tf.reduce_mean(weights * tf.square(target - prediction))


def select_eps_for_division(dtype):
    return 10.0 * np.finfo(dtype.as_numpy_dtype).tiny


def assert_no_infs_or_nans(tensor, name='assert_no_infs_or_nans'):
    with tf.name_scope(name):
        tensor = tf.convert_to_tensor(value=tensor)

        assert_ops = (tf.debugging.check_numerics(
            tensor, message='Inf or NaN detected.'),)
        with tf.control_dependencies(assert_ops):
            return tf.identity(tensor)


def nonzero_sign(x, name='nonzero_sign'):
    with tf.name_scope(name):
        x = tf.convert_to_tensor(value=x)

        one = tf.ones_like(x)
        return tf.where(tf.greater_equal(x, 0.0), one, -one)


def safe_signed_div(a, b, eps=None, name='safe_signed_div'):
    with tf.name_scope(name):
        a = tf.convert_to_tensor(value=a)
        b = tf.convert_to_tensor(value=b)

        if eps is None:
            eps = select_eps_for_division(b.dtype)

        eps = tf.convert_to_tensor(value=eps, dtype=b.dtype)

        return assert_no_infs_or_nans(a / (b + nonzero_sign(b) * eps))


def ray(point_2d, focal, principal_point, name="perspective_ray"):
    with tf.name_scope(name):
        point_2d = tf.convert_to_tensor(value=point_2d)
        focal = tf.convert_to_tensor(value=focal)
        principal_point = tf.convert_to_tensor(value=principal_point)

        point_2d -= principal_point
        point_2d = safe_signed_div(point_2d, focal)
        padding = [[0, 0] for _ in point_2d.shape]
        padding[-1][-1] = 1

        return tf.pad(tensor=point_2d, paddings=padding, mode="CONSTANT", constant_values=1.0)


def perspective_random_rays(focal, principal_point, height, width, n_rays, margin=0, name="random_rays"):
    with tf.name_scope(name):
        focal = tf.convert_to_tensor(value=focal)
        principal_point = tf.convert_to_tensor(value=principal_point)
        batch_dims = tf.shape(focal)[:-1]
        target_shape = tf.concat([batch_dims, [n_rays]], axis=0)
        random_x = tf.random.uniform(
            target_shape, minval=margin, maxval=width - margin, dtype=tf.int32)
        random_y = tf.random.uniform(
            target_shape, minval=margin, maxval=height - margin, dtype=tf.int32)
        pixels = tf.cast(tf.stack((random_x, random_y), axis=-1), tf.float32)
        rays = ray(pixels, tf.expand_dims(focal, -2),
                   tf.expand_dims(principal_point, -2))
        return rays, tf.cast(pixels, tf.int32)


def build_matrix_from_sines_and_cosines(sin_angles, cos_angles):
    sin_angles.shape.assert_is_compatible_with(cos_angles.shape)

    sx, sy, sz = tf.unstack(sin_angles, axis=-1)
    cx, cy, cz = tf.unstack(cos_angles, axis=-1)
    m00 = cy * cz
    m01 = (sx * sy * cz) - (cx * sz)
    m02 = (cx * sy * cz) + (sx * sz)
    m10 = cy * sz
    m11 = (sx * sy * sz) + (cx * cz)
    m12 = (cx * sy * sz) - (sx * cz)
    m20 = -sy
    m21 = sx * cy
    m22 = cx * cy
    matrix = tf.stack((m00, m01, m02,
                       m10, m11, m12,
                       m20, m21, m22),
                      axis=-1)  # pyformat: disable
    output_shape = tf.concat(
        (tf.shape(input=sin_angles)[:-1], (3, 3)), axis=-1)
    return tf.reshape(matrix, shape=output_shape)


def rotation_matrix_3d_from_euler(angles, name="rotation_matrix_3d_from_euler"):
    with tf.name_scope(name):
        angles = tf.convert_to_tensor(value=angles)

        sin_angles = tf.sin(angles)
        cos_angles = tf.cos(angles)
        return build_matrix_from_sines_and_cosines(sin_angles, cos_angles)


def change_coordinate_system(points3d, rotations=(0., 0., 0.), scale=(1., 1., 1.), name="change_coordinate_system"):
    with tf.name_scope(name):
        points3d = tf.convert_to_tensor(points3d)
        rotation = tf.convert_to_tensor(rotations)
        scale = tf.convert_to_tensor(scale)

        rotation_matrix = rotation_matrix_3d_from_euler(rotation)
        scaling_matrix = scale*tf.eye(3, 3)

        target_shape = [1]*(len(points3d.get_shape().as_list()) - 2) + [3, 3]
        transformation = tf.matmul(scaling_matrix, rotation_matrix)
        transformation = tf.reshape(transformation, target_shape)

        return tf.linalg.matrix_transpose(tf.matmul(transformation, tf.linalg.matrix_transpose(points3d)))


@tf.function
def _move_in_front_of_camera(points3d, rotation_matrix, translation_vector):
    points3d = tf.convert_to_tensor(value=points3d)
    rotation_matrix = tf.convert_to_tensor(value=rotation_matrix)
    translation_vector = tf.convert_to_tensor(value=translation_vector)

    points3d_corrected = tf.linalg.matrix_transpose(
        points3d) + translation_vector
    rotation_matrix_t = -tf.linalg.matrix_transpose(rotation_matrix)
    points3d_world = tf.matmul(rotation_matrix_t, points3d_corrected)

    return tf.linalg.matrix_transpose(points3d_world)


@tf.function
def camera_rays_from_extrinsics(rays, rotation_matrix, translation_vector):
    rays_org = _move_in_front_of_camera(tf.zeros_like(
        rays), rotation_matrix, translation_vector)
    rays_dir_ = _move_in_front_of_camera(
        rays, rotation_matrix, 0 * translation_vector)
    rays_dir = rays_dir_/tf.norm(rays_dir_, axis=-1, keepdims=True)
    return rays_org, rays_dir

In [None]:
hparam = {
    'model_latent_code_dim': 256,
    'model_fc_channels': 512,
    'model_fc_activation': 'relu',
    'model_norm_3d': 'batchnorm',
    'model_conv_size': 4,
    'model_num_latent_codes': 4371,
    'model_learning_rate_network': 1e-4,
    'model_learning_rate_codes': 1e-4,
    'model_checkpoint_dir': './ckpt'
}

In [None]:
class GeometryNetwork:
    def __init__(self, hparam):
        self.latent_code_dim = hparam['model_latent_code_dim']
        self.fc_channels = hparam['model_fc_channels']
        self.fc_activation = hparam['model_fc_activation']
        self.norm_3d = hparam['model_norm_3d']
        self.conv_size = hparam['model_conv_size']
        self.num_latent_codes = hparam['model_num_latent_codes']
        self.learning_rate_network = hparam['model_learning_rate_network']
        self.learning_rate_codes = hparam['model_learning_rate_codes']
        self.checkpoint_dir = hparam['model_checkpoint_dir']

        self.mask_voxels = self.get_mask_voxels()

        self.init_model()
        self.init_optimizer()
        self.init_checkpoint()

    def get_mask_voxels(self, shape=(1, 128, 128, 128, 1), dtype=np.float32):
        voxels = np.ones(shape, dtype=dtype)
        voxels[:, [0, -1], :, :, :] = 0
        voxels[:, :, [0, -1], :, :] = 0
        voxels[:, :, :, [0, -1], :] = 0
        return tf.convert_to_tensor(voxels)

    def norm_layer(self, tensor, normalization):
        if normalization and normalization.lower() == 'batchnorm':
            tensor = tf.keras.layers.BatchNormalization()(tensor)
        return tensor

    def conv_t_block_3d(self, tensor, num_filters, size, strides,
                        normalization=None, dropout=False,
                        alpha_lrelu=0.2, relu=True, rate=0.7):
        conv_3D_transpose = tf.keras.layers.Conv3DTranspose(num_filters,
                                                            size,
                                                            strides=strides,
                                                            padding='same',
                                                            kernel_initializer=tf.keras.initializers.glorot_normal(),
                                                            use_bias=False)

        tensor = conv_3D_transpose(tensor)

        tensor = self.norm_layer(tensor, normalization)

        if relu:
            tensor = tf.keras.layers.LeakyReLU(alpha=alpha_lrelu)(tensor)

        if dropout:
            tensor = tf.keras.layers.Dropout(rate)(tensor)

        return tensor

    def get_model(self):

        with tf.name_scope('Network/'):

            latent_code = tf.keras.layers.Input(shape=(self.latent_code_dim,))

            with tf.name_scope('FC_layers'):

                fc0 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(latent_code)

                fc1 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(fc0)

                fc2 = tf.keras.layers.Dense(self.fc_channels, activation=self.fc_activation)(fc1)

                fc2_as_volume = tf.keras.layers.Reshape((1, 1, 1, self.fc_channels))(fc2)

            with tf.name_scope('GLO_VoxelDecoder'):

                decoder_1 = self.conv_t_block_3d(fc2_as_volume,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_2 = self.conv_t_block_3d(decoder_1,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_3 = self.conv_t_block_3d(decoder_2,

                                                 num_filters=32,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_4 = self.conv_t_block_3d(decoder_3,

                                                 num_filters=16,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_5 = self.conv_t_block_3d(decoder_4,

                                                 num_filters=8,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                decoder_6 = self.conv_t_block_3d(decoder_5,

                                                 num_filters=4,

                                                 size=self.conv_size,

                                                 strides=2,

                                                 normalization=self.norm_3d)

                conv_3D_transpose_out = tf.keras.layers.Conv3DTranspose(1,

                                                                        self.conv_size,

                                                                        strides=2,

                                                                        padding='same',

                                                                        kernel_initializer=tf.keras.initializers.glorot_normal(),

                                                                        use_bias=False)

                volume_out = conv_3D_transpose_out(decoder_6)

        return tf.keras.Model(inputs=[latent_code], outputs=[volume_out])

    def init_model(self):
        self.model = self.get_model()
        self.model_backup = self.get_model()

        self.latest_epoch = tf.Variable(0, trainable=False, dtype=tf.int64)
        self.global_step = tf.Variable(0, trainable=False, dtype=tf.int64)

        init_latent_code = tf.random.normal(
            (self.num_latent_codes, self.latent_code_dim))
        self.latent_code_vars = tf.Variable(init_latent_code, trainable=True)

        self.trainable_variables = self.model.trainable_variables

    def init_optimizer(self):
        self.optimizer_network = tf.keras.optimizers.Adam(
            learning_rate=self.learning_rate_network)
        self.optimizer_latent = tf.keras.optimizers.Adam(
            learning_rate=self.learning_rate_codes)
    
    def init_checkpoint(self):
        self.summary_writer = tf.summary.create_file_writer(self.checkpoint_dir)
        
        self.checkpoint = tf.train.Checkpoint(
            model=self.model,
            latent_code_var=self.latent_code_vars,
            optimizer_network=self.optimizer_network,
            optimizer_latent=self.optimizer_latent,
            epoch=self.latest_epoch,
            global_step=self.global_step)
        
        self.manager = tf.train.CheckpointManager(checkpoint=self.checkpoint,
                                                  directory=self.checkpoint_dir,
                                                  max_to_keep=2)
        
        self.load_checkpoint()

    def load_checkpoint(self):
        latest_checkpoint = self.manager.latest_checkpoint
        
        if latest_checkpoint is not None:
            print('Checkpoint {} restored'.format(latest_checkpoint))
            
        self.checkpoint.restore(latest_checkpoint).expect_partial()
        
        for a, b in zip(self.model_backup.variables, self.model.variables):
            a.assign(b)
        else:
            print('No checkpoint was restored.')

    def reset_models(self):
        for a, b in zip(self.model.variables, self.model_backup.variables):
            a.assign(b)

In [None]:
def optimize_for_mask(geometry_network,
                      mask,
                      focal,
                      principal_point,
                      rotation_matrix,
                      translation_vector,
                      w2v_alpha,
                      w2v_beta,
                      near=1.25,
                      far=3.5,
                      n_samples=128,
                      density=20,
                      mirror_weight=1.0,
                      n_iter=100,
                      n_rays=1024):

    height, width = mask.shape[-3], mask.shape[-2]

    voxel_code_var = tf.reduce_mean(geometry_network.latent_code_vars, axis=0, keepdims=True)
    voxel_code_var = tf.Variable(voxel_code_var, trainable=True)

    network_vars = geometry_network.trainable_variables

    @tf.function
    def opt_step(r_org, r_dir, gt_a):
        with tf.GradientTape() as tape:
            pred_logits_voxels = geometry_network.model(voxel_code_var)
            voxels = tf.sigmoid(pred_logits_voxels) * geometry_network.mask_voxels

            ray_points_coarse = sample_1d(r_org, r_dir, near=near, far=far, n_samples=n_samples)

            voxel_values = ray_sample_voxel_grid(ray_points_coarse, voxels, w2v_alpha, w2v_beta)
            silhouettes = compute_density(voxel_values, density*tf.ones_like(voxel_values[..., 0]))

            silhouette_loss = l2_loss(silhouettes, gt_a)
            mirror_voxel_loss = l2_loss(voxels, tf.reverse(voxels, [1]))
            total_loss = silhouette_loss + mirror_weight*mirror_voxel_loss

        gradients = tape.gradient(total_loss, network_vars + [voxel_code_var])
        geometry_network.optimizer_network.apply_gradients(zip(gradients[:len(network_vars)], network_vars))
        geometry_network.optimizer_latent.apply_gradients(zip(gradients[len(network_vars):], [voxel_code_var]))

        return total_loss

    for it in range(n_iter):
        random_rays, random_pixels_xy = perspective_random_rays(focal, principal_point, height, width, n_rays)
        random_rays = change_coordinate_system(random_rays, (0., math.pi, math.pi), (-1., 1., 1.))
        rays_org, rays_dir = camera_rays_from_extrinsics(random_rays, rotation_matrix, translation_vector)

        random_pixels_yx = tf.reverse(random_pixels_xy, axis=[-1])
        random_pixels_yx = tf.cast(random_pixels_yx, tf.int32)
        pixels = tf.gather_nd(mask, random_pixels_yx, batch_dims=1)
        loss = opt_step(rays_org, rays_dir, pixels)

        print('Iter {:>2d} loss: {:.5f}'.format(it, loss))

    return voxel_code_var

In [None]:
class ModelTest(tf.test.TestCase):
    def test_model_training(self):
        batch_size = 10
        latent_code_dim = 256

        geom_network = GeometryNetwork(hparam)

        latent_codes = tf.zeros((batch_size, latent_code_dim))
        pred_logits_voxels = geom_network.model(latent_codes)
        pred_voxels = tf.sigmoid(pred_logits_voxels)

        self.assertAllInRange(pred_voxels, 0.0, 1.0)

In [None]:
class OptimizationTest(tf.test.TestCase):
    def test_optimization(self):
        geom_network = GeometryNetwork(hparam)

        mask = tf.ones((1, 128, 128, 1))

        latent_code = optimize_for_mask(
            geom_network,
            mask,
            focal=tf.ones((1, 2)),
            principal_point=tf.ones((1, 2)),
            rotation_matrix=tf.expand_dims(tf.eye(3, 3), 0),
            translation_vector=tf.ones((1, 3, 1)),
            w2v_alpha=tf.ones((1, 3)),
            w2v_beta=tf.ones((1, 3)),
            near=1.25,
            far=3.5,
            n_samples=128,
            density=1,
            mirror_weight=50.0,
            n_iter=10,
            n_rays=2024
        )

        pred_logits_voxels = geom_network.model(latent_code)
        pred_voxels = tf.sigmoid(pred_logits_voxels) * geom_network.mask_voxels

        self.assertAllInRange(pred_voxels, 0.0, 1.0)

In [None]:
ModelTest().test_model_training()
OptimizationTest().test_optimization()