In [1]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt

--2025-02-26 17:43:46--  https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-02-26 17:43:46 (18.1 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
# load text
with open('input.txt') as file:
    text = file.read()

print(text[:500])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor


In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x7f1832ff6670>

In [4]:
# preprocess text
text = text.lower()

# build vocab
chars = sorted(list(set(text)))

In [35]:
batch_size = 64 # number of instances to process simultaneously
block_size = 256 # maximum context length allowed
n_embed = 384
n_heads = 6
head_size = n_embed // n_heads
n_layers = 6
dropout = 0.2
vocab_size = len(chars)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate = 2e-4

In [6]:
# define decoder & encoder
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for i, s in enumerate(chars)}
encode = lambda text: [stoi[s] for s in text]
decode = lambda toks: ''.join([itos[i] for i in toks])

In [7]:
import torch

# encode text
data = torch.tensor(encode(text), dtype=torch.long)
data.shape

torch.Size([1115394])

In [8]:
# split data
train_set_size = 0.9
n = int(train_set_size*len(data))
train_data = data[:n]
val_data = data[n:]

In [9]:
def get_batch(split):
    if split not in ('train', 'val'):
        raise Exception('split must be train or val')

    data = train_data if split == 'train' else val_data
    high = len(data) - block_size
    idxs = torch.randint(low=0, high=high, size=(batch_size,))
    x = torch.stack([data[idx:idx+block_size] for idx in idxs])
    y = torch.stack([data[idx+1:idx+block_size+1] for idx in idxs])

    return x.to(device), y.to(device)

In [10]:
# evaluate model performance
@torch.no_grad()
def estimate_loss(model, eval_iters=1_000):
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for i in range(eval_iters):
            xb, yb = get_batch(split)
            _, loss = model(xb, yb)
            losses[i] = loss.item()
        out[split] = losses.mean().item()
    model.train()
    return out

In [11]:
class SingleHeadAttention(nn.Module):
    def __init__(self, n_embed=n_embed, head_size=head_size, dropout=dropout) -> None:
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
        self.dropout = nn.Dropout()

    def forward(self, x):
        _, T, C = x.shape
        k = self.key(x) # (B,T,C)
        q = self.query(x) # (B,T,C)
        weights = q @ k.transpose(-2, -1) * C**-0.5 # (B,T,C) @ (B,C,T) -> (B,T,T)
        weights = weights.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B,T,T)
        weights = F.softmax(weights, dim=-1) # (B,T,T)
        weights = self.dropout(weights)
        v = self.value(x) # (B,T,C)
        out = weights @ v # (B,T,T) @ (B,T,C) -> (B,T,C)
        return out

In [12]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_heads=n_heads, head_size=head_size, dropout=dropout) -> None:
        super().__init__()
        self.heads = nn.ModuleList([
            SingleHeadAttention() for _ in range(n_heads)
        ])
        self.proj = nn.Linear(n_embed, n_embed)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = torch.cat([h(x) for h in self.heads], dim=-1)
        x = self.proj(x)
        out = self.dropout(x)
        return out

In [13]:
class FeedForwardLayer(nn.Module):
    def __init__(self, dropout=dropout) -> None:
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
            nn.Dropout(dropout),)

    def forward(self, x):
        return self.net(x)

In [14]:
class Block(nn.Module):
    def __init__(self, n_embed=n_embed, n_heads=n_heads):
        super().__init__()
        self.sa_heads = MultiHeadAttention(n_heads, head_size)
        self.feed_forward = FeedForwardLayer()
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)

    def forward(self, x):
        x = x + self.sa_heads(self.ln1(x))
        x = x + self.feed_forward(self.ln2(x))
        return x

In [15]:
class GPTLM(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.tok_embedding = nn.Embedding(vocab_size, n_embed)
        self.position_embedding = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block() for _ in range(n_layers)])
        self.layer_norm = nn.LayerNorm(n_embed)
        self.lm_head = nn.Linear(n_embed, vocab_size)

        # for stabilizing training
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, idx, targets=None):
        B, T = idx.shape
        tok_embed = self.tok_embedding(idx) # (B,T,C)
        pos_embed = self.position_embedding(
            torch.arange(T, device=device)) # (T,C)
        x = tok_embed + pos_embed # (B,T,C)
        x = self.blocks(x) # (B,T,C)
        x = self.layer_norm(x) # (B,T,C)
        logits = self.lm_head(x) # idx: (B, T), logits: (B, T, C)
        B, T, C = logits.shape

        if targets is None:
            return logits, None

        loss = F.cross_entropy(logits.view(B*T, C), targets.view(B*T))
        return logits, loss


    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :] # (B,C)
            probs = F.softmax(logits, dim=-1) # (B,C)
            idx_next = torch.multinomial(probs, num_samples=1) # (B,1)
            idx = torch.cat((idx, idx_next), dim=1) # (B,T+1)
        return idx

In [16]:
model = GPTLM().to(device)
xb, yb = get_batch('train')
logits, loss = model(xb, yb)

In [17]:
decode(model.generate(torch.zeros((1, 1), dtype=torch.long).to(device),
                   max_new_tokens=100)[0].tolist())

'\n-ekfxymgtcbpp::jg,x?mom,&be;tml?fv&aqmmcadua!pqsy\n dass:m:-pr\n!$wovr &nkg,f3sxz.u\nzuy;oq,e;rm-\nif-vo'

In [24]:
optim = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [34]:
epochs = 3_000
for e in range(epochs):
    xb, yb = get_batch('train')

    logits, loss = model(xb, yb)
    optim.zero_grad()
    loss.backward()
    optim.step()

    if e % 1_000 == 0:
        print(loss.item())

1.0421206951141357
1.0306305885314941
1.0093374252319336


In [36]:
estimate_loss(model,eval_iters=50)

{'train': 0.7893437147140503, 'val': 1.5771703720092773}

In [37]:
res = decode(model.generate(torch.zeros((1, 1), dtype=torch.long).to(device),
                   max_new_tokens=3000)[0].tolist())
print(res)


and i will in evertae dull awthire?
ay, is not at myself musit,--though revengeful
like me forceful shouldst do that any man's!

grumilio:
tush, and no poison wep not with the foul cause.

petruchio:
a saint coil, you shall be well builted
withal; to her it hath some that kill'd my aufidius.

gentleman:
beat you, to the watch, sit's a man droop
to looking the severerely of your voices!

clarence:
titus:
you veronation is modest for this kind of contrary,
and merited to romeo, you mistake me
a lottendary of the body.

both:
mopsa, it was; therefore hereforw or incords,
i came it to breather.

romeo:
why, that's well: go wink, he is an rosema: it
so. there's none can pity to the state,
and many better subjects, which they royalty
he is asking to our furreson stainly beams
and to heaven there still wherein the hearts
which poestern our souls, he'ld wear us.

paulina:
let me be reason'd;
for so that my knees how he is not king,
it cannot help but will for revenge it,
like one confound to 

In [38]:
with open("out.txt", "w") as file:
    file.write(res)

In [40]:
import torch

torch.save(model.state_dict(), "model_weights.pth")