# NeRF: Collision handling in Instant Neural Graphics Primitives
#### Federico Montagna (fedemonti00@gmail.com)

# Code:

## Imports

In [1]:
from typing import Tuple, List

import numpy as np
from math import atan2
import cv2
from PIL import Image
import os
import matplotlib
from matplotlib.ticker import MultipleLocator, AutoMinorLocator, FixedLocator
import matplotlib.pyplot as plt
import random
from scipy.stats import norm
from sklearn.model_selection import train_test_split

import collections, functools, operator
from collections import Counter
import itertools

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.profiler as profiler

# from torchinfo import summary

from einops import rearrange, reduce, repeat

import traceback
from pprint import pprint

import wandb
from tqdm import tqdm

import pdb

try:
    from zoneinfo import ZoneInfo
except ImportError:
    from backports.zoneinfo import ZoneInfo

from datetime import datetime

# from params import *

print("Cuda avilable:", torch.cuda.is_available())
if torch.cuda.is_available():
    for i in range(torch.cuda.device_count()):
        print(f"Available device {i}:", torch.cuda.get_device_name(i))

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
if torch.cuda.is_available():
    print(f"Current device {torch.cuda.current_device()}:", torch.cuda.get_device_name(torch.cuda.current_device()))

torch.set_default_device(device)

