In [8]:
# Unlabled Raw Text

text = """
  your journey starts with one step
  your journeu starts today
  one step can change everything
"""

In [9]:
import torch

def build_dataset(text, block_size):
    tokens = [stoi[w] for w in text.split()]
    X, Y = [], []

    for i in range(len(tokens) - block_size):
        X.append(tokens[i:i+block_size])
        Y.append(tokens[i+1:i+block_size+1])

    return torch.tensor(X), torch.tensor(Y)


In [10]:
import torch.nn as nn
import torch.nn.functional as F


In [11]:
class GPTEmbedding(nn.Module):
    def __init__(self, vocab_size, d_model, block_size):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb = nn.Embedding(block_size, d_model)

    def forward(self, x):
        B, T = x.shape
        pos = torch.arange(T)
        return self.token_emb(x) + self.pos_emb(pos)


In [12]:
class FeedForward(nn.Module):
    def __init__(self, d_model):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(d_model, 4*d_model),
            nn.GELU(),
            nn.Linear(4*d_model, d_model)
        )

    def forward(self, x):
        return self.net(x)


In [13]:
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        assert d_model % num_heads == 0
        self.num_heads = num_heads
        self.d_head = d_model // num_heads

        self.qkv = nn.Linear(d_model, 3*d_model, bias=False)
        self.out = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x):
        B, T, C = x.shape

        qkv = self.qkv(x)
        Q, K, V = qkv.chunk(3, dim=-1)

        Q = Q.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        K = K.view(B, T, self.num_heads, self.d_head).transpose(1, 2)
        V = V.view(B, T, self.num_heads, self.d_head).transpose(1, 2)

        scores = (Q @ K.transpose(-2, -1)) / (self.d_head ** 0.5)

        mask = torch.triu(torch.ones(T, T), diagonal=1)
        scores = scores.masked_fill(mask == 1, float('-inf'))

        attn = F.softmax(scores, dim=-1)
        out = attn @ V

        out = out.transpose(1, 2).contiguous().view(B, T, C)
        return self.out(out)


In [14]:
class GPTBlock(nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.ln1 = nn.LayerNorm(d_model)
        self.attn = MultiHeadSelfAttention(d_model, num_heads)
        self.ln2 = nn.LayerNorm(d_model)
        self.mlp = FeedForward(d_model)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x


In [15]:
class MiniGPT(nn.Module):
    def __init__(self, vocab_size, d_model, block_size, num_heads, num_layers):
        super().__init__()
        self.embed = GPTEmbedding(vocab_size, d_model, block_size)
        self.blocks = nn.Sequential(
            *[GPTBlock(d_model, num_heads) for _ in range(num_layers)]
        )
        self.ln_f = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size)

    def forward(self, x, targets=None):
        x = self.embed(x)
        x = self.blocks(x)
        x = self.ln_f(x)
        logits = self.lm_head(x)

        loss = None
        if targets is not None:
            loss = F.cross_entropy(
                logits.view(-1, vocab_size),
                targets.view(-1)
            )

        return logits, loss


In [18]:
words = text.strip().split()
vocab = sorted(list(set(words)))
stoi = {s:i for i,s in enumerate(vocab)}
itos = {i:s for i,s in enumerate(vocab)}
vocab_size = len(vocab)
block_size = 8 # A common choice for block_size

model = MiniGPT(
    vocab_size=vocab_size,
    d_model=64,
    block_size=block_size,
    num_heads=4,
    num_layers=2
)

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)


In [20]:
X, Y = build_dataset(text, block_size)

for step in range(500):
    optimizer.zero_grad()

    logits, loss = model(X, Y)
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"step {step} | loss {loss.item():.4f}")

step 0 | loss 2.5572
step 100 | loss 0.1258
step 200 | loss 0.0561
step 300 | loss 0.0428
step 400 | loss 0.0369
