In [21]:
import numpy as np
import pandas as pd
import torch

import torch
import torch.nn as nn
import torch.nn.functional as F

import tiktoken

In [2]:
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2025-12-19 01:01:08--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8000::154, 2606:50c0:8001::154, 2606:50c0:8003::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8000::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.1’


2025-12-19 01:01:09 (13.5 MB/s) - ‘input.txt.1’ saved [1115394/1115394]



In [22]:
with open('input.txt', 'r', encoding='utf8') as f:
    text = f.read()

In [23]:
print(len(text))

1115394


In [24]:
chars = sorted(list(set(text)))
vocab_size = len(chars)

In [25]:
#We're doing character level token encoding - no semantic meaning is captured in this
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

In [26]:
def encode(str):
    return [stoi[s] for s in str]
def decode(vec):
    return ''.join([itos[i] for i in vec])

In [27]:
#Byte Pair Encoding
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(text)
data = torch.tensor(tokens, dtype = torch.long)

In [28]:
vocab_size = enc.n_vocab

def decode_BPE(tokens):
    return enc.decode(tokens)

In [29]:
#Char level
#sample = "Yo, what's up?"
#print(encode(sample))
#print(decode(encode(sample)))

In [31]:
#BPE
sample = "Yo, what's up?"
print(enc.encode(sample))
print(decode_BPE(enc.encode(sample)))

[38101, 11, 644, 338, 510, 30]
Yo, what's up?


In [32]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

Using device: mps


In [88]:
#data = torch.tensor(encode(text), dtype=torch.int)
#print(data.shape, data.dtype)

torch.Size([1115394]) torch.int32


In [33]:
n = int(0.8*len(data))
train = data[:n]
val = data[n:]

In [34]:
B = 8 #batch size
T = 256 #sequence length
torch.manual_seed(42)

def get_batch(split, T):
    sample = train if split == 'train' else val
    idx = torch.randint(len(sample) - T, (B,))
    x = torch.stack([sample[i:i+T] for i in idx])
    y = torch.stack([sample[i+1:i+T+1] for i in idx])
    return x.to(device), y.to(device)
    
xb, yb = get_batch('train', T)
print("xb:", xb)
print("yb:", yb)

xb: tensor([[ 6711,    25,   198,  ...,   423,  2074,  1549],
        [  284, 13197,   465,  ...,   683,    26,   198],
        [   11,  2074,    26,  ..., 44879,    40,  3535],
        ...,
        [  198,  5122,  2988,  ...,   198,  2348,   292],
        [  475,   484,   547,  ..., 11083,   286,  1971],
        [ 3223,  9538,  1657,  ...,   338,  9482,    11]], device='mps:0')
yb: tensor([[   25,   198,    35,  ...,  2074,  1549,   287],
        [13197,   465, 10645,  ...,    26,   198,     6],
        [ 2074,    26,   892,  ...,    40,  3535,  1565],
        ...,
        [ 5122,  2988,  1203,  ...,  2348,   292,    11],
        [  484,   547,  4844,  ...,   286,  1971,   198],
        [ 9538,  1657,    25,  ...,  9482,    11,   198]], device='mps:0')


In [35]:
import torch.nn.functional as F

class CausalSelfAttentionHead(nn.Module):
    def __init__(self, n_embd, head_size, block_size):
        super().__init__()
        self.k = nn.Linear(n_embd, head_size, bias = False)
        self.q = nn.Linear(n_embd, head_size, bias = False)
        self.v = nn.Linear(n_embd, head_size, bias = False)

        self.register_buffer("tril", torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B, T, C = x.shape

        k = self.k(x) # (B, T, head_size)
        q = self.q(x) # (B, T, head_size)
        v = self.v(x) # (B, T, head_size)

        #scaled dot product of key and query
        attn = (q @ k.transpose(-2, -1)) / (k.size(-1)**0.5)

        #casual mask
        attn = attn.masked_fill(self.tril[:T, :T]==0, float("-inf")) #tril creates a triangular lower half matrix, masked fill replaces futures tokens i.e. ones with 0 (due to tril) with - inf: which becomes 0 after softmax
                              
        attn = F.softmax(attn, dim=-1)

        return attn @ v # (B, T, head_size)
    

In [36]:
class MultiHeadAttention(nn.Module):
    def __init__(self, n_head, n_embd, block_size):
        super().__init__()
        
        head_size = n_embd // n_head
        
        self.heads = nn.ModuleList([
            CausalSelfAttentionHead(n_embd, head_size, block_size)
            for _ in range(n_head)
        ])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim =-1)
        return self.proj(out)