random_seed = 2**16 - 1
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.random.manual_seed(random_seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(random_seed)
print("Random seed:", random_seed)

os.environ["WANDB_NOTEBOOK_NAME"] = "main.ipynb"

plt.style.use("ggplot")

Cuda avilable: True
Available device 0: NVIDIA GeForce RTX 2070
Available device 1: NVIDIA GeForce RTX 2070
Current device 0: NVIDIA GeForce RTX 2070
Random seed: 65535


## Debug Functions

In [2]:
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def print2(texts, log: bool = False, color: bcolors = bcolors.OKCYAN) -> None:
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(color, "Line: ", calling_line, bcolors.ENDC)
        for text in texts:
            print(text)
        print(color, "-"*20, bcolors.ENDC)


def print_allocated_memory(log: bool = True):
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(bcolors.HEADER, "Line: ", calling_line, bcolors.ENDC)

        allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Allocated Memory: {allocated_memory:.2f} GB")

        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Peak Allocated Memory: {peak_memory:.2f} GB")

        print(bcolors.OKCYAN, "-"*20, bcolors.ENDC)

## Load wandb apikey

In [3]:
# apikey_path = ".wandb_apikey.txt"
# if os.path.exists(apikey_path):
#     with open(apikey_path, "r") as f:
#         apikey = f.read()
#         !wandb login {apikey} # --relogin

## Dataset

In [4]:
class ImagesDataset(Dataset):
    def __init__(
        self,
        root: str,
        dir_name: str,
        images_names: List[str],
    ) -> None:
        """

        Parameters
        ----------
        root : str
            Path to root directory.
        dir_name : str
            Name of directory.
        images_names : List[str]
            List of images names.
        
        Returns
        -------
        None
        """

        self._root: str = root
        self._dir_name: str = dir_name
        self._images_names: List[str] = images_names

        self._images_paths: List[str] = [
            os.path.join(self._root, self._dir_name, image_name)
            for image_name in self._images_names
        ]
    
    def __getitem__(self, idx: torch.Tensor or int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
        """
        Get data by indices.

        Parameters
        ----------
        idx : torch.Tensor or int
            Indices of data.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, int, int]
            Tuple of images.

        Raises
        ------
        ValueError
            If images have different sizes.
        """

        if idx == -1:
            idx = torch.arange(len(self._images_paths))
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        else:
            idx = [idx]

        images: List = [
            rearrange(io.read_image(self._images_paths[_id]), "rgb h w -> h w rgb")
            for _id in idx
        ]

        heights, widths = list(set([image.shape[0] for image in images])), list(set([image.shape[1] for image in images]))

        if len(heights) > 1 or len(widths) > 1:
            raise ValueError("Images have different sizes.")
        else:
            h = heights[0]
            w = widths[0]

        X: torch.Tensor = (
            torch.stack([
                torch.tensor(
                    np.stack(np.meshgrid(range(h), range(w), indexing="ij"), axis=-1).reshape(-1, 2)
                )
                for _ in images
            ]).float() / (max(w, h)) # No (max(w, h) - 1)
        ).unsqueeze(-1).unsqueeze(-1) #.requires_grad_()
        
        Y: torch.Tensor = torch.stack(
            images
        ).float() / 255

        return X, Y, h, w, self._images_names

    def __len__(self) -> int:
        """
        Returns
        -------
        int
            Length of dataset.
        """
        return len(self._images_paths)

## Models

### Backward Pass Differentiable Approximation

In [5]:
def BPDA(x, round_function):
    forward_value = round_function(x)
    out = x.clone()
    out.data = forward_value.data
    return out

def differentiable_floor(x, round_function=torch.floor):
    return BPDA(x, round_function)

def differentiable_round(x, round_function=torch.round):
    return BPDA(x, round_function)

### Gaussian Convolution Backward

In [6]:
class GaussianConvolution(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, hash_table_size: int):
        # pdb.set_trace()
        
        ctx.save_for_backward(x)
        ctx.hash_table_size = hash_table_size

        indices = differentiable_round(x * hash_table_size)

        # ctx.indices = indices
        # ctx.save_for_backward(indices)

        return indices

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # pdb.set_trace()
        
        # indices = ctx.saved_tensors
        x, = ctx.saved_tensors
        hash_table_size = ctx.hash_table_size

        indices = x * hash_table_size
        # indices = ctx.indices

        grad_x = grad_output * norm.pdf(np.arange(0, hash_table_size, 1), loc=indices, scale=1)

        return grad_x, None

### Differentiable Histogram

In [7]:
class DifferentiableHistogram(nn.Module):
    def __init__(self, num_bins):
        super(DifferentiableHistogram, self).__init__()
        self.num_bins = num_bins
        self.register_buffer('counts', torch.zeros(num_bins))

        # self.hist_grad = torch.zeros(num_bins)

    def forward(self, input):
        # Compute the histogram
        # hist = torch.histc(input, bins=self.num_bins, min=0, max=self.num_bins) # ? maybe max=(self.num_bins - 1)
        hist = torch.unique(input).int().bincount(minlength=self.num_bins)
        # hist = hist.requires_grad_()
 
        # Update the counts buffer
        self.counts += hist

        # # Store the gradients
        # self.hist_grad = hist

        return hist

    def backward(self, grad_output):
        # Compute the gradients of the counts buffer
        grad_input = grad_output * self.counts / self.counts.sum()

        # # Reset the gradients
        # self.hist_grad = torch.zeros(self.num_bins)

        return grad_input

class SoftHistogram(nn.Module):
    def __init__(self, bins, min, max, sigma=1):
        super(SoftHistogram, self).__init__()
        self.bins = bins
        self.min = min
        self.max = max
        self.sigma = sigma
        self.delta = float(max - min) / float(bins)
        self.centers = float(min) + self.delta * (torch.arange(bins).float() + 0.5)

    def forward(self, x):

        # a = torch.unsqueeze(x, 0)
        # print("x.shape:", x.shape, "a.shape", a.shape)
        # b = torch.unsqueeze(self.centers, 1)
        # print("centers.shape:", self.centers.shape, "b.shape", b.shape)

        
        # x = torch.unsqueeze(x, 0) - torch.unsqueeze(self.centers, 1)
        # x = torch.sigmoid(self.sigma * (x + self.delta/2)) - torch.sigmoid(self.sigma * (x - self.delta/2))
        # x = x.sum(dim=1)
        # return x
        x = x.unsqueeze(-1).unsqueeze(-1)
        ## input should be reshaped into [B, len]
        b, c, h, w = x.shape
        input = x.view(B, -1)
        x = torch.unsqueeze(input, 1) - torch.unsqueeze(self.centers, -1)
        x = torch.sigmoid(self.sigma * x)
        diff = torch.cat([torch.ones((b,1,h*w),device=input.device), x],dim=1) - torch.cat([x, torch.zeros((b,1,h*w),device=input.device)],dim=1)

        diff = diff.sum(dim=-1)
        diff[:,-2] += diff[:,-1]
        return diff[:,:-1]

#############################################
# Differentiable Histogram Counting Method
#############################################
# https://github.com/hyk1996/pytorch-differentiable-histogram
def differentiable_histogram(x, bins=255, min=0.0, max=1.0):

    if len(x.shape) == 4:
        n_samples, n_chns, _, _ = x.shape
    elif len(x.shape) == 2:
        n_samples, n_chns = 1, 1
    else:
        raise AssertionError('The dimension of input tensor should be 2 or 4.')

    hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device)
    delta = (max - min) / bins

    BIN_Table = torch.arange(start=0, end=bins, step=1) * delta

    for dim in range(1, bins-1, 1):
        h_r = BIN_Table[dim].item()             # h_r
        h_r_sub_1 = BIN_Table[dim - 1].item()   # h_(r-1)
        h_r_plus_1 = BIN_Table[dim + 1].item()  # h_(r+1)

        mask_sub = ((h_r > x) & (x >= h_r_sub_1)).float()
        mask_plus = ((h_r_plus_1 > x) & (x >= h_r)).float()

        hist_torch[:, :, dim] += torch.sum(((x - h_r_sub_1) * mask_sub).view(n_samples, n_chns, -1), dim=-1)
        hist_torch[:, :, dim] += torch.sum(((h_r_plus_1 - x) * mask_plus).view(n_samples, n_chns, -1), dim=-1)

    return hist_torch / delta

# Bard
# def differentiable_histogram_optimized(x, bins=255, min=0.0, max=1.0):
#     if len(x.shape) == 4:
#         n_samples, n_chns, _, _ = x.shape
#     elif len(x.shape) == 2:
#         n_samples, n_chns = 1, 1
#     else:
#         raise AssertionError('The dimension of input tensor should be 2 or 4.')

#     hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device)
#     delta = (max - min) / bins

#     # BIN_Table = torch.arange(start=0, end=bins, step=1) * delta

#     print(x.shape, BIN_Table.shape)

#     # Create a boolean mask representing which bin each pixel belongs to
#     mask = (x >= BIN_Table[:-1]).float() & (x < BIN_Table[1:])

#     # Calculate the histogram using a single accumulation step
#     hist_torch = torch.sum(mask.view(n_samples, n_chns, bins, 1) * BIN_Table, dim=3)

#     return hist_torch / delta

# Chat-GPT
def differentiable_histogram_optimized(x, bins=255, min=0.0, max=1.0):
    if len(x.shape) == 4:
        n_samples, n_chns, _, _ = x.shape
    elif len(x.shape) == 2:
        n_samples, n_chns = 1, 1
    else:
        raise AssertionError('The dimension of input tensor should be 2 or 4.')

    hist_torch = torch.zeros(n_samples, n_chns, bins).to(x.device)
    delta = (max - min) / bins

    bin_edges = torch.linspace(min, max, bins + 1, device=x.device)

    # Compute bin indices for each element in the input tensor
    bin_indices = torch.bucketize(x, bin_edges, right=True)

    # Create masks for each bin
    mask_sub = (bin_indices > 0).float()
    mask_plus = (bin_indices < bins).float()

    # Compute contributions to the histogram using cumulative sum

    print(bin_indices, bin_indices.shape)
    print(hist_torch.shape)
    print(mask_plus.shape, mask_sub.shape)
    print(torch.cumsum(x, dim=-1).shape)
    print((torch.cumsum(x, dim=-1)[:, bin_indices] * mask_sub).shape)
    # hist_torch += torch.cumsum(x, dim=-1)[:, :, bin_indices] * mask_sub
    # hist_torch[:, :, 1:] -= torch.cumsum(x, dim=-1)[:, :, bin_indices - 1] * mask_sub
    # hist_torch[:, :, :-1] += torch.cumsum(x, dim=-1)[:, :, bin_indices + 1] * mask_plus
    # hist_torch[:, :, -1] -= torch.cumsum(x, dim=-1)[:, :, bin_indices] * mask_plus

    hist_torch += torch.cumsum(x, dim=-1)[:, bin_indices] * mask_sub
    hist_torch[:, :, 1:] -= torch.cumsum(x, dim=-1)[:, bin_indices - 1] * mask_sub
    hist_torch[:, :, :-1] += torch.cumsum(x, dim=-1)[:, bin_indices + 1] * mask_plus
    hist_torch[:, :, -1] -= torch.cumsum(x, dim=-1)[:, bin_indices] * mask_plus
    

    return hist_torch / delta

### Hash Function

(Student's T Table)[https://www.craftonhills.edu/current-students/tutoring-center/mathematics-tutoring/distribution_tables_normal_studentt_chisquared.pdf]  
$\alpha = 0.8$, $DF = +\inf$

In [8]:
class HashFunction(nn.Module):
    def __init__(
        self,
        hidden_layers_widths: List[int],
        input_dim: int = 2,
        output_dim: int = 1,
        hash_table_size: int = 2**14,
        sigma_scale: float = 1.0,
        should_log: int = 0, 
    ) -> None:
        """
        Hash function module.

        Parameters
        ----------
        hidden_layers_widths : List[int]
            List of hidden layers widths.
        input_dim : int, optional (default is 2)
            Input dimension
        output_dim : int, optional (default is 1)
            Output dimension
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        sigma_scale : float, optional (default is 1.0)
            Sigma scale.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log layers grads and outputs.

        Returns
        -------
        None
        """

        super(HashFunction, self).__init__()

        self._hash_table_size: int = hash_table_size - 1
        self._sigma_scale: float = sigma_scale

        self._alpha: float = 0.8 # student's t-distribution confidence level
        
        self._should_log: int = should_log

        layers_widths: List[int] = [input_dim, *hidden_layers_widths, output_dim]

        self.module_list: nn.ModuleList = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_features=layers_widths[i], out_features=layers_widths[i + 1]),
                nn.ReLU() if i < len(layers_widths) - 2 else nn.Sigmoid()
            ) for i in range(len(layers_widths) - 1)
        ])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
        print2(("Input:", x, x.shape), (self._should_log > 0))
        print2(("Input grad:", x.requires_grad, ), (self._should_log > 0))

        for i, layer in enumerate(self.module_list):
            x = layer(x)

        print2((f"After layers:", f"Output: {x}, shape: {x.shape}"), (self._should_log > 0))
        print2((f"After layers:", f"Grad info: {x.requires_grad}, {x.grad_fn}"), (self._should_log > 1))
        
        # x = torch.nan_to_num(x) # Sanitize nan to 0.0

        if x.shape[-1] == 1: # directly indices
            indices = GaussianConvolution.apply(x, self._hash_table_size)
            print2(("GaussianConvolution:", indices, indices.shape), (self._should_log > 0))
            print2(("GaussianConvolution grad:", indices.requires_grad, ), (self._should_log > 1))

            sigma = None
        else: # mu and sigma            
            x = x.unsqueeze(-1)
            print2(("x:", x, x.shape), (self._should_log > 0))

            sigma = (x[..., 1, :] * self._sigma_scale)
            x = x[..., 0, :]
            mu = (differentiable_round(x * self._hash_table_size))

            print2(("Mu:", mu, mu.shape), (self._should_log > 0))
            print2(("Sigma:", sigma, sigma.shape), (self._should_log > 0))

            a, b = norm.interval(self._alpha, loc=mu.detach().cpu().numpy(), scale=sigma.detach().cpu().numpy()) # a -> at left of mean, b -> at right of mean
            a, b = np.round(a), np.round(b)

            print2(("a:", a, a.shape), (self._should_log > 0))
            print2(("b:", b, b.shape), (self._should_log > 0))

            indices = mu

            del mu, a, b

        return x, indices, sigma

