In [31]:
import os
import sys

if os.path.basename(os.getcwd()) == "testing":
    os.chdir(os.path.dirname(os.getcwd()))
    sys.path.append(os.getcwd())

import lossless.component.core.arm as arm_core
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

os.environ["OMP_NUM_THREADS"]="1"
os.environ["MKL_NUM_THREADS"]="1"
torch.set_num_threads(1)
torch.use_deterministic_algorithms(True)

DEVICE = torch.device("cpu")
arm_model = arm_core.Arm(dim_arm=16, n_hidden_layers_arm=2).to(DEVICE)
# load pretrained weights
arm_model.load_state_dict(torch.load("../logs/full_runs/trained_models/coolchic_arm.pth", map_location=DEVICE))
arm_model.eval()

# shape B,C,H,W
latents_dict = torch.load("../logs/full_runs/trained_models/coolchic_latents_snapshot.pt", map_location=DEVICE)
flat_latent = latents_dict["flat_latent"].to(DEVICE)
latent_context_flat = latents_dict["latent_context_flat"].to(DEVICE)
print(flat_latent.shape)
print(latent_context_flat.shape)
latents = torch.zeros((1,1,10,10), dtype=torch.float32).to(DEVICE)
latent_context_flat = torch.concat([arm_model.get_neighbor_context(latents[0,0].tolist(), i, j) for j in range(latents.shape[3]) for i in range(latents.shape[2])])[:10]
print(latent_context_flat.shape)

torch.Size([10])
torch.Size([10, 16])
torch.Size([10, 16])


In [None]:
from typing import OrderedDict, Tuple

import torch
import torch.nn.functional as F
from lossless.util.misc import safe_get_from_nested_lists
from torch import Tensor, index_select, nn


class ArmLinear(nn.Module):
    """Create a Linear layer of the Auto-Regressive Module (ARM). This is a
    wrapper around the usual ``nn.Linear`` layer of PyTorch, with a custom
    initialization. It performs the following operations:

    * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b}` if
      ``residual`` is ``False``

    * :math:`\\mathbf{x}_{out} = \\mathbf{W}\\mathbf{x}_{in} + \\mathbf{b} +
      \\mathbf{x}_{in}` if ``residual`` is ``True``.

    The input  :math:`\\mathbf{x}_{in}` is a :math:`[B, C_{in}]` tensor, the
    output :math:`\\mathbf{x}_{out}` is a :math:`[B, C_{out}]` tensor.

    The layer weight and bias shapes are :math:`\\mathbf{W} \\in
    \\mathbb{R}^{C_{out} \\times C_{in}}` and :math:`\\mathbf{b} \\in
    \\mathbb{R}^{C_{out}}`.
    """

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        residual: bool = False,
    ):
        """
        Args:
            in_channels: Number of input features :math:`C_{in}`.
            out_channels: Number of output features :math:`C_{out}`.
            residual: True to add a residual connection to the layer. Defaults to
                False.
        """

        super().__init__()

        self.residual = residual
        self.in_channels = in_channels
        self.out_channels = out_channels

        # -------- Instantiate empty parameters, set by the initialize function
        self.weight = nn.Parameter(torch.empty(out_channels, in_channels), requires_grad=True)
        self.bias = nn.Parameter(torch.empty((out_channels)), requires_grad=True)
        self.initialize_parameters()
        # -------- Instantiate empty parameters, set by the initialize function

    def initialize_parameters(self) -> None:
        """Initialize **in place** the weight and the bias of the linear layer.

        * Biases are always set to zero.

        * Weights are set to zero if ``residual == True``. Otherwise, sample
          from the Normal distribution: :math:`\\mathbf{W} \sim \\mathcal{N}(0,
          \\tfrac{1}{(C_{out})^4})`.
        """
        self.bias = nn.Parameter(torch.zeros_like(self.bias), requires_grad=True)
        if self.residual:
            self.weight = nn.Parameter(torch.zeros_like(self.weight), requires_grad=True)
        else:
            out_channel = self.weight.size()[0]
            self.weight = nn.Parameter(
                torch.randn_like(self.weight) / out_channel**2, requires_grad=True
            )

    def forward(self, x: Tensor) -> Tensor:
        """Perform the forward pass of this layer.

        Args:
            x: Input tensor of shape :math:`[B, C_{in}]`.

        Returns:
            Tensor with shape :math:`[B, C_{out}]`.
        """
        if self.residual:
            return F.linear(x, self.weight, bias=self.bias) + x

        # Not residual
        else:
            return F.linear(x, self.weight, bias=self.bias)


