In [2]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2026-01-01 13:12:23--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 2606:50c0:8002::154, 2606:50c0:8003::154, 2606:50c0:8000::154, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|2606:50c0:8002::154|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2026-01-01 13:12:24 (2.11 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [3]:
print(f'length of text: {len(text)}')

length of text: 1115394


In [4]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(' '.join(chars))
print(f'vocab size: {vocab_size}')


   ! $ & ' , - . 3 : ; ? A B C D E F G H I J K L M N O P Q R S T U V W X Y Z a b c d e f g h i j k l m n o p q r s t u v w x y z
vocab size: 65


In [5]:
stoi = {ch:i for i,ch in enumerate(chars)}
itos = {i:ch for i,ch in enumerate(chars)}

encode = lambda s: [stoi[char] for char in s]
decode = lambda l: ''.join([itos[index] for index in l])

In [6]:
print(encode('hii there'))
print(decode(encode('hii there')))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [7]:
import torch
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.dtype)

  cpu = _conversion_method_template(device=torch.device("cpu"))


torch.Size([1115394]) torch.int64


In [8]:
print(data[:10])

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47])


In [9]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f'bridge: {n}')
print(f'Train size: {len(train_data)}, Val size: {len(val_data)}')

bridge: 1003854
Train size: 1003854, Val size: 111540


In [10]:
block_size = 8
train_data[:block_size]

tensor([18, 47, 56, 57, 58,  1, 15, 47])

In [11]:
torch.manual_seed(1337)
batch_size = 4
block_size = 8

def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size, ))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])

    return x, y

xb, yb = get_batch('train')

print(f'batch shape: xb: {xb.shape} and yb: {yb.shape}')
print(f'xb: {xb}')
print(f'yb: {yb}') 

batch shape: xb: torch.Size([4, 8]) and yb: torch.Size([4, 8])
xb: tensor([[24, 43, 58,  5, 57,  1, 46, 43],
        [44, 53, 56,  1, 58, 46, 39, 58],
        [52, 58,  1, 58, 46, 39, 58,  1],
        [25, 17, 27, 10,  0, 21,  1, 54]])
yb: tensor([[43, 58,  5, 57,  1, 46, 43, 39],
        [53, 56,  1, 58, 46, 39, 58,  1],
        [58,  1, 58, 46, 39, 58,  1, 46],
        [17, 27, 10,  0, 21,  1, 54, 39]])


In [12]:
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(1337)
B, T, C = 4, 8, 32
 
x = torch.randn(B, T, C)

# Head of self attention
head_size = 16
key = nn.Linear(C, head_size, bias=False)
query = nn.Linear(C, head_size, bias=False)
value = nn.Linear(C, head_size, bias=False)
k = key(x) # (B,T,h)
q = query(x) # (B,T,h)

wei = q @ k.transpose(-2, -1) # (B, T, h) @ (B, h, T) -> (B, T, T)

tril = torch.tril(torch.ones(T, T))

# allow tokens to only talk to the tokens at previous positions.
# avoid tokens to talk to the future tokens!
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

v = value(x) # B,T,h
out = wei @ v
out.shape

torch.Size([4, 8, 16])

In [13]:
wei[0]

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1574, 0.8426, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2088, 0.1646, 0.6266, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5792, 0.1187, 0.1889, 0.1131, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0294, 0.1052, 0.0469, 0.0276, 0.7909, 0.0000, 0.0000, 0.0000],
        [0.0176, 0.2689, 0.0215, 0.0089, 0.6812, 0.0019, 0.0000, 0.0000],
        [0.1691, 0.4066, 0.0438, 0.0416, 0.1048, 0.2012, 0.0329, 0.0000],
        [0.0210, 0.0843, 0.0555, 0.2297, 0.0573, 0.0709, 0.2423, 0.2391]],
       grad_fn=<SelectBackward0>)

In [42]:
batch_size = 16
block_size = 32
max_iters = 5000
eval_interval = 500
learning_rate = 1e-3
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
# Embedding Dimension
n_embed = 64
# Number of Heads per Attention Block
n_head = 4
# Number of layers of Attention Blocks
n_layer = 4
dropout = 0

class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embed, head_size, bias=False)
        self.query = nn.Linear(n_embed, head_size, bias=False)
        self.value = nn.Linear(n_embed, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    
    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x) # B,T,head_size
        q = self.query(x) # B,T,head_size

        # perform the weighted aggregation of the values 
        v = self.value(x) # B,T,head_size

        # compute attention scores ("affinities")
        # (B,T,head_size) @ (B,head_size, T)  -> (B,T,T)
        wei = q @ k.transpose(-2, -1) * C**-0.5
        # allow tokens to only talk to the tokens at previous positions.
        # avoid tokens to talk to the future tokens!
        wei = wei.masked_fill(self.tril[:T,:T] == 0, float('-inf')) # B,T,T
        wei = F.softmax(wei, dim=-1) # B,T,T

        out = wei @ v # (B,T,T) @ (B,T,head_size) -> (B,T,head_size)
        return out

In [43]:
class MultiHeadAttention(nn.Module):
    """multiple heads of self-attention in parallel"""
    
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embed, n_embed)
    
    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1)
        out = self.proj(out)
        return out

