In [None]:
import functools
import gc
import importlib
import inspect
import multiprocessing
import pickle
import shutil
import traceback
from collections import OrderedDict, defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Union

# Third-party libraries - NumPy & Scientific
import numpy as np
from numpy.random import RandomState

# Third-party libraries - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard.summary import hparams

# Third-party libraries - Visualization
import matplotlib.pyplot as plt
from matplotlib.figure import Figure

# Third-party libraries - ML Tools
import wandb
from loguru import logger
from omegaconf import DictConfig, ListConfig, OmegaConf
import omegaconf.errors

# Local imports
from ext import common
import fvdb
import fvdb.nn as fvnn
from fvdb import JaggedTensor, GridBatch

# Local imports
from xcube_refactored.modules.autoencoding.hparams import hparams_handler
from xcube_refactored.utils.loss_util import AverageMeter
from xcube_refactored.utils.loss_util import TorchLossMeter
from xcube_refactored.utils import exp 

from xcube_refactored.modules.autoencoding.sunet import StructPredictionNet 

In [None]:
use_pos_embed_world = True

voxel_size = 0.0025
resolution = 512

# data info
duplicate_num = 10 # repeat the dataset to save the time of building dataloader
batch_size = 64
accumulate_grad_batches = 4
batch_size_val = 4
train_val_num_workers = 16

# diffusion - inference params
use_ddim = True
num_inference_steps = 100

# diffusion - scheduler-related adjust params
num_train_timesteps = 1000
beta_start = 0.0001
beta_end = 0.02
beta_schedule = "linear"
prediction_type = "v_prediction"

# diffusion - scale by std
scale_by_std = True
scale_factor = 1.0

ema = True
ema_decay = 0.9999


mse_weight = 1.0

learning_rate = {
  "init": 5.0e-5,
  "decay_mult": 1.0,
  "decay_step": 2000000000, # use a constant learning rate
  "clip": 1.0e-6
}
weight_decay = 0.0
grad_clip = 0.5

dims_diffuser = 3 # 3D conv
image_size = 128 # use during testing
model_channels = 128 
use_middle_attention: True
channel_mult = [1, 2, 2, 4] # 128 -> 16
attention_resolutions = [4, 8] # 32 | 16
num_res_blocks = 2
num_heads = 8
variance_type = "fixed_small"
clip_sample = False

In [None]:
def val_dataloader(self):
    import xcube.data as dataset
    val_set = dataset.build_dataset(
        self.hparams.val_dataset, self.get_dataset_spec(), self.hparams, self.hparams.val_kwargs)
    return DataLoader(val_set, batch_size=self.hparams.batch_size // self.trainer.world_size, shuffle=False,
                        num_workers=self.hparams.train_val_num_workers, collate_fn=self.get_collate_fn())


def train_dataloader(self):
    import xcube.data as dataset
    from torch.utils.data import DataLoader

    shuffle = True
    train_set = dataset.build_dataset(
        self.hparams.train_dataset, self.get_dataset_spec(), self.hparams, self.hparams.train_kwargs, duplicate_num=self.hparams.duplicate_num) # !: A change here for adding duplicate num for trainset without lantet

    batch_size = self.hparams.batch_size
    
    return DataLoader(train_set, batch_size=batch_size, shuffle=shuffle,
                        num_workers=self.hparams.train_val_num_workers, collate_fn=self.get_collate_fn())