In [1]:
import torch
import torch.nn.functional as F

words = open('names.txt', 'r').read().splitlines() 

In [2]:
N = torch.zeros((27, 27), dtype=torch.int32)

In [3]:
chars = sorted(list(set(''.join(words)))) 
stoi = {s:i+1 for i,s in enumerate(chars)} 

stoi['.'] = 0

itos = {i:s for s,i in stoi.items()}

In [4]:
xs, ys = [], []
for w in words:
  
  chs = ['.'] + list(w) + ['.']
  
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)


number of examples:  228146


In [5]:
for k in range(100):
  
  # forward pass
    xenc = F.one_hot(xs, num_classes=27).float() # input to the network: one-hot encoding
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean() #0.01 * W^2 mean is L2 norm regularization to make the distribution uniform. We wan the W's to be close to zero so as to get prob counts closer to themselves
    print(loss.item())
  
  # backward pass
    W.grad = None # set to zero the gradient
    loss.backward()
  
  # update
    W.data += -50 * W.grad

3.7686190605163574
3.378786325454712
3.1610782146453857
3.027181386947632
2.9344804286956787
2.8672285079956055
2.81665301322937
2.777146100997925
2.745253801345825
2.7188305854797363
2.696505546569824
2.6773722171783447
2.6608059406280518
2.6463515758514404
2.6336653232574463
2.622471570968628
2.6125476360321045
2.6037065982818604
2.595794439315796
2.5886807441711426
2.5822560787200928
2.5764291286468506
2.5711233615875244
2.566272735595703
2.5618226528167725
2.5577261447906494
2.5539441108703613
2.550442695617676
2.547192335128784
2.5441696643829346
2.5413525104522705
2.538722038269043
2.536262035369873
2.5339581966400146
2.531797409057617
2.5297679901123047
2.527860164642334
2.526063919067383
2.5243709087371826
2.522773265838623
2.521263837814331
2.519836902618408
2.5184857845306396
2.5172054767608643
2.515990972518921
2.5148372650146484
2.5137407779693604
2.51269793510437
2.511704921722412
2.5107581615448
2.509855031967163
2.5089921951293945
2.5081686973571777
2.507380485534668
2.5

In [25]:
#g = torch.Generator().manual_seed(2147483647)

for i in range(5):
  
  out = []
  ix = 0
  while True:
    
    # ----------
    # BEFORE:
    #p = P[ix]
    # ----------
    # NOW:
    xenc = F.one_hot(torch.tensor([ix]), num_classes=27).float()
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    # ----------
    
    ix = torch.multinomial(p, num_samples=1, replacement=True).item()
    out.append(itos[ix])
    if ix == 0:
      break
  print(''.join(out))

as.
janh.
raradikam.
shaielon.
belllyn.
