In [None]:
import os
import shared.data_utils as data_utils
import torch
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import modules
#import modules2 as modules
import time
from contextlib import nullcontext

In [None]:
dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16' # 'float32', 'bfloat16', or 'float16', the latter will auto implement a GradScaler
device_type = 'cuda'
device = torch.device(device_type)
ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torch.float16}[dtype]
ctx = nullcontext() if device_type == 'cpu' else torch.autocast(device_type=device_type, dtype=ptdtype)

In [None]:
random_seed = 123
torch.manual_seed(random_seed)
torch.cuda.manual_seed(random_seed)
batch_size = 4
shuffle = False

## Prepare synthetic dataset

In [None]:
class MeshDataset(Dataset):
    def __init__(self, mesh_list):
        self.mesh_list = mesh_list 

    def __len__(self):
        return len(self.mesh_list)

    def __getitem__(self, idx):
        return self.mesh_list[idx]
    
    def map(self, func):
        for i in range(len(self.mesh_list)):
           self.mesh_list[i] = func(self.mesh_list[i])
        return self

# Prepare synthetic dataset
ex_list = []
for k, mesh in enumerate(['cube', 'cylinder', 'cone', 'icosphere']):
    mesh_dict = data_utils.load_process_mesh(
        os.path.join('meshes', '{}.obj'.format(mesh)))
    mesh_dict['class_label'] = torch.tensor(k)
    mesh_dict['vertices'] = torch.tensor(mesh_dict['vertices'])
    mesh_dict['faces'] = torch.tensor(mesh_dict['faces'])
    ex_list.append(mesh_dict)

synthetic_dataset = MeshDataset(ex_list)

# Plot the meshes
mesh_list = []

for shape in synthetic_dataset:
    mesh_list.append(
      {'vertices': data_utils.dequantize_verts(torch.tensor(shape['vertices'])),
        'faces': data_utils.unflatten_faces(torch.tensor(shape['faces']))})
data_utils.plot_meshes(mesh_list, ax_lims=0.4)

## Create vertex model

In [None]:
# Prepare the dataset for vertex model training
def pad_batch(batch):
    # group matching keys in batch
    items = list(zip(*[item.values() for item in batch]))
    packed_dict = {}
    for i, key in enumerate(batch[0].keys()):
        if items[i][0].dim() == 0:
            padded_values = torch.tensor(items[i], device=device)
        else:
            padded_values = torch.nn.utils.rnn.pad_sequence(items[i], batch_first=True, padding_value=0.).to(device)
        packed_dict[key] = padded_values
    return packed_dict

vertex_model_dataset = data_utils.make_vertex_model_dataset(
    synthetic_dataset, apply_random_shift=False)


vertex_model_dataloader = iter(DataLoader(vertex_model_dataset, 
                                          shuffle=shuffle, 
                                          batch_size=batch_size, 
                                          collate_fn=pad_batch))
vertex_model_batch = next(vertex_model_dataloader)

max_num_input_verts=250

decoder_config={
        'embd_size': 128,
        'fc_size': 512, 
        'num_layers': 3,
        'dropout_rate': 0.,
        'take_context_embedding': False
}

# Create vertex model
vertex_model = modules.VertexModel(
    decoder_config=decoder_config,
    class_conditional=True,
    context_type='label',
    num_classes=4,
    max_num_input_verts=max_num_input_verts,
    quantization_bits=8,
    device=device
).to(device=device)

with ctx:
    vertex_model_pred_dist = vertex_model(vertex_model_batch)

vertex_model_loss = -torch.sum(
    vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) * 
    vertex_model_batch['vertices_flat_mask'])

with ctx:
    vertex_samples = vertex_model.sample(
    4, context=vertex_model_batch, max_sample_length=max_num_input_verts, top_p=0.95,
    recenter_verts=False, only_return_complete=False)

print(vertex_model_batch)
print(vertex_model_pred_dist)
print(vertex_samples)

## Create face model

In [None]:
face_model_dataset = data_utils.make_face_model_dataset(
    synthetic_dataset, apply_random_shift=False)

face_model_dataloader = iter(DataLoader(face_model_dataset, 
                                        shuffle=shuffle, 
                                        batch_size=batch_size,
                                        collate_fn=pad_batch))
face_model_batch = next(face_model_dataloader)

encoder_config={
        'embd_size': 128,
        'fc_size': 512, 
        'num_layers': 3,
        'num_heads': 4,
        'dropout_rate': 0.
}
decoder_config={
    'embd_size': 128,
    'fc_size': 512, 
    'num_layers': 3,
    'dropout_rate': 0.,
    'num_heads': 4,
    'take_context_embedding': True
}

# Create face model
face_model = modules.FaceModel(
    encoder_config=encoder_config,
    decoder_config=decoder_config,
    class_conditional=False,
    max_seq_length=1000,
    quantization_bits=8,
    max_num_input_verts=max_num_input_verts,
    decoder_cross_attention=True,
    use_discrete_vertex_embeddings=True,
    device=device
).to(device=device)

with ctx:
    face_model_pred_dist = face_model(face_model_batch)
face_model_loss = -torch.sum(face_model_pred_dist.log_prob(face_model_batch['faces']) * 
    face_model_batch['faces_mask'])

