In [None]:
import functools
import gc
import importlib
import inspect
import multiprocessing
import pickle
import shutil
import traceback
from collections import OrderedDict, defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Union

# Third-party libraries - NumPy & Scientific
import numpy as np
from numpy.random import RandomState

# Third-party libraries - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard.summary import hparams

# Third-party libraries - Visualization
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

# Third-party libraries - ML Tools
import wandb
from loguru import logger
from omegaconf import DictConfig, ListConfig, OmegaConf
import omegaconf.errors

# Local imports
from ext import common
import fvdb
import fvdb.nn as fvnn
from fvdb import JaggedTensor, GridBatch

# Local imports
from xcube_refactored.modules.autoencoding.hparams import hparams_handler
from xcube_refactored.utils.loss_util import AverageMeter
from xcube_refactored.utils.loss_util import TorchLossMeter
from xcube_refactored.utils import exp 

from xcube_refactored.modules.autoencoding.sunet import StructPredictionNet 


import gc
import importlib
from contextlib import contextmanager
import os

import fvdb
from fvdb.nn import VDBTensor
from fvdb import GridBatch

import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
from omegaconf import DictConfig, ListConfig, OmegaConf
import collections
from pathlib import Path
from pytorch_lightning.utilities.distributed import rank_zero_only


from xcube_refactored.utils import exp
from xcube_refactored.utils.vis_util import vis_pcs


from xcube_refactored.modules.diffusionmodules.schedulers.scheduling_ddim import DDIMScheduler
from xcube_refactored.modules.diffusionmodules.schedulers.scheduling_ddpm import DDPMScheduler
from xcube_refactored.modules.diffusionmodules.schedulers.scheduling_dpmpp_2m import DPMSolverMultistepScheduler


from xcube_refactored.modules.diffusionmodules.ema import LitEma


# Why aren't these used??????
from xcube_refactored.modules.diffusionmodules.openaimodel.unet_dense import UNetModel as UNetModel_Dense
from xcube_refactored.modules.diffusionmodules.openaimodel.unet_sparse import UNetModel as UNetModel_Sparse
from xcube_refactored.modules.diffusionmodules.openaimodel.unet_sparse_crossattn import UNetModel as UNetModel_Sparse_CrossAttn


# Why aren't these used??????
from xcube_refactored.modules.encoders import (SemanticEncoder, ClassEmbedder, PointNetEncoder,
                                    StructEncoder, StructEncoder3D, StructEncoder3D_remain_h, StructEncoder3D_v2)

In [None]:
# Hyperparameters

batch_size = 4
tree_depth = 3 # according to 512x512x512 -> 128x128x128
voxel_size = 0.0025
resolution = 512
use_fvdb_loader = True
use_hash_tree = True # use hash tree means use early dilation (description in Sec 3.4) 

# setup input
use_input_normal = True
use_input_semantic = False
use_input_intensity = False

# setup KL loss
cut_ratio = 16 # reduce the dimension of the latent space
kl_weight = 1.0 # activate when anneal is off
normalize_kld = True
enable_anneal = False
kl_weight_min = 1e-7
kl_weight_max = 1.0
anneal_star_iter = 0
anneal_end_iter = 70000 # need to adjust for different dataset


structure_weight = 20.0
normal_weight = 300.0
  

learning_rate = {
  "init": 1.0e-4,
  "decay_mult": 0.7,
  "decay_step": 50000,
  "clip": 1.0e-6
}
weight_decay = 0.0
grad_clip = 0.5

c_dim = 32
  
# unet parameters
in_channels = 32
num_blocks = tree_depth
f_maps = 32
neck_dense_type = "UNCHANGED"
neck_bound = [64, 64, 64] # useless but indicate here
num_res_blocks = 1
use_residual = False
order = "gcr"
is_add_dec = False
use_attention = False
use_checkpoint = False

In [None]:
optimizer = torch.optim.AdamW(self.parameters(), lr=lr_config['init'],
                                    weight_decay=self.hparams.weight_decay, amsgrad=True)

scheduler = LambdaLR(optimizer,
                lr_lambda=functools.partial(
                    lambda_lr_wrapper, lr_config=lr_config, batch_size=self.hparams.batch_size))

exp.global_var_manager.register_variable('skip_backward', False)


# Do gradient clipping in the training steps
grad_clip_val = self.hparams.get('grad_clip', 1000.)

if grad_clip_val == "inspect":
    from pytorch_lightning.utilities.grads import grad_norm
    grad_dict = grad_norm(self, 'inf')      # Get the maximum absolute value.
    print(grad_dict)
    grad_clip_val = 1000.

# torch.nn.utils.clip_grad_norm_(self.parameters(), max_norm=grad_clip_val)
torch.nn.utils.clip_grad_value_(self.parameters(), clip_value=grad_clip_val)

# If detect nan values, then this step is skipped
has_nan_value_cnt = 0
for p in filter(lambda p: p.grad is not None, self.parameters()):
    if torch.any(p.grad.data != p.grad.data):
        has_nan_value_cnt += 1
if has_nan_value_cnt > 0:
    exp.logger.warning(f"{has_nan_value_cnt} parameters get nan-gradient -- this step will be skipped.")
    for p in filter(lambda p: p.grad is not None, self.parameters()):
        p.grad.data.zero_()


    def train_dataloader(self):
        # Note:
        import xcube.data as dataset
        train_set = dataset.build_dataset(
            self.hparams.train_dataset, self.get_dataset_spec(), self.hparams, self.hparams.train_kwargs)
        torch.manual_seed(0)
        return DataLoader(train_set, batch_size=self.hparams.batch_size // self.trainer.world_size, shuffle=True,
                          num_workers=self.hparams.train_val_num_workers, collate_fn=self.get_collate_fn())

    def val_dataloader(self):
        import xcube.data as dataset
        val_set = dataset.build_dataset(
            self.hparams.val_dataset, self.get_dataset_spec(), self.hparams, self.hparams.val_kwargs)
        return DataLoader(val_set, batch_size=self.hparams.batch_size // self.trainer.world_size, shuffle=False,
                          num_workers=self.hparams.train_val_num_workers, collate_fn=self.get_collate_fn())

    def test_dataloader(self):
        import xcube.data as dataset
        self.hparams.test_kwargs.resolution = self.hparams.resolution # ! use for testing when training on X^3 but testing on Y^3

        test_set = dataset.build_dataset(
            self.hparams.test_dataset, self.get_dataset_spec(), self.hparams, self.hparams.test_kwargs)
        if self.hparams.test_set_shuffle:
            torch.manual_seed(0)
        return DataLoader(test_set, batch_size=1, shuffle=self.hparams.test_set_shuffle, 
                          num_workers=0, collate_fn=self.get_collate_fn())
