In [1]:
import functools
import gc
import importlib
import inspect
import multiprocessing
import pickle
import shutil
import traceback
from collections import OrderedDict, defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Union
from utils.Dataspec import DatasetSpec


# Third-party libraries - NumPy & Scientific
import numpy as np
from numpy.random import RandomState

# Third-party libraries - PyTorch
import torch
print (torch.__version__, "torch")
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
# from torch.utils.tensorboard.summary import hparams

# Third-party libraries - Visualization
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

# Third-party libraries - ML Tools
from omegaconf import DictConfig, ListConfig, OmegaConf
import omegaconf.errors

# Local imports
# from ext import common
import fvdb
import fvdb.nn as fvnn
from fvdb import JaggedTensor, GridBatch
from fvdb.nn import VDBTensor


from modules.autoencoding.sunet import StructPredictionNet
import collections


2.4.1 torch


In [2]:
# Hyperparameters

batch_size = 4
# Tree depth is what builds the different resolutions of the voxel grid.
# For example, a tree depth of 3 with a resolution of 512 would give you
# a grid of 128x128x128. The depth of the tree determines how many times
# the original grid is downsampled. So, a depth of 3 means the original
# grid is downsampled three times, resulting in a final resolution of
# 128x128x128.
tree_depth = 3 # according to 512x512x512 -> 128x128x128
voxel_size = 0.0025
resolution = 512
use_fvdb_loader = True
use_hash_tree = True # use hash tree means use early dilation (description in Sec 3.4) 


with_semantic_branch = False
extract_mesh = "store_true"

solver_order ="3"

use_dpm = "store_true"

ddim_step = 50

use_ddim = "store_true"

ema = "store_true"

batch_len = 64

toal_len = 700

seed = 0

world_size = 1

# setup input
use_input_normal = True
use_input_semantic = False
use_input_intensity = False
use_input_color = False

# setup KL loss
cut_ratio = 16 # reduce the dimension of the latent space
kl_weight = 1.0 # activate when anneal is off
normalize_kld = True
enable_anneal = False
kl_weight_min = 1e-7
kl_weight_max = 1.0
anneal_star_iter = 0
anneal_end_iter = 70000 # need to adjust for different dataset


structure_weight = 20.0
normal_weight = 300.0
  

learning_rate = {
  "init": 1.0e-4,
  "decay_mult": 0.7,
  "decay_step": 50000,
  "clip": 1.0e-6
}
weight_decay = 0.0
grad_clip = 0.5

c_dim = 32
  
# unet parameters
in_channels = 32
num_blocks = tree_depth
f_maps = 32
neck_dense_type = "UNCHANGED"
neck_bound = [64, 64, 64] # useless but indicate here
num_res_blocks = 1
use_residual = False
order = "gcr"
is_add_dec = False
use_attention = False
use_checkpoint = False


_custom_name =  "objaverse"
_objaverse_path = "/home/benzshawelt/Kedziora/kedziora_research/layer_x_layer/data_gen/voxels/512"
_split_path = "/home/benzshawelt/Kedziora/kedziora_research/layer_x_layer/data_gen/voxels/"
_text_emb_path = ""
_null_embed_path = "./assets/null_text_emb.pkl"
max_text_len= 77
text_embed_drop_prob= 0.1

train_dataset = "ObjaverseDataset"
train_val_num_workers= 16
train_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "train",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "text_embed_drop_prob": text_embed_drop_prob, # ! classifier-free training
  "random_seed": 0
}

val_dataset = "ObjaverseDataset"
val_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "test",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "random_seed": "fixed"
}

test_dataset = "ObjaverseDataset"
test_num_workers =8
test_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "test",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "random_seed": "fixed"
}

In [3]:
def reparametrize(mu, logvar):
    std = logvar.div(2).exp()
    eps = Variable(std.data.new(std.size()).normal_())
    return mu + std*eps

In [4]:
# get the paths of all of the voxelized shapes, the shapes are stored in .pkl files within 2 folder depths of the base path
def get_all_paths(base_path):
  import os
  # list all of the files in the base path
  # print("base path", base_path)
  # print(os.listdir(base_path))
  all_paths = []
  for root, dirs, files in os.walk(base_path):
    for file in files:
      if file.endswith(".pkl"):
        all_paths.append(os.path.join(root, file.split(".")[0]))
        
  return all_paths

In [5]:
def split_dataset(dataset, train_split_ratio, test_split_ratio, seed=0):
    np.random.seed(seed)
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    train_split = int(len(dataset) * train_split_ratio)
    # always take .10 of the dataset for test (removed from val)
    test_split = int(len(dataset) * test_split_ratio) + train_split

    #return the train, val, and test datasets
    return torch.utils.data.Subset(dataset, indices[:train_split]), torch.utils.data.Subset(dataset, indices[train_split:test_split]), torch.utils.data.Subset(dataset, indices[test_split:])
    

