In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from dataclasses import dataclass

torch.manual_seed(1024)

<torch._C.Generator at 0x2082300f870>

In [None]:
@dataclass
class CPTConfig:
    block_size: int = 512 # max_seq
    batch_size: int = 12
    n_layer: int = 12
    n_head: int = 12
    n_embed: int = 768 # hidden_dim, hidden_size (tie_embedding_weight)
    hidden_dim: int = n_embed
    dropout: float = 0.1
    head_size: int = n_embd // n_head
    vocab_size: int = 50207


In [None]:
class SingleHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.key = nn.Linear(config.hidden_dim, config.head_size)
        self.value = nn.Linear(config.hidden_dim, config.head_size)
        self.query = nn.Linear(config.hidden_dim, config.head_size)

        self.register_buffer(
            "attention_mask"
            torch.tril(
                torch.ones(config.block_size, config.block_size)
            )
        )
        self.dropout = nn.Dropout(config.dropout)
    
    def forward(self, x):
        batch_size, seq_len, hidden_dim = x.size()
        k = self.key(x)
        q = self.query(x)
        v = self.value(x)
        weight = q @ k.transpose(-2, -1)
        weight = weight.masked_fill(
            self.attention_mask[:seq_len, :seq_len] == 0,
            float('-inf')
        )
        
        weight = F.softmax(weight, dim-=1) / math.sqrt(self.head_size)

        weight = self.dropout(weight)
        output = weight @ v
        return output
