In [2]:
from transformers import GPT2Model,GPT2Config

In [3]:
config = GPT2Config(
    vocab_size=50257,  # Vocabulary size of the GPT-2 model
    n_embd=240,  # Hidden size of the transformer embeddings
    n_layer=10,  # Number of transformer layers
    n_head=10,  # Number of attention heads
    n_positions=1024,  # Maximum sequence length
)

In [4]:
model = GPT2Model(config)

In [20]:
model

GPT2Model(
  (wte): Embedding(50257, 240)
  (wpe): Embedding(1024, 240)
  (drop): Dropout(p=0.1, inplace=False)
  (h): ModuleList(
    (0-9): 10 x GPT2Block(
      (ln_1): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
      (attn): GPT2Attention(
        (c_attn): Conv1D()
        (c_proj): Conv1D()
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (resid_dropout): Dropout(p=0.1, inplace=False)
      )
      (ln_2): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
      (mlp): GPT2MLP(
        (c_fc): Conv1D()
        (c_proj): Conv1D()
        (act): NewGELUActivation()
        (dropout): Dropout(p=0.1, inplace=False)
      )
    )
  )
  (ln_f): LayerNorm((240,), eps=1e-05, elementwise_affine=True)
)

In [8]:
sum([p.numel() for p in model.parameters()]) / 1e6

19.25112

In [179]:
import torch
import torch.nn as nn
import lightning as L
import torch.nn.functional as F
import numpy as np
from torch.utils.data import Dataset, DataLoader
import torchmetrics
import random
import os
from typing import Dict, Any
from dataclasses import dataclass
@dataclass
class GPT2Config:
    vocab_size: int
    embed_dim: int
    block_size: int
    num_heads: int
    num_layers: int
    dropout: float
    lr: float
    t_max: int

class SwiGLU(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        hidden_size = int(config.embed_dim * (4 * (2 / 3)))
        self.linear_w = nn.Linear(config.embed_dim, hidden_size, bias=False)
        self.linear_v = nn.Linear(config.embed_dim, hidden_size, bias=False)
        self.linear_w2 = nn.Linear(hidden_size, config.embed_dim, bias=False)
        self.dropout = nn.Dropout(0.2)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.linear_w.weight)
        nn.init.xavier_uniform_(self.linear_v.weight)
        nn.init.xavier_uniform_(self.linear_w2.weight)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        out = self.dropout(self.linear_w2(F.silu(self.linear_w(x)) * self.linear_v(x)))
        return out

class Head(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.head_dim = config.embed_dim // config.num_heads
        self.qkv = nn.Linear(config.embed_dim, 3 * self.head_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)
        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.qkv.weight)

    def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor:
        batch_size, seq_len, _ = x.shape
        qkv = self.qkv(x).reshape(batch_size, seq_len, 3, self.head_dim).permute(0, 2, 1, 3)
        q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
        out = torch.nn.functional.scaled_dot_product_attention(
            q, k, v, attn_mask=None, dropout_p=self.dropout.p if training else 0.0, is_causal=True
        )
        return out

