In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import random

In [2]:
text = """
hello world
this is a tiny character-level language model
hello there
""".lower()

In [3]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

# Mapping char -> int
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for ch,i in stoi.items()}

def encode(s):
    return [stoi[c] for c in s]

def decode(ids):
    return ''.join([itos[i] for i in ids])


In [4]:
print(chars)
print(encode("hello"))
print(decode(encode("hello")))


['\n', ' ', '-', 'a', 'c', 'd', 'e', 'g', 'h', 'i', 'l', 'm', 'n', 'o', 'r', 's', 't', 'u', 'v', 'w', 'y']
[8, 6, 10, 10, 13]
hello


In [5]:
import torch

block_size = 16   # hur lång kontext modellen får se
data = torch.tensor(encode(text), dtype=torch.long)

def get_batch(batch_size=32):
    ix = torch.randint(len(data) - block_size - 1, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SelfAttention(nn.Module):
    def __init__(self, embed_dim):
        super().__init__()
        self.key = nn.Linear(embed_dim, embed_dim, bias=False)
        self.query = nn.Linear(embed_dim, embed_dim, bias=False)
        self.value = nn.Linear(embed_dim, embed_dim, bias=False)
        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape
        
        K = self.key(x)     # (B, T, C)
        Q = self.query(x)   # (B, T, C)
        V = self.value(x)   # (B, T, C)

        # Attention weights
        att = Q @ K.transpose(-2, -1) / (C**0.5)   # (B, T, T)
        att = att.masked_fill(self.mask[:T,:T] == 0, float('-inf'))
        att = F.softmax(att, dim=-1)

        out = att @ V  # (B, T, C)
        return out

class TinyTransformer(nn.Module):
    def __init__(self, vocab_size, embed_dim=64):
        super().__init__()
        self.embed = nn.Embedding(vocab_size, embed_dim)
        self.pos_embed = nn.Embedding(block_size, embed_dim)

        self.attn = SelfAttention(embed_dim)
        self.ln1 = nn.LayerNorm(embed_dim)
        
        self.ff = nn.Sequential(
            nn.Linear(embed_dim, embed_dim*4),
            nn.ReLU(),
            nn.Linear(embed_dim*4, embed_dim)
        )
        
        self.ln2 = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_emb = self.embed(idx)
        pos_emb = self.pos_embed(torch.arange(T))
        x = token_emb + pos_emb

        # Block
        x = x + self.attn(self.ln1(x))
        x = x + self.ff(self.ln2(x))

        logits = self.head(x)

        if targets is None:
            return logits
        
        # compute cross entropy loss
        B, T, C = logits.shape
        logits = logits.view(B*T, C)
        targets = targets.view(B*T)
        loss = F.cross_entropy(logits, targets)

        return logits, loss

model = TinyTransformer(vocab_size)


In [9]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)

for step in range(2000):
    xb, yb = get_batch()
    logits, loss = model(xb, yb)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    if step % 200 == 0:
        print(f"step {step}, loss {loss.item():.4f}")


step 0, loss 3.4652
step 200, loss 0.1098
step 400, loss 0.0961
step 600, loss 0.0933
step 800, loss 0.0541
step 1000, loss 0.0862
step 1200, loss 0.0692
step 1400, loss 0.0775
step 1600, loss 0.0987
step 1800, loss 0.0762


In [10]:
def generate(model, start_text="h", max_new_tokens=100):
    model.eval()
    idx = torch.tensor([encode(start_text)], dtype=torch.long)

    for _ in range(max_new_tokens):
        logits = model(idx[:, -block_size:])
        logits = logits[:, -1, :]
        probs = F.softmax(logits, dim=-1)
        next_id = torch.multinomial(probs, num_samples=1)
        idx = torch.cat([idx, next_id], dim=1)

    return decode(idx[0].tolist())

print(generate(model, "h"))


haracter-level language model
hello thereld
this is a tiny character-level language model
hello there


In [11]:
print(generate(model, "hell"))

hello world
this is a tiny character-level language model
hello therelis a tiny character-level language


In [12]:
print(generate(model, "this"))

this is a tiny character-level language model
hellllo theracter-level language model
hello therelelo tha
