<a href="https://colab.research.google.com/github/DatumLearning/Transformer-PyTorch/blob/main/decoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [15]:
class SelfAttention(nn.Module):
  def __init__(self , d_model):
    super(SelfAttention , self).__init__()
    self.d_model = d_model
    self.query = nn.Linear(d_model , d_model)
    self.key = nn.Linear(d_model , d_model)
    self.value = nn.Linear(d_model , d_model)

  def forward(self , x , mask = None):
    Q = self.query(x)
    K = self.key(x)
    V = self.value(x)

    scores = torch.matmul(Q , K.transpose(-2 , -1)) / math.sqrt(self.d_model)
    if mask is not None:
      scores = scores.masked_fill(mask == 0 , float('inf'))
    attn_weights = F.softmax(scores , dim = -1)
    output = torch.matmul(attn_weights , V)
    return output

In [16]:
class CrossAttention(nn.Module):
  def __init__(self , d_model):
    super(CrossAttention , self).__init__()
    self.d_model = d_model
    self.query = nn.Linear(d_model , d_model)
    self.key = nn.Linear(d_model , d_model)
    self.value = nn.Linear(d_model , d_model)

  def forward(self , x , enc_out , mask = None):
    Q = self.query(x)
    K = self.key(enc_out)
    V = self.value(enc_out)

    scores = torch.matmul(Q , K.transpose(-2 , -1)) / math.sqrt(self.d_model)
    if mask is not None:
      scores = scores.masked_fill(mask == 0 , float('inf'))
    attn_weights = F.softmax(scores , dim = -1)
    output = torch.matmul(attn_weights , V)
    return output

In [17]:
class FeedForward(nn.Module):
  def __init__(self , d_model , ff_hidden):
    super(FeedForward , self).__init__()
    self.l1 = nn.Linear(d_model , ff_hidden)
    self.l2 = nn.Linear(ff_hidden , d_model)

  def forward(self , x):
    return self.l2(F.relu(self.l1(x)))

In [21]:
class TransformerDecoderLayer(nn.Module):
  def __init__(self , d_model , ff_hidden , dropout = 0.1):
    super(TransformerDecoderLayer , self).__init__()
    self.self_attn = SelfAttention(d_model)
    self.cross_attn = CrossAttention(d_model)
    self.ffn = FeedForward(d_model , ff_hidden)
    self.norm1 = nn.LayerNorm(d_model)
    self.norm2 = nn.LayerNorm(d_model)
    self.norm3 = nn.LayerNorm(d_model)
    self.dropout = nn.Dropout(dropout)

  def forward(self , x , enc_out , tgt_mask = None , memory_mask = None):
    #Self Attention
    attn_out = self.self_attn(x , tgt_mask)
    x = x + self.dropout(attn_out)
    x = self.norm1(x)

    #Cross Attention
    attn_out = self.cross_attn(x , enc_out , memory_mask)
    x = x + self.dropout(attn_out)
    x = self.norm2(x)

    #Feed Forward
    ffn_out = self.ffn(x)
    x = x + self.dropout(ffn_out)
    x = self.norm3(x)
    return x

In [22]:
TDL = TransformerDecoderLayer(512 , 1000)

In [23]:
TDL(torch.randn(16 , 10 , 512) , torch.randn(16 , 10 , 512)).shape

torch.Size([16, 10, 512])