In [1]:
import typing
import math
import random

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch.nn import functional as F
from torch import Tensor

In [2]:
print("Torch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA version:", torch.version.cuda)
print("Current device:", torch.cuda.current_device() if torch.cuda.is_available() else "N/A")
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "N/A")

Torch version: 2.7.0+cu126
CUDA available: False
CUDA version: 12.6
Current device: N/A
Device name: N/A


In [3]:
if torch.cuda.is_available():
    device = torch.device('cuda')
    print('Using CUDA')
else:
    device = torch.device('cpu')
    print('Using CPU')

Using CPU


In [None]:
torch.manual_seed(1337)

<torch._C.Generator at 0x7918ebb2cab0>

In [5]:
with open('tiny-shakespeare.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [6]:
len(text)

1115393

In [7]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [8]:
chars = sorted(list(set(text)))
print(str().join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [9]:
vocab_size = len(chars)
vocab_size

65

In [10]:
stoi = {ch: i for i, ch in enumerate(chars)}
print(stoi)

{'\n': 0, ' ': 1, '!': 2, '$': 3, '&': 4, "'": 5, ',': 6, '-': 7, '.': 8, '3': 9, ':': 10, ';': 11, '?': 12, 'A': 13, 'B': 14, 'C': 15, 'D': 16, 'E': 17, 'F': 18, 'G': 19, 'H': 20, 'I': 21, 'J': 22, 'K': 23, 'L': 24, 'M': 25, 'N': 26, 'O': 27, 'P': 28, 'Q': 29, 'R': 30, 'S': 31, 'T': 32, 'U': 33, 'V': 34, 'W': 35, 'X': 36, 'Y': 37, 'Z': 38, 'a': 39, 'b': 40, 'c': 41, 'd': 42, 'e': 43, 'f': 44, 'g': 45, 'h': 46, 'i': 47, 'j': 48, 'k': 49, 'l': 50, 'm': 51, 'n': 52, 'o': 53, 'p': 54, 'q': 55, 'r': 56, 's': 57, 't': 58, 'u': 59, 'v': 60, 'w': 61, 'x': 62, 'y': 63, 'z': 64}


In [11]:
itos = {i: ch for ch, i in stoi.items()}
print(itos)

{0: '\n', 1: ' ', 2: '!', 3: '$', 4: '&', 5: "'", 6: ',', 7: '-', 8: '.', 9: '3', 10: ':', 11: ';', 12: '?', 13: 'A', 14: 'B', 15: 'C', 16: 'D', 17: 'E', 18: 'F', 19: 'G', 20: 'H', 21: 'I', 22: 'J', 23: 'K', 24: 'L', 25: 'M', 26: 'N', 27: 'O', 28: 'P', 29: 'Q', 30: 'R', 31: 'S', 32: 'T', 33: 'U', 34: 'V', 35: 'W', 36: 'X', 37: 'Y', 38: 'Z', 39: 'a', 40: 'b', 41: 'c', 42: 'd', 43: 'e', 44: 'f', 45: 'g', 46: 'h', 47: 'i', 48: 'j', 49: 'k', 50: 'l', 51: 'm', 52: 'n', 53: 'o', 54: 'p', 55: 'q', 56: 'r', 57: 's', 58: 't', 59: 'u', 60: 'v', 61: 'w', 62: 'x', 63: 'y', 64: 'z'}


In [12]:
def encode(s: str) -> typing.List[int]:
    return [stoi[c] for c in s]

def decode(ints: typing.List[int]) -> str:
    return str().join(itos[i] for i in ints)

In [13]:
print(encode('hello world'))
print(decode(encode('hello world')))

[46, 43, 50, 50, 53, 1, 61, 53, 56, 50, 42]
hello world


In [14]:
encoded_text = encode(text)
len(encoded_text)

1115393

In [None]:
data = torch.tensor(encoded_text, dtype=torch.long, device=device)
data.shape

torch.Size([1115393])

In [16]:
data.dtype

torch.int64

In [17]:
data[:1000]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 14, 43, 44,
        53, 56, 43,  1, 61, 43,  1, 54, 56, 53, 41, 43, 43, 42,  1, 39, 52, 63,
         1, 44, 59, 56, 58, 46, 43, 56,  6,  1, 46, 43, 39, 56,  1, 51, 43,  1,
        57, 54, 43, 39, 49,  8,  0,  0, 13, 50, 50, 10,  0, 31, 54, 43, 39, 49,
         6,  1, 57, 54, 43, 39, 49,  8,  0,  0, 18, 47, 56, 57, 58,  1, 15, 47,
        58, 47, 64, 43, 52, 10,  0, 37, 53, 59,  1, 39, 56, 43,  1, 39, 50, 50,
         1, 56, 43, 57, 53, 50, 60, 43, 42,  1, 56, 39, 58, 46, 43, 56,  1, 58,
        53,  1, 42, 47, 43,  1, 58, 46, 39, 52,  1, 58, 53,  1, 44, 39, 51, 47,
        57, 46, 12,  0,  0, 13, 50, 50, 10,  0, 30, 43, 57, 53, 50, 60, 43, 42,
         8,  1, 56, 43, 57, 53, 50, 60, 43, 42,  8,  0,  0, 18, 47, 56, 57, 58,
         1, 15, 47, 58, 47, 64, 43, 52, 10,  0, 18, 47, 56, 57, 58,  6,  1, 63,
        53, 59,  1, 49, 52, 53, 61,  1, 15, 39, 47, 59, 57,  1, 25, 39, 56, 41,
        47, 59, 57,  1, 47, 57,  1, 41, 

In [18]:
n = int(0.9 * len(data))
train_data = data[:n]
val_data = data[n:]
print(f'Training size: {len(train_data)}, Validation size: {len(val_data)}')

Training size: 1003853, Validation size: 111540


In [19]:
block_size = 8 # Also called "context length"

In [20]:
train_data[:block_size+1]

tensor([18, 47, 56, 57, 58,  1, 15, 47, 58])

In [21]:
xb = train_data[:block_size]
yb = train_data[1:block_size+1]
print('--- As characters ---')
for t in range(block_size):
    context = xb[:t+1]
    target = yb[t]
    print(f'When the input is {decode(context.tolist())} the next character is {itos[target.item()]}')
print('--- Encoded ---')
for t in range(block_size):
    context = xb[:t+1]
    target = yb[t]
    print(f'When the input is {context} the next character is {target}')

--- As characters ---
When the input is F the next character is i
When the input is Fi the next character is r
When the input is Fir the next character is s
When the input is Firs the next character is t
When the input is First the next character is  
When the input is First  the next character is C
When the input is First C the next character is i
When the input is First Ci the next character is t
--- Encoded ---
When the input is tensor([18]) the next character is 47
When the input is tensor([18, 47]) the next character is 56
When the input is tensor([18, 47, 56]) the next character is 57
When the input is tensor([18, 47, 56, 57]) the next character is 58
When the input is tensor([18, 47, 56, 57, 58]) the next character is 1
When the input is tensor([18, 47, 56, 57, 58,  1]) the next character is 15
When the input is tensor([18, 47, 56, 57, 58,  1, 15]) the next character is 47
When the input is tensor([18, 47, 56, 57, 58,  1, 15, 47]) the next character is 58


In [22]:
def get_batch(dataset: Tensor, batch_size: int, block_size: int, device=None) -> typing.Tuple[Tensor, Tensor]:
    '''
    Gets a batch of `batch_size` examples from `dataset`. Each example will
    consist of `block_size` characters. The inputs and labels will both be
    returned, both of which will be of size `(batch_size, block_size)`.
    '''

    ix = torch.randint(low=0, high=len(dataset)-block_size, size=(batch_size,), device=device)
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [23]:
batch_size = 4
xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
print(xb.shape)
print(xb)
print(yb.shape)
print(yb)

torch.Size([4, 8])
tensor([[53, 59,  6,  1, 58, 56, 47, 40],
        [49, 43, 43, 54,  1, 47, 58,  1],
        [13, 52, 45, 43, 50, 53,  8,  0],
        [ 1, 39,  1, 46, 53, 59, 57, 43]])
torch.Size([4, 8])
tensor([[59,  6,  1, 58, 56, 47, 40, 59],
        [43, 43, 54,  1, 47, 58,  1, 58],
        [52, 45, 43, 50, 53,  8,  0, 26],
        [39,  1, 46, 53, 59, 57, 43,  0]])


In [24]:
for b in range(batch_size):
    print(f'Example {b}')
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b, t]
        print(f'Block {t}: When the input is {context} the next character is {target}')

Example 0
Block 0: When the input is tensor([53]) the next character is 59
Block 1: When the input is tensor([53, 59]) the next character is 6
Block 2: When the input is tensor([53, 59,  6]) the next character is 1
Block 3: When the input is tensor([53, 59,  6,  1]) the next character is 58
Block 4: When the input is tensor([53, 59,  6,  1, 58]) the next character is 56
Block 5: When the input is tensor([53, 59,  6,  1, 58, 56]) the next character is 47
Block 6: When the input is tensor([53, 59,  6,  1, 58, 56, 47]) the next character is 40
Block 7: When the input is tensor([53, 59,  6,  1, 58, 56, 47, 40]) the next character is 59
Example 1
Block 0: When the input is tensor([49]) the next character is 43
Block 1: When the input is tensor([49, 43]) the next character is 43
Block 2: When the input is tensor([49, 43, 43]) the next character is 54
Block 3: When the input is tensor([49, 43, 43, 54]) the next character is 1
Block 4: When the input is tensor([49, 43, 43, 54,  1]) the next ch

In [25]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    token_embedding_table: nn.Embedding

    def __init__(self, vocab_size: int, device=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.token_embedding_table = nn.Embedding(vocab_size, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        #
        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the vocabulary size in this case).
        logits = self.token_embedding_table(idx)
        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [26]:
model = BigramLanguageModel(vocab_size, device=device)
logits, loss = model(xb, yb)
print(logits.shape)
print(loss)

torch.Size([32, 65])
tensor(4.9456, grad_fn=<NllLossBackward0>)


In [27]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=100)[0].tolist()
next_str = decode(next_idx)
print(next_str)


lfJeukRuaRJKXAYtXzfJ:HEPiu--sDioi;ILCo3pHNTmDwJsfheKRxZCFs
lZJ XQc?:s:HEzEnXalEPklcPU cL'DpdLCafBheH


In [28]:
@torch.no_grad()
def estimate_loss(model: BigramLanguageModel, train_dataset: Tensor, val_dataset: Tensor, eval_iterations: int, batch_size: int, block_size: int, device = None) -> typing.Dict[str, torch.types.Number]:
    dataset_splits = {'train': train_dataset, 'val': val_dataset}
    out = dict()
    for split_name, split_dataset in dataset_splits.items():
        losses = torch.zeros(eval_iterations, device=device)
        for i in range(eval_iterations):
            xb, yb = get_batch(split_dataset, batch_size, block_size, device)
            logits, loss = model(xb, yb)
            losses[i] = loss.item()
        out[split_name] = losses.mean().item()
    return out

In [29]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.6019, estimated training loss: 4.6372, estimated validation loss: 4.6385
Step: 40     , last seen loss: 4.6411, estimated training loss: 4.5873, estimated validation loss: 4.5857
Step: 80     , last seen loss: 4.5086, estimated training loss: 4.5488, estimated validation loss: 4.5471
Step: 120    , last seen loss: 4.4693, estimated training loss: 4.5019, estimated validation loss: 4.4999
Step: 160    , last seen loss: 4.4676, estimated training loss: 4.4553, estimated validation loss: 4.4559
Step: 200    , last seen loss: 4.4254, estimated training loss: 4.4195, estimated validation loss: 4.4117
Step: 240    , last seen loss: 4.3719, estimated training loss: 4.3730, estimated validation loss: 4.3665
Step: 280    , last seen loss: 4.2231, estimated training loss: 4.3345, estimated validation loss: 4.3317
Step: 320    , last seen loss: 4.2534, estimated training loss: 4.2916, estimated validation loss: 4.2882
Step: 360    , last seen loss: 4.2335, estimat

In [30]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


WZZNuQT?bncemJXrF
Thi3QPVm!PLv&..ydNU.QTkssaZccLCFq$CGPOy3hoSJDD-kNFCfqenzm&WtroYWO?e,-dgAp3BPu-fq!M.kCafgrKOy;LM.BugleHyFRkltvJMyD$VYAz$tciGGHEWn,-SUrxfJTeNo-s:q$VvDHDi'
sJJpkHEThoCZqiisnGNQjt gLEDYZ,
CfOMwJsKBGITkr g;W :!rtinvPlEQfaff,-ELvisDu,

CEuV
OMQUgHakceOMAs!3bkw!Hz-fi?knvARfdd;THVTyXAm Z:3R
k;KklfC3jdNuUCb.ce,NWVRe,xyQSbuig;tktENI!oCANTl&$Lis
Cdk
LcPbrREDKZEFN&M uXL
k?AOropoe3f'S:se-bHNut!'KIQHFp?ZYdUZtINCl;OfpldeAg,,r'Z
suenyes!sSjFENJXALeo-nzy-ko;:tXDim?AP
ZWvRfB
tyWtvQXjw,C bXAEzinO bbl;uXn oerzCXy!sXPfrxfMGE.LeY;oIKlAGEzU.CYTM-s&G,3
GNck zKNAULK;HWgZWMQXzy?Anld'M-MA$Tglrot?sf&sXn'3ptJnvP,,-
C$
BoTIiJYakDoYJHFGHFoCbtyopZLusqG&av&3ptJlsX
OMlDYeay,DjakHbXV
kEMItJMcbllld.OAINUrPWtThp!s&qMy-M3pr,!EA,O
st'fyntue,e,IU,Xas
FRNU.kDXsin:HW3b.Bi;.ernz-spt'mo.Gr!N3r rd lfEz!qG Asa$
3wgq!XyM m-EYWPCj!'-NNTaYDiwAs lligerWjlldero OiZQ Zzdd $DilMst'f$rXQZJ&DF'Cag pP?ZJ3M.ligmy
;EAtPWys
&zrRSYjPLFkl;?oMACSmOSJqaZJ?& f!q!EQllcNp moP yiNgD$croMGasLYIPirsgaIo,BRjncTNqTcnZcLReOPUr?UCxfmduyTY

In [31]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    n_embd: int
    block_size: int

    token_embedding_table: nn.Embedding
    position_embedding_table: nn.Embedding
    lm_head: nn.Linear

    def __init__(self, vocab_size: int, block_size: int, n_embd: int, device=None):
        super().__init__()
        self.vocab_size = vocab_size
        self.n_embd = n_embd
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        B, T = idx.shape

        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the number of embeddings in this case).
        token_embeddings = self.token_embedding_table(idx)

        # Shape is (T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        # Addition gets broadcasted, shape is (B,T,C)
        x = token_embeddings + position_embeddings

        # We then apply the linear layer, which gives us a (B, T, vocab_size)
        # tensor, which are our logits.
        logits = self.lm_head(x)

        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # Crop idx to the last `block_size` tokens
            idx_cond = idx[:, -block_size:]
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx_cond)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=-1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [32]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300
n_embd = 32

