In [None]:
import sys
sys.dont_write_bytecode=True

import torch

from modules import VoxelDataset
from modules import Decoder
from modules import plot_objt, dataloader_collate_fn, get_occurrence_map

from torch.utils.data import DataLoader
from torchsummary import summary

from tqdm import tqdm

device = 'cuda'

In [None]:
dataset_pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128)
}

In [None]:
train_dataset = VoxelDataset(dataset_pram['data_dir_pth'], dataset_pram['part_counts_npy_pth'], dataset_pram['outlier_objt_indices_npy_pth'])

train_dataloader = DataLoader(train_dataset, batch_size=dataset_pram['batch_size'], shuffle=False, collate_fn=dataloader_collate_fn)

In [None]:
model_pram = {
    'num_parts': train_dataset.num_parts,
    'latent_dim': (1, 16, 16, 16),
    'decoder_lr': 3e-5,
    'latent_lr': 3e-5,
}

In [None]:
decoder = Decoder(model_pram).to(device)

In [None]:
epoch = 1
best_avg_loss = float('inf')
no_improvement_cnt = 0

batch_size = dataset_pram['batch_size']
voxel_map_shape = dataset_pram['voxel_map_shape']

while True:
    tatal_loss = []

    pbar = tqdm(train_dataloader, desc='[EPOCH {}]'.format(epoch))

    for i, parts in enumerate(pbar):
        indices = [idx for idx in range(i*batch_size, i*batch_size+batch_size)]
        
        occurrence_maps = [get_occurrence_map(voxel_coords) for voxel_coords in parts]
        labels = torch.stack(occurrence_maps).view(batch_size, 1, *voxel_map_shape).to(device)
                    
        loss = decoder.train_step(indices, labels)
        
        tatal_loss.append(loss)
        
        avg_loss = sum(tatal_loss) / len(tatal_loss)
        
        pbar.set_postfix_str('Batch Loss: {:.4f} | Avg Loss: {:.4f}'.format(loss, avg_loss))
    
    if avg_loss < best_avg_loss:
        best_avg_loss = avg_loss
    else:
        no_improvement_cnt += 1
    
    if no_improvement_cnt == 3:
        print('No Improvement Count Reached.')
        break
    
    epoch += 1