In [1]:
import functools
import gc
import importlib
import inspect
import multiprocessing
import pickle
import shutil
import traceback
from collections import OrderedDict, defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Union

# Third-party libraries - NumPy & Scientific
import numpy as np
from numpy.random import RandomState

# Third-party libraries - PyTorch
import torch
print (torch.__version__, "torch")
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
# from torch.utils.tensorboard.summary import hparams

# Third-party libraries - Visualization
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

# Third-party libraries - ML Tools
from omegaconf import DictConfig, ListConfig, OmegaConf
import omegaconf.errors

# Local imports
# from ext import common
import fvdb
import fvdb.nn as fvnn
from fvdb import JaggedTensor, GridBatch

# Local imports
# from modules.autoencoding.hparams import hparams_handler
# from utils.loss_util import AverageMeter
# from utils.loss_util import TorchLossMeter
# from utils import exp 

# from modules.autoencoding.sunet import StructPredictionNet 

# from utils.vis_util import vis_pcs


# from modules.diffusionmodules.schedulers.scheduling_ddim import DDIMScheduler
# from modules.diffusionmodules.schedulers.scheduling_ddpm import DDPMScheduler
# from modules.diffusionmodules.schedulers.scheduling_dpmpp_2m import DPMSolverMultistepScheduler


# from modules.diffusionmodules.ema import LitEma

# Why aren't these used??????
from modules.encoders import (SemanticEncoder, ClassEmbedder, PointNetEncoder,
                                    StructEncoder, StructEncoder3D, StructEncoder3D_remain_h, StructEncoder3D_v2)

from modules.autoencoding.sunet import StructPredictionNet
import collections

2.4.1 torch


In [2]:
# Hyperparameters

batch_size = 4
tree_depth = 3 # according to 512x512x512 -> 128x128x128
voxel_size = 0.0025
resolution = 512
use_fvdb_loader = True
use_hash_tree = True # use hash tree means use early dilation (description in Sec 3.4) 


with_semantic_branch = False
extract_mesh = "store_true"

solver_order ="3"

use_dpm = "store_true"

ddim_step = 50

use_ddim = "store_true"

ema = "store_true"

batch_len = 64

toal_len = 700

seed = 0

world_size = 1

# setup input
use_input_normal = True
use_input_semantic = False
use_input_intensity = False

# setup KL loss
cut_ratio = 16 # reduce the dimension of the latent space
kl_weight = 1.0 # activate when anneal is off
normalize_kld = True
enable_anneal = False
kl_weight_min = 1e-7
kl_weight_max = 1.0
anneal_star_iter = 0
anneal_end_iter = 70000 # need to adjust for different dataset


structure_weight = 20.0
normal_weight = 300.0
  

learning_rate = {
  "init": 1.0e-4,
  "decay_mult": 0.7,
  "decay_step": 50000,
  "clip": 1.0e-6
}
weight_decay = 0.0
grad_clip = 0.5

c_dim = 32
  
# unet parameters
in_channels = 32
num_blocks = tree_depth
f_maps = 32
neck_dense_type = "UNCHANGED"
neck_bound = [64, 64, 64] # useless but indicate here
num_res_blocks = 1
use_residual = False
order = "gcr"
is_add_dec = False
use_attention = False
use_checkpoint = False


_custom_name =  "objaverse"
_objaverse_path = "/home/benzshawelt/Kedziora/kedziora_research/layer_x_layer/data_gen/voxels/512"
_split_path = "/home/benzshawelt/Kedziora/kedziora_research/layer_x_layer/data_gen/voxels/"
_text_emb_path = ""
_null_embed_path = "./assets/null_text_emb.pkl"
max_text_len= 77
text_embed_drop_prob= 0.1

train_dataset = "ObjaverseDataset"
train_val_num_workers= 16
train_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "train",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "text_embed_drop_prob": text_embed_drop_prob, # ! classifier-free training
  "random_seed": 0
}

val_dataset = "ObjaverseDataset"
val_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "test",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "random_seed": "fixed"
}

test_dataset = "ObjaverseDataset"
test_num_workers =8
test_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "test",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "random_seed": "fixed"
}

In [3]:
# get the paths of all of the voxelized shapes, the shapes are stored in .pkl files within 2 folder depths of the base path
def get_all_paths(base_path):
  import os
  # list all of the files in the base path
  # print("base path", base_path)
  # print(os.listdir(base_path))
  all_paths = []
  for root, dirs, files in os.walk(base_path):
    for file in files:
      if file.endswith(".pkl"):
        all_paths.append(os.path.join(root, file.split(".")[0]))
        
  return all_paths