model = BigramLanguageModel(vocab_size, block_size, n_embd, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.6175, estimated training loss: 4.5710, estimated validation loss: 4.5586
Step: 40     , last seen loss: 4.0492, estimated training loss: 4.0491, estimated validation loss: 4.0295
Step: 80     , last seen loss: 3.6627, estimated training loss: 3.6554, estimated validation loss: 3.6471
Step: 120    , last seen loss: 3.4259, estimated training loss: 3.3608, estimated validation loss: 3.3506
Step: 160    , last seen loss: 3.1479, estimated training loss: 3.1660, estimated validation loss: 3.1600
Step: 200    , last seen loss: 3.0765, estimated training loss: 3.0443, estimated validation loss: 3.0324
Step: 240    , last seen loss: 2.8424, estimated training loss: 2.9563, estimated validation loss: 2.9436
Step: 280    , last seen loss: 3.1493, estimated training loss: 2.8927, estimated validation loss: 2.8822
Step: 320    , last seen loss: 2.8410, estimated training loss: 2.8517, estimated validation loss: 2.8322
Step: 360    , last seen loss: 2.8110, estimat

In [33]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


Thawe beando alof u;
AEORUSseecullbeO FiKIChe pOusithorise spy tau C,
Whinutngl itaton comendodopond;

O:
T,

MIEEZe! age ndo icised heeerkSw therSineskean!
Tit h thor ar yort, in, bacare d myak ad tspith stowoSyo crs, br, s cor; R:
Kraonok f
Hy,
I:
ESievar,
ENGPd my thithallld mer-' bO:

Not missthithill:


HiNIOlce whe gfr mthe avKanomincowhipaildun es. thae trou hano.
F,  Sind wostorit t, ftis Jfinenghevu sad de .
ASapte! g t'atocENE:
AdeIEL,teveaoa.
ENRL
PeaDU:
RThis havechine, al t wR.
Thess t?
H:
SC, w, oles matheTE;

ARbanothonfourst, tn! bl
NECk wikOCvig wEpose seed wogh a moutorkor t ,
NEcispocd them n ane:
Pckvimesen REY:
Wheo urerd;veny tstaeme st. ge t
Ny dtirs gasthr me wofat lantnsqen bdd.
edWisomnyorond, tersellesg?iticorgele soworear T. o bictal I si;
A.
TMo s mont we .
Athat thorichowndeesl asuo ft omin I:
CAlesans f ige, fr I:
WOind s f t oZanisis, t th s.

I moTs hhit st b ten o wernoss, boLI maing onlThe hawof g suu cat d
Iavomanor i th nda$soru t e t M-;
:
Won ser

## Self-Attention

In [34]:
class Head(nn.Module):

    key: nn.Linear
    query: nn.Linear
    value: nn.Linear

    def __init__(self, head_size: int, n_embd: int, block_size:int , device = None):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False, device=device)
        self.query = nn.Linear(n_embd, head_size, bias=False, device=device)
        self.value = nn.Linear(n_embd, head_size, bias=False, device=device)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size, device=device)))

    def forward(self, x: Tensor) -> Tensor:
        B, T, C = x.shape

        k = self.key(x) # (B, T, C)
        q = self.query(x) # (B, T, C)

        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        v = self.value(x) # (B, T, C)
        out = wei @ v # (B, T, C)

        return out

