In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

import numpy as np

from modules import VoxelDataset
from modules import Decoder, LatentVariables, BCELoss
from modules import plot_objt_by_dataset, plot_objt_by_decoder, plot_part_by_voxel_coords, dataloader_collate_fn, get_voxel_map

from torch import optim
from torch import nn
from torch.utils.data import DataLoader, Subset

from tqdm import tqdm

device = 'cuda'

In [None]:
dataset_pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128),
    'train_test_split_ratio_train': 1
}

In [None]:
dataset = VoxelDataset(dataset_pram['data_dir_pth'],
                       dataset_pram['part_counts_npy_pth'],
                       dataset_pram['outlier_objt_indices_npy_pth'],
                       designate_num_objts=1)

train_size = int(len(dataset) * dataset_pram['train_test_split_ratio_train'])
test_size = len(dataset) - train_size

print(f'Training dataset size: {train_size} parts; Testing dataset size: {test_size} parts')

train_dataset = Subset(dataset, range(train_size))
test_dataset = Subset(dataset, range(train_size, train_size + test_size))

train_dataloader = DataLoader(train_dataset, batch_size=dataset_pram['batch_size'], shuffle=False, collate_fn=dataloader_collate_fn)
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=dataloader_collate_fn)

In [None]:
model_pram = {
    'num_parts': len(train_dataset.indices),
    'latent_dim': (1, 64, 64),
    'decoder_lr': 1e-4,
    'latent_lr': 3e-4,
}

In [None]:
decoder = Decoder(model_pram['latent_dim']).to(device)

latent_vars = LatentVariables(model_pram['num_parts'], model_pram['latent_dim']).to(device)

In [None]:
decoder_optim = optim.AdamW(decoder.parameters(), lr=model_pram['decoder_lr'])

latent_optim = optim.AdamW(latent_vars.parameters(), lr=model_pram['decoder_lr'])

In [None]:
loss_fn = BCELoss(gamma=0.8)

In [None]:
def train_step(indices, targets):
    decoder_optim.zero_grad()
    latent_optim.zero_grad()

    latent = latent_vars(indices)
    outputs = decoder(latent)
        
    loss = loss_fn(outputs, targets, logits=True)
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=1.0)
    torch.nn.utils.clip_grad_norm_(latent_vars.parameters(), max_norm=1.0)
    
    decoder_optim.step()
    latent_optim.step()

    return loss

In [None]:
epoch = 0
best_avg_loss = float('inf')
no_improvement_cnt = 0

if not os.path.isdir('./models'):
    os.mkdir('./models')

batch_size = dataset_pram['batch_size']
voxel_map_shape = dataset_pram['voxel_map_shape']

train_epochs = 999

while epoch < train_epochs:
    epoch += 1
    
    tatal_loss = []
    
    decoder.train()

    pbar = tqdm(train_dataloader, desc='[EPOCH {}]'.format(epoch))
    
    for i, parts in enumerate(pbar):
        indices = [idx for idx in range(i*batch_size, i*batch_size+batch_size)]

        voxel_maps = [get_voxel_map(voxel_coords, device) for voxel_coords in parts]
                
        targets = torch.stack(voxel_maps).view(batch_size, 1, *voxel_map_shape)

        loss = train_step(indices, targets)

        tatal_loss.append(loss)

        avg_loss = sum(tatal_loss) / len(tatal_loss)

        pbar.set_postfix_str('Batch Loss: {:.6f} | Avg Loss: {:.6f}'.format(loss, avg_loss))

    if avg_loss < best_avg_loss:
        best_avg_loss = avg_loss
        no_improvement_cnt = 0
        
        torch.save({
            'decoder_state_dict': decoder.state_dict(),
            'latent_vars_state_dict': latent_vars.state_dict(),
            'loss': avg_loss,
            'epoch': epoch
        }, f'./models/decoder_saved.pt')
    else:
        no_improvement_cnt += 1

    if no_improvement_cnt == 10:
        print('No Improvement Count Reached.')
        break

In [None]:
saved = torch.load(f'./models/decoder_saved.pt')

In [None]:
decoder.load_state_dict(saved['decoder_state_dict'])
latent_vars.load_state_dict(saved['latent_vars_state_dict'])

In [None]:
torch.max(nn.Sigmoid()(decoder(latent_vars.latents[0][None,:])))

In [None]:
plot_objt_by_decoder(decoder, latent_vars, dataset.each_chair_part_counts, 0, 0.9)