In [1]:
import functools
import gc
import importlib
import inspect
import multiprocessing
import pickle
import shutil
import traceback
from collections import OrderedDict, defaultdict
from enum import Enum
from pathlib import Path
from typing import Any, Callable, Mapping, Optional, Union
import collections

# Third-party libraries - NumPy & Scientific
import numpy as np
from numpy.random import RandomState

# Third-party libraries - PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from torch.utils.data.dataloader import default_collate
from torch.utils.tensorboard.summary import hparams

# Third-party libraries - Visualization
import matplotlib.pyplot as plt
from matplotlib.figure import Figure
from modules.diffusionmodules.ema import LitEma
import tqdm

# Third-party libraries - ML Tools
# import wandb
# from loguru import logger
from omegaconf import DictConfig, ListConfig, OmegaConf
import omegaconf.errors

# Local imports
# from ext import common
import fvdb
import fvdb.nn as fvnn
from fvdb.nn import VDBTensor
from fvdb import JaggedTensor, GridBatch

# Local imports
from modules.autoencoding.hparams import hparams_handler
from utils.loss_util import AverageMeter
from utils.loss_util import TorchLossMeter
from utils import exp 

from modules.autoencoding.sunet import StructPredictionNet 

from modules.diffusionmodules.schedulers.scheduling_ddim import DDIMScheduler
from modules.diffusionmodules.schedulers.scheduling_ddpm import DDPMScheduler
from modules.diffusionmodules.schedulers.scheduling_dpmpp_2m import DPMSolverMultistepScheduler

from utils.Dataspec import DatasetSpec as DS

from modules.diffusionmodules.diffusion_cross_attn import UNetModel as Diffusion_Cross_Attn

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.set_float32_matmul_precision('high')

from utils.vis_util import vis_pcs


In [None]:
# use_pos_embed_world = True

# voxel_size = 0.001953125
resolution = 512

# # data info
# duplicate_num = 10 # repeat the dataset to save the time of building dataloader
# batch_size = 64
# accumulate_grad_batches = 4
# batch_size_val = 4
# train_val_num_workers = 16

# # diffusion - inference params
# use_ddim = True
# num_inference_steps = 100

# # diffusion - scheduler-related adjust params
# num_train_timesteps = 1000
# beta_start = 0.0001
# beta_end = 0.02
# beta_schedule = "linear"
# prediction_type = "v_prediction"

#   # scheduler:
#   #   num_train_timesteps: ${num_train_timesteps}
#   #   beta_start: ${beta_start}
#   #   beta_end: ${beta_end}
#   #   beta_schedule: ${beta_schedule} # cosine
#   #   variance_type: "fixed_small"
#   #   clip_sample: False
#   #   prediction_type: ${prediction_type} # epsilon

# # diffusion - scale by std
# scale_by_std = True
# scale_factor = 1.0

# ema = True
# ema_decay = 0.9999


# mse_weight = 1.0


# weight_decay = 0.0
# grad_clip = 0.5

# dims_diffuser = 3 # 3D conv
# image_size = 128 # use during testing
# model_channels = 64
# use_middle_attention: True
# channel_mult = [1, 2, 2, 4] # 128 -> 16
# attention_resolutions = [4, 8] # 32 | 16
# num_res_blocks = 2
# num_heads = 8
# variance_type = "fixed_small"
# clip_sample = False
# context_dim = 1024
# use_text_cond = True
# use_normal_concat_cond= True
# num_classes = 3 # normal dim 3???

# use_pos_embed_world= False
# use_pos_embed = True


learning_rate = {
  "init": 5.0e-5,
  "decay_mult": 1.0,
  "decay_step": 2000000000, # use a constant learning rate
  "clip": 1.0e-6
}

# optimizer: "Adam"
# learning_rate:
#   init: 5.0e-5
#   decay_mult: 1.0
#   decay_step: 2000000000 # use a constant learning rate
#   clip: 1.0e-6
# weight_decay: 0.0
# grad_clip: 0.5


# Main hyperparameters for DiffusionModel
hparams = {
    # Positional embedding options
    "use_pos_embed_world": False,
    "use_pos_embed": True,
    "use_pos_embed_high": False,
    "use_pos_embed_world_high": False,
    
    # Diffusion process parameters
    "scale_by_std": True,
    "scale_factor": 1.0,
    "ema": True,
    "ema_decay": 0.9999,
    "use_ddim": True,
    "num_inference_steps": 100,
    
    # Conditioning options
    "use_text_cond": True,
    "use_normal_concat_cond": True,
    "use_semantic_cond": False,
    "use_mask_cond": False,
    "use_point_cond": False,
    "use_class_cond": False,
    "use_micro_cond": False,
    "use_single_scan_concat_cond": False,
    "use_classifier_free": True,
    "classifier_free_prob": 0.1,
    
    # For inference/evaluation
    "diffuser_image_size": 128,  # image_size from your list
    
    # Other parameters
    "num_classes": 3,
    "mse_weight": 1.0,
    "conditioning_key": "c_crossattn",  # Based on code, this seems most likely given your settings
}

# Noise scheduler parameters
noise_scheduler_params = {
    "num_train_timesteps": 1000,
    "beta_start": 0.0001,
    "beta_end": 0.02,
    "beta_schedule": "linear",
    "prediction_type": "v_prediction",
    "variance_type": "fixed_small",
    "clip_sample": False
}

# UNet/Diffuser hyperparameters (passed via diffuser_kwargs)
diffuser_kwargs = {
    "dims": 3,  # 3D conv as specified by dims_diffuser
    "model_channels": 64,
    "use_middle_attention": True,
    "channel_mult": [1, 2, 2, 4],  # 128 -> 16
    "attention_resolutions": [4, 8],  # 32 | 16
    "num_res_blocks": 2,
    "num_heads": 8,
    "context_dim": 1024
}

# VAE parameters (retrieved from the vae module)
# vae_hparams = {
#     "voxel_size": 0.001953125,
#     "tree_depth": None,  # You'll need to provide this
#     "num_blocks": None,  # You'll need to provide this
#     "f_maps": None,      # You'll need to provide this
#     "cut_ratio": None    # You'll need to provide this
# }

