In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
"""
Test demo for MLP mapper.
"""

import argparse
import numpy as np

import torch
import torch.backends.cudnn as cudnn

import models.s2vs as autoencoders
import util.misc as misc

/ibex/user/slimhy/PADS/code


In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("Latent Diffusion", add_help=False)

    # Model parameters
    parser.add_argument(
        "--model",
        default="kl_d512_m512_l8_edm",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_weights",
        action="store_true",
        default=False,
        help="Only resume weights, not optimizer state",
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    # Dataset parameters
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --num_workers 8 \
    --device cuda \
    --fetch_keys \
    --use_embeds \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True
args.fetch_keys = True
# --------------------

# Instantiate autoencoder
ae = autoencoders.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae.to(device)

Loading autoencoder ckpt/ae_m512.pth


KLAutoEncoder(
  (cross_attend_blocks): ModuleList(
    (0): PreNorm(
      (fn): Attention(
        (to_q): Linear(in_features=512, out_features=512, bias=False)
        (to_kv): Linear(in_features=512, out_features=1024, bias=False)
        (to_out): Linear(in_features=512, out_features=512, bias=True)
        (drop_path): Identity()
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
      (norm_context): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
    (1): PreNorm(
      (fn): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=512, out_features=4096, bias=True)
          (1): GEGLU()
          (2): Linear(in_features=2048, out_features=512, bias=True)
        )
        (drop_path): Identity()
      )
      (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    )
  )
  (point_embed): PointEmbed(
    (mlp): Linear(in_features=51, out_features=512, bias=True)
  )
  (layers): ModuleList(
    (0-23): 24 x ModuleLi

In [3]:
import models.diffusion as diffusion
from util import misc


def load_ckpt(model_type):
    args.model = model_type
    args.resume = "/ibex/user/slimhy/PADS/output/diffusion_models/kl_d512_m512_l8.pth"

    model = diffusion.__dict__[model_type]()
    model.eval()
    
    misc.load_model(args, model)
    return model.to(device)


CACHE = {}
def load_ckpt_cached(model_type, cache=CACHE):
    if model_type in cache:
        return cache[model_type]

    model = load_ckpt(model_type)
    cache[model_type] = model
    return model


@torch.inference_mode()
def sample_batch(model_type, n_samples=10, num_steps=20):
    cond = torch.Tensor([0]*n_samples).long().to(device)
    model = load_ckpt_cached(model_type=model_type)
    samples = model.sample(cond=cond, num_steps=num_steps)
    return samples


In [4]:
sampled_x = sample_batch("kl_d512_m512_l8_d24", n_samples=10, num_steps=20)

Resume checkpoint /ibex/user/slimhy/PADS/output/diffusion_models/kl_d512_m512_l8.pth


In [5]:
import util.s2vs as s2vs

x = sampled_x[0].unsqueeze(0).float()
cuda_mesh = s2vs.decode_latents(ae, x, grid_density=128, batch_size=128**3)

In [6]:
cuda_mesh.trimesh_mesh.show()