In [35]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    n_embd: int
    block_size: int
    head_size: int

    token_embedding_table: nn.Embedding
    position_embedding_table: nn.Embedding
    lm_head: nn.Linear
    sa_head: Head

    def __init__(self, vocab_size: int, block_size: int, n_embd: int, head_size: int, device=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)

        self.sa_head = Head(head_size, n_embd, block_size, device=device)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        B, T = idx.shape

        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the number of embeddings in this case).
        token_embeddings = self.token_embedding_table(idx)

        # Shape is (T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        # Addition gets broadcasted, shape is (B,T,C)
        x = token_embeddings + position_embeddings

        # Apply self-attention head
        x = self.sa_head(x)

        # We then apply the linear layer, which gives us a (B, T, vocab_size)
        # tensor, which are our logits.
        logits = self.lm_head(x)

        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # Crop idx to the last `block_size` tokens
            idx_cond = idx[:, -block_size:]
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx_cond)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=-1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [36]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300
n_embd = 32
head_size = 32

model = BigramLanguageModel(vocab_size, block_size, n_embd, head_size, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.2455, estimated training loss: 4.2274, estimated validation loss: 4.2275
Step: 40     , last seen loss: 3.5464, estimated training loss: 3.5262, estimated validation loss: 3.5265
Step: 80     , last seen loss: 3.2580, estimated training loss: 3.2525, estimated validation loss: 3.2492
Step: 120    , last seen loss: 3.4024, estimated training loss: 3.1643, estimated validation loss: 3.1636
Step: 160    , last seen loss: 3.1352, estimated training loss: 3.0846, estimated validation loss: 3.0838
Step: 200    , last seen loss: 3.1218, estimated training loss: 3.0189, estimated validation loss: 3.0092
Step: 240    , last seen loss: 3.0605, estimated training loss: 2.9590, estimated validation loss: 2.9382
Step: 280    , last seen loss: 3.0463, estimated training loss: 2.8991, estimated validation loss: 2.8678
Step: 320    , last seen loss: 2.9528, estimated training loss: 2.8229, estimated validation loss: 2.8062
Step: 360    , last seen loss: 2.6480, estimat

In [37]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


Fhad ty owuprselry ouris; ow?e the ay stir tangoprr had I mer yer boure, mUs mlpon b-lid cth bar, Dand, be hk be lt thes to ucow we.

Bofd dlot hy flouselas wd,,
S;
MGN
KW,
Tanret cads ml
O, hd foVingoul IHpe pamoArd daces abes,
omovy tsat ftat dr yory Hyous thasove'd.:
Wa, us gothinc a werifounoklgos rd theic, oou as dstf IGus
I P lusft tond p. 
ETishe, tky roms.


e f kpoin tH;

WIichist id hioul wary, u oto os: congisof qman.

Ton hsomino dee f bo E
Ckiont ary bo c om.
Tacoud hor fe tikger'r yon ane weson hou py to cosut opid, a Mtheay plavo th ou'errerr ben belas heR lpar agteth, ber ben, falserdcilet man k'st the tret, sanf mheakad tigereame ite I pee llcimcr eklarece lle landis g? our ble wn tyo sad hens st
CChe aye clanilasen pas,
War ir.

St car.

Y nd tleun anmen wecupole ande, b;
Aloue,.
B
The al Ier, may berle of ve s's ihet ar ff blito jatvis thithedow
Warwalde bles
PAh srsillf fe couc kily finof pe INerofo hiome pomis tmed; nd, psarkca?
Wr amel ds wyo gom sheaf gye! thay 

### Multiple-Head Attention

In [38]:
class MultiHeadAttention(nn.Module):

    def __init__(self, num_heads: int, head_size: int, n_embd: int, block_size: int, device = None):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, device=device) for _ in range(num_heads)])

    def forward(self, x: Tensor) -> Tensor:
        return torch.cat([h(x) for h in self.heads], dim=-1)

