# <div style="text-align: center;">
#     <img src="./transformer.webp" alt="Alt text" width="500">
# </div>

# 1. Encoder Block

Basically it consists of:
- Input Embeddings
- Multi-head Self-attention Layer
- Feedforward Network (MLP)

## 1.1 Input Embeddings

### 1.1.1 Create input sequence and input embeddings
Create a vocabulary embedding table and look for embedding of each token. **This embedding table is trainable.**

In [489]:
import torch
import torch.nn as nn

input_seq = ['<sos>', 'hello', 'world', '<eos>']
input_tokens = [0, 87, 101, 999] # just for example, in practice, there is a tokenizer to convert the input sequence to tokens

# imagine we have a vocabulary of 1000 words
vocab_size = 1000
embedding_dim = 64

vocab_embedding_table = nn.Embedding(vocab_size, embedding_dim)

# create input embeddings
input_embeddings = vocab_embedding_table(torch.LongTensor(input_tokens))

### 1.1.2 Create positional encodings

In [490]:
import math

max_seq_len = len(input_seq)
positional_encodings = torch.zeros(max_seq_len, embedding_dim)

for i in range(max_seq_len):
    for j in range(embedding_dim):
        if j % 2 == 0:
            positional_encodings[i, j] = math.sin(i / (10000 ** (j / embedding_dim)))
        else:
            positional_encodings[i, j] = math.cos(i / (10000 ** ((j - 1) / embedding_dim)))

### 1.1.3 Add input embeddings and positional encodings

In [491]:
input_embeddings = input_embeddings + positional_encodings

## 1.2 Multi-Head Self-Attention Layer

In [492]:
Q, K, V = input_embeddings, input_embeddings, input_embeddings

### 1.2.1 Create query, key, value matrices

In [493]:
W_q = nn.Linear(embedding_dim, embedding_dim, bias=False)
W_k = nn.Linear(embedding_dim, embedding_dim, bias=False)
W_v = nn.Linear(embedding_dim, embedding_dim, bias=False)

Q = W_q(input_embeddings)
K = W_k(input_embeddings)
V = W_v(input_embeddings)

print(Q.shape, K.shape, V.shape)

torch.Size([4, 64]) torch.Size([4, 64]) torch.Size([4, 64])


### 1.2.2 Convert Shapes As Multi-Head

In [494]:
head_num = 2 # for example

