In [1]:
import os
import sys

if os.path.basename(os.getcwd()) == "testing":
    os.chdir(os.path.dirname(os.getcwd()))
    sys.path.append(os.getcwd())

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

os.environ["OMP_NUM_THREADS"]="1"
os.environ["MKL_NUM_THREADS"]="1"
torch.set_num_threads(1)
torch.use_deterministic_algorithms(True)
torch.set_printoptions(precision=16, profile="full")
np.set_printoptions(precision=16, suppress=False)

DEVICE = torch.device("cpu")

def float_to_full_decimal(x):
    """Return full decimal precision for numpy or python float."""
    if isinstance(x, np.floating):
        dt = x.dtype
        if dt == np.float32:
            return format(x, '.9f')
        elif dt == np.float64:
            return format(x, '.17f')
        elif dt == np.float16:
            # float16 has ~4 decimal digits
            return format(float(x), '.7f')
        elif dt == np.float128:
            return format(x, '.36f')  # depending on platform
        else:
            return repr(x)

    if isinstance(x, float):  # python float (float64)
        return format(x, '.17f')

    return repr(x)

def float_to_bits(x):
    """Return IEEE-754 bitstring for numpy or python float."""
    if isinstance(x, np.floating):
        dt = x.dtype
        # Use numpy's view machinery
        return np.binary_repr(x.view(np.uint64 if dt == np.float64 else np.uint32),
                              width=64 if dt == np.float64 else 32)

    if isinstance(x, float):
        # Python float = IEEE754 float64
        return np.binary_repr(np.float64(x).view(np.uint64), width=64)

    raise TypeError(f"Unsupported type: {type(x)}")

def str_full_precision_tensor(a):
    """
    Print every element of a torch.Tensor or numpy.ndarray with:
      • full decimal precision
      • full binary representation (IEEE754 bits)
    """

    if isinstance(a, torch.Tensor):
        arr = a.detach().cpu().numpy()
    elif isinstance(a, np.ndarray):
        arr = a
    else:
        raise TypeError("Input must be torch.Tensor or np.ndarray")

    all_indices = list(np.ndindex(arr.shape))
    index_strs = [str(idx) for idx in all_indices]
    max_idx_len = max(len(s) for s in index_strs)

    # print("Shape:", arr.shape)
    # print("Dtype:", arr.dtype)
    # print()
    res_list = []
    for idx, x in np.ndenumerate(arr):
        idx_str = str(idx).ljust(max_idx_len)
        dec = float_to_full_decimal(x)
        bits = float_to_bits(x)
        res_list.append(f"{idx_str}: value={dec}, bits={bits}")
    return "\n".join(res_list)

In [4]:
class ArmLinear(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        residual: bool = False,
    ):
        super().__init__()

        self.residual = residual
        self.in_channels = in_channels
        self.out_channels = out_channels

        # -------- Instantiate empty parameters
        self.weight = nn.Parameter(torch.empty(out_channels, in_channels), requires_grad=True)
        self.bias = nn.Parameter(torch.empty((out_channels)), requires_grad=True)

    def forward(self, x: torch.Tensor, inspecting_index: int) -> torch.Tensor:
        if x.size(0) > 1:
            print(f"IN:  {str_full_precision_tensor(x[inspecting_index:inspecting_index+1])}")
        else:
            print(f"IN:  {str_full_precision_tensor(x)}")

        lin = x @ self.weight.t() + self.bias

        if self.residual:
            res = lin + x
        else:
            res = lin
        if x.size(0) > 1:
            print(f"OUT: {str_full_precision_tensor(res[inspecting_index:inspecting_index+1])}")
        else:
            print(f"OUT: {str_full_precision_tensor(res)}")
        return res


class SequentialLike(nn.Module):
    def __init__(self, *layers: nn.Module):
        super().__init__()

        for i, layer in enumerate(layers):
            self.add_module(str(i), layer)

    def forward(self, x: torch.Tensor, inspecting_index: int) -> torch.Tensor:
        for i, layer in enumerate(self.children()):  # iterates in registration order
            if i <= 1:
                if isinstance(layer, ArmLinear):
                    x = layer(x, inspecting_index)
                else:
                    x = layer(x)
        return x


