## This notebook will help you train a raw Point-Cloud GAN.

(Assumes
latent_3d_points is in the PYTHONPATH and that a trained AE model exists)

In [1]:
import numpy as np
import os.path as osp
import matplotlib.pylab as plt
import os
import sys

BASE = os.path.dirname(os.path.abspath(os.path.dirname("file"))) # latent_3D
sys.path.append(BASE) 

from src.autoencoder import Configuration as Conf
from src.neural_net import MODEL_SAVER_ID

from src.in_out import snc_category_to_synth_id, create_dir, PointCloudDataSet, \
                                        load_all_point_clouds_under_folder

from src.general_utils import plot_3d_point_cloud
from src.tf_utils import reset_tf_graph

from src.vanilla_gan import Vanilla_GAN
from src.w_gan_gp import W_GAN_GP
from src.generators_discriminators import point_cloud_generator,\
mlp_discriminator, leaky_relu

Instructions for updating:
Colocations handled automatically by placer.


In [2]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [3]:
class_name = 'chair'
# Use to save Neural-Net check-points etc.
top_out_dir = 'log/pointnet'          

# Top-dir of where point-clouds are stored.
top_in_dir = 'data/shape_net_core_uniform_samples_2048/'

top_out_dir = osp.join(BASE, top_out_dir, class_name)
top_in_dir = osp.join(BASE, top_in_dir)

experiment_name = 'raw_gan_with_w_gan_loss'

n_pc_points = 2048                # Number of points per model.
# class_name = raw_input('Give me the class name (e.g. "chair"): ').lower()

In [4]:
# Load point-clouds.
syn_id = snc_category_to_synth_id()[class_name]
class_dir = osp.join(top_in_dir , syn_id)
all_pc_data = load_all_point_clouds_under_folder(class_dir, n_threads=8, file_ending='.ply', verbose=True)
print ('Shape of DATA =', all_pc_data.point_clouds.shape)

6778 pclouds were loaded. They belong in 1 shape-classes.
Shape of DATA = (6778, 2048, 3)


Set GAN parameters.

In [5]:
use_wgan = True     # Wasserstein with gradient penalty, or not?
n_epochs = 10       # Epochs to train.

plot_train_curve = True
save_gan_model = True
saver_step = np.hstack([np.array([1, 5, 10]), np.arange(50, n_epochs + 1, 50)])

# If true, every 'saver_step' epochs we produce & save synthetic pointclouds.
save_synthetic_samples = True
# How many synthetic samples to produce at each save step.
n_syn_samples = all_pc_data.num_examples

# Optimization parameters
init_lr = 0.0001
batch_size = 50
noise_params = {'mu':0, 'sigma': 0.2}
noise_dim = 128
beta = 0.5 # ADAM's momentum.

n_out = [n_pc_points, 3] # Dimensionality of generated samples.


In [6]:
discriminator = mlp_discriminator
generator = point_cloud_generator


In [7]:
if save_synthetic_samples:
    synthetic_data_out_dir = osp.join(top_out_dir, 'OUT/synthetic_samples/', experiment_name)
    create_dir(synthetic_data_out_dir)

if save_gan_model:
    train_dir = osp.join(top_out_dir, 'OUT/raw_gan', experiment_name)
    create_dir(train_dir)

In [9]:
reset_tf_graph()

if use_wgan:
    lam = 10
    disc_kwargs = {'b_norm': False}
    gan = W_GAN_GP(experiment_name, init_lr, lam, n_out, noise_dim,
                    discriminator, generator,
                    disc_kwargs=disc_kwargs, beta=beta)
    
else:    
    leak = 0.2
    disc_kwargs = {'non_linearity': leaky_relu(leak), 'b_norm': False}
    gan = Vanilla_GAN(experiment_name, init_lr, n_out, noise_dim,
                      discriminator, generator, beta=beta, disc_kwargs=disc_kwargs)

accum_syn_data = []
train_stats = []