# Training parameters
training_params = {
    "batch_size": 64,
    "accumulate_grad_batches": 4,
    "batch_size_val": 4,
    "train_val_num_workers": 16,
    "weight_decay": 0.0,
    "grad_clip": 0.5,
    "learning_rate": {
        "init": 5.0e-5,
        "decay_mult": 1.0,
        "decay_step": 2000000000,  # Constant learning rate
        "clip": 1.0e-6
    }
}

# Data parameters
data_params = {
    "duplicate_num": 10,  # Repeat dataset to save time building dataloader
    "resolution": 512
}


_custom_name =  "objaverse"
_objaverse_path = "/home/benzshawelt/Kedziora/kedziora_research/layer_x_layer/data_gen/voxels/512"
_split_path = "/home/benzshawelt/Kedziora/kedziora_research/layer_x_layer/data_gen/voxels/"
_text_emb_path = ""
_null_embed_path = "./assets/null_text_emb.pkl"

In [3]:
def get_random_sample_pcs(self, ijk: fvdb.JaggedTensor, batch_size=1, M=3, use_center=False):
    # M: sample per point
    # !: be careful about the batch_size
    output_ijk = []
    for idx in range(batch_size):
        current_ijk = ijk[idx].jdata.float() # N, 3
        if use_center:
            output_ijk.append(current_ijk)
        else:            
            N = current_ijk.shape[0]
            # create offsets of size M*N x 3 with values in range [-0.5, 0.5]
            offsets = torch.FloatTensor(N * M, 3).uniform_(-0.5, 0.5).to(current_ijk.device)
            # duplicate your original point cloud M times
            expanded_point_cloud = current_ijk.repeat(M, 1)
            # add offsets to duplicated points
            expanded_point_cloud += offsets
            output_ijk.append(expanded_point_cloud)

            
    return fvdb.JaggedTensor(output_ijk)
def decode_to_points(self, latents):
    res = self.vae.unet.FeaturesSet()
    res, output_x = self.vae.unet.decode(res, latents, is_testing=True)
    grid_tree = res.structure_grid
    fine_list = []
    coarse_list = []
    for batch_idx in range(output_x.grid.grid_count):
        # fineest level
        pd_grid_0 = grid_tree[0]
        pd_xyz_0 = pd_grid_0.grid_to_world(self.get_random_sample_pcs(pd_grid_0.ijk[batch_idx], M=8))
        fine_list.append(pd_xyz_0.jdata.cpu().numpy())
        # coarsest level
        pd_grid_1 = grid_tree[len(grid_tree.keys()) - 1]
        pd_xyz_1 = pd_grid_1.grid_to_world(self.get_random_sample_pcs(pd_grid_1.ijk[batch_idx], M=3))
        coarse_list.append(pd_xyz_1.jdata.cpu().numpy())
    # plot the fine and coarse level
    viz_fine = vis_pcs(fine_list)
    viz_coarse = vis_pcs(coarse_list)
    decode_results = np.concatenate([viz_fine, viz_coarse], axis=0)
    return decode_results

