In [7]:
import torch


def positional_encoding(
    tensor, num_encoding_functions=6, include_input=True
) -> torch.Tensor:
    """
    Applies positional encoding to the input tensor.

    Args:
    tensor: Input tensor to be positionally encoded.
    num_encoding_functions: Number of encoding functions used to compute the positional encoding.
    include_input: Whether or not to include the input in the positional encoding.

    Returns:
    torch.Tensor: Positionally encoded tensor.
    """
    encoding = [tensor] if include_input else []

    for i in range(num_encoding_functions):
        for func in [torch.sin, torch.cos]:
            encoding.append(func(2**i * tensor))

    return torch.cat(encoding, dim=-1)

In [9]:
# make a test tensor, with batch_size 5 and each element is a timestep (between 0 and 10)
x = torch.randint(0, 10, (1, 10))
positional_encoding(x).shape

torch.Size([1, 130])

In [40]:
import torch.nn as nn


class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float()
            * (-torch.log(torch.tensor(10000.0)) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.pe = pe

    def forward(self, x):
        x = self.pe[:, : x.size(1), :].expand(x.size(0), -1, -1)
        return self.dropout(x)

In [26]:
x = torch.randint(0, 10, (1, 10))
res = PositionalEncoding(20)(x)

In [51]:
class TransformerEncoder(nn.Module):
    def __init__(self, input_size, d_model, nhead, num_layers, dropout=0.1):
        super(TransformerEncoder, self).__init__()
        self.input_size = input_size
        self.d_model = d_model
        self.nhead = nhead
        self.num_layers = num_layers
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model, nhead=nhead, dropout=dropout
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers)
        self.embedder = nn.Linear(input_size, d_model)

    def forward(self, x):
        x = self.embedder(x)
        x = self.transformer_encoder(x)
        return x

In [52]:
pos_enc_time = PositionalEncoding(12)
t = torch.randint(0, 10, (20, 5, 1))
t = pos_enc_time(t)
x = torch.randint(0, 10, (20, 5, 3 + 3 + 12))
x = x.view(20, -1)
x = x.to(dtype=torch.float)
transformer_encoder = TransformerEncoder(x.shape[1], 256, 4, 2)
transformer_encoder(x).shape

torch.Size([20, 256])