In [1]:
import torch

In [2]:
def positional_encoding(length, depth):
    depth = depth/2
    
    positions = torch.arange(length)[:, None]
    depth = torch.arange(depth)[None, :]/depth

    angle_rates = 1/10000**depth
    angle_rads = positions * angle_rates

    pos_encoding = torch.concat(
        [torch.sin(angle_rads), torch.cos(angle_rads)],
        dim=-1
    )

    return torch.tensor(pos_encoding, dtype=torch.float32)

In [13]:
class PositionalEmbedding(torch.nn.Module):
    def __init__(self, vocab_size, d_model):
        super(PositionalEmbedding, self).__init__()
        self.d_model = d_model
        self.positional = positional_encoding(2048, d_model)
        self.embedding = torch.nn.Embedding(vocab_size, d_model)

    
    def forward(self, x):
        length = x.shape[1]
        x = self.embedding(x)
        x *= torch.math.sqrt(torch.tensor(self.d_model, dtype=torch.float32))
        print(self.positional[None, :length, :].shape)
        x = x + self.positional[None, :length, :]
        return x

In [22]:
x = PositionalEmbedding(5000, 512)(torch.randint(low=0, high=10, size=(10, 20)))

torch.Size([1, 20, 512])


  return torch.tensor(pos_encoding, dtype=torch.float32)


In [15]:
class PointWiseFeedForward(torch.nn.Module):
    def __init__(self, dff, d_model):
        super(PointWiseFeedForward, self).__init__()
        self.fc1 = torch.nn.Linear(d_model, dff)
        self.fc2 = torch.nn.Linear(dff, d_model)
        self.relu = torch.nn.ReLU(inplace=True)

     
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

In [28]:
class EncoderLayer(torch.nn.Module):
    def __init__(self, dff, d_model, num_heads, dropout_rate=0.1):
        super(EncoderLayer, self).__init__()
        self.mha = torch.nn.MultiheadAttention(
            embed_dim=d_model,
            kdim=d_model,
            num_heads=num_heads,
            dropout=dropout_rate
        )

        self.ffn = PointWiseFeedForward(dff, d_model)
        
        self.layernorm1 = torch.nn.LayerNorm(normalized_shape=(10, 20, 512), eps=1e-6)
        self.layernorm2 = torch.nn.LayerNorm(normalized_shape=(10, 20, 512), eps=1e-6)

        self.dropout = torch.nn.Dropout(dropout_rate)

    
    def forward(self, x):
        length = x.shape[1]
        batch_size = x.shape[0]

        attn_output = self.mha(
            key=x,
            value=x,
            query=x,
        )
        
        out1 = self.layernorm1(attn_output[0] + x)

        ffn_output = self.ffn(out1)
        
        out2 = self.layernorm2(ffn_output + out1)

        out2 = self.dropout(out2)
        return out2

In [30]:
EncoderLayer(2048, 512, 2)(x).shape

torch.Size([10, 20, 512])

In [38]:
class Encoder(torch.nn.Module):
    def __init__(self, vocab_size, d_model, dff, num_attention_heads, num_layers, dropout_rate=0.1):
        super().__init__()
        self.num_layers = num_layers
        self.positional = PositionalEmbedding(vocab_size, d_model)
        
        self.dec_layers = [EncoderLayer(
            d_model=d_model,
            dff=dff,
            num_heads=num_attention_heads
        ) for _ in range(num_layers)]

        self.dropout = torch.nn.Dropout(dropout_rate)

    
    def forward(self, x):
        x = self.positional(x)
        x = self.dropout(x)

        for i in range(self.num_layers):
            x = self.dec_layers[i](x)
            
        return x

In [39]:
Encoder(5000, 512, 2048, 2, 2)(torch.randint(low=0, high=5000, size=(10, 20)))

torch.Size([1, 20, 512])


  return torch.tensor(pos_encoding, dtype=torch.float32)


tensor([[[ 0.4454,  0.2520,  0.2209,  ...,  0.3562,  1.0865,  1.2449],
         [-1.3287, -0.8176, -1.3068,  ..., -0.5788,  0.6384,  0.2364],
         [-0.5435, -0.1028, -0.2348,  ..., -0.1226,  0.0000,  1.2979],
         ...,
         [ 0.3918,  0.7389, -0.2951,  ...,  1.4244, -0.1599,  0.3395],
         [ 1.3196, -0.2490, -1.3605,  ...,  0.5349, -0.0000,  1.8006],
         [-1.3776, -0.2561,  2.2954,  ...,  0.4890, -0.1300, -1.9712]],

        [[-0.4215,  0.6183,  0.2494,  ...,  0.9821,  0.3358, -0.3742],
         [ 0.0000, -0.0000,  1.7514,  ...,  0.0894,  0.0000, -0.2292],
         [ 0.8847, -1.5229, -0.5831,  ..., -0.6730,  0.7599, -0.9935],
         ...,
         [-0.2896,  0.0901,  2.3021,  ...,  2.5325,  0.0969,  0.7341],
         [ 2.3065,  1.1182, -0.5429,  ...,  1.4037,  0.8681, -1.9225],
         [-0.0535,  0.0000,  0.3583,  ...,  2.4435, -1.4707, -3.1531]],

        [[ 2.3536, -3.5629,  0.0000,  ..., -0.0000,  2.1961, -0.0000],
         [ 2.8880, -0.3633, -0.0762,  ..., -0