In [4]:
def list_collate(batch):
    """
    This just do not stack batch dimension.
    """
    elem = None
    for e in batch:
        if e is not None:
            elem = e
            break
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        return batch
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            return list_collate([torch.as_tensor(b) if b is not None else None for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, DictConfig) or isinstance(elem, ListConfig):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: list_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [list_collate(samples) for samples in transposed]
    elif isinstance(elem, GridBatch):
        return fvdb.cat(batch)
    
    # elif isinstance(elem, pathlib.Path):
    #     return batch
    # elif elem is None:
    #     return batch

    # raise NotImplementedError
    return batch

In [5]:
class Embedder(nn.Module):
    def __init__(self, include_input=True, input_dims=3, max_freq_log2=10, num_freqs=10, log_sampling=True, periodic_fns=[torch.sin, torch.cos]):
        super().__init__()
        embed_fns = []
        d = input_dims
        out_dim = 0
        if include_input:
            out_dim += d
            
        max_freq = max_freq_log2
        N_freqs = num_freqs
        
        if log_sampling:
            freq_bands = 2.**torch.linspace(0., max_freq, steps=N_freqs)
        else:
            freq_bands = torch.linspace(2.**0., 2.**max_freq, steps=N_freqs)
        
        for freq in freq_bands:
            for _ in periodic_fns:
                out_dim += d
        
        self.include_input = include_input
        self.freq_bands = freq_bands
        self.periodic_fns = periodic_fns
        
        self.embed_fns = embed_fns
        self.out_dim = out_dim
        
    def forward(self, inputs):
        output_list = [inputs]
        for fn in self.embed_fns:
            output_list.append(fn(inputs))
            
        for freq in self.freq_bands:
            for p_fn in self.periodic_fns:
                output_list.append(p_fn(inputs * freq))
        
        return torch.cat(output_list, -1)



In [6]:
def get_embedder(multires, i=0, input_dims=3):
    if i == -1:
        return nn.Identity(), 3
    
    embedder_obj = Embedder(max_freq_log2=multires-1, num_freqs=multires, input_dims=input_dims)
    return embedder_obj, embedder_obj.out_dim

In [7]:
class DiffusionModel(nn.Module):
    def __init__(self, hparams, num_classes,use_spatial_transformer,
    context_dim, diffuser_kwargs, noise_scheduler_params, trained_autoencoder=None):
        super().__init__()
        self.hparams = hparams
        self.noise_scheduler_hparams = noise_scheduler_params
        self.hparams.get("ema", False)
        self.hparams.get("use_ddim", False)
        self.hparams.get("scale_by_std", True)
        self.hparams.get('scale_factor', 1.0)
        self.hparams.get('num_inference_steps', 1000)

        self.hparams.use_pos_embed = False
        self.hparams.use_pos_embed_high = False
        self.hparams.get('use_pos_embed_world', False)
        self.hparams.use_pos_embed_world_high = False

        self.vae = trained_autoencoder
        self.vae.requires_grad_(False)


        unet_num_blocks = self.vae.hparams["num_blocks"]
        num_input_channels = self.vae.hparams["f_maps"] * 2 ** (unet_num_blocks - 1) # Fix by using VAE hparams
        num_input_channels = int(num_input_channels / self.vae.hparams["cut_ratio"])

        out_channels = num_input_channels
        num_classes = None
        use_spatial_transformer = False
        context_dim=None

        if self.hparams.get("use_pos_embed"):
            num_input_channels += 3
        elif self.hparams.get("use_pos_embed_high"):
            embed_fn, input_ch = get_embedder(5)
            self.pos_embedder = embed_fn
            num_input_channels += input_ch
        elif self.hparams.get("use_pos_embed_world"):
            num_input_channels += 3
        elif self.hparams.get("use_pos_embed_world_high"):
            embed_fn, input_ch = get_embedder(5)
            self.pos_embedder = embed_fn
            num_input_channels += input_ch

        # Define the UNet either cross attn or sparse
        self.unet = Diffusion_Cross_Attn(num_input_channels=num_input_channels, 
                                                        out_channels=out_channels, 
                                                        num_classes=num_classes,
                                                        use_spatial_transformer=use_spatial_transformer,
                                                        context_dim=context_dim,
                                                        **diffuser_kwargs)


        # get the schedulers
        self.noise_scheduler = DDPMScheduler(noise_scheduler_params)
        self.ddim_scheduler = DDIMScheduler(noise_scheduler_params)


        self.hparams["classifier_free_prob"] = 0.1 # prob to drop the label
            
        # finetune config
        self.hparams.get('pretrained_model_name_or_path', None)
        self.hparams.get('ignore_mismatched_size', False)

        if self.hparams["ema"]:
            self.unet_ema = LitEma(self.unet, decay=self.hparams["ema_decay"])
            
        # scale by std
        if not self.hparams["scale_by_std"]:
            self.scale_factor = self.hparams["scale_factor"]
            assert self.scale_factor == 1., 'when not using scale_by_std, scale_factor should be 1.'
        else:
            self.register_buffer('scale_factor', torch.tensor(self.hparams.scale_factor).float())

    # @contextmanager
    def ema_scope(self):
        if self.hparams["ema"]:
            self.unet_ema.store(self.unet.parameters())
            self.unet_ema.copy_to(self.unet)
        try:
            yield None
        finally:
            if self.hparams["ema"]:
                self.unet_ema.restore(self.unet.parameters())
                
    def get_pos_embed(self, h):
        return h[:, :3]
    
    def get_pos_embed_high(self, h):
        xyz = h[:, :3] # N, 3
        xyz = self.pos_embedder(xyz) # N, C
        return xyz
    
    def conduct_classifier_free(self, cond, batch_size, device, is_testing=False):
        if isinstance(cond, VDBTensor):
            cond = cond.feature
        assert isinstance(cond, fvdb.JaggedTensor), "cond should be JaggedTensor"

        mask = torch.rand(batch_size, device=device) < self.hparams["classifier_free_prob"] 
        new_cond = []
        for idx in range(batch_size):
            if mask[idx] or is_testing:
                # during testing, use this function to zero the condition
                new_cond.append(torch.zeros_like(cond[idx].jdata))
            else:
                new_cond.append(cond[idx].jdata)
        new_cond = fvdb.JaggedTensor(new_cond)
        return new_cond
    
    @exp.mem_profile(every=1)
    def forward(self, batch, out: dict):
        # first get latent from vae, the latent is the input to the diffusion model
        # A latent is the encoded feature from the input
        with torch.no_grad():
            latents = self.vae._encode(batch, use_mode=False)

        # To Do: scale the latent
        if self.hparams["scale_by_std"]:
            latents = latents * self.scale_factor

        # then get the noise
        latent_data = latents.jdata
        noise = torch.randn_like(latent_data) # N, C

        # bsz is the batch size ???? TODO
        bsz = latents.grid.grid_count
        
        # Sample a random timestep for each latent
        # A timestep is a random point in the training schedule

        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) # B
        timesteps_sparse = timesteps.long()
        timesteps_sparse = timesteps_sparse[latents.feature.jidx.long()] # N, 1

        # Add noise to the latents according to the noise magnitude at each timestep
        # (this is the forward diffusion process)
        noisy_latents = self.noise_scheduler.add_noise(latent_data, noise, timesteps_sparse)
        
        # Predict the target for the noise residual (this is the backward diffusion process for training)
        if self.noise_scheduler.config.prediction_type == "epsilon":
            target = noise
        # Currently Used ------------------ Very Important ------------------
        elif self.noise_scheduler.config.prediction_type == "v_prediction":
            target = self.noise_scheduler.get_velocity(latent_data, noise, timesteps_sparse)
        else:
            raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")

        # Predict the noise residual and compute loss
        # forward_cond function use batch-level timesteps
        noisy_latents = VDBTensor(grid=latents.grid, feature=latents.grid.jagged_like(noisy_latents))


        cond_dict = None
        is_testing=False
        guidance_scale=1.0
   
        # Classifier free guidance is when the model is trained with a classifier, 
        # but at inference time, the classifier is not used
        do_classifier_free_guidance = guidance_scale != 1.0

        # ! corssattn part
        # text condition
        if self.hparams["use_text_cond"]:
            # traing-time: get text from batch
            if batch is not None:
                text_emb = torch.stack(batch[DS.TEXT_EMBEDDING]) # B, 77, 1024
                mask = torch.stack(batch[DS.TEXT_EMBEDDING_MASK]) # B, 77
            else:
                text_emb = cond_dict['text_emb']
                mask = cond_dict['text_emb_mask']                
            context = text_emb
            if do_classifier_free_guidance:
                context_copy = cond_dict['text_emb_null']
                mask_copy = cond_dict['text_emb_mask_null']

        
        # ! concat part            
        concat_list = []        
        # semantic condition
        # if self.hparams["use_semantic_cond"]:
        #     # traing-time: get semantic from batch
        #     if batch is not None:
        #         input_semantic = fvdb.JaggedTensor(batch[DS.LATENT_SEMANTIC])
        #     else:
        #         input_semantic = cond_dict['semantics']
        #     semantic_cond = self.cond_stage_model(input_semantic.jdata.long())
        #     if not is_testing and self.hparams.use_classifier_free: # if VDBtensor, convert to JaggedTensor
        #         semantic_cond = self.conduct_classifier_free(semantic_cond, noisy_latents.grid.grid_count, noisy_latents.grid.device)  
        #     concat_list.append(semantic_cond) # ! tensor type
        
        # ! corssattn part
        # text condition
        if self.hparams["use_text_cond"]:
            # traing-time: get text from batch
            if batch is not None:
                text_emb = torch.stack(batch[DS.TEXT_EMBEDDING]) # B, 77, 1024
                mask = torch.stack(batch[DS.TEXT_EMBEDDING_MASK]) # B, 77
            else:
                text_emb = cond_dict['text_emb']
                mask = cond_dict['text_emb_mask']                
            context = text_emb
            if do_classifier_free_guidance:
                context_copy = cond_dict['text_emb_null']
                mask_copy = cond_dict['text_emb_mask_null']

    
        if self.hparams["conditioning_key"] == 'none':
            # no condition is used -------------------- VERY IMPORTANT --------------------
            model_pred = self.unet(noisy_latents, timesteps)
        elif self.hparams["conditioning_key"] == 'c_crossattn':
            assert len(concat_list) > 0, "concat_list should not be empty"
            assert context is not None, "context should not be None"
            noisy_latents_in = VDBTensor.cat([noisy_latents] + concat_list, dim=1)
            model_pred = self.unet(noisy_latents_in, timesteps, context=context, mask=mask)
            
        else:
            raise NotImplementedError
        

        out.update({'pred': model_pred.jdata})
        out.update({'target': target})

        return out

    @torch.no_grad()
    def extract_latent(self, batch):
        return self.vae._encode(batch, use_mode=False)
    

    # Used for inference / evaluation only, not for training
    def evaluation_api(self, batch = None, grids: GridBatch = None, batch_size: int = None, latent_prev: VDBTensor = None, 
                       use_ddim=False, ddim_step=100, use_ema=True, use_dpm=False, use_karras=False, solver_order=3,
                       h_stride=1, guidance_scale: float = 1.0, 
                       cond_dict=None, res_coarse=None, guided_grid=None):
        """
        * @param grids: GridBatch from previous stage for conditional diffusion
        * @param batch_size: batch_size for unconditional diffusion
        * @param latent_prev: previous stage latent for conditional diffusion; not implemented yet
        * @param use_ddim: use DDIM or not
        * @param ddim_step: number of steps for DDIM
        * @param use_dpm: use DPM++ solver or not
        * @param use_karras: use Karras noise schedule or not 
        * @param solver_order: order of the solver; 3 for unconditional diffusion, 2 for guided sampling
        * @param use_ema: use EMA or not
        * @param h_stride: flag for remain_h VAE to create a anisotropic latent grid
        * @param cond_dict: conditional dictionary -> only pass if manully effort needed
        * @param res_coarse: previous stage result (semantics, normals, etc) for conditional diffusion
        """
        if grids is None: 
            if batch is not None:
                latents = self.extract_latent(batch)
                grids = latents.grid
            else:
                # use dense diffusion
                # create a dense grid
                assert batch_size is not None, "batch_size should be provided"

                # Haven't seen this before #TODO
                feat_depth = self.vae.hparams["tree_depth"] - 1
                gap_stride = 2 ** feat_depth
                gap_strides = [gap_stride, gap_stride, gap_stride // h_stride]



                if isinstance(self.hparams["diffuser_image_size"], int):
                    neck_bound = int(self.hparams["diffuser_image_size"] / 2)
                    low_bound = [-neck_bound] * 3
                    voxel_bound = [neck_bound * 2] * 3
                else:        
                    voxel_bound = self.hparams["diffuser_image_size"]
                    low_bound = [- int(res / 2) for res in self.hparams["diffuser_image_size"]]
                
                # sv is the voxel size
                voxel_sizes = [sv * gap for sv, gap in zip(self.vae.hparams["voxel_size"], gap_strides)] 
                origins = [sv / 2. for sv in voxel_sizes]
                grids = fvdb.sparse_grid_from_dense(
                                batch_size, 
                                voxel_bound, 
                                low_bound, 
                                device="cpu", # hack to fix bugs
                                voxel_sizes=voxel_sizes,
                                origins=origins).to(self.device)
        # parse the cond_dict
        if cond_dict is None:
            cond_dict = {}

        # mask condition
        if self.hparams["use_semantic_cond"]:
            # check if semantics is in cond_dict
            if 'semantics' not in cond_dict:
                if batch is not None:
                    cond_dict['semantics'] = fvdb.JaggedTensor(batch[DS.LATENT_SEMANTIC])
                elif res_coarse is not None:
                    cond_semantic = res_coarse.semantic_features[-1].jdata # N, class_num
                    cond_semantic = torch.argmax(cond_semantic, dim=1)
                    cond_dict['semantics'] = grids.jagged_like(cond_semantic)
                else:
                    raise NotImplementedError("No semantics provided")
                

        # single scan concat condition
        if self.hparams.use_normal_concat_cond:
            # traing-time: get single scan crop from batch
            if batch is not None:
                ref_grid = fvdb.cat(batch[DS.INPUT_PC])    
                ref_xyz = ref_grid.grid_to_world(ref_grid.ijk.float()) 
                concat_normal = grids.splat_trilinear(ref_xyz, fvdb.JaggedTensor(batch[DS.TARGET_NORMAL]))
            elif res_coarse is not None:
                concat_normal = res_coarse.normal_features[-1].feature # N, 3
                concat_normal.jdata /= (concat_normal.jdata.norm(dim=1, keepdim=True) + 1e-6) # avoid nan
            else:
                raise NotImplementedError("No normal provided")
            cond_dict['normal'] = concat_normal                
        
        # diffusion process starts here ______________________________________________________________________________________
        if use_ema:
            with self.ema_scope("Evaluation API"):
                latents = self.random_sample_latents(grids, use_ddim=use_ddim, ddim_step=ddim_step, use_dpm=use_dpm, use_karras=use_karras, solver_order=solver_order,
                                                     cond_dict=cond_dict, guidance_scale=guidance_scale)
        else:
            latents = self.random_sample_latents(grids, use_ddim=use_ddim, ddim_step=ddim_step, use_dpm=use_dpm, use_karras=use_karras, solver_order=solver_order,
                                                     cond_dict=cond_dict, guidance_scale=guidance_scale)
        # decode
        res = self.vae.unet.FeaturesSet()
        if guided_grid is None:
            res, output_x = self.vae.unet.decode(res, latents, is_testing=True)
        else:
            res, output_x = self.vae.unet.decode(res, latents, guided_grid)
        # TODO: add SDF output
        return res, output_x
    

    
    def random_sample_latents(self, grids: GridBatch, generator: torch.Generator = None, 
                              use_ddim=False, ddim_step=None, use_dpm=False, use_karras=False, solver_order=3,
                              cond_dict=None, guidance_scale=1.0) -> VDBTensor:
        if use_ddim:
            if ddim_step is None:
                ddim_step = self.hparams["num_inference_steps"]
            self.ddim_scheduler.set_timesteps(ddim_step, device=grids.device)
            timesteps = self.ddim_scheduler.timesteps
            scheduler = self.ddim_scheduler

        elif use_dpm:
            if ddim_step is None:
                ddim_step = self.hparams["num_inference_steps"]
            try:
                self.dpm_scheduler.set_timesteps(ddim_step, device=grids.device)
            except:
                # create a new dpm scheduler
                self.dpm_scheduler = DPMSolverMultistepScheduler(
                    num_train_timesteps=self.noise_scheduler_hparams["num_train_timesteps"],
                    beta_start=self.noise_scheduler_hparams["beta_start"],
                    beta_end=self.noise_scheduler_hparams["beta_end"],
                    beta_schedule=self.noise_scheduler_hparams["beta_schedule"],
                    solver_order=solver_order,
                    prediction_type=self.noise_scheduler_hparams["prediction_type"],
                    algorithm_type="dpmsolver++",
                    use_karras_sigmas=use_karras,
                )
                self.dpm_scheduler.set_timesteps(ddim_step, device=grids.device)
            timesteps = self.dpm_scheduler.timesteps
            scheduler = self.dpm_scheduler
        else:
            timesteps = self.noise_scheduler.timesteps
            scheduler = self.noise_scheduler
        
        # prepare the latents
        latents = torch.randn(grids.total_voxels, self.unet.out_channels, device=grids.device, generator=generator)
        
        for i, t in tqdm(enumerate(timesteps)):
            latent_model_input = latents
            latent_model_input = scheduler.scale_model_input(latent_model_input, t)
            latent_model_input = VDBTensor(grid=grids, feature=grids.jagged_like(latent_model_input))
            # Predict the noise residual
            noise_pred = self._forward_cond(latent_model_input, t, cond_dict=cond_dict, is_testing=True, guidance_scale=guidance_scale) # TODO: cond
            # compute the previous noisy sample x_t -> x_t-1
            latents = scheduler.step(noise_pred.jdata, t, latents).prev_sample # TODO: when there is scale model input, why there is latents
            
        # scale the latents to the original scale
        if self.hparams["scale_by_std"]:
            latents = 1. / self.scale_factor * latents
        
        return VDBTensor(grid=grids, feature=grids.jagged_like(latents))
    
    def _forward_cond(self, noisy_latents: VDBTensor, timesteps: torch.Tensor, 
                      batch = None, cond_dict = None, is_testing=False, guidance_scale=1.0) -> VDBTensor:
        do_classifier_free_guidance = guidance_scale != 1.0
        # ! adm part
        # mask condition
        if self.hparams["use_mask_cond"]:
            coords = noisy_latents.grid.grid_to_world(noisy_latents.grid.ijk.float())
            coords = VDBTensor(noisy_latents.grid, coords)
            cond = self.cond_stage_model(coords)
        # point condition
        if self.hparams["use_point_cond"]:
            coords = noisy_latents.grid.grid_to_world(noisy_latents.grid.ijk.float()) # JaggedTensor
            if self.hparams["cond_stage_model_use_normal"]:
                if batch is not None: # training-time: get normal from batch
                    ref_xyz = fvdb.JaggedTensor(batch[DS.INPUT_PC])
                    # splatting normal
                    input_normal = noisy_latents.grid.splat_trilinear(ref_xyz, fvdb.JaggedTensor(batch[DS.TARGET_NORMAL]))
                    # normalize normal
                    input_normal.jdata /= (input_normal.jdata.norm(dim=1, keepdim=True) + 1e-6) # avoid nan
            else:
                input_normal = None
            cond = self.cond_stage_model(coords, input_normal)
        # class condition:
        if self.hparams["use_class_cond"]:
            if batch is not None:
                cond = self.cond_stage_model(batch, key=DS.CLASS)
            else:
                cond = self.cond_stage_model(cond_dict, key="class") # not checked yet
        # micro condition
        if self.hparams["use_micro_cond"]:
            if batch is not None:
                micro = batch[DS.MICRO]
                micro = torch.stack(micro).float()
            else:
                micro = cond_dict['micro']
            micro = self.micro_pos_embedder(micro)
            cond = self.micro_cond_model(micro)
        
        # ! concat part            
        concat_list = []        
        # semantic condition
        if self.hparams["use_semantic_cond"]:
            # traing-time: get semantic from batch
            if batch is not None:
                input_semantic = fvdb.JaggedTensor(batch[DS.LATENT_SEMANTIC])
            else:
                input_semantic = cond_dict['semantics']
            semantic_cond = self.cond_stage_model(input_semantic.jdata.long())
            if not is_testing and self.hparams["use_classifier_free"]: # if VDBtensor, convert to JaggedTensor
                semantic_cond = self.conduct_classifier_free(semantic_cond, noisy_latents.grid.grid_count, noisy_latents.grid.device)  
            concat_list.append(semantic_cond) # ! tensor type
        # single scan concat condition
        if self.hparams["use_single_scan_concat_cond"]:
            # traing-time: get single scan crop from batch
            if batch is not None:
                single_scan = fvdb.JaggedTensor(batch[DS.SINGLE_SCAN_CROP])
                single_scan_intensity = fvdb.JaggedTensor(batch[DS.SINGLE_SCAN_INTENSITY_CROP])
            else:
                single_scan = cond_dict['single_scan']
                single_scan_intensity = cond_dict['single_scan_intensity']
                
            # here use splatting to build the single scan grid tree
            single_scan_hash_tree = self.vae.build_normal_hash_tree(single_scan)
            single_scan_grid = single_scan_hash_tree[0]            
            if self.hparams["encode_single_scan_by_points"]:
                single_scan_feat = self.single_scan_pos_embedder(single_scan, single_scan_intensity, single_scan_grid)
                single_scan_feat = VDBTensor(single_scan_grid, single_scan_feat)
            else:
                single_scan_coords = single_scan_grid.grid_to_world(single_scan_grid.ijk.float()).jdata
                single_scan_feat = self.single_scan_pos_embedder(single_scan_coords)
                single_scan_feat = VDBTensor(single_scan_grid, single_scan_grid.jagged_like(single_scan_feat))
            single_scan_cond = self.single_scan_cond_model(single_scan_feat, single_scan_hash_tree)
            # align this feature to the latent
            single_scan_cond = noisy_latents.grid.fill_to_grid(single_scan_cond.feature, single_scan_cond.grid, 0.0)
            if not is_testing and self.hparams["use_classifier_free"]:
                single_scan_cond = self.conduct_classifier_free(single_scan_cond, noisy_latents.grid.grid_count, noisy_latents.grid.device)             
            concat_list.append(single_scan_cond)
        if self.hparams["use_normal_concat_cond"]:
            # traing-time: get single scan crop from batch
            if batch is not None:
                # assert self.hparams.use_fvdb_loader is True, "use_fvdb_loader should be True for normal concat condition"
                ref_grid = fvdb.cat(batch[DS.INPUT_PC])    
                ref_xyz = ref_grid.grid_to_world(ref_grid.ijk.float()) 
                concat_normal = noisy_latents.grid.splat_trilinear(ref_xyz, fvdb.JaggedTensor(batch[DS.TARGET_NORMAL]))
            else:
                concat_normal = cond_dict['normal']
            concat_normal.jdata /= (concat_normal.jdata.norm(dim=1, keepdim=True) + 1e-6) # avoid nan
            if not is_testing and self.hparams["use_classifier_free"]:
                concat_normal = self.conduct_classifier_free(concat_normal, noisy_latents.grid.grid_count, noisy_latents.grid.device)            
            concat_list.append(concat_normal)

        if do_classifier_free_guidance and len(concat_list) > 0: # ! not tested yet
            if not self.hparams["use_classifier_free"]:
                # logger.info("Applying classifier-free guidance without doing it for concat condition")
                concat_list_copy = concat_list
            else:
                # logger.info("Applying classifier-free guidance for concat condition")    
                # assert self.hparams.use_classifier_free, "do_classifier_free_guidance should be used with use_classifier_free"
                concat_list_copy = []
                for cond in concat_list:
                    cond = self.conduct_classifier_free(cond, noisy_latents.grid.grid_count, noisy_latents.grid.device, is_testing=True)
                    concat_list_copy.append(cond)
        
        # ! corssattn part
        # text condition
        if self.hparams["use_text_cond"]:
            # traing-time: get text from batch
            if batch is not None:
                text_emb = torch.stack(batch[DS.TEXT_EMBEDDING]) # B, 77, 1024
                mask = torch.stack(batch[DS.TEXT_EMBEDDING_MASK]) # B, 77
            else:
                text_emb = cond_dict['text_emb']
                mask = cond_dict['text_emb_mask']                
            context = text_emb
            if do_classifier_free_guidance:
                context_copy = cond_dict['text_emb_null']
                mask_copy = cond_dict['text_emb_mask_null']

        # concat pos_emb
        if self.hparams["use_pos_embed"]:
            pos_embed = noisy_latents.grid.ijk
            noisy_latents = VDBTensor.cat([noisy_latents, pos_embed], dim=1)
        elif self.hparams["use_pos_embed_high"]:
            pos_embed = self.get_pos_embed_high(noisy_latents.grid.ijk.jdata)
            noisy_latents = VDBTensor.cat([noisy_latents, pos_embed], dim=1)
        elif self.hparams["use_pos_embed_world"]:
            pos_embed = noisy_latents.grid.grid_to_world(noisy_latents.grid.ijk.float())
            noisy_latents = VDBTensor.cat([noisy_latents, pos_embed], dim=1)
        elif self.hparams["use_pos_embed_world_high"]:
            pos_embed = noisy_latents.grid.grid_to_world(noisy_latents.grid.ijk.float())
            pos_embed = self.get_pos_embed_high(pos_embed.jdata)
            noisy_latents = VDBTensor.cat([noisy_latents, pos_embed], dim=1)

        if self.hparams["conditioning_key"] == 'none':
            model_pred = self.unet(noisy_latents, timesteps)
        elif self.hparams["conditioning_key"] == 'concat':
            assert len(concat_list) > 0, "concat_list should not be empty"
            noisy_latents_in = VDBTensor.cat([noisy_latents] + concat_list, dim=1)
            model_pred = self.unet(noisy_latents_in, timesteps)
            
            if do_classifier_free_guidance:
                noisy_latents_in_copy = VDBTensor.cat([noisy_latents] + concat_list_copy, dim=1)
                model_pred_copy = self.unet(noisy_latents_in_copy, timesteps)
                model_pred = VDBTensor(model_pred.grid, model_pred.grid.jagged_like(model_pred.feature.jdata + guidance_scale * (model_pred.feature.jdata - model_pred_copy.feature.jdata)))
        elif self.hparams["conditioning_key"] == 'adm':
            assert cond is not None, "cond should not be None"
            model_pred = self.unet(noisy_latents, timesteps, y=cond)
        elif self.hparams["conditioning_key"] == 'crossattn': ### !! its always crossattn
            assert context is not None, "context should not be None"
            model_pred = self.unet(noisy_latents, timesteps, context=context, mask=mask)
            
            if do_classifier_free_guidance:
                model_pred_copy = self.unet(noisy_latents, timesteps, context=context_copy, mask=mask_copy)
                model_pred = VDBTensor(model_pred.grid, model_pred.grid.jagged_like(model_pred.feature.jdata + guidance_scale * (model_pred.feature.jdata - model_pred_copy.feature.jdata)))
        elif self.hparams["conditioning_key"] == 'c_crossattn':
            assert len(concat_list) > 0, "concat_list should not be empty"
            assert context is not None, "context should not be None"
            noisy_latents_in = VDBTensor.cat([noisy_latents] + concat_list, dim=1)
            model_pred = self.unet(noisy_latents_in, timesteps, context=context, mask=mask)
            
            if do_classifier_free_guidance:
                noisy_latents_in_copy = VDBTensor.cat([noisy_latents] + concat_list_copy, dim=1)
                model_pred_copy = self.unet(noisy_latents_in_copy, timesteps, context=context_copy, mask=mask_copy)
                model_pred = VDBTensor(model_pred.grid, model_pred.grid.jagged_like(model_pred.feature.jdata + guidance_scale * (model_pred.feature.jdata - model_pred_copy.feature.jdata)))
        else:
            raise NotImplementedError

        return model_pred

In [None]:
vae = None  # Placeholder for the trained autoencoder // will be  gotten from torch.load_state_dict()


In [None]:
diffusion_model = DiffusionModel(
    hparams=hparams,
    num_classes=hparams["num_classes"],
    use_spatial_transformer=False,
    context_dim=diffuser_kwargs["context_dim"],
    diffuser_kwargs=diffuser_kwargs,
    noise_scheduler_params=noise_scheduler_params,
    trained_autoencoder=vae  #pre-trained VAE model
)

AttributeError: 'NoneType' object has no attribute 'ema'

In [None]:
def lambda_lr_wrapper(it, lr_config, batch_size, accumulate_grad_batches=1):
    return max(
        lr_config['decay_mult'] ** (int(it * batch_size * accumulate_grad_batches / lr_config['decay_step'])),
        lr_config['clip'] / lr_config['init'])

In [None]:
weight_decay = 0.0
batch_size = 64
use_input_normal = False
use_input_semantic = False
with_semantic_branch = False
use_input_intensity = False
max_text_len = 77
text_embed_drop_prob = 0.1


train_dataset = "ObjaverseDataset"
train_val_num_workers= 16
train_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "train",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "text_embed_drop_prob": text_embed_drop_prob, # ! classifier-free training
  "random_seed": 0
}

