In [2]:
# mask which is used during printing, order is H, W
print_mask = [
                                [-3, 0],
                                [-2, 0],
                      [-1, -1], [-1, 0], [-1, 1],
    [0, -3], [0, -2], [0,  -1]
]

In [6]:
import torch
import torch.nn.functional as F
import numpy as np

def _get_non_zero_pixel_ctx_index(dim_arm: int) -> torch.Tensor:
    """Generate the relative index of the context pixel with respect to the
    actual pixel being decoded.

    1D tensor containing the indices of the non zero context. This corresponds to the one
    in the pattern above. This allows to use the index_select function, which is significantly
    faster than usual indexing.

    0   1   2   3   4   5   6   7   8
    9   10  11  12  13  14  15  16  17
    18  19  20  21  22  23  24  25  26
    27  28  29  30  31  32  33  34  35
    36  37  38  39  *   x   x   x   x
    x   x   x   x   x   x   x   x   x
    x   x   x   x   x   x   x   x   x
    x   x   x   x   x   x   x   x   x
    x   x   x   x   x   x   x   x   x


    Args:
        dim_arm (int): Number of context pixels

    Returns:
        Tensor: 1D tensor with the flattened index of the context pixels.
    """
    # fmt: off
    if dim_arm == 8:
        return torch.tensor(
            [            13,
                         22,
                     30, 31, 32,
             37, 38, 39, #
            ]
        )

    raise ValueError(f"dim_arm {dim_arm} not supported.")


def _get_neighbor(
    x: torch.Tensor, mask_size: int, non_zero_pixel_ctx_idx: torch.Tensor
) -> torch.Tensor:
    """Use the unfold function to extract the neighbors of each pixel in x.

    Args:
        x (Tensor): [1, 1, H, W] feature map from which we wish to extract the
            neighbors
        mask_size (int): Virtual size of the kernel around the current coded latent.
            mask_size = 2 * n_ctx_rowcol - 1
        non_zero_pixel_ctx_idx (Tensor): [N] 1D tensor containing the indices
            of the non zero context pixels (i.e. floor(N ** 2 / 2) - 1).
            It looks like: [0, 1, ..., floor(N ** 2 / 2) - 1].
            This allows to use the index_select function, which is significantly
            faster than usual indexing.

    Returns:
        torch.tensor: [H * W, floor(N ** 2 / 2) - 1] the spatial neighbors
            the floor(N ** 2 / 2) - 1 neighbors of each H * W pixels.
    """
    pad = int((mask_size - 1) / 2)
    x_pad = F.pad(x, (pad, pad, pad, pad), mode="constant", value=0.0)

    # Shape of x_unfold is [B, C, H, W, mask_size, mask_size] --> [B * C * H * W, mask_size * mask_size]
    # reshape is faster than einops.rearrange
    x_unfold = (
        x_pad.unfold(2, mask_size, step=1)
        .unfold(3, mask_size, step=1)
        .reshape(-1, mask_size * mask_size)
    )

    # Convert x_unfold to a 2D tensor: [Number of pixels, all neighbors]
    # This is slower than reshape above
    # x_unfold = rearrange(
    #     x_unfold,
    #     'b c h w mask_h mask_w -> (b c h w) (mask_h mask_w)'
    # )

    # Select the pixels for which the mask is not zero
    # For a N x N mask, select only the first (N x N - 1) / 2 pixels
    # (those which aren't null)
    neighbor = torch.index_select(x_unfold, dim=1, index=non_zero_pixel_ctx_idx)
    return neighbor


def prepare_inputs(image: torch.Tensor, raw_synth_out: torch.Tensor):
    contexts = []
    synthesis_out_params_per_channel = [2, 3, 4]
    mask_size = 9
    context_size = 8
    non_zero_image_arm_ctx_index = _get_non_zero_pixel_ctx_index(context_size)

    assert len(synthesis_out_params_per_channel) == image.shape[1], (
        "Number of channels in image and synthesis_out_params_per_channel "
        "must be equal."
    )

    # First get contexts for all channels in the image
    # Use loop as _get_neighbor supports only [1, 1, H, W] input shape
    for channel_idx in range(len(synthesis_out_params_per_channel)):
        contexts.append(
            _get_neighbor(
                image[:, channel_idx : channel_idx + 1, :, :],
                mask_size,
                non_zero_image_arm_ctx_index,
            )
        )
    # Now concatenate the num_channels [H *W, context_size] shaped image contexts
    # into [H *W, context_size * num_channels]
    flat_image_context = torch.stack(contexts, dim=2).reshape(
        (image.shape[2] * image.shape[3], -1)
    )

    # Add synthesis output and already decoded channels information
    prepared_inputs = []
    for channel_idx in range(len(synthesis_out_params_per_channel)):
        prepared_inputs.append(
            torch.cat(
                [
                    flat_image_context,
                    # synthesis output has shape [1, C, H, W], we want [H*W, C]
                    raw_synth_out.permute(0, 2, 3, 1).reshape(
                        -1, sum(synthesis_out_params_per_channel)
                    ),
                    # append the couple of already decoded channels
                    (
                        image[:, :channel_idx]
                        .permute(0, 2, 3, 1)
                        .reshape(
                            -1,
                            channel_idx,
                        )
                        if channel_idx > 0
                        else torch.empty(
                            image.shape[2] * image.shape[3],
                            0,
                            dtype=image.dtype,
                            device=image.device,
                        )
                    ),
                ],
                dim=1,
            )
        )
    return prepared_inputs

