In [1]:
%cd /ibex/user/slimhy/PADS/code
%reload_ext autoreload
%set_env CUBLAS_WORKSPACE_CONFIG=:4096:8
"""
Test demo for MLP mapper.
"""

import argparse
import numpy as np

import torch
import torch.backends.cudnn as cudnn

import models.s2vs as autoencoders
import util.misc as misc

/ibex/user/slimhy/PADS/code
env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [2]:
def get_args_parser():
    parser = argparse.ArgumentParser("Latent Diffusion", add_help=False)

    # Model parameters
    parser.add_argument(
        "--model",
        default="kl_d512_m512_l8_edm",
        type=str,
        metavar="MODEL",
        help="Name of model to train",
    )
    parser.add_argument(
        "--batch_size",
        default=32,
        type=int,
        help="Batch size per GPU (effective batch size is batch_size * accum_iter * # gpus",
    )
    parser.add_argument(
        "--text_model_name",
        type=str,
        help="Text model name to use",
    )
    parser.add_argument(
        "--ae",
        type=str,
        metavar="MODEL",
        help="Name of autoencoder",
    )
    parser.add_argument(
        "--ae-latent-dim",
        type=int,
        default=512*8,
        help="AE latent dimension",
    )
    parser.add_argument(
        "--ae_pth",
        required=True,
        help="Autoencoder checkpoint"
    )
    parser.add_argument(
        "--point_cloud_size",
        default=2048,
        type=int,
        help="input size"
    )
    parser.add_argument(
        "--fetch_keys",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--use_embeds",
        action="store_true",
        default=False,
    )
    parser.add_argument(
        "--resume",
        default="",
        help="Resume from checkpoint"
    )
    parser.add_argument(
        "--resume_weights",
        action="store_true",
        default=False,
        help="Only resume weights, not optimizer state",
    )
    parser.add_argument(
        "--resume_full_weights",
        action="store_true",
        default=False,
        help="Resume the full model weights with the EDM wrapper",
    )

    # Dataset parameters
    parser.add_argument(
        "--device", default="cuda", help="device to use for training / testing"
    )
    parser.add_argument("--seed", default=0, type=int)
    parser.add_argument("--num_workers", default=60, type=int)
    parser.add_argument(
        "--pin_mem",
        action="store_true",
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )

    return parser


# Set dummy arg string to debug the parser
call_string = """--ae_pth ckpt/ae_m512.pth \
    --ae kl_d512_m512_l8 \
    --ae-latent-dim 4096 \
    --num_workers 8 \
    --batch_size 2 \
    --device cuda \
    --fetch_keys \
    --use_embeds \
    --seed 0"""

# Parse the arguments
args = get_args_parser()
args = args.parse_args(call_string.split())
args.data_path = "/ibex/project/c2273/PADS/3DCoMPaT/"

# --------------------
device = torch.device(args.device)

# Fix the seed for reproducibility
seed = args.seed + misc.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

cudnn.benchmark = True
args.fetch_keys = True
# --------------------

# Instantiate autoencoder
ae = autoencoders.__dict__[args.ae]()
ae.eval()
print("Loading autoencoder %s" % args.ae_pth)
ae.load_state_dict(torch.load(args.ae_pth, map_location="cpu")["model"])
ae = ae.to(device)

Loading autoencoder ckpt/ae_m512.pth


In [3]:
from datasets.metadata import class_to_idx 
import models.diffusion as diffusion
from util import misc


def load_ckpt(model_type):
    args.model = model_type
    args.resume = "/ibex/user/slimhy/PADS/code/ckpt/dm/kl_d512_m512_l8.pth"

    model = diffusion.__dict__[model_type]()
    model.eval()
    
    misc.load_model(args, model)
    return model.to(device)


CACHE = {}
def load_ckpt_cached(model_type, cache=CACHE):
    if model_type in cache:
        return cache[model_type]

    model = load_ckpt(model_type)
    cache[model_type] = model
    return model


@torch.inference_mode()
def sample_batch(model_type, class_name, n_samples=10, num_steps=20):
    class_id = class_to_idx(class_name, dataset="shapenet")
    cond = torch.Tensor([class_id]*n_samples).long().to(device)
    model = load_ckpt_cached(model_type=model_type)
    samples = model.sample(cond=cond, num_steps=num_steps)
    return samples


In [4]:
import util.s2vs as s2vs
def decode_latents(model, latents, grid_density=128, batch_size=128**3):
    # Decode the latents
    with torch.no_grad():
        mesh = s2vs.decode_latents(
            ae=model,
            latent=latents[0].unsqueeze(0),
            grid_density=grid_density,
            batch_size=batch_size,
        )
    return mesh
sampled_x = sample_batch("kl_d512_m512_l8_d24", "chair", n_samples=10, num_steps=20)
decode_latents(ae, sampled_x.float()).show()

Resume checkpoint /ibex/user/slimhy/PADS/code/ckpt/dm/kl_d512_m512_l8.pth


In [5]:
from datasets.metadata import class_to_hex
from datasets.latents import ShapeLatentDataset, ComposedPairedShapesLoader, PairType

# Create your datasets
dataset_train = ShapeLatentDataset(
    args.data_path,
    class_code=class_to_hex("chair"),
    split="train",
    shuffle_parts=True,
    filter_n_ids=8
)

loader = ComposedPairedShapesLoader(
    dataset_train,
    batch_size=args.batch_size,
    pair_types_list=[PairType.NO_ROT_PAIR],
    num_workers=4,
    seed=0,
    shuffle=False,
    drop_last=True,
)

l_a, bb_a, bb_l_a, meta_a = loader.get_tuple(device=device)



In [6]:
import models.diffusion as dm
import util.misc as misc

args.resume = "/ibex/user/slimhy/PADS/code/ckpt/dm/kl_d512_m512_l8.pth"
model = dm.kl_d512_m512_l8_d24().to(device)
misc.load_model(args, model)

Resume checkpoint /ibex/user/slimhy/PADS/code/ckpt/dm/kl_d512_m512_l8.pth


In [7]:
sampled_x = model.sample(cond=torch.zeros(1, dtype=torch.long).to(device), num_steps=20)
decode_latents(ae, sampled_x.float()).show()

In [8]:
import models.diffusion as dm
import util.misc as misc

args.resume = "/ibex/user/slimhy/PADS/output/pq_dm/pq_dm_pq_dm__rec_100/checkpoint-280.pth"
model = dm.kl_d512_m512_l8_d24_pq().to(device)
misc.load_model(args, model)

Resume checkpoint /ibex/user/slimhy/PADS/output/pq_dm/pq_dm_pq_dm__rec_100/checkpoint-280.pth


In [9]:
cond, _ = model.pqe(l_a, bb_a, bb_l_a, bb_l_a != -1)
sampled_x = model.sample(cond=cond)

In [10]:
decode_latents(ae, sampled_x.float()).show()

In [11]:
decode_latents(ae, l_a.float()).show()