# NeRF: Collision handling in Instant Neural Graphics Primitives
#### Federico Montagna (fedemonti00@gmail.com)

# Code:

## Imports

In [1]:
from typing import Tuple, List

import numpy as np
from math import atan2
import cv2
from PIL import Image
import os
import matplotlib
from matplotlib.ticker import MultipleLocator, AutoMinorLocator, FixedLocator
import matplotlib.pyplot as plt
import random
from scipy.stats import norm
from sklearn.model_selection import train_test_split

import collections, functools, operator
from collections import Counter
import itertools

import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import io, transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.autograd.profiler as profiler

# from torchinfo import summary

from einops import rearrange, reduce, repeat

import traceback
from pprint import pprint

import wandb
from tqdm import tqdm

import pdb

try:
    from zoneinfo import ZoneInfo
except ImportError:
    from backports.zoneinfo import ZoneInfo

from datetime import datetime

# from params import *

print("Cuda avilable:", torch.cuda.is_available())
for i in range(torch.cuda.device_count()):
    print(f"Available device {i}:", torch.cuda.get_device_name(i))

device = torch.device('cuda:0' if torch.cuda.is_available() else "cpu")
print(f"Current device {torch.cuda.current_device()}:", torch.cuda.get_device_name(torch.cuda.current_device()))

# torch.set_default_device(device)

# random_seed = None
# if should_random_seed:
#     random_seed = 2**16 - 1
#     random.seed(random_seed)
#     np.random.seed(random_seed)
#     torch.manual_seed(random_seed)
#     torch.random.manual_seed(random_seed)
#     if torch.cuda.is_available():
#         torch.cuda.manual_seed_all(random_seed)
#     print("Random seed:", random_seed)

os.environ["WANDB_NOTEBOOK_NAME"] = "main.ipynb"

plt.style.use("ggplot")

Cuda avilable: True
Available device 0: NVIDIA GeForce RTX 2070
Available device 1: NVIDIA GeForce RTX 2070
Current device 0: NVIDIA GeForce RTX 2070


## Debug Functions

In [2]:
class bcolors:
    HEADER = '\033[95m'
    OKBLUE = '\033[94m'
    OKCYAN = '\033[96m'
    OKGREEN = '\033[92m'
    WARNING = '\033[93m'
    FAIL = '\033[91m'
    ENDC = '\033[0m'
    BOLD = '\033[1m'
    UNDERLINE = '\033[4m'


def print2(texts, log: bool = False, color: bcolors = bcolors.OKCYAN) -> None:
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(color, "Line: ", calling_line, bcolors.ENDC)
        for text in texts:
            print(text)
        print(color, "-"*20, bcolors.ENDC)


def print_allocated_memory(log: bool = True):
    if log:
        stack = traceback.extract_stack()
        calling_frame = stack[-2]
        calling_line = calling_frame.line
        print(bcolors.HEADER, "Line: ", calling_line, bcolors.ENDC)

        allocated_memory = torch.cuda.memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Allocated Memory: {allocated_memory:.2f} GB")

        peak_memory = torch.cuda.max_memory_allocated() / (1024 ** 3)  # Convert to gigabytes
        print(f"Peak Allocated Memory: {peak_memory:.2f} GB")

        print(bcolors.OKCYAN, "-"*20, bcolors.ENDC)

## Load wandb apikey

In [3]:
# apikey_path = ".wandb_apikey.txt"
# if os.path.exists(apikey_path):
#     with open(apikey_path, "r") as f:
#         apikey = f.read()
#         !wandb login {apikey} # --relogin

## Dataset

