# NeRF: Collision handling in Instant Neural Graphics Primitives
#### Federico Montagna (fedemonti00@gmail.com)

# Code:

## Imports

In [None]:
from typing import Tuple, List

import numpy as np
from math import atan2
import cv2
from PIL import Image
import os
import matplotlib
from matplotlib.ticker import MultipleLocator, AutoMinorLocator, FixedLocator
import matplotlib.pyplot as plt
import random
from scipy.stats import norm
from sklearn.model_selection import train_test_split

import collections, functools, operator
from collections import Counter
import itertools

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.profiler as profiler
from torch.optim.lr_scheduler import StepLR, CosineAnnealingLR, CosineAnnealingWarmRestarts

# from torchinfo import summary

from einops import rearrange, reduce, repeat

import traceback
from pprint import pprint

import wandb
from tqdm import tqdm

import pdb

try:
    from zoneinfo import ZoneInfo
except ImportError:
    from backports.zoneinfo import ZoneInfo

from datetime import datetime

# from params import *

print("Cuda avilable:", torch.cuda.is_available())
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"Available device {i}:", torch.cuda.get_device_name(i))

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"Current device {torch.cuda.current_device()}:", torch.cuda.get_device_name(torch.cuda.current_device()))

torch.set_default_device(device)

random_seed = np.random.randint(0, (2**16 - 1)) #2**16 - 1
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.random.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
print("Random seed:", random_seed)

os.environ["WANDB_NOTEBOOK_NAME"] = "main.ipynb"

plt.style.use("ggplot")

## Debug Functions

In [None]:
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def print2(texts, log: bool = False, color: bcolors = bcolors.OKCYAN) -> None:
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(color, "Line: ", calling_line, bcolors.ENDC)
        for text in texts:
            print(text)
        print(color, "-"*20, bcolors.ENDC)


def print_allocated_memory(log: bool = True):
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(bcolors.HEADER, "Line: ", calling_line, bcolors.ENDC)

        allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Allocated Memory: {allocated_memory:.2f} GB")

        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Peak Allocated Memory: {peak_memory:.2f} GB")

        print(bcolors.OKCYAN, "-"*20, bcolors.ENDC)

## Load wandb apikey

In [None]:
# apikey_path = ".wandb_apikey.txt"
# if os.path.exists(apikey_path):
#     with open(apikey_path, "r") as f:
#         apikey = f.read()
#         !wandb login {apikey} # --relogin

## Dataset

