In [1]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from dataclasses import dataclass



In [2]:
@dataclass
class GPTConfig():
    n_layers: int = 12
    n_heads: int = 12
    d_model: int = 768
    vocab_size: int = 50257
    window_size: int = 1024
    dropout: int = 0.1
    batch_size: int = 512

In [3]:
class GPT2Attention(nn.Module):

    def __init__(self, config):
        super().__init__()
        
        self.config = config
        self.c_attn = nn.Linear(config.d_model, 3*config.d_model)
        self.c_proj = nn.Linear(config.d_model, config.d_model)
        self.attn_dropout = nn.Dropout(config.dropout)
        self.resid_dropout = nn.Dropout(config.dropout)

        self.register_buffer(
            "mask", 
            torch.tril(torch.ones(config.window_size, config.window_size)).view(
                1, 1, config.window_size, config.window_size
                )
            )

        # NOTE Residual weight scaling at initialization
        self.RESID_SCALING = 1
    
    def forward(self, x):

        B, T, C = x.shape

        qkv = self.c_attn(x)           # (B, T, 3*C)
        q, k, v = qkv.split(C, dim=2)  # (B, T, C)

        q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1,2) # (B, nh, T, d_k)
        k = k.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1,2) # (B, nh, T, d_k)
        v = v.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1,2) # (B, nh, T, d_v)

        attn = torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(C // self.config.n_heads)                 # (B, nh, T, T)
        masked_attn = attn.masked_fill(self.mask[:,:,:T,:T]==0, float("-inf"))     # (B, nh, T, T)
        attn_weights = F.softmax(masked_attn, dim=-1)               # (B, nh, T, T)
        # apply dropout
        attn_weights = self.attn_dropout(attn_weights)

        context = torch.matmul(attn_weights, v)                       # (B, nh, T, d_v)
        context = context.transpose(1,2).contiguous().view(B, T, C)   # (B, T, C)
        context = self.c_proj(context)                                # (B, T, C)
        context = self.resid_dropout(context)                         # (B, T, C)

        return context


class GPT2MLP(nn.Module):

    def __init__(self, config):

        super().__init__()
        
        self.c_fc = nn.Linear(config.d_model, 4 * config.d_model)
        self.c_proj = nn.Linear(4*config.d_model, config.d_model)
        self.act = nn.GELU(approximate='tanh')    # NOTE pay attention to how to use
        self.dropout = nn.Dropout(config.dropout)

        # NOTE Residual weight scaling at initialization
        self.RESID_SCALING = 1
    
    def forward(self, x):
        
        x = self.c_fc(x)
        x = self.act(x)
        x = self.c_proj(x)
        x = self.dropout(x)
        return x

In [4]:
class GPT2Block(nn.Module):

    def __init__(self, config):

        super().__init__()

        self.ln_1 = nn.LayerNorm(config.d_model)
        self.attn = GPT2Attention(config)
        self.ln_2 = nn.LayerNorm(config.d_model)
        self.mlp = GPT2MLP(config)
    
    def forward(self, x):

        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [5]:
class GPT2(nn.Module):

    def __init__(self, config):
        super().__init__()

        self.config = config
        self.transformer = nn.ModuleDict(dict(
            wte = nn.Embedding(config.vocab_size, config.d_model),
            wpe = nn.Embedding(config.window_size, config.d_model),
            drop = nn.Dropout(config.dropout),
            h = nn.ModuleList([GPT2Block(config) for _ in range(config.n_layers)]),
            ln_f = nn.LayerNorm(config.d_model),  # NOTE adidtional normalization
        ))

        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
        # NOTE Weight sharing used in the original transformer paper
        self.transformer.wte.weight = self.lm_head.weight

    def _init_weights(self, module):
        
        std = 0.02
        if isinstance(module, nn.Linear):
            if hasattr(module, "RESID_SCALING"):
                std *= torch.sqrt(2 * self.config.n_layers) 
            nn.init.normal_(module, mean=0.0, std=std)
            if module.bias:
                nn.init.zeros_(module.bias)

        elif isinstance(module, nn.Embedding):
            nn.init.normal_(module.weight, mean=0.0, std=std)
    
    def forward(self, inp, target=None):
        
        B, T = inp.shape
        pos = torch.arange(0, T, dtype=torch.long, device=x.device)
        pe = self.transformer.wpe(pos)
        te = self.transformer.wte(inp)

        x = te + pe
        # NOTE here is the dropout
        x = self.transformer.drop(x)

        for block in self.h:
            x = block(x)
        
        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if target:
            # NOTE pay attention to how to adapt the dimention of tensors
            loss = F.cross_entropy(logits.view(-1, logits.size(-1)), target.view(-1))
        return logits, loss




In [6]:
config = GPTConfig()
model = GPT2(config)
model

GPT2(
  (transformer): ModuleDict(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Linear(in_features=768, out_features=2304, bias=True)
          (c_proj): Linear(in_features=768, out_features=768, bias=True)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Linear(in_features=768, out_features=3072, bias=True)
          (c_proj): Linear(in_features=3072, out_features=768, bias=True)
          (act): GELU(approximate='tanh')
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_hea

In [19]:
import inspect

# Example 1: Basic function signature
def add(a: int, b=1, *args, **kwargs):
    return a + b

sig = inspect.signature(add)
print(sig)  # (a, b=1, *args, **kwargs)

(a: int, b=1, *args, **kwargs)


In [17]:
import inspect

# Example 1: Basic function signature
def add(a, b=1, *args, **kwargs):
    return a + b

sig = inspect.signature(add)
print(sig)  # (a, b=1, *args, **kwargs)

# Example 2: Getting parameter details
for name, param in sig.parameters.items():
    print(f"{name}: {param.default}, {param.kind}")

# Example 3: Checking if a parameter exists
has_b = 'b' in sig.parameters
print(has_b)  # True

# Example 4: Binding arguments to parameters
bound_args = sig.bind(5, 3)
print(bound_args.arguments)  # {'a': 5, 'b': 3}

# Example 5: Class method signature
class Calculator:
    def multiply(self, x, y=2):
        return x * y

calc_sig = inspect.signature(Calculator.multiply)
print(calc_sig)  # (self, x, y=2)

torch.Size([50257, 768])

In [55]:
sd = model.state_dict()
sd_keys = sd.keys()
sd_keys = [k for k in sd_keys if not k.endswith('.attn.mask')] # discard this mask / buffer, not a param
len(sd_keys)

149

In [44]:
from transformers import GPT2LMHeadModel
model_hf = GPT2LMHeadModel.from_pretrained("/data/repos/huggingface/gpt2")
sd_hf = model_hf.state_dict()

sd_keys_hf = sd_hf.keys()
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.masked_bias')] # ignore these, just a buffer
sd_keys_hf = [k for k in sd_keys_hf if not k.endswith('.attn.bias')] # same, just the mask (buffer)
transposed = ['attn.c_attn.weight', 'attn.c_proj.weight', 'mlp.c_fc.weight', 'mlp.c_proj.weight']

In [47]:
for k in sd_keys_hf:
    if any(k.endswith(w) for w in transposed):
        # special treatment for the Conv1D weights we need to transpose
        assert sd_hf[k].shape[::-1] == sd[k].shape
        with torch.no_grad():
            sd[k].copy_(sd_hf[k].t())
    else:
        # vanilla copy over the other parameters
        assert sd_hf[k].shape == sd[k].shape
        with torch.no_grad():
            sd[k].copy_(sd_hf[k])