val_dataset = "ObjaverseDataset"
val_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "test",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "random_seed": "fixed"
}

test_dataset = "ObjaverseDataset"
test_num_workers =8
test_kwargs = {
  "onet_base_path": _objaverse_path ,
  "resolution": resolution,
  "custom_name": _custom_name,
  "split_base_path": _split_path,
  "split": "test",
  "text_emb_path": _text_emb_path,
  "null_embed_path": _null_embed_path,
  "max_text_len": max_text_len,
  "random_seed": "fixed"
}



In [None]:
optimizer = torch.optim.AdamW(diffusion_model.parameters(), lr=learning_rate["init"],
                                    weight_decay=weight_decay, amsgrad=True)

scheduler = LambdaLR(optimizer,
                lr_lambda=functools.partial(
                    lambda_lr_wrapper, lr_config=learning_rate, batch_size=batch_size))

# exp.global_var_manager.register_variable('skip_backward', False)

def list_collate(batch):
    """
    This just do not stack batch dimension.
    """
    
    elem = None
    for e in batch:
        if e is not None:
            elem = e
            break
    elem_type = type(elem)
    if isinstance(elem, torch.Tensor):
        return batch
    elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
            and elem_type.__name__ != 'string_':
        if elem_type.__name__ == 'ndarray' or elem_type.__name__ == 'memmap':
            return list_collate([torch.as_tensor(b) if b is not None else None for b in batch])
        elif elem.shape == ():  # scalars
            return torch.as_tensor(batch)
    elif isinstance(elem, float):
        return torch.tensor(batch, dtype=torch.float64)
    elif isinstance(elem, int):
        return torch.tensor(batch)
    elif isinstance(elem, str):
        return batch
    elif isinstance(elem, DictConfig) or isinstance(elem, ListConfig):
        return batch
    elif isinstance(elem, collections.abc.Mapping):
        return {key: list_collate([d[key] for d in batch]) for key in elem}
    elif isinstance(elem, collections.abc.Sequence):
        # check to make sure that the elements in batch have consistent size
        it = iter(batch)
        elem_size = len(next(it))
        if not all(len(elem) == elem_size for elem in it):
            raise RuntimeError('each element in list of batch should be of equal size')
        transposed = zip(*batch)
        return [list_collate(samples) for samples in transposed]
    elif isinstance(elem, GridBatch):
        return fvdb.jcat(batch)
    
    return batch