In [None]:
class ImagesDataset(Dataset):
    def __init__(
        self,
        root: str,
        dir_name: str,
        images_names: List[str],
        should_random_permute_input: bool = False,
    ) -> None:
        """

        Parameters
        ----------
        root : str
            Path to root directory.
        dir_name : str
            Name of directory.
        images_names : List[str]
            List of images names.
        should_random_permute_input : bool, optional (default is False)
            Should random permute input
        
        Returns
        -------
        None
        """

        self._root: str = root
        self._dir_name: str = dir_name
        self._images_names: List[str] = images_names

        self._should_random_permute_input: bool = should_random_permute_input

        self._images_paths: List[str] = [
            os.path.join(self._root, self._dir_name, image_name)
            for image_name in self._images_names
        ]
    
    def __getitem__(self, idx: torch.Tensor or int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
        """
        Get data by indices.

        Parameters
        ----------
        idx : torch.Tensor or int
            Indices of data.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, int, int]
            Tuple of images.

        Raises
        ------
        ValueError
            If images have different sizes.
        """

        if idx == -1:
            idx = torch.arange(len(self._images_paths))
        
        if torch.is_tensor(idx):
            idx = idx
        else:
            idx = torch.tensor(idx)

        images: List = [
            rearrange(io.read_image(self._images_paths[_id]).to(device), "rgb h w -> h w rgb")
            for _id in idx
        ]

        heights, widths = list(set([image.shape[0] for image in images])), list(set([image.shape[1] for image in images]))

        if len(heights) > 1 or len(widths) > 1:
            raise ValueError("Images have different sizes.")
        else:
            h = heights[0]
            w = widths[0]

        imgs_shape = h * w
        
        reordered_indices: torch.Tensor = torch.stack([
            torch.zeros((imgs_shape, )).int()
            for _ in images
        ])

        X: torch.Tensor = torch.zeros(idx.shape[0], imgs_shape, 2)
        Y: torch.Tensor = torch.zeros(len(images), imgs_shape, images[0].shape[-1])

        for i, image in enumerate(images):
            if self._should_random_permute_input:
                shuffled_indices = torch.randperm(imgs_shape).int()
            else:
                shuffled_indices = torch.arange(imgs_shape).int()
            
            reordered_indices[i][shuffled_indices] = torch.arange(imgs_shape).int()
            
            X[i] = torch.tensor(
                np.stack(np.meshgrid(range(h), range(w), indexing="ij"), axis=-1).reshape(-1, 2)
            )[shuffled_indices]

            Y[i] = rearrange(image, "h w rgb -> (h w) rgb")[shuffled_indices]

        X = (
            X.float() / (max(w, h)) # No (max(w, h) - 1)
        ).unsqueeze(-1).unsqueeze(-1)

        Y = Y.float() / 255
        
        # else:
        #     X: torch.Tensor = (
        #         torch.stack([
        #             torch.tensor(
        #                 np.stack(np.meshgrid(range(h), range(w), indexing="ij"), axis=-1).reshape(-1, 2)
        #             )
        #             for _ in images
        #         ]).float() / (max(w, h)) # No (max(w, h) - 1)
        #     ).unsqueeze(-1).unsqueeze(-1) #.requires_grad_()
            
        #     Y: torch.Tensor = torch.stack(
        #         images
        #     ).float() / 255

        return X, Y, h, w, reordered_indices, self._images_names

    def __len__(self) -> int:
        """
        Returns
        -------
        int
            Length of dataset.
        """
        return len(self._images_paths)

## Models

### Backward Pass Differentiable Approximation
[https://github.com/kitayama1234/Pytorch-BPDA]()

In [None]:
def BPDA(x, round_function):
    forward_value = round_function(x)
    out = x.clone()
    out.data = forward_value.data
    return out

def differentiable_floor(x, round_function=torch.floor):
    return BPDA(x, round_function)

def differentiable_round(x, round_function=torch.round):
    return BPDA(x, round_function)

### Gaussian Convolution Backward

In [None]:
class GaussianConvolution(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, hash_table_size: int):
        # pdb.set_trace()
        
        ctx.save_for_backward(x)
        ctx.hash_table_size = hash_table_size

        indices = differentiable_round(x * hash_table_size)

        # ctx.indices = indices
        # ctx.save_for_backward(indices)

        return indices

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # pdb.set_trace()
        
        # indices = ctx.saved_tensors
        x, = ctx.saved_tensors
        hash_table_size = ctx.hash_table_size

        indices = x * hash_table_size
        # indices = ctx.indices

        grad_x = grad_output * norm.pdf(np.arange(0, hash_table_size, 1), loc=indices, scale=1)

        return grad_x, None

### Differentiable Histogram

In [None]:
#############################################
# Differentiable Histogram Counting Method
#############################################
# https://github.com/hyk1996/pytorch-differentiable-histogram
def differentiable_histogram(x, bins=255, min=0.0, max=1.0):

    if len(x.shape) == 4:
        n_samples, n_chns, _, _ = x.shape
    elif len(x.shape) == 2:
        n_samples, n_chns = 1, 1
    else:
        raise AssertionError('The dimension of input tensor should be 2 or 4.')

    hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device)
    delta = (max - min) / bins

    BIN_Table = torch.arange(start=0, end=bins, step=1) * delta

    for dim in range(1, bins-1, 1):
        h_r = BIN_Table[dim].item()             # h_r
        h_r_sub_1 = BIN_Table[dim - 1].item()   # h_(r-1)
        h_r_plus_1 = BIN_Table[dim + 1].item()  # h_(r+1)

        mask_sub = ((h_r > x) & (x >= h_r_sub_1)).float()
        mask_plus = ((h_r_plus_1 > x) & (x >= h_r)).float()

        hist_torch[:, :, dim] += torch.sum(((x - h_r_sub_1) * mask_sub).view(n_samples, n_chns, -1), dim=-1)
        hist_torch[:, :, dim] += torch.sum(((h_r_plus_1 - x) * mask_plus).view(n_samples, n_chns, -1), dim=-1)

    return hist_torch / delta


#############################################
# Kornia Differentiable Histogram Counting Method
#############################################
# https://kornia.readthedocs.io/en/latest/_modules/kornia/enhance/histogram.html#histogram

def marginal_pdf(
    values: torch.Tensor, bins: torch.Tensor, sigma: torch.Tensor, epsilon: float = 1e-10
) -> Tuple[torch.Tensor, torch.Tensor]:
    """Calculate the marginal probability distribution function of the input tensor based on the number of
    histogram bins.

    Args:
        values: shape [BxNx1].
        bins: shape [NUM_BINS].
        sigma: shape [1], gaussian smoothing factor.
        epsilon: scalar, for numerical stability.

    Returns:
        Tuple[torch.Tensor, torch.Tensor]:
          - torch.Tensor: shape [BxN].
          - torch.Tensor: shape [BxNxNUM_BINS].
    """

    if not isinstance(values, torch.Tensor):
        raise TypeError(f"Input values type is not a torch.Tensor. Got {type(values)}")

    if not isinstance(bins, torch.Tensor):
        raise TypeError(f"Input bins type is not a torch.Tensor. Got {type(bins)}")

    if not isinstance(sigma, torch.Tensor):
        raise TypeError(f"Input sigma type is not a torch.Tensor. Got {type(sigma)}")

    if not values.dim() == 3:
        raise ValueError(f"Input values must be a of the shape BxNx1. Got {values.shape}")

    if not bins.dim() == 1:
        raise ValueError(f"Input bins must be a of the shape NUM_BINS. Got {bins.shape}")

    if not sigma.dim() == 0:
        raise ValueError(f"Input sigma must be a of the shape 1. Got {sigma.shape}")

    residuals = values - bins.unsqueeze(0).unsqueeze(0)
    kernel_values = torch.exp(-0.5 * (residuals / sigma).pow(2))

    pdf = torch.mean(kernel_values, dim=1)
    normalization = torch.sum(pdf, dim=1).unsqueeze(1) + epsilon
    pdf = pdf / normalization

    return pdf, kernel_values

def histogram(x: torch.Tensor, bins: torch.Tensor, bandwidth: torch.Tensor = None, epsilon: float = 1e-10) -> torch.Tensor:
    """Estimate the NORMALIZED histogram of the input tensor.

    The calculation uses kernel density estimation which requires a bandwidth (smoothing) parameter.

    Args:
        x: Input tensor to compute the histogram with shape :math:`(B, D)`.
        bins: The number of bins to use the histogram :math:`(N_{bins})`.
        bandwidth: Gaussian smoothing factor with shape shape [1].
        epsilon: A scalar, for numerical stability.

    Returns:
        Computed histogram of shape :math:`(B, N_{bins})`.

    Examples:
        >>> x = torch.rand(1, 10)
        >>> bins = torch.torch.linspace(0, 255, 128)
        >>> hist = histogram(x, bins, bandwidth=torch.tensor(0.9))
        >>> hist.shape
        torch.Size([1, 128])
    """

    # x = rearrange(x, "b d -> 1 (b d)")
    # x = rearrange(x, "b d -> d b")
    # x = rearrange(x, "b d -> (d b) 1")
    # print(x.shape, bins.shape, bandwidth.shape)

    N = x.size(1)
    if bandwidth is None:
        std = torch.std(x)
        q25 = np.percentile(x.detach().cpu().numpy(), 25)
        q75 = np.percentile(x.detach().cpu().numpy(), 75)
        iqr = torch.tensor(q75 - q25)
        bandwidth = 0.9 * torch.min(std, iqr / 1.34) * (N **(-0.2))

    # print(sigma_initial)

    pdf, _ = marginal_pdf(x.unsqueeze(2), bins, bandwidth, epsilon)
    # print("PDF:", pdf, pdf.shape)

    return pdf * (N + epsilon)


### Hash Function

(Student's T Table)[https://www.craftonhills.edu/current-students/tutoring-center/mathematics-tutoring/distribution_tables_normal_studentt_chisquared.pdf]  
$\alpha = 0.8$, $DF = +\inf$

In [None]:
class HashFunction(nn.Module):
    def __init__(
        self,
        hidden_layers_widths: List[int],
        input_dim: int = 2,
        output_dim: int = 1,
        hash_table_size: int = 2**14,
        sigma_scale: float = 1.0,
        should_log: int = 0, 
    ) -> None:
        """
        Hash function module.

        Parameters
        ----------
        hidden_layers_widths : List[int]
            List of hidden layers widths.
        input_dim : int, optional (default is 2)
            Input dimension
        output_dim : int, optional (default is 1)
            Output dimension
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        sigma_scale : float, optional (default is 1.0)
            Sigma scale.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log layers grads and outputs.

        Returns
        -------
        None
        """

        super(HashFunction, self).__init__()

        self._hash_table_size: int = hash_table_size - 1
        self._sigma_scale: float = sigma_scale

        self._alpha: float = 0.8 # student's t-distribution confidence level
        
        self._should_log: int = should_log

        layers_widths: List[int] = [input_dim, *hidden_layers_widths, output_dim]

        self.module_list: nn.ModuleList = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_features=layers_widths[i], out_features=layers_widths[i + 1]),
                nn.ReLU() if (i < (len(layers_widths) - 2)) else nn.Sigmoid()
            ) for i in range(len(layers_widths) - 1)
        ])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
        print2(("Input:", x, x.shape), (self._should_log > 0))
        print2(("Input grad:", x.requires_grad, ), (self._should_log > 0))

        for i, layer in enumerate(self.module_list):
            x = layer(x)

        print2((f"After layers:", f"Output: {x}, shape: {x.shape}"), (self._should_log > 0))
        print2((f"After layers:", f"Grad info: {x.requires_grad}, {x.grad_fn}"), (self._should_log > 1))
        
        # x = torch.nan_to_num(x) # Sanitize nan to 0.0

        if x.shape[-1] == 1: # directly indices
            indices = GaussianConvolution.apply(x, self._hash_table_size)
            print2(("GaussianConvolution:", indices, indices.shape), (self._should_log > 0))
            print2(("GaussianConvolution grad:", indices.requires_grad, ), (self._should_log > 1))

            sigma = None
        else: # mu and sigma            
            x = x.unsqueeze(-1)
            print2(("x:", x, x.shape), (self._should_log > 0))

            sigma = (x[..., 1, :] * self._sigma_scale)
            x = x[..., 0, :]
            mu = (differentiable_round(x * self._hash_table_size))

            print2(("Mu:", mu, mu.shape), (self._should_log > 0))
            print2(("Sigma:", sigma, sigma.shape), (self._should_log > 0))

            a, b = norm.interval(self._alpha, loc=mu.detach().cpu().numpy(), scale=sigma.detach().cpu().numpy()) # a -> at left of mean, b -> at right of mean
            a, b = np.round(a), np.round(b)

            print2(("a:", a, a.shape), (self._should_log > 0))
            print2(("b:", b, b.shape), (self._should_log > 0))

            indices = mu

            del mu, a, b

        return x, indices, sigma

### Multiresolution

