# GPT Architecture
**with placeholder transformer block and layer norm**

## Configuration Params

In [1]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024, # Context length
    "emb_dim": 768,         # Embedding dimension
    "n_heads": 12,          # Number of attention heads
    "n_layers": 12,         # Number of transformer layers
    "drop_rate": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-Key-Value bias
}

## GPT Skeleton Class

In [2]:
import torch
import torch.nn as nn

In [4]:
class GPTSkeleton(nn.Module):
    def __init__(self, config: dict):
        super(GGPTSkeleton, self).__init__()
        self.token_embedding = nn.Embedding(config["vocab_size"], config["emb_dim"])
        self.position_embedding = nn.Embedding(config["context_length"], config["emb_dim"])

        self.dropout = nn.Dropout(config["drop_rate"])

        self.transformer_layers = nn.Seqential(
            *[TransformerBlockSkeleton(config) for _ in range(config["n_layers"])]
        )

        self.final_norm = LayerNormSkeleton(config["emb_dim"])

        self.out = nn.Linear(config["emb_dim"], config["vocab_size"], bias=False)

    def forward(self, token_ids):
        batch_size, seq_length = token_ids.shape
        token_embeds = self.token_embedding(token_ids)
        position_embeds = self.position_embedding(torch.arange(seq_length, device=token_ids.device))
        x = token_embeds + position_embeds
        x = self.dropout(x)
        x = self.transformer_layers(x)
        x = self.final_norm(x)
        logits = self.out(x)
        return logits

In [5]:
class TransformerBlockSkeleton(nn.Module):
    def __init__(self, config: dict):
        super(TransformerBlockSkeleton, self).__init__()

    def forward(self, x):
        return x

In [6]:
class LayerNormSkeleton(nn.Module):
    def __init__(self, normalized_shape, eps: float = 1e-5):
        super(LayerNormSkeleton, self).__init__()

    def forward(self, x):
        return x

Recall that nn.Embedding is like a look up table