In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class Head(nn.Module):
    """ One head of self-attention """

    def __init__(self, head_size, n_embd, dropout):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("tril", torch.tril(torch.ones(1024, 1024)))

    def forward(self, x):
        B, T, C = x.shape
        k = self.key(x)   # (B,T,head_size)
        q = self.query(x) # (B,T,head_size)

        # compute attention scores
        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,T)

        # mask the future tokens
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))

        # softmax normalization
        wei = F.softmax(wei, dim=-1) # (B,T,T)
        wei = self.dropout(wei)

        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,head_size)
        out = wei @ v     # (B,T,head_size)
        return out
class MultiHeadAttention(nn.Module):
    """ multiple heads of self-attention in parallel """

    def __init__(self, num_heads, head_size, n_embd, dropout):
        assert n_embd % num_heads == 0
        super().__init__()
        self.num_heads=num_heads
        self.head_size=head_size
        self.key=nn.Linear(n_embd,head_size*num_heads,bias=False)
        self.query=nn.Linear(n_embd,head_size*num_heads,bias=False)
        self.value=nn.Linear(n_embd,head_size*num_heads,bias=False)
        self.proj = nn.Linear(num_heads * head_size, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        k=self.key(x) # (B,T,head_size*num_heads)
        q=self.query(x) # (B,T,head_size*num_heads)
        v=self.value(x) # (B,T,head_size*num_heads)
        B,T,C=k.shape
        k=k.view(B,T,self.num_heads,self.head_size).transpose(1,2)
        q=q.view(B,T,self.num_heads,self.head_size).transpose(1,2)
        v=v.view(B,T,self.num_heads,self.head_size).transpose(1,2)
        wei= q @ k.transpose(-2, -1) * (self.head_size**-0.5)#(B,num_heads,T,T)
        wei=wei.masked_fill(torch.tril(torch.ones(T,T,device=x.device))==0,float('-inf'))
        wei=F.softmax(wei,dim=-1)
        out=wei @ v #(B,num_heads,T,head_size)
        out=out.transpose(1,2).contiguous().view(B,T,self.num_heads*self.head_size) 
        out = self.proj(out)
        out = self.dropout(out)
        return out
class FeedForward(nn.Module):
    """ a simple linear layer followed by a non-linearity """

    def __init__(self, n_embd, dropout):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.net(x)
class Block(nn.Module):
    def __init__(self, n_embd,dropout,num_heads,head_size):
        super().__init__()
        self.MultiHeadAttention = MultiHeadAttention(num_heads, head_size, n_embd, dropout)
        self.FeedForward = FeedForward(n_embd, dropout)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)
    def forward(self,x):
        x = x + self.MultiHeadAttention(self.ln1(x))
        x = x + self.FeedForward(self.ln2(x))
        return x
class SimpleLM(nn.Module):
    def __init__(self,n_embd,dropout,num_heads,head_size,vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(1024, n_embd)
        self.blocks = Block(n_embd,dropout,num_heads,head_size)
        self.ln_f = nn.LayerNorm(n_embd)
        self.lm_head = nn.Linear(n_embd, vocab_size)
    def forward(self, idx):
        B, T = idx.shape
        token_embeddings = self.token_embedding_table(idx) # (B,T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=idx.device)) # (T,C)
        x = token_embeddings + position_embeddings # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.ln_f(x) # (B,T,C)
        logits = self.lm_head(x) # (B,T,vocab_size)
        return logits