class MultiHeadAttention(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.heads = nn.ModuleList([Head(config) for _ in range(config.num_heads)])
        self.proj = nn.Linear(config.embed_dim, config.embed_dim, bias=False)
        self.dropout = nn.Dropout(config.dropout)

    def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor:
        out = torch.cat([head(x, training) for head in self.heads], dim=-1)
        out = self.proj(out)
        return self.dropout(out)

class Block(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.attention = MultiHeadAttention(config)
        self.ffwd = SwiGLU(config)
        self.ln1 = nn.LayerNorm(config.embed_dim)
        self.ln2 = nn.LayerNorm(config.embed_dim)

    def forward(self, x: torch.Tensor, training: bool) -> torch.Tensor:
        x = self.ln1(x)
        x = x + self.attention(x, training)
        x = self.ln2(x)
        return x + self.ffwd(x)

class TransformerDecoder(nn.Module):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.config = config
        self.token_embeddings = nn.Embedding(config.vocab_size, config.embed_dim)
        self.position_embeddings = nn.Embedding(config.block_size, config.embed_dim)
        self.blocks = nn.ModuleList([Block(config) for _ in range(config.num_layers)])
        self.ln_f = nn.LayerNorm(config.embed_dim)
        self.lm_head = nn.Linear(config.embed_dim, config.vocab_size)
        # Weight Tying
        self.token_embeddings.weight = self.lm_head.weight

    def forward(self, x: torch.Tensor, training: bool = False) -> torch.Tensor:
        batch_size, seq_len = x.shape
        token_emb = self.token_embeddings(x)
        pos_emb = self.position_embeddings(torch.arange(seq_len, device=x.device))
        x = token_emb + pos_emb
        for block in self.blocks:
            x = block(x, training)
        x = self.ln_f(x)
        logits = self.lm_head(x)
        return logits

    @torch.inference_mode()
    def generate(self, tokenizer, max_tokens: int, text: str = "", temperature: float = 0.0) -> str:
        self.eval()
        if text == "":
            text = tokenizer.bos_token
        input_ids = tokenizer(text=text, return_tensors='pt')['input_ids']
        for _ in range(max_tokens):
            input_ids = input_ids[:, -self.config.block_size:]
            logits = self(input_ids)[:, -1, :]
            if temperature == 0.0:
                _, next_token = torch.topk(logits, k=1, dim=-1)
            else:
                logits = logits / temperature
                probs = F.softmax(logits, dim=-1)
                next_token = torch.multinomial(probs, num_samples=1)
            input_ids = torch.cat((input_ids, next_token), dim=1)
        self.train()
        return tokenizer.decode(token_ids=input_ids[0])

class GPT2(L.LightningModule):
    def __init__(self, config: GPT2Config):
        super().__init__()
        self.model = TransformerDecoder(config)
        self.lr = config.lr
        self.t_max = config.t_max
        self.config = config

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.model(x)

    @torch.inference_mode()
    def generate(self, tokenizer, max_tokens: int, text: str = "", temperature: float = 0.0) -> str:
        return self.model.generate(tokenizer=tokenizer, text=text, max_tokens=max_tokens, temperature=temperature)

    def training_step(self, batch: dict, batch_idx: int) -> torch.Tensor:
        input_ids, targets = batch['input_ids'], batch['input_ids'][..., 1:].contiguous()
        logits = self(input_ids)[:, :-1, :].contiguous()
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index=50256)
        self.log('training_loss', loss, prog_bar=True, on_step=True, on_epoch=False, sync_dist=True)
        return loss

    def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
        logits = self(batch['input_ids'])
        targets = batch['input_ids'][..., 1:].contiguous()
        logits = logits[:, :-1, :].contiguous()
        loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), targets.view(-1), ignore_index=50256)
        self.log('validation_loss', loss, prog_bar=True, on_step=False, on_epoch=True, sync_dist=True)

    def configure_optimizers(self) -> Dict[str, Any]:
        opt = torch.optim.AdamW(self.parameters(), lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=opt, T_max=self.t_max)
        return {
            'optimizer': opt,
            'lr_scheduler': {
                "scheduler": scheduler,
                "monitor": "training_loss",
                "interval": "step",
                "frequency": 1,
            }
        }

class GenerateCallback(L.pytorch.callbacks.Callback):
    def __init__(self, tokenizer):
        super().__init__()
        self.tokenizer = tokenizer

    def on_epoch_end(self, trainer: L.Trainer, pl_module: GPT2) -> None:
        generated_text = pl_module.generate(tokenizer=self.tokenizer, max_tokens=256, temperature=1.0)
        print("Generated text:", generated_text)


In [180]:
config = GPT2Config(block_size=1024,n_embed=100,vocab_size=50257,n_heads=1,n_layers=1,dropout=0.0,t_max=0,lr=0.0)
model = TransformerDecoder(block_size=1024,n_embed=100,vocab_size=50257,n_heads=1,n_layers=1,dropout=0.0)


TypeError: GPT2Config.__init__() got an unexpected keyword argument 'n_embed'

In [176]:
model(torch.ones((1,1),dtype=torch.long),training=False)

tensor([[[ 0.2459,  0.4977,  1.0184,  ...,  0.0103,  0.0598, -0.4062]]],
       grad_fn=<ViewBackward0>)

In [None]:
model.