In [6]:
def generate_split_paths(base_path, split_ratio, seed=0):
    all_paths = get_all_paths(base_path)
    train_paths, val_paths, test_split= split_dataset(all_paths, split_ratio, 0.1, seed)

    # Save the split paths as a lst file
    split_base_path = Path(base_path).parent
    train_split_path = split_base_path / "train.lst"
    val_split_path = split_base_path / "val.lst"
    test_split_path = split_base_path / "test.lst"
    

    # delete the files if they already exist
    if train_split_path.exists():
        train_split_path.unlink()
    if val_split_path.exists():
        val_split_path.unlink()
    if test_split_path.exists():
        test_split_path.unlink()
        
    # write the paths to the files

    with open(test_split_path, "w") as f:
        for path in test_split:
            f.write(f"{path}\n")
    with open(train_split_path, "w") as f:
        for path in train_paths:
            f.write(f"{path}\n")
    with open(val_split_path, "w") as f:
        for path in val_paths:
            f.write(f"{path}\n")
    return True

In [7]:
generate_split_paths(_objaverse_path , 0.8, 42)

True

In [1]:
from modules.autoencoding.base_encoder import Encoder

class UnetWrapper(nn.Module):
    def __init__(self, unet, hparams):
        super().__init__()
        self.encoder = Encoder(hparams)
        self.unet = unet
        self.hparams = hparams
        # Ensure cut_ratio has a default value if not provided
        if "cut_ratio" not in self.hparams:
            self.hparams["cut_ratio"] = 1.0  # Default value, adjust as needed
        

    def build_hash_tree_from_grid(self, input_grid):
        hash_tree = {}
        input_xyz = input_grid.grid_to_world(input_grid.ijk.float())
        
        for depth in range(self.hparams["tree_depth"]):
            if depth != 0 and not self.hparams["use_hash_tree"]:
                break            
            voxel_size = [sv * 2 ** depth for sv in self.hparams["voxel_size"]]
            origins = [sv / 2. for sv in voxel_size]
            
            if depth == 0:
                hash_tree[depth] = input_grid
            else:
                hash_tree[depth] = fvdb.gridbatch_from_nearest_voxels_to_points(input_xyz, 
                                                                                  voxel_sizes=voxel_size, 
                                                                                  origins=origins)
        return hash_tree

    def forward(self, batch, out: dict):
        input_xyz = batch[DatasetSpec.INPUT_PC]
        hash_tree = self.build_hash_tree_from_grid(input_xyz)
        input_grid = hash_tree[0]
        batch.update({'input_grid': input_grid})

        if not self.hparams["use_hash_tree"]:
            hash_tree = None
                
        unet_feat = self.encoder(input_grid, batch)
        unet_feat = fvnn.VDBTensor(input_grid, input_grid.jagged_like(unet_feat))
        unet_res, unet_output, dist_features = self.unet(unet_feat, hash_tree)

        out.update({'tree': unet_res.structure_grid})
        out.update({
            'structure_features': unet_res.structure_features,
            'dist_features': dist_features,
        })
        out.update({'gt_grid': input_grid})
        out.update({'gt_tree': hash_tree})
        
        if self.hparams["with_normal_branch"]:
            out.update({
                'normal_features': unet_res.normal_features,
            })
        if self.hparams["with_semantic_branch"]:
            out.update({
                'semantic_features': unet_res.semantic_features,
            })
        if self.hparams["with_color_branch"]:
            out.update({
                'color_features': unet_res.color_features,
            })
        return out
    
    def get_dataset_spec(self):
        DS = DatasetSpec
        all_specs = [DS.SHAPE_NAME, DS.INPUT_PC,
                        DS.GT_DENSE_PC, DS.GT_GEOMETRY]
        if self.hparams.get("use_input_normal", True):
            all_specs.append(DS.TARGET_NORMAL)
            all_specs.append(DS.GT_DENSE_NORMAL)
        if self.hparams.get("use_input_semantic", False) or self.hparams.get("with_semantic_branch", False):
            all_specs.append(DS.GT_SEMANTIC)
        if self.hparams.get("use_input_intensity", False):
            all_specs.append(DS.INPUT_INTENSITY)
        return all_specs

    @torch.no_grad()
    def _encode(self, batch, use_mode=False):
        input_xyz = batch[DatasetSpec.INPUT_PC]
        hash_tree = self.build_hash_tree_from_grid(input_xyz)
        input_grid = hash_tree[0]
        batch.update({'input_grid': input_grid})

        if not self.hparams["use_hash_tree"]:
            hash_tree = None

        unet_feat = self.encoder(input_grid, batch)
        unet_feat = fvnn.VDBTensor(input_grid, input_grid.jagged_like(unet_feat))
        _, x, mu, log_sigma = self.unet.encode(unet_feat, hash_tree=hash_tree)
        if use_mode:
            sparse_feature = mu
        else:
            sparse_feature = reparametrize(mu, log_sigma)
        
        return fvnn.VDBTensor(x.grid, x.grid.jagged_like(sparse_feature))
    

    @staticmethod
    def load_from_checkpoint(checkpoint_path):
        """Load the entire model from a checkpoint without needing separate autoencoder initialization."""
        # Load the entire checkpoint
        checkpoint = torch.load(checkpoint_path)


        u_net = StructPredictionNet(
            in_channels=checkpoint.get('in_channels'),    
            num_blocks=checkpoint.get('num_blocks'),
            f_maps=checkpoint.get('f_maps'),
            neck_dense_type=checkpoint.get('neck_dense_type'),
            neck_bound=checkpoint.get('neck_bound'),
            num_res_blocks=checkpoint.get('num_res_blocks'),
            use_residual=checkpoint.get('use_residual'),
            order=checkpoint.get('order'),
            is_add_dec=checkpoint.get('is_add_dec'),
            use_attention=checkpoint.get('use_attention'),
            use_checkpoint=checkpoint.get('use_checkpoint'),
            c_dim=checkpoint.get('c_dim')
        )


        unet_wrapper = UnetWrapper(u_net, {
            "tree_depth": checkpoint.get('tree_depth'),
            "voxel_size": checkpoint.get('voxel_size'),
            "use_hash_tree": checkpoint.get('use_hash_tree'),
            "use_input_normal": checkpoint.get('use_input_normal'),
            "use_input_semantic": checkpoint.get('use_input_semantic'),
            "use_input_color": checkpoint.get('use_input_color'),
            "use_input_intensity": checkpoint.get('use_input_intensity'),
            "c_dim": checkpoint.get('c_dim'),
            "with_normal_branch": checkpoint.get('with_normal_branch'),
            "with_semantic_branch": checkpoint.get('with_semantic_branch'),
            "with_color_branch": checkpoint.get('with_color_branch'),
        })
        # Load the state dict
        unet_wrapper.load_state_dict(checkpoint.get('model_state_dict', checkpoint))
        
        unet_wrapper.hparams["num_blocks"] = checkpoint.get('num_blocks')
        unet_wrapper.hparams["f_maps"] = checkpoint.get('f_maps')
        unet_wrapper.hparams["cut_ratio"] = checkpoint.get('cut_ratio', 1.0)  # Added cut_ratio with default


        return unet_wrapper, u_net
    

    def get_config(self):
        return {
            "tree_depth": self.hparams["tree_depth"],
            "voxel_size": self.hparams["voxel_size"],
            "use_hash_tree": self.hparams["use_hash_tree"],
            "use_input_normal": self.hparams["use_input_normal"],
            "use_input_semantic": self.hparams["use_input_semantic"],
            "use_input_color": self.hparams["use_input_color"],
            "use_input_intensity": self.hparams["use_input_intensity"],
            "c_dim": self.hparams["c_dim"],
            "with_normal_branch": self.hparams["with_normal_branch"],
            "with_semantic_branch": self.hparams["with_semantic_branch"],
            "with_color_branch": self.hparams["with_color_branch"],
            "cut_ratio": self.hparams.get("cut_ratio", 1.0),  # Added cut_ratio with default
        }

    @staticmethod
    def load_from_checkpoint(checkpoint_path, industry_mapping):
        """Load the entire model from a checkpoint without needing separate autoencoder initialization."""
        # Load the entire checkpoint
        checkpoint = torch.load(checkpoint_path)

        u_net = StructPredictionNet(
            in_channels=checkpoint.get('in_channels'),    
            num_blocks=checkpoint.get('num_blocks'),
            f_maps=checkpoint.get('f_maps'),
            neck_dense_type=checkpoint.get('neck_dense_type'),
            neck_bound=checkpoint.get('neck_bound'),
            num_res_blocks=checkpoint.get('num_res_blocks'),
            use_residual=checkpoint.get('use_residual'),
            order=checkpoint.get('order'),
            is_add_dec=checkpoint.get('is_add_dec'),
            use_attention=checkpoint.get('use_attention'),
            use_checkpoint=checkpoint.get('use_checkpoint'),
            c_dim=checkpoint.get('c_dim')
        )

        unet_wrapper = UnetWrapper(u_net, {
            "tree_depth": checkpoint.get('tree_depth'),
            "voxel_size": checkpoint.get('voxel_size'),
            "use_hash_tree": checkpoint.get('use_hash_tree'),
            "use_input_normal": checkpoint.get('use_input_normal'),
            "use_input_semantic": checkpoint.get('use_input_semantic'),
            "use_input_color": checkpoint.get('use_input_color'),
            "use_input_intensity": checkpoint.get('use_input_intensity'),
            "c_dim": checkpoint.get('c_dim'),
            "with_normal_branch": checkpoint.get('with_normal_branch'),
            "with_semantic_branch": checkpoint.get('with_semantic_branch'),
            "with_color_branch": checkpoint.get('with_color_branch'),
            "cut_ratio": checkpoint.get('cut_ratio', 1.0),  # Added cut_ratio with default
        })

        # Load the state dict
        unet_wrapper.load_state_dict(checkpoint.get('model_state_dict', checkpoint))

        return unet_wrapper, u_net

    def save_checkpoint(self, filepath):
        """Save the model weights and configuration to a checkpoint file."""
        checkpoint = {
            # Configuration parameters
            **self.get_config(),  # Unpack all config parameters from get_config
            
            # Add U-Net specific parameters that aren't in get_config
            'in_channels': self.unet.in_channels,
            'num_blocks': self.unet.num_blocks,
            'f_maps': self.unet.f_maps,
            'neck_dense_type': self.unet.neck_dense_type,
            'neck_bound': self.unet.neck_bound,
            'num_res_blocks': self.unet.num_res_blocks,
            'use_residual': self.unet.use_residual,
            'order': self.unet.order,
            'is_add_dec': self.unet.is_add_dec,
            'use_attention': self.unet.use_attention,
            'use_checkpoint': self.unet.use_checkpoint,
            
            # Model weights
            'model_state_dict': self.state_dict(),
        }
        
        torch.save(checkpoint, filepath)

