In [34]:
import torch
import torch.nn as nn
import math

#MultiheadAttention

In [35]:
class MultiheadAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    super(MultiheadAttention,self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads
    self.qkv_layer = nn.Linear(d_model,3*d_model)
    self.linear_layer = nn.Linear(d_model,d_model)

  def forward(self,x,mask):
    qkv = self.qkv_layer(x)# 30 X 200 X 1536
    print(f"qkv_layer{qkv.shape}")
    qkv = qkv.reshape(batch_size,self.num_heads,max_sequence_length,3*self.head_dim)# 30 X 8 X 200 X 192
    print(f"qkv_layer{qkv.shape}")
    q,k,v = qkv.chunk(3,dim=-1)
    d_k = q.shape[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)# 30 X 8 X 200 X 192
    print(f"scaled_layer{scaled.shape}")
    scaled += mask
    self_attention = torch.softmax(scaled,dim=-1)# 30 X 8 X 200 X 192
    print(f"attention_matrix {self_attention.shape}")
    out = torch.matmul(self_attention,v)# 30 X 8 X 200 X 192
    out = out.reshape(batch_size,max_sequence_length,self.num_heads*self.head_dim)# 30 X 200 X 512
    print(f"out_shape {out.shape}")
    out = self.linear_layer(out)# 30 X 200 X 512
    return out

#MultiheadCrossAttention

In [36]:
class MultiheadCrossAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    super(MultiheadCrossAttention,self).__init__()
    self.d_model = d_model
    self.num_heads = num_heads
    self.head_dim = d_model // num_heads
    self.kv_layer = nn.Linear(d_model,2*d_model)
    self.q_layer = nn.Linear(d_model,d_model)
    self.linear_layer = nn.Linear(d_model,d_model)
  def forward(self,x,y,mask=None):
    kv = self.kv_layer(x)
    print(f"kv_layer {kv.shape}")
    q = self.q_layer(x)
    print(f"q_layer {q.shape}")

    kv = kv.reshape(batch_size,self.num_heads,max_sequence_length,2*self.head_dim)
    print(f"kv_layer {kv.shape}")
    q = q.reshape(batch_size,self.num_heads,max_sequence_length,self.head_dim)
    print(f"q_layer {q.shape}")

    k,v = kv.chunk(2,dim=-1)
    d_k = q.shape[-1]
    scaled = torch.matmul(q,k.transpose(-2,-1))/math.sqrt(d_k)# 30 X 8 X 200 X 192
    print(f"scaled_layer {scaled.shape}")
    self_attention = torch.softmax(scaled,dim=-1)# 30 X 8 X 200 X 192
    print(f"attention_matrix {self_attention.shape}")
    out = torch.matmul(self_attention,v)# 30 X 8 X 200 X 192
    out = out.reshape(batch_size,max_sequence_length,self.num_heads*self.head_dim)# 30 X 200 X 512
    print(f"out_shape {out.shape}")
    out = self.linear_layer(out)# 30 X 200 X 512
    return out



#LayerNormalization

In [37]:
class LayerNormalization(nn.Module):
  def __init__(self,parameter_size,eps=1e-5):
    super().__init__()
    self.parameter_size = parameter_size#[512]
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(parameter_size))#[512]
    self.beta = nn.Parameter(torch.zeros(parameter_size))#[512]

  def forward(self,x):
    dims = [(-i+1) for i in range(len(self.parameter_size))]#[-1]
    print(f"dims {dims}")
    mean = x.mean(dim=dims, keepdim=True)# 30 x 200 x 1
    var=((x-mean)**2).mean(dim=dims, keepdim=True)# 30 x 200 x 1
    std_ = (var+self.eps).sqrt()# 30 x 200 x 1
    y = (x - mean) / std_# 30 x 200 x 512
    out = self.gamma * y + self.beta# 30 x 200 x 512
    print(f"out_shape {out.shape}")
    return out

#PositionWisefeedForward

In [38]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self,d_model,ffn_hidden,drop_prob):
    super().__init__()
    self.linear_layer_1 = nn.Linear(d_model,ffn_hidden)#512 x 2048
    self.linear_layer_2 = nn.Linear(ffn_hidden,d_model)#2048 x 512
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p=drop_prob)
  def forward(self,x):
    x = self.linear_layer_1(x)# 30 x 200 x 2048
    print(f"linear_layer {x.shape}")
    x = self.relu(x)# 30 x 200 x 2048
    x = self.dropout(x)# 30 x 200 x 2048
    x = self.linear_layer_2(x)# 30 x 200 x 512
    print(f"linear_layer {x.shape}")
    return x

