In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

import numpy as np

from modules import VoxelDataset
from modules import Decoder, LatentVariables, BCELoss
from modules import plot_objt_by_dataset, plot_objt_by_decoder, plot_part_by_voxel_coords, dataloader_collate_fn, get_voxel_map

from torch import optim
from torch import nn
from torch.utils.data import DataLoader
from torchsummary import summary

from tqdm import tqdm

device = 'cuda'

In [None]:
dataset_pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128)
}

In [None]:
train_dataset = VoxelDataset(dataset_pram['data_dir_pth'],
                             dataset_pram['part_counts_npy_pth'],
                             dataset_pram['outlier_objt_indices_npy_pth'],
                             designate_num_objts=1)

train_dataloader = DataLoader(train_dataset, batch_size=dataset_pram['batch_size'], shuffle=False, collate_fn=dataloader_collate_fn)

In [None]:
pos_weights = []

for batch in train_dataloader:
    for part in batch:
        n_pos = len(part)
        n_neg = np.prod(dataset_pram['voxel_map_shape']) - n_pos
        pos_weights.append(n_pos / n_neg)

pos_weights = np.array(pos_weights, dtype=np.float32)

In [None]:
model_pram = {
    'num_parts': train_dataset.num_parts,
    'latent_dim': (1, 64, 64),
    'decoder_lr': 1e-4,
    'latent_lr': 1e-3,
}

In [None]:
decoder = Decoder(model_pram['latent_dim']).to(device)

latent_vars = LatentVariables(model_pram['num_parts'], model_pram['latent_dim']).to(device)

In [None]:
decoder_optim = optim.AdamW(decoder.parameters(), lr=model_pram['decoder_lr'])

latent_optim = optim.AdamW(latent_vars.parameters(), lr=model_pram['decoder_lr'])

In [None]:
loss_fn = BCELoss(gamma=0.8)

In [None]:
def train_step(indices, targets):
    decoder_optim.zero_grad()
    latent_optim.zero_grad()

    latent = latent_vars(indices)
    outputs = decoder(latent)
        
    loss = loss_fn(outputs, targets, logits=True)
    loss.backward()
    
    decoder_optim.step()
    latent_optim.step()

    return loss

In [None]:
epoch = 0
best_avg_loss = float('inf')
no_improvement_cnt = 0

batch_size = dataset_pram['batch_size']
voxel_map_shape = dataset_pram['voxel_map_shape']

train_epochs = 10

while epoch < train_epochs:
    epoch += 1
    
    tatal_loss = []
    
    decoder.train()

    pbar = tqdm(train_dataloader, desc='[EPOCH {}]'.format(epoch))
    
    for i, parts in enumerate(pbar):
        indices = [idx for idx in range(i*batch_size, i*batch_size+batch_size)]

        voxel_maps = [torch.tensor(get_voxel_map(voxel_coords), dtype=torch.float32) for voxel_coords in parts]
                
        targets = torch.stack(voxel_maps).view(batch_size, 1, *voxel_map_shape).to(device)

        loss = train_step(indices, targets)

        tatal_loss.append(loss)

        avg_loss = sum(tatal_loss) / len(tatal_loss)

        pbar.set_postfix_str('Batch Loss: {:.6f} | Avg Loss: {:.6f}'.format(loss, avg_loss))

    if avg_loss < best_avg_loss:
        best_avg_loss = avg_loss
        no_improvement_cnt = 0
    else:
        no_improvement_cnt += 1

    if no_improvement_cnt == 10:
        print('No Improvement Count Reached.')
        break

if not os.path.isdir('./models'):
    os.mkdir('./models')

torch.save({
    'decoder_state_dict': decoder.state_dict(),
    'latent_vars_state_dict': latent_vars.state_dict(),
    'loss': avg_loss,
    'epoch': epoch
}, f'./models/decoder_saved.pt')

In [None]:
save = torch.load(f'./models/decoder_saved.pt')

In [None]:
decoder.load_state_dict(save['decoder_state_dict'])
latent_vars.load_state_dict(save['latent_vars_state_dict'])

In [None]:
sig = nn.Sigmoid()

latent = latent_vars.latents[1].view(-1, 1, 64, 64)
pred = sig(decoder(latent))

voxel_coords = (pred > 0.5).nonzero()[:, 2:]

In [None]:
torch.max(pred)