# Let's test the decoder that we've built

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F

from decoder import DecoderBlock
from common import TransformerEmbeddings, LayerNorm

In [3]:
input_ids = torch.randint(high=1000, size=(1, 10,)) # The extra dimension at the start is for batch size.
input_ids

tensor([[813, 137, 291, 776, 847, 865, 852, 323, 508, 554]])

In [4]:
class GPT(nn.Module):
    def __init__(self, vocab_size, embed_dim, max_seq_len, n_heads, hidden_size, num_layers):
        super().__init__()

        self.embeddings = TransformerEmbeddings(vocab_size, embed_dim, max_seq_len)
        self.layers = nn.ModuleList([
            DecoderBlock(embed_dim, max_seq_len, n_heads, hidden_size, False)
            for _ in range(num_layers)
        ])
        self.ln = LayerNorm(embed_dim)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, ids):
        embeddings = self.embeddings(ids)
        for block in self.layers:
            outputs = block(embeddings)
        outputs = self.ln(outputs)
        logits = self.lm_head(outputs)

        return logits

In [5]:
gpt = GPT(
    vocab_size=20000, 
    embed_dim=1024, 
    max_seq_len=2048, 
    n_heads=16, 
    hidden_size=4096, 
    num_layers=32
)

In [6]:
f"{sum(p.numel() for p in gpt.parameters()):,} Parameters"

'412,309,024 Parameters'

In [7]:
logits = gpt(input_ids)
print(logits.size())
print('=' * 80)
print(logits)

token = torch.argmax(logits[:,-1,:])
print(f'Predicted token: {token.item()}')

torch.Size([1, 10, 20000])
tensor([[[ 0.1072,  0.0526,  0.5062,  ...,  0.0578, -0.6875,  0.4822],
         [ 1.4474,  0.4296,  0.2251,  ...,  0.5610, -0.3306,  0.3003],
         [ 0.3551, -0.9697,  0.4853,  ...,  0.3207,  0.1015,  0.6030],
         ...,
         [ 0.3563, -0.7549,  0.8186,  ..., -0.0430, -0.9175, -0.3486],
         [ 0.1267,  0.0051,  0.3541,  ...,  0.4223, -0.2188, -0.0300],
         [-0.5174,  0.1711,  0.2495,  ..., -0.3418, -0.4729,  0.8262]]],
       grad_fn=<ViewBackward0>)
Predicted token: 16016
