### 1. ENCODER LAYER

In [10]:
import torch
import torch.nn as nn 

# Initialize the TransformerEncoderLayer
encoder_layer = nn.TransformerEncoderLayer(
    d_model=3, 
    nhead=1, 
    dim_feedforward=4, 
    dropout=0.1, 
    activation="relu",
    batch_first=True
)


# run for a sample
src = torch.Tensor([[[ 0.69,  0.72, -1.41],
                     [ 0.21,  1.10, -1.31]]])

out = encoder_layer(src)
print(out)

tensor([[[ 0.5854,  0.8222, -1.4076],
         [ 0.1275,  1.1560, -1.2835]]], grad_fn=<NativeLayerNormBackward0>)


In [11]:
src.shape, out.shape 

(torch.Size([1, 2, 3]), torch.Size([1, 2, 3]))

In [12]:
# check

x = encoder_layer.self_attn(src, src, src)[0]
x = src + x
x1 = encoder_layer.norm1(x)
x = encoder_layer.linear2( torch.nn.ReLU()(encoder_layer.linear1(x1)) )
x = x + x1
x = encoder_layer.norm2(x)
print(x)

tensor([[[ 0.5983,  0.8106, -1.4089],
         [ 0.0828,  1.1812, -1.2640]]], grad_fn=<NativeLayerNormBackward0>)


### 2. MARK ENCODER LAYER

In [13]:
# Initialize the TransformerEncoderLayer
encoder_layer_mask = nn.TransformerEncoderLayer(
    d_model=3, 
    nhead=1, 
    dim_feedforward=4, 
    dropout=0.1, 
    activation="relu",
    batch_first=True
)

# run for a sample
src = torch.Tensor([[[ 0.69,  0.72, -1.41],
                     [ 0.21,  1.10, -1.31],
                     [-0.88,  0.60, -0.31]]])

mask = torch.triu(input=torch.ones(3, 3), diagonal=1).bool()

out = encoder_layer_mask(src, src_mask=mask)
print(out)

tensor([[[ 0.6199,  0.7909, -1.4108],
         [ 0.0297,  1.2096, -1.2393],
         [-0.7893,  1.4109, -0.6216]]], grad_fn=<NativeLayerNormBackward0>)


In [14]:
src.shape, out.shape

(torch.Size([1, 3, 3]), torch.Size([1, 3, 3]))

In [15]:
# check

x = encoder_layer.self_attn(src, src, src, attn_mask=mask)[0]
x = src + x
x1 = encoder_layer.norm1(x)
x = encoder_layer.linear2( torch.nn.ReLU()(encoder_layer.linear1(x1)) )
x = x + x1
x = encoder_layer.norm2(x)
print(x)

tensor([[[ 0.6375,  0.7745, -1.4120],
         [ 0.0828,  1.1812, -1.2640],
         [-1.1091,  1.3144, -0.2053]]], grad_fn=<NativeLayerNormBackward0>)
