In [2]:
from einops import rearrange
import torch
import torch.nn as nn
import numpy as np

In [19]:
from transformers import DistilBertTokenizerFast, DistilBertModel


def get_embedding(sentence):
    tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")
    tokens = tokenizer.encode(sentence, return_tensors='pt', padding="max_length", max_length=20)
    model = DistilBertModel.from_pretrained("distilbert-base-uncased")
    return model.embeddings.word_embeddings(tokens)

In [20]:
x1 = get_embedding('my name is jungwoo')
x2 = get_embedding('hi bye')

In [33]:
batch_sample = torch.cat([x1, x2])
batch_sample.size()

d_model = 768
max_length = 20

In [64]:
import math
class PositionalEncoding(nn.Module):
    def __init__(
            self, 
            d_model: int, 
            dropout: float, 
            max_length: int,
        ):
        super().__init__()
        
        self.dropout = nn.Dropout(dropout)
 
        # Encoding - From formula
        pos_encoding = torch.zeros(max_length, d_model)
        positions = rearrange(torch.arange(0, max_length, dtype=torch.float), 'm -> m 1')

        division_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0)) / d_model) # 1000^(2i/dim_model)

        pos_encoding[:, 0::2] = torch.sin(positions * division_term)
        pos_encoding[:, 1::2] = torch.cos(positions * division_term)

        # Saving buffer (same as parameter without gradients needed)
        self.pos_encoding = rearrange(pos_encoding, 'm d -> m 1 d')
 
    def forward(self, token_embedding: torch.tensor) -> torch.tensor:
        # Residual connection + pos encoding
        return self.dropout(token_embedding + self.pos_encoding[:token_embedding.size(0), :])

In [65]:
pe = PositionalEncoding(d_model, 0.1, max_length)

In [66]:
pe(batch_sample)

tensor([[[ 4.3281e-02,  1.0974e+00, -0.0000e+00,  ...,  1.1785e+00,
           2.5599e-02,  1.1376e+00],
         [ 1.5603e-02,  0.0000e+00, -6.7736e-02,  ...,  1.0172e+00,
          -0.0000e+00,  1.1001e+00],
         [-1.8944e-04,  1.1115e+00, -4.2494e-02,  ...,  1.0638e+00,
          -3.2619e-02,  1.0424e+00],
         ...,
         [-1.8499e-02,  1.0371e+00, -1.8143e-02,  ...,  1.0889e+00,
          -5.7111e-02,  1.0818e+00],
         [-1.8499e-02,  1.0371e+00, -1.8143e-02,  ...,  1.0889e+00,
          -5.7111e-02,  1.0818e+00],
         [-1.8499e-02,  1.0371e+00, -1.8143e-02,  ...,  1.0889e+00,
          -5.7111e-02,  1.0818e+00]],

        [[ 0.0000e+00,  5.8665e-01,  8.9732e-01,  ...,  0.0000e+00,
           0.0000e+00,  1.1376e+00],
         [ 8.9727e-01,  4.9398e-01,  9.7411e-01,  ...,  1.0203e+00,
          -1.2909e-01,  0.0000e+00],
         [ 1.0154e+00,  5.0047e-01,  9.6167e-01,  ...,  1.1025e+00,
          -2.8866e-02,  1.0463e+00],
         ...,
         [ 9.1647e-01,  5