In [39]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    n_embd: int
    block_size: int
    head_size: int

    token_embedding_table: nn.Embedding
    position_embedding_table: nn.Embedding
    lm_head: nn.Linear
    sa_heads: MultiHeadAttention

    def __init__(self, vocab_size: int, block_size: int, n_embd: int, num_heads: int, head_size: int, device=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)

        self.sa_heads = MultiHeadAttention(num_heads, head_size // num_heads, n_embd, block_size, device=device)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        B, T = idx.shape

        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the number of embeddings in this case).
        token_embeddings = self.token_embedding_table(idx)

        # Shape is (T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        # Addition gets broadcasted, shape is (B,T,C)
        x = token_embeddings + position_embeddings

        # Apply self-attention head
        x = self.sa_heads(x)

        # We then apply the linear layer, which gives us a (B, T, vocab_size)
        # tensor, which are our logits.
        logits = self.lm_head(x)

        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # Crop idx to the last `block_size` tokens
            idx_cond = idx[:, -block_size:]
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx_cond)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=-1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [40]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300
n_embd = 32
num_heads = 4
head_size = 32

model = BigramLanguageModel(vocab_size, block_size, n_embd, num_heads, head_size, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.2438, estimated training loss: 4.2147, estimated validation loss: 4.2123
Step: 40     , last seen loss: 3.4863, estimated training loss: 3.4342, estimated validation loss: 3.4232
Step: 80     , last seen loss: 3.2594, estimated training loss: 3.2366, estimated validation loss: 3.2221
Step: 120    , last seen loss: 3.2429, estimated training loss: 3.1475, estimated validation loss: 3.1407
Step: 160    , last seen loss: 2.8948, estimated training loss: 3.0684, estimated validation loss: 3.0430
Step: 200    , last seen loss: 3.0036, estimated training loss: 2.9721, estimated validation loss: 2.9567
Step: 240    , last seen loss: 2.7732, estimated training loss: 2.8977, estimated validation loss: 2.8846
Step: 280    , last seen loss: 2.8140, estimated training loss: 2.8293, estimated validation loss: 2.8238
Step: 320    , last seen loss: 2.8808, estimated training loss: 2.7819, estimated validation loss: 2.7679
Step: 360    , last seen loss: 2.8199, estimat

In [41]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


Wo'dt res:
d awto thome lame ttss fiminhotnicin maven omela thais thel lok youthcy a kmf blof pras mou sir frermer gthurp let hast radn.
Tif


Al dig.
Weree bousthe and wesit he ouriserur.
BAl:

RF:
IOO: nod yond eIabndF qh bickWeart swinsl hon, he pow,

Mot my my omfeay thigto emy btisgs hime yob:
G
Gor! to; of o. ars.

Fy by thins fich, bsoveecy ppraun bonevis hexethey got wgford bof gat this,

Alsrhin
RRO;; dey ghin le mont,
S:

AT
RS-YTtE, dopind waareopfrthe wis d, wos ma fof necesl tigchh thandlt,'Ahard.

Toupck-ve coim se sechin the hat, nomowdukse in he, th ceg ped ar nes younh tobu thouiveglt in:
Dt, angAlinst. sane d atee, ce thee memd fno qathee tos the now ay hom:
K:,, I Boo sitapt ecackbmHere mab myos wig muy cimi why wither lav,:
Hy
SSSI hey herot wond pe pe wu mre ghat soieimks gag wcof and favener tho, domenr mI ren fpar pgons nord ot, konger' kot serd to sak ckrfl wil
Y an'ce, mes dn ILFl;
KNLGon.
NI
AWI has yo.

Ss aacy your, dandy av
d nusindy row dono berecany wI a

### Adding Feed-Forward Layer

This layer adds computation to the communication between tokens.

In [42]:
class FeedForward(nn.Module):

    def __init__(self, n_embd: int, device = None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd, device=device),
            nn.ReLU()
        )

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

