In [47]:
import torch
import torch.nn.functional as F
import matplotlib as plt
%matplotlib inline

In [2]:
words = open('names.txt','r').read().splitlines()
words[:8]

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia']

In [3]:
len(words)

32033

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [5]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], []
for w in words[:5]:
    
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix] # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [6]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [7]:
X[:6,]

tensor([[ 0,  0,  0],
        [ 0,  0,  5],
        [ 0,  5, 13],
        [ 5, 13, 13],
        [13, 13,  1],
        [ 0,  0,  0]])

In [8]:
Y[:6]

tensor([ 5, 13, 13,  1,  0, 15])

In [9]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27, 2), generator=g)

In [10]:
C[:6,]

tensor([[ 1.5674, -0.2373],
        [-0.0274, -1.1008],
        [ 0.2859, -0.0296],
        [-1.5471,  0.6049],
        [ 0.0791,  0.9046],
        [-0.4713,  0.7868]])

In [11]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [12]:
W1 = torch.randn((6,100), generator=g)
b1 = torch.randn(100, generator=g)

In [13]:
# we would like to do something like this
emb @ W1 + b1
# but emb is the wrong shape (3D)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (96x2 and 6x100)

In [None]:
# we want emb to be 32 by 6, the six being 2 by 3 first layer neurons
torch.cat([emb[:,0,:], emb[:,1,:], emb[:,2,:]], 1).shape

In [14]:
# to cope with any block_size
torch.cat(torch.unbind(emb, 1), 1).shape

torch.Size([32, 6])

In [15]:
# more efficiently, as all torch tensors are stored as one-dimensional
# emb.view(32, 6) == torch.cat(torch.unbind(emb, 1), 1)

In [16]:
xDim0 = X.shape[0]
h = torch.zeros((xDim0, 100))
for i in range(xDim0):
    h[i] = emb.view(-1, 6)[i,:] @ W1 + b1
#     print(h[i])

In [17]:
print(h.shape, h.dtype)
h[0]

torch.Size([32, 100]) torch.float32