class Arm(nn.Module):
    """Instantiate an autoregressive probability module, modelling the
    conditional distribution :math:`p_{\\psi}(\\hat{y}_i \\mid
    \\mathbf{c}_i)` of a (quantized) latent pixel :math:`\\hat{y}_i`,
    conditioned on neighboring already decoded context pixels
    :math:`\\mathbf{c}_i \in \\mathbb{Z}^C`, where :math:`C` denotes the
    number of context pixels.

    The distribution :math:`p_{\\psi}` is assumed to follow a Laplace
    distribution, parameterized by an expectation :math:`\\mu` and a scale
    :math:`b`, where the scale and the variance :math:`\\sigma^2` are
    related as follows :math:`\\sigma^2 = 2 b ^2`.

    The parameters of the Laplace distribution for a given latent pixel
    :math:`\\hat{y}_i` are obtained by passing its context pixels
    :math:`\\mathbf{c}_i` through an MLP :math:`f_{\\psi}`:

    .. math::

        p_{\\psi}(\\hat{y}_i \\mid \\mathbf{c}_i) \sim \mathcal{L}(\\mu_i,
        b_i), \\text{ where } \\mu_i, b_i = f_{\\psi}(\\mathbf{c}_i).

    .. attention::

        The MLP :math:`f_{\\psi}` has a few constraint on its architecture:

        * The width of all hidden layers (i.e. the output of all layers except
          the final one) are identical to the number of pixel contexts
          :math:`C`;

        * All layers except the last one are residual layers, followed by a
          ``ReLU`` non-linearity;

        * :math:`C` must be at a multiple of 8.

    The MLP :math:`f_{\\psi}` is made of custom Linear layers instantiated
    from the ``ArmLinear`` class.
    """

    def __init__(self, dim_arm: int, n_hidden_layers_arm: int):
        """
        Args:
            dim_arm: Number of context pixels AND dimension of all hidden
                layers :math:`C`.
            n_hidden_layers_arm: Number of hidden layers. Set it to 0 for
                a linear ARM.
        """
        super().__init__()

        assert dim_arm % 8 == 0, (
            f"ARM context size and hidden layer dimension must be "
            f"a multiple of 8. Found {dim_arm}."
        )
        self.dim_arm = dim_arm
        self.hidden_layer_dim = 8

        # ======================== Construct the MLP ======================== #
        layers_list = nn.ModuleList()
        layers_list.append(ArmLinear(dim_arm, self.hidden_layer_dim, residual=False))

        # Construct the hidden layer(s)
        for i in range(n_hidden_layers_arm):
            layers_list.append(
                ArmLinear(self.hidden_layer_dim, self.hidden_layer_dim, residual=True)
            )
            layers_list.append(nn.ReLU())

        # Construct the output layer. It always has 2 outputs (mu and scale)
        layers_list.append(ArmLinear(self.hidden_layer_dim, 2, residual=False))
        self.mlp = nn.Sequential(*layers_list)
        # ======================== Construct the MLP ======================== #

        self.non_zero_pixel_ctx_index = _get_non_zero_pixel_ctx_index(self.dim_arm)
        self.non_zero_pixel_ctx_shifts = {
            # row = index // 9, col = index % 9
            index: [index // 9 - 4, index % 9 - 4]
            for index in self.non_zero_pixel_ctx_index.tolist()
        }

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
        """Perform the auto-regressive module (ARM) forward pass. The ARM takes
        as input a tensor of shape :math:`[B, C]` i.e. :math:`B` contexts with
        :math:`C` context pixels. ARM outputs :math:`[B, 2]` values correspond
        to :math:`\\mu, b` for each of the :math:`B` input pixels.

        .. warning::

            Note that the ARM expects input to be flattened i.e. spatial
            dimensions :math:`H, W` are collapsed into a single batch-like
            dimension :math:`B = HW`, leading to an input of shape
            :math:`[B, C]`, gathering the :math:`C` contexts for each of the
            :math:`B` pixels to model.

        .. note::

            The ARM MLP does not output directly the scale :math:`b`. Denoting
            :math:`s` the raw output of the MLP, the scale is obtained as
            follows:

            .. math::

                b = e^{x - 4}

        Args:
            x: Concatenation of all input contexts
                :math:`\\mathbf{c}_i`. Tensor of shape :math:`[B, C]`.

        Returns:
            Concatenation of all Laplace distributions param :math:`\\mu, b`.
            Tensor of shape :math:([B]). Also return the *log scale*
            :math:`s` as described above. Tensor of shape :math:`(B)`
        """
        raw_proba_param = self.mlp(x)
        mu = raw_proba_param[:, 0]
        log_scale = raw_proba_param[:, 1]
        print("ARM mu: ")
        for i in range(min(10, mu.size(0))):
            print(f"{mu[i].item():.16f}")
            
        # no scale smaller than exp(-4.6) = 1e-2 or bigger than exp(5.01) = 150
        scale = torch.exp(torch.clamp(log_scale - 4, min=-4.6, max=5.0))

        return mu, scale, log_scale

In [32]:
with torch.no_grad():
    incpecting_index = 1
    mu_batch, scale_batch, log_scale_batch = arm_model.forward(latent_context_flat)
    mu_singl, scale_singl, log_scale_singl = arm_model.forward(latent_context_flat[incpecting_index:incpecting_index+1])
    print(mu_batch[incpecting_index:incpecting_index+1].item())
    print(mu_singl.item())

ARM mu: 
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
-0.0087932804599404
ARM mu: 
-0.0087932804599404
-0.008793280459940434
-0.008793280459940434