def get_dataset_spec():
    all_specs = [DS.SHAPE_NAME, DS.INPUT_PC,
                    DS.GT_DENSE_PC, DS.GT_GEOMETRY]
    if use_input_normal:
        all_specs.append(DS.TARGET_NORMAL)
        all_specs.append(DS.GT_DENSE_NORMAL)
    if use_input_semantic or with_semantic_branch:
        all_specs.append(DS.GT_SEMANTIC)
    if use_input_intensity:
        all_specs.append(DS.INPUT_INTENSITY)
    return all_specs


def train_dataloader():
    from data.objaverse import ObjaverseDataset
    train_set =  ObjaverseDataset(onet_base_path=train_kwargs["onet_base_path"], 
                                  spec=get_dataset_spec(), 
                                  split=train_kwargs["split"], 
                                  resolution=train_kwargs["resolution"], 
                                  image_base_path=None, 
                                  random_seed=0, 
                                  hparams=None, 
                                  skip_on_error=False, 
                                  custom_name="objaverse", 
                                  text_emb_path="../data/objaverse/objaverse/text_emb", 
                                  null_embed_path="./assets/null_text_emb.pkl", 
                                  text_embed_drop_prob=0.0, 
                                  max_text_len=77, 
                                  duplicate_num=1, 
                                  split_base_path=_split_path,
                                  )
        
    return DataLoader(train_set, batch_size=batch_size, shuffle=True,
                        num_workers=train_val_num_workers, collate_fn=list_collate)


