In [None]:
%cd ..

In [2]:
from pathlib import Path

import torch as th
import torch.nn.functional as F
import numpy as np
import yaml
from easydict import EasyDict

from src.utils import instantiate_from_config
from src.utils.vis import save_sdf_as_mesh

In [3]:
th.set_grad_enabled(False)
th.cuda.set_device(4)

# Load Pretrained Models

In [4]:
gen32_args_path = "config/gen32/shapenet.yaml"
gen32_ckpt_path = "results/gen32/shapenet.pth"
sr64_args_path = "config/sr32_64/shapenet.yaml"
sr64_ckpt_path = "results/sr32_64/shapenet.pth"

In [5]:
with open(gen32_args_path) as f:
    args1 = EasyDict(yaml.safe_load(f))
with open(sr64_args_path) as f:
    args2 = EasyDict(yaml.safe_load(f))

In [6]:
model1 = instantiate_from_config(args1.model).cuda()
ckpt = th.load(gen32_ckpt_path, map_location="cpu")
model1.load_state_dict(ckpt["model_ema"])

<All keys matched successfully>

In [7]:
model2 = instantiate_from_config(args2.model).cuda()
ckpt = th.load(sr64_ckpt_path, map_location="cpu")
model2.load_state_dict(ckpt["model"])

<All keys matched successfully>

In [8]:
ddpm_sampler1 = instantiate_from_config(args1.ddpm.valid).cuda()
ddpm_sampler2 = instantiate_from_config(args2.ddpm.valid).cuda()

In [9]:
preprocessor1 = instantiate_from_config(args1.preprocessor, "cuda")
preprocessor2 = instantiate_from_config(args2.preprocessor, "cuda")

# Generate Low-Resolution ($32^3$)

In [10]:
c = th.randint(0, 13, (1,), dtype=th.int64, device="cuda")

In [12]:
out1 = ddpm_sampler1.sample_ddim(lambda x, t: model1(x, t, c=c), (1, 1, 32, 32, 32), show_pbar=True)

Sample DDIM: 100%|██████████████████████████████████████████████████████████████████████████████| 50/50 [00:04<00:00, 11.41it/s]


In [13]:
out1 = preprocessor1.destandardize(out1)
out1.shape

torch.Size([1, 1, 32, 32, 32])

In [14]:
# save as an obj file
save_sdf_as_mesh("gen32.obj", out1, safe=True)

# Super-Resolve to High-Resolution ($64^3$)

In [15]:
lr_cond = F.interpolate(out1, (64, 64, 64), mode="nearest")
lr_cond = preprocessor2.standardize(lr_cond, 0)
out2 = ddpm_sampler2.sample_ddim(lambda x, t: model2(th.cat([lr_cond, x], 1), t, c=c), (1, 1, 64, 64, 64), show_pbar=True)

Sample DDIM: 100%|██████████████████████████████████████████████████████████████████████████████| 50/50 [00:09<00:00,  5.46it/s]


In [16]:
out2 = preprocessor2.destandardize(out2, 1)
out2.shape

torch.Size([1, 1, 64, 64, 64])

In [17]:
# save as an obj file
save_sdf_as_mesh("sr64.obj", out2, safe=True)