<a href="https://colab.research.google.com/github/adityaghai07/ML-Projects/blob/main/gpt_from_scratch_by_karparthy.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2026-01-30 16:43:18--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.111.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2026-01-30 16:43:18 (19.8 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [2]:
with open('input.txt','r', encoding='utf-8') as f:
  text =  f.read()

In [3]:
len(text)

1115394

In [4]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [5]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
vocab_size
print(''.join(chars))


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz


In [6]:
stoi = { s:i  for i,s in enumerate(chars)}
itos = { i:s   for i,s in enumerate(chars)}
# take a string and give tokens -> encoder!


In [7]:
encoder =  lambda s: [stoi[c] for c in s]
decoder =  lambda l: [itos[i] for i in l]
encoder("Hi Aditya!")
decoder(encoder("Hi Aditya!"))

['H', 'i', ' ', 'A', 'd', 'i', 't', 'y', 'a', '!']

In [8]:
import torch
data = torch.tensor(encoder(text), dtype=torch.long)
print(data.shape, data.dtype)

torch.Size([1115394]) torch.int64


In [9]:
#split the data into train and test
n = int(0.9*(len(data)))
print(n)
train_data = data[:n]
val_data = data[n:]

1003854


In [10]:
# define block size, then how training will be considered, x,y

block_size = 8
x = train_data[:block_size]
y = train_data[1:block_size+1]
# x,y

for t in range(block_size):
  context = x[:t+1]
  targets = y[t]

  print(f'when input is {context}, need to predict {targets}')

when input is tensor([18]), need to predict 47
when input is tensor([18, 47]), need to predict 56
when input is tensor([18, 47, 56]), need to predict 57
when input is tensor([18, 47, 56, 57]), need to predict 58
when input is tensor([18, 47, 56, 57, 58]), need to predict 1
when input is tensor([18, 47, 56, 57, 58,  1]), need to predict 15
when input is tensor([18, 47, 56, 57, 58,  1, 15]), need to predict 47
when input is tensor([18, 47, 56, 57, 58,  1, 15, 47]), need to predict 58


In [11]:
# we will be creating actual batches!

block_size = 8
batch_size = 4

def get_batch(split):
  data = train_data if split == 'train' else val_data
  ix = torch.randint(len(data) - block_size, (batch_size,))
  x = torch.stack([data[i:i+block_size] for i in ix])
  y = torch.stack([data[i+1:i+block_size+1] for i in ix])
  return x,y

xb, yb = get_batch('train')

print(xb,yb)

for b in range(batch_size):
  for t in range(block_size):
    context = xb[b,:t+1]
    target = yb[b,t]
    print(f'when input is {context}, need to predict {target}')

tensor([[58,  1, 61, 47, 58, 46,  1, 58],
        [61, 52,  1, 45, 56, 39, 41, 47],
        [43, 52, 58, 50, 43,  6,  1, 49],
        [50, 41, 53, 51, 43,  1, 63, 53]]) tensor([[ 1, 61, 47, 58, 46,  1, 58, 46],
        [52,  1, 45, 56, 39, 41, 47, 53],
        [52, 58, 50, 43,  6,  1, 49, 47],
        [41, 53, 51, 43,  1, 63, 53, 59]])