#Decode Layer

In [39]:
class Decode_layer(nn.Module):
  def __init__(self,d_model,num_heads,drop_prob,ffn_hidden):
    super(Decode_layer,self).__init__()

    self.attention = MultiheadAttention(d_model,num_heads)

    self.norm1 = LayerNormalization([d_model])
    self.dropout1 = nn.Dropout(p=drop_prob)

    self.cross_attention = MultiheadCrossAttention(d_model,num_heads)

    self.norm2 = LayerNormalization([d_model])
    self.dropout2 = nn.Dropout(p=drop_prob)

    self.ffn = PositionwiseFeedForward(d_model,ffn_hidden,drop_prob)
    self.norm3 = LayerNormalization([d_model])
    self.dropout3 = nn.Dropout(p=drop_prob)

  def forward(self,x,y,decoder_mask):
    print("Starting epoch...")
    residual_x = x # 30 x 200 x 512
    x = self.attention(x,decoder_mask) # 30 X 200 X 512
    print("Attention done.")
    x = self.dropout1(x)# 30 X 200 X 512
    x = self.norm1(x+residual_x)# 30 X 200 X 512
    print("normalization completed")
    x = self.cross_attention(x,y,mask=None) # 30 X 200 X 512
    print("cross_Attention done.")
    x = self.dropout2(x)# 30 X 200 X 512
    x = self.norm2(x+residual_x)# 30 X 200 X 512
    print("normalization completed")
    residual_x = x# 30 X 200 X 512
    x = self.ffn(x)# 30 X 200 X 512
    print("feed forwarded")
    x = self.dropout3(x)# 30 X 200 X 512
    x = self.norm3(x+residual_x)# 30 X 200 X 512
    print("normalization completed")
    print("-------------------------------------------------")
    return x

In [40]:
class Sequential_decoder(nn.Sequential):
  def forward(self,*inputs):
    x ,y, mask = inputs
    for module in self._modules.values():
      y = module(x,y,mask)
    return y

#Decoder

In [41]:
class Decoder(nn.Module):
  def __init__(self,d_model,num_heads,drop_prob,ffn_hidden,num_layers):
    super().__init__()
    self.decode_layer = Sequential_decoder(*[Decode_layer(d_model,num_heads,drop_prob,ffn_hidden) for _ in range(num_layers)])

  def forward(self,x,y,mask):
    x = self.decode_layer(x,y,mask)
    return x

In [42]:
batch_size = 30
max_sequence_length = 200
num_heads = 8
d_model = 512
ffn_hidden = 2048
drop_prob = 0.1
num_layers = 5

#Random Input

In [43]:
x = torch.randn((batch_size,max_sequence_length,d_model))
y = torch.randn((batch_size,max_sequence_length,d_model))

In [44]:
decoder = Decoder(d_model,num_heads,drop_prob,ffn_hidden,num_layers)

In [45]:
mask = torch.full((max_sequence_length,max_sequence_length),float('-inf'))
mask = torch.triu(mask,diagonal=1)

In [46]:
out = decoder(x,y,mask)

Starting epoch...
qkv_layertorch.Size([30, 200, 1536])
qkv_layertorch.Size([30, 8, 200, 192])
scaled_layertorch.Size([30, 8, 200, 200])
attention_matrix torch.Size([30, 8, 200, 200])
out_shape torch.Size([30, 200, 512])
Attention done.
dims [1]
out_shape torch.Size([30, 200, 512])
normalization completed
kv_layer torch.Size([30, 200, 1024])
q_layer torch.Size([30, 200, 512])
kv_layer torch.Size([30, 8, 200, 128])
q_layer torch.Size([30, 8, 200, 64])
scaled_layer torch.Size([30, 8, 200, 200])
attention_matrix torch.Size([30, 8, 200, 200])
out_shape torch.Size([30, 200, 512])
cross_Attention done.
dims [1]
out_shape torch.Size([30, 200, 512])
normalization completed
linear_layer torch.Size([30, 200, 2048])
linear_layer torch.Size([30, 200, 512])
feed forwarded
dims [1]
out_shape torch.Size([30, 200, 512])
normalization completed
-------------------------------------------------
Starting epoch...
qkv_layertorch.Size([30, 200, 1536])
qkv_layertorch.Size([30, 8, 200, 192])
scaled_layertorch