In [1]:
from glob import glob
from tqdm import tqdm
import pandas as pd
import tensorflow.compat.v1 as tf

tf.logging.set_verbosity(tf.logging.ERROR)  # Hide TF deprecation messages
import matplotlib.pyplot as plt
import numpy as np
import pickle
import modules
import data_utils

ModuleNotFoundError: No module named 'tensorflow'

In [None]:
chair_meshes_paths = list(glob("chairs_ngon/*.obj"))
with open("chairs_split_dict.pickle", 'rb') as f:
    chairs_split_dict = pickle.load(f)
chairs_train = []
chairs_val = []
chairs_test = []
for c in chair_meshes_paths:
    try:
        split = chairs_split_dict[c.split("/")[-1].replace(".obj", "")]
    except KeyError:
#         print(c.split("/")[-1].replace(".obj", ""))
        continue
    if split =='train':
        chairs_train.append(c)
    elif split =='val':
        chairs_val.append(c)
    else:
        chairs_test.append(c)

In [None]:
train_paths = chairs_train
val_paths = chairs_val
max_length = 30

In [None]:
captions = pd.read_csv("captions_tablechair.csv").dropna()

In [None]:
train_captions=[]
for index, row in captions.iterrows():
    try:
        # if row["category"]=="Table":
        #     if tables_split_dict[row["modelId"]]=='train':
        #         train_captions.append(row['description'])
        if row["category"]=="Chair":
            if chairs_split_dict[row["modelId"]]=='train':
                train_captions.append(row['description'])
    except KeyError:
        continue

In [None]:
from tensorflow.keras.preprocessing.text import Tokenizer
tk = Tokenizer()
tk.fit_on_texts(train_captions)

In [None]:
def text2shape(paths, captions, tokenizer):
    for path in paths:
        # with open(path, 'rb') as obj_file:
        mesh_dict = data_utils.load_process_mesh(path)
#         mesh_dict['class_label'] = 18 if cls=="Chair" else 49
        if len(mesh_dict['vertices'])>500:
            continue
        if len(mesh_dict['faces'])>2600:
            continue
        # mesh_dict = random_scaling(mesh_dict)
        try:
            text = captions[captions["modelId"]==path.split("/")[-1].replace(".obj", "")].sample(n=1)["description"].values[0]
        except:
            continue
        text = text.lower().replace("the", '').replace("a", '').replace("of", '').replace("for", '').replace("and", '').replace("to", '').replace("in", '')
        text = " ".join(text.split(" ")[:max_length])
        text = tokenizer.texts_to_sequences([text])[0]
        mesh_dict['text_feature'] = np.pad(text, (0,max_length-len(text)))
        yield mesh_dict

In [None]:
Text2ShapeDatasetVal = tf.data.Dataset.from_generator(
        lambda:text2shape(val_paths, captions, tk),
        output_types={
            'vertices': tf.int32, 'faces': tf.int32,
#             'class_label': tf.int32,
            'text_feature': tf.int32},
        output_shapes={
            'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]),
#             'class_label': tf.TensorShape(()),
            'text_feature':tf.TensorShape(max_length)})
vertex_model_dataset_val = data_utils.make_vertex_model_dataset(Text2ShapeDatasetVal, apply_random_shift=False)
vertex_model_dataset_val = vertex_model_dataset_val.repeat()
vertex_model_dataset_val = vertex_model_dataset_val.padded_batch(BATCH_SIZE, padded_shapes=vertex_model_dataset_val.output_shapes)
vertex_model_dataset_val = vertex_model_dataset_val.prefetch(1)
itv = vertex_model_dataset_val.make_initializable_iterator()
vertex_model_batch_val = itv.get_next()
iterator_init_op_val = itv.initializer

In [None]:
vertex_model = modules.TextToVertexModel(
    decoder_config=dict(
      hidden_size=128,
      fc_size=512,
      num_heads=8,
      layer_norm=True,
      num_layers=24,
      dropout_rate=0.4,
      re_zero=True,
      memory_efficient=True
      ),
    path_to_embeddings="glove.6B/glove.6B.100d.txt",
    embedding_dims = 100,
    quantization_bits=8,
    tokenizer=tk,
    max_num_input_verts=500,  # number of vertices in the input mesh, if this is lower than the number of vertices in the mesh, there will be errors during training
    use_discrete_embeddings=True
)

In [None]:
# vertex_model_pred_dist_val = vertex_model(vertex_model_batch_val)
# vertex_model_loss_val = -tf.reduce_sum(
#     vertex_model_pred_dist_val.log_prob(vertex_model_batch_val['vertices_flat']) *
#     vertex_model_batch_val['vertices_flat_mask'])
vertex_samples_val = vertex_model.sample(
    1, context=vertex_model_batch_val, max_sample_length=500, top_p=0.95,
    recenter_verts=False, only_return_complete=True)

In [None]:

face_module_config=dict(
  encoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=10,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  decoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=14,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  class_conditional=False,
  decoder_cross_attention=True,
  use_discrete_vertex_embeddings=True,
  max_seq_length=8000,
  )
face_model=modules.FaceModel(**face_module_config)
face_samples_val = face_model.sample(
    context=vertex_samples_val, max_sample_length=2000, top_p=0.95,
    only_return_complete=True)
face_model_saver = tf.train.Saver(var_list=face_model.variables)

In [None]:
saver_vertex = tf.train.Saver(var_list=vertex_model.variables)

In [None]:
mesh_list = []
num_samples_complete = 0
num_samples_min = 1

with tf.Session() as sess:
  saver_vertex.restore(sess, "text_vertex/last-141")
  face_model_saver.restore(sess, "face_model/model")
  mesh_list = []
  num_samples_complete = 0
  while num_samples_complete < num_samples_min:
    v_samples_np = sess.run(vertex_samples_val)
    print(vertex_samples_val['text_feature'])
    if v_samples_np['completed'].size == 0:
      print('No vertex samples completed in this batch. Try increasing ' +
            'max_num_vertices.')
      continue
    f_samples_np = sess.run(
        face_samples_val,
        {vertex_samples_val[k]: v_samples_np[k] for k in vertex_samples_val.keys()})
    v_samples_np = f_samples_np['context']
    num_samples_complete_batch = f_samples_np['completed'].sum()
    num_samples_complete += num_samples_complete_batch
    print('Num. samples complete: {}'.format(num_samples_complete))
    for k in range(num_samples_complete_batch):
      verts = v_samples_np['vertices'][k][:v_samples_np['num_vertices'][k]]
      faces = data_utils.unflatten_faces(
          f_samples_np['faces'][k][:f_samples_np['num_face_indices'][k]])
      mesh_list.append({'vertices': verts, 'faces': faces})


data_utils.plot_meshes(mesh_list, ax_lims=0.4)