In [1]:
import torch
from torch import nn

In [2]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, num_embeddings, embedding_dim, device):
        super().__init__(num_embeddings, embedding_dim, padding_idx=0, device=device)

In [3]:
class PositionalEmbedding(nn.Module):
    def __init__(self, embedding_dim, max_len, device):
        super().__init__()
        
        assert (
            embedding_dim % 2 == 0
        ), "Embedding dimension must be even under this implementation"
        
        self.encoding=torch.zeros(max_len, embedding_dim, device=device)
        self.encoding.requires_grad = False
        pos = torch.arange(0, max_len, device=device)
        pos = pos.float().unsqueeze(dim=1)
        _2i = torch.arange(0, embedding_dim, step=2, device=device).float()
        self.encoding[:, 0::2] = torch.sin(pos/10000**(_2i/embedding_dim))
        self.encoding[:, 1::2] = torch.cos(pos/10000**(_2i/embedding_dim)) # This is why positional embeding should be a even number
        
    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

In [4]:
class TransformerEmbedding(nn.Module):
    def __init__(self, num_embeddings, embedding_dim, max_len, dropout, device):
        super().__init__()
        self.token_embs = TokenEmbedding(num_embeddings, embedding_dim, device)
        self.pos_emb = PositionalEmbedding(embedding_dim, max_len, device)
        self.drop_out = nn.Dropout(p=dropout)
    
    def forward(self, x):
        token_embs = self.token_embs(x)
        pos_emb = self.pos_emb(x)
        return self.drop_out(token_embs + pos_emb)

In [7]:
x = torch.tensor([[0, 1, 2, 3, 4]], device='cuda:0')

transformer_emb = TransformerEmbedding(5, 8, 10, dropout=0.1, device='cuda:0')

In [8]:
transformer_emb(x)

tensor([[[ 0.0000,  1.1111,  0.0000,  1.1111,  0.0000,  1.1111,  0.0000,
           0.0000],
         [ 0.6528, -0.1120, -0.1208,  2.5757, -0.0650,  1.4334,  0.6059,
           0.8823],
         [ 1.9290, -0.4476,  0.5116,  0.0000,  0.4895, -1.8389,  1.9962,
           1.3165],
         [ 0.0000, -0.8013, -0.0000,  1.7064,  0.3865,  1.1128, -0.6196,
          -0.2335],
         [-2.3316, -2.2084,  0.3167,  1.1464,  0.4992,  1.3762,  0.5085,
          -0.0526]]], device='cuda:0', grad_fn=<NativeDropoutBackward0>)