In [1]:
from glob import glob
from tqdm import tqdm
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)  # Hide TF deprecation messages
import matplotlib.pyplot as plt
import random
import numpy as np
import pickle
import modules
import data_utils



2023-07-27 09:56:35.527598: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [3]:
BATCH_SIZE=16

In [4]:
chair_meshes_paths = list(glob("chairs_ngon/*.obj"))

In [5]:
with open("chairs_split_dict.pickle", 'rb') as f:
    chairs_split_dict = pickle.load(f)

In [7]:
chairs_train = []
chairs_val = []
chairs_test = []
for c in chair_meshes_paths: 
    try:
        split = chairs_split_dict[c.split("/")[-1].replace(".obj", "")]
    except KeyError:
#         print(c.split("/")[-1].replace(".obj", ""))
        continue
    if split =='train':
        chairs_train.append(c)
    elif split =='val':
        chairs_val.append(c)
    else:
        chairs_test.append(c)

In [9]:
train_paths = chairs_train.copy()
random.shuffle(train_paths)

In [11]:
val_paths = chairs_val.copy()
random.shuffle(val_paths)

In [13]:
with open('chair_to_text.pickle', 'rb') as f:
    chair_to_text = pickle.load(f)

In [17]:
with open('word_to_int.pickle', 'rb') as f:
    word_to_int = pickle.load(f)

In [None]:
with open('int_to_word.pickle', 'rb') as f:
    int_to_word = pickle.load(f)

In [18]:
lengths = []
count = 0
for key, val in chair_to_text.items():
    for t in val:
        lengths.append(len(t))
        count+=1
max_length = np.array(lengths)
(max_length < 30).sum()/count

0.9860834232747128

In [19]:
max_length = 30

In [20]:
def text2shape(paths):
    for path in paths:
        mesh_dict = data_utils.load_process_mesh(path)
        if len(mesh_dict['vertices'])>800:
            continue
        if len(mesh_dict['faces'])>2800:
            continue
        try:
            texts =  chair_to_text[path.split("/")[-1].replace(".obj", "")]
        except KeyError:
            continue
        text = np.random.choice(texts)[:max_length]
        mesh_dict['text_feature'] = np.pad(text, (0,max_length-len(text)))
        yield mesh_dict

In [21]:
Text2ShapeDataset = tf.data.Dataset.from_generator(
        lambda:text2shape(train_paths),
        output_types={
            'vertices': tf.int32, 'faces': tf.int32,
            'text_feature': tf.int32},
        output_shapes={
            'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]),
            'text_feature':tf.TensorShape(max_length)})

In [22]:
vertex_model_dataset = data_utils.make_vertex_model_dataset(Text2ShapeDataset, apply_random_shift=False)
vertex_model_dataset = vertex_model_dataset.repeat()
vertex_model_dataset = vertex_model_dataset.padded_batch(BATCH_SIZE, padded_shapes=vertex_model_dataset.output_shapes)
vertex_model_dataset = vertex_model_dataset.prefetch(1)
it = vertex_model_dataset.make_initializable_iterator()
vertex_model_batch = it.get_next()
iterator_init_op_train = it.initializer

In [24]:
Text2ShapeDatasetVal = tf.data.Dataset.from_generator(
        lambda:text2shape(val_paths),
        output_types={
            'vertices': tf.int32, 'faces': tf.int32,
            'text_feature': tf.int32},
        output_shapes={
            'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]),
            'text_feature':tf.TensorShape(max_length)})
vertex_model_dataset_val = data_utils.make_vertex_model_dataset(Text2ShapeDatasetVal, apply_random_shift=False)
vertex_model_dataset_val = vertex_model_dataset_val.repeat()
vertex_model_dataset_val = vertex_model_dataset_val.padded_batch(BATCH_SIZE, padded_shapes=vertex_model_dataset_val.output_shapes)
vertex_model_dataset_val = vertex_model_dataset_val.prefetch(1)
itv = vertex_model_dataset_val.make_initializable_iterator()
vertex_model_batch_val = itv.get_next()
iterator_init_op_val = itv.initializer

In [25]:
o = (i for i in text2shape(train_paths))
TRAIN_SIZE = sum(1 for _ in o)
o = (i for i in text2shape(val_paths))
VAL_SIZE = sum(1 for _ in o)

  text = np.random.choice(texts)[:max_length]


In [26]:
TRAIN_SIZE

1251

In [27]:
VAL_SIZE

216