### Multiresolution

In [9]:
class Multiresolution(nn.Module):
    def __init__(
        self,
        n_min: int,
        n_max: int,
        num_levels: int,
        HashFunction: HashFunction,
        hash_table_size: int = 2**14,
        input_dim: int = 2,
        should_log: int = 0
    ) -> None:
        """
        Multiresolution module.

        Parameters
        ----------
        n_min : int
            Minimum scaling factor.
        n_max : int
            Maximum scaling factor.
        num_levels : int
            Number of levels.
        HashFunction : HashFunction
            Hash function module.
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        input_dim : int, optional (default is 2)
            Input dimension.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log helper functions.
            - > 2: Log collisions.
            - > 3: Log dummies.
            - > 5: Log initialization.

        Returns
        -------
        None
        """

        super(Multiresolution, self).__init__()

        self.HashFunction: nn.Module = HashFunction

        self._hash_table_size: int = hash_table_size
        self._num_levels: int = num_levels
        self._input_dim: int = input_dim
        self._should_log: int = should_log

        b: torch.Tensor = torch.tensor(np.exp((np.log(n_max) - np.log(n_min)) / (self._num_levels - 1))).float()
        if b > 2 or b <= 1:
            print(
                f"The between level scale is recommended to be <= 2 and needs to be > 1 but was {b:.4f}."
            )
        
        self._levels: torch.Tensor = torch.stack([
            torch.floor(n_min * (b ** l)) for l in range(self._num_levels)
        ]).reshape(1, 1, -1, 1)
        print2(("Levels:", self._levels, self._levels.shape), (self._should_log > 5))
        print2(("Levels grads:", self._levels.requires_grad, ), (self._should_log > 5))

        self._voxels_helper_hypercube: torch.Tensor = rearrange(
            torch.tensor(
                np.stack(np.meshgrid(range(2), range(2), range(self._input_dim - 1), indexing="ij"), axis=-1)
            ),
            "cols rows depths verts -> (depths rows cols) verts"
        ).T[:self._input_dim, :].unsqueeze(0).unsqueeze(2)
        print2(("voxels_helper_hypercube:", self._voxels_helper_hypercube, self._voxels_helper_hypercube.shape), (self._should_log > 5))
        print2(("voxels_helper_hypercube grads:", self._voxels_helper_hypercube.requires_grad), (self._should_log > 5))

    def forward(
        self, 
        x: torch.Tensor, 
        should_calc_hists: bool = False
    ) -> Tuple[torch.Tensor]:
        # print2(("x:", x, x.shape), (self._should_log > 0))
        print2(("x grads:", x.requires_grad, ), (self._should_log > 0))

        scaled_coords, grid_coords = self._scale_to_grid(x)

        probs, hashed, sigmas = self.HashFunction(grid_coords)
        print2(("hashed:", hashed, hashed.shape), (self._should_log > 0))
        print2(("hashed grads:", hashed.requires_grad, ), (self._should_log > 0))

        print2(("probs:", probs, probs.shape), (self._should_log > 0))
        print2(("probs grads:", probs.requires_grad, ), (self._should_log > 0))

        print2(("sigmas:", sigmas, sigmas.shape), (self._should_log > 0))
        print2(("sigmas grads:", sigmas.requires_grad, ), (self._should_log > 0))

        dummy_grids, dummy_hashed, og_indices = self._calc_dummies(grid_coords, hashed)

        collisions, min_possible_collisions = self.calc_hash_collisions(dummy_grids, dummy_hashed)
        del dummy_hashed

        unique_probs, unique_hashed, unique_sigmas = self._calc_uniques(probs, hashed, sigmas, og_indices)
        del og_indices

        if should_calc_hists:
            hists = self._hist_collisions(dummy_grids, unique_hashed, should_show=False)
        else:
            hists = None
        del dummy_grids, sigmas

        return hashed, unique_hashed, unique_probs, unique_sigmas, collisions, min_possible_collisions, hists

    @torch.no_grad()
    def _scale_to_grid(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Scale coordinates to grid.

        Parameters
        ----------
        x : torch.Tensor
            Original coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]
            Scaled coordinates, grid coordinates and dummy grids.
        """

        scaled_coords: torch.Tensor = x.float() * self._levels.float()
        print2(("scaled_coords:", scaled_coords, scaled_coords.shape), (self._should_log > 1))
        print2(("scaled_coords grads:", scaled_coords.requires_grad, ), (self._should_log > 1))

        grid_coords: torch.Tensor = rearrange(
            torch.add(
                torch.floor(scaled_coords),
                self._voxels_helper_hypercube
            ),
            "batch pixels xyz levels verts -> batch pixels levels verts xyz"
        )
        print2(("grid_coords:", grid_coords, grid_coords.shape), (self._should_log > 1))
        print2(("grid_coords grads:", grid_coords.requires_grad, ), (self._should_log > 1))

        return scaled_coords, grid_coords

    @torch.no_grad()
    def _calc_dummies(
        self,
        grid_coords: torch.Tensor,
        hashed: torch.Tensor,
    ) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]:
        """
        Calculate dummies.

        Parameters
        ----------
        grid_coords : torch.Tensor
            Grid coordinates.
        hashed : torch.Tensor
            Hashed grid coordinates.
        
        Returns
        -------
        Tuple[List[torch.Tensor], List[torch.Tensor], List[torch.Tensor]]
            Dummy grids and dummy hashed and original unique indices.
        """

        dummy_grids: List[Tuple[torch.Tensor]] = [
            # [
            torch.unique(rearrange(grid_coords[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=True) # first dimension is 0 because images should have all same size
            for l in range(self._num_levels)
            # ] for b in range(grid_coords.shape[0])
        ]

        dummy_grids, dummy_grids_inverse_indices = zip(*dummy_grids)
        print2(("dummy_grids:", dummy_grids, [dummy_grids[l].shape for l in range(self._num_levels)]), (self._should_log > 1))
        print2(("dummy_grids grads:", [(dummy_grids[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 3))

        og_indices = [
            torch.tensor([np.where(dummy_grids_inverse_indices[l].detach().cpu().numpy() == i)[0][0] for i in range(len(dummy_grids[l]))])
            for l in range(self._num_levels)
        ]
        print2(("og_indices:", [(og_indices[l], og_indices[l].shape) for l in range(self._num_levels)]), (self._should_log > 1))
        print2(("og_indices grads:", [(og_indices[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 1))

        dummy_hashed: List[torch.Tensor] = [
            # [
            torch.unique(rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=False) # first dimension is 0 because images should have all same size
            for l in range(self._num_levels)
            # ] for b in range(hashed.shape[0])
        ]
        print2(("dummy_hashed:", dummy_hashed, [dummy_hashed[l].shape for l in range(self._num_levels)]), (self._should_log > 3))
        print2(("dummy_hashed grads:", [(dummy_hashed[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 3))

        return dummy_grids, dummy_hashed, og_indices
    
    @torch.no_grad()
    def calc_hash_collisions(
        self, 
        dummy_grids: List[torch.Tensor], 
        dummy_hashed: List[torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Calculate hash collisions.

        Parameters
        ----------
        dummy_grid_coords : List[torch.Tensor]
            Grid coordinates.
        dummy_hashed : List[torch.Tensor]
            Hashed grid coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor]
            Collisions and minimum possible collisions at each level.
        """

        min_possible_collisions: torch.Tensor = torch.stack([
            torch.tensor((dummy_grids[l].shape[0]) - self._hash_table_size)
            for l in range(self._num_levels)
        ])
        min_possible_collisions[min_possible_collisions < 0] = 0
        print2(("min_possible_collisions:", min_possible_collisions, min_possible_collisions.shape), (self._should_log > 2))
        print2(("min_possible_collisions grads:", min_possible_collisions.requires_grad), (self._should_log > 2))

        collisions: torch.Tensor = torch.stack([
            # [
            torch.tensor(float(dummy_grids[l].shape[0] - torch.unique(dummy_hashed[l], dim=0).shape[0]))#, requires_grad=True)
            for l in range(self._num_levels)
            # ] 
            # for b in range(hashed.shape[0])
        ])
        print2(("collisions:", collisions, collisions.shape, collisions.dtype), (self._should_log > 2))
        print2(("collisions grads:", collisions.requires_grad, ), (self._should_log > 2))

        return collisions, min_possible_collisions

    @torch.no_grad()
    def _calc_uniques(
        self,
        probs: torch.Tensor,
        hashed: torch.Tensor,
        sigmas: torch.Tensor,
        og_indices: List[torch.Tensor]
    ) -> Tuple[torch.Tensor]:
        """
        Calculate uniques.

        Parameters
        ----------
        probs : torch.Tensor
            Probabilities.
        hashed : torch.Tensor
            Hashed grid coordinates.
        sigmas : torch.Tensor
            Sigmas of hashed grid coordinates.
        og_indices : List[torch.Tensor]
            Original unique indices.

        Returns
        -------
        Tuple[torch.Tensor]
            Unique probabilities, unique hashed, and unique_sigmas.
        """

        unique_probs: List[torch.Tensor] = [
            # [
            F.softmax(rearrange(probs[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]], dim=0)
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_probs:", [(unique_probs[b][l], unique_probs[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_probs:", [(unique_probs[l], unique_probs[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))

        unique_hashed: List[torch.Tensor] = [
            # [
            rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]]
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_hashed:", [(unique_hashed[b][l], unique_hashed[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_hashed:", [(unique_hashed[l], unique_hashed[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))

        unique_sigmas: List[torch.Tensor] = [
            # [
            rearrange(sigmas[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]]
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_sigmas:", [(unique_sigmas[b][l], unique_sigmas[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_sigmas:", [(unique_sigmas[l], unique_sigmas[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))


        return unique_probs, unique_hashed, unique_sigmas

    @torch.no_grad()
    def _hist_collisions(
        self,
        dummy_grids: List[torch.Tensor],
        dummy_hashed: List[torch.Tensor],
        should_show: bool = False
    ) -> List[plt.Figure]:
        """
        Show collisions.

        Parameters
        ----------
        dummy_grids : List[torch.Tensor]
            Grid coordinates.
        dummy_hashed : List[torch.Tensor]
            Hashed grid coordinates.
        should_show : bool, optional (default is False)
            Whether to show figure.
        
        Returns
        -------
        List[plt.Figure]
            Histograms of collisions, one per level.
        """

        figs=[]

        for l in range(self._num_levels):

            indices = dummy_hashed[l].detach().cpu().numpy()

            fig, ax = plt.subplots(figsize=(15, 5))
            ax.hist(
                indices,
                bins=self._hash_table_size,
                range=(0, self._hash_table_size),
                edgecolor='grey', 
                linewidth=0.5
            )

            ax.set_xlim(-1, self._hash_table_size)
            ax.xaxis.set_ticks(np.arange(0, self._hash_table_size, 10))

            start, end = ax.get_ylim()
            step = int(end * 0.1)
            ax.yaxis.set_ticks(np.arange(0, end, step if step > 0 else 1))

            plt.title(f"Level {l} ({int(self._levels[0, 0, l, 0].item())})")
            plt.xlabel("Hashed indices")
            plt.ylabel("Counts")

            figs.append(fig)

            if should_show:
                plt.show()
            
            plt.close()

        return figs


## Train & Test Loops

### Training Loop

In [10]:
def train_loop(
    x: torch.Tensor,
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    loss_fn: nn.Module,
    should_calc_hists: bool = False,
    should_kl_hist: bool = False,
    should_log: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, float]:
    """
    Train loop.

    Parameters
    ----------
    x : torch.Tensor
        Train input data.
    model : nn.Module
        Model to train.
    optimizer : torch.optim.Optimizer
        Optimizer.
    loss_fn : nn.Module
        Loss function.
    should_calc_hists : bool, optional (default is False)
        Whether to calculate histograms or not.
    should_kl_hist : bool, optional (default is False)
        Whether to calculate KL with histograms or not.
    should_log : int, optional (default is 0)
        - 0: No logging.
        - > 0: Log

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, float]
        Model output, unique probabilities, collisions, histograms, collisions losses, KL losses and loss.
    """
    print2(("Train loop", ), should_log > 0, bcolors.WARNING)
    
    model.train()
    optimizer.zero_grad()

    out, unique_hashed, unique_probs, unique_sigmas, collisions, min_possible_collisions, hists = multires(
        x, 
        should_calc_hists=should_calc_hists 
    )

    # print("out shape:", out.shape)

    # print(f"Unique hashed: {[unique_hashed[l].shape for l in range(len(unique_hashed))]}")

    # print(f"Softmax Unique hashed: {torch.softmax(unique_hashed[0].squeeze(-1), dim=-1), unique_hashed[0].shape}")

    # print(f"Unique probs: {unique_probs[0], unique_probs[0].shape}")

    collisions_losses, kl_losses, sigma_loss, loss = loss_fn(
        collisions, 
        min_possible_collisions, 
        unique_probs if not should_kl_hist else rearrange(out, "batch pixels levels verts xyz -> batch levels pixels (verts xyz)")[0], #unique_hashed,  
        unique_sigmas
    )

    loss.backward()
    optimizer.step()

    return out, unique_probs, collisions, min_possible_collisions, hists, collisions_losses, kl_losses, sigma_loss, loss.item()


### Testing Loop

In [11]:
def test_loop(
    x: torch.Tensor,
    model: nn.Module,
    loss_fn: nn.Module,
    should_calc_hists: bool = False,
    should_kl_hist: bool = False,
    should_log: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, float]:
    """
    Test loop.

    Parameters
    ----------
    x : torch.Tensor
        Test input data.
    model : nn.Module
        Model to test.
    loss_fn : nn.Module
        Loss function.
    should_calc_hists : bool, optional (default is False)
        Whether to calculate histograms or not.
    should_kl_hist : bool, optional (default is False)
        Whether to calculate KL with histograms or not.
    should_log : int, optional (default is 0)
        - 0: No logging.
        - > 0: Log

    Returns
    -------
    Tuple[torch.Tensor, torch.Tensor, torch.Tensor, List[plt.Figure], torch.Tensor, torch.Tensor, float]
        Model output, unique probabilities, collisions, histograms, collisions losses, KL losses and loss.
    """
    print2(("Test loop", ), should_log > 0, bcolors.FAIL)
    
    model.eval()

    out, unique_hashed, unique_probs, unique_sigmas, collisions, min_possible_collisions, hists = multires(
        x, 
        should_calc_hists=should_calc_hists 
    )

    collisions_losses, kl_losses, sigma_loss, loss = loss_fn(
        collisions, 
        min_possible_collisions, 
        unique_probs if not should_kl_hist else rearrange(out, "batch pixels levels verts xyz -> batch levels pixels (verts xyz)")[0], #unique_hashed, 
        unique_sigmas
    )

    return out, unique_probs, collisions, min_possible_collisions, hists, collisions_losses, kl_losses, sigma_loss, loss.item()

## Metrics, Loss & Optimizer

### Loss

In [12]:
class Loss(nn.Module):
    def __init__(
        self,
        l_collisions: float,
        l_kl_loss: float,
        l_sigma_loss: float,
        delta: float = 1,
        hash_table_size: int = 2**14,
        should_kl_hist: bool = False,
        should_log: int = 0,
    ) -> None:
        """
        Loss module.

        Parameters
        ----------
        min_possible_collisions : torch.Tensor
            Minimum possible collisions per level.
        l_collisions : float
            Collisions loss lambda.
        l_kl_loss : float
            KL loss lambda.
        l_sigma_loss : float
            Sigma loss lambda.
        delta : float, optional (default is 1)
            Delta parameter for collision loss.
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        should_kl_hist : bool, optional (default is False)
            Whether to calculate KL with histograms or not.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log helper functions.

        Returns
        -------
        None
        """

        super(Loss, self).__init__()

        self._l_collisions = l_collisions
        self._l_kl_loss = l_kl_loss
        self._l_sigma_loss = l_sigma_loss

        self._delta = delta
        self._hash_table_size = hash_table_size
        self._should_kl_hist = should_kl_hist
        self._should_log = should_log

        self.kl_div_loss = nn.KLDivLoss(reduction="batchmean")
        self.mse_loss = nn.MSELoss(reduction="mean")

        # self.differentiable_hist = DifferentiableHistogram(self._hash_table_size)
        self.differentiable_hist = SoftHistogram(bins=self._hash_table_size, min=0, max=self._hash_table_size, sigma=1.0)

    def forward(
        self,
        collisions: List,
        min_possible_collisions: List,
        probs: List,
        sigmas: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:

        # min_possible_collisions[min_possible_collisions <= 0] = self._delta
        delta = min_possible_collisions.clone()
        delta[delta <= 0] = self._delta
    
        collisions_losses: torch.Tensor = (collisions - min_possible_collisions) / delta # min_possible_collisions
        print2(("Collisions Losses:", collisions_losses), self._should_log > 0)
        print2(("Collisions Losses grads:", collisions_losses.requires_grad), self._should_log > 0)

        kl_div_losses = torch.stack([
            (
                self._kl_div(prob.shape[0], prob.squeeze(-1)) 
                if not self._should_kl_hist 
                else self._kl_div(0, prob)
            )
            for prob in probs
        ])

        print2(("KL Losses:", kl_div_losses), self._should_log > 0, bcolors.FAIL)

        sigma_losses = torch.stack([
            self.mse_loss(sigma, torch.zeros_like(sigma))
            for sigma in sigmas
        ])

        collisions_loss = torch.sum(collisions_losses)
        kl_loss = kl_div_losses.sum()
        sigma_loss = sigma_losses.sum()
        
        loss = (self._l_collisions * collisions_loss) + (self._l_kl_loss * kl_loss) + (self._l_sigma_loss * sigma_loss)
        # loss = abs(loss)
        print2(("Loss grads:", loss.requires_grad, ), should_log > 0, bcolors.HEADER)

        return collisions_losses, kl_div_losses, sigma_losses, loss

    # def hist_c_differentiable(self, data, bins):
    #     pooled_data, _ = torch.max(data, dim=-1)
    #     print(pooled_data)
    #     hist = torch.histc(pooled_data, bins=bins)
    #     return hist

    def _kl_div(
        self,
        level: int,
        p: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculate KL divergence loss.

        Parameters
        ----------
        level : int
            Level.

        Returns
        -------
        torch.Tensor
            KL divergence loss.
        """

        print2((f"p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1)

        if self._should_kl_hist:
            hist_p_nondiff = torch.histc(p, bins=self._hash_table_size, min=0, max=self._hash_table_size) # ? maybe max=(self._hash_table_size - 1)
            print2((f"hist_p_nondiff: {hist_p_nondiff}, shape: {hist_p_nondiff.shape}, sum: {torch.sum(hist_p_nondiff)}, requires_grad: {hist_p_nondiff.requires_grad}", ), self._should_log > 1)

            # Other differentiable histogram methods
            # hist_p = self.hist_c_differentiable(p, bins=self._hash_table_size)
            # hist_p = self.differentiable_hist(p)
            # hist_p = differentiable_histogram_optimized(p, bins=self._hash_table_size, min=0, max=self._hash_table_size).squeeze(0).squeeze(0)

            hist_p = differentiable_histogram(p, bins=self._hash_table_size, min=0, max=self._hash_table_size).squeeze(0).squeeze(0)
            print2((f"hist_p: {hist_p}, shape: {hist_p.shape}, sum: {torch.sum(hist_p)}, requires_grad: {hist_p.requires_grad}", ), self._should_log > 1)

            print("Count(hist_p_nondiff != hist_p):", hist_p_nondiff.shape[0] - torch.sum(torch.eq(hist_p_nondiff, hist_p)))

            p = hist_p / self._hash_table_size
            print2((f"p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1)

            p[p == 0] = 1e-10
            print2((f"after p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1)

            q = torch.ones(self._hash_table_size) / self._hash_table_size
        else:
            # q = torch.ones(level) / level
            q = torch.arange(level, dtype=torch.float32)
            print2((f"before q: {q}, shape: {q.shape}, requires_grad: {q.requires_grad}", ), self._should_log > 1)

            q = torch.softmax(q, dim=-1)

        print2((f"after q: {q}, shape: {q.shape}, requires_grad: {q.requires_grad}", ), self._should_log > 1)

        return self.kl_div_loss(p.log(), q)


### Optimizer

In [13]:
def get_optimizer(
    net: torch.nn.Module,
    # encoding_lr: float,
    hash_lr: float,
    # MLP_lr: float,
    # encoding_weight_decay: float,
    hash_weight_decay: float,
    # MLP_weight_decay: float,
    betas: tuple = (0.9, 0.99),
    eps: float = 1e-15
):
    optimizer = torch.optim.Adam(
        [
            # {"params": net.encoding.parameters(), "lr": encoding_lr, "weight_decay": encoding_weight_decay},
            {"params": net.HashFunction.parameters(), "lr": hash_lr, "weight_decay": hash_weight_decay},
            # {"params": net.mlp.parameters(), "lr": MLP_lr, "weight_decay": MLP_weight_decay}
        ],
        betas=betas,
        eps=eps,
    )
    return optimizer

## Early stopper

In [14]:
class EarlyStopper:
    def __init__(self, tolerance: int = 5, min_delta: int = 0, should_reset: bool = True):
        self.tolerance: int = tolerance
        self.min_delta: int = min_delta
        self.best_loss: float = np.inf
        self.counter: int = 0
        self.early_stop: bool = False
        self._should_reset: bool = should_reset

    def __call__(self, loss):
        # print(f"best_loss: {self.best_loss}, loss: {loss}, counter: {self.counter}")

        if abs(self.best_loss - loss) < self.min_delta and (loss < self.best_loss):
            # print("Stall")
            self.counter += 1
        elif abs(self.best_loss - loss) > self.min_delta and (loss > self.best_loss):
            # print("Growing")
            self.counter += 1
        else:
            if not self._should_reset:
                if self.counter <= 0:
                    self.counter = 0
                else:
                    self.counter -= 1
            else:
                self.counter = 0
                self.best_loss = loss

        if self.counter >= self.tolerance:
            self.early_stop = True

## Initialization

### Arguments

In [15]:
root_dir = "./images"
test_size = 0.2

train_images, test_images = train_test_split(
    [
        file for file in os.listdir(root_dir) if ("silhouette" in file) and (file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg"))
    ], 
    test_size=test_size, 
    random_state=random_seed
)

wandb_entity = "fedemonti00"
wandb_project = "project_course"
# wandb_name = "hash_function_training"

try:
    time = wandb_name
except NameError:
    time = (datetime.now(ZoneInfo("Europe/Rome"))).strftime("%Y%m%d%H%M%S")
print("RUN:", time)

early_stopper_tolerance = 500
early_stopper_min_delta = 1e-6 

histogram_rate = 10

hyperparameters = {
    "hash_table_size": 2**8,
    "n_min": 8,
    "n_max": 32,
    "num_levels": 4,
    "output_dim": 2,
    "hash_function_hidden_layers_widths": [8, 32, 8],
    # "hash_lr": 1e-3,
    "hash_lr": 1e-4,
    "hash_weight_decay": 1e-6,
    "kl_hist_loss": True,
    # "l_collisions": 1e-1,
    "l_collisions": 0,
    # "l_kl_loss": -100,
    # "l_kl_loss": 100,
    "l_kl_loss": 1,
    # "l_sigma_loss": 1e-1,
    "l_sigma_loss": 0,
    # "epochs": 1000,
    "epochs": 1,
    "random_seed": random_seed,
}

# should_log = False
should_log = True
should_wandb = True if hyperparameters["epochs"] > 999 else False

RUN: 20231207164646


### Wandb

#### Init

In [16]:
if should_wandb:
    # start a new wandb run to track this script
    wandb.init(
        entity = wandb_entity,
        # set the wandb project where this run will be logged
        project = wandb_project,

        name = time,

        # track hyperparameters and run metadata
        config = hyperparameters,
        # config = {
        #     "id_grid_search_params":        id_param,
        #     "grid_search_params":           params,
        #     "random_seed":                  random_seed,
        #     "HPD_learning_rate":            HPD_lr,
        #     "encoding_learning_rate":       encoding_lr,
        #     "MLP_learning_rate":            MLP_lr,
        #     "encoding_weight_decay":        encoding_weight_decay,
        #     "HPD_weight_decay":             HPD_weight_decay,
        #     "MLP_weight_decay":             MLP_weight_decay,
        #     "batch_size%":                  batch_size,
        #     "shuffled_pixels":              should_shuffle_pixels,
        #     "normalized_data":              True if not should_batchnorm_data else "BatchNorm1d",
        #     "architecture":                 "GeneralNeuralGaugeFields",
        #     "dataset":                      image_name,
        #     "epochs":                       epochs,
        #     "color":                        'RGB' if not should_bw else 'BW',
        #     "hash_table_size":              hash_table_size,
        #     "num_levels":                   num_levels,
        #     "n_min":                        n_min,
        #     "n_max":                        n_max,
        #     "MLP_hidden_layers_widths":     str(MLP_hidden_layers_widths),
        #     "HPD_hidden_layers_widths":     str(HPD_hidden_layers_widths),
        #     "HPD_out_features":             HPD_out_features,
        #     "feature_dim":                  feature_dim,
        #     "topk_k":                       topk_k,
        #     "loss_type":                    "JS+KLDiv" if should_sum_js_kl_div else ("KLDiv" if not should_js_div else "JSDiv"),
        #     "loss_lambda_MSE":              l_mse,
        #     "loss_lambda_JS_KL":            l_js_kl,
        #     "loss_lambda_collisions":       l_collisions,
        #     "loss_gamma":                   loss_gamma,
        #     "loss_epsilon":                 loss_epsilon,
        #     "inplace_scatter":              should_inplace_scatter,
        #     "MLP_activations":              "LeakyReLU" if should_leaky_relu else "ReLU",
        #     "collisions_loss_probs":        "topk_only" if should_keep_topk_only else "hash_table_size",
        #     "avg_topk_features":            "softmax_avg" if should_softmax_topk_features else ("weighted_avg" if should_softmax_topk_features != None else None),
        #     "hash_type":                    "HPD" if not should_use_hash_function else "hash_function"
        # }

        save_code = False,
    )

#### Logger

In [17]:
def wandb_log(
    e: int,
    train_loss: float,
    train_collisions: torch.Tensor,
    train_min_possible_collisions: torch.Tensor,
    train_collisions_losses: torch.Tensor,
    train_kl_losses: torch.Tensor,
    train_sigma_losses: torch.Tensor,
    train_hists: List[plt.Figure],
    test_loss: float,
    test_collisions: torch.Tensor,
    test_min_possible_collisions: torch.Tensor,
    test_collisions_losses: torch.Tensor,
    test_kl_losses: torch.Tensor,
    test_sigma_losses: torch.Tensor,
    test_hists: List[plt.Figure],
    should_log_hists: bool = False
) -> None:
    """
    Log to wandb.

    Parameters
    ----------
    e : int
        Epoch.
    train_loss : float
        Train loss.
    train_collisions : torch.Tensor
        Train collisions.
    train_min_possible_collisions : torch.Tensor
        Train minimum possible collisions.
    train_collisions_losses : torch.Tensor
        Train collisions losses.
    train_kl_losses : torch.Tensor
        Train KL losses.
    train_sigma_losses : torch.Tensor
        Train sigma losses.
    train_hists : List[plt.Figure]
        Train histograms.
    test_loss : float
        Test loss.
    test_collisions : torch.Tensor
        Test collisions.
    test_min_possible_collisions : torch.Tensor
        Test minimum possible collisions.
    test_collisions_losses : torch.Tensor
        Test collisions losses.
    test_kl_losses : torch.Tensor
        Test KL losses.
    test_sigma_losses : torch.Tensor
        Test sigma losses.
    test_hists : List[plt.Figure]
        Test histograms.
    should_log_hists : bool, optional (default is False)
        Whether to log histograms or not.

    Returns
    -------
    None
    """

    log = {
        "train_loss": train_loss,
        "test_loss": test_loss,
        # "train_sigma_loss": train_sigma_losses.item(),
        # "test_sigma_loss": test_sigma_losses.item(),
    }

    for l in range(hyperparameters["num_levels"]):
        log[f"train_collisions_level_{l}"] = train_collisions[l].item()
        log[f"train_min_possible_collisions_level_{l}"] = train_min_possible_collisions[l].item()
        log[f"train_collisions_loss_level_{l}"] = train_collisions_losses[l].item()
        log[f"train_kl_loss_level_{l}"] = train_kl_losses[l].item()
        log[f"train_sigma_loss_level_{l}"] = train_sigma_losses[l].item()

        log[f"test_collisions_level_{l}"] = test_collisions[l].item()
        log[f"test_min_possible_collisions_level_{l}"] = test_min_possible_collisions[l].item()
        log[f"test_collisions_loss_level_{l}"] = test_collisions_losses[l].item()
        log[f"test_kl_loss_level_{l}"] = test_kl_losses[l].item()
        log[f"test_sigma_loss_level_{l}"] = test_sigma_losses[l].item()
        
        if should_calc_hists:
            log[f"train_hist_counts_level_{l}"] = wandb.Image(
                train_hists[l],
                caption=f"Hashed indices counts at level {l} at epoch {e}"
            )

            log[f"test_hist_counts_level_{l}"] = wandb.Image(
                test_hists[l],
                caption=f"Hashed indices counts at level {l} at epoch {e}"
            )
    
    wandb.log(log)

### Load Datasets

In [18]:
train_dataset = ImagesDataset(
    root=root_dir.split("/")[0],
    dir_name=root_dir.split("/")[1],
    images_names=train_images
)
x, y, h, w, names = train_dataset[-1]

test_dataset = ImagesDataset(
    root=root_dir.split("/")[0],
    dir_name=root_dir.split("/")[1],
    images_names=["strawberry_small.jpg"]#test_images
)
eval_x, eval_y, eval_h, eval_w, eval_names = test_dataset[-1]

input_dim = x.shape[-3]

### Models initialization

In [19]:
hashFunction = HashFunction(
    hidden_layers_widths=hyperparameters["hash_function_hidden_layers_widths"],
    input_dim=input_dim,
    output_dim=hyperparameters["output_dim"],
    hash_table_size=hyperparameters["hash_table_size"],
    should_log=0 if should_log else 0
)

multires = Multiresolution(
    n_min=hyperparameters["n_min"],
    n_max=hyperparameters["n_max"],
    num_levels=hyperparameters["num_levels"],
    HashFunction=hashFunction,
    hash_table_size=hyperparameters["hash_table_size"],
    input_dim=input_dim,
    should_log=0 if should_log else 0
)

loss_fn = Loss(
    l_collisions=hyperparameters["l_collisions"],
    l_kl_loss=hyperparameters["l_kl_loss"],
    l_sigma_loss=hyperparameters["l_sigma_loss"],
    hash_table_size=hyperparameters["hash_table_size"],
    should_kl_hist=hyperparameters["kl_hist_loss"],
    should_log=2 if should_log else 0
)

optimizer = get_optimizer(
    net=multires,
    hash_lr=hyperparameters["hash_lr"],
    hash_weight_decay=hyperparameters["hash_weight_decay"],
)

early_stopper = EarlyStopper(
    tolerance=early_stopper_tolerance,
    min_delta=early_stopper_min_delta
)

print(multires)

Multiresolution(
  (HashFunction): HashFunction(
    (module_list): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=2, out_features=8, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=8, out_features=32, bias=True)
        (1): ReLU()
      )
      (2): Sequential(
        (0): Linear(in_features=32, out_features=8, bias=True)
        (1): ReLU()
      )
      (3): Sequential(
        (0): Linear(in_features=8, out_features=2, bias=True)
        (1): Sigmoid()
      )
    )
  )
)


In [20]:
plt.ioff()
should_calc_hists = False

pbar = tqdm(range(0, hyperparameters["epochs"]))

for e in pbar:

    should_calc_hists = ((e == hyperparameters["epochs"] - 1) or (e % histogram_rate == 0) or early_stopper.early_stop)
    
    _, _, train_collisions, train_min_possible_collisions, train_hists, train_collisions_losses, train_kl_losses, train_sigma_losses, train_loss = train_loop(
        x=x,
        model=multires,
        optimizer=optimizer,
        loss_fn=loss_fn,
        # Calc histograns at: first epoch, last epoch, every histogram_rate epochs, every time early stopper stops
        should_calc_hists=should_calc_hists,
        should_kl_hist=hyperparameters["kl_hist_loss"],
        should_log=1 if should_log else 0
    )

    _, _, test_collisions, test_min_possible_collisions, test_hists, test_collisions_losses, test_kl_losses, test_sigma_losses, test_loss = test_loop(
        x=eval_x,
        model=multires,
        loss_fn=loss_fn,
        # Calc histograns at: first epoch, last epoch, every histogram_rate epochs, every time early stopper stops
        should_calc_hists=should_calc_hists,
        should_kl_hist=hyperparameters["kl_hist_loss"],
        should_log=1 if should_log else 0
    )

    if should_wandb:
        wandb_log(
            e=e,
            train_loss=train_loss,
            train_collisions=train_collisions,
            train_min_possible_collisions=train_min_possible_collisions,
            train_collisions_losses=train_collisions_losses,
            train_kl_losses=train_kl_losses,
            train_sigma_losses=train_sigma_losses,
            train_hists=train_hists,
            test_loss=test_loss,
            test_collisions=test_collisions,
            test_min_possible_collisions=test_min_possible_collisions,
            test_collisions_losses=test_collisions_losses,
            test_kl_losses=test_kl_losses,
            test_sigma_losses=test_sigma_losses,
            test_hists=test_hists,
            should_log_hists=should_calc_hists
        )

    pbar.set_description(f"Epoch {e}, Loss: {test_loss}, Collisions: {test_collisions}")

    # for name, param in multires.named_parameters():
    #     print(f'Parameter: {name}, Gradient: {param.grad}')

    if early_stopper.early_stop:
        print("!!! Stopping at epoch:", e, "!!!")
        break

    early_stopper(test_loss)


  0%|          | 0/1 [00:00<?, ?it/s]

Train loop
[93m -------------------- [0m


  return func(*args, **kwargs)
  0%|          | 0/1 [00:01<?, ?it/s]

[96m Line:  print2(("Collisions Losses:", collisions_losses), self._should_log > 0) [0m
Collisions Losses:
tensor([ 77.0000, 164.0000,   1.3405,   0.2917], device='cuda:0')
[96m -------------------- [0m
[96m Line:  print2(("Collisions Losses grads:", collisions_losses.requires_grad), self._should_log > 0) [0m
Collisions Losses grads:
False
[96m -------------------- [0m
[96m Line:  print2((f"p: {p}, shape: {p.shape}, requires_grad: {p.requires_grad}", ), self._should_log > 1) [0m
p: tensor([[131., 131., 131., 131.],
        [131., 131., 131., 131.],
        [131., 131., 131., 131.],
        ...,
        [130., 130., 129., 129.],
        [130., 130., 129., 129.],
        [130., 130., 129., 129.]], device='cuda:0', grad_fn=<UnbindBackward0>), shape: torch.Size([8100, 4]), requires_grad: True
[96m -------------------- [0m
[96m Line:  print2((f"hist_p_nondiff: {hist_p_nondiff}, shape: {hist_p_nondiff.shape}, sum: {torch.sum(hist_p_nondiff)}, requires_grad: {hist_p_nondiff.requi




RuntimeError: The size of tensor a (256) must match the size of tensor b (4) at non-singleton dimension 2

In [None]:
if should_wandb:
    wandb.finish()