# Decoder architecture
---
This is the second notebook, that heavily bases on the encoder implementation. All the comments are in the `encoder.ipynb`, and you should start learning from there.

There are two main differences here:
- `CasualHeadAttention` is a version of the `MultiHeadAttention` class that contains `casual_mask`. This mask is a matrix with all the upper right triangle values equal to 0, and it is applied on the entire input sequence. The general goal is to ensure that decoder can only see the words before the word that is currently analyzed, so for example in word 4, decoder sees only words 1, 2, 3, and 4.
- `Decoder` class differs from the `Encoder` class in terms of output size. The output size is $T \times DictionarySize$. For example, if the longest sentence in the batch contained 30 words, and the dictionary contains 20,000 words, the network returns $30 \times 20,000$ matrix (30 words, each word one-hot-encoded).   

In [3]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [9]:
class CasualHeadAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_k
        self.n_heads = n_heads
        
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.key = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)
        
        self.out = nn.Linear(d_k * n_heads, d_model)
        
        # Casual mask 
        cm = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer(
            'casual_mask',
            cm.view(1, 1, max_len, max_len)
        )
        
        
    def forward(self, q, k, v, pad_mask=None):
        
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
        
        q = self.query(q) # N x T x (h*d_k) 
        k = self.key(k)   # N x T x (h*d_k)
        v = self.value(v) # N x T x (h*d_v) # d_v == d_k
        
        N = q.shape[0] # batch size
        T = q.shape[1] # sequence length
        
        # Changing shapes (reuqired for matrix multiplication)
        # view: (N, T, h*d_k) -> (N, T, h, d_k)
        # transpose: (N, T, h, d_k) -> (N, h, T, d_k)
        
        q = q.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # (N, h, T, d_k) x (N, h, d_k, T) -> (N, h, T, T)
        atention_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        
        if pad_mask is not None:
            # Mask has (N, T) shape, so we need to add two (inner) dimensions
            # We also change zeros with -inf, so that softmax will ignore these values
            atention_scores = atention_scores.masked_fill(
                 pad_mask[:, None, None, :] == 0, float('-inf')
                 )
            
        # We also need to add casual mask, so that we don't look into the future
        # Max_len is the length of the longest sequence possible, but in fact,
        # we need the longest sequence in the batch. Thus we crop casual mask to :T size            
        atention_scores = atention_scores.masked_fill(
                self.casual_mask[:, :, :T, :T] == 0, float('-inf')
                )
        
        attention_weights = F.softmax(atention_scores, dim=-1)
        
        A = attention_weights @ v
        
        # Reshape (N, h, T, d_k) -> (N, T, h, d_k) -> (N, T, h*d_k)
        A = A.transpose(1, 2)
        
        # Concatenate
        A = A.contiguous().view(N, T, self.n_heads * self.d_k)
        
        return self.out(A)
        
        
                

In [13]:
class TransformerBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout=0.1):
        super().__init__()
                
        self.attention = CasualHeadAttention(d_k, d_model, n_heads, max_len)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )
        
    def forward(self, x, pad_mask=None):
        x = self.norm1(x + self.attention(x, x, x, pad_mask))
        x = self.norm2(x + self.ff(x))
        return self.dropout(x)


In [14]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # [ [0], [1], [2], ..., [max_len-1] ]
        # 2d array of size max_len x 1
        position = torch.arange(max_len).unsqueeze(1)
        
        #[0, 2, 4, ...]
        exp_term = torch.arange(0, d_model, 2) 
        
        
        div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x.shape: N x T x D
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
        
        
        

In [15]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        max_len,
        d_k,
        d_model,
        n_heads,
        n_layers,
        dropout,
    ):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout=dropout)
        transformer_blocks = [
            TransformerBlock(d_k, d_model, n_heads, max_len, dropout=dropout)
            for _ in range(n_layers)
        ]

        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.norm = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, pad_mask = None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        
        x = self.norm(x)
        return self.out(x)

In [16]:
model = Decoder(
    vocab_size=20_000,
    max_len = 1024,
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout = 0.1,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
model.to(device)

cuda


Decoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (attention): CasualHeadAttention(
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (out): Linear(in_features=64, out_features=64, bias=True)
      )
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
    )
    (1): TransformerBlock(
      (attention): CasualHeadAttention(
        (query): Linear(in_featur

In [20]:
batch_size = 8
nr_words = 512
x = np.random.randint(0, 20_000, size=(batch_size, nr_words))
x_t = torch.tensor(x).to(device)

mask = np.ones((batch_size, nr_words))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

# Without mask
y = model(x_t)
print (y.shape)

# With mask
y = model(x_t, mask_t)
print (y.shape)

torch.Size([8, 512, 20000])
torch.Size([8, 512, 20000])
