In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

from modules import VoxelDataset
from modules import Encoder
from modules import Decoder
from modules import plot_objt_by_dataset, plot_objt_by_models, plot_part_by_voxel_coords, dataloader_collate_fn, get_voxel_map

from torch.utils.data import DataLoader
from torch.backends import cudnn

from tqdm import tqdm

cudnn.benchmark = True

device = 'cuda'

In [None]:
pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128),
    'train_test_split_ratio_train': 0.9,
    'latent_dim': (64, 64),
    'decoder_lr': 3e-5,
    'encoder_lr': 3e-5,
}

In [None]:
train_dataset = VoxelDataset(pram['data_dir_pth'],
                             pram['part_counts_npy_pth'],
                             pram['outlier_objt_indices_npy_pth'],
                             designate_num_objts=None,
                             train_test_split_ratio_train=pram['train_test_split_ratio_train'],
                             is_train=True)

train_dataloader = DataLoader(train_dataset,
                              batch_size=pram['batch_size'],
                              shuffle=False,
                              collate_fn=dataloader_collate_fn,
                              pin_memory=True)

In [None]:
test_dataset = VoxelDataset(pram['data_dir_pth'],
                            pram['part_counts_npy_pth'],
                            pram['outlier_objt_indices_npy_pth'],
                            designate_num_objts=None,
                            train_test_split_ratio_train=pram['train_test_split_ratio_train'],
                            is_train=False)

test_dataloader = DataLoader(test_dataset,
                             batch_size=1,
                             shuffle=False,
                             collate_fn=dataloader_collate_fn,
                             pin_memory=True)

In [None]:
encoder = Encoder(pram['latent_dim']).to(device)

decoder = Decoder(pram['latent_dim']).to(device)

In [None]:
saved_ckpt = torch.load('./models/encoder_decoder_saved.pt')

In [None]:
encoder.load_state_dict(saved_ckpt['encoder_state_dict'])

decoder.load_state_dict(saved_ckpt['decoder_state_dict'])

In [None]:
encoder.eval()

decoder.eval()

pass

In [None]:
voxel = get_voxel_map(test_dataset[0], device)
voxel = voxel.view(1, 1, *voxel.shape)

In [None]:
latent = torch.rand((1, 1, 64, 64), device=device)

pred = torch.sigmoid(decoder(latent))

torch.max(pred)

In [None]:
(pred > 0.5).nonzero().shape

In [None]:
plot_objt_by_dataset(test_dataset, 3)

In [None]:
plot_objt_by_models(encoder, decoder, test_dataset, 0, 0.5, device)