In [1]:
from glob import glob
from tqdm import tqdm
import pandas as pd
import tensorflow.compat.v1 as tf
tf.logging.set_verbosity(tf.logging.ERROR)  # Hide TF deprecation messages
import matplotlib.pyplot as plt
import trimesh
import random
import numpy as np
import pickle
import modules
import data_utils



2023-07-19 23:39:50.977391: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [2]:
import datetime
from tensorflow import summary as s
log_dir = "logs/text_gen/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# summary_writer = s.FileWriter(log_dir)

In [3]:
BATCH_SIZE=2

In [4]:
chair_meshes_paths = list(glob("chairs_ngon/*.obj"))
# chair_meshes_paths = [["Chair", path] for path in chair_meshes_paths]
tables_meshes_paths = list(glob("tables_ngon/*.obj"))
# tables_meshes_paths = [["Table", path] for path in tables_meshes_paths]
# chair_meshes_paths.extend(tables_meshes_paths)
# paths = chair_meshes_paths.copy()
# random.shuffle(paths)

In [5]:
with open("chairs_split_dict.pickle", 'rb') as f:
    chairs_split_dict = pickle.load(f)

In [6]:
with open("tables_split_dict.pickle", 'rb') as f:
    tables_split_dict = pickle.load(f)

In [7]:
chairs_train = []
chairs_val = []
chairs_test = []
for c in chair_meshes_paths: 
    try:
        split = chairs_split_dict[c.split("/")[-1].replace(".obj", "")]
    except KeyError:
#         print(c.split("/")[-1].replace(".obj", ""))
        continue
    if split =='train':
        chairs_train.append(c)
    elif split =='val':
        chairs_val.append(c)
    else:
        chairs_test.append(c)
# print(len(chairs_train))
# print(len(chairs_val))
# print(len(chairs_test))

In [8]:
tables_train = []
tables_val = []
tables_test = []
for c in tables_meshes_paths: 
    try:
        split = tables_split_dict[c.split("/")[-1].replace(".obj", "")]
    except KeyError:
#         print(c.split("/")[-1].replace(".obj", ""))
        continue
    if split =='train':
        tables_train.append(c)
    elif split =='val':
        tables_val.append(c)
    else:
        tables_test.append(c)
# print(len(tables_train))
# print(len(tables_val))
# print(len(tables_test))

In [9]:
chairs_train.extend(tables_train)
train_paths = chairs_train.copy()
random.shuffle(train_paths)
# train_paths = train_paths[:10]

In [10]:
TRAIN_SIZE = len(train_paths)

In [11]:
chairs_val.extend(tables_val)
val_paths = chairs_val.copy()
random.shuffle(val_paths)

In [12]:
VAL_SIZE = len(val_paths)

In [13]:
captions = pd.read_csv("captions_tablechair.csv").dropna()

In [14]:
# captions

In [15]:
max_length = 0
for c in captions['description'].values:
    cur = len(c.split(" "))
    if cur>max_length:
        max_length =cur
# max_length

In [16]:
train_captions=[]
for index, row in captions.iterrows():
    try:
        if row["category"]=="Table":
            if tables_split_dict[row["modelId"]]=='train':
                train_captions.append(row['description'])
        if row["category"]=="Chair":
            if chairs_split_dict[row["modelId"]]=='train':
                train_captions.append(row['description'])
    except KeyError:
        continue
# print(len(train_captions))

In [17]:
from tensorflow.keras.preprocessing.text import Tokenizer
tk = Tokenizer()
tk.fit_on_texts(train_captions)

In [18]:
def text2shape(paths, captions, tokenizer):
    for path in paths:
        # with open(path, 'rb') as obj_file:
        mesh_dict = data_utils.load_process_mesh(path)
#         mesh_dict['class_label'] = 18 if cls=="Chair" else 49
#         if len(mesh_dict['vertices'])>200:
#             continue
        if len(mesh_dict['faces'])>2600:
            continue
        try:
            text = captions[captions["modelId"]==path.split("/")[-1].replace(".obj", "")].sample(n=1)["description"].values[0]
        except:
            continue
        text = text.lower().replace("the", '').replace("a", '').replace("of", '').replace("for", '').replace("and", '').replace("to", '').replace("in", '')
        text = tokenizer.texts_to_sequences([text])[0]
        mesh_dict['text_feature'] = np.pad(text, (0,max_length-len(text)))
        yield mesh_dict

