# NeRF: Collision Handling in Instant Neural Graphics Primitives
Federico Montagna (fedemonti00@gmail.com)

# Code

## Imports

In [None]:
# Types
from typing import List, Tuple, Dict, Any, Optional, Union, Callable, TypeVar

# Numpy
import numpy as np

# Scikit-learn
from sklearn.model_selection import train_test_split

# Matplotlib
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

# Einops
from einops import rearrange, reduce, repeat

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam, SGD, AdamW
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, CosineAnnealingWarmRestarts
from torchvision import io

# datetime
try:
    from zoneinfo import ZoneInfo
except ImportError:
    from backports.zoneinfo import ZoneInfo

from datetime import datetime

# other
import traceback
import inspect
import os
import random
import wandb
from tqdm import tqdm


print("Cuda avilable:", torch.cuda.is_available())
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"Available device {i}:", torch.cuda.get_device_name(i))

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"Current device {torch.cuda.current_device()}:", torch.cuda.get_device_name(torch.cuda.current_device()))

# torch.set_default_device(device)

random_seed = 31504 # 4129
# random_seed = np.random.randint(0, (2**16 - 1)) # 4129
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.random.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
print("Random seed:", random_seed)

os.environ["WANDB_NOTEBOOK_NAME"] = "main.ipynb"

plt.style.use("ggplot")

## Debug functions

In [None]:
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'

decoded_functions = {
    # Dataset
    "__getitem__": 0,
    # general
    "__init__": 0,
    "forward": 1,
    "train_loop": 1,
    "test_loop": 1,
    "plot_images": 2,
    # MultiResolution
    "_multiresolution_hash": 2,
    "_scale_to_grid": 3,
    "_calc_bilinear_coefficients": 4,
    "_calc_dummies": 5,
    "_calc_hash_collisions": 6,
    "_calc_uniques": 7,
    "_hist_collisions": 8,
    # GeneralNeuralGaugeFields
    "_look_up_features": 2,
    "_bilinear_interpolation": 3,
    # Loss
    "_calc_hist_pdf": 2,
    "_kl_div": 3,
    "differentiable_histogram": 4,
    # Test loop
    "create_indices_mapping": 1,
}


def log(texts, allowed: List | bool, color: bcolors = bcolors.OKCYAN) -> None:
    should_log = (
        (type(allowed) == bool and allowed)
        or 
        (type(allowed) == list and decoded_functions[inspect.stack()[1][0].f_code.co_name] in allowed)
    )

    if should_log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(color, "Line: ", calling_line, bcolors.ENDC)

        try:
            print(*texts)
        except:
            print(texts)

        # print_allocated_memory(True)
        
        print(color, "-"*20, bcolors.ENDC)


def print_allocated_memory(log: bool = True):
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(bcolors.HEADER, "Line: ", calling_line, ", from: ", inspect.stack()[1][0].f_code.co_name, bcolors.ENDC)

        allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Allocated Memory: {allocated_memory:.2f} GB")

        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Peak Allocated Memory: {peak_memory:.2f} GB")

        print(bcolors.OKCYAN, "-"*20, bcolors.ENDC)


