In [1]:
from typing import List
import torch
import torch.nn as nn

from dataclasses import dataclass

@dataclass(frozen=True)
class Config:
    vocab_size: int = 50257
    context_length: int = 1024
    emb_dim: int = 768
    num_heads: int = 12
    num_layers: int = 12
    dropout: float = 0.1
    qkv_bias: bool = False
    bias: bool = False


class DummyTransformerBlock(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        
    def forward(self, x):
        return x
    
class DummyLayerNorm(nn.Module):
    def __init__(self, config: Config):
        super().__init__()
        
    def forward(self, x):
        return x

class DummyGPTModel(nn.Module):
    def __init__(self, config: Config, 
                 transformer_cls=DummyTransformerBlock, 
                 norm_layer_cls=DummyLayerNorm):
        super().__init__()
        self._cfg = config
        self._tok_emd = nn.Embedding(config.vocab_size, config.emb_dim)
        self._pos_emd = nn.Embedding(config.context_length, config.emb_dim)
        self._dropout = nn.Dropout(config.dropout)
        self._transformer_blocks = nn.Sequential(
            *[transformer_cls(config) for _ in range(config.num_layers)]
        )
        self._final_norm = norm_layer_cls(config)
        self._out_head = nn.Linear(config.emb_dim, config.vocab_size, bias=False)
        
    def forward(self, in_idx: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = in_idx.shape
        device = in_idx.device
        tok_embeds = self._tok_emd(in_idx)
        pos_embeds = self._pos_emd(
            torch.arange(seq_len, device=device)
        )
        x = tok_embeds + pos_embeds
        x = self._dropout(x)
        x = self._transformer_blocks(x)
        x = self._final_norm(x)
        return self._out_head(x)


In [2]:
import tiktoken

tokenizer = tiktoken.get_encoding("gpt2")
config = Config()
txts = ["Everything you can imagine is real.", "Simplicity is the ultimate sophistication."]
batch = [torch.tensor(tokenizer.encode(txt)) for txt in txts]
print(batch)
batch = nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=tokenizer.eot_token)
print(batch)
model = DummyGPTModel(config)
logits = model(batch)
print(logits.shape)  # Expected output: torch.Size([2, 8, 50257])
decoded = tokenizer.decode(batch[1].tolist())
print(decoded)  # Decoded text from the first sequence

[tensor([19693,   345,   460,  5967,   318,  1103,    13]), tensor([ 8890,   489,  8467,   318,   262,  8713, 44809,    13])]
tensor([[19693,   345,   460,  5967,   318,  1103,    13, 50256],
        [ 8890,   489,  8467,   318,   262,  8713, 44809,    13]])
torch.Size([2, 8, 50257])
Simplicity is the ultimate sophistication.


In [3]:
batch = []
txt1 = "Every effort moves you"
txt2 = "Every day holds a"
batch.append(torch.tensor(tokenizer.encode(txt1)))
batch.append(torch.tensor(tokenizer.encode(txt2)))
batch = torch.stack(batch, dim=0)
print(batch)
torch.manual_seed(123)
GPT_CONFIG_124M = Config()
model = DummyGPTModel(GPT_CONFIG_124M)
logits = model(batch)
print("Output shape:", logits.shape)
print(logits)

tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])
Output shape: torch.Size([2, 4, 50257])
tensor([[[-1.2034,  0.3201, -0.7130,  ..., -1.5548, -0.2390, -0.4667],
         [-0.1192,  0.4539, -0.4432,  ...,  0.2392,  1.3469,  1.2430],
         [ 0.5307,  1.6720, -0.4695,  ...,  1.1966,  0.0111,  0.5835],
         [ 0.0139,  1.6755, -0.3388,  ...,  1.1586, -0.0435, -1.0400]],

        [[-1.0908,  0.1798, -0.9484,  ..., -1.6047,  0.2439, -0.4530],
         [-0.7860,  0.5581, -0.0610,  ...,  0.4835, -0.0077,  1.6621],
         [ 0.3567,  1.2698, -0.6398,  ..., -0.0162, -0.1296,  0.3717],
         [-0.2407, -0.7349, -0.5102,  ...,  2.0057, -0.3694,  0.1814]]],
       grad_fn=<UnsafeViewBackward0>)