In [19]:
Text2ShapeDataset = tf.data.Dataset.from_generator(
        lambda:text2shape(train_paths, captions, tk),
        output_types={
            'vertices': tf.int32, 'faces': tf.int32,
#             'class_label': tf.int32,
            'text_feature': tf.int32},
        output_shapes={
            'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]),
#             'class_label': tf.TensorShape(()),
            'text_feature':tf.TensorShape(140)})

In [20]:
vertex_model_dataset = data_utils.make_vertex_model_dataset(Text2ShapeDataset, apply_random_shift=False)
vertex_model_dataset = vertex_model_dataset.repeat()
vertex_model_dataset = vertex_model_dataset.padded_batch(BATCH_SIZE, padded_shapes=vertex_model_dataset.output_shapes)
vertex_model_dataset = vertex_model_dataset.prefetch(1)
it = vertex_model_dataset.make_initializable_iterator()
vertex_model_batch = it.get_next()
iterator_init_op_train = it.initializer

In [21]:
Text2ShapeDatasetVal = tf.data.Dataset.from_generator(
        lambda:text2shape(val_paths, captions, tk),
        output_types={
            'vertices': tf.int32, 'faces': tf.int32,
#             'class_label': tf.int32,
            'text_feature': tf.int32},
        output_shapes={
            'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]),
#             'class_label': tf.TensorShape(()),
            'text_feature':tf.TensorShape(140)})
vertex_model_dataset_val = data_utils.make_vertex_model_dataset(Text2ShapeDatasetVal, apply_random_shift=False)
vertex_model_dataset_val = vertex_model_dataset_val.repeat()
vertex_model_dataset_val = vertex_model_dataset_val.padded_batch(BATCH_SIZE, padded_shapes=vertex_model_dataset_val.output_shapes)
vertex_model_dataset_val = vertex_model_dataset_val.prefetch(1)
itv = vertex_model_dataset_val.make_initializable_iterator()
vertex_model_batch_val = itv.get_next()
iterator_init_op_val = itv.initializer

In [22]:
o = (i for i in text2shape(train_paths, captions, tk))
TRAIN_SIZE = sum(1 for _ in o)
o = (i for i in text2shape(val_paths, captions, tk))
VAL_SIZE = sum(1 for _ in o)

  return np.array(vertices), np.array(faces)


In [23]:
TRAIN_SIZE

5606

In [24]:
VAL_SIZE

902

In [25]:
vertex_model = modules.TextToVertexModel(
    decoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=24,
      dropout_rate=0.4,
      re_zero=True,
      memory_efficient=True
      ),
    path_to_embeddings="glove.42B.300d.txt",
    embedding_dims = 300,
    quantization_bits=8,
    tokenizer=tk,
    max_num_input_verts=5000,  # number of vertices in the input mesh, if this is lower than the number of vertices in the mesh, there will be errors during training
    use_discrete_embeddings=True
)

Found 1917494 word vectors.
Converted 8256 words (1212 misses)


In [26]:
vertex_model_pred_dist = vertex_model(vertex_model_batch)
vertex_model_loss = -tf.reduce_sum(
    vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) *
    vertex_model_batch['vertices_flat_mask'])
vertex_samples = vertex_model.sample(
    BATCH_SIZE, context=vertex_model_batch, max_sample_length=1500, top_p=0.95,
    recenter_verts=False, only_return_complete=False)

print(vertex_model_batch)
print(vertex_model_pred_dist)


{'vertices': <tf.Tensor 'IteratorGetNext:2' shape=(?, ?, 3) dtype=int32>, 'faces': <tf.Tensor 'IteratorGetNext:0' shape=(?, ?) dtype=int32>, 'text_feature': <tf.Tensor 'IteratorGetNext:1' shape=(?, 140) dtype=int32>, 'vertices_flat': <tf.Tensor 'IteratorGetNext:3' shape=(?, ?) dtype=int32>, 'vertices_flat_mask': <tf.Tensor 'IteratorGetNext:4' shape=(?, ?) dtype=float32>}
tfp.distributions.Categorical("vertex_model_2/vertex_model/create_dist/Categorical/", batch_shape=[?, ?], event_shape=[], dtype=int32)