In [28]:
vertex_model = modules.TextToVertexModel(
    decoder_config=dict(
      hidden_size=128,
      fc_size=512,
      num_heads=8,
      layer_norm=True,
      num_layers=24,
      dropout_rate=0.4,
      re_zero=True,
      memory_efficient=True
      ),
    path_to_embeddings="glove/glove.6B.100d.txt",
    embedding_dims = 100,
    quantization_bits=8,
    vocab=word_to_int,
    max_num_input_verts=800,  # number of vertices in the input mesh, if this is lower than the number of vertices in the mesh, there will be errors during training
    use_discrete_embeddings=True
)

Found 400000 word vectors.
Converted 3657 words (6 misses)


In [29]:
vertex_model_pred_dist = vertex_model(vertex_model_batch, is_training=True)
vertex_model_loss = -tf.reduce_sum(
    vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) *
    vertex_model_batch['vertices_flat_mask'])
vertex_samples = vertex_model.sample(
    BATCH_SIZE, context=vertex_model_batch, max_sample_length=800, top_p=0.95,
    recenter_verts=False, only_return_complete=False)

print(vertex_model_batch)
print(vertex_model_pred_dist)


{'vertices': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?, 3) dtype=int32>, 'faces': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int32>, 'text_feature': <tf.Tensor 'IteratorGetNext:1' shape=(?, 30) dtype=int32>, 'vertices_flat': <tf.Tensor 'IteratorGetNext:3' shape=(?, ?) dtype=int32>, 'vertices_flat_mask': <tf.Tensor 'IteratorGetNext:4' shape=(?, ?) dtype=float32>}
tfp.distributions.Categorical("vertex_model_2/vertex_model/create_dist/Categorical/", batch_shape=[?, ?], event_shape=[], dtype=int32)


In [30]:
vertex_model_pred_dist_val = vertex_model(vertex_model_batch_val)
vertex_model_loss_val = -tf.reduce_sum(
    vertex_model_pred_dist_val.log_prob(vertex_model_batch_val['vertices_flat']) *
    vertex_model_batch_val['vertices_flat_mask'])
vertex_samples_val = vertex_model.sample(
    BATCH_SIZE, context=vertex_model_batch_val, max_sample_length=800, top_p=0.95,
    recenter_verts=False, only_return_complete=False)

In [34]:
face_module_config=dict(
  encoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=10,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  decoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=14,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  class_conditional=False,
  decoder_cross_attention=True,
  use_discrete_vertex_embeddings=True,
  max_seq_length=8000,
  )
face_model=modules.FaceModel(**face_module_config)

In [35]:
face_samples_val = face_model.sample(
    context=vertex_samples_val, max_sample_length=2800, top_p=0.95,
    only_return_complete=False)

In [36]:
face_model_saver = tf.train.Saver(var_list=face_model.variables)

In [None]:
#borrowed from https://github.com/optas/shapeglot
def token_ints_to_sentence(tokens, int_to_word):
    text = [int_to_word[i] for i in tokens]
    text = ' '.join(text)
    stop = text.find('<EOS>')
    if stop == -1:
        stop = len(text)
    text = text[:stop]
    return text

In [None]:
%matplotlib inline
from tqdm import trange
import os

last_saver_vertex = tf.train.Saver(var_list=vertex_model.variables) # will keep last 5 epochs
best_saver_vertex = tf.train.Saver(var_list=vertex_model.variables, max_to_keep=2)  # only keep 1 best

learning_rate = 1e-3
training_steps = 500
check_step_metrics = 50
check_step_samples = 50
EPOCHS = 500
optimizer = tf.train.AdamOptimizer(learning_rate)
vertex_model_optim_op = optimizer.minimize(vertex_model_loss, var_list=vertex_model.variables)
best_v_loss = float('inf')
# Training loop
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())

    face_model_saver.restore(sess, "face_model/model")

    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, EPOCHS))
        num_steps = (TRAIN_SIZE + BATCH_SIZE - 1) // BATCH_SIZE
        sess.run(iterator_init_op_train)
        t = trange(num_steps)
        loss_values = []
        for i in t:
            sess.run(vertex_model_optim_op)

            loss_val = sess.run(vertex_model_loss)
            loss_values.append(loss_val)
 
            # Log the loss in the tqdm progress bar
            t.set_postfix(loss='{:05.3f}'.format(loss_val))
        mean_loss = np.array(loss_values).mean()
        print("- Train loss vertex: " + str(mean_loss))
        
        last_save_path = os.path.join("text_vertex_shapeglot", 'last')
        last_saver_vertex.save(sess, last_save_path, global_step=e + 1)
    
        num_steps = (VAL_SIZE + BATCH_SIZE - 1) // BATCH_SIZE
        sess.run(iterator_init_op_val)
        loss_values = []
        t = trange(num_steps)
        for i in t:
            loss_val = sess.run(vertex_model_loss_val)
            loss_values.append(loss_val)
            t.set_postfix(loss='{:05.3f}'.format(loss_val))
        mean_loss = np.array(loss_values).mean()
        print("- Eval loss vertex: " + str(mean_loss))
 
        if mean_loss<=best_v_loss:
            best_v_loss = mean_loss
            best_save_path = os.path.join("text_vertex_shapeglot", 'best')
            best_save_path = best_saver_vertex.save(sess, best_save_path, global_step=e + 1)
            print("- Found new best vertex model, saving in {}".format(best_save_path))

        if e>100 and e % check_step_samples==0:
            sess.run(iterator_init_op_val)
            v_samples_np, f_samples_np, b_np = sess.run((vertex_samples_val, face_samples_val, vertex_model_batch_val))
            print(token_ints_to_sentence(b_np['text_feature']))
            mesh_list = []
            for n in range(BATCH_SIZE):
                mesh_list.append({
                    'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],
                    'faces': data_utils.unflatten_faces(
                        f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])
                    })
            data_utils.plot_meshes(mesh_list, ax_lims=0.5)
                