Q = Q.view(head_num, len(input_seq), embedding_dim//head_num)
K = K.view(head_num, len(input_seq), embedding_dim//head_num)
V = V.view(head_num, len(input_seq), embedding_dim//head_num)

print(Q.shape, K.shape, V.shape)

torch.Size([2, 4, 32]) torch.Size([2, 4, 32]) torch.Size([2, 4, 32])


### 1.2.3 Calculate Attentions
This include **matrix multiplication**, **scaling**, and **softmax**

In [495]:
# matrix multiplication
QK = torch.matmul(Q, K.transpose(1, 2))

# scaling
QK = QK / math.sqrt(embedding_dim//head_num)

# softmax
QK = torch.softmax(QK, dim=-1)

print(QK.shape)

torch.Size([2, 4, 4])


### 1.2.4 Multiply Values

In [496]:
QKV = torch.matmul(QK, V)

print(QKV.shape)

torch.Size([2, 4, 32])


### 1.2.5 Concatenating the Heads & Linear Projection

In [497]:
multi_head_output = QKV.transpose(0, 1).contiguous()
multi_head_output = multi_head_output.view(len(input_seq), embedding_dim)

# linear projection
W_o = nn.Linear(embedding_dim, embedding_dim, bias=False)
multi_head_output = W_o(multi_head_output)

### 1.2.5 Add & Norm
- Add residual connection
- Layernorm
- Dropout

In [498]:
# residual connection
output = input_embeddings + multi_head_output

# layer normalization
output = nn.LayerNorm(embedding_dim)(output)
output = nn.Dropout(0.3)(output)

## 1.3 FeedForward Network (MLP)
**Two Linear layers.** Also in the end, it has layernorm and dropout.

In [499]:
W_ff_1 = nn.Linear(embedding_dim, 4*embedding_dim)
W_ff_2 = nn.Linear(4*embedding_dim, embedding_dim)

# feedforward: linear -> relu -> linear
ff_output = W_ff_1(output)
ff_output = nn.ReLU()(ff_output)
ff_output = W_ff_2(ff_output)

# residual connection
ff_output = output + ff_output

# layer normalization
ff_output = nn.LayerNorm(embedding_dim)(ff_output)
ff_output = nn.Dropout(0.3)(ff_output)

encoder_output = ff_output

# 2. Decoder Block

## 2.1 Output Embeddings

### 2.1.1 Create output embeddings

In [500]:
output_seq = ['<sos>', 'I', 'am', 'here', '<eos>']
output_tokens = [0, 101, 102, 103, 999] # just for example, in practice, there is a tokenizer to convert the input sequence to tokens

output_embeddings = vocab_embedding_table(torch.LongTensor(output_tokens))

### 2.1.2 Create positional encodings

In [501]:
positional_encodings = torch.zeros(len(output_seq), embedding_dim)

for i in range(len(output_seq)):
    for j in range(embedding_dim):
        if j % 2 == 0:
            positional_encodings[i, j] = math.sin(i / (10000 ** (j / embedding_dim)))
        else:
            positional_encodings[i, j] = math.cos(i / (10000 ** ((j - 1) / embedding_dim)))

### 2.1.3 Add output embeddings and positional embeddings

In [502]:
output_embeddings = output_embeddings + positional_encodings

print(output_embeddings.shape)

torch.Size([5, 64])


## 2.2 Multi-Head Self-Attention and Cross-Attention

### 2.2.1 Masked Self-Attention Layer
Just do multi-head self attention on output tokens. **Note that now there is a mask during QK multiplication**.

In [503]:
W_q = nn.Linear(embedding_dim, embedding_dim, bias=False)
W_k = nn.Linear(embedding_dim, embedding_dim, bias=False)
W_v = nn.Linear(embedding_dim, embedding_dim, bias=False)

Q = W_q(output_embeddings)
K = W_k(output_embeddings)
V = W_v(output_embeddings)

Q = Q.view(head_num, len(output_seq), embedding_dim//head_num)
K = K.view(head_num, len(output_seq), embedding_dim//head_num)
V = V.view(head_num, len(output_seq), embedding_dim//head_num)

QK = torch.matmul(Q, K.transpose(1, 2))

# generate mask
mask = torch.tril(torch.ones(len(output_seq), len(output_seq))).unsqueeze(0)
# apply mask
QK = QK.masked_fill(mask == 0, -float('inf'))

QK = QK / math.sqrt(embedding_dim//head_num)

QK = torch.softmax(QK, dim=-1)

QKV = torch.matmul(QK, V)

## transpose and reshape
QKV = QKV.transpose(0, 1).contiguous()
QKV = QKV.view(len(output_seq), embedding_dim)

# linear projection
W_o = nn.Linear(embedding_dim, embedding_dim, bias=False)
QKV = W_o(QKV)

## add residual connection
QKV = output_embeddings + QKV

## layer normalization
QKV = nn.LayerNorm(embedding_dim)(QKV)
QKV = nn.Dropout(0.3)(QKV)

self_attention_output = QKV

print(self_attention_output.shape)

torch.Size([5, 64])


### 2.2.2 Cross-Attention

Now, the output of the first attention layer will be only the queries of the second attention layer. The output of encoder block will be the keys and values.

In [504]:
Q = self_attention_output
K = encoder_output
V = encoder_output

print(Q.shape, K.shape, V.shape)

torch.Size([5, 64]) torch.Size([4, 64]) torch.Size([4, 64])


Continue the multi-head cross-attention

In [505]:
W_q = nn.Linear(embedding_dim, embedding_dim, bias=False)
W_k = nn.Linear(embedding_dim, embedding_dim, bias=False)
W_v = nn.Linear(embedding_dim, embedding_dim, bias=False)

Q = W_q(Q)
K = W_k(K)
V = W_v(V)

Q = Q.view(head_num, len(output_seq), embedding_dim//head_num)
K = K.view(head_num, len(input_seq), embedding_dim//head_num)
V = V.view(head_num, len(input_seq), embedding_dim//head_num)

print(Q.shape, K.shape, V.shape)

torch.Size([2, 5, 32]) torch.Size([2, 4, 32]) torch.Size([2, 4, 32])


In [506]:
QK = torch.matmul(Q, K.transpose(1, 2))

QK = QK / math.sqrt(embedding_dim//head_num)
QK = torch.softmax(QK, dim=-1)

print(QK.shape)

QKV = torch.matmul(QK, V)

print(QKV.shape)

torch.Size([2, 5, 4])
torch.Size([2, 5, 32])


Concatenating Heads & Linear Projection

In [507]:
QKV = QKV.transpose(0, 1).contiguous()
QKV = QKV.view(len(output_seq), embedding_dim)

# linear projection
W_o = nn.Linear(embedding_dim, embedding_dim, bias=False)
QKV = W_o(QKV)

Add & Norm

In [508]:
# add residual connection
QKV = self_attention_output + QKV

# layer normalization
QKV = nn.LayerNorm(embedding_dim)(QKV)
QKV = nn.Dropout(0.3)(QKV)

## 2.3 Feedforward Layer (MLP)

In [509]:
W_ff_1 = nn.Linear(embedding_dim, 4*embedding_dim)
W_ff_2 = nn.Linear(4*embedding_dim, embedding_dim)

# feedforward: linear -> relu -> linear
output = W_ff_1(QKV)
output = nn.ReLU()(output)
output = W_ff_2(output)

# add residual connection
output = QKV + output

# layer normalization
output = nn.LayerNorm(embedding_dim)(output)
output = nn.Dropout(0.3)(output)

print(output.shape)

torch.Size([5, 64])


## 2.4 Decoding Layer
A final linear layer and a softmax.

In [510]:
W_dec = nn.Linear(embedding_dim, vocab_size)

logits = W_dec(output)

logits = nn.Softmax(dim=-1)(logits)

# get the logits at the last position
logits = logits[-1, :]

# get next token
next_token = torch.argmax(logits, dim=-1)

print(next_token)

tensor(952)