In [4]:
def split_dataset(dataset, train_split_ratio, test_split_ratio, seed=0):
    np.random.seed(seed)
    indices = np.arange(len(dataset))
    np.random.shuffle(indices)
    train_split = int(len(dataset) * train_split_ratio)
    # always take .10 of the dataset for test (removed from val)
    test_split = int(len(dataset) * test_split_ratio) + train_split

    #return the train, val, and test datasets
    return torch.utils.data.Subset(dataset, indices[:train_split]), torch.utils.data.Subset(dataset, indices[train_split:test_split]), torch.utils.data.Subset(dataset, indices[test_split:])
    

In [5]:
def generate_split_paths(base_path, split_ratio, seed=0):
    all_paths = get_all_paths(base_path)
    train_paths, val_paths, test_split= split_dataset(all_paths, split_ratio, 0.1, seed)

    # Save the split paths as a lst file
    split_base_path = Path(base_path).parent
    train_split_path = split_base_path / "train.lst"
    val_split_path = split_base_path / "val.lst"
    test_split_path = split_base_path / "test.lst"
    

    # delete the files if they already exist
    if train_split_path.exists():
        train_split_path.unlink()
    if val_split_path.exists():
        val_split_path.unlink()
    if test_split_path.exists():
        test_split_path.unlink()
        
    # write the paths to the files

    with open(test_split_path, "w") as f:
        for path in test_split:
            f.write(f"{path}\n")
    with open(train_split_path, "w") as f:
        for path in train_paths:
            f.write(f"{path}\n")
    with open(val_split_path, "w") as f:
        for path in val_paths:
            f.write(f"{path}\n")
    return True

In [6]:
generate_split_paths(_objaverse_path , 0.8, 42)

True

In [7]:
def lambda_lr_wrapper(it, lr_config, batch_size, accumulate_grad_batches=1):
    return max(
        lr_config['decay_mult'] ** (int(it * batch_size * accumulate_grad_batches / lr_config['decay_step'])),
        lr_config['clip'] / lr_config['init'])


In [8]:
# define the model
model = StructPredictionNet(
  in_channels=in_channels,
  num_blocks=num_blocks,
  f_maps=f_maps,
  neck_dense_type=neck_dense_type,
  neck_bound=neck_bound,
  num_res_blocks=num_res_blocks,
  use_residual=use_residual,
  order=order,
  is_add_dec=is_add_dec,
  use_attention=use_attention,
  use_checkpoint=use_checkpoint,
  c_dim=c_dim
)

In [9]:
from utils.Dataspec import DatasetSpec


optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate["init"],
                                    weight_decay=weight_decay, amsgrad=True)

scheduler = LambdaLR(optimizer,
                lr_lambda=functools.partial(
                    lambda_lr_wrapper, lr_config=learning_rate, batch_size=batch_size))

# exp.global_var_manager.register_variable('skip_backward', False)

