In [2]:
import torch 
import math
import base64
import gzip
import torch.nn as nn 
import numpy as np
from dataclasses import dataclass
from typing import Dict, Iterable, Optional
import torch.nn.functional as F

In [3]:
class RotaryPositionalEmbedding(nn.Module):
    def __init__(self, 
                 depth: int, # d_model // num_heads
                 base: int = 10000):
        super().__init__()
        theta = 1.0 / (base ** (torch.arange(0, depth, 2) / depth)) # [embed_dim //2 ,]
        self.register_buffer(name="theta", tensor=theta)
        
    def forward(self, x: torch.Tensor):
        # x: [N, num_heads, seq_len, depth] 
        # do seq_len kh cố định < tokenizer của gpt2 thiết kế cho chiều dài đông>, 
        # vì vậy ta sẽ tạo postions khi có seq_len đưuọc lấy từ x
        seq_len = x.size(2) 
        positions = torch.arange(0, seq_len)[:, None]
        # print("positions: ", positions.shape)
        # print("theta: ", self.theta.shape)
        theta_cp = positions * self.theta
        x_even, x_odd = x[..., 0::2], x[..., 1::2] # [N, num_heads, seq_len, depth//2]
        
        sin_angles = torch.sin(theta_cp)
        cos_angles = torch.cos(theta_cp)

        x_even_rotated = x_even * cos_angles - x_odd * sin_angles
        x_odd_rotated = x_even * sin_angles + x_odd * cos_angles

        x_rotated = torch.ones_like(x)
        x_rotated[..., 0::2] = x_even_rotated
        x_rotated[..., 1::2] = x_odd_rotated

        return x_rotated
 

mock_data = torch.randint(0, 1, (1, 4, 4))
embed_model = RotaryPositionalEmbedding(depth= 4//1)
output = embed_model.forward(mock_data)
print("output: ", output.shape)


output:  torch.Size([1, 4, 4])


In [46]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int):
        super().__init__()
        self.num_heads = num_heads
        self.d_model = d_model
        assert d_model % num_heads == 0, "d model must be divisible by num heads"
        self.depth  = self.d_model // num_heads

        self.rotary_embed = RotaryPositionalEmbedding(depth=self.depth)
        self.Wq = nn.Linear(in_features=d_model, out_features=d_model)
        self.Wk = nn.Linear(in_features=d_model, out_features=d_model)
        self.Wv = nn.Linear(in_features=d_model, out_features=d_model)
        self.Wo = nn.Linear(in_features=d_model, out_features=d_model)

    def split_heads(self, x: torch.Tensor, batch_size: int):
        # x: [N, seq_len, embed_dim] -> [N, num_heads, seq_len, embed_dim//num_heads]
        x = x.view(batch_size, -1, self.num_heads, self.depth)
        return x.transpose(1, 2)
    
    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        matmul_qk = torch.matmul(Q, K.transpose(-2, -1))
        dk = torch.tensor(K.size(-1), dtype=torch.float32)
        scaled_attention_logits = matmul_qk / torch.sqrt(dk)

        if mask is not None:
            scaled_attention_logits = scaled_attention_logits.masked_fill(mask == 0, -1e9)

        attention_weights = torch.nn.functional.softmax(scaled_attention_logits, dim=-1)
        output = torch.matmul(attention_weights, V)

        return output, attention_weights
    
    def forward(self, 
                x: torch.Tensor, 
                prev_K: torch.Tensor = None, 
                prev_V: torch.Tensor = None, 
                mask: torch.Tensor = None):
        
        x = x.float()
        batch_size = x.size(0)
        seq_len = x.size(1)
        # Linear projections and split and reshape into multiple heads
        q = self.split_heads(x=self.Wq(x), batch_size=batch_size) # [N, num_heads, seq_len, depth]
        k = self.split_heads(x=self.Wk(x), batch_size=batch_size) # [N, num_heads, seq_len, depth]
        v = self.split_heads(x=self.Wv(x), batch_size=batch_size) # [N, num_heads, seq_len, depth]
        q_rotated = self.rotary_embed(q) # [N, num_heads, seq_len, depth]
        k_rotated = self.rotary_embed(k) # [N, num_heads, seq_len, depth]

        if prev_K is not None and prev_V is not None:
            print('Before ===>', "x: ", x.shape, 
                  "q: ", q_rotated.shape, 
                  "k: ", k_rotated.shape, 
                  "v: ", v.shape, 
                  "prev_K: ", prev_K.shape, 
                  "prev_V: ", prev_V.shape)

            # Print shapes for debugging
            # print(f'prev_K shape: {prev_K.shape}, k shape: {k.shape}')
            # print(f'prev_V shape: {prev_V.shape}, v shape: {v.shape}')

            if prev_K.dim() == 3:
                prev_K = prev_K.view(batch_size, self.num_heads, -1, self.depth)
            if prev_V.dim() == 3:
                prev_V = prev_V.view(batch_size, self.num_heads, -1, self.depth)

            # Concatenate along the sequence length dimension
            k_rotated = torch.cat(tensors=(k_rotated, prev_K), dim=2) # (batch_size, num_heads, m + seq_len, depth)
            v = torch.cat([prev_V, v], dim=2)  # (batch_size, num_heads, m + seq_len, depth)
        
            print('After ===>', "x: ", x.shape, 
                  "q: ", q_rotated.shape, 
                  "k: ", k_rotated.shape, 
                  "v: ", v.shape, 
                  "prev_K: ", prev_K.shape, 
                  "prev_V: ", prev_V.shape)

        scaled_attention, _ = self.scaled_dot_product_attention(q_rotated, k_rotated, v, mask)
        # print(scaled_attention.shape) # [N, num_heads, seq_len, depth]
        scaled_attention = scaled_attention.transpose(1, 2).contiguous() # [N, seq_len, num_heads, depth]
        scaled_attention = scaled_attention.view(batch_size, seq_len, self.d_model) # [N, seq_len, d_model]

        output = self.Wo(scaled_attention) # [N, seq_len, d_model]
        return (
            output,  # [N, seq_len, d_model]
            k_rotated, # [N, seq_len, d_model]
            v # [N, seq_len, d_model]
        )
        