with ctx:
    face_samples = face_model.sample(
    context=vertex_samples, max_sample_length=500, top_p=0.95,
    only_return_complete=False)
print(face_model_batch)
print(face_model_pred_dist)
print(face_samples)

## Train models

In [None]:
# Optimization settings
learning_rate = 8e-4 #3e-4
training_steps = 300
log_step = 5
sample_step = 100
n_samples = 4

# Create an optimizer an minimize the summed log probability of the mesh sequences
face_model_optim = torch.optim.AdamW(face_model.parameters(), lr=learning_rate)
vertex_model_optim = torch.optim.AdamW(vertex_model.parameters(), lr=learning_rate)

vertex_model_dataloader = DataLoader(vertex_model_dataset, 
                                     shuffle=shuffle, 
                                     batch_size=batch_size, 
                                     collate_fn=pad_batch)
face_model_dataloader = DataLoader(face_model_dataset, 
                                   shuffle=shuffle, 
                                   batch_size=batch_size, 
                                   collate_fn=pad_batch)
vertex_model_dataloader_iter = iter(vertex_model_dataloader)
face_model_dataloader_iter = iter(face_model_dataloader)

# Training loop
for n in range(training_steps):
    try:
      vertex_model_batch = next(vertex_model_dataloader_iter)
    except StopIteration:
      vertex_model_dataloader_iter = iter(vertex_model_dataloader)
      vertex_model_batch = next(vertex_model_dataloader_iter)

    try:
      face_model_batch = next(face_model_dataloader_iter)
    except StopIteration:
      face_model_dataloader_iter = iter(face_model_dataloader)
      face_model_batch = next(face_model_dataloader_iter)
    
    t = time.time()

    with ctx:
      vertex_model_pred_dist = vertex_model(vertex_model_batch)
      
      
      face_model_pred_dist = face_model(face_model_batch)
      
      
    vertex_model_loss = -torch.sum(
        vertex_model_pred_dist.log_prob(vertex_model_batch['vertices_flat']) * 
        vertex_model_batch['vertices_flat_mask'])  
    
    face_model_loss = -torch.sum(face_model_pred_dist.log_prob(face_model_batch['faces']) * 
        face_model_batch['faces_mask'])
    
    # Run the optimization step after sample so it uses the old parameters
    vertex_model_optim.zero_grad()
    vertex_model_loss.backward()
    torch.nn.utils.clip_grad_norm_(vertex_model.parameters(), 1.0)
    vertex_model_optim.step()

    face_model_optim.zero_grad()
    face_model_loss.backward()
    torch.nn.utils.clip_grad_norm_(face_model.parameters(), 1.0)
    face_model_optim.step()

    # time forward pass
    dt = time.time() - t
    if n % log_step == 0:
        print('Step {}'.format(n))
        print('Loss (vertices) {}'.format(vertex_model_loss))
        print('Loss (faces) {}'.format(face_model_loss)) 
        print('Time (ms): {}'.format(dt * 1000))

        if n % sample_step == 0:
          with ctx:
            vertex_samples = vertex_model.sample(
              n_samples, context=vertex_model_batch, max_sample_length=200, top_p=0.95,
              recenter_verts=False, only_return_complete=False)
            
            face_samples = face_model.sample(
              context=vertex_samples, max_sample_length=500, top_p=0.95,
              only_return_complete=False)
          
          mesh_list = []
          for n in range(min(n_samples, batch_size)):
              mesh_list.append(
                {
                  'vertices': vertex_samples['vertices'][n][:vertex_samples['num_vertices'][n]].cpu(),
                  'faces': data_utils.unflatten_faces(
                    face_samples['faces'][n][:face_samples['num_face_indices'][n]].cpu())
                }
              )
          try:
            # sometimes theres an error when training on cpu... not sure why
            data_utils.plot_meshes(mesh_list, ax_lims=0.5)
          except:
            print("Error plotting meshes... skipping.")

## Inference

In [None]:
vertex_model_dataloader = DataLoader(vertex_model_dataset, shuffle=shuffle, batch_size=batch_size, collate_fn=pad_batch)
vertex_model_dataloader_iter = iter(vertex_model_dataloader)
vertex_model_batch = next(vertex_model_dataloader_iter)

for key, value in vertex_model_batch.items():
    vertex_model_batch[key] = value.to('cuda')

vertex_samples = vertex_model.to('cuda').sample(
    n_samples, context=vertex_model_batch, max_sample_length=200, top_p=0.95,
    recenter_verts=False, only_return_complete=False)    

face_samples = face_model.to('cuda').sample(
    context=vertex_samples, max_sample_length=500, top_p=0.95,
    only_return_complete=False)

mesh_list = []
for n in range(min(n_samples, batch_size)):
    mesh_list.append(
    {
        'vertices': vertex_samples['vertices'][n][:vertex_samples['num_vertices'][n]].to('cpu'),
        'faces': data_utils.unflatten_faces(
        face_samples['faces'][n][:face_samples['num_face_indices'][n]].to('cpu'))
    }
    )
    
data_utils.plot_meshes(mesh_list, ax_lims=0.5)