In [1]:
import math
import torch
from torch import nn

In [2]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

%load_ext autoreload
%autoreload 2

In [3]:
import torch
import numpy as np

def positional_encoding(x, y, num_pos_feats=128, temperature=10000):
    """
    Generates sine-cosine positional embeddings.

    It is important that the x-coordinates, and y-coordinates are normalized to [0, 1]
    so that the model can learn relative positional relationships, regardless of the
    input image's dimensions.

    Args:
        x: Normalized x-coordinates (tensor).
        y: Normalized y-coordinates (tensor).
        num_pos_feats: Number of positional features.
        temperature: Temperature parameter for frequency scaling.

    Returns:
        Positional embeddings (tensor).
    """
    scale = 2 * np.pi
    x = x * scale
    y = y * scale

    # generate a set of exponentially spaced frequencies for the sine and cosine
    # positional embeddings. These frequencies control the wavelengths of the sine
    # and cosine waves, allowing the model to capture positional information
    # at different scales.
    # The frequencies step by 2 to ensure that the frequencies used for
    # sine and cosines that are paired together are the same.
    # e.g. [0,0,2,2,4,4,6,6,...,62,62]
    dim_t = torch.arange(0, num_pos_feats, 2, dtype=torch.float32, device=x.device)
    # dim_t / num_pos_feats normalizes the range to [0,1]
    # {temperature**0 = 1, ...., temperature**1} exponential spacing of wavelengths
    # from lower frequencies (longer wavelength) to higher frequencies
    # (shorter wavelength)
    dim_t = temperature ** (dim_t / num_pos_feats)

    # Dividing the normalized coordinates by the frequency is equivalent to multipying by the period.
    # We're scaling the coordinates based on the wavelength of the sine and cosine waves.
    # We're converting the normalized coordinates [0, 2pi] into a phase value for the sine and cosine functions,
    # which is what the sine and consine functions expect as their input.
    # x[:,:,None] / dim_t broadcasting that each element in x is divided by each element of dim_t
    # to create a new tensor of shape (batch_size, height, width, num_pos_feats)
    # Each element pos_x[b, h, w, i] represents the normalized x-coordinate at position (h, w) in batch b,
    # divided by the i-th frequency in dim_t.
    pos_x = x[..., None] / dim_t
    pos_y = y[..., None] / dim_t
    # stack then flatten() will interleave the sine and cosine
    pos_x = torch.stack((pos_x.sin(), pos_x.cos()), dim=-1).flatten(-2)
    pos_y = torch.stack((pos_y.sin(), pos_y.cos()), dim=-1).flatten(-2)
    pos = torch.cat((pos_y, pos_x), dim=-1).permute(0,3,1,2)
    return pos

def batch_positional_encoding(batch_shape, heights, widths, num_pos_feats=128, temperature=10000):
    batch_grid_x = torch.zeros(batch_shape)
    batch_grid_y = torch.zeros(batch_shape)
    for batch_i, (height, width) in enumerate(zip(heights, widths)):

        x_axis = torch.linspace(0, 1, width)
        y_axis = torch.linspace(0, 1, height)
        grid_y, grid_x = torch.meshgrid(y_axis, x_axis, indexing="ij")
        batch_grid_x[batch_i,:height,:width] = grid_x
        batch_grid_y[batch_i,:height,:width] = grid_y
    
    return positional_encoding(batch_grid_x, batch_grid_y, num_pos_feats, temperature)

# # Example usage:
# batch_size, height, width = 2, 10, 10
# x = torch.linspace(0, 1, width).unsqueeze(0).unsqueeze(0).repeat(batch_size, height, 1)
# y = torch.linspace(0, 1, height).unsqueeze(0).unsqueeze(2).repeat(batch_size, 1, width)
# pos_emb = positional_encoding(x, y, num_pos_feats=128)
# print(f"Positional embedding shape: {pos_emb.shape}")


batch_shape = (3, 224, 250) # B, H, W
B,H,W = batch_shape
heights = [168, 224, 200]
widths = [168, 200, 250]
pos_emb = batch_positional_encoding(batch_shape, heights, widths, 128)
print(f"Positional embedding shape: {pos_emb.shape}")

Positional embedding shape: torch.Size([3, 256, 224, 250])


In [4]:
# should technically be able to do all the multiplications and stuff on 1D tensors

In [5]:
batch_shape = (3, 224, 256) # B, H, W
B,H,W = batch_shape
heights = [168, 224, 200]
widths = [168, 200, 256]

In [6]:
# x_axis = torch.arange(10)
# y_axis = torch.arange(10)

batch_grid_x = torch.zeros(batch_shape)
batch_grid_y = torch.zeros(batch_shape)
for batch_i, (height, width) in enumerate(zip(heights, widths)):
    
    x_axis = torch.linspace(0, 1, width)
    y_axis = torch.linspace(0, 1, height)
    grid_y, grid_x = torch.meshgrid(y_axis, x_axis, indexing="ij")
    batch_grid_x[batch_i,:height,:width] = grid_x
    batch_grid_y[batch_i,:height,:width] = grid_y

In [7]:
batch_grid_x

tensor([[[0.0000, 0.0060, 0.0120,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0060, 0.0120,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0060, 0.0120,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0050, 0.0101,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0050, 0.0101,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0050, 0.0101,  ..., 0.0000, 0.0000, 0.0000],
         ...,
         [0.0000, 0.0050, 0.0101,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0050, 0.0101,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0050, 0.0101,  ..., 0.0000, 0.0000, 0.0000]],

        [[0.0000, 0.0039, 0.0078,  ..., 0.9922, 0.9961, 1.0000],
         [0.0000, 0.0039, 0.0078,  ..., 0.9922, 0.9961, 1.0000],
         [0.0000, 0.0039, 0.0078,  ..., 0.9922, 0.9961, 1.

In [33]:
# x_sin = torch.arange(4*3).reshape((1,2,2,3))
# x_cos = torch.arange(4*3,2*4*3).reshape((1,2,2,3))
# x_pos1 = torch.cat([x_sin, x_cos], dim=3)
# x_pos2 = torch.stack([x_sin, x_cos], dim=4).flatten(3)
# x_pos3 = torch.stack([x_sin, x_cos],dim=-1).flatten(-2)

In [55]:
from detr.position_encoding import MyPositionalEncoding, PositionEmbeddingSine

In [56]:
officialPosEncoding = PositionEmbeddingSine(num_pos_feats=128, normalize=True, eps=0)
myPosEncoding = MyPositionalEncoding(num_pos_feats=128)

In [57]:
batch_shape = (3, 224, 256) # B, H, W
B,H,W = batch_shape
heights = [168, 224, 200]
widths = [168, 200, 256]

In [58]:
mask = torch.ones((batch_shape), dtype=torch.bool)
for batch_index, (height, width) in enumerate(zip(heights, widths)):
    mask[batch_index, :height, :width] = 0

In [59]:
official_pos_embed = officialPosEncoding(mask)

In [60]:
my_pos_embed = myPosEncoding(batch_shape, heights, widths)

In [62]:
torch.allclose(official_pos_embed, my_pos_embed)

False