
References:

https://jaketae.github.io/study/gpt/
https://medium.com/@sntaus/building-a-mini-gpt-like-language-model-from-scratch-27257bf5c145

In [1]:
import math
import torch
import torch.nn.functional as F
from torch import nn

In [2]:
class GPTConfig:
    attn_dropout = 0.1
    embed_dropout = 0.1
    ff_dropout = 0.1
    
    def __init__(
        self, vocab_size, max_len, **kwargs
    ):
        self.vocab_size = vocab_size
        self.max_len = max_len
        for key, value in kwargs.items():
            setattr(self, key, value)

class GPT1Config(GPTConfig):
    num_heads = 12
    num_blocks = 12
    embed_dim = 768

In [3]:
class MultiheadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.num_heads = config.num_heads
        assert embed_dim % self.num_heads == 0, "invalid heads and embedding dimension configuration"
        self.key = nn.Linear(embed_dim, embed_dim)
        self.value = nn.Linear(embed_dim, embed_dim)
        self.query = nn.Linear(embed_dim, embed_dim)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.attn_dropout = nn.Dropout(config.attn_dropout)
        self.proj_dropout = nn.Dropout(config.ff_dropout)
        self.register_buffer(
            "mask", 
            torch.tril(torch.ones(config.max_len, config.max_len))
            .unsqueeze(0).unsqueeze(0)
        )
    
    def forward(self, x):
        batch_size = x.size(0)
        seq_len = x.size(1)
        # x.shape == (batch_size, seq_len, embed_dim)
        k_t = self.key(x).reshape(batch_size, seq_len, self.num_heads, -1).permute(0, 2, 3, 1)
        v = self.value(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        q = self.query(x).reshape(batch_size, seq_len, self.num_heads, -1).transpose(1, 2)
        # shape == (batch_size, num_heads, seq_len, head_dim)
        
        attn = torch.matmul(q, k_t) / math.sqrt(q.size(-1))
        # attn.shape == (batch_size, num_heads, seq_len, seq_len)
        mask = self.mask[:, :, :seq_len, :seq_len]
        attn = attn.masked_fill(mask == 0, float("-inf"))
        attn = self.attn_dropout(attn)
        # attn.shape == (batch_size, num_heads, seq_len, seq_len)
        attn = F.softmax(attn, dim=-1)
        y = torch.matmul(attn, v)
        # y.shape == (batch_size, num_heads, seq_len, head_dim)
        y = y.transpose(1, 2)
        # y.shape == (batch_size, seq_len, num_heads, head_dim)
        y = y.reshape(batch_size, seq_len, -1)
        # y.shape == (batch_size, seq_len, embed_dim)
        y = self.proj_dropout(self.proj(y))
        return y
    
class Block(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.ln1 = nn.LayerNorm(embed_dim)
        self.ln2 = nn.LayerNorm(embed_dim)
        self.attn = MultiheadAttention(config)
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim * 4),
            nn.GELU(),
            nn.Linear(embed_dim * 4, embed_dim),
            nn.Dropout(config.ff_dropout),
        )
    
    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))
        return x

In [4]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        embed_dim = config.embed_dim
        self.max_len = config.max_len
        self.tok_embed = nn.Embedding(
            config.vocab_size, embed_dim
        )
        self.pos_embed = nn.Parameter(
            torch.zeros(1, config.max_len, embed_dim)
        )
        self.dropout = nn.Dropout(config.embed_dropout)
        self.blocks = nn.Sequential(
            *[Block(config) for _ in range(config.num_blocks)]
        )
        self.ln = nn.LayerNorm(embed_dim)
        self.fc = nn.Linear(embed_dim, config.vocab_size)
    
    def forward(self, x, target=None):
        # batch_size = x.size(0)
        seq_len = x.size(1)
        assert seq_len <= self.max_len, "sequence longer than model capacity"
        
        tok_embedding = self.tok_embed(x)
        # tok_embedding.shape == (batch_size, seq_len, embed_dim)
        pos_embedding = self.pos_embed[:, :seq_len, :]
        # pos_embedding.shape == (1, seq_len, embed_dim)
        x = self.dropout(tok_embedding + pos_embedding)
        x = self.blocks(x)
        x = self.ln(x)
        x = self.fc(x)
        # x.shape == (batch_size, seq_len, vocab_size)
        return x

In [5]:
max_len = 5
mask = torch.tril(torch.ones(max_len, max_len)).unsqueeze(0).unsqueeze(0)
mask

seq_len = 3
mask = mask[:, :, :seq_len, :seq_len]
mask

tensor([[[[1., 0., 0.],
          [1., 1., 0.],
          [1., 1., 1.]]]])

In [6]:
batch_size = 3
num_heads = 2

attn = torch.randn(batch_size, num_heads, seq_len, seq_len)
attn.shape

torch.Size([3, 2, 3, 3])

