In [1]:
from glob import glob
from tqdm import tqdm
import pandas as pd
import tensorflow.compat.v1 as tf

tf.logging.set_verbosity(tf.logging.ERROR)  # Hide TF deprecation messages
import numpy as np
import pickle
import modules
import data_utils



2023-07-31 12:42:02.194761: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [2]:
chair_meshes_paths = list(glob("chairs_ngon/*.obj"))
with open("chairs_split_dict.pickle", 'rb') as f:
    chairs_split_dict = pickle.load(f)
train_paths = []
val_paths = []
test_paths = []
for c in chair_meshes_paths:
    try:
        split = chairs_split_dict[c.split("/")[-1].replace(".obj", "")]
    except KeyError:
        continue
    if split =='train':
        train_paths.append(c)
    elif split =='val':
        val_paths.append(c)
    else:
        test_paths.append(c)

In [3]:
max_length = 30

In [4]:
captions = pd.read_csv("captions_tablechair.csv").dropna()

In [6]:
with open('chair_to_text_full.pickle', 'rb') as f:
    chair_to_text = pickle.load(f)
with open('word_to_int_full.pickle', 'rb') as f:
    word_to_int = pickle.load(f)
with open('int_to_word_full.pickle', 'rb') as f:
    int_to_word = pickle.load(f)

In [7]:
max_length = 30

In [9]:
def text2shape(paths):
    for path in paths:
        mesh_dict = data_utils.load_process_mesh(path)
        if len(mesh_dict['vertices'])>800:
            continue
        if len(mesh_dict['faces'])>2800:
            continue
        try:
            texts =  chair_to_text[path.split("/")[-1].replace(".obj", "")]
        except KeyError:
            continue
        text = np.random.choice(texts)[:max_length]
        mesh_dict['text_feature'] = np.pad(text, (0,max_length-len(text)))
        yield mesh_dict

In [11]:
Text2ShapeDataset = tf.data.Dataset.from_generator(
        lambda:text2shape(test_paths),
        output_types={
            'vertices': tf.int32, 'faces': tf.int32,
            'text_feature': tf.int32},
        output_shapes={
            'vertices': tf.TensorShape([None, 3]), 'faces': tf.TensorShape([None]),
            'text_feature':tf.TensorShape(max_length)})
vertex_model_dataset = data_utils.make_vertex_model_dataset(Text2ShapeDataset, apply_random_shift=False)
vertex_model_dataset = vertex_model_dataset.repeat()
vertex_model_dataset = vertex_model_dataset.padded_batch(16, padded_shapes=vertex_model_dataset.output_shapes)
vertex_model_dataset = vertex_model_dataset.prefetch(1)

In [12]:
vertex_model = modules.TextToVertexModel(
    decoder_config=dict(
      hidden_size=128,
      fc_size=512,
      num_heads=8,
      layer_norm=True,
      num_layers=24,
      dropout_rate=0.4,
      re_zero=True,
      memory_efficient=True
      ),
    path_to_embeddings="glove/glove.6B.100d.txt",
    embedding_dims = 100,
    quantization_bits=8,
    vocab = word_to_int,
    max_num_input_verts=800,  # number of vertices in the input mesh, if this is lower than the number of vertices in the mesh, there will be errors during training
    use_discrete_embeddings=True
)

Found 400000 word vectors.
Converted 3657 words (6 misses)


In [13]:
#borrowed from https://github.com/optas/shapeglot
def token_ints_to_sentence(tokens, int_to_word):
    text = [int_to_word[i] for i in tokens]
    text = ' '.join(text)
    stop = text.find('<EOS>')
    if stop == -1:
        stop = len(text)
    text = text[:stop]
    return text

In [19]:
it = vertex_model_dataset.make_initializable_iterator()
vertex_model_batch = it.get_next()
iterator_init_op = it.initializer
vertex_samples = vertex_model.sample(
    16, context=vertex_model_batch, max_sample_length=800, top_p=0.95,
    recenter_verts=False, only_return_complete=False)
saver_vertex = tf.train.Saver(var_list=vertex_model.variables)

In [None]:
face_module_config=dict(
  encoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=10,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  decoder_config=dict(
      hidden_size=512,
      fc_size=2048,
      num_heads=8,
      layer_norm=True,
      num_layers=14,
      dropout_rate=0.2,
      re_zero=True,
      memory_efficient=True,
      ),
  class_conditional=False,
  decoder_cross_attention=True,
  use_discrete_vertex_embeddings=True,
  max_seq_length=8000,
  )
face_model=modules.FaceModel(**face_module_config)
face_samples_val = face_model.sample(
    context=vertex_samples, max_sample_length=2800, top_p=0.95,
    only_return_complete=True)
face_model_saver = tf.train.Saver(var_list=face_model.variables)

In [None]:
mesh_list = []
num_samples_complete = 0
num_samples_min = 10

with tf.Session() as sess:
    saver_vertex.restore(sess, "text_vertex_shapeglot/best-237")
    face_model_saver.restore(sess, "face_model/model")
    sess.run(iterator_init_op)
    mesh_list = []
    num_samples_complete = 0
    while num_samples_complete < num_samples_min:
        batch, v_samples_np = sess.run((vertex_model_batch, vertex_samples))
        if v_samples_np['completed'].size == 0:
            print('No vertex samples completed in this batch. Try increasing max_num_vertices.')
            continue
        f_samples_np = sess.run(face_samples_val, {vertex_samples[k]: v_samples_np[k] for k in vertex_samples.keys()})
        v_samples_np = f_samples_np['context']
        num_samples_complete_batch = f_samples_np['completed'].sum()
        num_samples_complete += num_samples_complete_batch
        print('Num. samples complete: {}'.format(num_samples_complete))
        completed_indices = np.nonzero(f_samples_np['completed'])[0].tolist()
        for k in range(num_samples_complete_batch):
            print(token_ints_to_sentence(batch['text_feature'][completed_indices[k]], int_to_word))
            #add gt mesh
            mesh_list.append(
            {'vertices': data_utils.dequantize_verts(batch['vertices'][completed_indices[k]]),
                 'faces': data_utils.unflatten_faces(batch['faces'][completed_indices[k]])})
            verts = v_samples_np['vertices'][k][:v_samples_np['num_vertices'][k]]
            faces = data_utils.unflatten_faces(f_samples_np['faces'][k][:f_samples_np['num_face_indices'][k]])
            #add generated mesh
            mesh_list.append({'vertices': verts, 'faces': faces})


data_utils.plot_meshes(mesh_list, ax_lims=0.4)

In [None]:
data_utils.plot_meshes(mesh_list[:2], ax_lims=0.4)