In [4]:
class ImagesDataset(Dataset):
    def __init__(
        self,
        root: str,
        dir_name: str,
        images_names: List[str],
    ) -> None:
        """

        Parameters
        ----------
        root : str
            Path to root directory.
        dir_name : str
            Name of directory.
        images_names : List[str]
            List of images names.
        
        Returns
        -------
        None
        """

        self._root: str = root
        self._dir_name: str = dir_name
        self._images_names: List[str] = images_names

        self._images_paths: List[str] = [
            os.path.join(self._root, self._dir_name, image_name)
            for image_name in self._images_names
        ]
    
    def __getitem__(self, idx: torch.Tensor or int) -> Tuple[torch.Tensor, torch.Tensor, int, int]:
        """
        Get data by indices.

        Parameters
        ----------
        idx : torch.Tensor or int
            Indices of data.

        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, int, int]
            Tuple of images.

        Raises
        ------
        ValueError
            If images have different sizes.
        """

        if idx == -1:
            idx = torch.arange(len(self._images_paths))
        
        if torch.is_tensor(idx):
            idx = idx.tolist()
        else:
            idx = [idx]

        images: List = [
            rearrange(io.read_image(self._images_paths[_id]), "rgb h w -> h w rgb")
            for _id in idx
        ]

        heights, widths = list(set([image.shape[0] for image in images])), list(set([image.shape[1] for image in images]))

        if len(heights) > 1 or len(widths) > 1:
            raise ValueError("Images have different sizes.")
        else:
            h = heights[0]
            w = widths[0]

        X: torch.Tensor = (
            torch.stack([
                torch.tensor(
                    np.stack(np.meshgrid(range(h), range(w), indexing="ij"), axis=-1).reshape(-1, 2)
                )
                for _ in images
            ]).float() / (max(w, h)) # No (max(w, h) - 1)
        ).unsqueeze(-1).unsqueeze(-1) #.requires_grad_()
        
        Y: torch.Tensor = torch.stack(
            images
        ).float() / 255

        return X, Y, h, w, self._images_names

    def __len__(self) -> int:
        """
        Returns
        -------
        int
            Length of dataset.
        """
        return len(self._images_paths)

## Models

### Gaussian Convolution Backward

In [5]:
class GaussianConvolution(torch.autograd.Function):
    @staticmethod
    def forward(ctx, x: torch.Tensor, hash_table_size: int):
        # pdb.set_trace()
        
        ctx.save_for_backward(x)
        ctx.hash_table_size = hash_table_size

        indices = torch.round(x * hash_table_size)

        # ctx.indices = indices
        # ctx.save_for_backward(indices)

        return indices

    @staticmethod
    def backward(ctx, grad_output: torch.Tensor):
        # pdb.set_trace()
        
        # indices = ctx.saved_tensors
        x, = ctx.saved_tensors
        hash_table_size = ctx.hash_table_size

        indices = x * hash_table_size
        # indices = ctx.indices

        grad_x = grad_output * norm.pdf(np.arange(0, hash_table_size, 1), loc=indices, scale=1)

        return grad_x, None

### Hash Function

In [6]:
class HashFunction(nn.Module):
    def __init__(
        self,
        hidden_layers_widths: List[int],
        input_dim: int = 2,
        output_dim: int = 1,
        hash_table_size: int = 2**14,
        should_log: int = 0, 
    ) -> None:
        """
        Hash function module.

        Parameters
        ----------
        hidden_layers_widths : List[int]
            List of hidden layers widths.
        input_dim : int, optional (default is 2)
            Input dimension
        output_dim : int, optional (default is 1)
            Output dimension
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log layers grads and outputs.

        Returns
        -------
        None
        """

        super(HashFunction, self).__init__()

        self._hash_table_size: int = hash_table_size
        self._should_log: int = should_log

        layers_widths: List[int] = [input_dim, *hidden_layers_widths, output_dim]

        self.module_list: nn.ModuleList = nn.ModuleList([
            nn.Sequential(
                nn.Linear(in_features=layers_widths[i], out_features=layers_widths[i + 1]),
                nn.ReLU() if i < len(layers_widths) - 2 else nn.Sigmoid()
            ) for i in range(len(layers_widths) - 1)
        ])

    def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor]:
        print2(("Input:", x, x.shape), (self._should_log > 0))
        print2(("Input grad:", x.requires_grad, ), (self._should_log > 0))

        for i, layer in enumerate(self.module_list):
            x = layer(x)

        print2((f"After layers:", f"Output: {x}, shape: {x.shape}"), (self._should_log > 0))
        print2((f"After layers:", f"Grad info: {x.requires_grad}, {x.grad_fn}"), (self._should_log > 1))
        
        # x = torch.nan_to_num(x) # Sanitize nan to 0.0

        indices = GaussianConvolution.apply(x, self._hash_table_size)
        print2(("GaussianConvolution:", x, x.shape), (self._should_log > 0))
        print2(("GaussianConvolution grad:", indices.requires_grad, ), (self._should_log > 1))

        return x, indices

### Multiresolution

