In [1]:
import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as Func

In [2]:
!wget https://github.com/karpathy/char-rnn/raw/refs/heads/master/data/tinyshakespeare/input.txt

--2025-05-24 16:08:00--  https://github.com/karpathy/char-rnn/raw/refs/heads/master/data/tinyshakespeare/input.txt
Resolving github.com (github.com)... 20.205.243.166
Connecting to github.com (github.com)|20.205.243.166|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt [following]
--2025-05-24 16:08:01--  https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.3’


2025-05-24 16:08:01 (123 MB/s) - ‘input.txt.3’ saved [1115394/1115394]



In [3]:
with open("input.txt", encoding="utf-8") as f:
    text = f.read()

In [4]:
print(len(text))

1115394


In [5]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [6]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
c2i = {ch: i for i, ch in enumerate(chars)}
i2c = {i: ch for i, ch in enumerate(chars)}
encode = lambda s: [c2i[c] for c in s]
decode = lambda l: "".join([i2c[i] for i in l])

In [7]:
encode("VASYA\n!"), decode(encode("VASYA\n!"))

([34, 13, 31, 37, 13, 0, 2], 'VASYA\n!')

In [8]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype, data[:100])

torch.Size([1115394]) torch.int64 tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59])


In [9]:
train = data[: int(0.8 * len(data))]
valid = data[int(0.8 * len(data)) :]

In [10]:
block_size = 256
batch_size = 64


def get_batch(split):
    data_out = train if split == "train" else valid
    index = torch.randint(len(data_out) - block_size, (batch_size,))
    x = torch.stack([data_out[i : i + block_size] for i in index])
    y = torch.stack([data_out[i + 1 : i + block_size + 1] for i in index])
    x, y = x.to(device), y.to(device)
    return x, y

In [11]:
# x: Batch, Time, Features
# k: Batch, Time, HeadFeatures

In [12]:
n_embd = 384
dropout = 0.2


class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B, T, F = x.shape
        tril = torch.tril(torch.ones(block_size, block_size, device=x.device))
        k = self.key(x)
        q = self.query(x)
        attention_weight = q @ k.transpose(-2, -1) / (k.shape[-1] ** 0.5)
        attention_weight = attention_weight.masked_fill(
            tril[:T, :T] == 0, float("-inf")
        )
        attention_weight = Func.softmax(attention_weight, dim=-1)
        attention_weight = self.dropout(attention_weight)
        v = self.value(x)
        return attention_weight @ v

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for i in range(num_heads)])
        self.proj = nn.Linear(head_size * num_heads, n_embd)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.dropout(self.proj(out))
        return out