In [44]:
class FeedForward(nn.Module):
    """a simple linear ayer followed by a non-linearity"""

    def __init__(self, n_embed):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embed, 4 * n_embed),
            nn.ReLU(),
            nn.Linear(4 * n_embed, n_embed),
        )

    def forward(self, x):
        return self.net(x)

In [45]:
class Block(nn.Module):
    """Transformer block: communication followed by computation"""
    
    def __init__(self, n_embed, n_head):
        super().__init__()
        head_size = n_embed // n_head
        self.sa = MultiHeadAttention(n_head, head_size)
        self.ffwd = FeedForward(n_embed)
        self.ln1 = nn.LayerNorm(n_embed)
        self.ln2 = nn.LayerNorm(n_embed)
    
    def forward(self, x):
        # Residual Connections
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [65]:
class Transformer(nn.Module):

    def __init__(self):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embed)
        self.positional_embedding_table = nn.Embedding(block_size, n_embed)
        self.blocks = nn.Sequential(*[Block(n_embed, n_head) for _ in range(n_layer)])
        self.layer_norm = nn.LayerNorm(n_embed)
        self.linear = nn.Linear(n_embed, vocab_size)
    
    def forward(self, idx, targets=None):
        B, T = idx.shape
        token_embedding = self.token_embedding_table(idx) # B, T, n_embed
        positional_embedding = self.positional_embedding_table(torch.arange(T, device=device)) # T, n_embed
        x = token_embedding + positional_embedding # B, T, n_embed

        x = self.blocks(x) # B, T, n_embed
        x = self.layer_norm(x) # B, T, n_embed
        logits = self.linear(x) # B, T, vocab_size

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)
        
        return logits, loss

    def generate(self, idx, max_new_tokens=1000):
        for _ in range(max_new_tokens):
            idx_trimmed = idx[:, -block_size:] # B, T
            logits, loss = self.forward(idx_trimmed) # B, T, n_embed
            next_predicted_logit = logits[:,-1,:] # B, n_embed
            probs = F.softmax(next_predicted_logit, dim=-1) # B, 1
            next_predicted_character = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, next_predicted_character), dim=1) # B,T+1

        return idx
       

In [66]:
model = Transformer()
model.to(device)

Transformer(
  (token_embedding_table): Embedding(65, 64)
  (positional_embedding_table): Embedding(32, 64)
  (blocks): Sequential(
    (0): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
          (0-3): 4 x Head(
            (key): Linear(in_features=64, out_features=16, bias=False)
            (query): Linear(in_features=64, out_features=16, bias=False)
            (value): Linear(in_features=64, out_features=16, bias=False)
          )
        )
        (proj): Linear(in_features=64, out_features=64, bias=True)
      )
      (ffwd): FeedForward(
        (net): Sequential(
          (0): Linear(in_features=64, out_features=256, bias=True)
          (1): ReLU()
          (2): Linear(in_features=256, out_features=64, bias=True)
        )
      )
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
    )
    (1): Block(
      (sa): MultiHeadAttention(
        (heads): ModuleList(
       

In [67]:
print(sum(parameter.numel() for parameter in model.parameters()) / 1e6, ' M parameters')

0.209729  M parameters


In [75]:
# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

losses_per_eval_interval = []
for iter in range(max_iters):
    if iter % eval_interval == 0 or iter == (max_iters - 1):
        if losses_per_eval_interval:
            average_loss_per_eval_interval = sum(losses_per_eval_interval) / len(losses_per_eval_interval)
            print(f'Iteration: {iter} | Loss: {average_loss_per_eval_interval}')
    
    Xb, Yb = get_batch('train')

    logits, loss = model(Xb, Yb)
    optimizer.zero_grad(set_to_none=True)
    losses_per_eval_interval.append(loss.item())
    loss.backward()
    optimizer.step()


Iteration: 500 | Loss: 1.5014781765937806
Iteration: 1000 | Loss: 1.5010536304712296
Iteration: 1500 | Loss: 1.499786115805308
Iteration: 2000 | Loss: 1.498626964211464
Iteration: 2500 | Loss: 1.4969839842796326
Iteration: 3000 | Loss: 1.495960422317187
Iteration: 3500 | Loss: 1.4942270474774497
Iteration: 4000 | Loss: 1.49324335116148
Iteration: 4500 | Loss: 1.4922548153665331
Iteration: 4999 | Loss: 1.4909873292264426


In [76]:
# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(model.generate(context, max_new_tokens=2000)[0].tolist()))


That mean in home, shoo prince. Let me, the to blame,
That see spicial that be suns. Say he lady man's wellder
I spake to looks on, shorthy,
Let him or honour's furge. Lord His shally
Purou, Lady York you good sees;
He deaden that, a Capullius.

CORIOLANUS:
This new, thou mounths may with; Camille of any lady,
And for thyself recore, sirn for York, a sequits?
She with put haur is laboured sinful liber,
With please that leak that was I knowful please.
We musted perfear'd fings, if Richard yet he, the wail and to-day.

Second Citizen:
A bad horsome
he iell? death for yet monums before ent:
For that a heaph'n I have thee but it me;
He, not tune again
Thou:
Let me not vallasterity, a means,
And he to bark, an yet no. Wert Well he lay to which the eyes,
Injoy'd all were be thee comes in my empate,
And her, any looks that make the clows be of thoses;
We way, good myst any, but a buringler
And forth thou reason edger father exseen
And ragal:
I let not buy, if thou is shalt cannot len expratt