In [1]:
import math
from dataclasses import dataclass
import torch
import torch.nn as nn
from torch.nn import functional as F

In [None]:
class GPT(nn.Module):
    def __init__(self):
        super().__init__()
        self.context_length = 64
        self.vocab_size = 1000
        self.n_layers = 1
        self.n_heads = 1
        self.n_embd = 64

        self.fc1 = nn.Linear(self.n_embd, 2 * self.n_embd)
        self.gelu = nn.GELU(approximate='tanh')  # GELU usada no GPT
        self.fc2 = nn.Linear(2 * self.n_embd, self.n_embd)

        self.ln = nn.LayerNorm(self.n_embd)

        self.qkv_proj = nn.Linear(self.n_embd, 3 * self.n_embd)  # Projeção para Q, K, V
        self.out_proj = nn.Linear(self.n_embd, self.n_embd)      # Projeção final
        self.head_dim = self.n_embd // self.n_heads              # Dimensão de cada cabeça
        
        self.register_buffer(
            "mask",
            torch.tril(torch.ones(self.context_length, self.context_length))
                .view(1, 1, self.context_length, self.context_length)
        )
        
        assert self.n_embd % self.n_heads == 0, "n_embd deve ser divisível por n_heads"

    def mlp_forward(self, x):
        # x: (T, C)
        x = self.fc1(x)
        x = self.gelu(x)
        x = self.fc2(x)
        return x
    
    def layernorm_forward(self, x):
        # x: (T, C)
        return self.ln(x)
    
    def self_attention(self, x):
        # x: (T, C) → sequência, embedding
        T, C = x.size()

        # Projeta para Q, K, V
        qkv = self.qkv_proj(x)  # (T, 3*C)
        q, k, v = qkv.chunk(3, dim=1)  # Cada um (T, C)

        # Separa em múltiplas heads
        #q = q.view(T, self.n_heads, self.head_dim).transpose(0, 1)  # (nh, T, hd)
        #k = k.view(T, self.n_heads, self.head_dim).transpose(0, 1)
        #v = v.view(T, self.n_heads, self.head_dim).transpose(0, 1)

        # Produto escalar entre Q e Kᵀ
        att = (q @ k.transpose(-2, -1)) / math.sqrt(self.head_dim)  # (nh, T, T)

        # Aplicar máscara causal
        att = att.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))

        # Normaliza com Softmax
        att = torch.softmax(att, dim=-1)

        # Atenção aplicada sobre V
        y = att @ v  # (nh, T, hd)

        # Junta as heads
        y = y.transpose(0, 1).contiguous().view(T, C)  # (T, C)

        # Projeção final
        y = self.out_proj(y)
        return y