when input is tensor([58]), need to predict 1
when input is tensor([58,  1]), need to predict 61
when input is tensor([58,  1, 61]), need to predict 47
when input is tensor([58,  1, 61, 47]), need to predict 58
when input is tensor([58,  1, 61, 47, 58]), need to predict 46
when input is tensor([58,  1, 61, 47, 58, 46]), need to predict 1
when input is tensor([58,  1, 61, 47, 58, 46,  1]), need to predict 58
when input is tensor([58,  1, 61, 47, 58, 46,  1, 58]), need to predict 46
when input is tensor([61]), need to predict 52
when input is tensor([61, 52]), need to predict 1
when input is tensor([61, 52,  1]), need to predict 45
when input is tensor([61

In [12]:
# building a bigram model
import torch
import torch.nn as nn
from torch.nn import functional as F
torch.manual_seed(1337)

class bigramModel(nn.Module):

  def __init__(self,vocab_size):

    super().__init__()
    self.token_embedding_table = nn.Embedding(vocab_size, vocab_size)

  def forward(self, idx, targets=None):
    logits = self.token_embedding_table(idx)
    probs = F.softmax(logits, dim=-1)

    loss = None
    if targets is not None:
      B, T, C = logits.shape
      logits = logits.view(B*T, C)
      targets = targets.view(B*T)
      loss = F.cross_entropy(logits, targets)

    return logits, loss

  def generate(self, idx, max_new_tokens):
    for _ in range(max_new_tokens):
      logits, loss = self(idx)
      logits = logits[:, -1, :]
      probs = F.softmax(logits, dim=-1)
      idx_next = torch.multinomial(probs, num_samples=1)
      idx = torch.cat((idx, idx_next), dim=1)

    return idx



m = bigramModel(vocab_size)
logits, loss = m(xb, yb)
print(logits.shape)
print(loss)


a = m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)
a = a[0].tolist()
print(decoder(a))


torch.Size([32, 65])
tensor(5.1613, grad_fn=<NllLossBackward0>)
['\n', 'S', 'K', 'I', 'c', 'L', 'T', ';', 'A', 'c', 'E', 'L', 'M', 'o', 'T', 'b', 'v', 'Z', 'v', ' ', 'C', '?', 'n', 'q', '-', 'Q', 'E', '3', '3', ':', 'C', 'J', 'q', 'k', 'O', 'K', 'H', '-', 'q', ';', ':', 'l', 'a', '!', 'o', 'i', 'y', 'w', 'k', 'H', 'j', 'g', 'C', 'h', 'z', 'b', 'Q', '?', 'u', '!', '3', 'b', 'L', 'I', 'g', 'w', 'e', 'v', 'm', 'y', 'F', 'J', 'G', 'U', 'G', 'p', '\n', 'w', 'n', 'Y', 'W', 'm', 'n', 'x', 'K', 'W', 'W', 'e', 'v', '-', 't', 'D', 'q', 'X', 'E', 'r', 'V', 'K', 'L', 'g', 'J']


In [13]:
# actual training

batch_size = 32

optimizer = torch.optim.AdamW(m.parameters(), lr=1e-3)



for _ in range(1000):
  xb, yb = get_batch('train')
  logits, loss = m(xb, yb)
  optimizer.zero_grad(set_to_none=True)
  loss.backward()
  optimizer.step()

print(loss.item())

3.704136848449707


In [14]:
a = m.generate(idx=torch.zeros((1,1), dtype=torch.long), max_new_tokens=100)
a = a[0].tolist()
c = decoder(a)
print(''.join(c))


Wh;;Sq.f ustNzknc
kwgOj$dhPWr,SV?hsusiKpgXXUh;Apmem d?hESXI.i;TrJgkiF-oKbXCAA -botrngFCHAUQkn$

pn$w


In [15]:
# now we will talk about attention!
# math!!!

B, T, C = 4, 8, 2

x = torch.randint(1,10,(B,T,C)).float()
wei = torch.tril(torch.ones((T,T)))
wei = wei/wei.sum(1, keepdim=True)



out = wei @ x
out



tensor([[[5.0000, 2.0000],
         [3.5000, 4.5000],
         [3.6667, 6.0000],
         [3.2500, 5.5000],
         [4.4000, 6.2000],
         [4.0000, 5.8333],
         [4.7143, 5.7143],
         [4.5000, 5.8750]],

        [[3.0000, 2.0000],
         [3.5000, 5.5000],
         [3.0000, 4.0000],
         [4.0000, 3.7500],
         [3.6000, 3.2000],
         [3.3333, 4.0000],
         [3.0000, 4.2857],
         [2.7500, 4.5000]],

        [[2.0000, 8.0000],
         [5.0000, 4.5000],
         [4.3333, 6.0000],
         [5.2500, 5.2500],
         [4.8000, 5.4000],
         [5.0000, 6.0000],
         [5.1429, 5.2857],
         [4.8750, 4.8750]],

        [[8.0000, 5.0000],
         [4.5000, 5.0000],
         [5.3333, 3.6667],
         [5.7500, 4.2500],
         [5.0000, 4.2000],
         [5.5000, 4.8333],
         [4.8571, 5.1429],
         [5.1250, 5.0000]]])

In [16]:
# the same with softmax

wei = torch.tril(torch.ones((T,T)))
wei = wei.masked_fill(wei == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)
wei

out = wei @ x
out

tensor([[[5.0000, 2.0000],
         [3.5000, 4.5000],
         [3.6667, 6.0000],
         [3.2500, 5.5000],
         [4.4000, 6.2000],
         [4.0000, 5.8333],
         [4.7143, 5.7143],
         [4.5000, 5.8750]],

        [[3.0000, 2.0000],
         [3.5000, 5.5000],
         [3.0000, 4.0000],
         [4.0000, 3.7500],
         [3.6000, 3.2000],
         [3.3333, 4.0000],
         [3.0000, 4.2857],
         [2.7500, 4.5000]],

        [[2.0000, 8.0000],
         [5.0000, 4.5000],
         [4.3333, 6.0000],
         [5.2500, 5.2500],
         [4.8000, 5.4000],
         [5.0000, 6.0000],
         [5.1429, 5.2857],
         [4.8750, 4.8750]],

        [[8.0000, 5.0000],
         [4.5000, 5.0000],
         [5.3333, 3.6667],
         [5.7500, 4.2500],
         [5.0000, 4.2000],
         [5.5000, 4.8333],
         [4.8571, 5.1429],
         [5.1250, 5.0000]]])

In [17]:
# tring simple self attention with a single head
B, T, C = 4, 8, 32
x = torch.randint(1,10,(B,T,C)).float()
head_size = 16
query = nn.Linear(C, head_size, bias = False)
key = nn.Linear(C, head_size, bias = False)
# value = nn.Linear(C, head_size, bias = False)
q = query(x)  # (B,T,16)
k = key(x)    # (B,T,16)

wei = q @ k.transpose(-2, -1)   # B,T,16 x B,16,T

tril = torch.tril(torch.ones(T, T))
#wei = torch.zeros((T,T))
wei = wei.masked_fill(tril == 0, float('-inf'))
wei = F.softmax(wei, dim=-1)

values = nn.Linear(C, head_size, bias=False)
v = values(x)   # B,T,16
out = wei @ v
# wei.shape, v.shape

In [18]:
xa, ya = get_batch('train')
xa.shape, ya.shape

(torch.Size([32, 8]), torch.Size([32, 8]))

In [19]:
data.shape

torch.Size([1115394])

In [20]:
import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 1000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 4
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * head_size**-0.5
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.value(x)
        out = wei @ v
        return out

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # self.sa_head = Head(n_embd)  # NEW: single attention head
        self.sa_head = MultiHeadAttention(n_head, n_embd//n_head)  # NEW
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        x = self.sa_head(x)  # NEW: apply attention
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


model = BigramLanguageModel(vocab_size)
m = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

step 0: train loss 4.1798, val loss 4.1848
step 300: train loss 2.9827, val loss 3.0082
step 600: train loss 2.9813, val loss 2.9955
step 900: train loss 3.0184, val loss 3.0203

Falllldin sir
be ang e surint tulery nueunn  ontinwl.
UL' wea xen we r'en ye fo
fo arerr ma wae o, marfon t waelendan. ire 'st s.k f'ul qane 
ROde  ,
P ho IO; fD' ven flfathe pall? vell nu rr rs ce eat r nsm, anevMCOa  fvrof parerar ecedoene, mn ern fom  mys f eld:ea fo fr' w .la e m t re  b n? hea hen'pisebratu yne n wa ee,  bal frre
Ca.
Din gal we g rnl;  miuer f t lrco abeid fr .
QUL
ouat' wrirae nomyt ttesst teld, m'ay chir;,t   meny fanbn hom; ter alomir, the igra mofinu d e winan, wye f we


In [None]:
# I was moving sequencially, just the bigram, then position tokens, then self attention, then multihead, but better to seperate code from now, will add the feedforward now!

import torch
import torch.nn as nn
from torch.nn import functional as F

# hyperparameters
batch_size = 32 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?
max_iters = 1000
eval_interval = 300
learning_rate = 1e-2
device = 'cuda' if torch.cuda.is_available() else 'cpu'
eval_iters = 200
n_embd = 384
n_head = 4
# ------------

torch.manual_seed(1337)

# wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

# here are all the unique characters that occur in this text
chars = sorted(list(set(text)))
vocab_size = len(chars)
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# Train and test splits
data = torch.tensor(encode(text), dtype=torch.long)
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

# data loading
def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = torch.randint(len(data) - block_size, (batch_size,))
    x = torch.stack([data[i:i+block_size] for i in ix])
    y = torch.stack([data[i+1:i+block_size+1] for i in ix])
    x, y = x.to(device), y.to(device)
    return x, y

@torch.no_grad()
def estimate_loss():
    out = {}
    model.eval()
    for split in ['train', 'val']:
        losses = torch.zeros(eval_iters)
        for k in range(eval_iters):
            X, Y = get_batch(split)
            logits, loss = model(X, Y)
            losses[k] = loss.item()
        out[split] = losses.mean()
    model.train()
    return out

class Block(nn.Module):
  def __init__(self, n_embd, n_head):
    super().__init__()
    self.sa = MultiHeadAttention(n_head, n_embd//n_head)
    self.ffwd = FeedForward(n_embd)
    self.l1 = nn.LayerNorm(n_embd)
    self.l2 = nn.LayerNorm(n_embd)

  def forward(self, x):
    x = x + self.sa(self.l1(x))
    x = x + self.ffwd(self.l2(x))
    return x


class FeedForward(nn.Module):
  def __init__(self, n_embd):
    super().__init__()
    self.net = nn.Sequential(
        nn.Linear(n_embd, 4*n_embd),
        nn.ReLU(),
        nn.Linear(4*n_embd, n_embd)
    )

  def forward(self, x):
    return self.net(x)

class MultiHeadAttention(nn.Module):
    def __init__(self, num_heads, head_size):
        super().__init__()
        self.heads = nn.ModuleList([Head(head_size) for _ in range(num_heads)])
        self.proj = nn.Linear(n_embd, n_embd)

    def forward(self, x):
        out = torch.cat([h(x) for h in self.heads], dim=-1)
        out = self.proj(out)
        return out

class Head(nn.Module):
    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)
        q = self.query(x)
        wei = q @ k.transpose(-2,-1) * C**-0.5   # I know it should be head_size but somehow this is better
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf'))
        wei = F.softmax(wei, dim=-1)
        v = self.value(x)
        out = wei @ v
        return out

class BigramLanguageModel(nn.Module):
    def __init__(self, vocab_size):
        super().__init__()
        self.token_embedding_table = nn.Embedding(vocab_size, n_embd)
        self.position_embedding_table = nn.Embedding(block_size, n_embd)
        # self.sa_head = Head(n_embd)  # NEW: single attention head
        # self.sa_head = MultiHeadAttention(n_head, n_embd//n_head)  # NEW
        # self.ffwd = FeedForward(n_embd)
        self.blocks = nn.Sequential(
            Block(n_embd, n_head),
            Block(n_embd, n_head),
            Block(n_embd, n_head),
            nn.LayerNorm(n_embd)
        )
        self.lm_head = nn.Linear(n_embd, vocab_size)

    def forward(self, idx, targets=None):
        B, T = idx.shape

        tok_emb = self.token_embedding_table(idx)
        pos_emb = self.position_embedding_table(torch.arange(T, device=device))
        x = tok_emb + pos_emb
        # x = self.sa_head(x)  # NEW: apply attention
        # x = self.ffwd(x)  # NEW: apply feedforward
        x = self.blocks(x)
        logits = self.lm_head(x)

        if targets is None:
            loss = None
        else:
            B, T, C = logits.shape
            logits = logits.view(B*T, C)
            targets = targets.view(B*T)
            loss = F.cross_entropy(logits, targets)

        return logits, loss

    def generate(self, idx, max_new_tokens):
        for _ in range(max_new_tokens):
            idx_cond = idx[:, -block_size:]
            logits, loss = self(idx_cond)
            logits = logits[:, -1, :]
            probs = F.softmax(logits, dim=-1)
            idx_next = torch.multinomial(probs, num_samples=1)
            idx = torch.cat((idx, idx_next), dim=1)
        return idx


model = BigramLanguageModel(vocab_size)
m = model.to(device)

# create a PyTorch optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

for iter in range(max_iters):

    # every once in a while evaluate the loss on train and val sets
    if iter % eval_interval == 0:
        losses = estimate_loss()
        print(f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}")

    # sample a batch of data
    xb, yb = get_batch('train')

    # evaluate the loss
    logits, loss = model(xb, yb)
    optimizer.zero_grad(set_to_none=True)
    loss.backward()
    optimizer.step()

# generate from the model
context = torch.zeros((1, 1), dtype=torch.long, device=device)
print(decode(m.generate(context, max_new_tokens=500)[0].tolist()))

step 0: train loss 4.3758, val loss 4.3727
step 300: train loss 2.8301, val loss 2.8251
