In [None]:
import sys
sys.dont_write_bytecode=True

import os

import torch

from modules import LatentDataset
from modules import Diffusion
from modules import DiffusionTrainer

from torch import optim
from torch.utils.data import DataLoader
from torch.backends import cudnn

from tqdm import tqdm

cudnn.benchmark = True

device = 'cuda'

In [None]:
pram = {
    'latents_remove_duplicate_npy_pth': './dataset/latents_remove_duplicate.npy',
    'batch_size': 1,
    'train_epochs': 1000,
    'voxel_map_shape': (128, 128, 128),
    'designate_num_objts': 1,
    'train_test_split_ratio_train': 1,
    'latent_dim': (64, 64),
    'diffusion_lr': 3e-5,
    'beta_start': 1e-4,
    'beta_end': 2e-2,
    'time_steps': 1000,
    'warmup': 10
}

In [None]:
train_dataset = LatentDataset(pram['latents_remove_duplicate_npy_pth'])

train_dataloader = DataLoader(train_dataset,
                              batch_size=pram['batch_size'],
                              shuffle=True,
                              pin_memory=True,
                              drop_last=True)

In [None]:
diffusion = Diffusion().to(device)

In [None]:
def warmup_lr(step):
    return min(step, pram['warmup']) / pram['warmup']

In [None]:
diffusion_optim = optim.AdamW(diffusion.parameters(), lr=pram['diffusion_lr'])

diffusion_sched = torch.optim.lr_scheduler.LambdaLR(diffusion_optim, lr_lambda=warmup_lr)

In [None]:
trainer = DiffusionTrainer(diffusion,
                           pram['beta_start'],
                           pram['beta_end'],
                           pram['time_steps']).to(device)

In [None]:
def train_step(latents):
    diffusion_optim.zero_grad()
        
    loss = trainer(latents)
    loss.backward()
    
    diffusion_optim.step()
    diffusion_sched.step()
    
    return loss.item()

In [None]:
def train(train_dataloader, epoch, batch_size, voxel_map_shape=(128, 128, 128)):    
    tatal_loss = []
    
    pbar = tqdm(train_dataloader, desc='[EPOCH {}] Training'.format(epoch))
    
    for latents in pbar:
        loss = train_step(latents.to(device))
                
        tatal_loss.append(loss)

        avg_loss = sum(tatal_loss) / len(tatal_loss)

        pbar.set_postfix_str('Batch Loss: {:.6f} | Avg Loss: {:.6f}'.format(loss, avg_loss))
    
    return avg_loss

In [None]:
epoch = 0
best_avg_loss = float('inf')
no_improvement_cnt = 0

diffusion.train()

if not os.path.isdir('./models'):
    os.mkdir('./models')

batch_size = pram['batch_size']

train_epochs = pram['train_epochs']

while epoch < train_epochs:
    epoch += 1

    avg_loss = train(train_dataloader, epoch, batch_size)
        
    if avg_loss < best_avg_loss:
        best_avg_loss = avg_loss
        no_improvement_cnt = 0
        
        torch.save({
            'diffusion_state_dict': diffusion.state_dict(),
            'loss': avg_loss,
            'epoch': epoch
        }, f'./models/diffusion_saved.pt')
    else:
        no_improvement_cnt += 1

    if no_improvement_cnt == 3:
        print('No Improvement Count Reached.')
        break