In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [121]:
import re
text = """
Once upon a time there was a small cat.
The cat loved to chase butterflies.
Sometimes it would sleep under the sun.
"""

# simple tokenizer (split into words)
tokens = re.findall(r"\b\w+\b", text.lower())
vocab = sorted(set(tokens))
vocab_size = len(vocab)
stoi = {w: i for i, w in enumerate(vocab)}
itos = {i: w for w, i in stoi.items()}

data = torch.tensor([stoi[w] for w in tokens], dtype=torch.long)
print("Vocab size:", vocab_size, "| Data length:", len(data))


Vocab size: 19 | Data length: 22


In [127]:
block_size = 8  # context length
def get_batch():
    i = torch.randint(0, len(data) - block_size - 1, (1,))
    x = data[i : i + block_size].unsqueeze(0)  # (1, block_size)
    y = data[i + 1 : i + block_size + 1].unsqueeze(0)
    return x, y

x,y=get_batch()
print(x,y)

tensor([[16,  0, 13, 12, 17,  0,  8,  2]]) tensor([[ 0, 13, 12, 17,  0,  8,  2, 11]])


In [128]:
class TinyGPT(nn.Module):
    def __init__(self, vocab_size, embed_dim=64, n_heads=4, n_layers=2):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim, nhead=n_heads, batch_first=True
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)
        self.lm_head = nn.Linear(embed_dim, vocab_size)

    def forward(self, input_ids, labels=None):
        x = self.embed(input_ids)
        # causal mask (prevent seeing future tokens)
        seq_len = x.size(1)
        mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool().to(x.device)
        x = self.transformer(x, mask)
        logits = self.lm_head(x)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
        return logits, loss


In [140]:
model = TinyGPT(100)
model(x,y)[0].view(-1,100).shape

torch.Size([8, 100])

In [141]:
y.shape

torch.Size([1, 8])