In [2]:
import torch
from torch import Tensor
from torch import nn
import torch.nn.functional as F
import math

In [5]:
    class TokenEmbedding(nn.Module):
        def __init__(self, vocab_size, d_model):
            super(TokenEmbedding, self).__init__(vocab_size, d_model, padding_idx=1)


In [7]:
class PositionalEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super(PositionalEmbedding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model, device=device)
        self.encoding.require_grad = False
        pos = torch.arange(0, max_len, dtype=torch.float, device=device)
        pos = pos.unsqueeze(1)
        _2i = torch.arange(0, d_model, 2, dtype=torch.float, device=device)
        self.encoding[:, 0::2] = torch.sin(pos / (10000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (10000 ** (_2i / d_model)))

    def forward(self, x):
        batch_size, seq_len = x.size()
        return self.encoding[:seq_len, :]

In [None]:
class TransformerEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, max_len, drop_prob, device):
        super(TransformerEmbedding, self).__init__()
        self.token_embedding = TokenEmbedding(vocab_size, d_model)
        self.positional_embedding = PositionalEmbedding(d_model, max_len, device)
        self.dropout = nn.Dropout(p = drop_prob) 

    def forward(self, x):
        token_embedding = self.token_embedding(x)
        positional_embedding = self.positional_embedding(x)
        return self.dropout(token_embedding + positional_embedding)