In [43]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    n_embd: int
    block_size: int
    head_size: int

    token_embedding_table: nn.Embedding
    position_embedding_table: nn.Embedding
    lm_head: nn.Linear
    ffwd: FeedForward
    sa_heads: MultiHeadAttention

    def __init__(self, vocab_size: int, block_size: int, n_embd: int, num_heads: int, head_size: int, device=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)

        self.sa_heads = MultiHeadAttention(num_heads, head_size // num_heads, n_embd, block_size, device=device)
        self.ffwd = FeedForward(n_embd, device=device)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        B, T = idx.shape

        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the number of embeddings in this case).
        token_embeddings = self.token_embedding_table(idx)

        # Shape is (T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        # Addition gets broadcasted, shape is (B,T,C)
        x = token_embeddings + position_embeddings

        # Apply self-attention head
        x = self.sa_heads(x)

        # We then apply the linear layer, which gives us a (B, T, vocab_size)
        # tensor, which are our logits.
        logits = self.lm_head(x)

        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # Crop idx to the last `block_size` tokens
            idx_cond = idx[:, -block_size:]
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx_cond)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=-1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [44]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300
n_embd = 32
num_heads = 4
head_size = 32

model = BigramLanguageModel(vocab_size, block_size, n_embd, num_heads, head_size, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.2287, estimated training loss: 4.2037, estimated validation loss: 4.1975
Step: 40     , last seen loss: 3.6529, estimated training loss: 3.5844, estimated validation loss: 3.5809
Step: 80     , last seen loss: 3.2606, estimated training loss: 3.2589, estimated validation loss: 3.2616
Step: 120    , last seen loss: 3.2108, estimated training loss: 3.1450, estimated validation loss: 3.1383
Step: 160    , last seen loss: 3.0030, estimated training loss: 3.0429, estimated validation loss: 3.0362
Step: 200    , last seen loss: 3.0117, estimated training loss: 2.9549, estimated validation loss: 2.9391
Step: 240    , last seen loss: 2.8168, estimated training loss: 2.8826, estimated validation loss: 2.8734
Step: 280    , last seen loss: 2.8610, estimated training loss: 2.8279, estimated validation loss: 2.8174
Step: 320    , last seen loss: 2.7813, estimated training loss: 2.7879, estimated validation loss: 2.7759
Step: 360    , last seen loss: 2.8392, estimat

In [45]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


IUFPRASs bcos pot lis; but rstofirerills, okullll pet thod.

Whaf m.

Gt-E, besaser gacaviws wheat INooudpons Med a-wdd! totavese H! ano can ro bcuror bentheawnd hobof ers me
Epren I- worry af hern ihom iindudg bpe speint wave ser sansidsl u ntclouatrll lot rosithea ldt, fiul nay oegtr,
Mh.
DOGAIUDD.

WEDRUMUARHkitrl sid mard agnon b tous mindtave rt buandy to:
Cer ho I cis thaelnounl a; hea she thitr, novole, gsersse a im.
ERKHe weed.

Wh'inp I Omati, I gihakbeleagthiat hat forirmeatiltoue, hho
Dutrs e
Arolider tonis,
H,
E fowrof poreyonf yoor, arearemr ouarth tomr lintes, hou thint th fhisherense thlee!
Gd; yo copende. u
I wicllisible forf npisifde
RGeeci; add won; caer ofunely ata;
HAth o ouy cous ipor hity, Bpof the I ditovit hopors it? fhavast aullerth sa.
cho H wotirrim loures Wigsy I Dharer yor binsgor fot u finl nith, axe binmanou itver wvear gheosurrt milube ant sind hobit thowk:

ME;
WI lha is re mers forMotraidot dhe mary RSh youind I afe:
Fto Yurele beet ther Ithet alouvli

### Skip Connections, Layer Normalization, and Grouping Communication and Computation Layers into Blocks

In [130]:
class MultiHeadAttention(nn.Module):

    heads: nn.ModuleList
    projection: nn.Linear

    def __init__(self, num_heads: int, head_size: int, n_embd: int, block_size: int, device = None):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, device=device) for _ in range(num_heads)])
        self.projection = nn.Linear(num_heads * head_size, n_embd, device=device)

    def forward(self, x: Tensor) -> Tensor:
        # Apply the self-attention heads
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        # Apply the projection
        out = self.projection(out)

        return out

