In [1]:
import torch
import numpy as np
from typing import Any, List, Sequence, Optional, Tuple, Iterator, Dict, Callable, Union, Type

import jax
from jax import jit
import jax.numpy as jnp

In [2]:
def pos_encoding(t, channels):
    t = t.unsqueeze(-1).type(torch.float)
    
    inv_freq = 1.0 / (
        10000
        ** (torch.arange(0, channels, 2).float() / channels)
    )
    pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
    pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
    pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=1)
    return pos_enc
        

In [3]:
t = torch.Tensor([100, 200, 300, 400])
print(t.shape)
enc = pos_encoding(t, channels=256)
print(enc.shape)

torch.Size([4])
torch.Size([4, 256])


In [4]:
print(t)
print(enc)

tensor([100., 200., 300., 400.])
tensor([[-0.5064, -0.9286, -0.9795,  ...,  0.9999,  0.9999,  0.9999],
        [-0.8733, -0.6892, -0.3943,  ...,  0.9997,  0.9997,  0.9998],
        [-0.9998,  0.4170,  0.8208,  ...,  0.9993,  0.9994,  0.9995],
        [-0.8509,  0.9988,  0.7247,  ...,  0.9988,  0.9989,  0.9991]])


$$\text{P}(k, 2i) = \text{sin}\Big(\frac{k}{n^{2i/d}} \Big)$$
$$\text{P}(k, 2i+1) = \text{cos}\Big(\frac{k}{n^{2i/d}} \Big)$$

See [this]() article abut p

In [5]:
def getPositionEncoding(t: List[int], d: int, n: int = 10000) -> np.ndarray:
    """
    Generate the positional encoding for given positions 't' and dimension 'd'.
    The encoding has sine values at even indices and cosine values at odd indices.

    Args:
    -----
        t: List[int]
            List of positions for which the encoding is to be generated.
        d: int
            Dimension of the encoding vector for each position.
        n: int
            The base of the denominator in the encoding formula. Default value is 10000.

    Returns:
    --------
        P: np.ndarray
            A 2D numpy array with shape (len(t), d) representing the positional encoding.
    """

    # Determine the number of positions
    seq_len = len(t)
    
    # Initialize the encoding matrix with zeros
    P = jnp.zeros((seq_len, d))

    # For each position in t
    for k_idx, k in enumerate(t):
        # For each dimension up to d/2
        for i in jnp.arange(int(d/2)):
            # Calculate the denominator for this dimension
            denominator = jnp.power(n, 2*i/d)

            # Calculate the sine and cosine encoding for this position and dimension
            # P[k_idx, 2*i] = jnp.sin(k/denominator)
            # P[k_idx, 2*i+1] = jnp.cos(k/denominator)
            # x = x.at[idx].set(y)
            P = P.at[k_idx, 2*i].set(jnp.sin(k/denominator))
            P = P.at[k_idx, 2*i+1].set(jnp.cos(k/denominator))
    
    return P


t = jnp.array([100, 200, 300, 400])
print(type(t))
P = getPositionEncoding(t, d=256)
print(P.shape)
print(P)

<class 'jaxlib.xla_extension.ArrayImpl'>
(4, 256)
[[-0.50636566  0.8623189  -0.9285823  ...  0.9999333   0.01074587
   0.99994224]
 [-0.87329733  0.48718765 -0.68924314 ...  0.9997333   0.0214905
   0.99976903]
 [-0.99975586 -0.02209662  0.41700336 ...  0.99939996  0.03223265
   0.99948037]
 [-0.85091937 -0.52529633  0.9987548  ...  0.9989334   0.04297107
   0.9990763 ]]


In [6]:
def getPositionEncodingNumpy(t: List[int], d: int, n: int = 10000) -> np.ndarray:
    """
    Generate the positional encoding for given positions 't' and dimension 'd'.
    The encoding has sine values at even indices and cosine values at odd indices.

    Args:
    -----
        t: List[int]
            List of positions for which the encoding is to be generated.
        d: int
            Dimension of the encoding vector for each position.
        n: int
            The base of the denominator in the encoding formula. Default value is 10000.

    Returns:
    --------
        P: np.ndarray
            A 2D numpy array with shape (len(t), d) representing the positional encoding.
    """

    # Determine the number of positions
    seq_len = len(t)
    
    # Initialize the encoding matrix with zeros
    P = np.zeros((seq_len, d))

    # For each position in t
    for k_idx, k in enumerate(t):
        # For each dimension up to d/2
        for i in np.arange(int(d/2)):
            # Calculate the denominator for this dimension
            denominator = np.power(n, 2*i/d)

            # Calculate the sine and cosine encoding for this position and dimension
            P[k_idx, 2*i] = np.sin(k/denominator)
            P[k_idx, 2*i+1] = np.cos(k/denominator)
    
    return P

**Testing for small sequences**

In [7]:
t = jnp.array([100, 200, 300, 400])
%timeit getPositionEncoding(t, d=256).block_until_ready()

2.76 s ± 57.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
t = torch.Tensor([100, 200, 300, 400]) 
%timeit pos_encoding(t, channels=256)

60.9 µs ± 79.4 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [9]:
t = jnp.array([100, 200, 300, 400])
encodingJit = jit(getPositionEncoding, static_argnums=1)
%timeit encodingJit(t, d=256).block_until_ready()

5.79 ms ± 1.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [10]:
t = np.array([100, 200, 300, 400])
%timeit getPositionEncodingNumpy(t, d=256)

2.01 ms ± 14.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


**Testing for moderate sequences**

In [11]:
images = np.random.randn(128, 10)
noise_steps=1000
t = np.random.randint(low=1, high=noise_steps, size=(images.shape[0],))

%timeit getPositionEncodingNumpy(t, d=256)

63.8 ms ± 37 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [12]:
key = jax.random.PRNGKey(0)  # Use a PRNGKey for random number generation
images = jax.random.normal(key, (128,10))
noise_steps=1000
t = jax.random.randint(key, minval=1, maxval=noise_steps, shape=(images.shape[0],))

%timeit getPositionEncoding(t, d=256).block_until_ready()

1min 28s ± 523 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [13]:
images = torch.randn(128, 10)
noise_steps=1000
t = torch.randint(low=1, high=noise_steps, size=(images.shape[0],))

%timeit pos_encoding(t, channels=256)

126 µs ± 2.1 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
key = jax.random.PRNGKey(0)  # Use a PRNGKey for random number generation
images = jax.random.normal(key, (128,10))
noise_steps=1000
t = jax.random.randint(key, minval=1, maxval=noise_steps, shape=(images.shape[0],))

encodingJit = jit(getPositionEncoding, static_argnums=1)
%timeit encodingJit(t, d=256).block_until_ready()