def list_collate(batch):
    """
    This just do not stack batch dimension.
    """
    
    elem = None
    for e in batch:
        if e is not None:
            elem = e
            break
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        return batch
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            return list_collate([torch.as_tensor(b) if b is not None else None for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, DictConfig) or isinstance(elem, ListConfig):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: list_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [list_collate(samples) for samples in transposed]
    elif isinstance(elem, GridBatch):
        return fvdb.cat(batch)
    
    return batch

def get_dataset_spec():
    DS = DatasetSpec
    all_specs = [DS.SHAPE_NAME, DS.INPUT_PC,
                    DS.GT_DENSE_PC, DS.GT_GEOMETRY]
    if use_input_normal:
        all_specs.append(DS.TARGET_NORMAL)
        all_specs.append(DS.GT_DENSE_NORMAL)
    if use_input_semantic or with_semantic_branch:
        all_specs.append(DS.GT_SEMANTIC)
    if use_input_intensity:
        all_specs.append(DS.INPUT_INTENSITY)
    return all_specs



def train_dataloader():
    from data.objaverse import ObjaverseDataset
    train_set =  ObjaverseDataset(onet_base_path=train_kwargs["onet_base_path"], 
                                  spec=get_dataset_spec(), 
                                  split=train_kwargs["split"], 
                                  resolution=train_kwargs["resolution"], 
                                  image_base_path=None, 
                                  random_seed=0, 
                                  hparams=None, 
                                  skip_on_error=False, 
                                  custom_name="objaverse", 
                                  text_emb_path="../data/objaverse/objaverse/text_emb", 
                                  null_embed_path="./assets/null_text_emb.pkl", 
                                  text_embed_drop_prob=0.0, 
                                  max_text_len=77, 
                                  duplicate_num=1, 
                                  split_base_path=_split_path,
                                  )
        
    return DataLoader(train_set, batch_size=batch_size // world_size, shuffle=True,
                        num_workers=train_val_num_workers, collate_fn=list_collate)


print(get_dataset_spec())

def val_dataloader():
    from data.objaverse import ObjaverseDataset
    val_set = ObjaverseDataset(onet_base_path=val_kwargs["onet_base_path"],
                                spec=get_dataset_spec(), 
                                split=val_kwargs["split"], 
                                resolution=val_kwargs["resolution"], 
                                image_base_path=None, 
                                random_seed=0, 
                                hparams=None, 
                                skip_on_error=False, 
                                custom_name="objaverse", 
                                text_emb_path="../data/objaverse/objaverse/text_emb", 
                                null_embed_path="./assets/null_text_emb.pkl", 
                                text_embed_drop_prob=0.0, 
                                max_text_len=77, 
                                duplicate_num=1, 
                                split_base_path=_split_path,
                                )


    return DataLoader(val_set, batch_size=batch_size // world_size, shuffle=False,
                        num_workers=train_val_num_workers, collate_fn=list_collate)

def test_dataloader(resolution=resolution, test_set_shuffle=False):
    from data.objaverse import ObjaverseDataset
    resolution = resolution # ! use for testing when training on X^3 but testing on Y^3

    test_set =  ObjaverseDataset(onet_base_path=test_kwargs["onet_base_path"],
                                spec=get_dataset_spec(), 
                                split=test_kwargs["split"], 
                                resolution=resolution, 
                                image_base_path=None, 
                                random_seed=0, 
                                hparams=None, 
                                skip_on_error=False, 
                                custom_name="objaverse", 
                                text_emb_path="../data/objaverse/objaverse/text_emb", 
                                null_embed_path="./assets/null_text_emb.pkl", 
                                text_embed_drop_prob=0.0, 
                                max_text_len=77, 
                                duplicate_num=1, 
                                split_base_path=_split_path,
                                )
    
    if test_set_shuffle:
        torch.manual_seed(0)
    return DataLoader(test_set, batch_size=1, shuffle=test_set_shuffle, 
                        num_workers=0, collate_fn=list_collate)


[<DatasetSpec.SHAPE_NAME: 100>, <DatasetSpec.INPUT_PC: 200>, <DatasetSpec.GT_DENSE_PC: 400>, <DatasetSpec.GT_GEOMETRY: 800>, <DatasetSpec.TARGET_NORMAL: 300>, <DatasetSpec.GT_DENSE_NORMAL: 500>]


In [10]:
train_dataloader_ = train_dataloader()
val_dataloader_ = val_dataloader()
test_dataloader_ = test_dataloader()



In [11]:
import utils.exp as exp

In [12]:
epochs = 100

In [13]:
# Define the training step

for epoch in range(epochs):
    model.train()
    for i, batch in enumerate(train_dataloader_):
        optimizer.zero_grad()
        # Forward pass
        print(batch)
        loss = model(batch, None)
        loss.backward()



        # If detect nan values, then this step is skipped
        has_nan_value_cnt = 0
        for p in filter(lambda p: p.grad is not None, model.parameters()):
            if torch.any(p.grad.data != p.grad.data):
                has_nan_value_cnt += 1
        if has_nan_value_cnt > 0:
            exp.logger.warning(f"{has_nan_value_cnt} parameters get nan-gradient -- this step will be skipped.")
            for p in filter(lambda p: p.grad is not None, model.parameters()):
                p.grad.data.zero_()


        torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=grad_clip)

        optimizer.step()
        scheduler.step()
        
        if i % 100 == 0:
            print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(train_dataloader)}], Loss: {loss.item()}")
            
    # Validation
    model.eval()
    with torch.no_grad():
        for i, batch in enumerate(val_dataloader):
            loss = model(batch)
            if i % 100 == 0:
                print(f"Epoch [{epoch}/{epochs}], Step [{i}/{len(val_dataloader)}], Loss: {loss.item()}")
            
    # Save the model
    torch.save(model.state_dict(), f"model_{epoch}.pth")

  input_data = torch.load(os.path.join(self.onet_base_path, category, model) + ".pkl")
  input_data = torch.load(os.path.join(self.onet_base_path, category, model) + ".pkl")


AttributeError: Caught AttributeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "/home/benzshawelt/.conda/envs/xlayer/lib/python3.11/site-packages/torch/utils/data/_utils/worker.py", line 309, in _worker_loop
    data = fetcher.fetch(index)  # type: ignore[possibly-undefined]
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/benzshawelt/.conda/envs/xlayer/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 55, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1131378/1434144813.py", line 41, in list_collate
    return {key: list_collate([d[key] for d in batch]) for key in elem}
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1131378/1434144813.py", line 41, in <dictcomp>
    return {key: list_collate([d[key] for d in batch]) for key in elem}
                 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/tmp/ipykernel_1131378/1434144813.py", line 51, in list_collate
    return fvdb.cat(batch)
           ^^^^^^^^
AttributeError: module 'fvdb' has no attribute 'cat'
