In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

from modules import VoxelDataset
from modules import Diffusion
from modules import Encoder
from modules import Decoder
from modules import VanillaDiffusionSampler

from torch.utils.data import DataLoader
from torch.backends import cudnn

from tqdm import tqdm

cudnn.benchmark = True

device = 'cuda'

In [None]:
pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128),
    'designate_num_objts': 1,
    'train_test_split_ratio_train': 1,
    'latent_dim': (64, 64),
    'beta_start': 1e-4,
    'beta_end': 2e-2,
    'time_steps': 1000,
    'warmup': 10
}

In [None]:
diffusion = Diffusion().to(device)

saved_ckpt = torch.load('./models/diffusion_saved.pt')

diffusion.load_state_dict(saved_ckpt['diffusion_state_dict'])

In [None]:
decoder = Decoder(pram['latent_dim'], vd_ch=8).to(device)

saved_ckpt = torch.load('./models/encoder_decoder_8.pt')

decoder.load_state_dict(saved_ckpt['decoder_state_dict'])

In [None]:
sampler = VanillaDiffusionSampler(diffusion, pram['beta_start'], pram['beta_end'], pram['time_steps']).to(device)

In [None]:
x_T = torch.randn((pram['batch_size'], 1, *pram['latent_dim']), device=device)

x_t = x_T
for time_step in tqdm(reversed(range(pram['time_steps']))):
    x_t = sampler(x_t, time_step)
x_0 = x_t

In [None]:
pred = torch.sigmoid(decoder(x_0))

torch.max(pred)