attention = MultiHeadAttention(d_model=256, num_heads=8)
mock_data = torch.randint(0 ,1 , (1, 128, 256)).long()
out_encoder, _, _ = attention(x=mock_data, prev_K=None, prev_V=None, mask=None)
print(out_encoder.shape)

torch.Size([1, 128, 256])


In [47]:
class DecoderLayer(nn.Module):
    def __init__(self, d_model: int, num_heads: int, ff_dim: int, dropout: float = 0.1):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.enc_dec_attn = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.ffn = nn.Sequential(
            nn.Linear(in_features=d_model, out_features=ff_dim), 
            nn.ReLU(), 
            nn.Linear(in_features=ff_dim, out_features=d_model)
        )

        self.norm1 = nn.LayerNorm(normalized_shape=d_model)
        self.norm2 = nn.LayerNorm(normalized_shape=d_model)
        self.norm3 = nn.LayerNorm(normalized_shape=d_model)
        self.dropout = nn.Dropout(p=dropout)

    def forward(self, x, encoder_output, prev_K = None, prev_V = None):
        # Self-attention (masked for autoregressive behavior)
        print("--------------------> Caching ------------------>")
        _x, k, v = self.self_attn(x, prev_K, prev_V)
        print("<------------------- End Caching <---------------")
        x = x + self.dropout(_x)
        x = self.norm1(x)

        # Encoder-Decoder attention
        _x, _, _ = self.enc_dec_attn(x, encoder_output, encoder_output)
        x = x + self.dropout(_x)
        x = self.norm2(x)

        # Feed forward network
        x_ffn = self.ffn(x)
        x = x + self.dropout(x_ffn)
        x = self.norm3(x)
        return x, k, v

attention = MultiHeadAttention(d_model=256, num_heads=8)
decoder = DecoderLayer(d_model=256, num_heads=8, ff_dim=512)

mock_data = torch.randint(0 ,1 , (1, 128, 256)).float()
out_encoder, _, _ = attention(x=mock_data, prev_K=None, prev_V=None, mask=None)
output_dec, _, _ = decoder(mock_data, out_encoder)
print(output_dec.shape)