tensor([-1.6952e+00,  8.5502e+00,  1.6284e+00, -3.5999e+00, -2.5713e+00,
        -2.3191e+00, -2.3003e+00, -3.1968e-01, -2.0832e+00, -7.9059e-01,
         2.1309e+00, -5.2804e-01, -2.2888e+00,  1.4689e+00, -3.0978e+00,
        -3.1213e+00,  6.6470e-01,  1.4677e+00, -5.6544e+00, -4.6527e-01,
         2.8688e+00, -1.7292e+00,  5.0621e+00, -2.7106e+00, -1.1948e+00,
         3.5489e+00, -4.4059e+00,  1.7142e+00, -2.9280e-02,  8.0693e-01,
        -7.8854e-01,  1.9729e+00, -5.0255e+00, -4.4413e-01,  3.4469e+00,
         5.5688e+00,  1.7563e-01, -2.8857e+00,  7.2082e+00,  2.9067e+00,
        -3.8329e+00,  2.5278e-01,  2.0527e-02,  2.4942e+00,  9.4450e+00,
        -4.0162e-02, -2.1084e+00, -1.2085e+00,  2.9852e+00,  2.3586e+00,
         4.5141e+00, -3.0145e+00, -4.2536e+00,  8.0056e+00,  4.2016e+00,
         9.6026e-01,  1.3663e+00, -3.9845e+00, -3.3102e-01,  2.1777e+00,
         3.2924e-01, -2.3507e+00, -4.6957e-01, -2.3629e+00,  1.1956e+00,
        -3.6817e-01,  3.7476e-01,  9.9520e-03,  3.7

In [18]:
W2 = torch.randn((100,27))
b2 = torch.randn(27)
print(W2.shape, W2.dtype)

torch.Size([100, 27]) torch.float32


In [19]:
logits = h[:,:29] @ W2[:29,:]
logits = h @ W2 + b2

In [20]:
print(logits.shape, logits.dtype)
logits

torch.Size([32, 27]) torch.float32


tensor([[ 8.8866e+00, -6.0578e+00, -2.3388e+01,  3.5965e+01,  4.8657e+01,
          1.2296e+01,  2.5244e+01,  5.2045e+00,  8.4152e+00, -3.7288e+01,
         -1.0053e+01, -1.6036e+01, -5.2414e+01,  1.9942e+01,  7.7306e+00,
         -3.2081e+01,  1.8151e+00,  1.9518e+01, -3.7173e+01, -3.3772e+01,
          1.0955e+00, -2.3259e+01, -5.4315e+01,  6.8498e+00,  4.7499e+01,
         -1.2473e+01, -1.9731e+01],
        [-6.5959e+00, -5.1010e+00, -1.9361e+01,  2.0723e+01,  2.0051e+01,
          1.8174e+01,  4.8063e+01, -2.9651e+01, -1.0960e+01, -5.2127e-01,
         -1.0167e+01, -2.8149e+01, -1.8374e+01,  7.2232e+00,  2.2916e+01,
         -2.0904e+01, -1.1312e+01, -2.0879e+01,  5.5439e+00, -1.1177e+01,
          1.2069e+01, -1.8887e+01, -2.8387e+00, -2.9201e+01,  6.3625e+01,
          2.6885e+01,  8.7589e+00],
        [ 3.6579e+01,  8.3161e+00, -1.2294e+01,  1.6767e+01,  6.9654e+01,
         -6.6898e+00, -2.0228e+01,  4.5030e+01,  1.3578e+01, -5.8158e+01,
         -2.4547e+01, -1.0729e+01, -5.13

In [21]:
counts = logits.exp()

In [22]:
prob = counts / counts.sum(1, keepdims=True)

In [23]:
print(prob.shape)
prob[0,:].sum().item()

torch.Size([32, 27])


1.0

In [24]:
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(55.4697)

In [25]:
# ----------------- now made respectable :) -------------------

In [26]:
# build the dataset

block_size = 3 # context length: how many characters do we take to predict the next one?
X, Y = [], []
# for w in words[:5]:
for w in words:
    
#     print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
#         print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:] + [ix] # crop and append

X = torch.tensor(X)
Y = torch.tensor(Y)

In [27]:
X.shape, Y.shape # dataset

(torch.Size([228146, 3]), torch.Size([228146]))

In [42]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
C = torch.randn((27,2), generator=g)
W1 = torch.randn((6,100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100,27), generator=g)
b2 = torch.randn(27, generator=g)
parameters = (C, W1, b1, W2, b2)
for p in parameters:
    p.requires_grad = True

In [43]:
sum(p.nelement() for p in parameters)

3481

In [44]:
lre = torch.linspace(-3,0,200)
lrs = 10**lre

In [45]:
lri = []
lossi = []

for i in range(200):
    
    #construct minibatch
    ix = torch.randint(0, X.shape[0], (32,))
    
    # forward pass
    emb = C[X[ix]] 
    XDim0 = X.shape[0]
    XDim0 = 32
    h = torch.zeros((XDim0, 100))
    for i in range(XDim0):
        h[i] = torch.tanh(emb.view(-1, 6)[i,:] @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y[ix])
    print(loss.item())

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    
    # update
    lr = lrs[i]
    for p in parameters:
        p.data += -lr * p.grad
        
    #track stats
    lri.append(lr)
    lossi.append(loss.item())
        
print(loss.item())

23.297788619995117
17.854251861572266
18.838159561157227
20.084503173828125
21.652175903320312
17.424697875976562
19.750141143798828
19.3549861907959
19.259531021118164
19.53969955444336
20.349151611328125
18.20952796936035
19.53746795654297
17.284074783325195
19.177120208740234
19.914215087890625
17.72681427001953
19.45444107055664
17.600515365600586
18.57328987121582
17.146333694458008
17.63677406311035
19.947290420532227
17.613666534423828
19.935916900634766
15.171893119812012
19.368032455444336
18.561120986938477
18.368989944458008
17.176227569580078
18.813844680786133
19.36686897277832
14.116782188415527
18.173419952392578
16.089309692382812
19.409934997558594
16.139013290405273
17.309410095214844
16.704673767089844
16.43288803100586
16.83089256286621
18.555273056030273
18.12418556213379
15.718244552612305
17.376455307006836
15.292025566101074
15.988767623901367
15.259799003601074
17.41710662841797
18.246219635009766
16.706804275512695
16.230134963989258
14.73523235321045
15.05490

In [48]:
plt.plot(lri, lossi)

AttributeError: module 'matplotlib' has no attribute 'plot'

In [36]:
    emb = C[X] 
    XDim0 = X.shape[0]
    h = torch.zeros((XDim0, 100))
    for i in range(XDim0):
        h[i] = torch.tanh(emb.view(-1, 6)[i,:] @ W1 + b1)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    print(loss.item())

2.8546316623687744
