In [1]:
import numpy as np
import math, copy
import torch
import torch.nn as nn
import torch.nn.functional as F

class EmbedEncode(nn.Module):
    "also multiply the weights by sqrt(d_model) after embedding according to the paper"
    def __init__(self, d_model=512, d_input=1024, max_seq_len=512):
        super(EmbedEncode, self).__init__()
        self.d_model = d_model
        self.d_input = d_input
        self.max_seq_len = max_seq_len
        self.embedding = nn.Embedding(d_input, d_model)

    def forward(self, x):
        # x can be [batch, seq_len, d_input]
        # e.g. "say hello to world" 1x4x1024

        # 1. Add pads                    # 1x4x1024
        x = self.add_pads(x)             # 1x512x1024

        # 2. Embedding
        x = self.embedding(x)            # 1x512x512
        x = x * math.sqrt(self.d_model)  # 1x512x512

        # 3. Position Encoding
        x = self.pos_encoding(x)
        return x

    def add_pads(self, x):
        # add pads to [B x max_seq_len x d_input]
        if x.size()[1] < self.max_seq_len:
            B, S, _ = x.size()
            pads = torch.zeros([B, self.max_seq_len-S, self.d_input])
            pads[:, :, -1] = 1
            x = torch.cat([x, pads], -2)
        return x
    
    def pos_encoding(self, x):
        "PE (pos, 2i)   = sin(pos / 10000^(2i/d_model))" # position is word pos in sequence
        "PE (pos, 2i+1) = cos(pos / 10000^(2i/d_model))" # i is index in d_model
        B, _, _ = x.size() # 1x512x512
        even_i = torch.arange(0, self.d_model, 2).float()                      # 256 (d_model / 2)
        denominator = torch.pow(even_i, (even_i / self.d_model))               # 256 (d_model / 2)
        position = torch.arange(self.max_seq_len).reshape(self.max_seq_len, 1) # 512x1 (seq_len x 1)

        even_PE = torch.sin(position / denominator)                            # 512x256 (seq_len x (d_model/2))
        odd_PE  = torch.cos(position / denominator)                            # 512x256 (seq_len x (d_model/2))

        stacked = torch.stack([even_PE, odd_PE], dim=-1)                       # 512x256x2 (seq_len x (d_model/2) x 2)
        pe = torch.flatten(stacked, start_dim=-2, end_dim=-1)                  # 512x512 (seq_len x d_model) [[even_0 odd_0 even_1 odd_1...]...]
        batch_pe = pe.unsqueeze(0).repeat(B, 1, 1)                             # 1x512x512
        x = x + batch_pe
        return x




In [48]:
a = torch.tensor([[0, 0], [1, 1], [2, 2]])
a.unsqueeze(0).repeat(4, 1, 1)

tensor([[[0, 0],
         [1, 1],
         [2, 2]],

        [[0, 0],
         [1, 1],
         [2, 2]],

        [[0, 0],
         [1, 1],
         [2, 2]],

        [[0, 0],
         [1, 1],
         [2, 2]]])