In [131]:
class FeedForward(nn.Module):

    def __init__(self, n_embd: int, device = None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4, device=device),
            nn.ReLU(),
            nn.Linear(n_embd * 4, n_embd, device=device) # Projection layer
        ).to(device)

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

In [None]:
class Block(nn.Module):

    sa: MultiHeadAttention
    ffwd: FeedForward
    ln1: nn.LayerNorm
    ln2: nn.LayerNorm

    def __init__(self, n_embd: int, num_heads: int, block_size: int, device = None):
        super().__init__()
        head_size = n_embd // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size, n_embd, block_size, device=device)
        self.ffwd = FeedForward(n_embd, device=device)
        self.ln1 = nn.LayerNorm(n_embd, device=device)
        self.ln2 = nn.LayerNorm(n_embd, device=device)

    def forward(self, x: Tensor) -> Tensor:
        # Adding `x` to the layers is the skip connection
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [133]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    n_embd: int
    block_size: int
    head_size: int

    token_embedding_table: nn.Embedding
    position_embedding_table: nn.Embedding
    blocks: nn.Sequential
    lm_head: nn.Linear

    def __init__(self, vocab_size: int, block_size: int, n_embd: int, num_heads: int, head_size: int, device=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.n_embd = n_embd

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)

        self.blocks = nn.Sequential(
            Block(n_embd, num_heads, block_size, device=device),
            Block(n_embd, num_heads, block_size, device=device),
            Block(n_embd, num_heads, block_size, device=device),
            nn.LayerNorm(n_embd, device=device)
        ).to(device)

        self.lm_head = nn.Linear(n_embd, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        B, T = idx.shape

        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the number of embeddings in this case).
        token_embeddings = self.token_embedding_table(idx)

        # Shape is (T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        # Addition gets broadcasted, shape is (B,T,C)
        x = token_embeddings + position_embeddings

        # Apply self-attention head
        x = self.blocks(x)

        # We then apply the linear layer, which gives us a (B, T, vocab_size)
        # tensor, which are our logits.
        logits = self.lm_head(x)

        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # Crop idx to the last `block_size` tokens
            idx_cond = idx[:, -block_size:]
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx_cond)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=-1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [134]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300
n_embd = 32
num_heads = 4
head_size = 32

