# Let's test the decoder that we've built

In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from decoder import DecoderBlock, FlashDecoderBlock
from common import TransformerEmbeddings, LayerNorm

In [2]:
input_ids = torch.randint(high=1000, size=(1, 10,)) # The extra dimension at the start is for batch size.
input_ids

tensor([[743, 252, 409, 858, 127, 606, 571, 435, 353, 275]])

In [3]:
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len, n_heads, hidden_size, num_layers):
        super().__init__()

        self.embeddings = TransformerEmbeddings(vocab_size, embed_dim, max_seq_len)
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, max_seq_len, n_heads, hidden_size, False)
            for _ in range(num_layers)
        ])
        self.ln = LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, ids):
        embeddings = self.embeddings(ids)
        for block in self.layers:
            outputs = block(embeddings)
        outputs = self.ln(outputs)
        logits = self.lm_head(outputs)

        return logits

In [4]:
import tiktoken

In [5]:
tokenizer = tiktoken.get_encoding('cl100k_base')



In [6]:
gpt = GPT(
    vocab_size=tokenizer.n_vocab, 
    embed_dim=768, 
    max_seq_len=1024, 
    n_heads=16, 
    hidden_size=768 * 2, 
    num_layers=20
)

In [7]:
f"{sum(p.numel() for p in gpt.parameters()):,} Parameters"

'237,550,517 Parameters'

In [8]:
logits = gpt(input_ids)
print(logits.size())
print('=' * 80)
print(logits)

token = torch.argmax(logits[:,-1,:])
print(f'Predicted token: {token.item()}')

torch.Size([1, 10, 100277])
tensor([[[-0.2122,  0.5439,  0.2116,  ...,  0.4221,  0.7646, -0.8687],
         [-0.0962,  0.3857, -0.6739,  ..., -0.2257,  0.3856, -0.2822],
         [ 0.3044, -0.6492,  0.1518,  ...,  0.7394, -0.4433, -0.9983],
         ...,
         [-0.2844, -0.3886, -0.5908,  ..., -0.7100, -0.3606, -0.9481],
         [-0.2139,  0.0526,  0.3626,  ..., -0.1844, -0.6128, -0.3698],
         [-0.7062, -0.6437, -0.1544,  ..., -0.0406, -0.0345,  1.0520]]],
       grad_fn=<ViewBackward0>)
Predicted token: 76899


# Test Flash Attention GPT

In [9]:
class FlashGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len, n_heads, hidden_size, num_layers):
        super().__init__()

        self.embeddings = TransformerEmbeddings(vocab_size, embed_dim, max_seq_len)
        self.layers = nn.ModuleList([
            FlashDecoderBlock(embed_dim, max_seq_len, n_heads, hidden_size, False)
            for _ in range(num_layers)
        ])
        self.ln = LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, ids):
        embeddings = self.embeddings(ids)
        for block in self.layers:
            outputs = block(embeddings)
        outputs = self.ln(outputs)
        logits = self.lm_head(outputs)

        return logits

In [10]:
flashGPT = FlashGPT(
    vocab_size=tokenizer.n_vocab, 
    embed_dim=768, 
    max_seq_len=1024, 
    n_heads=16, 
    hidden_size=768 * 2, 
    num_layers=20
)

In [11]:
logits = flashGPT(input_ids)
print(logits.size())
print('=' * 80)
print(logits)

token = torch.argmax(logits[:,-1,:])
print(f'Predicted token: {token.item()}')

torch.Size([1, 10, 100277])
tensor([[[-0.9890, -0.4890,  0.1613,  ..., -0.6838, -0.6599,  1.1461],
         [-0.9446,  0.6658, -0.2388,  ...,  0.1858,  0.2388,  0.6867],
         [ 0.4036, -0.3223, -0.2047,  ...,  0.6309, -0.3606, -0.1010],
         ...,
         [-0.0617,  0.5768, -0.5326,  ..., -0.2657, -0.2882,  0.4066],
         [-0.1702,  0.0284,  0.4018,  ...,  0.0069,  0.6360,  0.1791],
         [-0.4051, -0.3799,  0.4261,  ..., -0.2515, -0.8847,  0.0564]]],
       grad_fn=<ViewBackward0>)
Predicted token: 62598