In [27]:
vertex_model_pred_dist_val = vertex_model(vertex_model_batch_val)
vertex_model_loss_val = -tf.reduce_sum(
    vertex_model_pred_dist_val.log_prob(vertex_model_batch_val['vertices_flat']) *
    vertex_model_batch_val['vertices_flat_mask'])
vertex_samples_val = vertex_model.sample(
    BATCH_SIZE, context=vertex_model_batch_val, max_sample_length=1500, top_p=0.95,
    recenter_verts=False, only_return_complete=False)

In [28]:
# with open("only_text_vars.pickle", 'rb') as f:
#     only_text_vars = pickle.load(f)

# text_vars = []
# for var in vertex_model.variables:
#     if '/'.join((var._variable._name).split('/')[1:]) in only_text_vars:
#         text_vars.append(var)
# text_vars=tuple(text_vars)
# # print(text_vars)

In [29]:
import pickle
with open("pretrained_vars.pickle", 'rb') as f:
    common_vars = pickle.load(f)

pretrained_vars = []
for var in vertex_model.variables:
    if '/'.join((var._variable._name).split('/')[1:]) in common_vars:
        pretrained_vars.append(var)
pretrained_vars=tuple(pretrained_vars)
# print(pretrained_vars)

In [30]:
vertex_model_saver = tf.train.Saver(var_list=pretrained_vars)

In [31]:
# Create face model
face_model = modules.FaceModel(
      encoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=10,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  decoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=14,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
    class_conditional=False,
    max_seq_length=8000, # number of faces in the input mesh, if this is lower than the number of vertices in the mesh, there will be errors during training
    quantization_bits=8,
    decoder_cross_attention=True,
    use_discrete_vertex_embeddings=True,
)

In [32]:
face_samples_val = face_model.sample(
    context=vertex_samples_val, max_sample_length=1000, top_p=0.95,
    only_return_complete=False)

In [33]:
face_model_saver = tf.train.Saver(var_list=face_model.variables)

In [None]:
from tqdm import trange
import os

last_saver = tf.train.Saver(var_list=vertex_model.variables) # will keep last 5 epochs
best_saver = tf.train.Saver(var_list=vertex_model.variables, max_to_keep=2)  # only keep 1 best checkpoint (best on eval)

# %matplotlib inline 
learning_rate = 5e-4
training_steps = 500
check_step_metrics = 10
check_step_samples = 100
EPOCHS = 25
optimizer = tf.train.AdamOptimizer(learning_rate)
# vertex_model_optim_op = optimizer.minimize(vertex_model_loss, var_list=text_vars)
vertex_model_optim_op = optimizer.minimize(vertex_model_loss, var_list=vertex_model.variables)
best_v_loss = float('inf')
# Training loop
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    
    vertex_model_saver.restore(sess, "vertex_model/model")
    face_model_saver.restore(sess, 'face_model/model')
    
#     train_writer = s.FileWriter(os.path.join(log_dir, 'train_summaries'), sess.graph)
#     eval_writer = s.FileWriter(os.path.join(log_dir, 'eval_summaries'), sess.graph)
    
    for e in range(EPOCHS):
        print("Epoch {}/{}".format(e + 1, EPOCHS))
        num_steps = (TRAIN_SIZE + BATCH_SIZE - 1) // BATCH_SIZE
        sess.run(iterator_init_op_train)
        t = trange(num_steps)
        loss_values = []
        for i in t:
            sess.run(vertex_model_optim_op)

            loss_val = sess.run(vertex_model_loss)
            loss_values.append(loss_val)
            # Log the loss in the tqdm progress bar
            t.set_postfix(loss='{:05.3f}'.format(loss_val))
        mean_loss = np.array(loss_values).mean()
        print("- Train loss: " + str(mean_loss))
        
        last_save_path = os.path.join("text_gen_pretrained", 'last')
        last_saver.save(sess, last_save_path, global_step=e + 1)
    
        num_steps = (VAL_SIZE + BATCH_SIZE - 1) // BATCH_SIZE
        sess.run(iterator_init_op_val)
        loss_values = []
        t = trange(num_steps)
        for i in t:
            loss_val = sess.run(vertex_model_loss_val)
            loss_values.append(loss_val)
            t.set_postfix(loss='{:05.3f}'.format(loss_val))
        mean_loss = np.array(loss_values).mean()
        print("- Eval metrics: " + str(mean_loss))

        if mean_loss<=best_v_loss:
            best_v_loss = mean_loss
            best_save_path = os.path.join("text_gen_pretrained", 'best')
            best_save_path = best_saver.save(sess, best_save_path, global_step=e + 1)
            print("- Found new best model, saving in {}".format(best_save_path))
                
        #SAmples
        if e>20:
            sess.run(iterator_init_op_val)
            v_samples_np, f_samples_np, b_np = sess.run((vertex_samples_val, face_samples_val, vertex_model_batch_val))
            mesh_list = []
            for n in range(BATCH_SIZE):
                mesh_list.append({
                    'vertices': v_samples_np['vertices'][n][:v_samples_np['num_vertices'][n]],
                    'faces': data_utils.unflatten_faces(
                        f_samples_np['faces'][n][:f_samples_np['num_face_indices'][n]])
                    })
            data_utils.plot_meshes(mesh_list, ax_lims=0.5)
                

