In [86]:
import torch
import torch.nn as nn
from attention import MultiHeadAttention

In [87]:
GPT_CONFIG_124M = {
    "vocab_size" : 50257,
    "context_length": 1024, 
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

In [98]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias = False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))


    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        values = self.W_value(x)
        queries = self.W_query(x)

        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)       # splitting the matrix in order to
        values = values.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2)   # generate (B, NUM_HEADS) independent matrices
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim).transpose(1, 2) # of size (NUM_TOKENS, HEAD_DIM) for parallel
                                                                                             # computation (hence, multi-head attention)

        attn_scores = queries @ keys.transpose(2, 3)
        mask = torch.triu(torch.ones(num_tokens, num_tokens, device=x.device), diagonal=1).bool()
        attn_scores = attn_scores.masked_fill(mask.unsqueeze(0).unsqueeze(0), float('-inf'))
        

        attn_weights = torch.softmax(attn_scores / keys.shape[-1] ** 0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vector = (attn_weights @ values).transpose(1, 2)
        context_vector = context_vector.contiguous().view(b, num_tokens, self.d_out)
        context_vector = self.out_proj(context_vector)

        return context_vector


In [99]:
class DummyGPTModel(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.tok_emb = nn.Embedding(cfg["vocab_size"], cfg["emb_dim"])
        self.pos_emb = nn.Embedding(cfg["context_length"], cfg["emb_dim"])
        self.drop_emb = nn.Dropout(cfg["drop_rate"])
        self.trf_blocks = nn.Sequential(*[DummyTransformerBlock(cfg) for _ in range(cfg["n_layers"])])
        self.final_norm = DummyLayerNorm(cfg["emb_dim"])
        self.out_head = nn.Linear(cfg["emb_dim"], cfg["vocab_size"], bias = False)


    def forward(self, in_idx):
        batch_size, seq_len, = in_idx.shape
        tok_embs = self.tok_emb(in_idx)
        pos_embs = self.pos_emb(torch.arange(seq_len, device=in_idx.device))
        x = tok_embs + pos_embs
        x = self.drop_emb(x)
        x = self.trf_blocks(x)
        x = self.final_norm(x)
        logits = self.out_head(x)
        return logits 


class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.mha = MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            num_heads=cfg["n_heads"],
            dropout=cfg["drop_rate"],
            qkv_bias=cfg["qkv_bias"]
            )

        self.ff = FeedForwardNetwork(cfg)
        self.pre_norm = LayerNorm(cfg["emb_dim"])
        self.post_norm = LayerNorm(cfg["emb_dim"])
        self.drop_shortcut = nn.Dropout(cfg["drop_rate"])


    def forward(self, x):

        shortcut = x
        x = self.pre_norm(x)
        x = self.mha(x)
        x = self.drop_shortcut(x)
        x = x + shortcut



        shortcut = x
        x = self.post_norm(x)
        x = self.ff(x)
        x = self.drop_shortcut(x)
        x = x + shortcut
        return x


class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))
        
    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x - mean) / torch.sqrt(var + self.eps)
        return self.scale * norm_x + self.shift
    

class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5 * x * (1 + torch.tanh(torch.sqrt(torch.tensor(2) / torch.pi) * (x + 0.044715 * torch.pow(x,3))))
    

class FeedForwardNetwork(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4 * cfg["emb_dim"]),
            GELU(),
            nn.Linear(4 * cfg["emb_dim"], cfg["emb_dim"])
        )

    def forward(self, x):
        return self.layers(x)



In [100]:
torch.manual_seed(123)
x = torch.rand(2, 4, 768)
block = TransformerBlock(GPT_CONFIG_124M)
output = block(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)


Input shape: torch.Size([2, 4, 768])
Output shape: torch.Size([2, 4, 768])


In [90]:
import tiktoken
tokenizer = tiktoken.get_encoding("gpt2")
batch = []
text1 = "hello my name is"
text2 = "every day is a"

batch.append(torch.tensor(tokenizer.encode(text1)))
batch.append(torch.tensor(tokenizer.encode(text2)))
batch = torch.stack(batch, dim=0)
print(batch)


tensor([[31373,   616,  1438,   318],
        [16833,  1110,   318,   257]])


In [91]:
torch.manual_seed(123)
GPT = DummyGPTModel(GPT_CONFIG_124M)
logits = GPT(batch)
#print(logits)

print(logits)
print(logits.shape)

tensor([[[-0.1281,  0.7687, -0.0526,  ..., -0.5329, -0.1665, -0.0681],
         [-0.6953,  0.2532, -1.2054,  ...,  0.4483, -0.2764,  1.1931],
         [-0.0747,  0.6798,  0.4218,  ...,  0.3246,  0.0692, -0.2881],
         [ 0.1041, -0.1162, -0.4047,  ...,  0.8971, -0.2541, -0.1744]],

        [[-0.7004,  0.0765, -0.3991,  ..., -0.9598,  0.0935,  0.3342],
         [-0.5938,  0.4453, -0.0059,  ...,  0.3414,  0.0572,  1.0986],
         [ 0.4985,  0.1049, -0.2637,  ...,  0.4080, -0.1786, -0.4767],
         [-0.1035, -0.5901, -0.3931,  ...,  1.4022, -0.3188,  0.1304]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([2, 4, 50257])
