# Decoder architecture
---
This is the second notebook, that heavily bases on the encoder implementation. All the comments are in the `encoder.ipynb`, and you should start learning from there.

There are two main differences here:
- `CasualHeadAttention` is a version of the `MultiHeadAttention` class that contains `casual_mask`. This mask is a matrix with all the upper right triangle values equal to 0, and it is applied on the entire input sequence. The general goal is to ensure that decoder can only see the words before the word that is currently analyzed, so for example in word 4, decoder sees only words 1, 2, 3, and 4.
- `Decoder` class differs from the `Encoder` class in terms of output size. The output size is $T \times DictionarySize$. For example, if the longest sentence in the batch contained 30 words, and the dictionary contains 20,000 words, the network returns $30 \times 20,000$ matrix (30 words, each word one-hot-encoded).   

In [2]:
import math

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import dataset

import numpy as np
import matplotlib.pyplot as plt

In [3]:
class CasualHeadAttention(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_k
        self.n_heads = n_heads
        
        self.query = nn.Linear(d_model, d_k * n_heads)
        self.key = nn.Linear(d_model, d_k * n_heads)
        self.value = nn.Linear(d_model, d_k * n_heads)
        
        self.out = nn.Linear(d_k * n_heads, d_model)
        
        # Casual mask 
        cm = torch.tril(torch.ones(max_len, max_len))
        self.register_buffer(
            'casual_mask',
            cm.view(1, 1, max_len, max_len)
        )
        
        
    def forward(self, q, k, v, pad_mask=None):
        
        # Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) V
        
        q = self.query(q) # N x T x (h*d_k) 
        k = self.key(k)   # N x T x (h*d_k)
        v = self.value(v) # N x T x (h*d_v) # d_v == d_k
        
        N = q.shape[0] # batch size
        T = q.shape[1] # sequence length
        
        # Changing shapes (reuqired for matrix multiplication)
        # view: (N, T, h*d_k) -> (N, T, h, d_k)
        # transpose: (N, T, h, d_k) -> (N, h, T, d_k)
        
        q = q.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        k = k.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        v = v.view(N, T, self.n_heads, self.d_k).transpose(1, 2)
        
        # (N, h, T, d_k) x (N, h, d_k, T) -> (N, h, T, T)
        atention_scores = q @ k.transpose(-2, -1) / math.sqrt(self.d_k)
        
        if pad_mask is not None:
            # Mask has (N, T) shape, so we need to add two (inner) dimensions
            # We also change zeros with -inf, so that softmax will ignore these values
            atention_scores = atention_scores.masked_fill(
                 pad_mask[:, None, None, :] == 0, float('-inf')
                 )
            
        # We also need to add casual mask, so that we don't look into the future
        # Max_len is the length of the longest sequence possible, but in fact,
        # we need the longest sequence in the batch. Thus we crop casual mask to :T size            
        atention_scores = atention_scores.masked_fill(
                self.casual_mask[:, :, :T, :T] == 0, float('-inf')
                )
        
        attention_weights = F.softmax(atention_scores, dim=-1)
        
        A = attention_weights @ v
        
        # Reshape (N, h, T, d_k) -> (N, T, h, d_k) -> (N, T, h*d_k)
        A = A.transpose(1, 2)
        
        # Concatenate
        A = A.contiguous().view(N, T, self.n_heads * self.d_k)
        
        return self.out(A)
        
        
                

In [4]:
class TransformerBlock(nn.Module):
    def __init__(self, d_k, d_model, n_heads, max_len, dropout=0.1):
        super().__init__()
                
        self.attention = CasualHeadAttention(d_k, d_model, n_heads, max_len)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)
        
        self.ff = nn.Sequential(
            nn.Linear(d_model, 4 * d_model),
            nn.GELU(),
            nn.Linear(4 * d_model, d_model)
        )
        
    def forward(self, x, pad_mask=None):
        x = self.norm1(x + self.attention(x, x, x, pad_mask))
        x = self.norm2(x + self.ff(x))
        return self.dropout(x)