In [7]:
class Multiresolution(nn.Module):
    def __init__(
        self,
        n_min: int,
        n_max: int,
        num_levels: int,
        HashFunction: HashFunction,
        hash_table_size: int = 2**14,
        input_dim: int = 2,
        should_log: int = 0
    ) -> None:
        """
        Multiresolution module.

        Parameters
        ----------
        n_min : int
            Minimum scaling factor.
        n_max : int
            Maximum scaling factor.
        num_levels : int
            Number of levels.
        HashFunction : HashFunction
            Hash function module.
        hash_table_size : int, optional (default is 2**14)
            Hash table size.
        input_dim : int, optional (default is 2)
            Input dimension.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log helper functions.
            - > 2: Log collisions.
            - > 3: Log dummies.
            - > 5: Log initialization.

        Returns
        -------
        None
        """

        super(Multiresolution, self).__init__()

        self.HashFunction: nn.Module = HashFunction

        self._hash_table_size: int = hash_table_size
        self._num_levels: int = num_levels
        self._input_dim: int = input_dim
        self._should_log: int = should_log

        b: torch.Tensor = torch.tensor(np.exp((np.log(n_max) - np.log(n_min)) / (self._num_levels - 1))).float()
        if b > 2 or b <= 1:
            print(
                f"The between level scale is recommended to be <= 2 and needs to be > 1 but was {b:.4f}."
            )
        
        self._levels: torch.Tensor = torch.stack([
            torch.floor(n_min * (b ** l)) for l in range(self._num_levels)
        ]).reshape(1, 1, -1, 1)
        print2(("Levels:", self._levels, self._levels.shape), (self._should_log > 5))
        print2(("Levels grads:", self._levels.requires_grad, ), (self._should_log > 5))

        self._voxels_helper_hypercube: torch.Tensor = rearrange(
            torch.tensor(
                np.stack(np.meshgrid(range(2), range(2), range(self._input_dim - 1), indexing="ij"), axis=-1)
            ),
            "cols rows depths verts -> (depths rows cols) verts"
        ).T[:self._input_dim, :].unsqueeze(0).unsqueeze(2)
        print2(("voxels_helper_hypercube:", self._voxels_helper_hypercube, self._voxels_helper_hypercube.shape), (self._should_log > 5))
        print2(("voxels_helper_hypercube grads:", self._voxels_helper_hypercube.requires_grad), (self._should_log > 5))

        self.min_possible_collisions: torch.Tensor = ((self._levels[0, 0, :, 0] + 1) ** 2) - self._hash_table_size
        self.min_possible_collisions[self.min_possible_collisions < 0] = 0
        print2(("min_possible_collisions:", self.min_possible_collisions, self.min_possible_collisions.shape), (self._should_log > 5))
        print2(("min_possible_collisions grads:", self.min_possible_collisions.requires_grad), (self._should_log > 5))

    def forward(
        self, 
        x: torch.Tensor, 
        should_calc_hists: bool = False
    ) -> Tuple[torch.Tensor]:
        # print2(("x:", x, x.shape), (self._should_log > 0))
        print2(("x grads:", x.requires_grad, ), (self._should_log > 0))

        scaled_coords, grid_coords = self._scale_to_grid(x)

        probs, hashed = self.HashFunction(grid_coords)
        print2(("hashed:", hashed, hashed.shape), (self._should_log > 0))
        print2(("hashed grads:", hashed.requires_grad, ), (self._should_log > 0))

        print2(("probs:", probs, probs.shape), (self._should_log > 0))
        print2(("probs grads:", probs.requires_grad, ), (self._should_log > 0))

        dummy_grids, dummy_hashed, og_indices = self._calc_dummies(grid_coords, hashed)

        collisions: torch.Tensor = self.calc_hash_collisions(dummy_grids, dummy_hashed)
        del dummy_hashed

        unique_probs, unique_hashed = self._calc_uniques(probs, hashed, og_indices)
        del og_indices

        if should_calc_hists:
            hists = self._hist_collisions(dummy_grids, unique_hashed, should_show=False)
        else:
            hists = None
        del dummy_grids, unique_hashed

        return hashed, unique_probs, collisions, hists

    @torch.no_grad()
    def _scale_to_grid(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]]:
        """
        Scale coordinates to grid.

        Parameters
        ----------
        x : torch.Tensor
            Original coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor, torch.Tensor, List[torch.Tensor]
            Scaled coordinates, grid coordinates and dummy grids.
        """

        scaled_coords: torch.Tensor = x.float() * self._levels.float()
        print2(("scaled_coords:", scaled_coords, scaled_coords.shape), (self._should_log > 1))
        print2(("scaled_coords grads:", scaled_coords.requires_grad, ), (self._should_log > 1))

        grid_coords: torch.Tensor = rearrange(
            torch.add(
                torch.floor(scaled_coords),
                self._voxels_helper_hypercube
            ),
            "batch pixels xyz levels verts -> batch pixels levels verts xyz"
        )
        print2(("grid_coords:", grid_coords, grid_coords.shape), (self._should_log > 1))
        print2(("grid_coords grads:", grid_coords.requires_grad, ), (self._should_log > 1))

        return scaled_coords, grid_coords

    @torch.no_grad()
    def calc_hash_collisions(
        self, 
        dummy_grids: List[torch.Tensor], 
        dummy_hashed: List[torch.Tensor]
    ) -> torch.Tensor:
        """
        Calculate hash collisions.

        Parameters
        ----------
        dummy_grid_coords : List[torch.Tensor]
            Grid coordinates.
        dummy_hashed : List[torch.Tensor]
            Hashed grid coordinates.
        
        Returns
        -------
        torch.Tensor
            Hash collisions per level.
        """
        collisions: torch.Tensor = torch.stack([
            # [
            torch.tensor(float(dummy_grids[l].shape[0] - torch.unique(dummy_hashed[l], dim=0).shape[0]), requires_grad=True)
            for l in range(self._num_levels)
            # ] 
            # for b in range(hashed.shape[0])
        ])
        print2(("collisions:", collisions, collisions.shape, collisions.dtype), (self._should_log > 2))
        print2(("collisions grads:", collisions.requires_grad, ), (self._should_log > 2))

        return collisions

    @torch.no_grad()
    def _calc_dummies(
        self,
        grid_coords: torch.Tensor,
        hashed: torch.Tensor,
    ) -> Tuple[torch.Tensor]:
        """
        Calculate dummies.

        Parameters
        ----------
        grid_coords : torch.Tensor
            Grid coordinates.
        hashed : torch.Tensor
            Hashed grid coordinates.
        
        Returns
        -------
        Tuple[torch.Tensor]
            Dummy grids and dummy hashed and original unique indices.
        """

        dummy_grids: List[Tuple[torch.Tensor]] = [
            # [
            torch.unique(rearrange(grid_coords[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=True) # first dimension is 0 because images should have all same size
            for l in range(self._num_levels)
            # ] for b in range(grid_coords.shape[0])
        ]

        dummy_grids, dummy_grids_inverse_indices = zip(*dummy_grids)
        print2(("dummy_grids:", dummy_grids, [dummy_grids[l].shape for l in range(self._num_levels)]), (self._should_log > 1))
        print2(("dummy_grids grads:", [(dummy_grids[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 3))

        og_indices = [
            torch.tensor([np.where(dummy_grids_inverse_indices[l].numpy() == i)[0][0] for i in range(len(dummy_grids[l]))])
            for l in range(self._num_levels)
        ]
        print2(("og_indices:", [(og_indices[l], og_indices[l].shape) for l in range(self._num_levels)]), (self._should_log > 1))
        print2(("og_indices grads:", [(og_indices[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 1))

        dummy_hashed: List[torch.Tensor] = [
            # [
            torch.unique(rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz"), dim=0, return_inverse=False) # first dimension is 0 because images should have all same size
            for l in range(self._num_levels)
            # ] for b in range(hashed.shape[0])
        ]
        print2(("dummy_hashed:", dummy_hashed, [dummy_hashed[l].shape for l in range(self._num_levels)]), (self._should_log > 3))
        print2(("dummy_hashed grads:", [(dummy_hashed[l].requires_grad, ) for l in range(self._num_levels)]), (self._should_log > 3))

        return dummy_grids, dummy_hashed, og_indices
    
    # @torch.no_grad()
    def _calc_uniques(
        self,
        probs: torch.Tensor,
        hashed: torch.Tensor,
        og_indices: List[torch.Tensor]
    ) -> Tuple[torch.Tensor]:
        """
        Calculate uniques.

        Parameters
        ----------
        probs : torch.Tensor
            Probabilities.
        hashed : torch.Tensor
            Hashed grid coordinates.
        og_indices : List[torch.Tensor]
            Original unique indices.

        Returns
        -------
        Tuple[torch.Tensor]
            Unique probabilities and unique hashed.
        """

        unique_probs: List[torch.Tensor] = [
            # [
            F.softmax(rearrange(probs[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]], dim=0)
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_probs:", [(unique_probs[b][l], unique_probs[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_probs:", [(unique_probs[l], unique_probs[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))

        unique_hashed: List[torch.Tensor] = [
            # [
            rearrange(hashed[0, :, l, :, :], "pixels verts xyz -> (pixels verts) xyz")[og_indices[l]]
            for l in range(self._num_levels)
            # ]
            # for b in range(probs.shape[0])
        ]
        print2(("unique_hashed:", [(unique_hashed[b][l], unique_hashed[b][l].shape) for l in range(self._num_levels) for b in range(probs.shape[0])]), (self._should_log > 0))
        print2(("unique_hashed:", [(unique_hashed[l], unique_hashed[l].shape) for l in range(self._num_levels)]), (self._should_log > 0))

        return unique_probs, unique_hashed

    @torch.no_grad()
    def _hist_collisions(
        self,
        dummy_grids: List[torch.Tensor],
        dummy_hashed: List[torch.Tensor],
        should_show: bool = False
    ) -> List[plt.Figure]:
        """
        Show collisions.

        Parameters
        ----------
        dummy_grids : List[torch.Tensor]
            Grid coordinates.
        dummy_hashed : List[torch.Tensor]
            Hashed grid coordinates.
        should_show : bool, optional (default is False)
            Whether to show figure.
        
        Returns
        -------
        List[plt.Figure]
            Histograms of collisions, one per level.
        """

        figs=[]

        for l in range(self._num_levels):

            indices = dummy_hashed[l].detach().numpy()

            fig, ax = plt.subplots(figsize=(15, 5))
            ax.hist(
                indices,
                bins=self._hash_table_size,
                range=(0, self._hash_table_size),
                edgecolor='grey', 
                linewidth=0.5
            )

            ax.set_xlim(-1, self._hash_table_size)
            ax.xaxis.set_ticks(np.arange(0, self._hash_table_size, 10))

            start, end = ax.get_ylim()
            step = int(end * 0.1)
            ax.yaxis.set_ticks(np.arange(0, end, step if step > 0 else 1))

            plt.title(f"Level {l} ({int(self._levels[0, 0, l, 0].item())})")
            plt.xlabel("Hashed indices")
            plt.ylabel("Counts")

            figs.append(fig)

            if should_show:
                plt.show()
            
            plt.close()

        return figs



## Metrics, Loss & Optimizer

### Loss

In [8]:
class Loss(nn.Module):
    def __init__(
        self,
        delta: float = 1,
        should_log: int = 0
    ) -> None:
        """
        Loss module.

        Parameters
        ----------
        delta : float, optional (default is 1)
            Delta parameter for collision loss.
        should_log : int, optional (default is 0)
            - 0: No logging.
            - > 0: Log forward pass.
            - > 1: Log helper functions.

        Returns
        -------
        None
        """

        super(Loss, self).__init__()

        self._delta = delta
        self._should_log = should_log

        self.kl_div_loss = nn.KLDivLoss(reduction="batchmean")

    def forward(
        self,
        collisions: List,
        min_possible_collisions: List,
        probs: List,
    ) -> Tuple[torch.Tensor]:

        # min_possible_collisions[min_possible_collisions <= 0] = self._delta
        delta = min_possible_collisions.clone()
        delta[delta <= 0] = self._delta
    
        collisions_losses: torch.Tensor = (collisions - min_possible_collisions) / delta # min_possible_collisions
        print2(("Collisions Losses:", collisions_losses), self._should_log > 0)
        print2(("Collisions Losses grads:", collisions_losses.requires_grad), self._should_log > 0)

        kl_div_losses = torch.stack([
            self._kl_div(prob.shape[0], prob)
            for prob in probs
        ])

        return collisions_losses, kl_div_losses

    def _kl_div(
        self,
        level: int,
        p: torch.Tensor,
    ) -> torch.Tensor:
        """
        Calculate KL divergence loss.

        Parameters
        ----------
        level : int
            Level.

        Returns
        -------
        torch.Tensor
            KL divergence loss.
        """

        q = torch.ones(level) / level

        return self.kl_div_loss(p.log(), q)


### Optimizer

In [9]:
def get_optimizer(
    net: torch.nn.Module,
    # encoding_lr: float,
    hash_lr: float,
    # MLP_lr: float,
    # encoding_weight_decay: float,
    hash_weight_decay: float,
    # MLP_weight_decay: float,
    betas: tuple = (0.9, 0.99),
    eps: float = 1e-15
):
    optimizer = torch.optim.Adam(
        [
            # {"params": net.encoding.parameters(), "lr": encoding_lr, "weight_decay": encoding_weight_decay},
            {"params": net.HashFunction.parameters(), "lr": hash_lr, "weight_decay": hash_weight_decay},
            # {"params": net.mlp.parameters(), "lr": MLP_lr, "weight_decay": MLP_weight_decay}
        ],
        betas=betas,
        eps=eps,
    )
    return optimizer

## Early stopper

In [10]:
class EarlyStopper:
    def __init__(self, tolerance: int = 5, min_delta: int = 0, should_reset: bool = True):
        self.tolerance: int = tolerance
        self.min_delta: int = min_delta
        self.best_loss: float = np.inf
        self.counter: int = 0
        self.early_stop: bool = False
        self._should_reset: bool = should_reset

    def __call__(self, loss):
        # print(f"best_loss: {self.best_loss}, loss: {loss}, counter: {self.counter}")

        if abs(self.best_loss - loss) < self.min_delta and (loss < self.best_loss):
            # print("Stall")
            self.counter += 1
        elif abs(self.best_loss - loss) > self.min_delta and (loss > self.best_loss):
            # print("Growing")
            self.counter += 1
        else:
            if not self._should_reset:
                if self.counter <= 0:
                    self.counter = 0
                else:
                    self.counter -= 1
            else:
                self.counter = 0
                self.best_loss = loss

        if self.counter >= self.tolerance:
            self.early_stop = True

## Initialization

### Arguments

In [11]:
root_dir = "./images"
test_size = 0.2
random_state = 65535

train_images, test_images = train_test_split(
    [
        file for file in os.listdir(root_dir) if ("silhouette" in file) and (file.endswith(".jpg") or file.endswith(".png") or file.endswith(".jpeg"))
    ], 
    test_size=test_size, 
    random_state=random_state
)

wandb_entity = "fedemonti00"
wandb_project = "project_course"
# wandb_name = "hash_function_training"

try:
    time = wandb_name
except NameError:
    time = (datetime.now(ZoneInfo("Europe/Rome"))).strftime("%Y%m%d%H%M%S")
print("RUN:", time)

early_stopper_tolerance = 10
early_stopper_min_delta = 1e-6     

histogram_rate = 10

hyperparameters = {
    "hash_table_size": 2**8,
    "n_min": 8,
    "n_max": 32,
    "num_levels": 4,
    "output_dim": 1,
    "hash_function_hidden_layers_widths": [8, 32, 8],
    "hash_lr": 1e-3,
    "hash_weight_decay": 1e-6,
    "l_collisions": 1,
    "l_kl_loss": 100,
    "epochs": 1000,
}

RUN: 20231102100528


### Wandb

In [12]:
# ------------------------------ #
#          WANDB INIT            #
# ------------------------------ #

# start a new wandb run to track this script
wandb.init(
    entity = wandb_entity,
    # set the wandb project where this run will be logged
    project = wandb_project,

    name = time,

    # track hyperparameters and run metadata
    config = hyperparameters,
    # config = {
    #     "id_grid_search_params":        id_param,
    #     "grid_search_params":           params,
    #     "random_seed":                  random_seed,
    #     "HPD_learning_rate":            HPD_lr,
    #     "encoding_learning_rate":       encoding_lr,
    #     "MLP_learning_rate":            MLP_lr,
    #     "encoding_weight_decay":        encoding_weight_decay,
    #     "HPD_weight_decay":             HPD_weight_decay,
    #     "MLP_weight_decay":             MLP_weight_decay,
    #     "batch_size%":                  batch_size,
    #     "shuffled_pixels":              should_shuffle_pixels,
    #     "normalized_data":              True if not should_batchnorm_data else "BatchNorm1d",
    #     "architecture":                 "GeneralNeuralGaugeFields",
    #     "dataset":                      image_name,
    #     "epochs":                       epochs,
    #     "color":                        'RGB' if not should_bw else 'BW',
    #     "hash_table_size":              hash_table_size,
    #     "num_levels":                   num_levels,
    #     "n_min":                        n_min,
    #     "n_max":                        n_max,
    #     "MLP_hidden_layers_widths":     str(MLP_hidden_layers_widths),
    #     "HPD_hidden_layers_widths":     str(HPD_hidden_layers_widths),
    #     "HPD_out_features":             HPD_out_features,
    #     "feature_dim":                  feature_dim,
    #     "topk_k":                       topk_k,
    #     "loss_type":                    "JS+KLDiv" if should_sum_js_kl_div else ("KLDiv" if not should_js_div else "JSDiv"),
    #     "loss_lambda_MSE":              l_mse,
    #     "loss_lambda_JS_KL":            l_js_kl,
    #     "loss_lambda_collisions":       l_collisions,
    #     "loss_gamma":                   loss_gamma,
    #     "loss_epsilon":                 loss_epsilon,
    #     "inplace_scatter":              should_inplace_scatter,
    #     "MLP_activations":              "LeakyReLU" if should_leaky_relu else "ReLU",
    #     "collisions_loss_probs":        "topk_only" if should_keep_topk_only else "hash_table_size",
    #     "avg_topk_features":            "softmax_avg" if should_softmax_topk_features else ("weighted_avg" if should_softmax_topk_features != None else None),
    #     "hash_type":                    "HPD" if not should_use_hash_function else "hash_function"
    # }

    save_code = True,
)

# ------------------------------ #

[34m[1mwandb[0m: Currently logged in as: [33mfedemonti00[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Load Datasets

In [13]:
train_dataset = ImagesDataset(
    root=root_dir.split("/")[0],
    dir_name=root_dir.split("/")[1],
    images_names=train_images
)
x, y, h, w, names = train_dataset[-1]

test_dataset = ImagesDataset(
    root=root_dir.split("/")[0],
    dir_name=root_dir.split("/")[1],
    images_names=test_images
)
eval_x, eval_y, eval_h, eval_w, eval_names = test_dataset[-1]

input_dim = x.shape[-3]
should_log = False

### Models initialization

In [14]:
hashFunction = HashFunction(
    hidden_layers_widths=hyperparameters["hash_function_hidden_layers_widths"],
    input_dim=input_dim,
    output_dim=hyperparameters["output_dim"],
    hash_table_size=hyperparameters["hash_table_size"],
    should_log=0 if should_log else 0
)

multires = Multiresolution(
    n_min=hyperparameters["n_min"],
    n_max=hyperparameters["n_max"],
    num_levels=hyperparameters["num_levels"],
    HashFunction=hashFunction,
    hash_table_size=hyperparameters["hash_table_size"],
    input_dim=input_dim,
    should_log=6 if should_log else 0
)

loss_fn = Loss(
    should_log=0 if should_log else 0
)

optimizer = get_optimizer(
    net=multires,
    hash_lr=hyperparameters["hash_lr"],
    hash_weight_decay=hyperparameters["hash_weight_decay"],
)

early_stopper = EarlyStopper(
    tolerance=early_stopper_tolerance,
    min_delta=early_stopper_min_delta
)

print(multires)

Multiresolution(
  (HashFunction): HashFunction(
    (module_list): ModuleList(
      (0): Sequential(
        (0): Linear(in_features=2, out_features=8, bias=True)
        (1): ReLU()
      )
      (1): Sequential(
        (0): Linear(in_features=8, out_features=32, bias=True)
        (1): ReLU()
      )
      (2): Sequential(
        (0): Linear(in_features=32, out_features=8, bias=True)
        (1): ReLU()
      )
      (3): Sequential(
        (0): Linear(in_features=8, out_features=1, bias=True)
        (1): Sigmoid()
      )
    )
  )
)


In [15]:
plt.ioff()
should_calc_hists = False

pbar = tqdm(range(0, hyperparameters["epochs"]))

multires.train()

for e in pbar:
    optimizer.zero_grad()

    # Calc histograns at: first epoch, last epoch, every histogram_rate epochs, every time early stopper stops
    should_calc_hists = (e == hyperparameters["epochs"] - 1) or (e % histogram_rate == 0) or early_stopper.early_stop
    
    out, unique_probs, collisions, figs = multires(
        x, 
        should_calc_hists=should_calc_hists 
    )

    collisions_losses, kl_losses = loss_fn(collisions, multires.min_possible_collisions, unique_probs)

    collisions_loss = torch.sum(collisions_losses)
    kl_loss = kl_losses.sum()
    loss = hyperparameters["l_collisions"] * collisions_loss - hyperparameters["l_kl_loss"] * kl_loss
    # print2(("Loss grads:", loss.requires_grad, ), True, bcolors.HEADER)

    log = {
        "train_loss": loss.item(),
    }

    for l in range(hyperparameters["num_levels"]):
        log[f"train_collisions_level_{l}"] = collisions[l].item()
        log[f"train_min_possible_collisions_level_{l}"] = multires.min_possible_collisions[l].item()

        log[f"train_collisions_loss_level_{l}"] = collisions_losses[l].item()
        log[f"train_kl_loss_level_{l}"] = kl_losses[l].item()
        
        if should_calc_hists:
            log[f"train_hist_counts_level_{l}"] = wandb.Image(
                figs[l],
                caption=f"Hashed indices counts at level {l} at epoch {e}"
            )
    del figs

    wandb.log(log)

    # print2(
    #     (f"""Epoch {e}, 
    #     Collision_losses: {collisions_losses}, 
    #     KL Loss: {kl_losses}, 
    #     Loss: {loss.item()},
    #     Collisions: {collisions},
    #     Min possible Collisions: {multires.min_possible_collisions}""", ), 
    #     True, 
    #     bcolors.OKGREEN
    # )
    pbar.set_description(f"Epoch {e}, Loss: {loss.item()}, Collisions: {collisions}")

    loss.backward()

    # for name, param in multires.named_parameters():
    #     print(f'Parameter: {name}, Gradient: {param.grad}')

    optimizer.step()

    if early_stopper.early_stop:
        print("!!! Stopping at epoch:", e, "!!!")
        break

    early_stopper(loss.item())

Epoch 75, Loss: 131.01731872558594, Collisions: tensor([ 38.,  95., 322., 967.]):   8%|▊         | 75/1000 [00:58<11:55,  1.29it/s]  


!!! Stopping at epoch: 75 !!!


In [16]:
wandb.finish()

0,1
train_collisions_level_0,██▇▇▇▆▆▆▅▅▅▄▄▄▄▄▄▅▄▃▃▃▃▂▂▂▂▃▃▂▂▂▂▁▁▁▁▂▃▄
train_collisions_level_1,██▇▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▁▂▂▃
train_collisions_level_2,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
train_collisions_level_3,██▇▆▅▅▄▄▃▃▃▃▃▂▂▂▂▃▂▂▂▃▃▂▂▂▃▃▂▂▂▂▂▂▂▂▂▁▁▁
train_collisions_loss_level_0,██▇▇▇▆▆▆▅▅▅▄▄▄▄▄▄▅▄▃▃▃▃▂▂▂▂▃▃▂▂▂▂▁▁▁▁▂▃▄
train_collisions_loss_level_1,██▇▇▇▆▆▆▅▅▅▄▄▄▄▄▃▃▃▂▂▂▂▂▂▂▂▂▂▁▁▂▁▁▁▂▁▂▂▃
train_collisions_loss_level_2,██▇▇▆▆▅▅▅▄▄▄▄▃▃▃▃▃▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
train_collisions_loss_level_3,██▇▆▅▅▄▄▃▃▃▃▃▂▂▂▂▃▂▂▂▃▃▂▂▂▃▃▂▂▂▂▂▂▂▂▂▁▁▁
train_kl_loss_level_0,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▇▇▇▇██
train_kl_loss_level_1,▁▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇██

0,1
train_collisions_level_0,38.0
train_collisions_level_1,95.0
train_collisions_level_2,322.0
train_collisions_level_3,967.0
train_collisions_loss_level_0,38.0
train_collisions_loss_level_1,95.0
train_collisions_loss_level_2,0.74054
train_collisions_loss_level_3,0.16086
train_kl_loss_level_0,0.00601
train_kl_loss_level_1,0.00829
