In [None]:
# https://arxiv.org/abs/2410.01201
# were RNNs All We Needed

from tqdm import tqdm
import torch
import tiktoken
import urllib.request

enc = tiktoken.get_encoding("gpt2")
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/refs/heads/master/data/tinyshakespeare/input.txt'
filename = 'shakespeare.txt'
urllib.request.urlretrieve(url, filename)

docs = open("shakespeare.txt", "r").read().splitlines()
len(docs)

40000

In [2]:
tokens= set(enc.encode(" ".join(docs)))
tokens_ids = [int(i) for i in tokens]
Vocab = {}
for i in tokens_ids:
    token = enc.decode([i])
    Vocab[token] = i
vocab_size = len(Vocab)
print(vocab_size)


11387


In [3]:
block_size = 64 # context length: how many tokens do we take to predict the next one?

def build_dataset(docs):  
    X, Y = [], []
  
    for doc in docs:
        context = [0] * block_size
        for token in doc:
            ix = Vocab.get(token)
            if ix is None:
                continue  # skip token if it's not in the vocabulary
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix]  # crop and append

    X = torch.tensor(X, dtype=torch.long)
    Y = torch.tensor(Y, dtype=torch.long)
    print(f"X shape: {X.shape}, Y shape: {Y.shape}")
    return X, Y

n = int(len(docs) * 0.9)
Xtr, Ytr = build_dataset(docs[:n])
Xte, Yte = build_dataset(docs[n:])

X shape: torch.Size([978353, 64]), Y shape: torch.Size([978353])
X shape: torch.Size([95053, 64]), Y shape: torch.Size([95053])


In [82]:
n_embd = 10
hidden_size = 256
block_size = 64
batch_size = 32
g = torch.Generator().manual_seed(983245) # for reproducibility
C = torch.randn((vocab_size, n_embd),                   generator=g)
# Parameters for the first minGRU layer
Zt1_W  = torch.randn((block_size * n_embd, hidden_size), generator=g) * (1)/((block_size * n_embd)**0.5)
Zt1_B  = torch.randn (hidden_size                      , generator=g) * 0 
Ht1_W  = torch.randn((block_size * n_embd, hidden_size), generator=g) * (1)/((block_size * n_embd)**0.5)
Ht1_B  = torch.randn(hidden_size                       , generator=g) * 0
h_prev = torch.zeros((batch_size, hidden_size))

# Parameters for the second min GRU layer
Zt2_W = torch.randn((hidden_size, hidden_size),         generator=g) * (1)/((hidden_size)**0.5) # input-to-gate
Zt2_B = torch.randn(hidden_size,                        generator=g) * 0
Ht2_W = torch.randn((hidden_size, hidden_size),         generator=g) * (1)/((hidden_size)**0.5) # input-to-candidate
Ht2_B = torch.randn(hidden_size,                        generator=g) * 0
W     = torch.randn((hidden_size, vocab_size),          generator=g) * 0.01
B     = torch.randn(vocab_size) * 0


# BatchNorm parameters
bngain = torch.ones((1, hidden_size))
bnbias = torch.zeros((1, hidden_size))
bnmean_running = torch.zeros((1, hidden_size))
bnstd_running = torch.ones((1, hidden_size))


parameters = [C, Zt1_W, Zt1_B, Ht1_W, Ht1_B, Zt2_W, Zt2_B, Ht2_W, Ht2_B, W, B, bngain, bnbias]
print(sum(p.nelement() for p in parameters)) # number of parameters in total
for p in parameters:
  p.requires_grad = True




3500617


In [84]:
max_steps = 50000
lossi = []