dis <function mlp_discriminator at 0x7fa24af79b70>
gen <function point_cloud_generator at 0x7fa24af900d0>
gen out = gen(noise)
gen decoder
decoder_fc_0 Tensor("raw_gan_with_w_gan_loss_1/generator/decoder_fc_0/BiasAdd:0", shape=(?, 64), dtype=float32)
decoder_fc_1 Tensor("raw_gan_with_w_gan_loss_1/generator/decoder_fc_1/BiasAdd:0", shape=(?, 128), dtype=float32)
decoder_fc_2 Tensor("raw_gan_with_w_gan_loss_1/generator/decoder_fc_2/BiasAdd:0", shape=(?, 512), dtype=float32)
decoder_fc_3 Tensor("raw_gan_with_w_gan_loss_1/generator/decoder_fc_3/BiasAdd:0", shape=(?, 1024), dtype=float32)
gen_dec  Tensor("raw_gan_with_w_gan_loss_1/generator/FullyConnected/BiasAdd:0", shape=(?, 6144), dtype=float32)
gen_dec  Tensor("raw_gan_with_w_gan_loss_1/generator/Reshape:0", shape=(?, 2048, 3), dtype=float32)
dis(real)
dis encoder
encoder_conv_layer_0 Tensor("raw_gan_with_w_gan_loss_1/discriminator/raw_gan_with_w_gan_loss/discriminator/encoder_conv_layer_0/Squeeze:0", shape=(?, 2048, 64), dtype=float32)

In [17]:
print(gan.__dict__)

{'graph': <tensorflow.python.framework.ops.Graph object at 0x7f8a3cb04198>, 'name': 'raw_gan_with_w_gan_loss', 'epoch': <tf.Variable 'raw_gan_with_w_gan_loss/epoch:0' shape=() dtype=float32_ref>, 'increment_epoch': <tf.Tensor 'raw_gan_with_w_gan_loss/AssignAdd:0' shape=() dtype=float32_ref>, 'no_op': <tf.Operation 'NoOp' type=NoOp>, 'noise_dim': 128, 'n_output': [2048, 3], 'discriminator': <function mlp_discriminator at 0x7f89a0792b70>, 'generator': <function point_cloud_generator at 0x7f89a07a90d0>, 'noise': <tf.Tensor 'raw_gan_with_w_gan_loss_1/Placeholder:0' shape=(?, 128) dtype=float32>, 'real_pc': <tf.Tensor 'raw_gan_with_w_gan_loss_1/Placeholder_1:0' shape=(?, 2048, 3) dtype=float32>, 'generator_out': <tf.Tensor 'raw_gan_with_w_gan_loss_1/generator/Reshape:0' shape=(?, 2048, 3) dtype=float32>, 'real_prob': <tf.Tensor 'raw_gan_with_w_gan_loss_1/discriminator/Sigmoid:0' shape=(?, 1) dtype=float32>, 'real_logit': <tf.Tensor 'raw_gan_with_w_gan_loss_1/discriminator/raw_gan_with_w_gan

In [None]:
# Train the GAN.
for _ in range(n_epochs):
    loss, duration = gan._single_epoch_train(all_pc_data, batch_size, noise_params)
    epoch = int(gan.sess.run(gan.increment_epoch))
    print (epoch, loss)

    if save_gan_model and epoch in saver_step:
        checkpoint_path = osp.join(train_dir, MODEL_SAVER_ID)
        gan.saver.save(gan.sess, checkpoint_path, global_step=gan.epoch)

    if save_synthetic_samples and epoch in saver_step:
        syn_data = gan.generate(n_syn_samples, noise_params)
        np.savez(osp.join(synthetic_data_out_dir, 'epoch_' + str(epoch)), syn_data)
        for k in range(3):  # plot three (synthetic) random examples.
            plot_3d_point_cloud(syn_data[k][:, 0], syn_data[k][:, 1], syn_data[k][:, 2],
                               in_u_sphere=True)

    train_stats.append((epoch, ) + loss)

In [None]:
if plot_train_curve:
    x = range(len(train_stats))
    d_loss = [t[1] for t in train_stats]
    g_loss = [t[2] for t in train_stats]
    plt.plot(x, d_loss, '--')
    plt.plot(x, g_loss)
    plt.title('GAN training. (%s)' %(class_name))
    plt.legend(['Discriminator', 'Generator'], loc=0)
    
    plt.tick_params(axis='x', which='both', bottom='off', top='off')
    plt.tick_params(axis='y', which='both', left='off', right='off')
    
    plt.xlabel('Epochs.') 
    plt.ylabel('Loss.')