--------------------> Caching ------------------>
<------------------- End Caching <---------------
Before ===> x:  torch.Size([1, 128, 256]) q:  torch.Size([1, 8, 128, 32]) k:  torch.Size([1, 8, 128, 32]) v:  torch.Size([1, 8, 128, 32]) prev_K:  torch.Size([1, 128, 256]) prev_V:  torch.Size([1, 128, 256])
After ===> x:  torch.Size([1, 128, 256]) q:  torch.Size([1, 8, 128, 32]) k:  torch.Size([1, 8, 256, 32]) v:  torch.Size([1, 8, 256, 32]) prev_K:  torch.Size([1, 8, 128, 32]) prev_V:  torch.Size([1, 8, 128, 32])
torch.Size([1, 128, 256])


In [None]:
class TransformerDecoder(nn.Module):
    def __init__(self, d_model: int, num_heads: int, 
                 num_layers: int, ff_dim: int, 
                 vocab_size: int, dropout: float=0.1):
        super().__init__()

        self.embed_model = nn.Embedding(num_embeddings=vocab_size, 
                                        embedding_dim=d_model)
        
        self.layers = nn.ModuleList([
            DecoderLayer(d_model=d_model, num_heads=num_heads, ff_dim=ff_dim, dropout=dropout)
            for _ in range(num_layers)
        ])

        self.fc = nn.Linear(in_features=d_model, out_features=vocab_size)
        self.d_model = d_model


    def forward(self, x, enc_out, prev_K_list: list = None, prev_V_list: list = None):
        x = self.embed_model(x) * math.sqrt(self.d_model)

        if prev_K_list is None or prev_V_list is None:
            prev_V_list = [None] * len(self.layers)
            prev_K_list = [None] * len(self.layers)

        new_K_list = []
        new_V_list = []

        for i, layer in enumerate(self.layers):
            x, k, v = layer(x, enc_out, prev_K_list[i], prev_V_list[i])
            print('===>', x.shape, k.shape, v.shape)
            if prev_K_list[i] is not None:
                print('*', prev_K_list[i].shape, prev_V_list[i].shape)
            
            new_K_list.append(k)
            new_V_list.append(v)
        logits = self.fc(x)
        return logits, new_K_list, new_V_list



d_model = 512
num_heads = 8
num_layers = 1
ff_dim = 2048
vocab_size = 10000
max_seq_len = 50

attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)

transformer_decoder = TransformerDecoder(
    d_model=d_model, 
    num_heads=num_heads, 
    num_layers=num_layers, 
    ff_dim=ff_dim, 
    vocab_size=vocab_size, 
)

mock_data = torch.randint(0, 1, size=(1, max_seq_len, d_model)).long()

out_encoder, _, _ = attention(x=mock_data, prev_K=None, prev_V=None, mask=None)

output, _, _ = transformer_decoder(mock_data, out_encoder)
print(output.shape)

In [None]:
start_token = 1     # Mock start token, typically something like <sos> in real vocabularies
end_token = 2       # Mock end token, typically something like <eos> in real vocabularies

generated_sequence = [start_token]
prev_K_list, prev_V_list = None, None

for i in range(max_seq_len):

    print('---> Decoding...')
    # Prepare the input tensor for the current step (batch_size=1)
    current_input = torch.tensor([generated_sequence[-1]], dtype=torch.float).unsqueeze(0)  # Shape: (1, 1)
    print(current_input.shape)

    # Forward pass through the decoder
    logits, new_K_list, new_V_list = decoder(current_input, out_encoder, prev_K_list, prev_V_list)

    # Select the token with the highest probability
    next_token = logits.argmax(dim=-1).item()  # Get the predicted token index

    # Append the generated token to the sequence
    generated_sequence.append(next_token)

    # Update the cached keys and values for the next iteration
    prev_K_list, prev_V_list = new_K_list, new_V_list

    # Stop if the end token is generated
    if next_token == end_token:
        break

    print('---> End Decoding...')

# `generated_sequence` now contains the full generated sequence of tokens