In [14]:
class FeedForward(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.feed = nn.Sequential(
            nn.Linear(n_embd, 4 * n_embd),
            nn.ReLU(),
            nn.Linear(4 * n_embd, n_embd),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        return self.feed(x)

In [15]:
class Block(nn.Module):
    def __init__(self, n_embd, n_head):
        super().__init__()
        head_size = n_embd // n_head
        self.multihead = MultiHeadAttention(n_head, head_size)
        self.feed = FeedForward(n_embd)
        self.ln1 = nn.LayerNorm(n_embd)
        self.ln2 = nn.LayerNorm(n_embd)

    def forward(self, x):
        x = self.ln1(x + self.multihead(x))
        x = self.ln2(x + self.feed(x))
        return x

In [16]:
block_count = 6
n_head = 6


class Decoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.token_embedding = nn.Embedding(vocab_size, n_embd)
        self.position_embedding = nn.Embedding(block_size, n_embd)
        self.blocks = nn.Sequential(
            *[Block(n_embd, n_head) for i in range(block_count)]
        )
        self.out = nn.Linear(n_embd, vocab_size)

    def forward(self, x):
        B, T = x.shape
        token_embd = self.token_embedding(x)  # BTF
        position_embd = self.position_embedding(torch.arange(T, device=device))
        x = token_embd + position_embd
        x = self.blocks(x)
        x = self.out(x)
        return x

In [17]:
device = "cuda"

model = Decoder()
model.to(device)

Decoder(
  (token_embedding): Embedding(65, 384)
  (position_embedding): Embedding(256, 384)
  (blocks): Sequential(
    (0): Block(
      (multihead): MultiHeadAttention(
        (heads): ModuleList(
          (0-5): 6 x Head(
            (key): Linear(in_features=384, out_features=64, bias=False)
            (query): Linear(in_features=384, out_features=64, bias=False)
            (value): Linear(in_features=384, out_features=64, bias=False)
            (dropout): Dropout(p=0.2, inplace=False)
          )
        )
        (proj): Linear(in_features=384, out_features=384, bias=True)
        (dropout): Dropout(p=0.2, inplace=False)
      )
      (feed): FeedForward(
        (feed): Sequential(
          (0): Linear(in_features=384, out_features=1536, bias=True)
          (1): ReLU()
          (2): Linear(in_features=1536, out_features=384, bias=True)
          (3): Dropout(p=0.2, inplace=False)
        )
      )
      (ln1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
      

In [18]:
sum(p.numel() for p in model.parameters()) / 1e6

10.788161

In [19]:
from tqdm import tqdm

optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, fused=True)
max_iters = 2500
eval = 500
eval_dur = 50


def calc_loss(logits, y):
    B, T, F = logits.shape
    logits = logits.view(B * T, F)
    target = y.view(B * T)
    loss = Func.cross_entropy(logits, target)
    return loss


for i in tqdm(range(max_iters)):
    if i % eval == 0 or i == max_iters - 1:
        with torch.no_grad():
            model.eval()
            losses = []
            for split in ["train", "eval"]:
                for k in range(eval_dur):
                    x, y = get_batch(split)
                    logits = model(x)
                    loss = calc_loss(logits, y)
                    losses.append(loss.detach().item())
                print(f"Step {i} split {split}: {np.mean(losses)}")
            model.train()
    x, y = get_batch("train")
    logits = model(x)
    loss = calc_loss(logits, y)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad(set_to_none=True)

  0%|          | 0/2500 [00:00<?, ?it/s]

Step 0 split train: 4.3469745063781735
Step 0 split eval: 4.344580020904541


 20%|██        | 500/2500 [04:01<15:00,  2.22it/s]

Step 500 split train: 1.826581163406372
Step 500 split eval: 1.8947253406047821


 40%|████      | 1000/2500 [08:02<11:16,  2.22it/s]

Step 1000 split train: 1.484866075515747


 40%|████      | 1001/2500 [08:19<2:10:59,  5.24s/it]

Step 1000 split eval: 1.6093757796287536


 60%|██████    | 1500/2500 [12:03<07:28,  2.23it/s]

Step 1500 split train: 1.3401396203041076


 60%|██████    | 1501/2500 [12:20<1:26:25,  5.19s/it]

Step 1500 split eval: 1.497417219877243


 80%|████████  | 2000/2500 [16:04<03:44,  2.23it/s]

Step 2000 split train: 1.2550005960464476


 80%|████████  | 2001/2500 [16:20<42:57,  5.17s/it]

Step 2000 split eval: 1.4388985979557036


100%|█████████▉| 2499/2500 [20:03<00:00,  2.23it/s]

Step 2499 split train: 1.1949415707588196


100%|██████████| 2500/2500 [20:19<00:00,  2.05it/s]

Step 2499 split eval: 1.403828316926956





In [20]:
def generate(primer, max_generation):
    for i in range(max_generation):
        primer_input = primer[:, -block_size:]
        logits = model(primer_input)
        logits = logits[:, -1, :]
        probs = Func.softmax(logits, dim=-1)
        next_int = torch.multinomial(probs, num_samples=1)
        primer = torch.cat([primer, next_int], dim=-1)
    return primer


context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(generate(context, max_generation=1000)[0].tolist()))



CORIOLANUS:
That od child, in brief, to the trumpest, to
hath the crown shore beneful out sucorn in lips.
There lawful were yetful love, I at
raps, and fell oven to thee show my sove--we did
sprithence fit bothes.

FRIAR LAURENCE:
That even with cames med our blood again in a
secute, instow her mercy ago.

QUEEN MARGARET:
My Lord meo's fairl, uncle is despair with me?
Did not comfort, thou hence''st no mattering alone
Let my father withding with my sweet
'Tyieldst unto unfit
Against onest out vhooks that braiten'd of the willly
Bidined frown us his truththwhath have not hailt think own
As I repent the mark of Tyral: behildly, thou art
To rose it the budlower in thy hands, purchast
Affetch's princes and say, protesed and him him
Which sidiety ofn to hurst by hisbance hath be pordician.

DUKE OF YORK:
Here do fact unto that husband of him.

HASTINGS:
Sir, a prospersorditation to do do look.
Is it thou hast thou with cheerfect I'll kill him?
3 lood King HENRY VI

EDWARD IV:
Thy saint te