In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

from modules import VoxelDataset
from modules import Encoder
from modules import Diffusion
from modules import DiffusionTrainer
from modules import dataloader_collate_fn, get_voxel_map

from torch import optim
from torch.utils.data import DataLoader
from torch.backends import cudnn

from tqdm import tqdm

cudnn.benchmark = True

device = 'cuda'

In [None]:
pram = {
    'data_dir_pth': './dataset/chair_voxel_data_remove_duplicate',
    'part_counts_npy_pth': './dataset/each_chair_parts_count_remove_duplicate.npy',
    'outlier_objt_indices_npy_pth': './dataset/outlier_objt_indices.npy',
    'batch_size': 1,
    'voxel_map_shape': (128, 128, 128),
    'designate_num_objts': 10,
    'train_test_split_ratio_train': 1,
    'latent_dim': (64, 64),
    'diffusion_lr': 3e-5,
    'beta_start': 1e-4,
    'beta_end': 2e-2,
    'time_steps': 1000,
    'warmup': 10
}

In [None]:
train_dataset = VoxelDataset(pram['data_dir_pth'],
                             pram['part_counts_npy_pth'],
                             pram['outlier_objt_indices_npy_pth'],
                             designate_num_objts=pram['designate_num_objts'],
                             train_test_split_ratio_train=pram['train_test_split_ratio_train'],
                             is_train=True)

train_dataloader = DataLoader(train_dataset,
                              batch_size=pram['batch_size'],
                              shuffle=False,
                              collate_fn=dataloader_collate_fn,
                              pin_memory=True)

In [None]:
encoder = Encoder(pram['latent_dim'], ve_ch=8).to(device)

saved_ckpt = torch.load('./models/encoder_decoder_8.pt')

encoder.load_state_dict(saved_ckpt['encoder_state_dict'])

In [None]:
diffusion = Diffusion().to(device)

In [None]:
def warmup_lr(step):
    return min(step, pram['warmup']) / pram['warmup']

In [None]:
diffusion_optim = optim.AdamW(diffusion.parameters(), lr=pram['diffusion_lr'])

diffusion_sched = torch.optim.lr_scheduler.LambdaLR(diffusion_optim, lr_lambda=warmup_lr)

In [None]:
trainer = DiffusionTrainer(diffusion,
                           pram['beta_start'],
                           pram['beta_end'],
                           pram['time_steps']).to(device)

In [None]:
def train_step(latents):
    diffusion_optim.zero_grad()
        
    loss = trainer(latents)
    loss.backward()
    
    diffusion_optim.step()
    diffusion_sched.step()
    
    return loss.item()

In [None]:
def train(train_dataloader, epoch, batch_size, voxel_map_shape=(128, 128, 128)):    
    tatal_loss = []
    
    pbar = tqdm(train_dataloader, desc='[EPOCH {}] Training'.format(epoch))
    
    for parts in pbar:
        voxel_maps = [get_voxel_map(voxel_coords, device) for voxel_coords in parts]
            
        latents = encoder(torch.stack(voxel_maps).view(batch_size, 1, *voxel_map_shape))

        loss = trainer(latents)
                
        tatal_loss.append(loss)

        avg_loss = sum(tatal_loss) / len(tatal_loss)

        pbar.set_postfix_str('Batch Loss: {:.6f} | Avg Loss: {:.6f}'.format(loss, avg_loss))
    
    return avg_loss

In [None]:
epoch = 0
best_avg_loss = float('inf')
no_improvement_cnt = 0

encoder.eval()
diffusion.train()

if not os.path.isdir('./models'):
    os.mkdir('./models')

batch_size = pram['batch_size']

train_epochs = 5

while epoch < train_epochs:
    epoch += 1

    avg_loss = train(train_dataloader, epoch, batch_size)
        
    if avg_loss < best_avg_loss:
        best_avg_loss = avg_loss
        no_improvement_cnt = 0
        
        torch.save({
            'diffusion_state_dict': diffusion.state_dict(),
            'loss': avg_loss,
            'epoch': epoch
        }, f'./models/diffusion_saved.pt')
    else:
        no_improvement_cnt += 1

    if no_improvement_cnt == 3:
        print('No Improvement Count Reached.')
        break