In [37]:
class MLP(nn.Module):
    def __init__(self, n_embd):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, 4*n_embd),
            nn.GELU(),
            nn.Linear(n_embd*4, n_embd)
        )
    def forward(self, x):
        return self.net(x)

In [38]:
class TransformerBlock(nn.Module):
    def __init__(self, n_embd, n_head, block_size):
        super().__init__()
        self.ln1 = nn.LayerNorm(n_embd)
        self.attn = MultiHeadAttention(n_head, n_embd, block_size)
        self.ln2 = nn.LayerNorm(n_embd)
        self.mlp = MLP(n_embd)

    def forward(self, x):
        x = x + self.attn(self.ln1(x))
        x = x + self.mlp(self.ln2(x))
        return x

In [39]:
class GPT(nn.Module):
    def __init__(self,vocab_size, n_embd, n_head, n_layer, block_size):
        super().__init__()

        #embedding
        self.token_emb = nn.Embedding(vocab_size, n_embd)
        self.pos_emb = nn.Embedding(block_size, n_embd)

        #transformer block
        self.blocks = nn.Sequential(
            *[TransformerBlock(n_embd, n_head, block_size) for _ in range(n_layer)]
        )
        
        #Final Layer Norm
        self.ln_f = nn.LayerNorm(n_embd)

        self.lm_head = nn.Linear(n_embd, vocab_size, bias = False)

        # weight tying (GPT-2 style)
        self.lm_head.weight = self.token_emb.weight

        self.block_size = block_size
        self.vocab_size = vocab_size

    def forward(self, idx, targets=None):
        """
        idx : (B, T)
        targets: (B, T) or none
        """
        B, T = idx.shape
        assert T <= self.block_size, "Sequence length exceeds block size"

        # embeddings
        tok = self.token_emb(idx)  # (B, T, n_embd)
        pos = self.pos_emb(torch.arange(T, device=idx.device))  # (T, n_embd)
        x = tok + pos  # (B, T, n_embd)

        # transformer
        x = self.blocks(x)
        x = self.ln_f(x)

        # logits
        logits = self.lm_head(x)  # (B, T, vocab_size)

        loss = None
        if targets is not None:
            logits = logits.view(-1, self.vocab_size)
            targets = targets.view(-1)
            loss = F.cross_entropy(logits, targets)

        return logits, loss


In [40]:
@torch.no_grad()
def generate(model, idx, max_new_tokens):
    device = next(model.parameters()).device
    idx = idx.to(device)

    for _ in range(max_new_tokens):
        idx_cond = idx[:, -model.block_size:]

        logits, _ = model(idx_cond)
        logits = logits[:, -1, :]          # (B, vocab)
        probs = F.softmax(logits, dim=-1)

        # MPS-safe sampling
        idx_next = torch.multinomial(
            probs.cpu(), num_samples=1
        ).to(device)

        idx = torch.cat([idx, idx_next], dim=1)

    return idx

In [41]:
#vocab_size = 50257 # already defined earlier
n_embd = 256 #embedding dimension
n_head = 8
n_layer = 6
block_size = 256 #max context length

model = GPT(
    vocab_size = vocab_size,
    n_embd = n_embd,
    n_head = n_head,
    n_layer = n_layer,
    block_size = block_size
)
model = model.to(device)

In [42]:
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4)

In [46]:
%%time
num_steps = 12000

for step in range(num_steps):
    xb, yb = get_batch("train", T)  # (B, T), (B, T)

    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    if step % 100 == 0:
        print(f"step {step} | loss {loss.item():.4f}")

