In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

from modules import VoxelDataset
from modules import Encoder
from modules import Decoder
from modules import BinaryFocalLoss
from modules import dataloader_collate_fn, get_voxel_map

from torch import optim
from torch.utils.data import DataLoader
from torch.backends import cudnn

from tqdm import tqdm

cudnn.benchmark = True

device = 'cuda'

In [None]:
pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128),
    'designate_num_objts': 10,
    'train_test_split_ratio_train': 1,
    'latent_dim': (64, 64),
    'decoder_lr': 3e-5,
    'encoder_lr': 3e-5,
}

In [None]:
train_dataset = VoxelDataset(pram['data_dir_pth'],
                             pram['part_counts_npy_pth'],
                             pram['outlier_objt_indices_npy_pth'],
                             designate_num_objts=pram['designate_num_objts'],
                             train_test_split_ratio_train=pram['train_test_split_ratio_train'],
                             is_train=True)

train_dataloader = DataLoader(train_dataset,
                              batch_size=pram['batch_size'],
                              shuffle=False,
                              collate_fn=dataloader_collate_fn,
                              pin_memory=True)

In [None]:
encoder = Encoder(pram['latent_dim'], ve_ch=8).to(device)

decoder = Decoder(pram['latent_dim'], vd_ch=8).to(device)

In [None]:
encoder_optim = optim.AdamW(encoder.parameters(), lr=pram['encoder_lr'])

decoder_optim = optim.AdamW(decoder.parameters(), lr=pram['decoder_lr'])

In [None]:
loss_fn = BinaryFocalLoss(gamma=5, alpha=0.8)

In [None]:
def train_step(targets):
    encoder_optim.zero_grad()
    decoder_optim.zero_grad()

    latent = encoder(targets)
    outputs = decoder(latent)
        
    loss = loss_fn(outputs, targets, logits=True)
    loss.backward()
    
    encoder_optim.step()
    decoder_optim.step()

    return loss

In [None]:
def train(train_dataloader, epoch, batch_size, voxel_map_shape=(128, 128, 128)):    
    tatal_loss = []
    
    pbar = tqdm(train_dataloader, desc='[EPOCH {}] Training'.format(epoch))
    
    for parts in pbar:
        voxel_maps = [get_voxel_map(voxel_coords, device) for voxel_coords in parts]
                
        targets = torch.stack(voxel_maps).view(batch_size, 1, *voxel_map_shape)

        loss = train_step(targets)

        tatal_loss.append(loss)

        avg_loss = sum(tatal_loss) / len(tatal_loss)

        pbar.set_postfix_str('Batch Loss: {:.6f} | Avg Loss: {:.6f}'.format(loss, avg_loss))
    
    return avg_loss

In [None]:
epoch = 0
best_avg_loss = float('inf')
no_improvement_cnt = 0

encoder.train()
decoder.train()

if not os.path.isdir('./models'):
    os.mkdir('./models')

batch_size = pram['batch_size']

train_epochs = 50

while epoch < train_epochs:
    epoch += 1

    avg_loss = train(train_dataloader, epoch, batch_size)
    
    if avg_loss < best_avg_loss:
        best_avg_loss = avg_loss
        no_improvement_cnt = 0
        
        torch.save({
            'encoder_state_dict': encoder.state_dict(),
            'decoder_state_dict': decoder.state_dict(),
            'loss': avg_loss,
            'epoch': epoch
        }, f'./models/encoder_decoder_saved.pt')
    else:
        no_improvement_cnt += 1

    if no_improvement_cnt == 3:
        print('No Improvement Count Reached.')
        break