2023-07-27 09:57:33.154538: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1669] Found device 0 with properties: 
name: NVIDIA GeForce RTX 3090 major: 8 minor: 6 memoryClockRate(GHz): 1.74
pciBusID: 0000:65:00.0
2023-07-27 09:57:33.154578: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2023-07-27 09:57:33.154597: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2023-07-27 09:57:33.154604: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2023-07-27 09:57:33.154610: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2023-07-27 09:57:33.154616: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcusolver.so.11
2023-07-27 09:57:33.154622: I tensorflow/stream_executor/plat

Epoch 1/500


  0%|                                                                           | 0/79 [00:00<?, ?it/s]2023-07-27 09:58:27.953523: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
  return np.array(vertices), np.array(faces)
  text = np.random.choice(texts)[:max_length]
2023-07-27 09:58:29.097785: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
100%|██████████████████████████████████████████████████| 79/79 [03:54<00:00,  2.96s/it, loss=59472.551]


- Train loss vertex: 51872.543


100%|██████████████████████████████████████████████████| 14/14 [00:11<00:00,  1.20it/s, loss=49038.879]


- Eval loss vertex: 43216.562
- Found new best vertex model, saving in text_vertex_shapeglot/best-1
Epoch 2/500


100%|██████████████████████████████████████████████████| 79/79 [02:56<00:00,  2.23s/it, loss=52437.812]


- Train loss vertex: 50180.895


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.14it/s, loss=44011.266]


- Eval loss vertex: 40390.7
- Found new best vertex model, saving in text_vertex_shapeglot/best-2
Epoch 3/500


100%|██████████████████████████████████████████████████| 79/79 [02:53<00:00,  2.20s/it, loss=47381.129]


- Train loss vertex: 44244.082


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.17it/s, loss=39816.707]


- Eval loss vertex: 34321.223
- Found new best vertex model, saving in text_vertex_shapeglot/best-3
Epoch 4/500


100%|██████████████████████████████████████████████████| 79/79 [02:54<00:00,  2.21s/it, loss=39879.605]


- Train loss vertex: 39615.055


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.19it/s, loss=36147.078]


- Eval loss vertex: 31000.309
- Found new best vertex model, saving in text_vertex_shapeglot/best-4
Epoch 5/500


100%|██████████████████████████████████████████████████| 79/79 [02:49<00:00,  2.14s/it, loss=46294.305]


- Train loss vertex: 36331.83


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.26it/s, loss=33765.055]


- Eval loss vertex: 29055.16
- Found new best vertex model, saving in text_vertex_shapeglot/best-5
Epoch 6/500


100%|██████████████████████████████████████████████████| 79/79 [02:52<00:00,  2.18s/it, loss=38560.195]


- Train loss vertex: 34283.39


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.15it/s, loss=30354.266]


- Eval loss vertex: 27377.316
- Found new best vertex model, saving in text_vertex_shapeglot/best-6
Epoch 7/500


100%|██████████████████████████████████████████████████| 79/79 [02:50<00:00,  2.16s/it, loss=36493.289]


- Train loss vertex: 32315.512


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.20it/s, loss=30250.656]


- Eval loss vertex: 25598.201
- Found new best vertex model, saving in text_vertex_shapeglot/best-7
Epoch 8/500


100%|██████████████████████████████████████████████████| 79/79 [02:52<00:00,  2.18s/it, loss=30913.875]


