In [115]:
words = open('../names.txt', 'r').read().splitlines()

In [116]:
import torch

In [117]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

chars = ['.'] + chars 
len(chars)

27

In [118]:
# Encoding 2 char combinations in a dict
stoi_2 = {}
n = 0
for ch1 in chars:
    for ch2 in chars:
        stoi_2[f'{ch1}{ch2}'] = n
        n += 1
len(stoi_2.items())


729

In [142]:
itos_2 = {i:s for s,i in stoi_2.items()}
itos_2[0]

'..'

In [146]:
# create the dataset
xs, ys = [], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi_2[f'{ch1}{ch2}']
        ix2 = stoi[ch3]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print(f'number of examples: {num}')

# init the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729,27), generator=g, requires_grad=True)

number of examples: 196113


In [121]:
import torch.nn.functional as F

In [149]:
# gradient descent
for k in range(100):
    # forward pass
    xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
    logits = xenc @ W # this becomes just a row of W because of one-hot
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    print(loss.item())
    
    # backward pass
    W.grad = None
    loss.backward()
    
    # update
    W.data += -10 * W.grad

2.9221363067626953
2.9183037281036377
2.91451096534729
2.910757064819336
2.9070420265197754
2.903364896774292
2.8997249603271484
2.8961219787597656
2.892554759979248
2.889024019241333
2.885528326034546
2.8820672035217285
2.8786404132843018
2.8752474784851074
2.8718881607055664
2.8685615062713623
2.865267515182495
2.8620052337646484
2.8587746620178223
2.8555750846862793
2.8524060249328613
2.8492679595947266
2.846158981323242
2.8430802822113037
2.8400301933288574
2.837008476257324
2.8340160846710205
2.8310511112213135
2.828113555908203
2.8252031803131104
2.8223202228546143
2.8194632530212402
2.8166322708129883
2.8138277530670166
2.8110482692718506
2.8082940578460693
2.8055646419525146
2.8028597831726074
2.8001790046691895
2.7975223064422607
2.794888973236084
2.792279005050659
2.789691686630249
2.7871274948120117
2.784585475921631
2.7820651531219482
2.779567003250122
2.777090311050415
2.774634599685669
2.772200107574463
2.7697861194610596
2.767392158508301
2.765018939971924
2.762665510177

In [143]:
# Sampling
g = torch.Generator().manual_seed(2147483647 + 2)

for i in range(1):
    out = []
    ix = 0
    
    while True:
        print(itos_2[ix])
        xenc = F.one_hot(torch.tensor([ix]), num_classes=729).float()
        logits = xenc @ W # predict log-counts
        counts = logits.exp()
        p = counts / counts.sum(1, keepdims=True)
        
        print(f'sum: {torch.sum(p)}\tshape: {p.shape}')
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        
        if ix % 27 == 0:
            break
    # print(out)
    print(''.join(out))

..
sum: 0.9999998807907104	shape: torch.Size([1, 27])
.t
sum: 0.9999999403953552	shape: torch.Size([1, 27])
.v
sum: 0.9999999403953552	shape: torch.Size([1, 27])
.w
sum: 1.0	shape: torch.Size([1, 27])
.y
sum: 1.0	shape: torch.Size([1, 27])
.a
sum: 1.0	shape: torch.Size([1, 27])
.f
sum: 1.0	shape: torch.Size([1, 27])
.a
sum: 1.0	shape: torch.Size([1, 27])
.i
sum: 1.0	shape: torch.Size([1, 27])
.l
sum: 0.9999998807907104	shape: torch.Size([1, 27])
.i
sum: 1.0	shape: torch.Size([1, 27])
.s
sum: 1.0	shape: torch.Size([1, 27])
.h
sum: 1.0000001192092896	shape: torch.Size([1, 27])
.u
sum: 0.9999999403953552	shape: torch.Size([1, 27])
.i
sum: 1.0	shape: torch.Size([1, 27])
.s
sum: 1.0	shape: torch.Size([1, 27])
.y
sum: 1.0	shape: torch.Size([1, 27])
.a
sum: 1.0	shape: torch.Size([1, 27])
.e
sum: 1.0	shape: torch.Size([1, 27])
.s
sum: 1.0	shape: torch.Size([1, 27])
.h
sum: 1.0000001192092896	shape: torch.Size([1, 27])
.a
sum: 1.0	shape: torch.Size([1, 27])
.v
sum: 0.9999999403953552	shape: tor