In [7]:
# image = np.arange(3 * 81, dtype=int).reshape((1, 3, 9, 9))
C, H, W = 3, 4, 5
R = 9
CHW = C * H * W
image = np.array(
    [
        [
            [[h * W * C + w * C + c for w in range(W)] for h in range(H)]
            for c in range(C)
        ]
    ],
    dtype=int,
)
raw_synth_out = np.arange(R * H * W, dtype=int).reshape((1, H, W, R)).transpose(0, 3, 1, 2) + CHW
image = torch.tensor(image)
raw_synth_out = torch.tensor(raw_synth_out)

# print(image)
# print()
# print(raw_synth_out)

prepared_inputs = prepare_inputs(image, raw_synth_out)
r, g, b = prepared_inputs
# print(r.shape)
# print(g.shape)
# print(b.shape)

In [8]:
assert isinstance(image, torch.Tensor)

def f_num(num, width=4):
    """Formats a number to `width` characters"""
    formatted = str(int(num))
    return formatted.rjust(width)


def group_triplets(a: list):
    """Groups a list into triplets for showing pixel values"""
    if len(a) % 3 != 0:
        return a
    grouped = []
    for i in range(0, len(a), 3):
        grouped.append("".join(a[i : i + 3]).replace(" ", "_"))
    return grouped


def visualize_context(image: torch.Tensor, mask: list, point: list = [3, 3]):
    _, C, H, W = image.shape
    im_vals = [[[f_num(image[0, c, h, w].tolist()) for w in range(W)] for h in range(H)] for c in range(C)]
    
    for dx, dy in mask:
        x, y = point[0] + dx, point[1] + dy
        if 0 <= x < H and 0 <= y < W:
            for c in range(C):
                im_vals[c][x][y] = f"  x"
    
    for c in range(C):
        print(f"Channel {c}:")
        for h in range(H):
            print(" ".join(im_vals[c][h]))
        print()

visualize_context(image, print_mask, point = [3,4])

# 0 = red
# 1 = green
# 2 = blue
chosen_channel = 2
working_with = prepared_inputs[chosen_channel]
check_bounds = [[0, 24], [24, 24 + 9], [24 + 9, 24 + 9 + 10]]
# 0 = context pixels
# 1 = raw_synth_out
# 2 = intra pixel context (red for green, red, green for blue)
check_part_index = 2
working_with = working_with[
    :, check_bounds[check_part_index][0] : check_bounds[check_part_index][1]
]
for ind in range(H * W):
    # if ind != 18:
    #     continue
    tmp = list(map(f_num, working_with[ind].detach().cpu().numpy()))
    tmp = group_triplets(tmp)
    img_vals_at_pixel = list(
        map(f_num, image[0, :, ind // W, ind % W].tolist())
    )
    print(f"ind: {str(ind).rjust(2)}, [", *img_vals_at_pixel, "], ", *tmp)

Channel 0:
   0    3    6    9   x
  15   18   21   24   x
  30   33   36   x   x
  45   x   x   x   57

Channel 1:
   1    4    7   10   x
  16   19   22   25   x
  31   34   37   x   x
  46   x   x   x   58

Channel 2:
   2    5    8   11   x
  17   20   23   26   x
  32   35   38   x   x
  47   x   x   x   59

ind:  0, [    0    1    2 ],     0    1
ind:  1, [    3    4    5 ],     3    4
ind:  2, [    6    7    8 ],     6    7
ind:  3, [    9   10   11 ],     9   10
ind:  4, [   12   13   14 ],    12   13
ind:  5, [   15   16   17 ],    15   16
ind:  6, [   18   19   20 ],    18   19
ind:  7, [   21   22   23 ],    21   22
ind:  8, [   24   25   26 ],    24   25
ind:  9, [   27   28   29 ],    27   28
ind: 10, [   30   31   32 ],    30   31
ind: 11, [   33   34   35 ],    33   34
ind: 12, [   36   37   38 ],    36   37
ind: 13, [   39   40   41 ],    39   40
ind: 14, [   42   43   44 ],    42   43
ind: 15, [   45   46   47 ],    45   46
ind: 16, [   48   49   50 ],    48   49
ind: 

In [5]:
assert isinstance(image, torch.Tensor)

contexts = []
non_zero_image_arm_ctx_index = _get_non_zero_pixel_ctx_index(8)
for channel_idx in range(3):
    contexts.append(
        _get_neighbor(
            image[:, channel_idx : channel_idx + 1, :, :],
            9,
            non_zero_image_arm_ctx_index,
        )
    )
# Now concatenate the num_channels [H *W, context_size] shaped image contexts
# into [H *W, context_size * num_channels]
# print(contexts)
# print(contexts[0].shape)
flat_image_context = torch.stack(contexts, dim=2).reshape((H*W, -1))
print(flat_image_context)
print(flat_image_context.shape)

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  0,  0,  0,  1,  2],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
          0,  1,  2,  3,  4,  5],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,
          3,  4,  5,  6,  7,  8],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  3,  4,  5,
          6,  7,  8,  9, 10, 11],
        [ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5,  0,  0,  0,
          0,  0,  0,  0,  0,  0],
        [ 0,  0,  0,  0,  0,  0,  0,  1,  2,  3,  4,  5,  6,  7,  8,  0,  0,  0,
          0,  0,  0, 15, 16, 17],
        [ 0,  0,  0,  0,  0,  0,  3,  4,  5,  6,  7,  8,  9, 10, 11,  0,  0,  0,
         15, 16, 17, 18, 19, 20],
        [ 0,  0,  0,  0,  0,  0,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,