In [5]:
import torch
from torch import nn
import torch.nn.functional as F
import math
import time

In [29]:
class Attention(nn.Module):
  def __init__(self, ctx_size, emb_size, head_dim):
    super().__init__()
    self.head_dim = head_dim
    self.q = nn.Linear(emb_size, head_dim, bias=False)
    self.k = nn.Linear(emb_size, head_dim, bias=False)
    self.v = nn.Linear(emb_size, head_dim, bias=False)

    self.register_buffer('tril', torch.tril(torch.ones(ctx_size,ctx_size)))

  def forward(self, x):
    B, T, C = x.shape

    k = self.k(x)
    q = self.q(x)
    v = self.v(x)

    kT = k.transpose(-2,-1)
    qk = q @ kT * self.head_dim ** -0.5

    qk = qk.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    qk = F.softmax(qk, dim=-1)
    return qk @ v

class Transformer(nn.Module):
  def __init__(self, ctx_size, emb_size, head_size):
    super().__init__()
    head_dim = emb_size // head_size
    self.heads = nn.ModuleList([Attention(ctx_size, emb_size, head_dim) for _ in range(head_size)])
    self.proj = nn.Linear(emb_size, emb_size, bias=False)
    self.mlp = MLP(emb_size)
    self.ln1 = nn.LayerNorm(emb_size, bias=False)
    self.ln2 = nn.LayerNorm(emb_size, bias=False)

  def forward(self, x):
    x = self.ln1(x)
    y = torch.cat([head(x) for head in self.heads])
    x = x + self.proj(y)
    y = self.ln2(x)
    x = x + self.mlp(y)

class MLP(nn.Module):
  def __init__(self, emb_size):
    super().__init__()
    self.mlp = nn.Sequential(
        nn.Linear(emb_size,emb_size*4,bias=False),
        nn.ReLU(),
        nn.Linear(emb_size*4,emb_size, bias=False)
    )
  def forward(self, x):
    return self.mlp(x)

class GPT(nn.Module):
  def __init__(self, vocab_size, ctx_size, emb_size, layer_cnt, head_size):
    super().__init__()
    self.ctx_size = ctx_size
    self.layer_cnt = layer_cnt

    #embedding
    self.wemb = nn.Embedding(vocab_size, emb_size)
    self.pemb = nn.Embedding(ctx_size, emb_size)

    #Transfomer layers
    self.layers = nn.Sequential(*[Transformer(ctx_size,emb_size,head_size) for _ in range(layer_cnt)])

    self.ln = nn.LayerNorm(emb_size, bias=False)

    #Language modelin head
    self.ff = nn.Linear(emb_size, vocab_size, bias=False)


  def forward(self, x, targets=None):
    B, T = x.shape

    #embedding
    wemb = self.wemb(x)
    pemb = self.pemb(torch.arange(T, device='cuda'))

    x = wemb + pemb #B, T, C

    x = self.layers(x)

    logit = self.ff(self.ln(x))

In [31]:
model = GPT(4096,512,384,6,6)
sum([p.numel() for p in model.parameters()])

13964160

In [8]:
torch.arange(4)

tensor([0, 1, 2, 3])

In [12]:
c = 4
k = torch.rand((1,c,c))
k

tensor([[[0.2073, 0.7470, 0.3432, 0.8191],
         [0.1831, 0.3467, 0.3920, 0.6274],
         [0.5366, 0.2404, 0.4426, 0.7120],
         [0.3001, 0.9236, 0.1838, 0.2203]]])

In [13]:
k.transpose(-2,-1)

tensor([[[0.2073, 0.1831, 0.5366, 0.3001],
         [0.7470, 0.3467, 0.2404, 0.9236],
         [0.3432, 0.3920, 0.4426, 0.1838],
         [0.8191, 0.6274, 0.7120, 0.2203]]])

In [15]:
a = [1,2,3,4]
a[-2]

3

In [16]:
1/math.sqrt(4)

0.5

In [17]:
1 * 4 ** -0.5

0.5

In [20]:
4, 4 **-1, math.sqrt(4), 4 ** 0.5

(4, 0.25, 2.0, 2.0)

In [21]:
torch.ones((4,4))

tensor([[1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.],
        [1., 1., 1., 1.]])

In [25]:
tril = torch.tril(torch.ones(4,4))
tril

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [26]:
k = torch.rand((4,4))
k.masked_fill(tril == 0, float('-inf'))

tensor([[0.8996,   -inf,   -inf,   -inf],
        [0.3985, 0.2193,   -inf,   -inf],
        [0.4521, 0.3498, 0.4845,   -inf],
        [0.7688, 0.5964, 0.3812, 0.8400]])

In [28]:
F.softmax(k.masked_fill(tril == 0, float('-inf')))

  F.softmax(k.masked_fill(tril == 0, float('-inf')))


tensor([[1.0000, 0.0000, 0.0000, 0.0000],
        [0.5447, 0.4553, 0.0000, 0.0000],
        [0.3406, 0.3075, 0.3519, 0.0000],
        [0.2782, 0.2342, 0.1888, 0.2988]])

In [30]:
tril[:2, :2]

tensor([[1., 0.],
        [1., 1.]])

In [1]:
[0,1,2,3]

[0, 1, 2, 3]

In [11]:
[ Attention(25,25,25) for _ in range(5)]

[Attention(
   (q): Linear(in_features=25, out_features=25, bias=False)
   (k): Linear(in_features=25, out_features=25, bias=False)
   (v): Linear(in_features=25, out_features=25, bias=False)
 ),
 Attention(
   (q): Linear(in_features=25, out_features=25, bias=False)
   (k): Linear(in_features=25, out_features=25, bias=False)
   (v): Linear(in_features=25, out_features=25, bias=False)
 ),
 Attention(
   (q): Linear(in_features=25, out_features=25, bias=False)
   (k): Linear(in_features=25, out_features=25, bias=False)
   (v): Linear(in_features=25, out_features=25, bias=False)
 ),
 Attention(
   (q): Linear(in_features=25, out_features=25, bias=False)
   (k): Linear(in_features=25, out_features=25, bias=False)
   (v): Linear(in_features=25, out_features=25, bias=False)
 ),
 Attention(
   (q): Linear(in_features=25, out_features=25, bias=False)
   (k): Linear(in_features=25, out_features=25, bias=False)
   (v): Linear(in_features=25, out_features=25, bias=False)
 )]

In [14]:
def topla(a,b,c):
  return a+b+c

topla(1,2,3)

6

In [16]:
topla(1,2,3)

6