In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

from modules import VoxelDataset
from modules import Decoder
from modules import LatentVariables
from modules import BinaryFocalLoss
from modules import plot_objt_by_dataset, plot_objt_by_latents

from torch import optim
from torch.backends import cudnn

from tqdm import tqdm

cudnn.benchmark = True

device = 'cuda'

In [None]:
pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128),
    'designate_num_objts': 10,
    'train_test_split_ratio_train': 0.9,
    'latent_dim': (64, 64),
    'decoder_lr': 3e-5,
    'latents_lr': 3e-4,
}

In [None]:
train_dataset = VoxelDataset(pram['data_dir_pth'],
                             pram['part_counts_npy_pth'],
                             pram['outlier_objt_indices_npy_pth'],
                             designate_num_objts=pram['designate_num_objts'],
                             train_test_split_ratio_train=pram['train_test_split_ratio_train'],
                             is_train=True)

In [None]:
test_dataset = VoxelDataset(pram['data_dir_pth'],
                            pram['part_counts_npy_pth'],
                            pram['outlier_objt_indices_npy_pth'],
                            designate_num_objts=pram['designate_num_objts'],
                            train_test_split_ratio_train=pram['train_test_split_ratio_train'],
                            is_train=False)

In [None]:
decoder = Decoder(pram['latent_dim'], vd_ch=8).to(device)

In [None]:
# saved_ckpt = torch.load('./models/encoder_decoder_saved.pt')

In [None]:
# decoder.load_state_dict(saved_ckpt['decoder_state_dict'])

In [None]:
USER_DEFINED_PARTS_NUM = 30

latent_vars = LatentVariables(USER_DEFINED_PARTS_NUM, pram['latent_dim']).to(device)

In [None]:
latents_optim = optim.AdamW(latent_vars.parameters(), lr=pram['latents_lr'])

In [None]:
loss_fn = BinaryFocalLoss(gamma=5, alpha=0.8)

In [None]:
def train_step(target):
    latents_optim.zero_grad()
    
    latents = latent_vars(None)[0]
        
    outputs = []
    for latent in latents:
        outputs.append(decoder(latent.view(1, 1, *latent.shape)))
    outputs = torch.stack(outputs, dim=1)[0]
            
    pred_voxels = torch.sum(torch.sigmoid(outputs), dim=0)
        
    loss = loss_fn(pred_voxels, target)
    loss.backward()
    
    latents_optim.step()

    return loss

In [None]:
epoch = 0
best_loss = float('inf')
no_improvement_cnt = 0

decoder.eval()
latent_vars.train()

dataset = train_dataset
target_idx = 0

base_idx = sum(dataset.each_chair_part_counts[:target_idx])

target_voxel_map = torch.zeros(pram['voxel_map_shape'], dtype=torch.float32, device=device)

for i in range(base_idx, base_idx+dataset.each_chair_part_counts[target_idx]):
    for x, y, z in dataset.parts_voxel_coords[i]:
        target_voxel_map[x, y, z] = 1.0

target_voxel_map = target_voxel_map.view(1, *target_voxel_map.shape)

pbar = tqdm()

while True:
    loss = train_step(target_voxel_map)
    
    if loss < best_loss:
        best_loss = loss
        no_improvement_cnt = 0
    else:
        no_improvement_cnt += 1

    if no_improvement_cnt == 3:
        print('No Improvement Count Reached.')
        break
    
    pbar.set_postfix_str('Loss: {:.6f}'.format(loss))
    pbar.update(1)

In [None]:
plot_objt_by_latents(decoder, latent_vars.latents)