In [1]:
"""Important imports"""
from polygen.modules.data_modules import PolygenDataModule, collate_vertex_model_batch, collate_face_model_batch
from polygen.modules.vertex_model import VertexModel
from polygen.modules.face_model import FaceModel

In [4]:
"""Get dataset ready"""
import glob
data_dir = "meshes/"
all_files = glob.glob(data_dir + "/*.obj")
label_dict = {}
for i, mesh_file in enumerate(all_files):
    label_dict[mesh_file] = i

vertex_data_module = PolygenDataModule(data_dir = data_dir, collate_fn = collate_vertex_model_batch, batch_size = 4,
                                        training_split = 1.0, val_split = 0.0, default_shapenet = False, all_files = all_files,
                                        label_dict = label_dict)

face_data_module = PolygenDataModule(data_dir = data_dir, collate_fn = collate_face_model_batch, batch_size = 4,
                                        training_split = 1.0, val_split = 0.0, default_shapenet = False, all_files = all_files,
                                        label_dict = label_dict)

vertex_data_module.setup()
face_data_module.setup()

vertex_dataloader = vertex_data_module.train_dataloader()
face_dataloader = face_data_module.train_dataloader()

In [13]:
"""Load models"""
transformer_config = {
    "hidden_size": 128,
    "fc_size": 256,
    "num_heads": 4,
    "num_layers": 3,
    "dropout_rate": 0.
}

vertex_model = VertexModel(decoder_config=transformer_config, class_conditional=True, num_classes=4,
                            max_num_input_verts=250, quantization_bits=8)

face_model = FaceModel(encoder_config=transformer_config, decoder_config=transformer_config, class_conditional=False,
                        max_seq_length=500, quantization_bits=8, decoder_cross_attention = True, use_discrete_vertex_embeddings=True)

In [14]:
"""Train Vertex Model"""
import torch
epochs = 500
vertex_model_optimizer = vertex_model.configure_optimizers()["optimizer"]
for i in range(epochs):
    for j, batch in enumerate(vertex_dataloader):
        vertex_model_optimizer.zero_grad()
        logits = vertex_model(batch)
        pred_dist = torch.distributions.categorical.Categorical(logits=logits)
        loss = -torch.sum(pred_dist.log_prob(batch["vertices_flat"])* batch["vertices_flat_mask"])
        loss.backward()
        vertex_model_optimizer.step()
    if i % 50 == 0:
        print(f"Epoch {i}: Loss = {loss.item()}")

Epoch 0: Loss = 2548.35693359375
Epoch 50: Loss = 2513.84326171875
Epoch 100: Loss = 2559.0986328125


KeyboardInterrupt: 

In [None]:
"""Train face model"""
face_model_optimizer = face_model.configure_optimizers()["optimizer"]
for i in range(epochs):
    for j, batch in enumerate(face_dataloader):
        face_model_optimizer.zero_grad()
        logits = face_model(batch)
        pred_dist = torch.distributions.categorical.Categorical(logits=logits)
        loss = -torch.sum(pred_dist.log_prob(batch["faces"])* batch["faces_mask"])
        loss.backward()
        face_model_optimizer.step()
    if i % 50 == 0:
        print(f"Epoch {i}: Loss = {loss.item()}")