<a href="https://colab.research.google.com/github/ManuSinghYadav/Andrej_Karpathy_Zero_to_Hero/blob/main/AK_Lecture_Series_Nano_GPT_(Full).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### [Andrej's Code](https://colab.sandbox.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing#scrollTo=hoelkOrFY8bN)

In [None]:
# @title Version 1 (Without Blocks, Residual Layers and LayerNorm)

import torch
import torch.nn as nn
from torch.nn import functional as F

# Hyperparameters
batch_size = 32
block_size = 8
max_iters = 10000
learning_rate = 1e-3
eval_interval = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Getting batches
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))


  def forward(self, x):
    B, T, C = x.shape
    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    wei = q @ k.transpose(-2,-1) * C**-0.5
    # Fix: Slice tril to match the current sequence length T
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v
    return out

class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads =  nn.ModuleList([Head(head_size) for h in range(num_heads)])

  def forward(self, x):
    out = torch.cat([h (x) for h in self.heads], dim=2)
    return out

class FeedForward(nn.Module):
  def __init__(self, n_embd):
    super().__init__()
    self.fwd = nn.Sequential(
        nn.Linear(n_embd, n_embd),
        nn.ReLU()
    )

  def forward(self, x):
    return self.fwd(x)

class BigramLanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.positional_embedding_table = nn.Embedding(block_size, n_embd)
    self.sa_head = MultiHeadAttention(4, n_embd//4)
    self.fwd = FeedForward(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    """idx (B, T), targets (B, T)."""
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx) # (B, T, C)
    pos_emb = self.positional_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    x = self.sa_head(x)
    fwd = self.fwd(x)
    logits = self.lm_head(fwd)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for i in range(max_new_tokens):
      idx_cond = idx[:, -block_size:]
      logits, loss = self.forward(idx_cond)
      logits = logits[:,-1,:] # Picking the last row of each batch.
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx


# Defining our model
model = BigramLanguageModel().to(device)

# Training
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
  xb, yb = get_batch('train')
  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if iter % eval_interval == 0:
    losses = estimate_loss()
    print(f"{iter: 5d} / {max_iters} : train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  if iter == 8000:
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate*0.00001)  # Learning rate decay

# Printing output
output = model.generate(torch.tensor([[13]], device=device), 500)
print(decode(output[0].tolist()))

--2025-07-10 07:14:25--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.6’


2025-07-10 07:14:25 (16.6 MB/s) - ‘input.txt.6’ saved [1115394/1115394]

    0 / 10000 : train loss 4.1902, val loss 4.1902
  500 / 10000 : train loss 2.5846, val loss 2.5954
 1000 / 10000 : train loss 2.4651, val loss 2.4678
 1500 / 10000 : train loss 2.4005, val loss 2.4000
 2000 / 10000 : train loss 2.3406, val loss 2.3531
 2500 / 10000 : train loss 2.3139, val loss 2.3329
 3000 / 10000 : train loss 2.2982, val loss 2.3089
 3500 / 10000 : train loss 2.2641, val loss 2.2855
 4000 / 10000 : train loss 2.2536, val loss 2.2546
 4500 / 10000 : tr

In [None]:
# @title Version 2 (With Blocks, Residual Layers and LayerNorm)

import torch
import torch.nn as nn
from torch.nn import functional as F

# Hyperparameters
batch_size = 32
block_size = 8
max_iters = 10000
learning_rate = 1e-3
eval_interval = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 32

torch.manual_seed(1337)

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Getting batches
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))


  def forward(self, x):
    B, T, C = x.shape
    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    wei = q @ k.transpose(-2,-1) * C**-0.5
    # Fix: Slice tril to match the current sequence length T
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    out = wei @ v
    return out

class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads =  nn.ModuleList([Head(head_size) for h in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)

  def forward(self, x):
    out = torch.cat([h (x) for h in self.heads], dim=2)
    out = self.proj(out)
    return out

class FeedForward(nn.Module):
  def __init__(self, n_embd):
    super().__init__()
    self.fwd = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
    )

  def forward(self, x):
    return self.fwd(x)

class Block(nn.Module):
  def __init__(self, n_embd, n_head):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)
    self.ffwd = FeedForward(n_embd)

  def forward(self, x):
    x = self.sa(self.ln1(x))  # Residual Add
    x = self.ffwd(self.ln2(x))
    return x

class BigramLanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.positional_embedding_table = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(
        Block(n_embd, n_head=4),
        Block(n_embd, n_head=4),
        Block(n_embd, n_head=4),
    )
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    """idx (B, T), targets (B, T)."""
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx) # (B, T, C)
    pos_emb = self.positional_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    x = self.blocks(x)
    x = self.ln_f(x)
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for i in range(max_new_tokens):
      idx_cond = idx[:, -block_size:]
      logits, loss = self.forward(idx_cond)
      logits = logits[:,-1,:] # Picking the last row of each batch.
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx


# Defining our model
model = BigramLanguageModel().to(device)

# Training
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
  xb, yb = get_batch('train')
  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if iter % eval_interval == 0:
    losses = estimate_loss()
    print(f"{iter: 5d} / {max_iters} : train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  if iter == 8000:
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate*0.00001)  # Learning rate decay

# Printing output
output = model.generate(torch.tensor([[13]], device=device), 500)
print(decode(output[0].tolist()))

--2025-07-11 07:22:44--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.111.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.111.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2025-07-11 07:22:45 (30.2 MB/s) - ‘input.txt’ saved [1115394/1115394]

    0 / 10000 : train loss 3.9622, val loss 3.9798
  500 / 10000 : train loss 2.7101, val loss 2.7011
 1000 / 10000 : train loss 2.5202, val loss 2.4925
 1500 / 10000 : train loss 2.4079, val loss 2.4091
 2000 / 10000 : train loss 2.3565, val loss 2.3719
 2500 / 10000 : train loss 2.2864, val loss 2.3039
 3000 / 10000 : train loss 2.2626, val loss 2.2912
 3500 / 10000 : train loss 2.2217, val loss 2.2403
 4000 / 10000 : train loss 2.1802, val loss 2.2264
 4500 / 10000 : train 

In [None]:
# @title Version 3 (Scaling it up)

import torch
import torch.nn as nn
from torch.nn import functional as F

# Hyperparameters
batch_size = 64
block_size = 256
max_iters = 10000
learning_rate = 3e-3
eval_interval = 500
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 6
dropout = 0.2

torch.manual_seed(1337)

!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(list(set(text)))
vocab_size = len(chars)
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data))
train_data = data[:n]
val_data = data[n:]

# Getting batches
def get_batch(split):
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Head(nn.Module):
  def __init__(self, head_size):
    super().__init__()
    self.query = nn.Linear(n_embd, head_size, bias=False)
    self.key = nn.Linear(n_embd, head_size, bias=False)
    self.value = nn.Linear(n_embd, head_size, bias=False)
    self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    B, T, C = x.shape
    q = self.query(x)
    k = self.key(x)
    v = self.value(x)
    wei = q @ k.transpose(-2,-1) * C**-0.5
    # Fix: Slice tril to match the current sequence length T
    wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
    wei = F.softmax(wei, dim=-1)
    wei = self.dropout(wei)
    out = wei @ v
    return out

class MultiHeadAttention(nn.Module):
  def __init__(self, num_heads, head_size):
    super().__init__()
    self.heads =  nn.ModuleList([Head(head_size) for h in range(num_heads)])
    self.proj = nn.Linear(n_embd, n_embd)
    self.dropout = nn.Dropout(dropout)

  def forward(self, x):
    out = torch.cat([h (x) for h in self.heads], dim=2)
    out = self.dropout(self.proj(out))
    return out

class FeedForward(nn.Module):
  def __init__(self, n_embd):
    super().__init__()
    self.fwd = nn.Sequential(
        nn.Linear(n_embd, 4 * n_embd),
        nn.ReLU(),
        nn.Linear(4 * n_embd, n_embd),
    )

  def forward(self, x):
    return self.fwd(x)

class Block(nn.Module):
  def __init__(self, n_embd, n_head):
    super().__init__()
    head_size = n_embd // n_head
    self.sa = MultiHeadAttention(n_head, head_size)
    self.ln1 = nn.LayerNorm(n_embd)
    self.ln2 = nn.LayerNorm(n_embd)
    self.ffwd = FeedForward(n_embd)

  def forward(self, x):
    x = self.sa(self.ln1(x))  # Residual Add
    x = self.ffwd(self.ln2(x))
    return x

class BigramLanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
    self.positional_embedding_table = nn.Embedding(block_size, n_embd)
    self.blocks = nn.Sequential(*[Block(n_embd, n_head) for i in range(3)])
    self.ln_f = nn.LayerNorm(n_embd)
    self.lm_head = nn.Linear(n_embd, vocab_size)

  def forward(self, idx, targets=None):
    """idx (B, T), targets (B, T)."""
    B, T = idx.shape
    tok_emb = self.token_embedding_table(idx) # (B, T, C)
    pos_emb = self.positional_embedding_table(torch.arange(T, device=device))
    x = tok_emb + pos_emb
    x = self.blocks(x)
    x = self.ln_f(x)
    logits = self.lm_head(x)

    if targets is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, idx, max_new_tokens):
    for i in range(max_new_tokens):
      idx_cond = idx[:, -block_size:]
      logits, loss = self.forward(idx_cond)
      logits = logits[:,-1,:] # Picking the last row of each batch.
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)
    return idx


# Defining our model
model = BigramLanguageModel().to(device)

# Training
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):
  xb, yb = get_batch('train')
  logits, loss = model(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()
  if iter % eval_interval == 0:
    losses = estimate_loss()
    print(f"{iter: 5d} / {max_iters} : train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

  if iter == 8000:
    optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate*0.00001)  # Learning rate decay

# Printing output
output = model.generate(torch.tensor([[13]], device=device), 500)
print(decode(output[0].tolist()))

--2025-07-11 09:43:51--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt.2’


2025-07-11 09:43:51 (25.9 MB/s) - ‘input.txt.2’ saved [1115394/1115394]

    0 / 10000 : train loss 4.4665, val loss 4.4996
  500 / 10000 : train loss 3.3165, val loss 3.3520
 1000 / 10000 : train loss 3.3105, val loss 3.3438
 1500 / 10000 : train loss 3.3094, val loss 3.3464
 2000 / 10000 : train loss 3.3104, val loss 3.3499
 2500 / 10000 : train loss 3.3117, val loss 3.3463
 3000 / 10000 : train loss 3.3101, val loss 3.3502
 3500 / 10000 : train loss 3.3083, val loss 3.3474


In [None]:
# @title Remimplemetation
import torch
import torch.nn as nn
import torch.nn.functional as F

# !wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

chars = sorted(set(text))

device = 'cuda' if torch.cuda.is_available() else 'cpu'
vocab_size = len(chars)
context_window = 128
batch_size = 64
emb_size = 256
max_iters = 5000
lr = 1e-3
head_dim = 32
n_head = emb_size // head_dim
n_layer = 8

itos = {i:v for i, v in enumerate(chars)}
stoi = {v:i for i, v in enumerate(chars)}

encode = lambda x : [stoi[i] for i in x]
decode = lambda x : [itos[i] for i in x]

enocded_data = torch.tensor(encode(text))

def train_test_split():
  n = int(len(enocded_data) * 0.9)
  train = enocded_data[:n]
  test = enocded_data[n:]
  return train, test

train, test = train_test_split()  # Remove

def get_batch(split):
  data = train if split == 'train' else test
  batch = torch.randint(0, len(data) - context_window, (batch_size,))
  ix = torch.stack([data[i: i+context_window] for i in batch])
  yx = torch.stack([data[i+1: i+context_window+1] for i in batch])
  ix, yx = ix.to(device), yx.to(device)
  return ix, yx

class Head(nn.Module):
  def __init__(self):
    super().__init__()
    self.query = nn.Linear(emb_size, head_dim)
    self.key = nn.Linear(emb_size, head_dim)
    self.value = nn.Linear(emb_size, head_dim)

  def forward(self, x):
    B,T,C = x.shape

    q = self.query(x)
    k = self.key(x)
    v = self.value(x)

    atten = (q @ k.transpose(-2, -1)) / (head_dim ** 0.5)
    tril = torch.tril(torch.ones(context_window, context_window, device=device))
    atten = atten.masked_fill(tril[:T, :T]==0, float('-inf'))
    atten = F.softmax(atten, -1)
    out = atten @ v
    return out

class MultiHeadAttention(nn.Module):
  def __init__(self):
    super().__init__()
    self.heads = nn.ModuleList([Head() for i in range(n_head)])

  def forward(self, x):
    return torch.cat([h(x) for h in self.heads], -1)

class FeedForward(nn.Module):
  def __init__(self):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(emb_size, 4*emb_size),
        nn.ReLU(),
        nn.Linear(4*emb_size, emb_size),
    )
  def forward(self, x):
    return self.net(x)

