In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt 
%matplotlib inline

# Read all the words

In [3]:
words = open('names.txt', 'r').read().splitlines()
words[:5]


['emma', 'olivia', 'ava', 'isabella', 'sophia']

In [6]:
#build the vocabolary od characters and mapping  to/ from integers

chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [11]:
stoi

{'a': 1,
 'b': 2,
 'c': 3,
 'd': 4,
 'e': 5,
 'f': 6,
 'g': 7,
 'h': 8,
 'i': 9,
 'j': 10,
 'k': 11,
 'l': 12,
 'm': 13,
 'n': 14,
 'o': 15,
 'p': 16,
 'q': 17,
 'r': 18,
 's': 19,
 't': 20,
 'u': 21,
 'v': 22,
 'w': 23,
 'x': 24,
 'y': 25,
 'z': 26,
 '.': 0}

In [12]:
# build the dataset

block_size = 3
X, Y = [], []

for w in words[:5]:
    print(w)
    context = [0] * block_size
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '--->', itos[ix])
        context = context[1:]+[ix]
        
X = torch.tensor(X)
Y = torch.tensor(Y)


emma
... ---> e
..e ---> m
.em ---> m
emm ---> a
mma ---> .
olivia
... ---> o
..o ---> l
.ol ---> i
oli ---> v
liv ---> i
ivi ---> a
via ---> .
ava
... ---> a
..a ---> v
.av ---> a
ava ---> .
isabella
... ---> i
..i ---> s
.is ---> a
isa ---> b
sab ---> e
abe ---> l
bel ---> l
ell ---> a
lla ---> .
sophia
... ---> s
..s ---> o
.so ---> p
sop ---> h
oph ---> i
phi ---> a
hia ---> .


In [13]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [15]:
C  =  torch.randn((27, 2))

In [16]:
C

tensor([[-0.1119,  0.2636],
        [-0.6989,  2.1506],
        [-0.7149, -0.5945],
        [-0.8677,  0.7291],
        [-1.4325, -0.0783],
        [ 0.5118, -0.7719],
        [ 0.4522, -0.5731],
        [-0.6922,  0.4657],
        [ 0.1393, -0.2769],
        [-1.0653,  1.0387],
        [-0.1632,  0.6185],
        [ 1.3196,  1.5815],
        [ 0.6852,  1.2500],
        [ 1.6148, -0.4787],
        [-0.9683,  0.5787],
        [-0.9439, -0.2379],
        [ 0.5031,  0.4044],
        [-0.2047,  0.3614],
        [-0.3804, -0.9753],
        [-0.4532, -0.4589],
        [ 0.7407,  0.2495],
        [ 0.9104,  0.0138],
        [-1.2113, -0.6081],
        [ 1.0128, -0.5024],
        [ 0.5786, -0.5936],
        [-1.3831,  2.6347],
        [ 0.2767, -0.4244]])

In [17]:
C[5]

tensor([ 0.5118, -0.7719])

In [21]:
F.one_hot(torch.tensor(5), num_classes=27).float() @ C

tensor([ 0.5118, -0.7719])

In [22]:
C[X]

tensor([[[-0.1119,  0.2636],
         [-0.1119,  0.2636],
         [-0.1119,  0.2636]],

        [[-0.1119,  0.2636],
         [-0.1119,  0.2636],
         [ 0.5118, -0.7719]],

        [[-0.1119,  0.2636],
         [ 0.5118, -0.7719],
         [ 1.6148, -0.4787]],

        [[ 0.5118, -0.7719],
         [ 1.6148, -0.4787],
         [ 1.6148, -0.4787]],

        [[ 1.6148, -0.4787],
         [ 1.6148, -0.4787],
         [-0.6989,  2.1506]],

        [[-0.1119,  0.2636],
         [-0.1119,  0.2636],
         [-0.1119,  0.2636]],

        [[-0.1119,  0.2636],
         [-0.1119,  0.2636],
         [-0.9439, -0.2379]],

        [[-0.1119,  0.2636],
         [-0.9439, -0.2379],
         [ 0.6852,  1.2500]],

        [[-0.9439, -0.2379],
         [ 0.6852,  1.2500],
         [-1.0653,  1.0387]],

        [[ 0.6852,  1.2500],
         [-1.0653,  1.0387],
         [-1.2113, -0.6081]],

        [[-1.0653,  1.0387],
         [-1.2113, -0.6081],
         [-1.0653,  1.0387]],

        [[-1.2113, -0

In [23]:
C[X].shape

torch.Size([32, 3, 2])

In [24]:
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [27]:
W1 = torch.randn((6,100))
b1 = torch.rand((100))

In [34]:
torch.cat(torch.unbind(emb, 1),1).shape

torch.Size([32, 6])

In [35]:
h = torch.tanh(emb.view(32, 6) @ W1 + b1)

In [36]:
h.shape

torch.Size([32, 100])

In [38]:
W2 = torch.randn((100, 27))
b2 = torch.randn(27)

In [40]:
logits = h @ W2 + b2

In [54]:
counts = logits.exp()
prob = counts / counts.sum(1, keepdims = True)

In [52]:
prob[0].sum()

tensor(1.)

In [57]:
loss = -prob[torch.arange(32),Y].log().mean()
loss

tensor(11.8003)

In [60]:
g = torch.Generator().manual_seed(2147483647)
C = torch.randn((27,2), generator = g)
W1 = torch.randn((6,100), generator = g)
b1 = torch.randn(100, generator = g)
W2 = torch.randn((100, 27), generator=g)
b2 = torch.randn(27, generator = g)
parameters = [C,W1, b1, W2, b2]

In [61]:
sum(p.nelement() for p in parameters) # number parameters in total


3481

In [62]:
#----------version 1--------------
emb = C[X]
h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
logits = h @ W2 + b2
counts = logits.exp()
prob = counts / counts.sum(1, keepdims = True)
loss = -prob[torch.arange(32), Y].log().mean()
loss

tensor(17.7697)

In [67]:
for p in parameters:
    p.requires_grad = True

In [69]:
#----------version 1--------------
# froward pass
for _ in range(100):
    emb = C[X]
    h = torch.tanh(emb.view(-1, 6) @ W1 + b1) # (32, 100)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    print(loss.item())
    # packward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    #update
    for p in parameters:
        p.data += -0.1 * p.grad


3.985848903656006
3.6028308868408203
3.2621419429779053
2.961381196975708
2.6982977390289307
2.469712972640991
2.271660327911377
2.101283550262451
1.9571774005889893
1.8374857902526855
1.7380965948104858
1.6535115242004395
1.579089879989624
1.5117665529251099
1.4496047496795654
1.3913120031356812
1.335992455482483
1.283052921295166
1.2321909666061401
1.18338143825531
1.1367990970611572
1.092664122581482
1.0510923862457275
1.0120267868041992
0.9752704501152039
0.9405565857887268
0.9076125025749207
0.8761921525001526
0.8460890650749207
0.817135751247406
0.7891992330551147
0.7621746063232422
0.7359814047813416
0.7105579972267151
0.6858609318733215
0.6618651747703552
0.6385656595230103
0.6159818172454834
0.5941657423973083
0.5732105374336243
0.553256094455719
0.5344880819320679
0.5171167254447937
0.5013312697410583
0.48724257946014404
0.4748404920101166
0.4639975428581238
0.45451444387435913
0.4461706876754761
0.4387663006782532
0.43213313817977905
0.4261389374732971
0.42067983746528625
0.

tensor([0.1135, 0.0418, 0.8390, 0.0057])