In [33]:
import torch
import torch.nn as nn
import math

In [34]:
batch_size = 30
max_sequence_length = 200

In [44]:
class Encode(nn.Module):
  def __init__(self,d_model,num_heads,drop_prob,ffn_hidden,num_layers):
    super().__init__()
    self.encode_layer = nn.Sequential(*[Encode_layer(d_model,num_heads,drop_prob,ffn_hidden) for _ in range(num_layers)])

  def forward(self,x):
    x = self.encode_layer(x)
    return x

In [46]:
class MultiheadAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    super(MultiheadAttention,self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads
    self.qkv_layer = nn.Linear(d_model,3*d_model)
    self.linear_layer = nn.Linear(d_model,d_model)

  def forward(self,x,mask=None):
    qkv = self.qkv_layer(x)# 30 X 200 X 1536
    qkv = qkv.reshape(batch_size,self.num_heads,max_sequence_length,3*self.head_dim)# 30 X 8 X 200 X 192
    q,k,v = qkv.chunk(3,dim=-1)
    d_k = q.shape[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)# 30 X 8 X 200 X 192
    if mask is not None:
      scaled += mask
    self_attention = torch.softmax(scaled,dim=-1)# 30 X 8 X 200 X 192
    out = torch.matmul(self_attention,v)# 30 X 8 X 200 X 192
    out = out.reshape(batch_size,max_sequence_length,self.num_heads*self.head_dim)# 30 X 200 X 512
    out = self.linear_layer(out)# 30 X 200 X 512
    return out

In [37]:
class LayerNormalization(nn.Module):
  def __init__(self,parameter_size,eps=1e-5):
    super().__init__()
    self.parameter_size = parameter_size#[512]
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameter_size))#[512]
    self.beta = nn.Parameter(torch.zeros(parameter_size))#[512]

  def forward(self,x):
    dims = [(-i+1) for i in range(len(self.parameter_size))]#[-1]
    mean = x.mean(dim=dims, keepdim=True)# 30 x 200 x 1
    var=((x-mean)**2).mean(dim=dims, keepdim=True)# 30 x 200 x 1
    std_ = (var+self.eps).sqrt()# 30 x 200 x 1
    y = (x - mean) / std_# 30 x 200 x 512
    out = self.gamma * y + self.beta# 30 x 200 x 512
    return out

In [38]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self,d_model,ffn_hidden,drop_prob):
    super().__init__()
    self.linear_layer_1 = nn.Linear(d_model,ffn_hidden)#512 x 2048
    self.linear_layer_2 = nn.Linear(ffn_hidden,d_model)#2048 x 512
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p=drop_prob)
  def forward(self,x):
    x = self.linear_layer_1(x)# 30 x 200 x 2048
    x = self.relu(x)# 30 x 200 x 2048
    x = self.dropout(x)# 30 x 200 x 2048
    x = self.linear_layer_2(x)# 30 x 200 x 512
    return x

In [45]:
class Encode_layer(nn.Module):
  def __init__(self,d_model,num_heads,drop_prob,ffn_hidden):
    super(Encode_layer,self).__init__()
    self.attention = MultiheadAttention(d_model,num_heads)
    self.norm1 = LayerNormalization([d_model])
    self.dropout1 = nn.Dropout(p=drop_prob)
    self.ffn = PositionwiseFeedForward(d_model,ffn_hidden,drop_prob)
    self.norm2 = LayerNormalization([d_model])
    self.dropout2 = nn.Dropout(p=drop_prob)

  def forward(self,x):
    print("Starting epoch...")
    residual_x = x # 30 x 200 x 512
    x = self.attention(x) # 30 X 200 X 512
    print("Attention done.")
    x = self.dropout1(x)# 30 X 200 X 512
    x = self.norm1(x+residual_x)# 30 X 200 X 512
    print("normalization completed")
    residual_x = x# 30 X 200 X 512
    x = self.ffn(x)# 30 X 200 X 512
    print("feed forwarded")
    x = self.dropout2(x)# 30 X 200 X 512
    x = self.norm2(x+residual_x)# 30 X 200 X 512
    print("normalization completed")
    return x

In [40]:
d_model = 512
num_heads = 8
drop_prob = 0.1
ffn_hidden = 2048
num_layers = 5

#Random input

In [41]:
x = torch.randn((batch_size,max_sequence_length,d_model))

In [47]:
encoder = Encode(d_model,num_heads,drop_prob,ffn_hidden,num_layers)

In [48]:
out = encoder(x)


Starting epoch...
Attention done.
normalization completed
feed forwarded
normalization completed
Starting epoch...
Attention done.
normalization completed
feed forwarded
normalization completed
Starting epoch...
Attention done.
normalization completed
feed forwarded
normalization completed
Starting epoch...
Attention done.
normalization completed
feed forwarded
normalization completed
Starting epoch...
Attention done.
normalization completed
feed forwarded
normalization completed


tensor([-1.1239e-01,  6.5751e-01, -9.9169e-01,  1.4961e+00,  5.0408e-01,
         3.1691e-02,  1.9468e-01, -1.3889e+00, -3.1518e-01,  1.4601e+00,
         2.3780e+00,  2.2442e-01, -1.4515e+00, -9.0728e-01,  1.1531e-01,
         1.0405e+00, -1.3333e-01,  9.7122e-02, -5.2967e-01, -1.7226e+00,
        -3.1487e+00, -1.1113e+00, -7.4371e-01, -8.7416e-01,  1.5411e+00,
         5.0356e-01, -2.6564e-01,  1.7900e-01, -8.1131e-02,  6.3473e-01,
         2.6696e-01,  5.7962e-01,  9.8714e-02,  9.1608e-01,  8.4464e-01,
         9.4181e-01,  3.6722e-01, -7.6956e-01,  1.2320e+00,  1.0828e+00,
         1.2177e+00,  2.0513e-01,  2.7616e-01, -1.0094e+00, -2.2355e-01,
        -9.2582e-01,  6.7670e-01,  8.0102e-01,  2.4730e-01,  1.3716e-01,
        -4.8882e-01,  1.5830e+00,  3.5139e-01, -3.9447e-01, -7.7217e-01,
        -1.4764e+00,  7.7121e-01,  1.5478e+00,  2.0299e-01,  1.1259e+00,
         8.5689e-01, -1.5605e+00, -8.5204e-01,  5.7054e-01, -2.9758e-01,
         5.8265e-01, -8.0308e-01, -4.1799e-01, -2.4