class Arm(nn.Module):
    def __init__(self, dim_arm: int, n_hidden_layers_arm: int):
        super().__init__()

        self.dim_arm = dim_arm
        self.hidden_layer_dim = 8

        # ======================== Construct the MLP ======================== #
        layers_list = nn.ModuleList()
        layers_list.append(ArmLinear(dim_arm, self.hidden_layer_dim, residual=False))

        # Construct the hidden layer(s)
        for _ in range(n_hidden_layers_arm):
            layers_list.append(
                ArmLinear(self.hidden_layer_dim, self.hidden_layer_dim, residual=True)
            )
            layers_list.append(nn.ReLU())

        # Construct the output layer. It always has 2 outputs (mu and scale)
        layers_list.append(ArmLinear(self.hidden_layer_dim, 2, residual=False))
        self.mlp = SequentialLike(*layers_list)
        # ======================== Construct the MLP ======================== #

    def forward(
        self, x: torch.Tensor, inspecting_index: int
    ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """
        Inputs:
            x - tensor of shape (B, C), where C = dim_arm
        """
        raw_proba_param = self.mlp(x, inspecting_index)
        mu = raw_proba_param[:, 0]
        log_scale = raw_proba_param[:, 1]

        # no scale smaller than exp(-4.6) = 1e-2 or bigger than exp(5.01) = 150
        scale = torch.exp(torch.clamp(log_scale - 4, min=-4.6, max=5.0))

        return mu, scale, log_scale

In [5]:
arm_model = Arm(dim_arm=16, n_hidden_layers_arm=2).to(DEVICE)
# load pretrained weights
arm_model.load_state_dict(
    torch.load("../logs/full_runs/trained_models/coolchic_arm.pth", map_location=DEVICE)
)
arm_model.eval()

latents_dict = torch.load(
    "../logs/full_runs/trained_models/coolchic_latents_snapshot.pt", map_location=DEVICE
)
flat_latent = latents_dict["flat_latent"].to(DEVICE)
latent_context_flat = latents_dict["latent_context_flat"].to(DEVICE)
# latents = torch.zeros((1,1,10,10), dtype=torch.float32).to(DEVICE)
# latent_context_flat = torch.concat([arm_model.get_neighbor_context(latents[0,0].tolist(), i, j) for j in range(latents.shape[3]) for i in range(latents.shape[2])])[:10]

with torch.no_grad():
    incpecting_index = 1
    print("="*80)
    mu_batch, scale_batch, log_scale_batch = arm_model.forward(latent_context_flat, incpecting_index)
    print("="*80)
    mu_singl, scale_singl, log_scale_singl = arm_model.forward(
        latent_context_flat[incpecting_index : incpecting_index + 1], incpecting_index
    )
    print("="*80)
    print(str_full_precision_tensor(mu_batch[incpecting_index : incpecting_index + 1]))
    print(str_full_precision_tensor(mu_singl))

IN:  (0, 0) : value=0.000000000, bits=00000000000000000000000000000000
(0, 1) : value=0.000000000, bits=00000000000000000000000000000000
(0, 2) : value=0.000000000, bits=00000000000000000000000000000000
(0, 3) : value=0.000000000, bits=00000000000000000000000000000000
(0, 4) : value=0.000000000, bits=00000000000000000000000000000000
(0, 5) : value=0.000000000, bits=00000000000000000000000000000000
(0, 6) : value=0.000000000, bits=00000000000000000000000000000000
(0, 7) : value=0.000000000, bits=00000000000000000000000000000000
(0, 8) : value=0.000000000, bits=00000000000000000000000000000000
(0, 9) : value=0.000000000, bits=00000000000000000000000000000000
(0, 10): value=0.000000000, bits=00000000000000000000000000000000
(0, 11): value=0.000000000, bits=00000000000000000000000000000000
(0, 12): value=0.000000000, bits=00000000000000000000000000000000
(0, 13): value=0.000000000, bits=00000000000000000000000000000000
(0, 14): value=0.000000000, bits=00000000000000000000000000000000
(0, 1