# Building a Decoder Transformer
* simplified version of the original transformar
* sequence generation not requiring an encoder: text generation, text completion

Particularities:
* Masked multi-head self-attention
* Upper triangular mask: prevent attending to future tokens, can only observe previous tokens
* Decoder head: linear + softmax activation to predict most lkely token from the entire vocabulary

## Maked self-attention
Key to causal behaviour. Triangular mask:
```
Orange   1 0 0 0 0
is       1 1 0 0 0
my       1 1 1 0 0
favorite 1 1 1 1 0
fruit    1 1 1 1 1
```
For example: in "Orange in my favorite fruit", the token "favorite" only pays attention to "Orange", "is", "my", "favorite" (4th row in the matrix). The model would learn that the probable next word is "fruit".

Same multi-head attention as in encoder transformer, only the mask is different.

In [None]:
self_attention_mask = (1 - torch.triu(
    torch.ones(1, sequence_length, sequence_length), diagonal=1)).bool()
#...
output = decoder(input_sequence, self_attention_mask)

## Transformer body and head

In [20]:
import torch.nn as nn
import torch

class TransformerDecoder(nn.Module):
    def __init__(self, vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length):
        super(TransformerDecoder, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.positional_encoding = PositionalEncoding(d_model, dropout, max_sequence_length)
        self.layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout) for _ in range(num_layers)])

        # Add a linear layer (head) for next-word prediction
        self.fc = nn.Linear(d_model, vocab_size)

    def forward(self, x, self_mask):
        x = self.embedding(x)
        x = self.positional_encoding(x)
        for layer in self.layers:
            x = layer(x, self_mask)

        # Apply the forward pass through the model head
        x = self.fc(x)
        return F.log_softmax(x, dim=-1)

Test using random sequence. Output is next-token probabilities

In [31]:
num_classes = 3
vocab_size = 10000
batch_size = 8
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
sequence_length = 64
dropout = 0.1

In [32]:
input_sequence = torch.randint(0, vocab_size, (batch_size, sequence_length))

# Create a triangular attention mask for causal attention
self_attention_mask = (1 - torch.triu(torch.ones(1, sequence_length, sequence_length), diagonal=1)).bool()  # Upper triangular mask

# Instantiate the decoder transformer
decoder = TransformerDecoder(vocab_size, d_model, num_layers, num_heads, d_ff, dropout, max_sequence_length=sequence_length)

output = decoder(input_sequence, self_attention_mask)
print(output.shape)
print(output)

torch.Size([8, 64, 10000])
tensor([[[ -8.4589, -10.2904,  -9.2712,  ..., -10.3509,  -8.9835,  -9.8259],
         [ -8.9041, -10.2764,  -9.1243,  ...,  -9.5853,  -9.0880,  -9.2692],
         [ -8.9674,  -9.3047,  -9.6107,  ..., -10.2786,  -8.6873,  -9.6084],
         ...,
         [ -9.2677,  -9.5564,  -9.5862,  ..., -10.4316, -10.1462,  -8.9992],
         [ -8.9629, -10.0210,  -8.4622,  ...,  -9.3217,  -9.6332,  -9.2137],
         [ -9.6159, -10.1622,  -9.7675,  ...,  -9.4639,  -9.5766, -10.7483]],

        [[ -9.4188, -10.2605,  -9.3097,  ...,  -8.8197,  -9.3785,  -8.7900],
         [ -9.4871,  -9.5820,  -8.6926,  ...,  -9.2533,  -9.1897,  -9.1286],
         [ -9.1479,  -9.5352,  -8.5771,  ...,  -9.4810,  -8.9143, -10.1448],
         ...,
         [ -9.4951,  -9.8587,  -9.3739,  ...,  -9.5837,  -8.5317,  -8.6408],
         [-10.0069, -10.0204,  -9.6759,  ...,  -9.0715,  -9.4575,  -9.5549],
         [ -8.3765,  -9.9367,  -9.7127,  ...,  -9.6211, -10.2719,  -9.4033]],

        [[ -8.649

In [18]:
from torch import nn, Tensor
import math

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):
        super().__init__()
        self.dropout = nn.Dropout(p=dropout)

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x: Tensor) -> Tensor:
        """
        Arguments:
            x: Tensor, shape ``[seq_len, batch_size, embedding_dim]``
        """
        x = x + self.pe[:x.size(0)]
        return self.dropout(x)

In [25]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super(DecoderLayer, self).__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = FeedForwardSubLayer(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        attn_output = self.self_attn(x, x, x, mask)
        x = self.norm1(x+self.dropout(attn_output))
        ff_output = self.feed_forward(x)
        x = self.norm2(x+self.dropout(ff_output))
        return x

In [27]:
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        """
        num_heads: number of attention heads, each handling embeddings of size head_dim
        """
        super(MultiHeadAttention, self).__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        self.head_dim = d_model // num_heads

        # define 3 linear transformatins for attention input
        self.query_linear = nn.Linear(d_model, d_model)
        self.key_linear = nn.Linear(d_model, d_model)
        self.value_linear = nn.Linear(d_model, d_model)

        # define linear transformation for the final concatenated output
        self.output_linear = nn.Linear(d_model, d_model)

    def split_heads(self, x, batch_size):
        """Split the inputs accross attention heads"""
        x = x.view(batch_size, -1, self.num_heads, self.head_dim)
        return x.permute(0, 2, 1, 3).contiguous().view(batch_size * self.num_heads, -1, self.head_dim)

    def compute_attention(self, query, key, mask=None):
        """Computes attention weights inside each head."""
        # Compute dot-product attention scores
        scores = torch.matmul(query, key.permute(1, 2, 0))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-1e9"))
        # Normalize attention scores into attention weights
        attention_weights = F.softmax(scores, dim=-1)
        return attention_weights

    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)

        # attention weights for Q, K, V
        query = self.split_heads(self.query_linear(query), batch_size)
        key = self.split_heads(self.key_linear(key), batch_size)
        value = self.split_heads(self.value_linear(value), batch_size)

        # concatenate
        attention_weights = self.compute_attention(query, key, mask)

        # Multiply attention weights by values, concatenate and linearly project outputs
        output = torch.matmul(attention_weights, value)
        output = output.view(batch_size, self.num_heads, -1, self.head_dim).permute(0, 2, 1, 3).contiguous().view(batch_size, -1, self.d_model)
        return self.output_linear(output)

In [29]:
import torch.nn as nn

class FeedForwardSubLayer(nn.Module):
    def __init__(self, d_model, d_ff):
        super(FeedForwardSubLayer, self).__init__()
        self.fc1 = nn.Linear(d_model, d_ff)
        self.fc2 = nn.Linear(d_ff, d_model)
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x)))