# **Decoder Architecture**

The decoder takes as input the hidden states generated by the encoder and the previously generated output tokens and uses them to predict the next output token. At each step, the decoder attends to different parts of the input sequence using its attention mechanism, allowing it to capture complex relationships between the input and output sequences.

![](https://media.geeksforgeeks.org/wp-content/uploads/20240110165738/Transformer-python.webp)

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
d_model = 512
num_heads = 8
drop_prob = 0.1
batch_size = 30
max_sequence_length = 200
ffn_hidden = 2048
num_layers = 5

In [3]:
import math

def scaled_attention(q, k, v, mask):                  # 30 x 8 x 200 x 64
  shape = max_sequence_length                         # 200
  d_k = q.size()[-1]                                  # 1 x 1
  scaled = (q @ k.transpose(-2, -1)) / math.sqrt(d_k) # 30 x 8 x 200 x 200

  if mask:
    tril = torch.tril(torch.ones(shape, shape))       # 200 x 200
    mask = tril.masked_fill(tril == 0, float('-inf')) # 200 x 200
    mask = mask.masked_fill(tril == 1, 0)             # 200 x 200
    scaled += mask                                    # 30 x 8 x 200 x 200

  attention = F.softmax(scaled, dim = -1)             # 30 x 8 x 200 x 200
  values = attention @ v                              # 30 x 8 x 200 x 64

  return values

In [4]:
class MultiHeadAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model                           # 512
    self.num_heads = num_heads                       # 8
    self.head_dim = d_model // num_heads             # 64
    self.qkv_layer = nn.Linear(d_model, 3 * d_model) # 512 x 1536
    self.lin_layer = nn.Linear(d_model, d_model)     # 512 x 512

  def forward(self, x, mask = False):
    batch_size, sequence_len, input_dim = x.size()   # 30 x 200 x 512
    qkv = self.qkv_layer(x)                          # 30 x 200 x 1536
    qkv = qkv.reshape(batch_size, sequence_len,
                self.num_heads, 3 * self.head_dim)   # 30 x 200 x 8 x 196
    qkv = qkv.permute(0, 2, 1, 3)                    # 30 x 8 x 200 x 196
    q, k, v = qkv.chunk(3, dim = -1)                 # (30 x 8 x 200 x 64) * 3
    values = scaled_attention(q, k, v, mask)         # 30 x 8 x 200 x 64
    values = values.reshape(batch_size,sequence_len,
                    self.num_heads * self.head_dim)  # 30 x 200 x 512
    out = self.lin_layer(values)                     # 30 x 200 x 512
    return out

In [5]:
class MultiHeadCrossAttention(nn.Module):
  def __init__(self, d_model, num_heads):
    super().__init__()
    self.d_model = d_model                           # 512
    self.num_heads = num_heads                       # 8
    self.head_dim = d_model // num_heads             # 64
    self.kv_layer = nn.Linear(d_model, 2 * d_model)  # 512 x 1024
    self.q_layer = nn.Linear(d_model, d_model)       # 512 x 512
    self.lin_layer = nn.Linear(d_model, d_model)     # 512 x 512

  def forward(self, x, y):
    batch_size, sequence_len, input_dim = x.size()   # 30 x 200 x 512
    kv = self.kv_layer(x)                            # 30 x 200 x 1024
    q = self.q_layer(y)                              # 30 x 200 x 512
    kv = kv.reshape(batch_size, sequence_len,
                self.num_heads, 2 * self.head_dim)   # 30 x 200 x 8 x 128
    q = q.reshape(batch_size, sequence_len,
                self.num_heads, self.head_dim)       # 30 x 200 x 8 x 64
    kv = kv.permute(0, 2, 1, 3)                      # 30 x 8 x 200 x 128
    q = q.permute(0, 2, 1, 3)                        # 30 x 8 x 200 x 64
    k, v = kv.chunk(2, dim = -1)                     # (30 x 8 x 200 x 64) * 2
    values = scaled_attention(q, k, v, mask = False) # 30 x 8 x 200 x 64
    values = values.reshape(batch_size,sequence_len,
                    self.num_heads * self.head_dim)  # 30 x 200 x 512
    out = self.lin_layer(values)                     # 30 x 200 x 512
    return out

In [6]:
class PositionwiseFeedForward(nn.Module):
  def __init__(self, d_model, hidden, prob):
    super(PositionwiseFeedForward, self).__init__()
    self.linear1 = nn.Linear(d_model, hidden)       # 512, 2048
    self.linear2 = nn.Linear(hidden, d_model)       # 2048 x 512
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(prob)

  def forward(self, x):                             # 30 x 200 x 512
    x = self.linear1(x)                             # 30 x 200 x 2048
    x = self.relu(x)                                # 30 x 200 x 2048
    x = self.dropout(x)                             # 30 x 200 x 2048
    x = self.linear2(x)                             # 30 x 200 x 512
    return x

In [7]:
class LayerNorm(nn.Module):
  def __init__(self, params_shape, eps = 1e-5):
    super().__init__()
    self.params_shape = params_shape                         # 1 x 512
    self.eps = eps
    self.gamma = nn.Parameter(torch.ones(params_shape))      # 1 x 512
    self.beta = nn.Parameter(torch.zeros(params_shape))      # 1 x 512

  def forward(self, input):
    dims = [-(i + 1) for i in range(len(self.params_shape))] # 1 x params_shape
    mean = input.mean(dim = dims, keepdim = True)            # 30 x 200 x 1
    var = (((input - mean) ** 2)
                .mean(dim = dims , keepdim = True))          # 30 x 200 x 1
    sd = (var + self.eps).sqrt()                             # 30 x 200 x 1
    X_dash = (input - mean) / sd                             # 30 x 200 x 512
    Y = self.gamma * X_dash + self.beta                      # 30 x 200 x 512

    return Y

In [8]:
class DecoderLayer(nn.Module):
  def __init__(self, d_model, num_heads, drop_prob, ffn_hidden):
    super().__init__()
    self.self_attn = MultiHeadAttention(d_model, num_heads)
    self.norm1 = nn.LayerNorm([d_model])
    self.dropout1 = nn.Dropout(drop_prob)
    self.enc_dec_attn = MultiHeadCrossAttention(d_model, num_heads)
    self.norm2 = nn.LayerNorm([d_model])
    self.dropout2 = nn.Dropout(drop_prob)
    self.ffn = PositionwiseFeedForward(d_model, ffn_hidden, drop_prob)
    self.norm3 = nn.LayerNorm([d_model])
    self.dropout3 = nn.Dropout(drop_prob)

  def forward(self, x, y, mask):
    res_y = y                            # 30 x 200 x 512
    y = self.self_attn(y, mask = True)   # 30 x 200 x 512
    y = self.dropout1(y)                 # 30 x 200 x 512
    y = self.norm1(y + res_y)            # 30 x 200 x 512

    res_y = y                            # 30 x 200 x 512
    y = self.enc_dec_attn(x, y)          # 30 x 200 x 512
    y = self.dropout2(y)                 # 30 x 200 x 512
    y = self.norm2(y + res_y)            # 30 x 200 x 512

    res_y = y                            # 30 x 200 x 512
    y = self.ffn(y)                      # 30 x 200 x 512
    y = self.dropout3(y)                 # 30 x 200 x 512
    y = self.norm3(y + res_y)            # 30 x 200 x 512

    return y


In [9]:
class SequentialDecoder(nn.Sequential):
  def forward(self, *inputs):
    x, y, mask = inputs
    for module in self._modules.values():
        y = module(x, y, mask) #30 x 200 x 512
    return y

In [10]:
class Decoder(nn.Module):
  def __init__(self, d_model, num_layers, num_heads, drop_prob, ffn_hidden):
    super().__init__()
    self.layers = SequentialDecoder(*[DecoderLayer(d_model, num_heads, drop_prob, ffn_hidden) for _ in range(num_layers)])

  def forward(self, x, y, mask):
    # x : Eng : 30 x 200 x 512
    # y : Eng : 30 x 200 x 512
    # mask : 200 x 200
    y = self.layers(x, y, mask)  # 30 x  200 x 512
    return y

In [11]:
x = torch.randn(batch_size, max_sequence_length, d_model)
y = torch.randn(batch_size, max_sequence_length, d_model)
mask = torch.randn(max_sequence_length, max_sequence_length)
decoder = Decoder(d_model, num_layers, num_heads, drop_prob, ffn_hidden)

y = decoder(x,y,mask)


In [12]:
x[0][0][:20]

tensor([-0.3718, -0.3440, -0.9321,  0.7372, -1.7299, -0.8228, -0.8395,  1.3644,
        -0.0908, -0.5668,  0.4527,  0.3867,  0.0058, -0.3630,  0.5339,  0.8531,
        -0.3019,  1.2913,  2.8505,  0.4399])

In [13]:
y[0][0][:20]

tensor([-0.4367, -1.8011,  1.0248,  1.5188,  0.6213,  0.5930,  0.6552, -1.9216,
         0.7310, -0.7416,  0.5891, -1.0531,  1.1018,  0.5524,  1.6332,  0.5605,
        -0.6998, -0.1624,  1.2620,  0.1964], grad_fn=<SliceBackward0>)