In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class LayerNorm(nn.Module):
    def __init__(self, ndim, bias, eps = 1e-5):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(ndim))
        self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
        self.eps = eps
    
    def forward(self, x):   
        return F.layer_norm(x, self.weight.shape, self.weight, self.bias, self.eps)

In [4]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout, bias):
        super().__init__()
        self.n_embd = n_embd
        self.n_head = n_head
        self.head_size = n_embd // n_head

        self.c_attn = nn.Linear(n_embd, 3 * n_embd, bias = bias)
        
        self.c_proj = nn.Linear(n_embd, n_embd, bias = bias)

        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)
        
        self.register_buffer('bias_mask', torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))
    
    def forward(self, x):
        B, T, C = x.size()

        q, k, v = self.c_attn(x).split(self.n_embd, dim = 2)
        
        k = k.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        q = q.view(B, T, self.n_head, self.head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, self.head_size).transpose(1, 2)

        att = (q @ k.transpose(-2, -1)) * (1.0 / (self.head_size ** 0.5))
        att = att.masked_fill(self.bias_mask[:, :, :T, :T] == 0, float('-inf'))
        att = F.softmax(att, dim = -1)
        att = self.attn_dropout(att)

        y = att @ v
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.resid_dropout(self.c_proj(y))

        return y