step 0 | loss 4.5180
step 100 | loss 4.4793
step 200 | loss 4.2571
step 300 | loss 3.9151
step 400 | loss 4.5281
step 500 | loss 4.2226
step 600 | loss 3.8728
step 700 | loss 4.0065
step 800 | loss 4.4602
step 900 | loss 4.1559
step 1000 | loss 4.0248
step 1100 | loss 4.3312
step 1200 | loss 4.2090
step 1300 | loss 3.7860
step 1400 | loss 4.0280
step 1500 | loss 4.0067
step 1600 | loss 4.1319
step 1700 | loss 3.7496
step 1800 | loss 3.6849
step 1900 | loss 4.0623
step 2000 | loss 4.2197
step 2100 | loss 3.5693
step 2200 | loss 3.9186
step 2300 | loss 3.9745
step 2400 | loss 3.6727
step 2500 | loss 3.2856
step 2600 | loss 3.2693
step 2700 | loss 3.8220
step 2800 | loss 3.4356
step 2900 | loss 3.7082
step 3000 | loss 3.3516
step 3100 | loss 3.1928
step 3200 | loss 3.6121
step 3300 | loss 3.5367
step 3400 | loss 3.0117
step 3500 | loss 3.4192
step 3600 | loss 2.9645
step 3700 | loss 3.1129
step 3800 | loss 3.1557
step 3900 | loss 3.1249
step 4000 | loss 3.1448
step 4100 | loss 3.2409
step

In [48]:
context = torch.tensor(
    enc.encode("To be, or not to be"),
    dtype=torch.long,
    device=device
).unsqueeze(0)

out = generate(model, context, max_new_tokens=200)
print(enc.decode(out[0].tolist()))


To be, or not to be gall
Against a tongue rasciv to come; at the hands
The child if deny to take him.

 ble lineal, ho! what dissolution!
'Yourbroke, or at bone!
Who is that thy lodging:
This exile is great pound come?
The sounded by my heart; who, every one hap is sweet;
Banish'd is dark to went with high?
The treachery heaven wilderness is wash'd,
And who part in high ready sense; unless
The contrary doth death,
The contrary and do a want of false one that is servant,
To this breathingr'd, as a secrets in second hopes,
For hers, all a purpose.
Whereto, fools, devise dear Romeo by thee;
She shall not thy lance.
He shall not name his father that No: let me speak,
It shall not be reign as it be to no.
Strike up with


In [49]:
torch.save(model.state_dict(), "gpt_bpe_shakespeare.pt")

In [50]:
checkpoint = {
    "model_state": model.state_dict(),
    "config": {
        "vocab_size": vocab_size,
        "n_embd": n_embd,
        "n_head": n_head,
        "n_layer": n_layer,
        "block_size": block_size,
    },
    "step": step,
    "loss": loss.item(),
}

torch.save(checkpoint, "gpt_checkpoint.pt")


In [51]:
checkpoint = {
    "model_state": model.state_dict(),
    "optimizer_state": optimizer.state_dict(),
    "step": step,
}
torch.save(checkpoint, "gpt_resume.pt")

# Loading the saved model

In [52]:
checkpoint = torch.load("gpt_checkpoint.pt", map_location=device)
model.load_state_dict(checkpoint["model_state"])
model = model.to(device)
model.eval()


GPT(
  (token_emb): Embedding(50257, 256)
  (pos_emb): Embedding(256, 256)
  (blocks): Sequential(
    (0): TransformerBlock(
      (ln1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (attn): MultiHeadAttention(
        (heads): ModuleList(
          (0-7): 8 x CausalSelfAttentionHead(
            (k): Linear(in_features=256, out_features=32, bias=False)
            (q): Linear(in_features=256, out_features=32, bias=False)
            (v): Linear(in_features=256, out_features=32, bias=False)
          )
        )
        (proj): Linear(in_features=256, out_features=256, bias=True)
      )
      (ln2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (net): Sequential(
          (0): Linear(in_features=256, out_features=1024, bias=True)
          (1): GELU(approximate='none')
          (2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
    )
    (1): TransformerBlock(
      (ln1): LayerNorm((256,), eps=1e-05, elemen

In [47]:
#To resume training
optimizer.load_state_dict(checkpoint["optimizer_state"])
