In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
#read in all the words
words = open('../makemore/names.txt', 'r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
len(words)

32033

In [7]:
# build the vocab
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [8]:
# build the dataset

block_size = 3
X, Y = [], []
for w in words[:5]:
    
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix]
        
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [9]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [10]:
# Embeddings
C = torch.randn((27, 2))

In [13]:
C[5]

tensor([ 0.1986, -0.5341])

In [16]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([ 0.1986, -0.5341])

In [18]:
C[torch.tensor([5,6,7])]

tensor([[ 0.1986, -0.5341],
        [ 0.3257,  1.1056],
        [ 0.3570, -1.3659]])

In [19]:
C[X]

tensor([[[ 2.0233,  0.1450],
         [ 2.0233,  0.1450],
         [ 2.0233,  0.1450]],

        [[ 2.0233,  0.1450],
         [ 2.0233,  0.1450],
         [ 0.1986, -0.5341]],

        [[ 2.0233,  0.1450],
         [ 0.1986, -0.5341],
         [-0.0475, -2.4666]],

        [[ 0.1986, -0.5341],
         [-0.0475, -2.4666],
         [-0.0475, -2.4666]],

        [[-0.0475, -2.4666],
         [-0.0475, -2.4666],
         [-0.3485,  1.3389]],

        [[ 2.0233,  0.1450],
         [ 2.0233,  0.1450],
         [ 2.0233,  0.1450]],

        [[ 2.0233,  0.1450],
         [ 2.0233,  0.1450],
         [-0.4482,  0.1619]],

        [[ 2.0233,  0.1450],
         [-0.4482,  0.1619],
         [ 1.4469,  0.1367]],

        [[-0.4482,  0.1619],
         [ 1.4469,  0.1367],
         [-2.4094,  0.4257]],

        [[ 1.4469,  0.1367],
         [-2.4094,  0.4257],
         [-0.6222,  1.0113]],

        [[-2.4094,  0.4257],
         [-0.6222,  1.0113],
         [-2.4094,  0.4257]],

        [[-0.6222,  1

In [20]:
C[X].shape

torch.Size([32, 3, 2])

In [31]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [32]:
W1 = torch.randn((6, 100))
b1 = torch.randn(100)

In [33]:
torch.cat([emb[:, 0, :],emb[:, 1, :],emb[:, 2, :]], 1)

tensor([[ 2.0233,  0.1450,  2.0233,  0.1450,  2.0233,  0.1450],
        [ 2.0233,  0.1450,  2.0233,  0.1450,  0.1986, -0.5341],
        [ 2.0233,  0.1450,  0.1986, -0.5341, -0.0475, -2.4666],
        [ 0.1986, -0.5341, -0.0475, -2.4666, -0.0475, -2.4666],
        [-0.0475, -2.4666, -0.0475, -2.4666, -0.3485,  1.3389],
        [ 2.0233,  0.1450,  2.0233,  0.1450,  2.0233,  0.1450],
        [ 2.0233,  0.1450,  2.0233,  0.1450, -0.4482,  0.1619],
        [ 2.0233,  0.1450, -0.4482,  0.1619,  1.4469,  0.1367],
        [-0.4482,  0.1619,  1.4469,  0.1367, -2.4094,  0.4257],
        [ 1.4469,  0.1367, -2.4094,  0.4257, -0.6222,  1.0113],
        [-2.4094,  0.4257, -0.6222,  1.0113, -2.4094,  0.4257],
        [-0.6222,  1.0113, -2.4094,  0.4257, -0.3485,  1.3389],
        [ 2.0233,  0.1450,  2.0233,  0.1450,  2.0233,  0.1450],
        [ 2.0233,  0.1450,  2.0233,  0.1450, -0.3485,  1.3389],
        [ 2.0233,  0.1450, -0.3485,  1.3389, -0.6222,  1.0113],
        [-0.3485,  1.3389, -0.6222,  1.0

In [35]:
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([32, 6])

In [37]:
a = torch.arange(18)
a

tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17])

In [38]:
a.shape

torch.Size([18])

In [40]:
a.view(3, 3, 2)

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]]])

In [41]:
a.storage()

 0
 1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
[torch.storage._TypedStorage(dtype=torch.int64, device=cpu) of size 18]

In [42]:
emb.view(32, 6)

tensor([[ 2.0233,  0.1450,  2.0233,  0.1450,  2.0233,  0.1450],
        [ 2.0233,  0.1450,  2.0233,  0.1450,  0.1986, -0.5341],
        [ 2.0233,  0.1450,  0.1986, -0.5341, -0.0475, -2.4666],
        [ 0.1986, -0.5341, -0.0475, -2.4666, -0.0475, -2.4666],
        [-0.0475, -2.4666, -0.0475, -2.4666, -0.3485,  1.3389],
        [ 2.0233,  0.1450,  2.0233,  0.1450,  2.0233,  0.1450],
        [ 2.0233,  0.1450,  2.0233,  0.1450, -0.4482,  0.1619],
        [ 2.0233,  0.1450, -0.4482,  0.1619,  1.4469,  0.1367],
        [-0.4482,  0.1619,  1.4469,  0.1367, -2.4094,  0.4257],
        [ 1.4469,  0.1367, -2.4094,  0.4257, -0.6222,  1.0113],
        [-2.4094,  0.4257, -0.6222,  1.0113, -2.4094,  0.4257],
        [-0.6222,  1.0113, -2.4094,  0.4257, -0.3485,  1.3389],
        [ 2.0233,  0.1450,  2.0233,  0.1450,  2.0233,  0.1450],
        [ 2.0233,  0.1450,  2.0233,  0.1450, -0.3485,  1.3389],
        [ 2.0233,  0.1450, -0.3485,  1.3389, -0.6222,  1.0113],
        [-0.3485,  1.3389, -0.6222,  1.0

In [48]:
h = torch.tanh(emb.view(emb.shape[0], 6) @ W1 + b1)

In [49]:
h

tensor([[ 0.8197, -0.0324,  0.9959,  ..., -0.9977, -1.0000,  0.9990],
        [ 0.3140, -0.9059,  0.9951,  ..., -0.9990, -0.9996,  0.6589],
        [-0.9043, -0.5794,  0.3032,  ..., -0.9992, -1.0000,  0.5536],
        ...,
        [-0.7798,  0.7497,  0.5406,  ..., -0.6555,  0.9942, -0.9225],
        [-0.5983,  0.9279,  0.9909,  ..., -0.0677,  1.0000, -1.0000],
        [-0.5124,  1.0000,  0.3753,  ...,  0.9952,  0.9982, -0.9981]])

In [50]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [51]:
logits = h @ W2 + b2

In [52]:
logits

tensor([[-2.6883e+00, -5.8920e+00, -7.1630e+00,  5.2270e+00,  7.3209e+00,
          9.6966e-02,  2.0839e+00,  2.9406e-01,  1.1726e+01, -4.2984e+00,
         -4.6772e-01, -1.2581e+01, -2.8558e-01,  5.3770e+00,  7.7524e+00,
         -8.2396e+00,  1.3375e+00,  5.4692e+00,  6.2233e+00, -5.8089e+00,
         -3.6569e-01,  1.4966e+01,  1.3558e+01,  8.7811e+00,  7.6344e+00,
          3.9933e+00, -6.8445e+00],
        [-7.1144e+00,  7.6414e+00, -6.5093e+00, -5.9955e-01, -1.4815e-01,
          4.4421e+00,  7.1776e-01,  5.9655e+00,  5.8935e-01, -5.0330e+00,
          6.0961e+00, -2.2013e+00,  8.7650e+00, -1.9746e+00,  1.9109e+00,
         -5.5769e+00, -2.9610e+00,  1.7672e+00, -6.4793e+00, -9.3769e-01,
         -6.9434e+00,  4.5413e+00,  1.3250e+01,  6.0554e+00,  1.0570e+01,
         -4.1057e+00, -3.0274e+00],
        [-1.8583e+00,  2.1015e+00,  1.2863e+01, -5.5000e-01,  4.9604e+00,
          2.0600e+01,  3.4340e+00,  9.9069e+00,  4.2384e+00, -1.1464e+01,
          3.7062e+00,  8.9832e+00, -5.64

In [53]:
counts = logits.exp()

In [55]:
prob = counts / counts.sum(1, keepdims=True)

In [56]:
prob.shape

torch.Size([32, 27])

In [60]:
loss = -prob[torch.arange(32), Y].log().mean()

In [61]:
loss

tensor(15.1882)

In [71]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27, 2), generator=g)
W1 = torch.randn((6, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = [C, W1, b1, W2, b2]

In [72]:
sum(p.nelement() for p in parameters)

3481

In [75]:
emb = C[X] # (32, 3, 10)
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) 
logits = h @ W2 + b2 # (32, 27)
loss = F.cross_entropy(logits, Y)
loss

tensor(17.7697)

In [74]:
F.cross_entropy(logits, Y)

tensor(17.7697)