In [7]:
attn = attn.masked_fill(mask == 0, float("-inf"))
attn

tensor([[[[ 1.7825,    -inf,    -inf],
          [-0.5122,  0.0047,    -inf],
          [-0.2257, -1.3988,  2.2008]],

         [[ 0.9375,    -inf,    -inf],
          [ 1.4302, -1.0021,    -inf],
          [-0.8016, -0.4134, -0.2736]]],


        [[[ 1.8246,    -inf,    -inf],
          [ 0.7566, -0.4946,    -inf],
          [ 1.1573,  0.4111,  0.5795]],

         [[ 1.7281,    -inf,    -inf],
          [ 2.3234,  0.6134,    -inf],
          [ 0.9229,  1.3463,  1.5307]]],


        [[[ 0.0044,    -inf,    -inf],
          [-0.2688, -1.2587,    -inf],
          [-0.0407, -0.6177,  1.4859]],

         [[-0.3037,    -inf,    -inf],
          [ 1.1637, -1.4806,    -inf],
          [-0.4594, -0.7926,  1.6997]]]])

In [8]:
F.softmax(attn, dim=-1)

tensor([[[[1.0000, 0.0000, 0.0000],
          [0.3736, 0.6264, 0.0000],
          [0.0792, 0.0245, 0.8963]],

         [[1.0000, 0.0000, 0.0000],
          [0.9193, 0.0807, 0.0000],
          [0.2398, 0.3536, 0.4066]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.7775, 0.2225, 0.0000],
          [0.4913, 0.2330, 0.2757]],

         [[1.0000, 0.0000, 0.0000],
          [0.8468, 0.1532, 0.0000],
          [0.2292, 0.3500, 0.4208]]],


        [[[1.0000, 0.0000, 0.0000],
          [0.7291, 0.2709, 0.0000],
          [0.1622, 0.0911, 0.7467]],

         [[1.0000, 0.0000, 0.0000],
          [0.9337, 0.0663, 0.0000],
          [0.0963, 0.0690, 0.8346]]]])

In [9]:
vocab_size = 10
max_len = 12

config = GPT1Config(vocab_size, max_len)
model = GPT(config)

model

GPT(
  (tok_embed): Embedding(10, 768)
  (dropout): Dropout(p=0.1, inplace=False)
  (blocks): Sequential(
    (0): Block(
      (ln1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (attn): MultiheadAttention(
        (key): Linear(in_features=768, out_features=768, bias=True)
        (value): Linear(in_features=768, out_features=768, bias=True)
        (query): Linear(in_features=768, out_features=768, bias=True)
        (proj): Linear(in_features=768, out_features=768, bias=True)
        (attn_dropout): Dropout(p=0.1, inplace=False)
        (proj_dropout): Dropout(p=0.1, inplace=False)
      )
      (ff): Sequential(
        (0): Linear(in_features=768, out_features=3072, bias=True)
        (1): GELU(approximate='none')
        (2): Linear(in_features=3072, out_features=768, bias=True)
        (3): Dropout(p=0.1, inplace=False)
      )
    )
    (1): Block(
      (ln1): LayerNorm((768,), eps=1e-05, elemen

In [10]:
seq_len = 15

test_input = torch.randint(high=vocab_size, size=(batch_size, seq_len))
try:
    model(test_input).shape
except AssertionError as e:
    print(e)

sequence longer than model capacity


In [11]:
model(test_input[:, :max_len]).shape

torch.Size([3, 12, 10])

In [None]:
# Input: PyTorch model, input tensor (batch_size x max_token_count) 
# and max_output_token_count (when to stop if <end> not encountered)
# Function to perform inference recursively for each sequence in a batch
def infer_recursive(model, input_vectors, max_output_token_count=10):
    model.eval()  # Set model to evaluation mode
    outputs = []

    # Loop over sequences in the batch
    for i in range(input_vectors.shape[0]):
        print(f"Infering sequence {i}")
        input_vector = input_vectors[i].reshape(1, input_vectors.shape[1])
        predicted_sequence = []
        wc = 0  # Initialize word count

        with torch.no_grad():  # Disable gradient computation
            while True:
                output = model(input_vector)  # Pass current input through model
                predicted_index = output[0, :].argmax().item()  # Get index of predicted token
                predicted_sequence.append(predicted_index)  # Append predicted index to sequence
                # Stop when <end> token is predicted or the maximum output length is reached
                if predicted_index == word_to_ix['<end>'] or wc > max_output_token_count:
                    break
                # Append predicted token to input and increment word count
                input_vector = torch.cat([input_vector, torch.tensor([[predicted_index]])], dim=1)
                wc += 1
        outputs.append(torch.tensor(predicted_sequence))  # Append predicted sequence to outputs
    outputs = pad_tensors(outputs)  # Pad predicted sequences to the same length
    return outputs