In [4]:
class LayerNorm(nn.Module):
    def __init__(self, config: Config, eps=1e-5):
        super().__init__()
        self._eps = eps
        emb_dim = config.emb_dim
        self._scale = nn.Parameter(torch.ones(emb_dim))
        self._shift = nn.Parameter(torch.zeros(emb_dim))
        
    def forward(self, x: torch.Tensor):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True)
        norm_x = (x - mean) / torch.sqrt(var + self._eps)
        return self._scale * norm_x + self._shift

In [5]:
txt1 = "Every effort moves you"
txt2 = "Every day holds a"
batch = [
    torch.tensor(tokenizer.encode(txt1)),
    torch.tensor(tokenizer.encode(txt2)),
]
batch = torch.stack(batch, dim=0)
print(batch)
torch.manual_seed(123)
GPT_CONFIG_124M = Config()
model = DummyGPTModel(GPT_CONFIG_124M, norm_layer_cls=LayerNorm)
logits = model(batch)
print("Output shape:", logits.shape)
print(logits)

tensor([[6109, 3626, 6100,  345],
        [6109, 1110, 6622,  257]])
Output shape: torch.Size([2, 4, 50257])
tensor([[[-0.7862,  0.2201, -0.4505,  ..., -0.9930, -0.1411, -0.2997],
         [-0.0788,  0.3002, -0.2933,  ...,  0.1582,  0.8911,  0.8224],
         [ 0.3706,  1.1119, -0.3223,  ...,  0.8017, -0.0038,  0.3932],
         [ 0.0636,  1.0565, -0.2506,  ...,  0.7537, -0.0750, -0.6892]],

        [[-0.7203,  0.1351, -0.6010,  ..., -1.0265,  0.1728, -0.2918],
         [-0.5934,  0.4450, -0.0059,  ...,  0.3412,  0.0572,  1.0979],
         [ 0.2673,  0.8401, -0.4473,  ..., -0.0181, -0.1089,  0.2539],
         [-0.1034, -0.5897, -0.3929,  ...,  1.4013, -0.3186,  0.1303]]],
       grad_fn=<UnsafeViewBackward0>)


In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config: Config):
        d_in = d_out = config.emb_dim
        assert (d_out % config.num_heads == 0), "d_out must be divisiable by num_heads"
        super().__init__()
        self._d_in = d_in
        self._d_out = d_out
        self._num_heads = config.num_heads
        self._d_head = d_out // config.num_heads
        self._w_q = nn.Linear(self._d_in, self._d_out, bias=config.qkv_bias)
        self._w_k = nn.Linear(self._d_in, self._d_out, bias=config.qkv_bias)
        self._w_v = nn.Linear(self._d_in, self._d_out, bias=config.qkv_bias)
        self._out_proj = nn.Linear(self._d_out, self._d_out)
        self._dropout = nn.Dropout(config.dropout)
        self.register_buffer("mask", torch.triu(torch.ones(config.context_length, config.context_length), diagonal=1))
        
    def forward(self, x):
        import math
        
        q = self._w_q(x)
        k = self._w_k(x)
        v = self._w_v(x)
        
        b, context_length, d_out = q.shape
        split_view = lambda x: x.view(b, context_length, self._num_heads, self._d_head)\
            .transpose(1, 2)
        q = split_view(q)
        k = split_view(k)
        v = split_view(v)
        
        attn_scores = q @ k.transpose(-1, -2) / math.sqrt(k.shape[-1])
        attn_scores.masked_fill_(self.mask.bool()[:context_length, :context_length], -torch.inf)
        weight = torch.softmax(attn_scores, dim=-1)
        weight = self._dropout(weight)
        
        z = weight @ v
        z = z.transpose(1, 2).contiguous().view(b, context_length, d_out)
        z = self._out_proj(z)
        return z
        

In [7]:
class FeedForward(nn.Module):
    def __init__(self,
                 config: Config,
                 ff_mid_dim:int=0,
                 ff_activation=nn.GELU()):
        super().__init__()
        if ff_mid_dim == 0:
            ff_mid_dim = 4 * config.emb_dim
        self._layers = nn.Sequential(
            nn.Linear(config.emb_dim, ff_mid_dim),
            ff_activation,
            nn.Linear(ff_mid_dim, config.emb_dim),
        )
        
    def forward(self, x):
        return self._layers(x)