def plot_images(outs: np.ndarray, targets: np.ndarray, allowed: List = [], is_test: bool = False) -> None:
    if allowed: # if allowed is not empty then check else do nothing
        if decoded_functions[inspect.stack()[0][0].f_code.co_name] in allowed:
            rows = outs.shape[0]
            cols = 2

            fig, axs = plt.subplots(rows, cols, figsize=(5 * cols, 5 * rows))
            axs = axs.flatten()
            for i in range(0, (rows * cols), cols):
                axs[i + 0].imshow(outs[i//cols])
                axs[i + 0].set_title("Prediction" if not is_test else "Output")
                axs[i + 1].imshow(targets[i//cols])
                axs[i + 1].set_title("Target")
            plt.show()


## Load Wandb Api Key

In [None]:
# apikey_path = ".wandb_apikey.txt"
# if os.path.exists(apikey_path):
#     with open(apikey_path, "r") as f:
#         apikey = f.read()
#         !wandb login {apikey} # --relogin

## Dataset

In [None]:
class ImageDataset(Dataset):
    def __init__(
        self, 
        images_paths: List[str],
        should_randomize_input: bool = False,
        should_log: List[int] = []
    ) -> None:
        """
        
        Parameters
        ----------
        images_paths : List[str]
            List of paths to images.
        should_randomize_input : bool, optional (default is False)
            Should randomize the images.
        should_log : List[int], optional (default is [])
            List of decoded functions to log.
        
        Returns
        -------
        None
        """
        super(ImageDataset, self).__init__()

        self._images_paths: List[str] = images_paths
        self._should_randomize_input: bool = should_randomize_input
        self._should_log: List[int] = should_log

    def __getitem__(self, idx: torch.Tensor or int) -> Dict[str, torch.Tensor]:
        """

        Parameters
        ----------
        idx : torch.Tensor or int
            Index of the image.
        
        Returns
        -------
        Dict[str, torch.Tensor]
            Dictionary with image and target.
        
        Raises
        ------
        RuntimeError
            If error occurs while reading images.
        """

        if idx == -1:
            idx = torch.arange(len(self._images_paths), device="cpu")
        
        if not torch.is_tensor(idx):
            idx = torch.tensor([idx], device="cpu")
        
        try:
            images: torch.Tensor = torch.stack([
                rearrange(
                    io.read_image(self._images_paths[id]).cpu(),
                    "rgb h w -> h w rgb"
                )
                for id in idx
            ])
        except Exception as e:
            raise RuntimeError(f"Error while reading images: {e}")
        
        h, w = images.shape[1], images.shape[2]
        images_shape = h * w

        reordered_indices: torch.Tensor = torch.zeros(
            (images.shape[0], images_shape),
            dtype=torch.int64,
            device="cpu"
        )
        log(("Reordered indices:", reordered_indices.shape), self._should_log)

        X: torch.Tensor = torch.zeros((images.shape[0], images_shape, 2), device="cpu")
        log(("X:", X.shape), self._should_log)

        Y: torch.Tensor = torch.zeros((images.shape[0], images_shape, images.shape[-1]), device="cpu")
        log(("Y:", Y.shape), self._should_log)

        for i in range(images.shape[0]):
            if self._should_randomize_input:
                shuffled_indices: torch.Tensor = (
                    torch.randperm(images_shape, device="cpu") 
                )
                log(("Shuffled indices:", shuffled_indices, shuffled_indices.shape), self._should_log)

                reordered_indices[i][shuffled_indices] = torch.arange(images_shape, device="cpu")
            else:
                reordered_indices[i] = torch.arange(images_shape, device="cpu")

            X[i] = torch.tensor(
                np.stack(np.meshgrid(range(h), range(w), indexing="ij"), axis=-1).reshape(-1, 2)
            )
        
            Y[i] = rearrange(
                images[i],
                "h w rgb -> (h w) rgb"
            )
            
            if self._should_randomize_input:
                X[i] = X[i][shuffled_indices]
                Y[i] = Y[i][shuffled_indices]

        X = (
            X.float() / max(h, w)
        ).unsqueeze(-1).unsqueeze(-1)

        Y = Y.float() / 255
        
        to_return = {
            "X": X,
            "Y": Y,
            "h": h,
            "w": w,
            "reordered_indices": reordered_indices
        }

        return to_return

    def __len__(self) -> int:
        return len(self._images_paths)

## Models

### Backward Pass Differentiable Approximation
[https://github.com/kitayama1234/Pytorch-BPDA]()

In [None]:
def differentiable_round(x, round_function=torch.round):
    forward_value = round_function(x)
    out = x.clone()
    out.data = forward_value.data

    def backward(grad_output):
        return grad_output

    out.register_hook(backward)

    return out

### Differentiable Hisogram
[https://github.com/hyk1996/pytorch-differentiable-histogram]()


In [None]:
#############################################
# Differentiable Histogram Counting Method
#############################################
# https://github.com/hyk1996/pytorch-differentiable-histogram

# TAKES UP TO 1GB OF GPU FOR EACH LEVEL for macaw2
def differentiable_histogram(x, bins=255, min=0.0, max=1.0, should_log: List[int] = []):

    if len(x.shape) == 4:
        n_samples, n_chns, _, _ = x.shape
    elif len(x.shape) == 2:
        n_samples, n_chns = 1, 1
    else:
        raise AssertionError('The dimension of input tensor should be 2 or 4.')

    hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device)
    log(("hist_torch:", hist_torch.shape), should_log)
    delta = (max - min) / bins

    BIN_Table = torch.arange(start=0, end=bins, step=1) * delta
    log(("BIN_Table:", BIN_Table.shape), should_log)

    for dim in range(1, bins-1, 1):
        h_r = BIN_Table[dim].item()             # h_r
        h_r_sub_1 = BIN_Table[dim - 1].item()   # h_(r-1)
        h_r_plus_1 = BIN_Table[dim + 1].item()  # h_(r+1)

        mask_sub = ((h_r > x) & (x >= h_r_sub_1)).float()
        mask_plus = ((h_r_plus_1 > x) & (x >= h_r)).float()

        hist_torch[:, :, dim] += torch.sum(((x - h_r_sub_1) * mask_sub).view(n_samples, n_chns, -1), dim=-1)
        hist_torch[:, :, dim] += torch.sum(((h_r_plus_1 - x) * mask_plus).view(n_samples, n_chns, -1), dim=-1)

        del mask_sub
        del mask_plus
        del h_r
        del h_r_sub_1
        del h_r_plus_1

    log(("hist_torch:", hist_torch.shape), should_log)

    del BIN_Table

    return hist_torch / delta

### Hash Function Model

In [None]:
class LearnableHashFunctionModel(nn.Module):
    def __init__(
        self,
        hidden_layers_widths: List[int],
        input_size: int = 2,
        output_size: int = 2,
        hash_table_size: int = 2**14,
        sigmas_scale: float = 1.0,
        hidden_layers_activation: nn.Module = nn.Tanh(),
        dropout_rate: float | None = None,
        should_log: List[int] = [],
    ) -> None:
        """

        Parameters
        ----------
        hidden_layers_widths : List[int]
            List of hidden layers widths.
        input_size : int, optional (default is 2)
            Input size.
        output_size : int, optional (default is 2)
            Output size.
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        sigmas_scale : float, optional (default is 1.0)
            Sigmas scale.
        hidden_layers_activation : nn.Module, optional (default is nn.Tanh())
            Activation function for hidden layers.
        dropout_rate : float, optional (default is None)
            Dropout rate.
        should_log : List[int], optional (default is [])
            List of decoded functions to log.

        Returns
        -------
        None
        """
        super(LearnableHashFunctionModel, self).__init__()

        self._input_size: int = input_size
        self._output_size: int = output_size
        self._hash_table_size: int = hash_table_size
        self._sigmas_scale: float = sigmas_scale

        self._dropout: nn.Module = nn.Dropout(dropout_rate) if dropout_rate is not None else None

        self._should_log: List[int] = should_log

        layers_widths = [input_size, *hidden_layers_widths, output_size]

        self._module_list: nn.ModuleList = nn.ModuleList([
            nn.Sequential(
                nn.Linear(
                    in_features=layers_widths[i], 
                    out_features=layers_widths[i + 1],
                    device=device
                ),
                hidden_layers_activation if (i < len(layers_widths) - 2) else nn.Sigmoid()
            )
            for i in range(len(layers_widths) - 1)
        ])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Hashes the coordinates.

        Parameters
        ----------
        x : torch.Tensor
            Images' grid coordinates.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            The hashed coordinates and their uncertainty.
        """

        log(("x:", x, x.shape, x.requires_grad, x.is_leaf), self._should_log)

        # for layer in self._module_list:
        for i, layer in enumerate(self._module_list):
            x = layer(x)
            log((f"After layer {i}:", x, x.shape, x.requires_grad, x.is_leaf), self._should_log)

        if self._dropout is not None:
            x = self._dropout(x)
            log(("After dropout:", x, x.shape, x.requires_grad, x.is_leaf), self._should_log)
            x = torch.nan_to_num(x) # Sanitize nan to 0.0
            x[x == 0.0] = 1e-10 # Sanitize 0.0 to 1e-10

        x = x.unsqueeze(-1)
        indices = differentiable_round(x[..., 0, :] * (self._hash_table_size - 1))
        log(("indices:", indices, indices.shape, indices.requires_grad, indices.is_leaf), self._should_log)
        sigmas = x[..., 1, :] * self._sigmas_scale
        log(("sigmas:", sigmas, sigmas.shape, sigmas.requires_grad, sigmas.is_leaf), self._should_log)

        del x
        
        return indices, sigmas
    
    def get_hash_table_size(self) -> int:
        """
        Returns
        -------
        int
            Hash table size.
        """
        return self._hash_table_size
    
    def get_input_size(self) -> int:
        """
        Returns
        -------
        int
            Input size.
        """
        return self._input_size

###  Multiresolution Model

In [None]:
class MultiresolutionModel(nn.Module):
    def __init__(
        self,
        n_min: int,
        n_max: int,
        num_levels: int,
        hashModel: LearnableHashFunctionModel | None,
        hash_table_size: int | None = None,
        input_size: int | None = None,
        should_use_all_levels: bool = False,
        should_fast_hash: bool = False,
        should_calc_collisions: bool = False,
        should_normalize_grid_coords: bool = False,
        should_log: List[int] = [],
    ) -> None:
        """
        
        Parameters
        ----------
        n_min : int
            Minimum scaling factor.
        n_max : int
            Maximum scaling factor.
        num_levels : int
            Number of levels.
        hashModel : LearnableHashFunctionModel | None
            Hash function model.
        hash_table_size : int | None, optional (default is None)
            Hash table size. If None then hashModel's hash_table_size is used.
        input_size : int | None, optional (default is None)
            Input size. If None then hashModel's input_size is used.
        should_use_all_levels : bool, optional (default is False)
            Whether to use all levels or only the ones with collisions.
        should_fast_hash : bool, optional (default is False)
            Whether to use fast hash instead of HashFunction or not.
        should_calc_collisions : bool, optional (default is False)
            Whether to calculate collisions or not.
        should_normalize_grid_coords : bool, optional (default is False)
            Whether to normalize grid coordinates or not.
        should_log : List[int], optional (default is [])
            List of decoded functions to log.
        
        Returns
        -------
        None
        """
        super(MultiresolutionModel, self).__init__()

        self._n_min: int = n_min
        self._n_max: int = n_max
        self._num_levels: int = num_levels

        self.hashModel: LearnableHashFunctionModel = hashModel
        self._hash_table_size: int = hash_table_size if hash_table_size is not None else self.hashModel.get_hash_table_size()
        self._input_size: int = input_size if input_size is not None else self.hashModel.get_input_size()

        self._should_use_all_levels: bool = should_use_all_levels
        self._should_fast_hash: bool = should_fast_hash
        self._should_calc_collisions: bool = should_calc_collisions
        self._should_normalize_grid_coords: bool = should_normalize_grid_coords
        self._should_log: List[int] = should_log

        b: torch.Tensor = torch.tensor(np.exp((np.log(n_max) - np.log(n_min)) / (self._num_levels - 1))).float()
        if b > 2 or b <= 1:
            print(
                f"The between level scale is recommended to be <= 2 and needs to be > 1 but was {b:.4f}."
            )
        
        self._levels: torch.Tensor = torch.stack([
            torch.floor(n_min * (b ** l)) for l in range(self._num_levels)
        ]).reshape(1, 1, -1, 1).to(device)
        log(("Levels:", self._levels, self._levels.shape, self._levels.requires_grad, ), self._should_log)

        self._voxels_helper_hypercube: torch.Tensor = rearrange(
            torch.tensor(
                np.stack(np.meshgrid(range(2), range(2), range(self._input_size - 1), indexing="ij"), axis=-1),
                device=device
            ),
            "cols rows depths verts -> (depths rows cols) verts"
        ).T[:self._input_size, :].unsqueeze(0).unsqueeze(2)
        log(("voxels_helper_hypercube:", self._voxels_helper_hypercube, self._voxels_helper_hypercube.shape, self._voxels_helper_hypercube.requires_grad), self._should_log)

        if self._should_fast_hash:
            self._prime_numbers = torch.nn.Parameter(
                torch.from_numpy(
                    np.array([1, 2654435761, 805459861])
                ).to(device),
                False
            )

    def forward(
        self, 
        x: torch.Tensor, 
        should_calc_hists: bool = False, 
        should_show_hists: bool = False,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor, List[plt.Figure] | None]:
        log(("x:", x, x.shape, x.requires_grad, x.is_leaf), self._should_log)

        grid_coords, coeffs = self._scale_to_grid(x)
        log(("grid_coords:", grid_coords, grid_coords.shape, grid_coords.requires_grad, grid_coords.is_leaf), self._should_log)
        log(("coeffs:", coeffs, coeffs.shape, coeffs.requires_grad, coeffs.is_leaf), self._should_log)
        
        dummy_grids, og_indices, _ = self._calc_dummies(grid_coords=grid_coords)
        min_possible_collisions, _ = self._calc_hash_collisions(dummy_grids=dummy_grids)
        del _
        log(("dummy_grids:", dummy_grids, [dummy_grids[l].shape for l in range(self._num_levels)]), self._should_log)
        log(("og_indices:", og_indices, [(og_indices[l], og_indices[l].shape) for l in range(self._num_levels)]), self._should_log)
        log(("min_possible_collisions:", min_possible_collisions, min_possible_collisions.shape, min_possible_collisions.requires_grad), self._should_log)

        hashed, sigmas = (
            self._fast_hash(grid_coords)
            if self._should_fast_hash
            else (
                self.hashModel(grid_coords)
                if self._should_use_all_levels
                else
                self._multiresolution_hash(grid_coords, min_possible_collisions)
            )
        )
        del grid_coords
        log(("hashed:", hashed, hashed.shape, hashed.requires_grad, hashed.is_leaf), self._should_log)
        log(("sigmas:", sigmas, sigmas.shape, sigmas.requires_grad, sigmas.is_leaf), self._should_log)

        _, _, dummy_hashed = self._calc_dummies(hashed=hashed)
        log(("dummy_hashed:", dummy_hashed, [dummy_hashed[l].shape for l in range(self._num_levels)]), self._should_log)

        collisions = None
        if self._should_calc_collisions:
            _, collisions = self._calc_hash_collisions(dummy_grids=dummy_grids, dummy_hashed=dummy_hashed)
            del _, dummy_grids, dummy_hashed
            log(("collisions:", collisions, collisions.shape, collisions.requires_grad, collisions.is_leaf), self._should_log)

        hists = self._hist_collisions(hashed, og_indices, min_possible_collisions, should_show=should_show_hists) if should_calc_hists else None
        del og_indices

        return hashed, sigmas, min_possible_collisions, collisions, coeffs, hists

    @torch.no_grad()
    def _scale_to_grid(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        
        Parameters
        ----------
        x : torch.Tensor
            Coordinates to scale. (batch, pixels, xyz, 1, 1)
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Scaled grid coordinates and coefficients for bilinear interpolation.
        """

        scaled_coords: torch.Tensor = x.float() * self._levels # (batch, pixels, xyz, levels, 1)
        log(("scaled_coords:", scaled_coords, scaled_coords.shape, scaled_coords.requires_grad, scaled_coords.is_leaf), self._should_log)

        grid_coords: torch.Tensor = rearrange(
            torch.add(
                torch.floor(scaled_coords),
                self._voxels_helper_hypercube
            ),
            "batch pixels xyz levels verts -> batch pixels levels verts xyz"
        )
        log(("grid_coords:", grid_coords, grid_coords.shape, grid_coords.requires_grad, grid_coords.is_leaf), self._should_log)

        coeffs: torch.Tensor = self._calc_bilinear_coefficients(scaled_coords, grid_coords)
        log(("coeffs:", coeffs, coeffs.shape, coeffs.requires_grad, coeffs.is_leaf), self._should_log)
        del scaled_coords

        if self._should_normalize_grid_coords:
            levels_expanded = rearrange(self._levels.expand_as(grid_coords), "batch pixels verts levels xyz -> batch pixels levels verts xyz")
            log(("levels_expanded:", levels_expanded, levels_expanded.shape, levels_expanded.requires_grad, levels_expanded.is_leaf), self._should_log)

            grid_coords = grid_coords / levels_expanded
            log(("normalized_grid_coords:", grid_coords, grid_coords.shape, grid_coords.requires_grad, grid_coords.is_leaf), self._should_log)

        return grid_coords, coeffs

    @torch.no_grad()
    def _calc_bilinear_coefficients(self, scaled_coords: torch.Tensor, grid_coords: torch.Tensor) -> torch.Tensor:
        """
        
        Parameters
        ----------
        scaled_coords : torch.Tensor
            Scaled coordinates.
        grid_coords : torch.Tensor
            Grid coordinates.
        
        Returns
        -------
        torch.Tensor
            Coefficients for bilinear interpolation.
        """

        log(("SCALED COORDS:", scaled_coords[0, :, :, 0, :], scaled_coords.shape), self._should_log)
        log(("GRID COORDS:", grid_coords[0, :, 0, :, :], grid_coords.shape), self._should_log)

        _as: torch.Tensor = grid_coords[:, :, :, 0, :].unsqueeze(-2)  # bottom-right vertices of cells
        _ds: torch.Tensor = grid_coords[:, :, :, -1, :].unsqueeze(-2)  # top-left vertices of cells

        log(("_as:", _as, _as.shape), self._should_log)
        log(("_ds:", _ds, _ds.shape), self._should_log)

        coeffs: torch.Tensor = torch.stack([
            (_ds[:, :, :, :, 0] - scaled_coords[:, :, 0, :, :]) * (_ds[:, :, :, :, 1] - scaled_coords[:, :, 1, :, :]),  # (xd - x) * (yd - y)
            (scaled_coords[:, :, 0, :, :] - _as[:, :, :, :, 0]) * (_ds[:, :, :, :, 1] - scaled_coords[:, :, 1, :, :]),  # (x - xa) * (yd - y)
            (_ds[:, :, :, :, 0] - scaled_coords[:, :, 0, :, :]) * (scaled_coords[:, :, 1, :, :] - _as[:, :, :, :, 1]),  # (xd - x) * (y - ya)
            (scaled_coords[:, :, 0, :, :] - _as[:, :, :, :, 0]) * (scaled_coords[:, :, 1, :, :] - _as[:, :, :, :, 1]),  # (x - xa) * (y - ya)
        ], dim=-1).squeeze(-2).unsqueeze(2)#.to(device)
        
        del _as
        del _ds
        del scaled_coords
        del grid_coords

        log(("COEFFS:", coeffs, coeffs.shape), self._should_log)

        return coeffs

    @torch.no_grad()
    def _calc_dummies(
        self, 
        grid_coords: torch.Tensor | None = None, 
        hashed: torch.Tensor | None = None
    ) -> Tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]:
        """
        Calculates unique values for the grid coordinates or hashed coordinates.
        
        Parameters
        ----------
        grid_coords : torch.Tensor | None, optional (default is None)
            Grid coordinates. 
        hashed : torch.Tensor | None, optional (default is None)
            Hashed coordinates. 
            
        Returns
        -------
        Tuple[torch.Tensor | None, torch.Tensor | None, torch.Tensor | None]
            Unique grid coordinates, original unique indices and unique hashed coordinates. 
        """

        dummy_grids = None
        og_indices = None
        dummy_hashed = None

        if grid_coords is not None:
            log(("grid_coords:", grid_coords, grid_coords.shape, grid_coords.requires_grad, grid_coords.is_leaf), self._should_log)
            dummy_grids: List[Tuple[torch.Tensor]] = [
                # [
                torch.unique(rearrange(grid_coords[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=True) # first dimension is 0 because images should have all same size
                for l in range(self._num_levels)
                # ] for b in range(grid_coords.shape[0])
            ]

            dummy_grids, dummy_grids_inverse_indices = zip(*dummy_grids)
            log(("dummy_grids:", dummy_grids, [dummy_grids[l].shape for l in range(self._num_levels)]), self._should_log)
            log(("dummy_grids grads:", [(dummy_grids[l].requires_grad, ) for l in range(self._num_levels)]), self._should_log)

            og_indices = [
                torch.tensor(
                    [np.where(dummy_grids_inverse_indices[l].detach().cpu().numpy() == i)[0][0] for i in range(len(dummy_grids[l]))],
                    device=device
                )
                for l in range(self._num_levels)
            ]
            log(("og_indices:", [(og_indices[l], og_indices[l].shape) for l in range(self._num_levels)]), self._should_log)
            log(("og_indices grads:", [(og_indices[l].requires_grad, ) for l in range(self._num_levels)]), self._should_log)

        if hashed is not None:
            log(("hashed:", hashed, hashed.shape, hashed.requires_grad, hashed.is_leaf), self._should_log)
            
            dummy_hashed: List[torch.Tensor] = [
                # [
                torch.unique(rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=False) # first dimension is 0 because images should have all same size
                for l in range(self._num_levels)
                # ] for b in range(hashed.shape[0])
            ]
            log(("dummy_hashed:", dummy_hashed, [dummy_hashed[l].shape for l in range(self._num_levels)]), self._should_log)
            log(("dummy_hashed grads:", [(dummy_hashed[l].requires_grad, ) for l in range(self._num_levels)]), self._should_log)

        return dummy_grids, og_indices, dummy_hashed

    @torch.no_grad()
    def _calc_hash_collisions(
        self,
        dummy_grids: List[torch.Tensor],
        dummy_hashed: List[torch.Tensor] | None = None
    ) -> Tuple[torch.Tensor | None, torch.Tensor | None]:
        """
        Calculates the hash collisions.

        Parameters
        ----------
        dummy_grids : List[torch.Tensor]
            List of unique grid coordinates.
        dummy_hashed : List[torch.Tensor] | None, optional (default is None)
            List of unique hashed coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor | None, torch.Tensor | None]
            Minimum possible collisions and actual collisions at each level.
        """

        min_possible_collisions = None
        collisions = None

        if dummy_hashed is None:
            min_possible_collisions: torch.Tensor = torch.stack([
                torch.tensor(
                    (dummy_grids[l].shape[0]) - self._hash_table_size,
                    device=device
                )
                for l in range(self._num_levels)
            ])
            min_possible_collisions[min_possible_collisions < 0] = 0
            log(("min_possible_collisions:", min_possible_collisions, min_possible_collisions.shape, min_possible_collisions.requires_grad), self._should_log)
        else:
            collisions: torch.Tensor = torch.stack([
                torch.tensor(
                    float(dummy_grids[l].shape[0] - torch.unique(dummy_hashed[l], dim=0).shape[0]),
                    device=device
                )
                for l in range(self._num_levels)
            ])
            log(("collisions:", collisions, collisions.shape, collisions.dtype, collisions.requires_grad), self._should_log)

        return min_possible_collisions, collisions

    @torch.no_grad()
    def _calc_uniques(
        self, 
        hashed: torch.Tensor,
        og_indices: torch.Tensor
    ) -> List[torch.Tensor]:
        """
        Calculates the unique values for the hashed coordinates.

        Parameters
        ----------
        hashed : torch.Tensor
            Hashed coordinates.
        og_indices : torch.Tensor
            Original unique indices.
        
        Returns
        -------
        List[torch.Tensor]
            List of unique hashed coordinates one for each level.
        """
        
        unique_hashed: List[torch.Tensor] = [
            rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]] # first dimension is 0 because images should have all same size
            for l in range(self._num_levels)
        ]
        log(("unique_hashed:", [(unique_hashed[b][l], unique_hashed[b][l].shape) for l in range(self._num_levels) for b in range(hashed.shape[0])]), self._should_log)
        log(("unique_hashed:", [(unique_hashed[l], unique_hashed[l].shape) for l in range(self._num_levels)]), self._should_log)

        return unique_hashed

    @torch.no_grad()
    def _hist_collisions(
        self,
        hashed: torch.Tensor,
        og_indices: torch.Tensor,
        min_possible_collisions: torch.Tensor,
        should_show: bool = False
    ) -> List[plt.Figure]:
        """
        Calculates the histogram of the collisions, one for each level.
        
        Parameters
        ----------
        hashed : torch.Tensor
            Hashed coordinates.
        og_indices : torch.Tensor
            Original unique indices.
        min_possible_collisions : torch.Tensor
            Minimum possible collisions, one for each level.

        Returns
        -------
        List[plt.Figure]
            List of histograms, one for each level.
        """

        figs=[]

        unique_hashed = self._calc_uniques(hashed, og_indices)
        log(("unique_hashed:", unique_hashed, [unique_hashed[l].shape for l in range(self._num_levels)], unique_hashed[0].requires_grad, unique_hashed[0].is_leaf), self._should_log)

        for l, min_collisions in enumerate(min_possible_collisions):
            
            if (min_collisions <= 0) and not self._should_fast_hash:
                figs.append(None)
                continue

            indices = unique_hashed[l].detach().cpu().numpy()

            fig, ax = plt.subplots(figsize=(15, 5))
            ax.hist(
                indices,
                bins=self._hash_table_size,
                range=(0, self._hash_table_size),
                edgecolor='grey', 
                linewidth=0.5
            )

            ax.set_xlim(-1, self._hash_table_size)
            ax.xaxis.set_ticks(np.arange(0, self._hash_table_size, 10))

            start, end = ax.get_ylim()
            step = int(end * 0.1)
            ax.yaxis.set_ticks(np.arange(0, end, step if step > 0 else 1))

            plt.title(f"Level {l} ({int(self._levels[0, 0, l, 0].item())})")
            plt.xlabel("Hashed indices")
            plt.ylabel("Counts")

            figs.append(fig)

            if should_show:
                plt.show()

            plt.close()
        
        del unique_hashed

        return figs
    
    @torch.no_grad()
    def _fast_hash(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Implements the hash function proposed by NVIDIA.

        Parameters
        ----------
        x : torch.Tensor
            Grid coordinates to hash of shape (batch, pixels, levels, 2^input_dim, input_dim)
            This tensor should contain the vertices of the hyper cube for each level.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Hashed coordinates and their uncertainty.
        """
        tmp = torch.zeros(
            (x.shape[0], x.shape[1], self._num_levels, 2**self._input_size),
            device=device
        ).to(int)

        for i in range(self._input_size):
            tmp = torch.bitwise_xor(
                (x[:, :, :, :, i].to(int) * self._prime_numbers[i]),
                tmp
            )

        hash = torch.remainder(tmp, self._hash_table_size).unsqueeze(-1).float() # TODO: check if self._hash_table_size - 1 is correct
        del tmp

        sigmas: torch.Tensor = torch.zeros_like(hash, dtype=torch.float32, device=device) + 1e-10 # to prevent division by zero

        return hash, sigmas

    def _multiresolution_hash(
        self,
        x: torch.Tensor,
        min_possible_collisions: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        
        Parameters
        ----------
        x : torch.Tensor
            Grid coordinates to hash of shape (batch, pixels, levels, 2^input_dim, input_dim)
            This tensor should contain the vertices of the hyper cube for each level.
        min_possible_collisions : torch.Tensor
            Minimum possible hash collisions for each level.
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Hashed coordinates and their uncertainty.
        """
        log(("x:", x, x.shape, x.requires_grad, x.is_leaf), self._should_log)
        log(("min_possible_collisions:", min_possible_collisions, min_possible_collisions.shape, min_possible_collisions.requires_grad), self._should_log)

        sigma_epsilon = 1e-10 # to prevent division by zero

        hashed: torch.Tensor = torch.zeros((x.shape[0], x.shape[1], self._num_levels, 2**self._input_size, 1), device=device)
        sigmas: torch.Tensor = torch.zeros((x.shape[0], x.shape[1], self._num_levels, 2**self._input_size, 1), device=device) + sigma_epsilon

        x_non_collisions_levels = x[:, :, min_possible_collisions <= 0, :, :]

        hashed_non_collisions_levels = (
            (
                (self._levels[:, :, min_possible_collisions <= 0, :] + 1) 
                * 
                x_non_collisions_levels[:, :, :, :, 0]
            ) 
            + 
            x_non_collisions_levels[:, :, :, :, 1]
        )
        log(("Hashed indices of levels without collisions:", hashed_non_collisions_levels, hashed_non_collisions_levels.shape), self._should_log)
        hashed[:, :, min_possible_collisions <= 0, :, :] = hashed_non_collisions_levels.unsqueeze(-1)

        hashed_collisions_levels, sigmas_collisions_levels = self.hashModel(x[:, :, min_possible_collisions > 0, :, :])
        sigmas[:, :, min_possible_collisions > 0, :, :] = sigmas_collisions_levels
        hashed[:, :, min_possible_collisions > 0, :, :] = hashed_collisions_levels
        del hashed_collisions_levels
        del sigmas_collisions_levels
        del x_non_collisions_levels

        log(("hashed:", hashed, hashed.shape, hashed.requires_grad, hashed.is_leaf), self._should_log)
        log(("sigmas:", sigmas, sigmas.shape, sigmas.requires_grad, sigmas.is_leaf), self._should_log)

        return hashed, sigmas

    def get_hash_table_size(self) -> int:
        """
        Returns
        -------
        int
            Hash table size.
        """
        return self._hash_table_size
    
    def get_num_levels(self) -> int:
        """
        Returns
        -------
        int
            Number of levels.
        """
        return self._num_levels


### General Neural Gauge Fields Model

In [None]:
class GNGFModel(nn.Module):
    def __init__(
        self,
        batch_size: int,
        hidden_layers_widths: List[int],
        multiresModel: MultiresolutionModel,
        feature_size: int = 2,
        topk: int = 1,
        num_levels: int | None = None,
        hash_table_size: int | None = None,
        should_circular_topk: bool = True,
        should_learn_images: bool = False,
        should_log: List[int] = [],
    ) -> None:
        """
        
        Parameters
        ----------
        batch_size : int
            Batch size.
        hidden_layers_widths : List[int]
            List of hidden layers widths.
        multiresModel : MultiresolutionModel
            Multiresolution model.
        feature_size : int, optional (default is 2)
            Feature size.
        topk : int, optional (default is 1)
            Top k. If -1 then all the hashed coordinates are used.
        num_levels : int | None, optional (default is None)
            Number of levels. If None then MultiresolutionModel's num_levels is used.
        hash_table_size : int | None, optional (default is None)
            Hash table size. If None then MultiresolutionModel's hash_table_size is used.
        should_circular_topk : bool, optional (default is True)
            Whether to use circular topk or not.
        should_learn_images : bool, optional (default is False)
            Whether to learn the images or not.
        should_log : List[int], optional (default is [])
            List of decoded functions to log.
        
        Returns
        -------
        None
        """
        super(GNGFModel, self).__init__()

        self._batch_size: int = batch_size
        self.multiresModel: MultiresolutionModel = multiresModel

        self._feature_size: int = feature_size
        self._topk: int = topk
        self._num_levels: int = num_levels if num_levels is not None else self.multiresModel.get_num_levels()
        self._hash_table_size: int = hash_table_size if hash_table_size is not None else self.multiresModel.get_hash_table_size()

        self._should_circular_topk: bool = should_circular_topk
        self._should_learn_images: bool = should_learn_images
        self._should_log: List[int] = should_log

        self.hash_tables: torch.nn.ModuleList = torch.nn.ModuleList([
            torch.nn.ModuleList([
                torch.nn.Embedding(self._hash_table_size, self._feature_size, device=device)
                for _ in range(self._num_levels)
            ])
            for _ in range(self._batch_size)
        ])
        log(("Hash table image 0 level 0:", self.hash_tables[0][0].weight, self.hash_tables[0][0].weight.shape), self._should_log)

        self._apply_init(torch.nn.init.uniform_, -1.0, 1.0)
        log(("Initialized hash table image 0 level 0:", self.hash_tables[0][0].weight, self.hash_tables[0][0].weight.shape), self._should_log)

        layers_widths = [(self._num_levels * self._feature_size), *hidden_layers_widths, 3]
        self.mlp: nn.ModuleList = nn.ModuleList([
            nn.Sequential(
                nn.Linear(
                    in_features=layers_widths[i], 
                    out_features=layers_widths[i + 1],
                    device=device
                ),
                nn.ReLU() if (i < (len(layers_widths) - 2)) else nn.Sigmoid()
            )
            for i in range(len(layers_widths) - 1)
        ])

    def forward(
        self, 
        x: torch.Tensor,
        should_calc_hists: bool = False,
        should_show_hists: bool = False,
    ) -> Tuple[torch.Tensor | None, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor | None, List[plt.Figure] | None]:
        
        hashed, sigmas, min_possible_collisions, collisions, coeffs, hists = self.multiresModel(
            x, 
            should_calc_hists, 
            should_show_hists
        )
        del x

        out = None
        if self._should_learn_images:
            out = self._look_up_features(hashed, sigmas)

            out = self._bilinear_interpolation(out, coeffs)
            del coeffs

            for i, layer in enumerate(self.mlp):
                out = layer(out)
                log((f"After layer {i}:", out, out.shape, out.requires_grad, out.is_leaf), self._should_log)

        return out, hashed, sigmas, min_possible_collisions, collisions, hists
    
    def _look_up_features(self, indices: torch.Tensor, sigmas: torch.Tensor) -> torch.Tensor:
        """
        Looks up features from the hash tables.

        Parameters
        ----------
        indices : torch.Tensor
            Hashed coordinates.
        sigmas : torch.Tensor
            Hashed coordinates' uncertainty.

        Returns
        -------
        torch.Tensor
            Looked up features.
        """
        log(("indices:", indices, indices.shape, indices.requires_grad, indices.is_leaf), self._should_log)
        log(("sigmas:", sigmas, sigmas.shape, sigmas.requires_grad, sigmas.is_leaf), self._should_log)

        if self._topk == -1: # use all hashed coordinates
            topks = torch.arange(
                self._hash_table_size,
                dtype=torch.float32,
                device=device
            )
        else:
            topks = torch.arange(
                -(self._topk // 2), (self._topk // 2) + 1,
                dtype=torch.float32,
                device=device
            )
        log(("topks:", topks, topks.shape), self._should_log, color=bcolors.FAIL)

        new_indices = indices + topks

        if self._should_circular_topk: # CIRCULAR IMPLEMENTATION
            new_indices = torch.remainder(new_indices, self._hash_table_size)
        else: # LINEAR IMPLEMENTATION
            new_indices[new_indices < 0] = 0
            new_indices[new_indices >= self._hash_table_size] = self._hash_table_size - 1
        log(("new_indices:", new_indices, new_indices.shape), self._should_log)

        looked_up: torch.Tensor = rearrange(
            torch.stack([
                torch.stack([
                    self.hash_tables[b][l](
                        x[:, l, :, :].int()
                    )
                    for l in range(self._num_levels)
                ])
                for b, x in enumerate(new_indices)
            ]),
            "batch levels pixels verts k features -> batch pixels levels features verts k"
        )
        log(("looked_up:", looked_up, looked_up.shape), self._should_log)
        del new_indices

        # Calculate the Gaussian probabilities
        gaussian_probs = (
            torch.exp(-(1/2) * ((topks - 0) / (sigmas))**2) / ((sigmas) * torch.sqrt(2 * torch.tensor(np.pi)))
        ).unsqueeze(3)
        log(("gaussian_probs:", gaussian_probs, gaussian_probs.shape), self._should_log, color=bcolors.WARNING)
        del topks

        # (weighted avg) sum(looked_up * topk)/sum(topk)
        looked_up = rearrange(
            torch.sum(looked_up * gaussian_probs, dim=-1) / torch.sum(gaussian_probs, dim=-1),
            "batch pixels levels features verts -> batch pixels features levels verts"
        )
        log(("Weighted avg looked_up:", looked_up, looked_up.shape), self._should_log)

        del gaussian_probs

        return looked_up

    def _bilinear_interpolation(self, features: torch.Tensor, coeffs: torch.Tensor) -> torch.Tensor:
        """
        Bilinear interpolate features with coefficients.

        Parameters
        ----------
        features : torch.Tensor
            Features to interpolate.
        coeffs : torch.Tensor
            Coefficients for bilinear interpolation.

        Returns
        -------
        torch.Tensor
            Interpolated features.
        """
        log(("coeffs:", coeffs, coeffs.shape), self._should_log)
        log(("features:", features, features.shape), self._should_log)

        weighted_features: torch.Tensor = features * coeffs
        del features
        del coeffs
        log(("weighted_features:", weighted_features, weighted_features.shape), self._should_log)

        weighted_summed_features: torch.Tensor = torch.sum(weighted_features, dim=-1)#.to(device)
        del weighted_features
        log(("weighted_summed_features:", weighted_summed_features, weighted_summed_features.shape), self._should_log)
        
        stack: torch.Tensor = rearrange(weighted_summed_features, "batch pixels features levels -> batch pixels (levels features)")#.to(device)
        del weighted_summed_features
        log(("stacked:", stack, stack.shape), self._should_log)

        return stack
    
    def _apply_init(self, init_func, *args) -> None:
        """
        Initializes the hash tables weights with a random uniform function
        """
        for b in range(self._batch_size):
            for i in range(self._num_levels):
                init_func(self.hash_tables[b][i].weight, *args)


## Metrics, Loss and Optimizer

### Metrics

In [None]:
def calc_accuracy(predicted: np.ndarray, target: np.ndarray, size: int) -> float:
    return (np.equal(predicted, target).sum() / size) * 100

def calc_psnr(pred: np.ndarray, target: np.ndarray) -> float:
    mse = np.square(pred - target).mean()
    return 20 * np.log10(np.max(target)) - 10 * np.log10(mse) # psne


### Loss

In [None]:
class Loss(nn.Module):
    def __init__(
        self,
        hash_table_size: int,
        kl_div_reduction: str = "batchmean",
        hist_sanitize_eps: float = 1e-10,
        should_use_all_levels: bool = False,
        should_sanitize_hist_before: bool = False,
        should_exp_normalize_kl_div: bool = False,
        should_log: List[int] = []
    ) -> None:
        """
        
        Parameters
        ----------
        hash_table_size : int
            Hash table size.
        kl_div_reduction : str, optional (default is "batchmean")
            KL divergence reduction.
        hist_sanitize_eps : float, optional (default is 1e-10)
            Epsilon to prevent division by zero.
        should_use_all_levels : bool, optional (default is False)
            Whether to use all levels or only the ones with collisions.
        should_sanitize_hist_before : bool, optional (default is False)
            Whether to sanitize the histogram before or after normalization.
        should_exp_normalize_kl_div : bool, optional (default is False)
            Whether to normalize the KL divergence or not.
        should_log : List[int], optional (default is [])
            List of decoded functions to log. (9 to plot KL histograms)
        
        Returns
        -------
        None
        """
        super(Loss, self).__init__()

        self._hash_table_size: int = hash_table_size
        self._hist_sanitize_eps: float = hist_sanitize_eps

        self._should_use_all_levels: bool = should_use_all_levels
        self._should_sanitize_hist_before: bool = should_sanitize_hist_before
        self._should_exp_normalize_kl_div: bool = should_exp_normalize_kl_div
        self._should_log: List[int] = should_log

        self._KLDiv: nn.KLDivLoss = nn.KLDivLoss(reduction=kl_div_reduction)
        self._sigmas_MSE: nn.MSELoss = nn.MSELoss(reduction="mean")
        self._images_MSE: nn.MSELoss = nn.MSELoss(reduction="mean")

    def forward(
        self,
        min_possible_collisions: torch.Tensor,
        indices: torch.Tensor,
        sigmas: torch.Tensor,
        pred: torch.Tensor | None = None,
        target: torch.Tensor | None = None,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
        
        levels: int = min_possible_collisions.shape[0]

        kl_div_losses: torch.Tensor = torch.zeros(levels, device=device)
        for l in range(levels):
            if not self._should_use_all_levels and min_possible_collisions[l] <= 0:
                continue
            
            kl_div_level_loss = self._kl_div(self._calc_hist_pdf(indices[l]))
            kl_div_losses[l] = kl_div_level_loss if not self._should_exp_normalize_kl_div else (1 - torch.exp(-kl_div_level_loss))
        log(("kl_div_losses:", kl_div_losses, kl_div_losses.shape, kl_div_losses.requires_grad, kl_div_losses.is_leaf), self._should_log)
        
        sigmas_losses: torch.Tensor = torch.stack([
            self._sigmas_MSE(sigmas[l], torch.zeros_like(sigmas[l]))
            for l in range(levels)
        ])
        log(("sigmas_losses:", sigmas_losses, sigmas_losses.shape, sigmas_losses.requires_grad, sigmas_losses.is_leaf), self._should_log)

        images_losses: torch.Tensor | None = None
        if pred is not None and target is not None:
            batch_size = pred.shape[0]

            images_losses: torch.Tensor = torch.stack([
                self._images_MSE(pred[b], target[b])
                for b in range(batch_size)
            ])
            log(("images_losses:", images_losses, images_losses.shape, images_losses.requires_grad, images_losses.is_leaf), self._should_log)

        return kl_div_losses, sigmas_losses, images_losses

    def _calc_hist_pdf(
        self,
        indices: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculates the histogram pdf.

        Parameters
        ----------
        indices : torch.Tensor
            Hashed coordinates.

        Returns
        -------
        torch.Tensor
            Histogram pdf.
        """
        log(("indices:", indices, indices.shape, indices.requires_grad, indices.is_leaf), self._should_log)

        hist_p = differentiable_histogram(indices, bins=self._hash_table_size, min=0, max=self._hash_table_size, should_log=self._should_log).squeeze(0).squeeze(0)
        log((f"hist_p_diff   : {hist_p.long()}, shape: {hist_p.shape}, sum: {torch.sum(hist_p)}, requires_grad: {hist_p.requires_grad}", ), self._should_log)

        # Real implemenation of histogram pdf, but not differentiable
        hist_p_nondiff = torch.histc(indices, bins=self._hash_table_size, min=0, max=self._hash_table_size) # ? maybe max=(self._hash_table_size - 1)
        log((f"hist_p_nondiff: {hist_p_nondiff}, shape: {hist_p_nondiff.shape}, sum: {torch.sum(hist_p_nondiff)}, requires_grad: {hist_p_nondiff.requires_grad}", ), self._should_log, color=bcolors.WARNING)
        log((f"hist_p_diff - hist_p_nondiff: {hist_p - hist_p_nondiff}", ), self._should_log, color=bcolors.WARNING)
        del indices

        if self._should_sanitize_hist_before:
            hist_p[hist_p == 0] = (torch.sum(hist_p) * self._hist_sanitize_eps)

        p = hist_p / torch.sum(hist_p)
        log(("p:", p, p.shape, p.requires_grad, p.is_leaf, p.sum().item()), self._should_log)

        del hist_p

        if not self._should_sanitize_hist_before:
            p[p == 0] = self._hist_sanitize_eps
        # log(("p after:", p, p.shape, p.requires_grad, p.is_leaf), self._should_log)

        return p

    def _kl_div(
        self,
        p: torch.Tensor
    ) -> torch.Tensor:
        """
        Calculates the KL divergence.
        
        Parameters
        ----------
        p : torch.Tensor
            Histogram pdf.
        
        Returns
        -------
        torch.Tensor
            KL divergence.
        """
        log(("p:", p, p.shape, p.requires_grad, p.is_leaf, p.sum().item()), self._should_log)

        q = torch.ones(
            self._hash_table_size, 
            device=device
        ) / self._hash_table_size
        log(("q:", q, q.shape, q.requires_grad, q.is_leaf), self._should_log)

        if 9 in self._should_log:
            self._plot_histograms(p.detach().cpu().numpy(), q.detach().cpu().numpy())

        kl_div_loss = self._KLDiv(p.log(), q)
        log(("kl_div_loss:", kl_div_loss, kl_div_loss.shape, kl_div_loss.requires_grad, kl_div_loss.is_leaf), self._should_log)
        
        del p
        del q

        return kl_div_loss
    
    def _plot_histograms(self, p, q) -> None:
        fig, ax = plt.subplots(figsize=(15, 5))
        ax.bar(np.arange(len(p)), p, alpha=1, label="p")
        ax.bar(np.arange(len(q)), q, alpha=0.5, label="q")

        ax.set_xlim(-1, self._hash_table_size)
        ax.xaxis.set_ticks(np.arange(0, self._hash_table_size, 10))

        # start, end = ax.get_ylim()
        # step = int(end * 0.1)
        # ax.yaxis.set_ticks(np.arange(0, end, step if step > 0 else 1))

        plt.legend()
        plt.show()
        plt.close()


### Optimizer

In [None]:
def get_optimizer(
    models_parameters: Dict[str, Any],
    optimizers: List[torch.optim.Optimizer] = [torch.optim.Adam, torch.optim.AdamW],
) -> Dict[str, torch.optim.Optimizer]:
    
    optims = {}
    
    for model, optimizer in zip(models_parameters.items(), optimizers):
        models_params = []

        for m in model[1]["each"]:
            param, lr, weight_decay= m.values()

            models_params.append({
                "params": param, "lr": lr, "weight_decay": weight_decay
            })
        
        betas = model[1]["betas"]
        eps = model[1]["eps"]

        optims[model[0]] = optimizer(
            models_params,
            betas=betas,
            eps=eps,
        )
    
    return optims

## Train & Test Loops

### Train loop

In [None]:
def train_loop(
    x: torch.Tensor,
    y: torch.Tensor,
    reordered_indices: torch.Tensor,
    h: int,
    w: int,
    model: nn.Module,
    optimizers: List[torch.optim.Optimizer],
    loss_fn: nn.Module,
    l_kl_loss: float,
    l_sigmas_loss: float,
    l_images_loss: float,
    l_reg_loss: float,
    norm_regularization_order: int = 2,
    gradient_clipping: float | None = None,
    should_calc_hists: bool = False,
    should_learn_images: bool = False,
    excluded_params_from_reg_loss: List[str] = ["_prime_numbers", "hash_tables", "mlp"],
    should_randomize_input: bool = False,
    should_log: List[int] = [],
    should_log_grads: bool = False,
) -> Dict[str, Any]:
    """
    Trains the model.

    Parameters
    ----------
    x : torch.Tensor
        Input data.
    y : torch.Tensor
        Target data.
    reordered_indices : torch.Tensor
        Reordered indices.
    h : int
        Imags height.
    w : int
        Images width.
    model : nn.Module
        Model to train.
    optimizers : List[torch.optim.Optimizer]
        List of optimizers.
    loss_fn : nn.Module
        Loss function.
    l_kl_loss : float
        KL divergence loss lambda.
    l_sigmas_loss : float
        Sigmas loss lambda.
    l_images_loss : float
        Images loss lambda.
    l_reg_loss : float
        Regularization loss lambda.
    norm_regularization_order : int, optional (default is 2)
        Norm regularization order.
    gradient_clipping : float | None, optional (default is None)
        Gradient clipping value. If None then no clipping is applied.
    should_calc_hists : bool, optional (default is False)
        Whether to calculate histograms or not.
    should_learn_images : bool, optional (default is False)
        Whether to learn the images or not.
    excluded_params_from_reg_loss : List[str], optional (default is ["_prime_numbers", "hash_tables", "mlp])
        List of parameters to exclude from the regularization loss.
    should_randomize_input : bool, optional (default is False)
        Whether to randomize the input or not.
    should_log : List[int], optional (default is [])
        - 0: Log "Train loop".
        - 1: Log
        - 2: Log and plot.
    should_log_grads : bool, optional (default is False)
        Whether to log the gradients or not.
    
    Returns
    -------
    Dict[str, Any]
        Dictionary with the loss and the predictions.
    """
    log(("Train loop", ), 0 in should_log, color=bcolors.WARNING)

    model.train()
    # for key, optimizer in optimizers.items():
    #     optimizer.zero_grad()

    pred, hashed, sigmas, min_possible_collisions, collisions, hists = model(x, should_calc_hists=should_calc_hists, should_show_hists=(2 in should_log))
    # print(torch.count_nonzero(hashed[0] != hashed[1])) # == 0
    # print(torch.count_nonzero(sigmas[0] != sigmas[1])) # == 0
    
    if should_learn_images:
        log(("pred:", pred, pred.shape, pred.requires_grad, pred.is_leaf), should_log)

    kl_div_losses, sigmas_losses, images_losses = loss_fn(
        min_possible_collisions=min_possible_collisions,
        indices=rearrange(hashed, "batch pixels levels verts 1 -> batch levels pixels (verts 1)")[0],
        sigmas=rearrange(sigmas, "batch pixels levels verts 1 -> batch levels pixels (verts 1)")[0],
        pred=pred,
        target=y,
    )
    log(("kl_div_losses:", kl_div_losses, kl_div_losses.shape, kl_div_losses.requires_grad, kl_div_losses.is_leaf), should_log)
    log(("sigmas_losses:", sigmas_losses, sigmas_losses.shape, sigmas_losses.requires_grad, sigmas_losses.is_leaf), should_log)
    if images_losses is not None:
        log(("images_losses:", images_losses, images_losses.shape, images_losses.requires_grad, images_losses.is_leaf), should_log)

    regularization_loss = torch.tensor(0.0, device=device)
    for name, param in model.named_parameters():
        # if np.sum(np.isin(np.array(name.split('.')), excluded_params_from_reg_loss)) == 0:
        if not any([excluded_param in name for excluded_param in excluded_params_from_reg_loss]):
            # print(name)
            regularization_loss += torch.linalg.norm(param, ord=norm_regularization_order)

    loss = (
        (l_kl_loss * torch.sum(kl_div_losses)) +
        (l_sigmas_loss * torch.sum(sigmas_losses)) +
        (
            (l_images_loss * torch.sum(images_losses)) 
            if images_losses is not None else
            torch.tensor(0.0, device=device)
        ) +
        (l_reg_loss * regularization_loss)
    )
    log(("loss:", loss, loss.shape, loss.requires_grad, loss.is_leaf), should_log)

    if gradient_clipping is not None:
        torch.nn.utils.clip_grad_norm_(model.multiresModel.hashModel.parameters(), gradient_clipping)

    if should_log_grads:
        if pred is not None:
            pred.retain_grad()
        hashed.retain_grad()
        sigmas.retain_grad()
    
    loss.backward()

    if should_log_grads:
        if pred is not None:
            log(("Pred gradient: ", pred.grad, pred.grad_fn, pred.shape, pred.sum()), True, color=bcolors.OKGREEN)
        log(("Hashed gradient: ", hashed.grad, hashed.grad_fn, hashed.shape, hashed.sum()), True, color=bcolors.OKGREEN)
        log(("Sigmas gradient: ", sigmas.grad, sigmas.grad_fn, sigmas.shape, sigmas.sum()), True, color=bcolors.OKGREEN)

        for name, param in model.named_parameters():
            if "hashModel" in name:
                log((name, param.grad), True, color=bcolors.OKGREEN)

    # for key, optimizer in optimizers.items():
    #     # TODO HOW?
    #     ######
    #     # Lastly, we skip Adam steps for hash table entries whose gradient is exactly 0. 
    #     # This saves ∼ 10% performance when gradients are sparse, which is a common occurrence with 𝑇 ≫ BatchSize. 
    #     # Even though this heuristic violates some of the assumptions behind Adam, we observe no degradation in convergence.
    #     ######
    #     optimizer.step()

    pred_images = None
    target_images = None
    images_psnr = None
    if should_learn_images:
        if should_randomize_input:
            for i in range(pred.shape[0]):
                pred[i] = pred[i][reordered_indices[i]]
                y[i] = y[i][reordered_indices[i]]
        pred_images = (pred * 255).reshape(-1, h, w, 3).to(int).detach().cpu().numpy()
        target_images = (y * 255).reshape(-1, h, w, 3).to(int).detach().cpu().numpy()

        plot_images(pred_images, target_images, should_log)

        images_psnr = [calc_psnr(pred_images[i], target_images[i]) for i in range(pred_images.shape[0])]

    to_return = {
        "pred_images": pred_images,
        "target_images": target_images,
        "images_psnr": images_psnr,
        "min_possible_collisions": min_possible_collisions.detach().cpu().numpy(),
        "collisions": collisions.detach().cpu().numpy() if collisions is not None else None,
        "histograms": hists,
        "kl_div_losses": kl_div_losses.detach().cpu().numpy(),
        "sigmas_losses": sigmas_losses.detach().cpu().numpy(),
        "reg_loss": regularization_loss.item(),
        "images_losses": images_losses.detach().cpu().numpy() if images_losses is not None else None,
        "loss": loss.item(),
    }

    return to_return

### Test Loop

In [None]:
def create_indices_mapping(
    pred: np.ndarray,
    h: int,
    w: int,
    min_possible_collisions: np.ndarray,
    hashed: np.ndarray,
    hash_table_size: int,
    save_dir: str | None = None,
    should_show: bool = False,
    should_log: List[int] = [],
) -> List[plt.Figure]:
    """
    
    Parameters
    ----------
    pred : np.ndarray
        Predicted images.
    h : int
        Images height.
    w : int
        Images width.
    min_possible_collisions: np.ndarray
        Minimum possible collisions per level
    hashed : np.ndarray
        Hashed coordinates.
    hash_table_size : int
        Hash table size.
    save_dir : str | None, optional (default is None)
        Directory to save the plots.
    should_show : bool, optional (default is False)
        Whether to show the plots or not.
    should_log : List[int], optional (default is [])
        - 1: Log

    Returns
    -------
    List[plt.Figure]
        List of plots, one for each level, for each batch.
    """
    
    batch, pixels, levels, verts, xyz = hashed.shape
    original_vertex = 0

    maps = []
    colors = plt.cm.viridis(np.linspace(0, 1, hash_table_size))
    cmap = ListedColormap(colors)

    for b in range(batch):
        pred_b = pred[b, :, :].reshape(h, w, 3)

        for l in range(levels):
            if min_possible_collisions[l] <= 0:
                continue

            fig, ax = plt.subplots()
            log(("Batch:", b, ", level:", l), should_log, color=bcolors.OKGREEN)

            hashed_bl = hashed[b, :, l, original_vertex, :].squeeze(-1).reshape(h, w)

            plt.imshow(pred_b)
            plt.imshow(hashed_bl, cmap=cmap, alpha=0.75, vmin=0, vmax=hash_table_size)

            ticks = [n for n in range(0, hash_table_size + 10, 10)]
            plt.colorbar(ticks=ticks, label="Hashed indices", orientation="vertical", alpha=1.0)

            if should_show:
                plt.title(f"Image batch {b}, level {l}")
                plt.show()

            if save_dir is not None:
                plt.title(f"Image {save_dir.split('/')[-1]} batch {b}, level {l}")
                fig.savefig(f"{save_dir}/batch_{b}_level_{l}_image.png", dpi=fig.dpi)
            plt.close()

            maps.append(fig)

    return maps

In [None]:
def test_loop(
    x: torch.Tensor,
    y: torch.Tensor,
    reordered_indices: torch.Tensor,
    h: int,
    w: int,
    model: nn.Module,
    loss_fn: nn.Module,
    l_kl_loss: float,
    l_sigmas_loss: float,
    l_images_loss: float,
    l_reg_loss: float,
    norm_regularization_order: int = 2,
    should_calc_hists: bool = False,
    should_learn_images: bool = False,
    should_randomize_input: bool = False,
    excluded_params_from_reg_loss: List[str] = ["_prime_numbers", "hash_tables", "mlp"],
    should_log: List[int] = [],
    save_indices_maps_dir: str | None = None,
) -> Dict[str, Any]:
    """
    Trains the model.

    Parameters
    ----------
    x : torch.Tensor
        Input data.
    y : torch.Tensor
        Target data.
    reordered_indices : torch.Tensor
        Reordered indices.
    h : int
        Imags height.
    w : int
        Images width.
    model : nn.Module
        Model to train.
    loss_fn : nn.Module
        Loss function.
    l_kl_loss : float
        KL divergence loss lambda.
    l_sigmas_loss : float
        Sigmas loss lambda.
    l_images_loss : float
        Images loss lambda.
    l_reg_loss : float
        Regularization loss lambda.
    norm_regularization_order : int, optional (default is 2)
        Norm regularization order.
    should_calc_hists : bool, optional (default is False)
        Whether to calculate histograms or not.
    should_learn_images : bool, optional (default is False)
        Whether to learn the images or not.
    should_randomize_input : bool, optional (default is False)
        Whether to randomize the input or not.
    excluded_params_from_reg_loss : List[str], optional (default is ["_prime_numbers", "hash_tables", "mlp])
        List of parameters to exclude from the regularization loss.
    should_log : List[int], optional (default is [])
        - 0: Log "Test loop".
        - 1: Log
        - 2: Log and plot.
    save_indices_maps_dir : str | None, optional (default is None)
        Directory to save the plots.
    
    Returns
    -------
    Dict[str, Any]
        Dictionary with the loss and the predictions.
    """
    log(("Test loop", ), 0 in should_log, color=bcolors.WARNING)

    model.eval()

    pred, hashed, sigmas, min_possible_collisions, collisions, hists = model(
        x, 
        should_calc_hists=should_calc_hists, 
        should_show_hists=(2 in should_log)
    )

    log(("hashed:", hashed, hashed.shape, hashed.requires_grad, hashed.is_leaf), should_log)

    # print(torch.count_nonzero(hashed[0] != hashed[1])) # == 0
    # print(torch.count_nonzero(sigmas[0] != sigmas[1])) # == 0
    
    if should_learn_images:
        log(("pred:", pred, pred.shape, pred.requires_grad, pred.is_leaf), should_log)

    kl_div_losses, sigmas_losses, images_losses = loss_fn(
        min_possible_collisions=min_possible_collisions,
        indices=rearrange(hashed, "batch pixels levels verts 1 -> batch levels pixels (verts 1)")[0],
        sigmas=rearrange(sigmas, "batch pixels levels verts 1 -> batch levels pixels (verts 1)")[0],
        pred=pred,
        target=y,
    )
    log(("kl_div_losses:", kl_div_losses, kl_div_losses.shape, kl_div_losses.requires_grad, kl_div_losses.is_leaf), should_log)
    log(("sigmas_losses:", sigmas_losses, sigmas_losses.shape, sigmas_losses.requires_grad, sigmas_losses.is_leaf), should_log)
    if images_losses is not None:
        log(("images_losses:", images_losses, images_losses.shape, images_losses.requires_grad, images_losses.is_leaf), should_log)

    regularization_loss = torch.tensor(0.0, device=device)
    for name, param in model.named_parameters():
        # if np.sum(np.isin(np.array(name.split('.')), excluded_params_from_reg_loss)) == 0:
        if not any([excluded_param in name for excluded_param in excluded_params_from_reg_loss]):
            # print(name)
            regularization_loss += torch.linalg.norm(param, ord=norm_regularization_order)

    loss = (
        (l_kl_loss * torch.sum(kl_div_losses)) +
        (l_sigmas_loss * torch.sum(sigmas_losses)) +
        (
            (l_images_loss * torch.sum(images_losses)) 
            if images_losses is not None else
            torch.tensor(0.0, device=device)
        ) +
        (l_reg_loss * regularization_loss)
    )
    log(("loss:", loss, loss.shape, loss.requires_grad, loss.is_leaf), should_log)

    pred_images = None
    target_images = None
    images_psnr = None
    if should_learn_images:
        if should_randomize_input:
            for i in range(pred.shape[0]):
                pred[i] = pred[i][reordered_indices[i]]
                hashed[i] = hashed[i][reordered_indices[i]]
                y[i] = y[i][reordered_indices[i]]
        
        indices_maps = create_indices_mapping(
            pred=pred.detach().cpu().numpy(),
            h=h, 
            w=w,
            min_possible_collisions=min_possible_collisions,
            hashed=hashed.detach().cpu().numpy(), 
            hash_table_size=model.multiresModel.get_hash_table_size(),
            save_dir=save_indices_maps_dir,
            should_show=(2 in should_log),
            should_log=should_log
        )

        pred_images = (pred * 255).reshape(-1, h, w, 3).to(int).detach().cpu().numpy()
        target_images = (y * 255).reshape(-1, h, w, 3).to(int).detach().cpu().numpy()

        plot_images(pred_images, target_images, should_log, is_test=True)

        images_psnr = [calc_psnr(pred_images[i], target_images[i]) for i in range(pred_images.shape[0])]

    to_return = {
        "pred_images": pred_images,
        "target_images": target_images,
        "images_psnr": images_psnr,
        "min_possible_collisions": min_possible_collisions.detach().cpu().numpy(),
        "collisions": collisions.detach().cpu().numpy() if collisions is not None else None,
        "histograms": hists,
        "kl_div_losses": kl_div_losses.detach().cpu().numpy(),
        "sigmas_losses": sigmas_losses.detach().cpu().numpy(),
        "reg_loss": regularization_loss.item(),
        "images_losses": images_losses.detach().cpu().numpy() if images_losses is not None else None,
        "loss": loss.item(),
    }

    return to_return

## Early Stopper

In [None]:
class EarlyStopper:
    def __init__(self, tolerance: int = 5, min_delta: int = 0, should_reset: bool = True):
        self.tolerance: int = tolerance
        self.min_delta: int = min_delta
        self.best_loss: float = np.inf
        self.counter: int = 0
        self.early_stop: bool = False
        self._should_reset: bool = should_reset

    def __call__(self, loss):
        # print(f"best_loss: {self.best_loss}, loss: {loss}, counter: {self.counter}")

        if abs(self.best_loss - loss) < self.min_delta and (loss < self.best_loss):
            # print("Stall")
            self.counter += 1
        elif abs(self.best_loss - loss) > self.min_delta and (loss > self.best_loss):
            # print("Growing")
            self.counter += 1
        else:
            if not self._should_reset:
                if self.counter <= 0:
                    self.counter = 0
                else:
                    self.counter -= 1
            else:
                self.counter = 0
                self.best_loss = loss

        if self.counter >= self.tolerance:
            self.early_stop = True

## Initializations

### Arguments

In [None]:
train_images_paths = ["./images/strawberry_small.jpg"]

wandb_entity = "fedemonti00"
wandb_project = "project_course_final2"
# wandb_name = "train_test"

time = (datetime.now(ZoneInfo("Europe/Rome"))).strftime("%Y%m%d%H%M%S")
try:
    wandb_name = wandb_name
except NameError:
    wandb_name = time
    
print("RUN:", time)

weights_path = None # None to disable
should_test = True if weights_path is not None else False
should_train_and_test = False if weights_path is None else False

should_fast_hash = False
should_learn_images = True if not should_fast_hash else True
scheduler_type = "CosineAnnealingLR" if not should_fast_hash else None # None, "StepLR", "CosineAnnealingLR", "CosineAnnealingWarmRestarts"

hyperparameters = {
    # General
    "time": time,
    "weights_initialization": torch.nn.init.kaiming_normal_, # None, torch.nn.init.xavier_uniform_, torch.nn.init.xavier_normal_, torch.nn.init.kaiming_uniform_, torch.nn.init.kaiming_normal_
    "epochs": 2000,
    "random_seed": random_seed,
    # ------- #
    # Dataset
    "images_paths": train_images_paths,

    "should_randomize_input": False,
    # ------- #
    # LearnableHashFunctionModel
    "hash_hidden_layers_widths": [64],
    "output_size": 2,
    "hash_table_size": 2**8,
    "sigmas_scale": 1.0,
    "hash_hidden_layers_activation": nn.Tanh(), # nn.Tanh(), nn.Hardtanh(), nn.Softsign() # PROBABLY SHOULD BE SIMMETRIC
    "hash_dropout_rate": None,

    "hash_lr": 1e-3 if not should_fast_hash else None,
    "hash_weight_decay": 0 if not should_fast_hash else None, # 1e-6
    # ------- #
    # MultiresolutionModel
    "n_min": 8,
    "n_max": 32,
    "num_levels": 4,
    
    "should_fast_hash": should_fast_hash,
    "should_use_all_levels": True if not should_fast_hash else True,
    "should_normalize_grid_coords": True if not should_fast_hash else False,
    "should_calc_collisions": True,
    # ------- #
    # GNGFModel
    "feature_size": 2,
    "gngf_hidden_layers_widths": [64, 64],
    "topk": 5 if not should_fast_hash else 1, # must be odd or -1 to use all the hashed coordinates

    "features_lr": 1e-3,
    "features_weight_decay": 0,
    "mlp_lr": 1e-2,
    "mlp_weight_decay": 1e-6,

    "should_circular_topk": True,
    "should_learn_images": should_learn_images,
    # ------- #
    # Loss
    "kl_div_reduction": "batchmean",
    "hist_sanitize_eps": 1e-10,
    "should_sanitize_hist_before": True,
    "should_exp_normalize_kl_div": True,

    "l_kl_loss": 1e1 if not should_fast_hash else 0,
    "l_sigmas_loss": 1 if not should_fast_hash else 0,
    "l_images_loss": 1e2 if should_learn_images else 0,
    "l_reg_loss": 1e0 if not should_fast_hash else 0,
    "norm_regularization_order": 2,

    "excluded_params_from_reg_loss": np.array(["_prime_numbers", "hash_tables", "mlp"]),

    "gradient_clipping": 1e-1 if not should_fast_hash else None, # None to disable
    # ------- #
    # Optimizer & Scheduler
    "NeRF_optimizer": torch.optim.Adam if should_learn_images else None,
    "hash_optimizer": torch.optim.AdamW if not should_fast_hash else None,

    # only for hash function model
    "scheduler_type": scheduler_type if not should_fast_hash else None,
    "scheduler_gamma": (0.9 if scheduler_type == "StepLR" else 0) if scheduler_type else None,
    "scheduler_step_size": (100 if scheduler_type == "StepLR" else 1000) if scheduler_type else None,
    "scheduler_last_epoch": 1000 if scheduler_type else -1,
    # ------- #
}

histograms_rate = 100 if not should_train_and_test else 1
save_weights_rate = 1 if (hyperparameters["epochs"] > 99) else None # None to disable
if save_weights_rate is not None:
    os.makedirs(f"./weights", exist_ok=True)

early_stopper_tolerance = hyperparameters["epochs"] # hyperparameters["epochs"] // 10 if hyperparameters["epochs"] > 99 else hyperparameters["epochs"]
early_stopper_min_delta = 1e-6

should_log = (True if (hyperparameters["epochs"] < 99) else False) if (weights_path is None) else False
should_log_allocated_memory = False if should_log else False
should_log_grads = True if should_log else False
should_wandb = (True if hyperparameters["epochs"] > 99 else False) if (weights_path is None) else False

print_allocated_memory(should_log_allocated_memory)

### Model Weights Initialization

In [None]:
def model_init_weights(model: nn.Module):
    for m in model.modules():
        if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
            # Applying Xavier initialization to linear and convolutional layers
            if hyperparameters["weights_initialization"] is not None:
                hyperparameters["weights_initialization"](m.weight)
            
            if m.bias is not None:
                nn.init.constant_(m.bias, 0)

### Wandb

#### Init

In [None]:
if should_wandb:
    wandb.init(
        entity=wandb_entity,
        project=wandb_project,
        name=wandb_name,
        config=hyperparameters,
        save_code=False
    )

#### Logger

In [None]:
def wandb_log(
    e: int,
    batch_size: int,
    train_obj: Dict[str, Any],
    lr: float | None = None,
    should_log: bool = False,
    should_wandb: bool = False
) -> None:
    """
    
    Parameters
    ----------
    e : int
        Epoch.
    train_obj : Dict[str, Any]
        Training object.
    lr : float | None, optional (default is None)
        Learning rate.
    should_log : bool, optional (default is False)
        Whether to log or not.
    should_wandb : bool, optional (default is False)
        Whether to log to wandb or not.
    
    Returns
    -------
    None
    """
    
    to_log = {}

    if lr is not None:
        to_log["lr"] = lr

    for key, values in train_obj.items():
        if values is not None:
            if isinstance(values, np.ndarray) or isinstance(values, List):
                for i, value in enumerate(values):
                    num = ("level_" if (len(values) > batch_size) else "") + str(i)

                    if isinstance(value, plt.Figure) or isinstance(value, np.ndarray):
                        to_log[f"train/media/{key}_{num}"] = wandb.Image(value)
                    else:
                        to_log[f"train/{key}_{num}"] = value
            else:
                to_log[f"train/{key}"] = values

    log(("log:", to_log), should_log)
    if should_wandb:
        wandb.log(to_log)

    del to_log

### Load Dataset

In [None]:
data = ImageDataset(
    train_images_paths,
    should_randomize_input=hyperparameters["should_randomize_input"],
    should_log=[]
    # should_log=[0]
)

X, Y, h, w, reordered_indices = data[-1].values()

batch_size = X.shape[0]
input_size = X.shape[2] # X = (batch, pixels, input_dim, 1, 1) 

# print(torch.count_nonzero(X[0] != random_X[0][random_reordered_indices[0, :]]))
# print(torch.count_nonzero(X[1] != random_X[1][random_reordered_indices[1, :]]))
# print(torch.count_nonzero(a["X"]!= b["X"][..., :][b["reordered_indices"]]))


### Load Models

In [None]:
if not hyperparameters["should_fast_hash"]:
    learnableHashFunctionModel = LearnableHashFunctionModel(
        hidden_layers_widths=hyperparameters["hash_hidden_layers_widths"],
        input_size=input_size,
        output_size=hyperparameters["output_size"],
        hash_table_size=hyperparameters["hash_table_size"],
        sigmas_scale=hyperparameters["sigmas_scale"],
        hidden_layers_activation=hyperparameters["hash_hidden_layers_activation"],
        dropout_rate=hyperparameters["hash_dropout_rate"],
        should_log=[] if should_log else [],
        # should_log=[0, 1, 2] if should_log else [],
    )

    if hyperparameters["weights_initialization"] is not None:
        learnableHashFunctionModel.apply(model_init_weights)

multiresolutionModel = MultiresolutionModel(
    n_min=hyperparameters["n_min"],
    n_max=hyperparameters["n_max"],
    num_levels=hyperparameters["num_levels"],
    hashModel=None if hyperparameters["should_fast_hash"] else learnableHashFunctionModel,
    hash_table_size=hyperparameters["hash_table_size"] if hyperparameters["should_fast_hash"] else None ,
    input_size=input_size if hyperparameters["should_fast_hash"] else None,
    should_use_all_levels=hyperparameters["should_use_all_levels"],
    should_fast_hash=hyperparameters["should_fast_hash"],
    should_calc_collisions=hyperparameters["should_calc_collisions"],
    should_normalize_grid_coords=hyperparameters["should_normalize_grid_coords"],
    should_log=[3] if should_log else [],
    # should_log=[0, 1, 2, 3, 4, 5, 6, 7, 8] if should_log else [],
)

gngfModel = GNGFModel(
    batch_size=batch_size,
    hidden_layers_widths=hyperparameters["gngf_hidden_layers_widths"],
    multiresModel=multiresolutionModel,
    feature_size=hyperparameters["feature_size"],
    topk=hyperparameters["topk"],
    should_circular_topk=hyperparameters["should_circular_topk"],
    should_learn_images=hyperparameters["should_learn_images"],
    should_log=[] if should_log else [],
    # should_log=[0, 1, 2, 3] if should_log else [],
)

print(gngfModel)

loss_fn = Loss(
    hash_table_size=hyperparameters["hash_table_size"],
    hist_sanitize_eps=hyperparameters["hist_sanitize_eps"],
    kl_div_reduction=hyperparameters["kl_div_reduction"],
    should_use_all_levels=hyperparameters["should_use_all_levels"],
    should_sanitize_hist_before=hyperparameters["should_sanitize_hist_before"],
    should_exp_normalize_kl_div=hyperparameters["should_exp_normalize_kl_div"],
    should_log=[9] if should_log else []
    # should_log=[0, 1, 2, 3, 4] if should_log else []
)

models_parameters = {}
opts = []
if hyperparameters["should_learn_images"]:
    models_parameters["NeRF"] = {
        "each": [ # TODO cambiare nome
            {
                "param": gngfModel.hash_tables.parameters(),
                "lr": hyperparameters["features_lr"],
                "weight_decay": hyperparameters["features_weight_decay"]
            },
            {
                "param": gngfModel.mlp.parameters(),
                "lr": hyperparameters["mlp_lr"],
                "weight_decay": hyperparameters["mlp_weight_decay"]
            }
        ],
        "betas": (0.9, 0.99),
        "eps": 1e-15
    }
    opts.append(hyperparameters["NeRF_optimizer"])

if not hyperparameters["should_fast_hash"]:
    models_parameters["hash"] = {
        "each": [ # TODO cambiare nome
            {
                "param": gngfModel.multiresModel.hashModel.parameters(),
                "lr": hyperparameters["hash_lr"],
                "weight_decay": hyperparameters["hash_weight_decay"]
            }
        ],
        "betas": (0.9, 0.999), # default values
        "eps": 1e-8 # default value
    }
    opts.append(hyperparameters["hash_optimizer"])

optimizers = get_optimizer(
    models_parameters=models_parameters,
    optimizers=opts,
)

if hyperparameters["scheduler_type"] == "CosineAnnealingLR":
    scheduler = CosineAnnealingLR(optimizers["hash"], T_max=hyperparameters["scheduler_step_size"], eta_min=hyperparameters["scheduler_gamma"])
elif hyperparameters["scheduler_type"] == "CosineAnnealingWarmRestarts":
    scheduler = CosineAnnealingWarmRestarts(optimizers["hash"], T_0=hyperparameters["scheduler_step_size"], T_mult=1, eta_min=hyperparameters["scheduler_gamma"])
elif hyperparameters["scheduler_type"] == "StepLR":
    scheduler = StepLR(optimizers["hash"], step_size=hyperparameters["scheduler_step_size"], gamma=hyperparameters["scheduler_gamma"])

early_stopper = EarlyStopper(
    tolerance=early_stopper_tolerance, 
    min_delta=early_stopper_min_delta
)

### Load Weights

In [None]:
if weights_path is not None:
    checkpoint = torch.load(weights_path)

    gngfModel.load_state_dict(checkpoint["model_state_dict"])

    if "optimizer_NeRF_state_dict" in checkpoint.keys():
        optimizers["NeRF"].load_state_dict(checkpoint["optimizer_NeRF_state_dict"])
    if "optimizer_hash_state_dict" in checkpoint.keys():
        optimizers["hash"].load_state_dict(checkpoint["optimizer_hash_state_dict"])
    
    hyperparameters = checkpoint["hyperparameters"]
    wandb_name = checkpoint["run_name"]

    start_epoch = checkpoint["epoch"] + 1
    hyperparameters["epochs"] = start_epoch + 1

    if hyperparameters["scheduler_type"] == "CosineAnnealingLR":
        scheduler = CosineAnnealingLR(optimizers["hash"], T_max=hyperparameters["scheduler_step_size"], eta_min=hyperparameters["scheduler_gamma"], last_epoch=checkpoint["epoch"])
    elif hyperparameters["scheduler_type"] == "CosineAnnealingWarmRestarts":
        scheduler = CosineAnnealingWarmRestarts(optimizers["hash"], T_0=hyperparameters["scheduler_step_size"], T_mult=1, eta_min=hyperparameters["scheduler_gamma"], last_epoch=checkpoint["epoch"])
    elif hyperparameters["scheduler_type"] == "StepLR":
        scheduler = StepLR(optimizers["hash"], step_size=hyperparameters["scheduler_step_size"], gamma=hyperparameters["scheduler_gamma"], last_epoch=checkpoint["epoch"])

    # should_log = True
    should_log = False
    should_log_allocated_memory = True if should_log else False
    should_log_grads = True if should_log else False
    # should_wandb = True
    should_wandb = False if not should_test else False

    print("Weights loaded", checkpoint)

### Save Checkpoints

In [None]:
def save_checkpoint(
    epoch: int,
    run_name: str,
    model: nn.Module,
    best_loss: float,
    best_psnr: float,
    loss: float,
    psnr: float | None,
    NeRF_optimizer: torch.optim.Optimizer | None = None,
    hash_optimizer: torch.optim.Optimizer | None = None,
    save_weights_rate: int | None = None,
    should_log: bool = False,
):
    checkpoint = {
        "epoch": epoch,
        "run_name": run_name,
        "hyperparameters": hyperparameters,
        "model_state_dict": model.state_dict(),
        "loss": loss,
        "psnr": psnr,
    }
    if NeRF_optimizer:
        checkpoint["optimizer_NeRF_state_dict"] = NeRF_optimizer.state_dict()
    if hash_optimizer:
        checkpoint["optimizer_hash_state_dict"] = hash_optimizer.state_dict()

    should_checkpoint_end = ((epoch == hyperparameters["epochs"] - 1) or early_stopper.early_stop)

    if (loss < best_loss): # only if the loss is better than the previous one save the model
        best_loss = loss
        weights_name = run_name + "_loss"

        if (save_weights_rate is not None and (epoch % save_weights_rate == 0)):
            os.makedirs(f"./weights/{run_name}", exist_ok=True)
            torch.save(checkpoint, f"./weights/{run_name}/{weights_name}_checkpoint.pth")
    
    should_checkpoint_psnr = (psnr > best_psnr) if psnr else False
    if should_checkpoint_psnr:
        best_psnr = psnr
        weights_name = run_name + "_psnr"

        if (save_weights_rate is not None and (epoch % save_weights_rate == 0)):
            os.makedirs(f"./weights/{run_name}", exist_ok=True)
            torch.save(checkpoint, f"./weights/{run_name}/{weights_name}_checkpoint.pth")

    if should_checkpoint_end:
        weights_name = run_name + "_last"

        if (save_weights_rate is not None and (epoch % save_weights_rate == 0)):
            os.makedirs(f"./weights/{run_name}", exist_ok=True)
            torch.save(checkpoint, f"./weights/{run_name}/{weights_name}_checkpoint.pth")

    log(("Checkpoint saved:", checkpoint), should_log, color=bcolors.OKGREEN)
    
    del checkpoint

    return best_loss, best_psnr

### Run

In [None]:
plt.ioff()
should_calc_hists = False

try:
    start_epoch
except NameError:
    start_epoch = 0

pbar = tqdm(range(start_epoch, hyperparameters["epochs"]))
best_loss = np.inf
best_psnr = 0.0

print_allocated_memory(should_log_allocated_memory)

for e in pbar:
    should_calc_hists = ((e == hyperparameters["epochs"] - 1) or (e % histograms_rate == 0) or early_stopper.early_stop) if histograms_rate is not None else False
    
    if hyperparameters["should_fast_hash"] and (e > 0):
        should_calc_hists = False
    
    for key, optimizer in optimizers.items():
        optimizer.zero_grad()

    if not should_test or should_train_and_test:
        results_dict = train_loop(
            x=X.to(device),
            y=Y.to(device),
            reordered_indices=reordered_indices,
            h=h,
            w=w,
            model=gngfModel,
            optimizers=optimizers,
            loss_fn=loss_fn,
            l_kl_loss=hyperparameters["l_kl_loss"],
            l_sigmas_loss=hyperparameters["l_sigmas_loss"],
            l_images_loss=hyperparameters["l_images_loss"],
            l_reg_loss=hyperparameters["l_reg_loss"],
            norm_regularization_order=hyperparameters["norm_regularization_order"],
            gradient_clipping=hyperparameters["gradient_clipping"],
            should_calc_hists=should_calc_hists,
            should_learn_images=hyperparameters["should_learn_images"],
            excluded_params_from_reg_loss=hyperparameters["excluded_params_from_reg_loss"],
            should_randomize_input=hyperparameters["should_randomize_input"],
            should_log=[0, 2] if should_log else [],
            # should_log=[0, 1] if should_log else [],
            should_log_grads=should_log_grads
        )

        loss, psnr = save_checkpoint(
            epoch = e,
            run_name = wandb_name,
            model = gngfModel,
            best_loss = best_loss,
            best_psnr = best_psnr,
            loss = results_dict["loss"],
            psnr = np.mean(results_dict["images_psnr"]) if hyperparameters["should_learn_images"] else None,
            NeRF_optimizer = optimizers["NeRF"] if hyperparameters["should_learn_images"] else None,
            hash_optimizer = optimizers["hash"] if not hyperparameters["should_fast_hash"] else None,
            save_weights_rate = save_weights_rate,
            should_log=False if should_log else False,
        )
        best_loss = loss
        best_psnr = psnr
        del loss, psnr
    
    if should_test or should_train_and_test:
        results_dict = test_loop(
            x=X.to(device),
            y=Y.to(device),
            reordered_indices=reordered_indices,
            h=h,
            w=w,
            model=gngfModel,
            loss_fn=loss_fn,
            l_kl_loss=hyperparameters["l_kl_loss"],
            l_sigmas_loss=hyperparameters["l_sigmas_loss"],
            l_images_loss=hyperparameters["l_images_loss"],
            l_reg_loss=hyperparameters["l_reg_loss"],
            norm_regularization_order=hyperparameters["norm_regularization_order"],
            should_calc_hists=True if should_test else False,
            should_learn_images=hyperparameters["should_learn_images"],
            should_randomize_input=hyperparameters["should_randomize_input"],
            excluded_params_from_reg_loss=hyperparameters["excluded_params_from_reg_loss"],
            should_log=[0, 2],
            save_indices_maps_dir=f"./weights/{wandb_name}" if should_test else None,
        )
    
    for key, optimizer in optimizers.items():
        # TODO HOW?
        ######
        # Lastly, we skip Adam steps for hash table entries whose gradient is exactly 0. 
        # This saves ∼ 10% performance when gradients are sparse, which is a common occurrence with 𝑇 ≫ BatchSize. 
        # Even though this heuristic violates some of the assumptions behind Adam, we observe no degradation in convergence.
        ######
        optimizer.step()

    if (
        (hyperparameters["scheduler_type"] is not None) 
        and 
        (
            hyperparameters["scheduler_last_epoch"] == -1
            or 
            hyperparameters["scheduler_last_epoch"] > e
        )
    ) and (not should_test or should_train_and_test):
        scheduler.step()

    if np.isnan(results_dict["loss"]):
        log(("!!! NaN !!!"), True, color=bcolors.FAIL)
        break 

    wandb_log(
        e=e,
        batch_size=batch_size,
        lr=scheduler.get_last_lr()[0] if (hyperparameters["scheduler_type"] is not None) else None,
        train_obj=results_dict,
        should_log=False if should_log else False,
        should_wandb=should_wandb
    )

    if early_stopper.early_stop:
        print("!!! Stopping at epoch:", e, "!!!")

        del results_dict
        break

    early_stopper(results_dict["loss"])

    pbar.set_description(f"Epoch {e} - Loss: {results_dict['loss']}, Collisions: {results_dict['collisions']}, PSNR: {results_dict['images_psnr']}")

    print_allocated_memory(should_log_allocated_memory)
    del results_dict


### Wandb finish and free memory

In [None]:
if should_wandb:
    wandb.finish()

if torch.cuda.is_available():
    torch.cuda.empty_cache()