model = BigramLanguageModel(vocab_size, block_size, n_embd, num_heads, head_size, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.3397, estimated training loss: 4.3333, estimated validation loss: 4.3263
Step: 40     , last seen loss: 3.4829, estimated training loss: 3.3710, estimated validation loss: 3.3580
Step: 80     , last seen loss: 3.1658, estimated training loss: 3.0927, estimated validation loss: 3.0977
Step: 120    , last seen loss: 2.8591, estimated training loss: 2.8759, estimated validation loss: 2.8557
Step: 160    , last seen loss: 2.7337, estimated training loss: 2.7402, estimated validation loss: 2.7236
Step: 200    , last seen loss: 2.8474, estimated training loss: 2.6596, estimated validation loss: 2.6420
Step: 240    , last seen loss: 2.4873, estimated training loss: 2.5879, estimated validation loss: 2.5724
Step: 280    , last seen loss: 2.6747, estimated training loss: 2.5502, estimated validation loss: 2.5292
Step: 320    , last seen loss: 2.5536, estimated training loss: 2.5241, estimated validation loss: 2.5104
Step: 360    , last seen loss: 2.4208, estimat

In [135]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


ULUHES:
Nhou now coond, seedst anhebln, lar averd you ar thill wat ave whougf whe ret $rry pof asse your? nich-He youse; ico:
I your
Sim my dinge a che lime?
The is an bin;
Tor, I ffents, do co ggelsef thol, folgt an whechshf thou ?
Hem:
Therce osor atbe on sout jul of woue?

AI mu ble to sough sos!
Whtou plesssstat at lle noowf romir con the do ticiden,
Mouss thang le.
Unle Lung,
Porne,, ind bpoly louce ridiarthe pir mesert wo engrscosw thert, suand
Ed wherte on bwe theendsry be not, hertry.
Sus Qim, willvece ppoprait,
Rous it hert tend to fete to hof rint Ce's.
HuemENG, Romich bethst go cow slle'n.

The As weler; Rlangelgy:
No not 'uscet, papeaved, thinttrtou Prefelcen heipcowh pu thend noth pame bein dent hu, fuemy ome hists;
Gy fuertr ont is my surpint ay of my iou litw.
Thaj is foor yourtes.
LICH OFILO:
Hiekoves re;
Gospr rumwe or anere I Istm wate tho I but
Wibert make wir le-grtork he heain greste&biontoses.

ORFOK:
Whe thad the her hall Gory:
I wising y rot ainds o
Und as we.


### Scaling Up the Model

In [136]:
class Head(nn.Module):

    key: nn.Linear
    query: nn.Linear
    value: nn.Linear
    dropout: nn.Dropout

    def __init__(self, head_size: int, n_embd: int, block_size:int, dropout_probability: float, device = None):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False, device=device)
        self.query = nn.Linear(n_embd, head_size, bias=False, device=device)
        self.value = nn.Linear(n_embd, head_size, bias=False, device=device)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size, device=device)))
        self.dropout = nn.Dropout(dropout_probability)

    def forward(self, x: Tensor) -> Tensor:
        B, T, C = x.shape

        k = self.key(x) # (B, T, C)
        q = self.query(x) # (B, T, C)

        wei = q @ k.transpose(-2, -1) * C**-0.5 # (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)

        wei = self.dropout(wei)

        v = self.value(x) # (B, T, C)
        out = wei @ v # (B, T, C)

        return out

In [137]:
class MultiHeadAttention(nn.Module):

    heads: nn.ModuleList
    projection: nn.Linear
    dropout: nn.Dropout

    def __init__(self, num_heads: int, head_size: int, n_embd: int, block_size: int, dropout_probability: float, device = None):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size, n_embd, block_size, dropout_probability, device=device) for _ in range(num_heads)])
        self.projection = nn.Linear(num_heads * head_size, n_embd, device=device)
        self.dropout = nn.Dropout(dropout_probability)

    def forward(self, x: Tensor) -> Tensor:
        # Apply the self-attention heads
        out = torch.cat([h(x) for h in self.heads], dim=-1)

        # Apply the projection
        out = self.projection(out)

        # Apply dropout
        out = self.dropout(out)

        return out

In [138]:
class FeedForward(nn.Module):

    net: nn.Sequential

    def __init__(self, n_embd: int, dropout_probability: float, device = None):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(n_embd, n_embd * 4, device=device),
            nn.ReLU(),
            nn.Linear(n_embd * 4, n_embd, device=device), # Projection layer
            nn.Dropout(dropout_probability)
        ).to(device)

    def forward(self, x: Tensor) -> Tensor:
        return self.net(x)

In [139]:
class Block(nn.Module):

    sa: MultiHeadAttention
    ffwd: FeedForward
    ln1: nn.LayerNorm
    ln2: nn.LayerNorm

    def __init__(self, n_embd: int, num_heads: int, block_size: int, dropout_probability: float, device = None):
        super().__init__()
        head_size = n_embd // num_heads
        self.sa = MultiHeadAttention(num_heads, head_size, n_embd, block_size, dropout_probability, device=device)
        self.ffwd = FeedForward(n_embd, dropout_probability, device=device)
        self.ln1 = nn.LayerNorm(n_embd, device=device)
        self.ln2 = nn.LayerNorm(n_embd, device=device)

    def forward(self, x: Tensor) -> Tensor:
        # Adding `x` to the layers is the skip connection
        x = x + self.sa(self.ln1(x))
        x = x + self.ffwd(self.ln2(x))
        return x