class TransformerBlock(nn.Module):
    def __init__(self, config: Config, 
                 norm_layer_cls=LayerNorm,
                 feed_fwd_cls=FeedForward,
                 dropouts=None):
        if dropouts is None:
            dropouts = [config.dropout, config.dropout]
        super().__init__()
        self._norm_1 = norm_layer_cls(config=config)
        self._attention = MultiHeadAttention(config=config)
        self._drop_1 = nn.Dropout(dropouts[0])
        self._norm_2 = norm_layer_cls(config=config)
        self._ff = feed_fwd_cls(config=config)
        self._drop_2 = nn.Dropout(dropouts[1])
        
    def forward(self, x):
        shortcut = x
        x = self._norm_1(x)
        x = self._attention(x)
        x = self._drop_1(x)
        x += shortcut
        
        shortcut = x
        x = self._norm_2(x)
        x = self._ff(x)
        x = self._drop_2(x)
        x += shortcut
        
        return x
        

In [8]:
class GPTModel(nn.Module):
    def __init__(self,
                 config: Config,
                 gpt_dropout=-1,
                 transformer_cls=TransformerBlock,
                 norm_layer_cls=LayerNorm):
        super().__init__()
        self._tok_emd = nn.Embedding(config.vocab_size, config.emb_dim)
        self._pos_emd = nn.Embedding(config.context_length, config.emb_dim)
        if gpt_dropout < 0:
            gpt_dropout = config.dropout
        self._dropout = nn.Dropout(gpt_dropout)
        self._transformers = nn.Sequential(
            *[transformer_cls(config=config) for _ in range(config.num_layers)]
        )
        self._final_norm_layer = norm_layer_cls(config=config)
        self._out_head = nn.Linear(config.emb_dim, config.vocab_size, bias=config.bias)
        
    def forward(self, in_idx: torch.Tensor) -> torch.Tensor:
        batch_size, seq_len = in_idx.shape
        token_embeds = self._tok_emd(in_idx)
        pos_embeds = self._pos_emd(torch.arange(seq_len, device=in_idx.device))

        x = token_embeds + pos_embeds
        x = self._dropout(x)
        x = self._transformers(x)
        x = self._final_norm_layer(x)
        return self._out_head(x)
        
        

In [9]:
torch.manual_seed(123)
# demo_config = Config(context_length=4)
model = GPTModel(config=GPT_CONFIG_124M)
logits = model(batch)
print("Output shape:", logits.shape)
print(logits)

Output shape: torch.Size([2, 4, 50257])
tensor([[[ 0.3612,  0.4223, -0.0709,  ...,  0.3479,  0.4655, -0.2833],
         [-0.1786, -0.5656, -0.9478,  ...,  0.0475,  0.5173, -0.3161],
         [ 0.7118,  0.0335,  0.1078,  ...,  0.1019, -0.4330, -0.2547],
         [-1.0068,  0.3421, -0.1191,  ...,  0.7194,  0.4018,  0.0532]],

        [[-0.2562,  0.0899,  0.0337,  ...,  0.2659,  0.4448, -0.6800],
         [ 0.1229,  0.3651, -0.2071,  ...,  0.7703,  0.2702,  0.2249],
         [ 1.0556,  1.0312, -0.2797,  ...,  0.6933,  0.3201, -0.3172],
         [-0.1560,  0.3924,  0.3286,  ...,  1.2626, -0.1862,  0.0392]]],
       grad_fn=<UnsafeViewBackward0>)


In [None]:
def generate_text_simple(model, idx, max_new_tokens=20, context_size=10):
    for _ in range(max_new_tokens):
        idx_cond = idx[:, -context_size:]
        with torch.no_grad():
            logits = model(idx_cond)
        logits = logits[:, -1, :]
        probs = torch.softmax(logits, dim=-1)
        idx_next = torch.argmax(probs, dim=-1, keepdim=True)
        idx = torch.cat((idx, idx_next), dim=1)
    
    return idx

start_context = "Hello, I am"
encoded = tokenizer.encode(start_context)
print("encoded:", encoded)
encoded_tensor = torch.tensor(encoded).unsqueeze(0)
print("encoded_tensor.shape:", encoded_tensor.shape)

model.eval()
out = generate_text_simple(
model=model,
idx=encoded_tensor,
max_new_tokens=6,
context_size=GPT_CONFIG_124M.context_length,
)
print("Output:", out)
print("Output length:", len(out[0]))

decoded_text = tokenizer.decode(out.squeeze(0).tolist())
print(decoded_text)

encoded: [15496, 11, 314, 716]
encoded_tensor.shape: torch.Size([1, 4])
Output: tensor([[15496,    11,   314,   716, 27018, 24086, 47843, 30961, 42348,  7267]])
Output length: 10