In [None]:
class Multiresolution(nn.Module):
    def __init__(
        self,
        n_min: int,
        n_max: int,
        num_levels: int,
        HashFunction: HashFunction,
        hash_table_size: int = 2**14,
        input_dim: int = 2,
        should_use_all_levels: bool = False,
        should_normalize_levels: bool = False,
        should_fast_hash: bool = False,
        should_log: int = 0
    ) -> None:
        """
        Multiresolution module.

        Parameters
        ----------
        n_min : int
            Minimum scaling factor.
        n_max : int
            Maximum scaling factor.
        num_levels : int
            Number of levels.
        HashFunction : HashFunction
            Hash function module.
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        input_dim : int, optional (default is 2)
            Input dimension.
        should_use_all_levels : bool, optional (default is False)
            Whether to use all levels or only the ones with collisions.
        should_normalize_levels : bool, optional (default is False)
            Whether to normalize levels or not.
        should_fast_hash : bool, optional (default is False)
            Whether to use fast hash instead of HashFunction or not.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log helper functions.
            - > 2: Log collisions.
            - > 3: Log dummies.
            - > 5: Log initialization.

        Returns
        -------
        None
        """

        super(Multiresolution, self).__init__()

        self.HashFunction: nn.Module = HashFunction

        self._hash_table_size: int = hash_table_size
        self._num_levels: int = num_levels
        self._input_dim: int = input_dim
        self._should_use_all_levels: bool = should_use_all_levels
        self._should_normalize_levels: bool = should_normalize_levels
        self._should_fast_hash: bool = should_fast_hash
        self._should_log: int = should_log

        b: torch.Tensor = torch.tensor(np.exp((np.log(n_max) - np.log(n_min)) / (self._num_levels - 1))).float()
        if b > 2 or b <= 1:
            print(
                f"The between level scale is recommended to be <= 2 and needs to be > 1 but was {b:.4f}."
            )
        
        self._levels: torch.Tensor = torch.stack([
            torch.floor(n_min * (b ** l)) for l in range(self._num_levels)
        ]).reshape(1, 1, -1, 1)
        print2(("Levels:", self._levels, self._levels.shape), (self._should_log > 5))
        print2(("Levels grads:", self._levels.requires_grad, ), (self._should_log > 5))

        self._voxels_helper_hypercube: torch.Tensor = rearrange(
            torch.tensor(
                np.stack(np.meshgrid(range(2), range(2), range(self._input_dim - 1), indexing="ij"), axis=-1)
            ),
            "cols rows depths verts -> (depths rows cols) verts"
        ).T[:self._input_dim, :].unsqueeze(0).unsqueeze(2)
        print2(("voxels_helper_hypercube:", self._voxels_helper_hypercube, self._voxels_helper_hypercube.shape), (self._should_log > 5))
        print2(("voxels_helper_hypercube grads:", self._voxels_helper_hypercube.requires_grad), (self._should_log > 5))

        if self._should_fast_hash:
            self._prime_numbers = torch.nn.Parameter(
                torch.from_numpy(
                    np.array([1, 2654435761, 805459861])
                ).to(device),
                False
            )
    
    def forward(
        self, 
        x: torch.Tensor, 
        should_calc_hists: bool = False,
    ) -> Tuple[torch.Tensor]:
        # print2(("x:", x, x.shape), (self._should_log > 0))
        print2(("x grads:", x.requires_grad, ), (self._should_log > 0))

        _, grid_coords = self._scale_to_grid(x)
        dummy_grids, og_indices, _ = self._calc_dummies(grid_coords=grid_coords)
        min_possible_collisions, _ = self.calc_hash_collisions(dummy_grids=dummy_grids)

        if self._should_fast_hash:
            hashed = self._fast_hash(grid_coords)
            probs = torch.ones_like(hashed)
            sigmas = torch.zeros_like(hashed)
        else:
            probs, hashed, sigmas = (
                self.HashFunction(grid_coords) 
                if self._should_use_all_levels 
                else 
                self._multiresolution_hash(grid_coords, min_possible_collisions)
            )

        print2(("hashed:", hashed, hashed.shape), (self._should_log > 0))
        print2(("hashed grads:", hashed.requires_grad, ), (self._should_log > 0))

        print2(("probs:", probs, probs.shape), (self._should_log > 0))
        print2(("probs grads:", probs.requires_grad, ), (self._should_log > 0))

        print2(("sigmas:", sigmas, sigmas.shape), (self._should_log > 0))
        print2(("sigmas grads:", sigmas.requires_grad, ), (self._should_log > 0))

        _, _, dummy_hashed = self._calc_dummies(hashed=hashed)

        _, collisions = self.calc_hash_collisions(dummy_grids=dummy_grids, dummy_hashed=dummy_hashed)
        del dummy_hashed

        unique_probs, unique_hashed, unique_sigmas = self._calc_uniques(probs, hashed, sigmas, og_indices)
        del og_indices

        if should_calc_hists:
            hists = self._hist_collisions(dummy_grids, unique_hashed, min_possible_collisions, should_show=False)
        else:
            hists = None
        del dummy_grids, sigmas

        return hashed, unique_hashed, unique_probs, unique_sigmas, collisions, min_possible_collisions, hists

    def _multiresolution_hash(self, x: torch.Tensor, min_possible_collision: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Multiresolution hash function.
        Calculates the hash of the input tensor only for levels with more than 0 min_possible_collisions.

        Parameters
        ----------
        x : torch.Tensor
            Input tensor.
        min_possible_collision : torch.Tensor
            Minimum possible collisions.
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
            Hashed tensor, probabilities and sigmas.
        """

        hashed = torch.zeros((x.shape[0], x.shape[1], self._num_levels, self._input_dim**2, 1))
        probs = torch.ones((x.shape[0], x.shape[1], self._num_levels, self._input_dim**2, 1))
        sigmas = torch.zeros((x.shape[0], x.shape[1], self._num_levels, self._input_dim**2, 1))

        x_non_collisions_levels = x[:, :, min_possible_collision <= 0, :, :]

        if self._should_normalize_levels:
            x_non_collisions_levels = x_non_collisions_levels * self._levels.unsqueeze(-1)[:, :, min_possible_collision <= 0, :, :]

        # TODO: check, input is ij so maybe (i + j * level)?
        hashed_non_collisions_levels = (
            (
                (self._levels[:, :, min_possible_collision <= 0, :] + 1) 
                * 
                x_non_collisions_levels[:, :, :, :, 0]
            ) 
            + 
            x_non_collisions_levels[:, :, :, :, 1]
        )

        print2(("Hashed indices of levels without collisions:", hashed_non_collisions_levels, hashed_non_collisions_levels.shape), (self._should_log > 5))
        hashed[:, :, min_possible_collision <= 0, :, :] = hashed_non_collisions_levels.unsqueeze(-1)

        probs_collsions_levels, hashed_collisions_levels, sigmas_collisions_levels = self.HashFunction(x[:, :, min_possible_collision > 0, :, :])
        probs[:, :, min_possible_collision > 0, :, :] = probs_collsions_levels
        sigmas[:, :, min_possible_collision > 0, :, :] = sigmas_collisions_levels
        hashed[:, :, min_possible_collision > 0, :, :] = hashed_collisions_levels

        del hashed_non_collisions_levels, probs_collsions_levels, hashed_collisions_levels, sigmas_collisions_levels

        return probs, hashed, sigmas

    @torch.no_grad()
    def _scale_to_grid(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Scale coordinates to grid.

        Parameters
        ----------
        x : torch.Tensor
            Original coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]
            Scaled coordinates, grid coordinates and dummy grids.
        """

        scaled_coords: torch.Tensor = x.float() * self._levels.float()
        print2(("scaled_coords:", scaled_coords, scaled_coords.shape), (self._should_log > 1))
        print2(("scaled_coords grads:", scaled_coords.requires_grad, ), (self._should_log > 1))

        grid_coords: torch.Tensor = rearrange(
            torch.add(
                torch.floor(scaled_coords),
                self._voxels_helper_hypercube
            ),
            "batch pixels xyz levels verts -> batch pixels levels verts xyz"
        )
        
        if self._should_normalize_levels:
            grid_coords = grid_coords / self._levels.unsqueeze(-1)

        print2(("grid_coords:", grid_coords, grid_coords.shape), (self._should_log > 1))
        print2(("grid_coords grads:", grid_coords.requires_grad, ), (self._should_log > 1))

        return _, grid_coords

    @torch.no_grad()
    def _calc_dummies(
        self,
        grid_coords: torch.Tensor | None = None,
        hashed: torch.Tensor | None = None,
    ) -> Tuple[List[torch.Tensor] | None, List[torch.Tensor] | None, List[torch.Tensor] | None]:
        """
        Calculate dummies.
        If grid_coords is None the dummy_grids won't be calculated.
        If hashed is None the dummy_hashed won't be calculated.

        Parameters
        ----------
        grid_coords : torch.Tensor or None, optional (default is None)
            Grid coordinates.
        hashed : torch.Tensor or None, optional (default is None)
            Hashed grid coordinates.
        
        Returns
        -------
        Tuple[List[torch.Tensor] or None, List[torch.Tensor] or None, List[torch.Tensor] or None]
            Dummy grids, original unique indices and dummy hashed .
        """

        dummy_grids = None
        og_indices = None
        dummy_hashed = None

        if grid_coords is not None:
            dummy_grids: List[Tuple[torch.Tensor]] = [
                # [
                torch.unique(rearrange(grid_coords[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=True) # first dimension is 0 because images should have all same size
                for l in range(self._num_levels)
                # ] for b in range(grid_coords.shape[0])
            ]

            dummy_grids, dummy_grids_inverse_indices = zip(*dummy_grids)
            print2(("dummy_grids:", dummy_grids, [dummy_grids[l].shape for l in range(self._num_levels)]), (self._should_log > 1))
            print2(("dummy_grids grads:", [(dummy_grids[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 3))

            og_indices = [
                torch.tensor([np.where(dummy_grids_inverse_indices[l].detach().cpu().numpy() == i)[0][0] for i in range(len(dummy_grids[l]))])
                for l in range(self._num_levels)
            ]
            print2(("og_indices:", [(og_indices[l], og_indices[l].shape) for l in range(self._num_levels)]), (self._should_log > 1))
            print2(("og_indices grads:", [(og_indices[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 1))

        if hashed is not None:
            dummy_hashed: List[torch.Tensor] = [
                # [
                torch.unique(rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=False) # first dimension is 0 because images should have all same size
                for l in range(self._num_levels)
                # ] for b in range(hashed.shape[0])
            ]
            print2(("dummy_hashed:", dummy_hashed, [dummy_hashed[l].shape for l in range(self._num_levels)]), (self._should_log > 3))
            print2(("dummy_hashed grads:", [(dummy_hashed[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 3))

        return dummy_grids, og_indices, dummy_hashed 
    
    @torch.no_grad()
    def calc_hash_collisions(
        self, 
        dummy_grids: List[torch.Tensor] , 
        dummy_hashed: List[torch.Tensor] | None = None
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate hash collisions.
        If dummy_hashed is None the min_possible_collision will be calculated while collisions won't.
        If dummy_hashed is not None the min_possible_collision won't be calculated while collisions will.

        Parameters
        ----------
        dummy_grid_coords : List[torch.Tensor]
            Grid coordinates.
        dummy_hashed : List[torch.Tensor] or None, optional (default is None)
            Hashed grid coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor or None, torch.Tensor or None]
            Minimum possible collisions and actual collisions at each level.
        """

        min_possible_collisions = None
        collisions = None

        if dummy_hashed is None:
            min_possible_collisions: torch.Tensor = torch.stack([
                torch.tensor((dummy_grids[l].shape[0]) - self._hash_table_size)
                for l in range(self._num_levels)
            ])
            min_possible_collisions[min_possible_collisions < 0] = 0
            print2(("min_possible_collisions:", min_possible_collisions, min_possible_collisions.shape), (self._should_log > 2))
            print2(("min_possible_collisions grads:", min_possible_collisions.requires_grad), (self._should_log > 2))
        else:
            collisions: torch.Tensor = torch.stack([
                # [
                torch.tensor(float(dummy_grids[l].shape[0] - torch.unique(dummy_hashed[l], dim=0).shape[0]))#, requires_grad=True)
                for l in range(self._num_levels)
                # ] 
                # for b in range(hashed.shape[0])
            ])
            print2(("collisions:", collisions, collisions.shape, collisions.dtype), (self._should_log > 2))
            print2(("collisions grads:", collisions.requires_grad, ), (self._should_log > 2))

        return min_possible_collisions, collisions

    @torch.no_grad()
    def _calc_uniques(
        self,
        probs: torch.Tensor,
        hashed: torch.Tensor,
        sigmas: torch.Tensor,
        og_indices: List[torch.Tensor]
    ) -> Tuple[torch.Tensor]:
        """
        Calculate uniques.

        Parameters
        ----------
        probs : torch.Tensor
            Probabilities.
        hashed : torch.Tensor
            Hashed grid coordinates.
        sigmas : torch.Tensor
            Sigmas of hashed grid coordinates.
        og_indices : List[torch.Tensor]
            Original unique indices.

        Returns
        -------
        Tuple[torch.Tensor]
            Unique probabilities, unique hashed, and unique_sigmas.
        """

        unique_probs: List[torch.Tensor] = [
            # [
            F.softmax(rearrange(probs[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]], dim=0)
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_probs:", [(unique_probs[b][l], unique_probs[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_probs:", [(unique_probs[l], unique_probs[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))

        unique_hashed: List[torch.Tensor] = [
            # [
            rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]]
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_hashed:", [(unique_hashed[b][l], unique_hashed[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_hashed:", [(unique_hashed[l], unique_hashed[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))

        unique_sigmas: List[torch.Tensor] = [
            # [
            rearrange(sigmas[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]]
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_sigmas:", [(unique_sigmas[b][l], unique_sigmas[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_sigmas:", [(unique_sigmas[l], unique_sigmas[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))


        return unique_probs, unique_hashed, unique_sigmas

    @torch.no_grad()
    def _hist_collisions(
        self,
        dummy_grids: List[torch.Tensor],
        dummy_hashed: List[torch.Tensor],
        min_possible_collisions: torch.Tensor,
        should_show: bool = False
    ) -> List[plt.Figure]:
        """
        Show collisions.

        Parameters
        ----------
        dummy_grids : List[torch.Tensor]
            Grid coordinates.
        dummy_hashed : List[torch.Tensor]
            Hashed grid coordinates.
        min_possible_collisions : torch.Tensor
            Minimum possible collisions for each level.
        should_show : bool, optional (default is False)
            Whether to show figure.
        
        Returns
        -------
        List[plt.Figure]
            Histograms of collisions, one per level.
        """

        figs=[]

        # for l in range(self._num_levels):
        for l, min_collisions in enumerate(min_possible_collisions):

            if (min_collisions <= 0) and not self._should_fast_hash:
                figs.append(None)
                continue

            indices = dummy_hashed[l].detach().cpu().numpy()

            fig, ax = plt.subplots(figsize=(15, 5))
            ax.hist(
                indices,
                bins=self._hash_table_size,
                range=(0, self._hash_table_size),
                edgecolor='grey', 
                linewidth=0.5
            )

            ax.set_xlim(-1, self._hash_table_size)
            ax.xaxis.set_ticks(np.arange(0, self._hash_table_size, 10))

            start, end = ax.get_ylim()
            step = int(end * 0.1)
            ax.yaxis.set_ticks(np.arange(0, end, step if step > 0 else 1))

            plt.title(f"Level {l} ({int(self._levels[0, 0, l, 0].item())})")
            plt.xlabel("Hashed indices")
            plt.ylabel("Counts")

            figs.append(fig)

            if should_show:
                plt.show()
            
            plt.close()

        return figs

    @torch.no_grad()
    def _fast_hash(self, grid: torch.Tensor) -> torch.Tensor:
        '''
        Implements the hash function proposed by NVIDIA.
        Args:
            grid: A tensor of the shape (batch, pixels, levels, 2^input_dim, input_dim).
               This tensor should contain the vertices of the hyper cube
               for each level.
        Returns:
            A tensor of the shape (batch, pixels, levels, 2^input_dim, 1) containing the
            indices into the hash table for all vertices.
        '''
        tmp = torch.zeros(
            (grid.shape[0], grid.shape[1], self._num_levels, 2**self._input_dim),
            dtype=torch.int64,
            device=device
        )

        for i in range(self._input_dim):
            tmp = torch.bitwise_xor(
                (grid[:, :, :, :, i].to(int) * self._prime_numbers[i]),
                tmp
            )

        hash = torch.remainder(tmp, self._hash_table_size).unsqueeze(-1).to(float)
        del tmp

        return hash

## Train & Test Loops

### Training Loop

In [None]:
def train_loop(
    x: torch.Tensor,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module,
    l_collisions: float,
    l_kl_loss: float,
    l_sigma_loss: float,
    l_zero_bins: float,
    l_l2_reg: float = 0,
    should_calc_hists: bool = False,
    should_kl_hist: bool = False,
    should_log: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, float]:
    """
    Train loop.

    Parameters
    ----------
    x : torch.Tensor
        Train input data.
    model : nn.Module
        Model to train.
    optimizer : torch.optim.Optimizer
        Optimizer.
    loss_fn : nn.Module
        Loss function.
    l_collisions : float
        Collisions loss lambda.
    l_kl_loss : float
        KL loss lambda.
    l_sigma_loss : float
        Sigma loss lambda.
    l_zero_bins : float
        Zero bins loss lambda.
    l_l2_reg : float, optional (default is 0)
        L2 regularization lambda.
    should_calc_hists : bool, optional (default is False)
        Whether to calculate histograms or not.
    should_kl_hist : bool, optional (default is False)
        Whether to calculate KL with histograms or not.
    should_log : int, optional (default is 0)
        - 0: No logging.
        - > 0: Log

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]
        Output, unique probabilities, collisions, minimum possible collisions, histograms, collisions losses, KL losses, sigma loss, zero bins per level, l2_reg loss, loss.
    """
    print2(("Train loop", ), should_log > 0, bcolors.WARNING)
    
    model.train()
    optimizer.zero_grad()

    out, unique_hashed, unique_probs, unique_sigmas, collisions, min_possible_collisions, hists = model(
        x, 
        should_calc_hists=should_calc_hists 
    )

    # print("out shape:", out.shape)

    # print(f"Unique hashed: {[unique_hashed[l].shape for l in range(len(unique_hashed))]}")

    # print(f"Softmax Unique hashed: {torch.softmax(unique_hashed[0].squeeze(-1), dim=-1), unique_hashed[0].shape}")

    # print(f"Unique probs: {unique_probs[0], unique_probs[0].shape}")

    # print("out", out)

    collisions_losses, kl_div_losses, sigma_losses, zero_bins_per_level = loss_fn(
        collisions, 
        min_possible_collisions, 
        unique_probs if not should_kl_hist else rearrange(out, "batch pixels levels verts xyz -> batch levels pixels (verts xyz)")[0], #unique_hashed,  
        unique_sigmas
    )

    # L2 regularization
    l2_reg_loss = torch.tensor(0.0)
    for param in model.parameters():
        l2_reg_loss += torch.norm(param)

    collisions_loss = torch.sum(collisions_losses)
    kl_loss = kl_div_losses.sum()
    sigma_loss = sigma_losses.sum()
    zero_bins_loss = zero_bins_per_level.sum()

    loss = (
        (l_collisions * collisions_loss) 
        + (l_kl_loss * kl_loss) 
        + (l_sigma_loss * sigma_loss) 
        + (l_zero_bins * zero_bins_loss)
        + (l_l2_reg * l2_reg_loss)
    )

    print2(("Loss grads:", loss.requires_grad, ), should_log > 0, bcolors.HEADER)

    loss.backward()
    optimizer.step()

    return out, unique_probs, collisions, min_possible_collisions, hists, collisions_losses, kl_div_losses, sigma_losses, zero_bins_per_level, l2_reg_loss.item(), loss.item()


### Testing Loop

In [None]:
def test_loop(
    x: torch.Tensor,
    model: nn.Module,
    loss_fn: nn.Module,
    l_collisions: float,
    l_kl_loss: float,
    l_sigma_loss: float,
    l_zero_bins: float,
    l_l2_reg: float = 0,
    should_calc_hists: bool = False,
    should_kl_hist: bool = False,
    should_log: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, float]:
    """
    Test loop.

    Parameters
    ----------
    x : torch.Tensor
        Test input data.
    model : nn.Module
        Model to test.
    loss_fn : nn.Module
        Loss function.
    l_collisions : float
        Collisions loss lambda.
    l_kl_loss : float
        KL loss lambda.
    l_sigma_loss : float
        Sigma loss lambda.
    l_zero_bins : float
        Zero bins loss lambda.
    l_l2_reg : float, optional (default is 0)
        L2 regularization lambda.
    should_calc_hists : bool, optional (default is False)
        Whether to calculate histograms or not.
    should_kl_hist : bool, optional (default is False)
        Whether to calculate KL with histograms or not.
    should_log : int, optional (default is 0)
        - 0: No logging.
        - > 0: Log

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, float, float]
        Output, unique probabilities, collisions, minimum possible collisions, histograms, collisions losses, KL losses, sigma loss, zero bins per level, l2_reg loss, loss.
    """
    print2(("Test loop", ), should_log > 0, bcolors.FAIL)
    
    model.eval()

    out, unique_hashed, unique_probs, unique_sigmas, collisions, min_possible_collisions, hists = model(
        x, 
        should_calc_hists=should_calc_hists 
    )

    collisions_losses, kl_div_losses, sigma_losses, zero_bins_per_level = loss_fn(
        collisions, 
        min_possible_collisions, 
        unique_probs if not should_kl_hist else rearrange(out, "batch pixels levels verts xyz -> batch levels pixels (verts xyz)")[0], #unique_hashed, 
        unique_sigmas
    )

    # L2 regularization
    l2_reg_loss = torch.tensor(0.0)
    for name, param in model.named_parameters():
        if name != "_prime_numbers":
            l2_reg_loss += torch.norm(param)

    collisions_loss = torch.sum(collisions_losses)
    kl_loss = kl_div_losses.sum()
    sigma_loss = sigma_losses.sum()
    zero_bins_loss = zero_bins_per_level.sum()

    loss = (
        (l_collisions * collisions_loss) 
        + (l_kl_loss * kl_loss) 
        + (l_sigma_loss * sigma_loss) 
        + (l_zero_bins * zero_bins_loss)
        + (l_l2_reg * l2_reg_loss)
    )

    print2(("Loss grads:", loss.requires_grad, ), should_log > 0, bcolors.HEADER)

    return out, unique_probs, collisions, min_possible_collisions, hists, collisions_losses, kl_div_losses, sigma_losses, zero_bins_per_level, l2_reg_loss.item(), loss.item()

## Metrics, Loss & Optimizer

### Loss

In [None]:
class Loss(nn.Module):
    def __init__(
        self,
        delta: float = 1,
        hash_table_size: int = 2**14,
        should_kl_hist: bool = False,
        should_diff_hist_optimized: bool = False,
        should_use_all_levels: bool = False,
        should_log: int = 0,
    ) -> None:
        """
        Loss module.

        Parameters
        ----------
        delta : float, optional (default is 1)
            Delta parameter for collision loss.
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        should_kl_hist : bool, optional (default is False)
            Whether to calculate KL with histograms or not.
        should_diff_hist_optimized: bool, optional (default is False)
            Whether to use optimized differentiable histogram or not.
        should_use_all_levels : bool, optional (default is False)
            Whether to use all levels or only the ones with collisions.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log helper functions.

        Returns
        -------
        None
        """

        super(Loss, self).__init__()

        self._delta = delta
        self._hash_table_size = hash_table_size
        self._should_kl_hist = should_kl_hist
        self._should_diff_hist_optimized = should_diff_hist_optimized
        self._should_use_all_levels = should_use_all_levels
        self._should_log = should_log

        self.kl_div_loss = nn.KLDivLoss(reduction="batchmean")
        self.mse_loss = nn.MSELoss(reduction="mean")

    def forward(
        self,
        collisions: List,
        min_possible_collisions: List,
        x: List, # levels pixels (verts xyz) 
        sigmas: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:

        delta = min_possible_collisions.clone()
        delta[delta <= 0] = self._delta
    
        collisions_losses: torch.Tensor = (collisions - min_possible_collisions) / delta
        print2(("Collisions Losses:", collisions_losses), self._should_log > 0)
        print2(("Collisions Losses grads:", collisions_losses.requires_grad), self._should_log > 0)

        zero_bins_per_level = torch.zeros(min_possible_collisions.shape)
        kl_div_losses = torch.zeros(x.shape[0])

        for l, level in enumerate(x):
            if not self._should_use_all_levels and min_possible_collisions[l] <= 0:
                continue
            
            p, zero_bins = self._calc_hist_pdf(level)
            zero_bins_per_level[l] = zero_bins

            kl_div_losses[l] = self._kl_div(l, p)

        print2(("KL Losses:", kl_div_losses), self._should_log > 0, bcolors.FAIL)

        sigma_losses = torch.stack([
            self.mse_loss(sigma, torch.zeros_like(sigma))
            for sigma in sigmas
        ])

        return collisions_losses, kl_div_losses, sigma_losses, zero_bins_per_level

    def _calc_hist_pdf(
        self,
        x: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate histogram PDF.

        Parameters
        ----------
        x : torch.Tensor
            Indices to which calculate histogram PDF.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Histogram PDF and count of zerobins.
        """
        
        print2((f"x: {x}, shape: {x.shape}, requires_grad: {x.requires_grad}", ), self._should_log > 1)

        hist_p_nondiff = torch.histc(x, bins=self._hash_table_size, min=0, max=self._hash_table_size) # ? maybe max=(self._hash_table_size - 1)
        print2((f"hist_p_nondiff: {hist_p_nondiff}, shape: {hist_p_nondiff.shape}, sum: {torch.sum(hist_p_nondiff)}, requires_grad: {hist_p_nondiff.requires_grad}", ), self._should_log > 1)

        if self._should_diff_hist_optimized:
            hist_p = torch.sum(histogram(x, bins=torch.torch.linspace(0, self._hash_table_size, self._hash_table_size), bandwidth=torch.tensor(0.5)), dim=0)
        else:
            hist_p = differentiable_histogram(x, bins=self._hash_table_size, min=0, max=self._hash_table_size).squeeze(0).squeeze(0)
        print2((f"hist_p_diff: {hist_p.long()}, shape: {hist_p.shape}, sum: {torch.sum(hist_p)}, requires_grad: {hist_p.requires_grad}", ), self._should_log > 1)

        # assert torch.allclose(hist_p_nondiff, hist_p, atol=1), f"hist_p_nondiff != hist_p: {(hist_p_nondiff - hist_p)}"

        p = hist_p / torch.prod(torch.tensor(x.shape))
        print2((f"p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1)

        zero_bins = torch.count_nonzero(p == 0)

        p[p == 0] = 1e-10
        print2((f"after p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1)

        return p, zero_bins

    def _kl_div(
        self,
        level: int,
        p: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculate KL divergence loss.

        Parameters
        ----------
        level : int
            Level.

        Returns
        -------
        torch.Tensor
            KL divergence loss.
        """

        print2((f"p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1)

        if self._should_kl_hist:
            q = torch.ones(self._hash_table_size) / self._hash_table_size
        else:
            # q = torch.ones(level) / level
            q = torch.arange(level, dtype=torch.float32)
            print2((f"before q: {q}, shape: {q.shape}, requires_grad: {q.requires_grad}", ), self._should_log > 1)

            q = torch.softmax(q, dim=-1)

        # print2((f"after q: {q}, shape: {q.shape}, requires_grad: {q.requires_grad}", ), self._should_log > 1)

        return self.kl_div_loss(p.log(), q)


### Optimizer

In [None]:
def get_optimizer(
    net: torch.nn.Module,
    # encoding_lr: float,
    hash_lr: float,
    # MLP_lr: float,
    # encoding_weight_decay: float,
    hash_weight_decay: float,
    # MLP_weight_decay: float,
    betas: tuple = (0.9, 0.99),
    eps: float = 1e-15
):
    optimizer = torch.optim.Adam(
        [
            # {"params": net.encoding.parameters(), "lr": encoding_lr, "weight_decay": encoding_weight_decay},
            {"params": net.HashFunction.parameters(), "lr": hash_lr, "weight_decay": hash_weight_decay},
            # {"params": net.mlp.parameters(), "lr": MLP_lr, "weight_decay": MLP_weight_decay}
        ],
        betas=betas,
        eps=eps,
    )
    return optimizer

## Early stopper

In [None]:
class EarlyStopper:
    def __init__(self, tolerance: int = 5, min_delta: int = 0, should_reset: bool = True):
        self.tolerance: int = tolerance
        self.min_delta: int = min_delta
        self.best_loss: float = np.inf
        self.counter: int = 0
        self.early_stop: bool = False
        self._should_reset: bool = should_reset

    def __call__(self, loss):
        # print(f"best_loss: {self.best_loss}, loss: {loss}, counter: {self.counter}")

        if abs(self.best_loss - loss) < self.min_delta and (loss < self.best_loss):
            # print("Stall")
            self.counter += 1
        elif abs(self.best_loss - loss) > self.min_delta and (loss > self.best_loss):
            # print("Growing")
            self.counter += 1
        else:
            if not self._should_reset:
                if self.counter <= 0:
                    self.counter = 0
                else:
                    self.counter -= 1
            else:
                self.counter = 0
                self.best_loss = loss

        if self.counter >= self.tolerance:
            self.early_stop = True

## Initialization

### Arguments

In [None]:
root_dir = "./images"
test_size = 0.2

train_images, test_images = train_test_split(
    [
        file for file in os.listdir(root_dir) if ("silhouette" in file) and (file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg"))
    ], 
    test_size=test_size, 
    random_state=65535#random_seed
)

wandb_entity = "fedemonti00"
wandb_project = "project_course"
# wandb_name = "hash_function_training"

try:
    time = wandb_name
except NameError:
    time = (datetime.now(ZoneInfo("Europe/Rome"))).strftime("%Y%m%d%H%M%S")
print("RUN:", time)

histogram_rate = 10

scheduler = None # None, StepLR or CosineAnnealingLR or CosineAnnealingWarmRestarts

hyperparameters = {
    "hash_table_size": 2**8,
    "n_min": 8,
    "n_max": 32,
    "num_levels": 4,
    "output_dim": 2,
    "hash_function_hidden_layers_widths": [8, 32, 8],
    "hash_lr": 1e-3,
    "hash_weight_decay": 1e-6,
    "sigma_scale": 1,
    "kl_hist_loss": True,
    "differentiable_histogram_optimized": False,
    "should_use_all_levels": False, # !! should be true if NOT using kl_hist_loss !!
    "should_random_permute_input": False,
    "should_normalize_levels": False,
    "lambdas_decay": -1, # -1 to disable
    "scheduler": scheduler,
    "scheduler_gamma": None if not scheduler else (0.9 if scheduler == "StepLR" else 1e-4),
    "l_collisions": 0,
    "l_kl_loss": 1e2,
    "l_sigma_loss": 0,
    "l_zero_bins": 0, # 1e-1,
    "l_l2_reg": 0,
    # "epochs": 2000,
    # "epochs": 999,
    "epochs": 1,
    "random_seed": random_seed,
}

early_stopper_tolerance = hyperparameters["epochs"] // 4
early_stopper_min_delta = 1e-6 

# should_log = False
should_log = True
should_wandb = True if hyperparameters["epochs"] > 999 else False

### Wandb

#### Init

In [None]:
if should_wandb:
    # start a new wandb run to track this script
    wandb.init(
        entity = wandb_entity,
        # set the wandb project where this run will be logged
        project = wandb_project,

        name = time,

        # track hyperparameters and run metadata
        config = hyperparameters,
        # config = {
        #     "id_grid_search_params":        id_param,
        #     "grid_search_params":           params,
        #     "random_seed":                  random_seed,
        #     "HPD_learning_rate":            HPD_lr,
        #     "encoding_learning_rate":       encoding_lr,
        #     "MLP_learning_rate":            MLP_lr,
        #     "encoding_weight_decay":        encoding_weight_decay,
        #     "HPD_weight_decay":             HPD_weight_decay,
        #     "MLP_weight_decay":             MLP_weight_decay,
        #     "batch_size%":                  batch_size,
        #     "shuffled_pixels":              should_shuffle_pixels,
        #     "normalized_data":              True if not should_batchnorm_data else "BatchNorm1d",
        #     "architecture":                 "GeneralNeuralGaugeFields",
        #     "dataset":                      image_name,
        #     "epochs":                       epochs,
        #     "color":                        'RGB' if not should_bw else 'BW',
        #     "hash_table_size":              hash_table_size,
        #     "num_levels":                   num_levels,
        #     "n_min":                        n_min,
        #     "n_max":                        n_max,
        #     "MLP_hidden_layers_widths":     str(MLP_hidden_layers_widths),
        #     "HPD_hidden_layers_widths":     str(HPD_hidden_layers_widths),
        #     "HPD_out_features":             HPD_out_features,
        #     "feature_dim":                  feature_dim,
        #     "topk_k":                       topk_k,
        #     "loss_type":                    "JS+KLDiv" if should_sum_js_kl_div else ("KLDiv" if not should_js_div else "JSDiv"),
        #     "loss_lambda_MSE":              l_mse,
        #     "loss_lambda_JS_KL":            l_js_kl,
        #     "loss_lambda_collisions":       l_collisions,
        #     "loss_gamma":                   loss_gamma,
        #     "loss_epsilon":                 loss_epsilon,
        #     "inplace_scatter":              should_inplace_scatter,
        #     "MLP_activations":              "LeakyReLU" if should_leaky_relu else "ReLU",
        #     "collisions_loss_probs":        "topk_only" if should_keep_topk_only else "hash_table_size",
        #     "avg_topk_features":            "softmax_avg" if should_softmax_topk_features else ("weighted_avg" if should_softmax_topk_features != None else None),
        #     "hash_type":                    "HPD" if not should_use_hash_function else "hash_function"
        # }

        save_code = False,
    )

#### Logger

In [None]:
def wandb_log(
    e: int,
    lr: float,
    l_zero_bins: float,
    l_kl_loss: float,

    train: Tuple = None,
    test: Tuple = None,
    fast_hash: Tuple = None,

    # train_loss: float,
    # train_collisions: torch.Tensor,
    # train_min_possible_collisions: torch.Tensor,
    # train_collisions_losses: torch.Tensor,
    # train_kl_losses: torch.Tensor,
    # train_sigma_losses: torch.Tensor,
    # train_l2_reg_loss: float,
    # train_hists: List[plt.Figure],
    # train_zero_bins:torch.Tensor,
    # test_loss: float,
    # test_collisions: torch.Tensor,
    # test_min_possible_collisions: torch.Tensor,
    # test_collisions_losses: torch.Tensor,
    # test_kl_losses: torch.Tensor,
    # test_sigma_losses: torch.Tensor,
    # test_l2_reg_loss: float,
    # test_zero_bins:torch.Tensor,
    # test_hists: List[plt.Figure],
    should_log_hists: bool = False
) -> None:
    """
    Log to wandb.

    Parameters
    ----------
    e : int
        Epoch.
    lr : float
        Learning rate modified by the scheduler.
    l_zero_bins : float
        Zero bins loss lambda.
    l_kl_loss : float
        KL loss lambda.
    train_loss : float
        Train loss.
    train_collisions : torch.Tensor
        Train collisions.
    train_min_possible_collisions : torch.Tensor
        Train minimum possible collisions.
    train_collisions_losses : torch.Tensor
        Train collisions losses.
    train_kl_losses : torch.Tensor
        Train KL losses.
    train_sigma_losses : torch.Tensor
        Train sigma losses.
    train_l2_reg_loss : float
        Train L2 regularization loss.
    train_zero_bins : torch.Tensor
        Train zero bins.
    train_hists : List[plt.Figure]
        Train histograms.
    test_loss : float
        Test loss.
    test_collisions : torch.Tensor
        Test collisions.
    test_min_possible_collisions : torch.Tensor
        Test minimum possible collisions.
    test_collisions_losses : torch.Tensor
        Test collisions losses.
    test_kl_losses : torch.Tensor
        Test KL losses.
    test_sigma_losses : torch.Tensor
        Test sigma losses.
    test_l2_reg_loss : float
        Test L2 regularization loss.
    test_zero_bins : torch.Tensor
        Test zero bins.
    test_hists : List[plt.Figure]
        Test histograms.
    should_log_hists : bool, optional (default is False)
        Whether to log histograms or not.

    Returns
    -------
    None
    """

    if fast_hash is not None:
        _, _, fast_hash_collisions, fast_hash_min_possible_collisions, fast_hash_hists, fast_hash_collisions_losses, fast_hash_kl_losses, fast_hash_sigma_losses, fast_hash_zero_bins, fast_hash_l2_reg_loss, fast_hash_loss = fast_hash
        log = {}
        log["fast_hash/loss"] = fast_hash_loss
        log["fast_hash/l2_reg_loss"] = fast_hash_l2_reg_loss
        
        for l in range(hyperparameters["num_levels"]):
            log[f"fast_hash/collisions_level_{l}"] = fast_hash_collisions[l].item()
            log[f"fast_hash/min_possible_collisions_level_{l}"] = fast_hash_min_possible_collisions[l].item()
            log[f"fast_hash/collisions_loss_level_{l}"] = fast_hash_collisions_losses[l].item()
            log[f"fast_hash/kl_loss_level_{l}"] = fast_hash_kl_losses[l].item()
            log[f"fast_hash/sigma_loss_level_{l}"] = fast_hash_sigma_losses[l].item()
            log[f"fast_hash/zero_bins_level_{l}"] = fast_hash_zero_bins[l].item()
        
            if should_log_hists and fast_hash_hists[l] is not None:
                log[f"fast_hash/hist_counts_level_{l}"] = wandb.Image(
                    fast_hash_hists[l],
                    caption=f"Hashed indices counts at level {l} at epoch {e}"
                )
    else:
        _, _, train_collisions, train_min_possible_collisions, train_hists, train_collisions_losses, train_kl_losses, train_sigma_losses, train_zero_bins, train_l2_reg_loss, train_loss = train
        _, _, test_collisions, test_min_possible_collisions, test_hists, test_collisions_losses, test_kl_losses, test_sigma_losses, test_zero_bins, test_l2_reg_loss, test_loss = test

        log = {
            "train/loss": train_loss,
            "test/loss": test_loss,
            "train/l2_reg_loss": train_l2_reg_loss,
            "test/l2_reg_loss": test_l2_reg_loss,
        }

        if hyperparameters["lambdas_decay"] > -1:
            log["l_zero_bins"] = l_zero_bins
            log["l_kl_loss"] = l_kl_loss

        if hyperparameters["scheduler"] is not None:
            log["lr"] = lr

        for l in range(hyperparameters["num_levels"]):
            log[f"train/collisions_level_{l}"] = train_collisions[l].item()
            log[f"train/min_possible_collisions_level_{l}"] = train_min_possible_collisions[l].item()
            log[f"train/collisions_loss_level_{l}"] = train_collisions_losses[l].item()
            log[f"train/kl_loss_level_{l}"] = train_kl_losses[l].item()
            log[f"train/sigma_loss_level_{l}"] = train_sigma_losses[l].item()
            log[f"train/zero_bins_level_{l}"] = train_zero_bins[l].item()

            log[f"test/collisions_level_{l}"] = test_collisions[l].item()
            log[f"test/min_possible_collisions_level_{l}"] = test_min_possible_collisions[l].item()
            log[f"test/collisions_loss_level_{l}"] = test_collisions_losses[l].item()
            log[f"test/kl_loss_level_{l}"] = test_kl_losses[l].item()
            log[f"test/sigma_loss_level_{l}"] = test_sigma_losses[l].item()
            log[f"test/zero_bins_level_{l}"] = test_zero_bins[l].item()

            
            if should_log_hists and train_hists[l] is not None:
                log[f"train/hist_counts_level_{l}"] = wandb.Image(
                    train_hists[l],
                    caption=f"Hashed indices counts at level {l} at epoch {e}"
                )
            
            if should_log_hists and test_hists[l] is not None:
                log[f"test/hist_counts_level_{l}"] = wandb.Image(
                    test_hists[l],
                    caption=f"Hashed indices counts at level {l} at epoch {e}"
                )
    
    wandb.log(log)

### Load Datasets

In [None]:
train_dataset = ImagesDataset(
    root=root_dir.split("/")[0],
    dir_name=root_dir.split("/")[1],
    images_names=train_images,
    should_random_permute_input=hyperparameters["should_random_permute_input"]
)
x, y, h, w, reordered_indices, names = train_dataset[-1]

test_dataset = ImagesDataset(
    root=root_dir.split("/")[0],
    dir_name=root_dir.split("/")[1],
    images_names=["strawberry_small.jpg"],#test_images
    # should_random_permute_input=hyperparameters["should_random_permute_input"]
)
eval_x, eval_y, eval_h, eval_w, _, eval_names = test_dataset[-1]

input_dim = x.shape[-3]


### Models initialization

In [None]:
hashFunction = HashFunction(
    hidden_layers_widths=hyperparameters["hash_function_hidden_layers_widths"],
    input_dim=input_dim,
    output_dim=hyperparameters["output_dim"],
    hash_table_size=hyperparameters["hash_table_size"],
    sigma_scale=hyperparameters["sigma_scale"],
    should_log=0 if should_log else 0
)

multires = Multiresolution(
    n_min=hyperparameters["n_min"],
    n_max=hyperparameters["n_max"],
    num_levels=hyperparameters["num_levels"],
    HashFunction=hashFunction,
    hash_table_size=hyperparameters["hash_table_size"],
    input_dim=input_dim,
    should_use_all_levels=hyperparameters["should_use_all_levels"],
    should_normalize_levels=hyperparameters["should_normalize_levels"],
    should_log=0 if should_log else 0
)

multires_fast_hash = Multiresolution(
    n_min=hyperparameters["n_min"],
    n_max=hyperparameters["n_max"],
    num_levels=hyperparameters["num_levels"],
    HashFunction=None,
    hash_table_size=hyperparameters["hash_table_size"],
    input_dim=input_dim,
    should_use_all_levels=True,
    should_normalize_levels=hyperparameters["should_normalize_levels"],
    should_fast_hash=True,
    should_log=0 if should_log else 0
)

loss_fn = Loss(
    hash_table_size=hyperparameters["hash_table_size"],
    should_kl_hist=hyperparameters["kl_hist_loss"],
    should_diff_hist_optimized=hyperparameters["differentiable_histogram_optimized"],
    should_use_all_levels=hyperparameters["should_use_all_levels"],
    should_log=0 if should_log else 0
)

optimizer = get_optimizer(
    net=multires,
    hash_lr=hyperparameters["hash_lr"],
    hash_weight_decay=hyperparameters["hash_weight_decay"],
)

if hyperparameters["epochs"] > 1:
    if hyperparameters["scheduler"] == "CosineAnnealingLR":
        scheduler = CosineAnnealingLR(optimizer, T_max=hyperparameters["epochs"]//4, eta_min=hyperparameters["scheduler_gamma"])
    elif hyperparameters["scheduler"] == "CosineAnnealingWarmRestarts":
        scheduler = CosineAnnealingWarmRestarts(optimizer, T_0=hyperparameters["epochs"]//4, T_mult=1, eta_min=hyperparameters["scheduler_gamma"])
    elif hyperparameters["scheduler"] == "StepLR":
        scheduler = StepLR(optimizer, step_size=int(0.05 * hyperparameters["epochs"]), gamma=hyperparameters["scheduler_gamma"])

early_stopper = EarlyStopper(
    tolerance=early_stopper_tolerance,
    min_delta=early_stopper_min_delta
)

print(multires)
print(multires_fast_hash)

### Run

In [None]:
plt.ioff()
should_calc_hists = False

pbar = tqdm(range(0, hyperparameters["epochs"]))

# _, _, fast_hash_collisions, fast_hash_min_possible_collisions, fast_hash_hists, fast_hash_collisions_losses, fast_hash_kl_losses, fast_hash_sigma_losses, fast_hash_zero_bins, fast_hash_l2_reg_loss, fast_hash_loss = test_loop(
fast_hash = test_loop(
    x=eval_x,
    model=multires_fast_hash,
    loss_fn=loss_fn,
    l_l2_reg=hyperparameters["l_l2_reg"],
    l_collisions=hyperparameters["l_collisions"],
    l_kl_loss=hyperparameters["l_kl_loss"],
    l_sigma_loss=hyperparameters["l_sigma_loss"],
    l_zero_bins=hyperparameters["l_zero_bins"],
    # Calc histograns at: first epoch, last epoch, every histogram_rate epochs, every time early stopper stops
    should_calc_hists=True,
    should_kl_hist=hyperparameters["kl_hist_loss"],
    should_log=1 if should_log else 0
)

if should_wandb:
    wandb_log(
        e=-1,
        lr=None,
        l_zero_bins=hyperparameters["l_zero_bins"],
        l_kl_loss=hyperparameters["l_kl_loss"],
        fast_hash=fast_hash,
        should_log_hists=True
    )

for e in pbar:
    should_calc_hists = ((e == hyperparameters["epochs"] - 1) or (e % histogram_rate == 0) or early_stopper.early_stop)
    
    # _, _, train_collisions, train_min_possible_collisions, train_hists, train_collisions_losses, train_kl_losses, train_sigma_losses, train_zero_bins, train_l2_reg_loss, train_loss = train_loop(
    train = train_loop(
        x=x,
        model=multires,
        optimizer=optimizer,
        loss_fn=loss_fn,
        l_collisions=hyperparameters["l_collisions"],
        l_kl_loss=hyperparameters["l_kl_loss"],
        l_sigma_loss=hyperparameters["l_sigma_loss"],
        l_zero_bins=hyperparameters["l_zero_bins"],
        l_l2_reg=hyperparameters["l_l2_reg"],
        # Calc histograns at: first epoch, last epoch, every histogram_rate epochs, every time early stopper stops
        should_calc_hists=should_calc_hists,
        should_kl_hist=hyperparameters["kl_hist_loss"],
        should_log=1 if should_log else 0
    )

    # _, _, test_collisions, test_min_possible_collisions, test_hists, test_collisions_losses, test_kl_losses, test_sigma_losses, test_zero_bins, test_l2_reg_loss, test_loss = test_loop(
    test = test_loop(
        x=eval_x,
        model=multires,
        loss_fn=loss_fn,
        l_l2_reg=hyperparameters["l_l2_reg"],
        l_collisions=hyperparameters["l_collisions"],
        l_kl_loss=hyperparameters["l_kl_loss"],
        l_sigma_loss=hyperparameters["l_sigma_loss"],
        l_zero_bins=hyperparameters["l_zero_bins"],
        # Calc histograns at: first epoch, last epoch, every histogram_rate epochs, every time early stopper stops
        should_calc_hists=should_calc_hists,
        should_kl_hist=hyperparameters["kl_hist_loss"],
        should_log=1 if should_log else 0
    )

    if should_wandb:
        wandb_log(
            e=e,
            lr=scheduler.get_last_lr()[0] if hyperparameters["scheduler"] is not None else None, 
            l_zero_bins=hyperparameters["l_zero_bins"],
            l_kl_loss=hyperparameters["l_kl_loss"],
            train=train,
            test=test,
            # train_loss=train_loss,
            # train_collisions=train_collisions,
            # train_min_possible_collisions=train_min_possible_collisions,
            # train_collisions_losses=train_collisions_losses,
            # train_kl_losses=train_kl_losses,
            # train_sigma_losses=train_sigma_losses,
            # train_l2_reg_loss=train_l2_reg_loss,
            # train_zero_bins=train_zero_bins,
            # train_hists=train_hists,
            # test_loss=test_loss,
            # test_collisions=test_collisions,
            # test_min_possible_collisions=test_min_possible_collisions,
            # test_collisions_losses=test_collisions_losses,
            # test_kl_losses=test_kl_losses,
            # test_sigma_losses=test_sigma_losses,
            # test_l2_reg_loss=test_l2_reg_loss,
            # test_zero_bins=test_zero_bins,
            # test_hists=test_hists,
            should_log_hists=should_calc_hists
        )

    if (hyperparameters["lambdas_decay"] > -1) and e != 0 and e % hyperparameters["lambdas_decay"] == 0:
        hyperparameters["l_zero_bins"] = hyperparameters["l_zero_bins"] * 0.9
        hyperparameters["l_kl_loss"] = hyperparameters["l_kl_loss"] * 1.001
    
    if hyperparameters["scheduler"] is not None and hyperparameters["epochs"] > 1:
        scheduler.step()
    
    # if np.isnan(train_loss):
    if np.isnan(train[-1]):
        break 

    # pbar.set_description(f"Epoch {e}, Loss: {test_loss}, Collisions: {test_collisions}")
    pbar.set_description(f"Epoch {e}, Loss: {test[-1]}, Collisions: {test[2]}")

    # for name, param in multires.named_parameters():
    #     print(f'Parameter: {name}, Gradient: {param.grad}')

    if early_stopper.early_stop:
        print("!!! Stopping at epoch:", e, "!!!")
        break

    # early_stopper(test_loss)
    early_stopper(test[-1])


In [None]:
if should_wandb:
    wandb.finish()