In [4]:
import torch 
from torch import nn
import torch.nn.functional as F

import math

In [10]:
nn.ModuleDict(dict(
    wte = nn.Embedding(20, 50),
    wpe = nn.Embedding(2, 50), 
    drop = nn.Dropout(0.2),
    ))

ModuleDict(
  (wte): Embedding(20, 50)
  (wpe): Embedding(2, 50)
  (drop): Dropout(p=0.2, inplace=False)
)

# Attention workflow
1. Input 
2. Word embeddings
3. Positional embeddings
4. Concat(Word, Positional)
5. Normalization (optional)
6. Attention


In [117]:
import torch

# Example dimensions
B = 32  # Batch size
T = 1024  # Sequence length or context window
C = 128 # Embedding dimension

# Generate tokenized input
x = torch.randint(0, 50276, (50276,))
print(x.shape)
# Ensure the length of x is divisible by T
num_tokens = B * T
x = x[:num_tokens]  # Truncate or pad x as necessary

print(x.shape)
# Reshape to (B, T)
x = x.view(B, T)

x.shape


torch.Size([50276])
torch.Size([32768])


torch.Size([32, 1024])

In [118]:
# x = torch.randint(1, 50276, (1, 1, 50276))
#token embeddings
def txt_emb(vocab_size=50276, n_emb=768):
    wte = nn.Embedding(vocab_size, n_emb)
    return wte

#positional embeddings
def pos_emb(block_size=1024, n_embd=768):
    wpe = nn.Embedding(block_size, n_embd)
    return wpe

T = txt_emb()(x)
b, t = x.size()
pos = torch.arange(0, t)
P = pos_emb()(pos)

In [123]:
T.size(), P.size()

(torch.Size([32, 1024, 768]), torch.Size([1024, 768]))

In [127]:
concat_x = (T + P)

In [142]:
#attention
class CausalSelfAttention(nn.Module):

    def __init__(self, n_embd: int=768, 
                 n_head: int=12, 
                 dropout: float=0.1, 
                 block_size: int=1024):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.dropout = dropout
        self.block_size = block_size
        assert self.n_embd % self.n_head == 0
        self.lin = nn.Linear(self.n_embd, 3 * self.n_embd, bias=True)

        # output projection
        self.c_proj = nn.Linear(self.n_embd, self.n_embd, bias=True)

        # regularization
        self.attn_dropout = nn.Dropout(self.dropout)
        self.resid_dropout = nn.Dropout(self.dropout)
        
    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        q, k, v  = self.lin(x).split(self.n_embd, dim=2)
        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)

    
        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
        att = F.softmax(att, dim=-1)
        att = self.attn_dropout(att)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        
        # output
        y = self.resid_dropout(self.c_proj(y))
        return y


attn = CausalSelfAttention()
attn = attn(concat_x)

In [143]:
del attn