In [1]:
"""
Test demo for MLP mapper.
"""
import argparse
import datetime
import h5py
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
from transformers import AutoTokenizer, CLIPTextModel, BertTokenizer, BertModel

import util.misc as misc
from engine_node2node import get_text_embeddings
from util.datasets import build_shape_surface_occupancy_dataset

In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("Latent Diffusion", add_help=False)

    # Model parameters
    parser.add_argument(
        "--model",
        default="kl_d512_m512_l8_edm",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--intensity_loss",
        action="store_true",
        default=False,
        help="Contrastive edit intensity loss using ground-truth labels.",
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_weights",
        action="store_true",
        default=False,
        help="Only resume weights, not optimizer state",
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    # Dataset parameters
    parser.add_argument(
        "--dataset",
        type=str,
        choices=["graphedits"],
        help="dataset name",
    )
    parser.add_argument(
        "--data_path",
        type=str,
        help="dataset path",
    )
    parser.add_argument(
        "--data_type",
        type=str,
        help="dataset type",
    )
    parser.add_argument(
        "--max_edge_level",
        default=None,
        type=int,
        help="maximum edge level to use",
    )
    parser.add_argument(
        "--alt_ae_embeds",
        type=str,
        default=None,
        help="Alternative autoencoder embeddings to use",
    )
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument(
        "--ft_bert",
        action="store_true",
        default=False,
        help="Also fine-tune the BERT model",
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser

In [3]:
# Set dummy arg string to debug the parser
call_string = """--ae_pth /ibex/user/slimhy/Shape2VecSet/output/graph_edit/ae/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --text_model_name bert-base-uncased \
    --dataset graphedits \
    --data_path /ibex/user/slimhy/ShapeWalk/ \
    --data_type release \
    --batch_size 32 \
    --num_workers 8 \
    --model mlp_mapper_bert_l1 \
    --resume /ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_l1/checkpoint-100.pth \
    --resume_full_weights \
    --device cuda \
    --fetch_keys \
    --use_embeds \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

In [4]:
import models.autoencoders as autoencoders

# --------------------
args.use_clip = "clip" in args.text_model_name
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True

args.fetch_keys = True
dataset_train = build_shape_surface_occupancy_dataset("train", args=args)
dataset_val = build_shape_surface_occupancy_dataset("val", args=args)

# Create data loaders
data_loader_train = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)

data_loader_val = torch.utils.data.DataLoader(
    dataset_val,
    batch_size=args.batch_size,
    shuffle=False,
    num_workers=args.num_workers,
    pin_memory=True,
    drop_last=True,
)
# --------------------

# Instantiate autoencoder
ae = autoencoders.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae.to(device)




Loading autoencoder /ibex/user/slimhy/Shape2VecSet/output/graph_edit/ae/ae_m512.pth


KLAutoEncoder(
  (cross_attend_blocks): ModuleList(
    (0): PreNorm(
      (fn): Attention(
        (to_q): Linear(in_features=512, out_features=512, bias=False)
        (to_kv): Linear(in_features=512, out_features=1024, bias=False)
        (to_out): Linear(in_features=512, out_features=512, bias=True)
        (drop_path): Identity()
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm_context): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): PreNorm(
      (fn): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=4096, bias=True)
          (1): GEGLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (drop_path): Identity()
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (point_embed): PointEmbed(
    (mlp): Linear(in_features=51, out_features=512, bias=True)
  )
  (layers): ModuleList(
    (0-23): 24 x ModuleLi

In [5]:
import models.mlp_mapper as mlp_mapper

model = mlp_mapper.__dict__[args.model](use_linear_proj=not args.use_clip)
model.to(device)

# Load the checkpoint
if args.resume:
    print("Loading checkpoint [%s]..." % args.resume)
    checkpoint = torch.load(args.resume, map_location="cpu")
    model.load_state_dict(checkpoint["model"])
    print("Done.")

Loading checkpoint [/ibex/user/slimhy/Shape2VecSet/output/graph_edit/dm/mlp_mapper_l1/checkpoint-100.pth]...
Done.


In [6]:
def apply_edit(net, x_a, x_b, embed_ab):
    # Reshape from (B, D, K) to (B, M)
    x_a = x_a.flatten(1)
    x_b = x_b.flatten(1)
    embed_ab = embed_ab.flatten(1)

    # Concatenate the latent vector with the embedding
    edit_vec = net(x_a, embed_ab)

    # Add the edit vector to the latent vector
    return x_a + edit_vec

In [7]:
import mcubes
import trimesh

@torch.inference_mode()
def decode_latent(kk, latent, fcode="", density=32):
    latent = torch.tensor(latent).cuda().reshape(1,512,8).type(torch.float32)
    density = density
    gap = 2. / density
    x = np.linspace(-1, 1, density+1)
    y = np.linspace(-1, 1, density+1)
    z = np.linspace(-1, 1, density+1)
    xv, yv, zv = np.meshgrid(x, y, z)
    grid = torch.from_numpy(np.stack([xv, yv, zv]).astype(np.float32)).view(3, -1).transpose(0, 1)[None].to(device, non_blocking=True)
    print(grid.shape)
 
    logits = ae.decode(latent, grid)
    logits = logits.detach()
    
    volume = logits.view(density+1, density+1, density+1).permute(1, 0, 2).cpu().numpy()
    verts, faces = mcubes.marching_cubes(volume, 0)
 
    verts *= gap
    verts -= 1
 
    m = trimesh.Trimesh(verts, faces)
    m.export('decoded_shapes/node_%d_%s.obj' % (kk, fcode))


In [8]:
for data in data_loader_val:
    break
edit_key, embeds_a, embeds_b, embeds_text = data



In [9]:
# Move all the garbage to CUDA
embeds_a = embeds_a.cuda()
embeds_b = embeds_b.cuda()
embeds_text = embeds_text.cuda()
x_b_edited = apply_edit(model, embeds_a, embeds_b, embeds_text)
x_b = embeds_b
x_a = embeds_a

In [11]:
# Decode a few of the edited latent vectors
for kk in range(5):
    decode_latent(kk, x_b_edited[kk].cpu().detach().numpy(), fcode="b_pred", density=128)
    # Also save the GTs
    decode_latent(kk, x_b[kk].cpu().detach().numpy(), fcode="b_gt", density=128)
    decode_latent(kk, x_a[kk].cpu().detach().numpy(), fcode="a_gt", density=128)

torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])


In [12]:
trimesh.load("decoded_shapes/node_0_a_gt.obj").show()
trimesh.load("decoded_shapes/node_0_b_gt.obj").show()
trimesh.load("decoded_shapes/node_0_b_pred.obj").show()

In [13]:
dset_json = "/ibex/user/slimhy/ShapeWalk/release/release_val.json"
dset_json = json.load(open(dset_json))

def show_side_by_side(kk):
    print(edit_key[kk])
    print(dset_json[edit_key[kk]]['prompt'], dset_json[edit_key[kk]]['edge_intensity'])

    # First load all the meshes
    mesh_a = trimesh.load("decoded_shapes/node_%d_a_gt.obj" % kk)
    mesh_b = trimesh.load("decoded_shapes/node_%d_b_gt.obj" % kk)
    mesh_b_pred = trimesh.load("decoded_shapes/node_%d_b_pred.obj" % kk)

    # Translate a to the left, b to the right
    mesh_a.apply_translation([-2, 0, 0])
    mesh_b_pred.apply_translation([1.5, 0, 0])

    # Combine them into a single scene
    scene = trimesh.Scene([mesh_a, mesh_b, mesh_b_pred])
    scene.show()
    return scene.show()

In [14]:
show_side_by_side(2)

1654770748290731882_1601164496463705984
Remove armrests. 0


In [15]:
show_side_by_side(3)

1100475754727085274_859722239200822368
Decrease legs thickness significantly. 5


In [16]:
show_side_by_side(0)

990219307528931364_505636938342785934
Increase legs indentation vastly. 9


In [18]:
rec_all = torch.load("rec.pt")

In [19]:
rec_all.shape

torch.Size([32, 4096, 3])

In [21]:
@torch.inference_mode()
def encode_latent(pointcloud):   
    return ae.encode(pointcloud)

In [30]:
latents = encode_latent(rec_all.reshape(-1, 2048, 3))

In [35]:
latents = latents[1]

In [36]:
# Decode a few of the edited latent vectors
for kk in range(5):
    decode_latent(kk, latents[kk].cpu().detach().numpy(), fcode="b_pred", density=128)

torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
torch.Size([1, 2146689, 3])