class Block(nn.Module):
  def __init__(self):
    super().__init__()
    self.sa_head = MultiHeadAttention()
    self.lm_head = FeedForward()
    self.ln1 = nn.LayerNorm(emb_size)
    self.ln2 = nn.LayerNorm(emb_size)

  def forward(self, x):
    ln1_out = self.ln1(x)
    sa_out = x + self.sa_head(ln1_out)
    ln2_out = self.ln1(sa_out)
    lm_out = x + self.lm_head(ln2_out)
    return lm_out

class BigramLanguageModel(nn.Module):
  def __init__(self):
    super().__init__()
    self.nn_emb = nn.Embedding(vocab_size, emb_size)
    self.nn_pos = nn.Embedding(context_window, emb_size)
    self.blocks = nn.Sequential(*[Block() for i in range(n_layer)])
    self.nn_linear = nn.Linear(emb_size, vocab_size)

  def forward(self, x, y=None):
    B, T = x.shape
    emb = self.nn_emb(x)
    pos_emb = self.nn_pos(torch.arange(T, device=device))
    pos_emb = pos_emb + emb
    block_out = self.blocks(pos_emb)
    logits = self.nn_linear(block_out)

    if y is None:
      loss = None
    else:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = y.view(B*T)
      loss = F.cross_entropy(logits, targets)
    return logits, loss

  def generate(self, ix, max_new_tokens):
    for i in range(max_new_tokens):
      ix_cond = ix[:, -context_window:]
      logits, loss = self.forward(ix_cond)
      logits = logits[:,-1,:]
      probs = F.softmax(logits, -1)
      ixn = torch.multinomial(probs, num_samples=1)
      ix = torch.cat((ix, ixn), dim=1)
    return ix

model = BigramLanguageModel().to(device)
print(f"Total Parameters: {sum(i.numel() for i in model.parameters()):,}")

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

Total Parameters: 5,857,857


In [None]:
for i in range(max_iters):
  ix, yx = get_batch('train')
  logits, loss = model(ix, yx)
  optimizer.zero_grad(set_to_none=True)
  if i % 250 == 0:
    print(f"{i:5d} /  {max_iters} : loss is {loss:.4f}")
  loss.backward()
  optimizer.step()

# Generation
print(''.join(decode(model.generate(torch.tensor([[2]], device=device), 1000)[0].tolist())))

    0 /  5000 : loss is 4.5459
  250 /  5000 : loss is 2.0996
  500 /  5000 : loss is 1.7375
  750 /  5000 : loss is 1.5076
 1000 /  5000 : loss is 1.4073
 1250 /  5000 : loss is 1.3664
 1500 /  5000 : loss is 1.3197
 1750 /  5000 : loss is 1.2654
 2000 /  5000 : loss is 1.2172
 2250 /  5000 : loss is 1.2118
 2500 /  5000 : loss is 1.1514
 2750 /  5000 : loss is 1.1526
 3000 /  5000 : loss is 1.0938
 3250 /  5000 : loss is 1.0523
 3500 /  5000 : loss is 1.0299
 3750 /  5000 : loss is 0.9833
 4000 /  5000 : loss is 0.9563
 4250 /  5000 : loss is 0.9371
 4500 /  5000 : loss is 0.8295
 4750 /  5000 : loss is 0.8154
!

KING EDWARD IV:
Soft, tell me, is it doth. Answer I,
Let it remembers smarves fetch at roar'd,
As you shall lay the battle subject and deceive
And dissolute your lusty for your debt.

ROMEO:
That she is tooth touch by my descent.
Sir, bear the swines he knew my cousins
Before that thy father rises with thy cheeks, which before
My company may half made special we in our mistr

In [None]:
4:10 - 4:25

tensor(12)

In [None]:
xt, yt = get_batch('test')
logits, loss = model(xt, yt)
loss

tensor(1.9263, device='cuda:0', grad_fn=<NllLossBackward0>)