NameError: name 'nn' is not defined

In [19]:
def lambda_lr_wrapper(it, lr_config, batch_size, accumulate_grad_batches=1):
    return max(
        lr_config['decay_mult'] ** (int(it * batch_size * accumulate_grad_batches / lr_config['decay_step'])),
        lr_config['clip'] / lr_config['init'])


In [20]:
# define the model
u_net = StructPredictionNet(
  in_channels=in_channels,
  num_blocks=num_blocks,
  f_maps=f_maps,
  neck_dense_type=neck_dense_type,
  neck_bound=neck_bound,
  num_res_blocks=num_res_blocks,
  use_residual=use_residual,
  order=order,
  is_add_dec=is_add_dec,
  use_attention=use_attention,
  use_checkpoint=use_checkpoint,
  c_dim=c_dim
)

unet_wrapper = UnetWrapper(u_net, {
    "tree_depth": tree_depth,
    "voxel_size": [voxel_size, voxel_size, voxel_size],
    "use_hash_tree": use_hash_tree,
    "use_input_normal": use_input_normal,
    "use_input_semantic": use_input_semantic,
    "use_input_color": use_input_color,
    "use_input_intensity": use_input_intensity,
    "c_dim": c_dim,
    "with_normal_branch": True,
    "with_semantic_branch": with_semantic_branch,
    "with_color_branch": False,
})