- Train loss vertex: 30208.379


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.22it/s, loss=27309.270]


- Eval loss vertex: 23322.41
- Found new best vertex model, saving in text_vertex_shapeglot/best-8
Epoch 9/500


100%|██████████████████████████████████████████████████| 79/79 [02:49<00:00,  2.15s/it, loss=28310.289]


- Train loss vertex: 27516.31


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.16it/s, loss=23745.328]


- Eval loss vertex: 21244.559
- Found new best vertex model, saving in text_vertex_shapeglot/best-9
Epoch 10/500


100%|██████████████████████████████████████████████████| 79/79 [02:49<00:00,  2.15s/it, loss=28279.914]


- Train loss vertex: 25328.227


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.17it/s, loss=22305.803]


- Eval loss vertex: 19588.059
- Found new best vertex model, saving in text_vertex_shapeglot/best-10
Epoch 11/500


100%|██████████████████████████████████████████████████| 79/79 [02:50<00:00,  2.16s/it, loss=26388.121]


- Train loss vertex: 23979.469


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.24it/s, loss=22812.193]


- Eval loss vertex: 18740.941
- Found new best vertex model, saving in text_vertex_shapeglot/best-11
Epoch 12/500


100%|██████████████████████████████████████████████████| 79/79 [02:50<00:00,  2.15s/it, loss=24009.742]


- Train loss vertex: 23180.156


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.19it/s, loss=21972.594]


- Eval loss vertex: 18172.277
- Found new best vertex model, saving in text_vertex_shapeglot/best-12
Epoch 13/500


100%|██████████████████████████████████████████████████| 79/79 [02:50<00:00,  2.16s/it, loss=25081.826]


- Train loss vertex: 22559.55


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.20it/s, loss=21656.945]


- Eval loss vertex: 17562.56
- Found new best vertex model, saving in text_vertex_shapeglot/best-13
Epoch 14/500


100%|██████████████████████████████████████████████████| 79/79 [02:49<00:00,  2.14s/it, loss=25079.188]


- Train loss vertex: 21966.775


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.21it/s, loss=21099.895]


- Eval loss vertex: 17201.158
- Found new best vertex model, saving in text_vertex_shapeglot/best-14
Epoch 15/500


100%|██████████████████████████████████████████████████| 79/79 [02:51<00:00,  2.17s/it, loss=22622.484]


- Train loss vertex: 21571.186


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.23it/s, loss=20753.219]


- Eval loss vertex: 17018.725
- Found new best vertex model, saving in text_vertex_shapeglot/best-15
Epoch 16/500


100%|██████████████████████████████████████████████████| 79/79 [02:51<00:00,  2.17s/it, loss=21128.797]


- Train loss vertex: 21204.41


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.22it/s, loss=20301.107]


- Eval loss vertex: 16668.73
- Found new best vertex model, saving in text_vertex_shapeglot/best-16
Epoch 17/500


100%|██████████████████████████████████████████████████| 79/79 [02:48<00:00,  2.13s/it, loss=20558.367]


- Train loss vertex: 20921.06


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.15it/s, loss=18329.016]


- Eval loss vertex: 16495.613
- Found new best vertex model, saving in text_vertex_shapeglot/best-17
Epoch 18/500


100%|██████████████████████████████████████████████████| 79/79 [02:50<00:00,  2.15s/it, loss=23263.969]


- Train loss vertex: 20704.244


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.20it/s, loss=19733.584]


- Eval loss vertex: 16256.7
- Found new best vertex model, saving in text_vertex_shapeglot/best-18
Epoch 19/500


100%|██████████████████████████████████████████████████| 79/79 [02:51<00:00,  2.17s/it, loss=19738.770]


- Train loss vertex: 20324.066


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.24it/s, loss=20144.129]


- Eval loss vertex: 15946.819
- Found new best vertex model, saving in text_vertex_shapeglot/best-19
Epoch 20/500


100%|██████████████████████████████████████████████████| 79/79 [02:49<00:00,  2.15s/it, loss=22412.285]


- Train loss vertex: 20053.434


100%|██████████████████████████████████████████████████| 14/14 [00:06<00:00,  2.15it/s, loss=18418.898]


- Eval loss vertex: 15860.804
- Found new best vertex model, saving in text_vertex_shapeglot/best-20
Epoch 21/500


 10%|█████▏                                             | 8/79 [00:18<02:50,  2.40s/it, loss=15045.428]

In [None]:
acc = []
loss = []
with tf.Session() as sess:
    for sh in text2shape(val_paths):
        a = sh['vertices']
        
    