In [1]:
import torch
import random
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
# read all words

words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
len(words)

32033

Vocab

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}
vocab_size = len(itos)
print(itos)
print(vocab_size)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
27


Dataset

In [5]:
block_size = 3

def build_dataset(words):

    X, Y = [], []
    for w in words:
        context = [0] * block_size
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context)
            Y.append(ix)
            context = context[1:] + [ix] #crop and append

    X = torch.tensor(X)
    Y = torch.tensor(Y)
    print(X.shape, Y.shape)
    return X, Y


random.seed(42)
random.shuffle(words)
n1 = int(0.8*len(words))
n2 = int(0.9*len(words))
X_train, Y_train = build_dataset(words[:n1])
X_dev, Y_dev = build_dataset(words[n1:n2])
X_test, Y_test = build_dataset(words[n2:])

torch.Size([182625, 3]) torch.Size([182625])
torch.Size([22655, 3]) torch.Size([22655])
torch.Size([22866, 3]) torch.Size([22866])


MLP params

In [6]:
n_embed = 10 #embedding dimensions
n_hidden = 200 #no of neurons in the hidden layer

g = torch.Generator().manual_seed(42)

C = torch.randn((vocab_size, n_embed), generator=g)
W1 = torch.randn((n_embed*block_size, n_hidden), generator=g) * (5/3)/(n_embed*block_size)**0.5
b1 = torch.randn(n_hidden, generator=g) * 0.01
W2 = torch.randn((n_hidden, vocab_size), generator=g) * 0.01
b2 = torch.randn(vocab_size, generator=g) * 0

parameters = [C, W1, b1, W2, b2]
print(sum(p.nelement() for p in parameters)) # total no of params
for p in parameters:
    p.requires_grad = True

11897


Optimization

Eval

In [7]:
@torch.no_grad() #disable grad tracking
def split_loss(split):
    x, y = {
        'train': (X_train, Y_train),
        'val': (X_dev, Y_dev),
        'test': (X_test, Y_test)
    }[split]

    emb = C[x] #(N, block_size, n_embed)
    embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size*n_embed)
    h = torch.tanh(embcat@ W1 + b1)
    logits = h @ W2 +b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')
    

train 3.2980921268463135
val 3.2999227046966553


Sample from the model

In [8]:
g = torch.Generator().manual_seed(42)

for _ in range(20):
    out = []
    context = [0]*block_size
    while True:
        emb = C[torch.tensor([context])]
        h = torch.tanh(emb.view(1, -1) @ W1 + b1)
        logits = h @ W2 + b2
        probs = F.softmax(logits, dim=1)

        #sample from the dist
        ix = torch.multinomial(probs, num_samples=1, generator=g).item()

        #shift the context window
        context = context[1:] + [ix]
        out.append(ix)

        if ix == 0:
            break

    print(''.join(itos[i] for i in out))


xjuguenvtps.
fabiquedxfmubnwmsflaypglzofmwhwlxoln.
epjccuodsgjdmzu.
knxcmjjobdrggbdlpk.
mnqhqyjfbscvghigeaczalcvjwzajwtphjpdmquotcc.
weltxosvgkohobr.
uklnncvrigmydlsoumf.
pjjiewx.
lxmjuhm.
fsckbirdovhgn.
kgoktfkzuacabxa.
atodr.
bxwqzjzdqvtmdampemaqj.
omtafjiirvqtlfkyeumxuoxtame.
ovzqmywog.
acdtqumkorvdyxxhlsogob.
tnslwkgmnfuyccqendhln.
quehejojixfdirndbgcpvrsczagrtpltqc.
jsnq.
dazxygkihhnynvyfjfzgxlvkqncqgahwkig.


In [9]:
max_steps = 200000
batch_size = 32
lossi = []

for i in range(max_steps):

    #minibatch
    ix = torch.randint(0, X_train.shape[0], (batch_size, ), generator=g)
    Xb, Yb = X_train[ix], Y_train[ix] #batch X, Y

    #forward pass
    emb = C[Xb] #embed the characters into vectors
    embcat = emb.view(emb.shape[0], -1)
    h_pre_act = embcat @ W1 + b1
    h = torch.tanh(h_pre_act)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Yb)

    #backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    #learning rate
    lr = 0.1 if i < 100000 else 0.01 # step learning rate decay

    #update
    for p in parameters:
        p.data += -lr * p.grad

    #track stats
    if i % 10000 == 0:
        print(f'{i:7d}/{max_steps:7d}: {loss.item():.4f}')
    lossi.append(loss.log10().item())

      0/ 200000: 3.2844
  10000/ 200000: 2.0816
  20000/ 200000: 2.3091
  30000/ 200000: 1.9938
  40000/ 200000: 2.2890
  50000/ 200000: 2.2568
  60000/ 200000: 2.3548
  70000/ 200000: 2.4330
  80000/ 200000: 2.0860
  90000/ 200000: 2.1756
 100000/ 200000: 2.5186
 110000/ 200000: 1.8009
 120000/ 200000: 1.8870
 130000/ 200000: 2.0851
 140000/ 200000: 1.8776
 150000/ 200000: 2.4181
 160000/ 200000: 1.8908
 170000/ 200000: 1.8506
 180000/ 200000: 1.8193
 190000/ 200000: 2.1520


In [10]:
@torch.no_grad() #disable grad tracking
def split_loss(split):
    x, y = {
        'train': (X_train, Y_train),
        'val': (X_dev, Y_dev),
        'test': (X_test, Y_test)
    }[split]

    emb = C[x] #(N, block_size, n_embed)
    embcat = emb.view(emb.shape[0], -1) # concat into (N, block_size*n_embed)
    h = torch.tanh(embcat@ W1 + b1)
    logits = h @ W2 +b2
    loss = F.cross_entropy(logits, y)
    print(split, loss.item())

split_loss('train')
split_loss('val')
    

train 2.039062023162842
val 2.1046345233917236


Batch Normalization

- Standardizing hidden states to be unit gaussians
- 2015, Ioffe et al

In [11]:
h_pre_act.mean(0, keepdim=True)

tensor([[ 0.2450,  0.8468, -0.3370, -0.6128,  0.2819, -1.4102, -0.6329, -0.1932,
         -0.8002, -0.5700,  0.4880, -1.0280, -0.6705,  0.2041,  0.0656, -1.3523,
         -0.0108,  0.2196, -0.3964, -0.0611, -0.1599,  0.4347, -0.3384,  0.6716,
         -0.5845, -0.1421,  0.4394, -0.9133,  0.7759, -0.2814, -0.6424,  0.8679,
          0.1787,  0.1866, -0.4940,  0.6676,  0.5536, -1.0461, -0.0434, -0.0142,
          1.4273,  0.0497,  1.6385, -1.1905,  0.4529,  0.3203,  0.2545,  1.7487,
          0.7042, -0.3972, -0.4655,  0.3131,  0.1845,  0.3747,  0.4394,  0.2772,
         -0.1256,  0.5866,  0.0908,  0.8287,  0.8183, -0.0711, -1.6706,  0.1434,
         -0.6621,  0.2321,  1.1720, -0.0776,  0.7844, -0.3088,  0.2917,  0.0561,
          1.1711, -1.1904,  1.6098,  0.7403, -1.3572, -0.6413,  0.0484, -0.6799,
          0.3241,  1.1963,  0.1833,  0.5028,  0.9815, -0.4083,  0.4402, -1.1991,
         -1.0565,  0.2991, -0.2325,  0.1594,  0.3470,  0.3566, -0.8424, -0.3876,
          1.8976,  0.3698,  