loss_fn = torch.nn.CrossEntropyLoss()
# x_t: (batch_size, input_size)
# h_prev: (batch_size, hidden_size)
for step in tqdm(range(max_steps), desc='Training Progress', unit=' steps'):
    minibatch = torch.randint(0, Xtr.shape[0], (batch_size,), generator=g)
    Xb, Yb = Xtr[minibatch], Ytr[minibatch]
    emb = C[Xb]
    embcat = emb.view(emb.shape[0], -1)
    Ztpreact = embcat @ Zt1_W + Zt1_B
    Htilde = embcat @ Ht1_W + Ht1_B 

    bnmeani = Ztpreact.mean(0, keepdim=True)
    bnstdi = Ztpreact.std(0, keepdim=True)
    Ztpreact = bngain * (Ztpreact - bnmeani) / bnstdi + bnbias
    with torch.no_grad():
        bnmean_running = 0.999 * bnmean_running + 0.001 * bnmeani
        bnstd_running = 0.999 * bnstd_running + 0.001 * bnstdi 
    Zt =  torch.sigmoid(Ztpreact)

    H1_t = (1-Zt) * h_prev + Zt * Htilde # first hidden state
    # Suppose H1_t is the output from the first minGRU layer
    Zt2 = torch.sigmoid(H1_t @ Zt2_W + Zt2_B)
    Htilde2 = H1_t @ Ht2_W + Ht2_B
    H2_t = (1 - Zt2) * H1_t + Zt2 * Htilde2
    outputs = H2_t @ W + B 
    loss = loss_fn(outputs.view(-1, vocab_size), Yb)
    for p in parameters:
        p.grad = None
    loss.backward()
    lr = 0.1 if step < 30000 else 0.01
    for p in parameters:
        p.data += -lr * p.grad

    if step % 1000 == 0:
        tqdm.write(f'Step {step}, Loss: {loss.item()}')


Training Progress:   0%|          | 14/50000 [00:00<12:19, 67.55 steps/s]

Step 0, Loss: 2.1301419734954834


Training Progress:   2%|▏         | 1011/50000 [00:16<12:44, 64.11 steps/s]

Step 1000, Loss: 1.7428120374679565


Training Progress:   4%|▍         | 2009/50000 [00:29<11:02, 72.46 steps/s]

Step 2000, Loss: 1.4959784746170044


Training Progress:   6%|▌         | 3008/50000 [00:43<10:30, 74.56 steps/s]

Step 3000, Loss: 1.960930585861206


Training Progress:   8%|▊         | 4008/50000 [00:56<10:25, 73.58 steps/s]

Step 4000, Loss: 2.287388324737549


Training Progress:  10%|█         | 5017/50000 [01:10<09:47, 76.57 steps/s]

Step 5000, Loss: 2.3840837478637695


Training Progress:  12%|█▏        | 6009/50000 [01:24<09:18, 78.77 steps/s]

Step 6000, Loss: 2.0343244075775146


Training Progress:  14%|█▍        | 7014/50000 [01:38<10:11, 70.35 steps/s]

Step 7000, Loss: 2.069871187210083


Training Progress:  16%|█▌        | 8005/50000 [01:53<14:54, 46.97 steps/s]

Step 8000, Loss: 1.3308762311935425


Training Progress:  18%|█▊        | 9005/50000 [02:06<09:34, 71.32 steps/s]

Step 9000, Loss: 1.7264564037322998


Training Progress:  20%|██        | 10008/50000 [02:19<09:29, 70.20 steps/s]

Step 10000, Loss: 2.6870265007019043


Training Progress:  22%|██▏       | 11008/50000 [02:36<11:11, 58.03 steps/s]

Step 11000, Loss: 1.5227341651916504


Training Progress:  24%|██▍       | 12014/50000 [02:51<08:13, 77.01 steps/s]

Step 12000, Loss: 1.8487526178359985


Training Progress:  26%|██▌       | 13008/50000 [03:05<08:08, 75.68 steps/s]

Step 13000, Loss: 2.2663402557373047


Training Progress:  28%|██▊       | 14011/50000 [03:21<09:56, 60.33 steps/s]

Step 14000, Loss: 1.3932157754898071


Training Progress:  30%|███       | 15015/50000 [03:36<07:42, 75.63 steps/s]

Step 15000, Loss: 1.6852589845657349


Training Progress:  32%|███▏      | 16011/50000 [03:49<07:37, 74.25 steps/s]

Step 16000, Loss: 1.726888656616211


Training Progress:  34%|███▍      | 17012/50000 [04:05<07:46, 70.67 steps/s]

Step 17000, Loss: 1.9028089046478271


Training Progress:  35%|███▌      | 17527/50000 [04:13<07:50, 69.04 steps/s]


KeyboardInterrupt: 