# print(get_dataset_spec())

def val_dataloader():
    from data.objaverse import ObjaverseDataset
    val_set = ObjaverseDataset(onet_base_path=val_kwargs["onet_base_path"],
                                spec=get_dataset_spec(), 
                                split=val_kwargs["split"], 
                                resolution=val_kwargs["resolution"], 
                                image_base_path=None, 
                                random_seed=0, 
                                hparams=None, 
                                skip_on_error=False, 
                                custom_name="objaverse", 
                                text_emb_path="../data/objaverse/objaverse/text_emb", 
                                null_embed_path="./assets/null_text_emb.pkl", 
                                text_embed_drop_prob=0.0, 
                                max_text_len=77, 
                                duplicate_num=1, 
                                split_base_path=_split_path,
                                )


    return DataLoader(val_set, batch_size=batch_size, shuffle=False,
                        num_workers=train_val_num_workers, collate_fn=list_collate)

def test_dataloader(resolution=resolution, test_set_shuffle=False):
    from data.objaverse import ObjaverseDataset
    resolution = resolution # ! use for testing when training on X^3 but testing on Y^3

    test_set =  ObjaverseDataset(onet_base_path=test_kwargs["onet_base_path"],
                                spec=get_dataset_spec(), 
                                split=test_kwargs["split"], 
                                resolution=resolution, 
                                image_base_path=None, 
                                random_seed=0, 
                                hparams=None, 
                                skip_on_error=False, 
                                custom_name="objaverse", 
                                text_emb_path="../data/objaverse/objaverse/text_emb", 
                                null_embed_path="./assets/null_text_emb.pkl", 
                                text_embed_drop_prob=0.0, 
                                max_text_len=77, 
                                duplicate_num=1, 
                                split_base_path=_split_path,
                                )
    
    if test_set_shuffle:
        torch.manual_seed(0)
    return DataLoader(test_set, batch_size=1, shuffle=test_set_shuffle, 
                        num_workers=0, collate_fn=list_collate)