In [None]:
optimizer = torch.optim.AdamW(unet_wrapper.parameters(), lr=learning_rate["init"],
                                    weight_decay=weight_decay, amsgrad=True)

scheduler = LambdaLR(optimizer,
                lr_lambda=functools.partial(
                    lambda_lr_wrapper, lr_config=learning_rate, batch_size=batch_size))

# exp.global_var_manager.register_variable('skip_backward', False)

def list_collate(batch):
    """
    This just do not stack batch dimension.
    """
    
    elem = None
    for e in batch:
        if e is not None:
            elem = e
            break
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        return batch
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            return list_collate([torch.as_tensor(b) if b is not None else None for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, DictConfig) or isinstance(elem, ListConfig):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: list_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [list_collate(samples) for samples in transposed]
    elif isinstance(elem, GridBatch):
        return fvdb.jcat(batch)
    
    return batch




def train_dataloader():
    from data.objaverse import ObjaverseDataset
    train_set =  ObjaverseDataset(onet_base_path=train_kwargs["onet_base_path"], 
                                  spec=unet_wrapper.get_dataset_spec(), 
                                  split=train_kwargs["split"], 
                                  resolution=train_kwargs["resolution"], 
                                  image_base_path=None, 
                                  random_seed=0, 
                                  hparams=None, 
                                  skip_on_error=False, 
                                  custom_name="objaverse", 
                                  text_emb_path="../data/objaverse/objaverse/text_emb", 
                                  null_embed_path="./assets/null_text_emb.pkl", 
                                  text_embed_drop_prob=0.0, 
                                  max_text_len=77, 
                                  duplicate_num=1, 
                                  split_base_path=_split_path,
                                  )
        
    return DataLoader(train_set, batch_size=batch_size // world_size, shuffle=True,
                        num_workers=train_val_num_workers, collate_fn=list_collate)


# print(get_dataset_spec())

def val_dataloader():
    from data.objaverse import ObjaverseDataset
    val_set = ObjaverseDataset(onet_base_path=val_kwargs["onet_base_path"],
                                spec=unet_wrapper.get_dataset_spec(), 
                                split=val_kwargs["split"], 
                                resolution=val_kwargs["resolution"], 
                                image_base_path=None, 
                                random_seed=0, 
                                hparams=None, 
                                skip_on_error=False, 
                                custom_name="objaverse", 
                                text_emb_path="../data/objaverse/objaverse/text_emb", 
                                null_embed_path="./assets/null_text_emb.pkl", 
                                text_embed_drop_prob=0.0, 
                                max_text_len=77, 
                                duplicate_num=1, 
                                split_base_path=_split_path,
                                )


    return DataLoader(val_set, batch_size=batch_size // world_size, shuffle=False,
                        num_workers=train_val_num_workers, collate_fn=list_collate)

def test_dataloader(resolution=resolution, test_set_shuffle=False):
    from data.objaverse import ObjaverseDataset
    resolution = resolution # ! use for testing when training on X^3 but testing on Y^3

    test_set =  ObjaverseDataset(onet_base_path=test_kwargs["onet_base_path"],
                                spec=unet_wrapper.get_dataset_spec(), 
                                split=test_kwargs["split"], 
                                resolution=resolution, 
                                image_base_path=None, 
                                random_seed=0, 
                                hparams=None, 
                                skip_on_error=False, 
                                custom_name="objaverse", 
                                text_emb_path="../data/objaverse/objaverse/text_emb", 
                                null_embed_path="./assets/null_text_emb.pkl", 
                                text_embed_drop_prob=0.0, 
                                max_text_len=77, 
                                duplicate_num=1, 
                                split_base_path=_split_path,
                                )
    
    if test_set_shuffle:
        torch.manual_seed(0)
    return DataLoader(test_set, batch_size=1, shuffle=test_set_shuffle, 
                        num_workers=0, collate_fn=list_collate)


In [22]:
train_dataloader_ = train_dataloader()
val_dataloader_ = val_dataloader()
test_dataloader_ = test_dataloader()



In [23]:
import utils.exp as exp

In [24]:
# from utils.color_util import color_from_points, semantic_from_points
from utils.loss_util import TorchLossMeter
from utils.Dataspec import DatasetSpec as DS
    
class Loss(nn.Module):
    def __init__(self, hparams):
        super().__init__()
        self.hparams = hparams
        # Set default for use_fvdb_loader if not provided
        if "use_fvdb_loader" not in self.hparams:
            self.hparams["use_fvdb_loader"] = True
        # Set default for remain_h if not provided
        if "remain_h" not in self.hparams:
            self.hparams["remain_h"] = False
        # Set default kl_weight if not provided
        if "kl_weight" not in self.hparams:
            self.hparams["kl_weight"] = 1.0

    def transform_field(self, field: torch.Tensor):
        gt_band = 1.0 # not sure if this will be changed
        # For scalar voxel_size, use the first element from the list
        voxel_size = self.hparams["voxel_size"][0] if isinstance(self.hparams["voxel_size"], list) else self.hparams["voxel_size"]
        truncation_size = gt_band * voxel_size
        # non-binary supervision (made sure derivative norm at 0 if 1)
        field = torch.tanh(field / truncation_size) * truncation_size
        return field
    
    def cross_entropy(self, pd_struct: fvnn.VDBTensor, gt_grid: fvdb.GridBatch, dynamic_grid: fvdb.GridBatch = None):
        assert torch.allclose(pd_struct.grid.origins, gt_grid.origins)
        assert torch.allclose(pd_struct.grid.voxel_sizes, gt_grid.voxel_sizes)
        idx_mask = gt_grid.ijk_to_index(pd_struct.grid.ijk).jdata == -1
        idx_mask = idx_mask.long()
        if dynamic_grid is not None:
            dynamic_mask = dynamic_grid.ijk_to_index(pd_struct.grid.ijk).jdata == -1
            loss = F.cross_entropy(pd_struct.jdata, idx_mask, reduction='none') * dynamic_mask.float()
            loss = loss.mean()
        else:
            loss = F.cross_entropy(pd_struct.jdata, idx_mask)
        return 0.0 if idx_mask.size(0) == 0 else loss
    
    def struct_acc(self, pd_struct: fvnn.VDBTensor, gt_grid: fvdb.GridBatch):
        assert torch.allclose(pd_struct.grid.origins, gt_grid.origins)
        assert torch.allclose(pd_struct.grid.voxel_sizes, gt_grid.voxel_sizes)
        idx_mask = gt_grid.ijk_to_index(pd_struct.grid.ijk).jdata == -1
        idx_mask = idx_mask.long()
        return torch.mean((pd_struct.jdata.argmax(dim=1) == idx_mask).float())
    
    def grid_iou(self, gt_grid: fvdb.GridBatch, pd_grid: fvdb.GridBatch):
        assert gt_grid.grid_count == pd_grid.grid_count
        idx = pd_grid.ijk_to_index(gt_grid.ijk)
        upi = (pd_grid.num_voxels + gt_grid.num_voxels).cpu().numpy().tolist()
        ious = []
        for i in range(len(upi)):
            inter = torch.sum(idx[i].jdata >= 0).item()
            ious.append(inter / (upi[i] - inter + 1.0e-6))
        return np.mean(ious)

    def normal_loss(self, batch, normal_feats: fvnn.VDBTensor, eps=1e-6):
        if self.hparams["use_fvdb_loader"]:
            ref_grid = batch['input_grid']
            ref_xyz = ref_grid.grid_to_world(ref_grid.ijk.float()) 
        else:
            ref_xyz = fvdb.JaggedTensor(batch[DS.INPUT_PC])
        
        gt_normal = normal_feats.grid.splat_trilinear(ref_xyz, fvdb.JaggedTensor(batch[DS.TARGET_NORMAL]))
        # normalize normal
        gt_normal.jdata /= (gt_normal.jdata.norm(dim=1, keepdim=True) + eps)
        normal_loss = F.l1_loss(gt_normal.jdata, normal_feats.jdata)
        return normal_loss
    
    def color_loss(self, batch, color_feats: fvnn.VDBTensor):
        assert self.hparams["use_fvdb_loader"] is True
        # check if color_feats is empty
        if color_feats.grid.total_voxels == 0:
            return 0.0
        ref_grid = batch['input_grid']
        ref_xyz = ref_grid.grid_to_world(ref_grid.ijk.float())
        ref_color = fvdb.JaggedTensor(batch[DS.INPUT_COLOR])
        
        target_xyz = color_feats.grid.grid_to_world(color_feats.grid.ijk.float())
        target_color = []
        slect_color_feats = []
        for batch_idx in range(ref_grid.grid_count):
            ref_color_i = ref_color[batch_idx].jdata
            target_color.append(color_from_points(target_xyz[batch_idx].jdata, ref_xyz[batch_idx].jdata, ref_color_i, k=1))
            slect_color_feats.append(color_feats.feature[batch_idx].jdata)
            
        if len(target_color) == 0 or len(slect_color_feats) == 0: # to avoid JaggedTensor build from empty list
            return 0.0  
        
        target_color = fvdb.JaggedTensor(target_color)
        slect_color_feats = fvdb.JaggedTensor(slect_color_feats)
        color_loss = F.l1_loss(slect_color_feats.jdata, target_color.jdata)
        return color_loss
    
    def semantic_loss(self, batch, semantic_feats: fvnn.VDBTensor):
        assert self.hparams["use_fvdb_loader"] is True
        # check if semantic_feats is empty
        if semantic_feats.grid.total_voxels == 0:
            return 0.0
        ref_grid = batch['input_grid']
        ref_xyz = ref_grid.grid_to_world(ref_grid.ijk.float())
        ref_semantic = fvdb.JaggedTensor(batch[DS.GT_SEMANTIC])
        if ref_semantic.jdata.size(0) == 0: # if all samples in this batch is without semantic
            return 0.0
                
        target_xyz = semantic_feats.grid.grid_to_world(semantic_feats.grid.ijk.float())       
        target_semantic = []
        slect_semantic_feats = []
        for batch_idx in range(ref_grid.grid_count):
            ref_semantic_i = ref_semantic[batch_idx].jdata
            if ref_semantic_i.size(0) == 0:
                continue
            target_semantic.append(semantic_from_points(target_xyz[batch_idx].jdata, ref_xyz[batch_idx].jdata, ref_semantic_i))
            slect_semantic_feats.append(semantic_feats.feature[batch_idx].jdata)
                    
        if len(target_semantic) == 0 or len(slect_semantic_feats) == 0: # to avoid JaggedTensor build from empty list
            return 0.0

        target_semantic = fvdb.JaggedTensor(target_semantic)
        slect_semantic_feats = fvdb.JaggedTensor(slect_semantic_feats)
        
        if slect_semantic_feats.jdata.size(0) == 0: # to aviod cross_entropy take empty tensor
            return 0.0
        
        semantic_loss = F.cross_entropy(slect_semantic_feats.jdata, target_semantic.jdata.long())
        return semantic_loss
    
    def get_kl_weight(self, global_step):
        # linear annealing the kl weight
        if global_step > self.hparams["anneal_star_iter"]:
            if global_step < self.hparams["anneal_end_iter"]:
                kl_weight = self.hparams["kl_weight_min"] + \
                                         (self.hparams["kl_weight_max"] - self.hparams["kl_weight_min"]) * \
                                         (global_step - self.hparams["anneal_star_iter"]) / \
                                         (self.hparams["anneal_end_iter"] - self.hparams["anneal_star_iter"])
            else:
                kl_weight = self.hparams["kl_weight_max"]
        else:
            kl_weight = self.hparams["kl_weight_min"]

        return kl_weight

    def forward(self, batch, out, compute_metric: bool, global_step, current_epoch, optimizer_idx=0):
        loss_dict = TorchLossMeter()
        metric_dict = TorchLossMeter()
        latent_dict = TorchLossMeter()

        dynamic_grid = None

        if not self.hparams["use_hash_tree"]:
            gt_grid = out['gt_grid']
            if self.hparams["supervision"]["structure_weight"] > 0.0:
                for feat_depth, pd_struct_i in out['structure_features'].items():
                    downsample_factor = 2 ** feat_depth
                    if self.hparams["remain_h"]:
                        pd_voxel_size = pd_struct_i.grid.voxel_sizes[0]
                        h_factor = pd_voxel_size[0] // pd_voxel_size[2]
                        downsample_factor = [downsample_factor, downsample_factor, downsample_factor // h_factor]
                    if downsample_factor != 1:             
                        gt_grid_i = gt_grid.coarsened_grid(downsample_factor)
                        dyn_grid_i = dynamic_grid.coarsened_grid(downsample_factor) if dynamic_grid is not None else None
                    else:
                        gt_grid_i = gt_grid
                        dyn_grid_i = dynamic_grid
                    loss_dict.add_loss(f"struct-{feat_depth}", self.cross_entropy(pd_struct_i, gt_grid_i, dyn_grid_i),
                                    self.hparams["supervision"]["structure_weight"])
                    if compute_metric:
                        with torch.no_grad():
                            metric_dict.add_loss(f"struct-acc-{feat_depth}", self.struct_acc(pd_struct_i, gt_grid_i))
        else:
            if self.hparams["supervision"]["structure_weight"] > 0.0:
                gt_tree = out['gt_tree']
                for feat_depth, pd_struct_i in out['structure_features'].items():
                    gt_grid_i = gt_tree[feat_depth]
                    # get dynamic grid
                    dyn_grid_i = dynamic_grid.coarsened_grid(2 ** feat_depth) if dynamic_grid is not None else None
                    loss_dict.add_loss(f"struct-{feat_depth}", self.cross_entropy(pd_struct_i, gt_grid_i, dyn_grid_i),
                                    self.hparams["supervision"]["structure_weight"])
                    if compute_metric:
                        with torch.no_grad():
                            metric_dict.add_loss(f"struct-acc-{feat_depth}", self.struct_acc(pd_struct_i, gt_grid_i))
        
        # compute normal loss
        if self.hparams["with_normal_branch"]:
            if out['normal_features'] == {}:
                normal_loss = 0.0
            else:
                feat_depth = min(out['normal_features'].keys())
                normal_loss = self.normal_loss(batch, out['normal_features'][feat_depth])
                    
            loss_dict.add_loss(f"normal", normal_loss, self.hparams["supervision"]["normal_weight"])
        
        # compute semantic loss
        if self.hparams["with_semantic_branch"]:
            for feat_depth, pd_semantic_i in out['semantic_features'].items():
                semantic_loss = self.semantic_loss(batch, pd_semantic_i)
                if semantic_loss == 0.0: # do not take empty into log
                    continue
                loss_dict.add_loss(f"semantic_{feat_depth}", semantic_loss, self.hparams["supervision"]["semantic_weight"])
                
        # compute color loss
        if self.hparams["with_color_branch"]:
            for feat_depth, pd_color_i in out['color_features'].items():
                color_loss = self.color_loss(batch, pd_color_i)
                if color_loss == 0.0:
                    continue
                loss_dict.add_loss(f"color_{feat_depth}", color_loss, self.hparams["supervision"]["color_weight"])

        # compute KL divergence
        if "dist_features" in out:
            dist_features = out['dist_features']
            kld = 0.0
            for latent_id, (mu, logvar) in enumerate(dist_features):
                num_voxel = mu.size(0)
                kld_temp = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
                kld_total = kld_temp.item()
                if self.hparams["normalize_kld"]:
                    kld_temp /= num_voxel

                kld += kld_temp
                latent_dict.add_loss(f"mu-{latent_id}", mu.mean())
                latent_dict.add_loss(f"logvar-{latent_id}", logvar.mean())
                latent_dict.add_loss(f"kld-true-{latent_id}", kld_temp.item())
                latent_dict.add_loss(f"kld-total-{latent_id}", kld_total)

            if self.hparams["enable_anneal"]:
                loss_dict.add_loss("kld", kld, self.get_kl_weight(global_step))
            else:
                loss_dict.add_loss("kld", kld, self.hparams["kl_weight"])
            
        return loss_dict, metric_dict, latent_dict

In [25]:
# define the loss module
loss_module = Loss({
    "tree_depth": tree_depth,
    "voxel_size": [voxel_size, voxel_size, voxel_size],
    "use_hash_tree": use_hash_tree,
    "use_input_normal": use_input_normal,
    "use_input_semantic": use_input_semantic,
    "use_input_color": use_input_color,
    "use_input_intensity": use_input_intensity,
    "c_dim": c_dim,
    "with_normal_branch": True,
    "with_semantic_branch": with_semantic_branch,
    "with_color_branch": False,
    "supervision": {
        "structure_weight": structure_weight,
        "normal_weight": normal_weight,
        "semantic_weight": 0.0,
        "color_weight": 0.0
    },
    "normalize_kld": normalize_kld,
    "enable_anneal": enable_anneal,
    "kl_weight_min": kl_weight_min,
    "kl_weight_max": kl_weight_max,
    "anneal_star_iter": anneal_star_iter,
    "anneal_end_iter": anneal_end_iter,
})

In [26]:
epochs = 100

In [27]:
best_val_loss = float('inf')

for epoch in range(epochs):
    out_dict = {}
    unet_wrapper.train()
    for i, batch in enumerate(train_dataloader_):
        optimizer.zero_grad()

        out_dict = unet_wrapper(batch, out_dict)
        
        loss_dict, metric_dict, latent_dict = loss_module(batch, out_dict, compute_metric=True, global_step=epoch * len(train_dataloader_) + i, current_epoch=epoch)

        loss = loss_dict.get_sum()

        loss.backward()
            
        # Gradient clipping

        # If detect nan values, then this step is skipped
        has_nan_value_cnt = 0
        for p in filter(lambda p: p.grad is not None, unet_wrapper.parameters()):
            if torch.any(p.grad.data != p.grad.data):
                has_nan_value_cnt += 1
        if has_nan_value_cnt > 0:
            exp.logger.warning(f"{has_nan_value_cnt} parameters get nan-gradient -- this step will be skipped.")
            for p in filter(lambda p: p.grad is not None, unet_wrapper.parameters()):
                p.grad.data.zero_()


        torch.nn.utils.clip_grad_value_(unet_wrapper.parameters(), clip_value=grad_clip)

        optimizer.step()
        scheduler.step()
        
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(train_dataloader_)}], Loss: {loss.item()}")
            
    # Validation
    unet_wrapper.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_dataloader_):
            out_dict = unet_wrapper(batch, out_dict)
        
            loss_dict, metric_dict, latent_dict = loss_module(batch, out_dict, compute_metric=True, global_step=epoch * len(train_dataloader_) + i, current_epoch=epoch)

            loss = loss_dict.get_sum()

            if i % 10 == 0:
                print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(val_dataloader_)}], Loss: {loss.item()}")

    # Save the model if validation loss is improved
    if loss < best_val_loss:
        best_val_loss = loss
        print(f"Validation loss improved to {best_val_loss}. Saving model...")
            
        # Save the model
        unet_wrapper.save_checkpoint(f"best_vae_model.pth") 

Epoch [0/100], Step [0/2], Loss: 49.02888870239258
Epoch [0/100], Step [0/1], Loss: 48.60641098022461
Validation loss improved to 48.60641098022461. Saving model...


Exception ignored in: <function _ConnectionBase.__del__ at 0x7f6f93cd4cc0>
Traceback (most recent call last):
  File "/home/benzshawelt/.conda/envs/xlayer/lib/python3.11/multiprocessing/connection.py", line 133, in __del__
    self._close()
  File "/home/benzshawelt/.conda/envs/xlayer/lib/python3.11/multiprocessing/connection.py", line 377, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor


KeyboardInterrupt: 