In [5]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=2048, dropout=0.1):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
        # [ [0], [1], [2], ..., [max_len-1] ]
        # 2d array of size max_len x 1
        position = torch.arange(max_len).unsqueeze(1)
        
        #[0, 2, 4, ...]
        exp_term = torch.arange(0, d_model, 2) 
        
        
        div_term = torch.exp(exp_term * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        # x.shape: N x T x D
        x = x + self.pe[:, :x.size(1), :]
        return self.dropout(x)
        
        
        

In [6]:
class Decoder(nn.Module):
    def __init__(
        self,
        vocab_size,
        max_len,
        d_k,
        d_model,
        n_heads,
        n_layers,
        dropout,
    ):
        
        super().__init__()
        
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoding = PositionalEncoding(d_model, max_len, dropout=dropout)
        transformer_blocks = [
            TransformerBlock(d_k, d_model, n_heads, max_len, dropout=dropout)
            for _ in range(n_layers)
        ]

        self.transformer_blocks = nn.Sequential(*transformer_blocks)
        self.norm = nn.LayerNorm(d_model)
        self.out = nn.Linear(d_model, vocab_size)
        
    def forward(self, x, pad_mask = None):
        x = self.embedding(x)
        x = self.pos_encoding(x)
        for block in self.transformer_blocks:
            x = block(x, pad_mask)
        
        x = self.norm(x)
        return self.out(x)

In [7]:
model = Decoder(
    vocab_size=20_000,
    max_len = 1024,
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout = 0.1,
)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print (device)
model.to(device)

cuda


Decoder(
  (embedding): Embedding(20000, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (attention): CasualHeadAttention(
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (out): Linear(in_features=64, out_features=64, bias=True)
      )
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
    )
    (1): TransformerBlock(
      (attention): CasualHeadAttention(
        (query): Linear(in_featur

In [8]:
batch_size = 8
nr_words = 512
x = np.random.randint(0, 20_000, size=(batch_size, nr_words))
x_t = torch.tensor(x).to(device)

mask = np.ones((batch_size, nr_words))
mask[:, 256:] = 0
mask_t = torch.tensor(mask).to(device)

# Without mask
y = model(x_t)
print (y.shape)

# With mask
y = model(x_t, mask_t)
print (y.shape)

torch.Size([8, 512, 20000])
torch.Size([8, 512, 20000])


### Training part

In [9]:
from transformers import AutoTokenizer, DataCollatorWithPadding
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [10]:
from datasets import load_dataset
raw_datasets = load_dataset("glue", "sst2")
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [11]:
def tokenize_fn(batch):
    return tokenizer(batch['sentence'], truncation=True)

tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

tokenized_datasets 

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [12]:
tokenized_datasets = tokenized_datasets.remove_columns(
    ["sentence", "label", "idx"]
    )

In [13]:
from torch.utils.data import DataLoader

train_loader = DataLoader(
    tokenized_datasets['train'], 
    batch_size=32, 
    shuffle=True,
    collate_fn=data_collator
)

valid_loader = DataLoader(
    tokenized_datasets['validation'],
    batch_size=32,
    shuffle=False,
    collate_fn=data_collator
)



In [14]:
for batch in train_loader:
    for k, v in batch.items():
        print('k:', k, 'n.shape:', v.shape)
    break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


k: input_ids n.shape: torch.Size([32, 55])
k: attention_mask n.shape: torch.Size([32, 55])


In [15]:
# We will use padding token id to tell the CrossEntropyLoss 
# to ignore the padding token in the input sequence.
print ('Padding token:', tokenizer.pad_token)
print ('Padding token id:', tokenizer.pad_token_id)


Padding token: [PAD]
Padding token id: 0


In [16]:
model = Decoder(
    vocab_size=tokenizer.vocab_size,
    max_len = tokenizer.max_model_input_sizes[checkpoint],
    d_k = 16,
    d_model = 64,
    n_heads = 4,
    n_layers = 2,
    dropout = 0.1
)

model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (attention): CasualHeadAttention(
        (query): Linear(in_features=64, out_features=64, bias=True)
        (key): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (out): Linear(in_features=64, out_features=64, bias=True)
      )
      (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
      (ff): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=256, out_features=64, bias=True)
      )
    )
    (1): TransformerBlock(
      (attention): CasualHeadAttention(
        (query): Linear(in_featur

In [17]:
criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters())


#### Predictions
Here is an extra code that uses the model to predict the outputs. 
We are going to test the model behaviors after each epoch.


In [19]:
def generate(model, tokenizer, device, prompt = "I'm ", max_output_length = 160):
        tokenized_prompt = tokenizer(prompt, return_tensors="pt")
        input_ids = tokenized_prompt["input_ids"][:, :-1].to(device)
        mask = tokenized_prompt["attention_mask"][:, :-1].to(device)

        for _ in range (max_output_length):
            outputs = model(input_ids, mask)
            prediction_id = torch.argmax(outputs[:, -1, :], axis=-1)

            input_ids = torch.hstack((input_ids, prediction_id[:, None]))
            mask = torch.ones_like(input_ids).to(device)

            if prediction_id == tokenizer.sep_token_id:
                break

        return tokenizer.decode(input_ids[0])
    
generate(model, tokenizer, device, max_output_length = 10)

"[CLS] I'm Extreme の gone vein economically Kemp traveling beverages cryingGC"

In [20]:
from datetime import datetime

In [21]:
def train(model, criterion, optimizer, train_loader, epochs, valid_loader = None, print_every = 1):
    train_losses = np.zeros(epochs)
    valid_losses = np.zeros(epochs)
    
    for epoch in range (epochs):
        model.train()
        t0 = datetime.now()
        train_loss = []
        print ('Training...    \r', end = '')
        for batch in train_loader:
            batch = {k: v.to(device) for k, v in batch.items()}

            optimizer.zero_grad()
            
            # Targets are just inputs, but shifted by one position (backward)
            targets = batch['input_ids'].clone().detach()
            targets = torch.roll(targets, shifts=-1, dims = -1)
            targets[:, -1] = tokenizer.pad_token_id
            
            outputs = model(batch['input_ids'], batch['attention_mask'])
            
            # This part is tricky. Our output shape is N x T x V, 
            # where N is batch size, T is sequence length, and V is vocab size,
            # and our targets shape is N x T.
            # CrossEntropyLoss expects scores in the form 
            # N x V x T, so we need to transpose            
            loss = criterion(outputs.transpose(2, 1), targets)
            
            loss.backward()
            optimizer.step()

            train_loss.append(loss.item())
            
        train_loss = np.mean(train_loss)


        if valid_loader is not None:
            print ('Validating...    \r', end = '')
            model.eval()
            valid_loss = 0
            n_valid = 0
            for batch in valid_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                
                targets = batch['input_ids'].clone().detach()
                targets = torch.roll(targets, shifts=-1, dims = -1)
                targets[:, -1] = tokenizer.pad_token_id
            
                outputs = model(batch['input_ids'], batch['attention_mask'])
                loss = criterion(outputs.transpose(2, 1), targets)                
                    
                valid_loss += loss.item()*batch["input_ids"].size(0)
                n_valid += batch["input_ids"].size(0)
            
            valid_loss /= n_valid
        else:
            valid_loss = np.nan
        
        train_losses[epoch] = train_loss
        valid_losses[epoch] = valid_loss

        generated_text = generate(model, tokenizer, device, prompt = "I'm ", max_output_length = 10)
 
        if epoch%print_every == 0:
            t1 = datetime.now() - t0
            minutes, seconds = divmod(t1.total_seconds(), 60)
            formatted_time = "{:02}:{:02}".format(int(minutes), int(seconds))
            
            
            print (f'Epoch: {epoch}: Train loss: {train_loss:.2f}, Valid loss: {valid_loss:.2f}, Duration: {formatted_time}min, Text: {generated_text}')
           
           
    return train_losses, valid_losses



In [22]:
train_losses, valid_losses = train(
    model=model, 
    criterion=criterion, 
    optimizer=optimizer,   
    train_loader=train_loader, 
    valid_loader=valid_loader, 
    epochs=50
)

Epoch: 0: Train loss: 5.94, Valid loss: 5.71, Duration: 00:39min, Text: [CLS] I'm not be a movie [SEP]
Epoch: 1: Train loss: 4.98, Valid loss: 5.70, Duration: 00:39min, Text: [CLS] I'm not only to be a few laughs, but it
Epoch: 2: Train loss: 4.64, Valid loss: 5.75, Duration: 00:39min, Text: [CLS] I'm giving a movie that's a movie that is
Epoch: 3: Train loss: 4.44, Valid loss: 5.78, Duration: 00:39min, Text: [CLS] I'm not a movie that is a movie that is a
Epoch: 4: Train loss: 4.31, Valid loss: 5.84, Duration: 00:39min, Text: [CLS] I'm not a lot of a lot of a lot of
Epoch: 5: Train loss: 4.20, Valid loss: 5.89, Duration: 00:39min, Text: [CLS] I'm not a movie that's been able to be
Epoch: 6: Train loss: 4.10, Valid loss: 5.92, Duration: 00:39min, Text: [CLS] I'm giving the film's most sincere and the film
Epoch: 7: Train loss: 4.02, Valid loss: 5.98, Duration: 00:39min, Text: [CLS] I'm not a movie that's a pianist, but
Epoch: 8: Train loss: 3.95, Valid loss: 6.05, Duration: 00:39min, T

#### Let's test how it works

In [23]:
model.eval()
one_elem_loader = DataLoader(
    tokenized_datasets["validation"],
    batch_size=1,
    collate_fn=data_collator
)

for batch in one_elem_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(batch["input_ids"], batch["attention_mask"])
    prediction_ids = torch.argmax(outputs, axis=-1)
    break

print (outputs.shape)
print (prediction_ids)

torch.Size([1, 12, 28996])
tensor([[ 170,  112,  188,  170, 2523, 1105, 9998, 6276, 2025,  102,  102,  102]],
       device='cuda:0')


In [24]:
print ('Input:', tokenizer.decode(batch["input_ids"][0]))
print ('Output:', tokenizer.decode(prediction_ids[0]))


Input: [CLS] it's a charming and often affecting journey. [SEP]
Output: a's a movie and intelligent funny study [SEP] [SEP] [SEP]


##### Using the model to generate text

In [25]:
print(generate(model, tokenizer, device, prompt = "I'm "))
print ('\n')
print (generate(model, tokenizer, device, prompt = "Transformers are "))

[CLS] I'm not a cheat in the end [SEP]


[CLS] Transformers aren't a lioness trouble & tontollah in a cold mosque. [SEP]