In [None]:
@exp.mem_profile(every=1)
def compute_loss(batch, out, compute_metric: bool):
    loss_dict = exp.TorchLossMeter()
    metric_dict = exp.TorchLossMeter()

    # compute the MSE loss
    if hparams["mse_weight"] > 0.0:
        loss_dict.add_loss("mse", F.mse_loss(out["pred"], out["target"]), hparams["mse_weight"])
    if compute_metric: # currently use MSE as metric
        metric_dict.add_loss("mse", F.mse_loss(out["pred"], out["target"]))

    return loss_dict, metric_dict 

In [None]:
def log_dict_prefix(
    self,
    prefix: str,
    dictionary: Mapping[str, Any],
    prog_bar: bool = False,
    logger: bool = True,
    on_step: Optional[bool] = None,
    on_epoch: Optional[bool] = None,
):
    """
    This overrides fixes if dict key is not a string...
    """
    dictionary = {
        prefix + "/" + str(k): v for k, v in dictionary.items()
    }
    self.log_dict(dictionary=dictionary,
                    prog_bar=prog_bar,
                    logger=logger, on_step=on_step, on_epoch=on_epoch)

In [None]:
def log_image(self, name: str, img, step: Optional[int] = None):
    # if self.trainer.logger is not None:
        # if self.logger_type == 'tb':
        if img.shape[2] <= 4:
            # WHC -> CWH
            img = np.transpose(img, (2, 0, 1))
        # self.trainer.logger.experiment.add_image(name, img, self.trainer.global_step)

