In [3]:
words = ["cat", "dog", "fish", "cat", "dog", "bird"]

In [1]:
import torch
import numpy as np

In [5]:
class PositionalEncoding(torch.nn.Module):
  def __init__(self, embed_dim, max_len= 5000):
    super(PositionalEncoding, self).__init__()
    pe = torch.zeros(max_len, embed_dim)
    position = torch.arange(0, max_len, dtype= torch.float).unsqueeze(1)
    div_term = torch.exp(torch.arange(0, embed_dim, 2).float() * (-np.log(10000.0) / embed_dim))
    pe[:, 0::2] = torch.sin(position * div_term)
    pe[:, 1::2] = torch.cos(position * div_term)
    pe = pe.unsqueeze(0).transpose(0, 1)
    self.register_buffer('pe', pe.unsqueeze(0))

  def forward(self, x):
    seq_len = x.size(1)
    return x + self.pe[:, :seq_len]

In [6]:
if __name__ == '__main__':
  embed_dim = 4
  batch_size = 1
  seq_len = len(words)
  # Example token embeddings for simplicity
  input_tensor = torch.tensor([
      [0.1, 0.3, 0.5, 0.7],  # "cat"
      [0.2, 0.4, 0.6, 0.8],  # "dog"
      [0.3, 0.5, 0.7, 0.9],  # "fish"
      [0.1, 0.3, 0.5, 0.7],  # "cat"
      [0.2, 0.4, 0.6, 0.8],  # "dog"
      [0.4, 0.6, 0.8, 1.0],  # "bird"
  ]).unsqueeze(0)  # Add batch dimension: Shape (1, seq_len, embed_dim)
  positional_encoding = PositionalEncoding(embed_dim= embed_dim, max_len= seq_len)
  output_tensor = positional_encoding(input_tensor)
  print(output_tensor)

tensor([[[[ 0.1000,  1.3000,  0.5000,  1.7000],
          [ 0.2000,  1.4000,  0.6000,  1.8000],
          [ 0.3000,  1.5000,  0.7000,  1.9000],
          [ 0.1000,  1.3000,  0.5000,  1.7000],
          [ 0.2000,  1.4000,  0.6000,  1.8000],
          [ 0.4000,  1.6000,  0.8000,  2.0000]],

         [[ 0.9415,  0.8403,  0.5100,  1.6999],
          [ 1.0415,  0.9403,  0.6100,  1.8000],
          [ 1.1415,  1.0403,  0.7100,  1.9000],
          [ 0.9415,  0.8403,  0.5100,  1.6999],
          [ 1.0415,  0.9403,  0.6100,  1.8000],
          [ 1.2415,  1.1403,  0.8100,  1.9999]],

         [[ 1.0093, -0.1161,  0.5200,  1.6998],
          [ 1.1093, -0.0161,  0.6200,  1.7998],
          [ 1.2093,  0.0839,  0.7200,  1.8998],
          [ 1.0093, -0.1161,  0.5200,  1.6998],
          [ 1.1093, -0.0161,  0.6200,  1.7998],
          [ 1.3093,  0.1839,  0.8200,  1.9998]],

         [[ 0.2411, -0.6900,  0.5300,  1.6996],
          [ 0.3411, -0.5900,  0.6300,  1.7996],
          [ 0.4411, -0.4900,  0.73