In [36]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [4]:
words = open('names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [5]:
# build the vocabulary of characters and mappings to/from integers
chars = sorted(list(set(''.join(words))))
stoi = {s: i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i: s for s, i in stoi.items()}

In [7]:
# build the dataset

block_size = 3
X, Y = [], []

for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--------->', itos[ix])
        context = context[1:] + [ix]  # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)


emma
... ---------> e
..e ---------> m
.em ---------> m
emm ---------> a
mma ---------> .
olivia
... ---------> o
..o ---------> l
.ol ---------> i
oli ---------> v
liv ---------> i
ivi ---------> a
via ---------> .
ava
... ---------> a
..a ---------> v
.av ---------> a
ava ---------> .
isabella
... ---------> i
..i ---------> s
.is ---------> a
isa ---------> b
sab ---------> e
abe ---------> l
bel ---------> l
ell ---------> a
lla ---------> .
sophia
... ---------> s
..s ---------> o
.so ---------> p
sop ---------> h
oph ---------> i
phi ---------> a
hia ---------> .


In [9]:
C = torch.randn((27, 2))

In [10]:
C[torch.tensor([5, 6, 7, 5])]

tensor([[ 0.8101, -0.9739],
        [-0.4678, -0.2861],
        [ 0.3730,  0.0889],
        [ 0.8101, -0.9739]])

In [13]:
C[X][13, 2]

tensor([ 1.3418, -0.1721])

In [15]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [18]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [21]:
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([32, 6])

In [24]:
# a more efficient way
emb.view(32, 6) == torch.cat(torch.unbind(emb, 1), 1)

tensor([[True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, True, True],
        [True, True, True, True, T

In [27]:
h = torch.tanh(emb.view(-1, 6) @ W1 + b1)
h

tensor([[ 0.2339, -0.0139,  0.8062,  ...,  0.9939, -0.7337,  0.2307],
        [-0.9978, -0.2165,  0.5344,  ...,  0.9731, -0.7164,  0.1805],
        [-0.5041, -0.3469,  0.9928,  ...,  0.9980, -0.9876, -0.9914],
        ...,
        [ 0.7157, -0.9600, -0.1182,  ...,  0.9982, -0.7500, -0.8613],
        [-0.9794,  0.9890,  0.4482,  ...,  0.9997,  0.4104, -0.9807],
        [-0.7778,  0.0912,  0.9986,  ...,  0.9522, -0.2538, -0.7991]])

In [28]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

logits = h @ W2 + b2

In [29]:
counts = logits.exp()
prob = counts / counts.sum(1, keepdim=True)

torch.Size([32, 27])

In [32]:
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(15.8784)

In [33]:
# cleaning up a bit

g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, b1, W2, b2]
sum(p.nelement() for p in parameters)  # total # of parameters

3481

In [39]:
for p in parameters:
    p.requires_grad = True

In [41]:
for _ in range(10):
    # forward pass
    emb = C[X]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1)  # (32, 100)
    logits = h @ W2 + b2  # (32, 27)
    # counts = logits.exp()
    # prob = counts / counts.sum(1, keepdim=True)
    # loss = -prob[torch.arange(32), Y].log().mean()
    loss = F.cross_entropy(logits, Y)  # type: ignore
    print(loss)

    # cross_entropy is simpler & more efficient
    # also more numerically well-behaved with very large logits since exp() can go to inf

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data += -0.1 * p.grad # type: ignore

tensor(13.7737, grad_fn=<NllLossBackward0>)
tensor(11.2711, grad_fn=<NllLossBackward0>)
tensor(9.4690, grad_fn=<NllLossBackward0>)
tensor(8.0027, grad_fn=<NllLossBackward0>)
tensor(6.8788, grad_fn=<NllLossBackward0>)
tensor(6.0633, grad_fn=<NllLossBackward0>)
tensor(5.4135, grad_fn=<NllLossBackward0>)
tensor(4.8626, grad_fn=<NllLossBackward0>)
tensor(4.3832, grad_fn=<NllLossBackward0>)
tensor(3.9575, grad_fn=<NllLossBackward0>)