In [None]:
train_loss = 0

for batch_idx, batch in enumerate(train_dataloader()):
    batch_loss = 0
    
    is_val = False
    if batch_idx % 10 == 0:
        is_val = True

    if batch_idx % 100 == 0:
        gc.collect()
        torch.cuda.empty_cache()

    out = {'idx': batch_idx}
    with exp.pt_profile_named("forward"):
        out = diffusion_model(batch, out)

    if out is None and not is_val:
        # return None
        print("out is None???")

    # Compute metric in train would be the out['pred'] and out['target']
    loss_dict, metric_dict = compute_loss(batch, out, compute_metric=is_val)

    if not is_val:
        log_dict_prefix('train_loss', loss_dict)
    else:
        log_dict_prefix('val_metric', metric_dict)
        log_dict_prefix('val_loss', loss_dict)

        if hparams.log_image:
            cond_dict = {}
            
            if hparams.use_text_cond:
                text_emb = torch.stack(batch[DS.TEXT_EMBEDDING]) # B, 77, 1024
                mask = torch.stack(batch[DS.TEXT_EMBEDDING_MASK]) # B, 77
                cond_dict['text_emb'] = text_emb
                cond_dict['text_emb_mask'] = mask
            
            if batch_idx == 0:
                with diffusion_model.ema_scope("Plotting"):
                    # first extract latent

                    clean_latents = diffusion_model.extract_latent(batch)
                    grids = clean_latents.grid

                    # sample latents
                    sample_latents = diffusion_model.random_sample_latents(grids, use_ddim=hparams.use_ddim, ddim_step=100, cond_dict=cond_dict) # TODO: change this ddim_step to variable
                    
                    # decode clean latents first
                    decode_clean = decode_to_points(clean_latents)
                    # Decode sample latents
                    decode_sample = decode_to_points(sample_latents)
                    sample = np.concatenate([decode_clean, decode_sample], axis=0)
                    log_image("img/sample", sample)
                    # clean matplotlib opens
                    plt.close('all')

                    clean_latents = diffusion_model.extract_latent(batch)
                    grids = clean_latents.grid
                    _ = diffusion_model.random_sample_latents(grids, use_ddim=hparams.use_ddim, ddim_step=100, cond_dict=cond_dict) # TODO: change this ddim_step to variable

    loss_sum = loss_dict.get_sum()

    batch_loss += loss_sum
    if hparams.ema:
        diffusion_model.unet_ema(diffusion_model.unet)    


