<a href="https://colab.research.google.com/github/HolgerMolin/HolgerGPT/blob/main/GPT2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import torch
from torch import nn
import torch.nn.functional as F
import math

In [3]:
class FFN(nn.Module):
  def __init__(self, embedding_dim, hidden_dim):
    super().__init__()
    self.linear_1 = nn.Linear(embedding_dim, hidden_dim)
    self.linear_2 = nn.Linear(hidden_dim, embedding_dim)

  def forward(self, x):
    return self.linear_2(F.relu(self.linear_1(x)))


class AttentionHead(nn.Module):
    def __init__(self, embedding_dim, key_dim, context_length):
        super().__init__()
        self.w_q = torch.randn((embedding_dim, key_dim), requires_grad=True)
        self.w_k = torch.randn((embedding_dim, key_dim), requires_grad=True)
        self.w_v1 = torch.randn((embedding_dim, key_dim), requires_grad=True)
        self.w_v2 = torch.randn((key_dim, embedding_dim), requires_grad=True)

        self.mask = torch.triu(torch.ones(context_length, context_length)) == 0
        self.k_dim = key_dim


    def forward(self, x):
        keys = x @ self.w_k
        queries = x @ self.w_q
        attention = queries @ torch.transpose(keys, 1, 2) * (self.k_dim ** -0.5)
        attention = attention.masked_fill(self.mask, -float('inf'))
        attention = F.softmax(attention, dim=-1)
        values = x @ self.w_v1 @ self.w_v2
        additions = attention @ values

        return additions

class MultiHeadAttention(nn.Module):
  def __init__(self, embedding_dim, num_heads, context_length):
    super().__init__()
    assert embedding_dim % num_heads == 0
    self.key_dim = int(embedding_dim / num_heads)
    self.heads = nn.ModuleList(AttentionHead(embedding_dim, self.key_dim, context_length) for _ in range(num_heads))

  def forward(self, x):
    for head in self.heads:
        x += head(x)
    return x
      
class TransformerBlock(nn.Module):
  def __init__(self, embedding_dim, num_heads, hidden_dim, context_length):
    super().__init__()
    self.ffn = FFN(embedding_dim, hidden_dim)
    self.attention = MultiHeadAttention(embedding_dim, num_heads, context_length)
    self.norm1 = nn.LayerNorm((context_length, embedding_dim))
    self.norm2 = nn.LayerNorm((context_length, embedding_dim))

  def forward(self, x):
    x += self.ffn(x)
    x = self.norm1(x)
    x += self.attention(x)
    x = self.norm2(x)
    return x

class PositionalEncoding(nn.Module):
    def __init__(self, context_length, embedding_dim):
        super(PositionalEncoding, self).__init__()
        
        self.embedding_dim = embedding_dim
        self.context_length = context_length

        pe = torch.zeros(context_length, embedding_dim)
        position = torch.arange(0, context_length, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embedding_dim, 2).float() * (-math.log(10000.0) / embedding_dim))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1)]
        return x

class HolgerGPT(nn.Module):
  def __init__(self, num_layers, embedding_dim, num_heads, hidden_dim, context_length, vocab_size):
    super().__init__()
    self.word_embedding = nn.Embedding(vocab_size, embedding_dim)
    self.pos_embedding = PositionalEncoding(context_length, embedding_dim)
    self.transformer_blocks = nn.ModuleList(TransformerBlock(embedding_dim, num_heads, hidden_dim, context_length) for _ in range(num_layers))
    self.linear = nn.Linear(embedding_dim * context_length, context_length)
    self.embedding_dim = embedding_dim
    self.context_length = context_length

  def forward(self, x):
    x = self.word_embedding(x)
    x = self.pos_embedding(x)
    for block in self.transformer_blocks:
      x = block.forward(x)
    x = self.linear(x.reshape(-1, self.embedding_dim * self.context_length))
    return x


In [4]:
model = HolgerGPT(4, 128, 32, 256, 256, 256)
print(f'Number of parameters: {sum(p.numel() for p in model.parameters()):_}')
x = torch.ones((128, 256), dtype=torch.long)

Number of parameters: 9_209_600


In [5]:
y = model.forward(x)

In [6]:
y.shape

torch.Size([128, 256])