In [140]:
class BigramLanguageModel(nn.Module):

    vocab_size: int
    block_size: int
    num_layers: int
    n_embd: int
    num_heads: int
    head_size: int
    dropout_probability: float

    token_embedding_table: nn.Embedding
    position_embedding_table: nn.Embedding
    blocks: nn.Sequential
    ln_f: nn.LayerNorm
    lm_head: nn.Linear

    def __init__(self, vocab_size: int, block_size: int, num_layers: int, n_embd: int, num_heads: int, head_size: int, dropout_probability: float, device=None):
        super().__init__()

        self.vocab_size = vocab_size
        self.block_size = block_size
        self.num_layers = num_layers
        self.n_embd = n_embd
        self.num_heads = num_heads
        self.head_size = head_size
        self.dropout_probability = dropout_probability

        self.token_embedding_table = nn.Embedding(vocab_size, n_embd, device=device)
        self.position_embedding_table = nn.Embedding(block_size, n_embd, device=device)

        self.blocks = nn.Sequential(
            *[Block(n_embd, num_heads, block_size, dropout_probability, device=device) for _ in range(num_layers)]
        ).to(device)
        self.ln_f = nn.LayerNorm(n_embd, device=device)
        self.lm_head = nn.Linear(n_embd, vocab_size, device=device)

    def forward(self, idx: Tensor, targets: typing.Optional[Tensor] = None) -> typing.Tuple[Tensor, typing.Optional[Tensor]]:
        # `idx` and targets are (B,T) tensors (batch size by time). In this case
        # 'time' represents block size.
        B, T = idx.shape

        # `logits` are (B,T,C) tensors, (batch size by time by channel), where
        # the channel dimension comes from the embedding table. Essentially,
        # each character in idx is replaced by an embedding vector of length C
        # (which is the number of embeddings in this case).
        token_embeddings = self.token_embedding_table(idx)

        # Shape is (T,C)
        position_embeddings = self.position_embedding_table(torch.arange(T, device=device))

        # Addition gets broadcasted, shape is (B,T,C)
        x = token_embeddings + position_embeddings

        # Apply self-attention head
        x = self.blocks(x)

        # Apply layer normalization
        x = self.ln_f(x)

        # We then apply the linear layer, which gives us a (B, T, vocab_size)
        # tensor, which are our logits.
        logits = self.lm_head(x)

        logits = typing.cast(Tensor, logits)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape

            logits = logits.view(B * T, C)
            targets = targets.view(B * T)

            loss = F.cross_entropy(logits, targets)

        # If `targets` was not provided, then output `logits` is a 3D tensor of
        # shape:
        #     `(batch_size, block_size, vocab_size)`
        #
        # Otherwise, if `targets` was provided, then output `logits` is a 2D
        # tensor of shape:
        #     `(batch_size * block_size, vocab_size)`
        return logits, loss

    def generate(self, idx: Tensor, max_new_tokens: int) -> Tensor:
        # `idx` is (B,T), which is `(batch_size, block_size)`
        for _ in range(max_new_tokens):
            # Crop idx to the last `block_size` tokens
            idx_cond = idx[:, -block_size:]
            # `logits` is (B,T,C), where C is the channel length (length of
            # embedding vector, in this case it is `vocab_length`)
            logits, loss = self(idx_cond)
            # Get last character of logits - becomes (B, C)
            logits = logits[:, -1, :]
            # Still (B,C)
            probs = F.softmax(logits, dim=-1)
            # Now its (B,1) since we are getting only one sample
            idx_next = torch.multinomial(probs, num_samples=1)
            # Append sampled index to the running sequence - becomes (B,T+1)
            idx = torch.cat((idx, idx_next), dim=1)
        # The final `idx` tensor will be of shape
        #     `(batch_size, block_size + max_steps)`
        return idx

In [141]:
batch_size = 32
max_steps = 1000
learning_rate = 1e-3
eval_iterations = 300
num_layers = 6
n_embd = 32
num_heads = 4
head_size = 32
dropout_probability = 0.2

model = BigramLanguageModel(vocab_size, block_size, num_layers, n_embd, num_heads, head_size, dropout_probability, device=device)

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
model.train()

for step in range(max_steps):
    xb, yb = get_batch(train_data, batch_size=batch_size, block_size=block_size, device=device)
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

    model.eval()
    if max_steps < 25 or step % (max_steps // 25) == 0:
        loss_dict = estimate_loss(model, train_data, val_data, eval_iterations, batch_size, block_size, device)
        print(f'Step: {step:<7}, last seen loss: {loss.item():.4f}, estimated training loss: {loss_dict["train"]:.4f}, estimated validation loss: {loss_dict["val"]:.4f}')

    model.train()

Step: 0      , last seen loss: 4.3326, estimated training loss: 4.1834, estimated validation loss: 4.1998
Step: 40     , last seen loss: 3.2495, estimated training loss: 3.3235, estimated validation loss: 3.3262
Step: 80     , last seen loss: 2.8774, estimated training loss: 3.1198, estimated validation loss: 3.1216
Step: 120    , last seen loss: 2.8405, estimated training loss: 2.8815, estimated validation loss: 2.8844
Step: 160    , last seen loss: 2.8752, estimated training loss: 2.7424, estimated validation loss: 2.7331
Step: 200    , last seen loss: 2.7303, estimated training loss: 2.6533, estimated validation loss: 2.6472
Step: 240    , last seen loss: 2.5362, estimated training loss: 2.5993, estimated validation loss: 2.6003
Step: 280    , last seen loss: 2.5401, estimated training loss: 2.5651, estimated validation loss: 2.5501
Step: 320    , last seen loss: 2.6948, estimated training loss: 2.5276, estimated validation loss: 2.5210
Step: 360    , last seen loss: 2.5846, estimat

In [142]:
idx = torch.zeros((1, 1), dtype=torch.long, device=device)
next_idx = model.generate(idx, max_new_tokens=1000)[0].tolist()
next_str = decode(next_idx)
print(next_str)


Woveveve youd haver dentl spowsst for.

Rotth cof not oloud gath to
Thato dhahes unes De misht;
My.
I warancy bingh Lis how mos haunbltin, thelorsr hy to Vat gas Gom nelarde'd and brendell meejirdor thie thatir cove he louden;
Selay andsl guwt me moblou.

Tpoofe,
CUTDore did my?
Wey vopoulo derpresns stlot, crenow Setirens a ca the ow he forlla,
Guir sh fowesls pard by sour itaiumcenf any bzelak
Goup:g, loom ese hlioshasad os by head
Tubours iniqurist vethee the to fo ditcken
Cand
Sit to yof forther com vat tasth them toe she pri'eed! s'des!
An:
Caakcof I brotr thach woul ant
I I Lor sth tes bise.
Siqun CILORIY

Seatte.

LRCFOLEONIIO:
If my tupthe
SIO:
Bre Wh his mouser lof potrimint Vy of ad duvor mure, ane,
Anillee worcter retal by olo moure
MI:
CHESIIORLI D finsel my, your mace: Rthilouregn.

BIUGY:
I lourg,:
Hu; put analth thavend my onst tove bich indlle tarjto I a teeast day ibered wes: lous reegrprof
Inestouy houth hes ailh, tom pou in:
NGou thereren ticy y ow sars th you brse,