2023-07-19 23:42:36.086412: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
2023-07-19 23:42:36.157777: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1669] Found device 0 with properties: 
name: NVIDIA GeForce RTX 3090 major: 8 minor: 6 memoryClockRate(GHz): 1.74
pciBusID: 0000:65:00.0
2023-07-19 23:42:36.157836: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
2023-07-19 23:42:36.190374: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
2023-07-19 23:42:36.330758: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcufft.so.10
2023-07-19 23:42:36.332168: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcurand.so.10
2023-07-19 23:42:36.379217: I tensorflow/stream_executor/platform/

Epoch 1/25


  0%|                                                                                                                  | 0/2803 [00:00<?, ?it/s]2023-07-19 23:43:37.156135: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcublas.so.11
  return np.array(vertices), np.array(faces)
2023-07-19 23:43:37.938881: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudnn.so.8
100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [21:28<00:00,  2.18it/s, loss=156.681]


- Train loss: 580.20966


100%|███████████████████████████████████████████████████████████████████████████████████████████| 451/451 [01:05<00:00,  6.91it/s, loss=685.813]


- Eval metrics: 545.78906
- Found new best model, saving in text_gen_pretrained/best-1
Epoch 2/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:40<00:00,  2.37it/s, loss=147.795]


- Train loss: 488.99664


100%|███████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.61it/s, loss=713.644]


- Eval metrics: 568.5297
Epoch 3/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:39<00:00,  2.38it/s, loss=159.339]


- Train loss: 466.3405


100%|███████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.63it/s, loss=713.567]


- Eval metrics: 600.6265
Epoch 4/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:38<00:00,  2.38it/s, loss=135.900]


- Train loss: 389.41794


100%|███████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.61it/s, loss=843.733]


- Eval metrics: 651.78143
Epoch 5/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:40<00:00,  2.37it/s, loss=152.627]


- Train loss: 388.4648


100%|███████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.61it/s, loss=855.284]


- Eval metrics: 705.25555
Epoch 6/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:38<00:00,  2.38it/s, loss=151.475]


- Train loss: 338.52786


100%|███████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.61it/s, loss=910.198]


- Eval metrics: 746.0661
Epoch 7/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:39<00:00,  2.38it/s, loss=135.456]


- Train loss: 312.0437


100%|██████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.63it/s, loss=1001.966]


- Eval metrics: 832.432
Epoch 8/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:39<00:00,  2.38it/s, loss=131.761]


- Train loss: 274.9527


100%|██████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.63it/s, loss=1093.651]


- Eval metrics: 890.1261
Epoch 9/25


100%|█████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:41<00:00,  2.37it/s, loss=104.839]


- Train loss: 238.92009


100%|██████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.61it/s, loss=1115.816]


- Eval metrics: 949.01544
Epoch 10/25


100%|██████████████████████████████████████████████████████████████████████████████████████████| 2803/2803 [19:43<00:00,  2.37it/s, loss=96.648]


- Train loss: 217.06757


100%|██████████████████████████████████████████████████████████████████████████████████████████| 451/451 [00:59<00:00,  7.61it/s, loss=1216.680]


- Eval metrics: 982.214
Epoch 11/25


 39%|██████████████████████████████████▍                                                      | 1084/2803 [07:36<13:04,  2.19it/s, loss=142.023]