In [57]:
words = open('names.txt' , 'r').read().splitlines()

In [58]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [59]:
import torch
N = torch.zeros((27 , 27, 27) , dtype=torch.int32)
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i , s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s , i in stoi.items()}


for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        N[ix1, ix2, ix3] += 1
    
P = (N + 1).float()
P /= P.sum(2 , keepdim=True)

In [60]:
P[1, 1 , 1].item()

0.0017152659129351377

In [61]:
# Sampling process
g = torch.Generator().manual_seed(2147483647)

for i in range(10):
    out = ['.', '.']
    while True:
        p = P[stoi[out[-2]], stoi[out[-1]]]  # Plug the last two chars into our probabilities table
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:  # End of sequence
            break
    print(''.join(out[2:-1]))  # Skip the initial start characters and the final end character


ce
za
zogh
uriana
kaydnevonimittain
luwak
ka
da
samiyah
javer


In [62]:
log_likehood = 0.0
n = 0
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
    n+=1
    logprob = torch.log(P[ix1, ix2, ix3])
    log_likehood += logprob

print(f'{log_likehood:.4f}') # -410414.9688
nll = -log_likehood

print(nll / n) # tensor(2.0927)

-410414.9688
tensor(2.0927)


In [72]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

xs,ys = [] , []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append((ix1, ix2)) # We need to keep two chars for the input, so it's an array of tuples instead of simple ints
        ys.append(ix3)
    
xs = torch.tensor(xs, dtype=torch.long, device=device)
ys = torch.tensor(ys, dtype=torch.long, device=device)
print("examples:", ys.shape[0])


g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27*2 , 27) , device=device , requires_grad=True)

print(f"xs device: {xs.device}")
print(f"ys device: {ys.device}")
print(f"W device: {W.device}")

import torch.nn.functional as F
for k in range(1000):
    xenc = F.one_hot(xs, num_classes=27).float()
    
    logits = xenc.view(-1, 27*2) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(ys.shape[0], device=device), ys].log().mean()
    
    W.grad = None
    loss.backward()
    
    with torch.no_grad():
        W -= 3 * W.grad
    print(loss.item())
    
print(loss.item())

mps
examples: 196113
xs device: mps:0
ys device: mps:0
W device: mps:0
4.1658759117126465
4.107062816619873
4.050714492797852
3.9967381954193115
3.9450480937957764
3.8955628871917725
3.8482043743133545
3.8028955459594727
3.7595603466033936
3.7181217670440674
3.678502321243286
3.640623092651367
3.604403257369995
3.5697641372680664
3.5366265773773193
3.5049121379852295
3.4745466709136963
3.4454565048217773
3.417574405670166
3.3908350467681885
3.3651785850524902
3.3405487537384033
3.316892385482788
3.2941620349884033
3.2723140716552734
3.251307249069214
3.2311041355133057
3.2116684913635254
3.1929678916931152
3.1749696731567383
3.157644271850586
3.1409621238708496
3.1248958110809326
3.1094179153442383
3.09450101852417
3.08012056350708
3.0662500858306885
3.052866220474243
3.0399439334869385
3.0274605751037598
3.0153942108154297
3.003723621368408
2.9924285411834717
2.981489658355713
2.970888614654541
2.9606082439422607
2.9506325721740723
2.940946102142334
2.931534767150879
2.922385454177856

In [73]:
import random 
import math 
random.shuffle(words)
c = len(words)
trainSet = words[:math.floor(c * 0.8)]
devSet = words[math.floor(c * 0.8) : math.floor(c*0.9)]
testSet = words[math.floor(c*0.9) : ]

In [74]:
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
print(device)

xs,ys = [] , []
for w in trainSet:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
        xs.append((ix1, ix2)) # We need to keep two chars for the input, so it's an array of tuples instead of simple ints
        ys.append(ix3)
    
xs = torch.tensor(xs, dtype=torch.long, device=device)
ys = torch.tensor(ys, dtype=torch.long, device=device)
print("examples:", ys.shape[0])


g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27*2 , 27) , device=device , requires_grad=True)

print(f"xs device: {xs.device}")
print(f"ys device: {ys.device}")
print(f"W device: {W.device}")

import torch.nn.functional as F
for k in range(1000):
    xenc = F.one_hot(xs, num_classes=27).float()
    
    logits = xenc.view(-1, 27*2) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    loss = -probs[torch.arange(ys.shape[0], device=device), ys].log().mean()
    
    W.grad = None
    loss.backward()
    
    with torch.no_grad():
        W -= 3 * W.grad
    print(loss.item())
    
print(loss.item())

mps
examples: 156966
xs device: mps:0
ys device: mps:0
W device: mps:0
4.367815017700195
4.289290428161621
4.216370582580566
4.148350715637207
4.084595203399658
4.024557113647461
3.967797040939331
3.913970947265625
3.862809419631958
3.814103841781616
3.7676963806152344
3.7234606742858887
3.6812963485717773
3.6411187648773193
3.6028573513031006
3.5664451122283936
3.5318188667297363
3.4989168643951416
3.467674493789673
3.438023805618286
3.4098944664001465
3.383211135864258
3.3578951358795166
3.3338658809661865
3.311042070388794
3.2893407344818115
3.2686805725097656
3.248983860015869
3.2301764488220215
3.2121870517730713
3.1949517726898193
3.178410768508911
3.1625099182128906
3.147199869155884
3.132436752319336
3.1181817054748535
3.1043992042541504
3.0910584926605225
3.0781314373016357
3.065593719482422
3.053422689437866
3.0415990352630615
3.030104398727417
3.018923044204712
3.008039712905884
2.997441530227661
2.9871160984039307
2.9770522117614746
2.9672393798828125
2.957667827606201
2.94

In [76]:
lossi = []
for w in testSet:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
      
    xenc = F.one_hot(torch.tensor((ix1, ix2)), num_classes=27).float().to(device)
    logits = xenc.view(-1, 27*2) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    lossi.append(-probs[0, ix3].log())
print(torch.tensor(lossi).mean()) ## 2.2615 similar than training set

lossi = []
for w in devSet:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    ix3 = stoi[ch3]
      
    xenc = F.one_hot(torch.tensor((ix1, ix2)), num_classes=27).float().to(device)
    logits = xenc.view(-1, 27*2) @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    lossi.append(-probs[0, ix3].log())
print(torch.tensor(lossi).mean())

tensor(2.2870)
tensor(2.2933)


In [79]:
smoothnesses = [0, 0.01, 0.02, 0.05, 0.1, 0.25, 0.5, 1.0]

for i, smoothness in enumerate(smoothnesses):
    W =  torch.randn((27*2, 27),device=device, requires_grad=True)
    for k in range(1000):
      xenc = F.one_hot(xs, num_classes=27).float().to(device)
      logits = xenc.view(-1, 27*2) @ W
      counts = logits.exp()
      probs = counts / counts.sum(1, keepdims=True)
      loss = -probs[torch.arange(ys.shape[0]), ys].log().mean() + smoothness*(W**2).mean()
      
      W.grad = None
      loss.backward()
    
      W.data += -3 * W.grad
    
    print(f"Smoothness: {smoothness} => loss train set: {loss.item()}")

    lossi = []
    for w in devSet:
      chs = ['.'] + list(w) + ['.']
      for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]
          
        xenc = F.one_hot(torch.tensor((ix1, ix2)), num_classes=27).float().to(device)
        logits = xenc.view(-1, 27*2) @ W
        counts = logits.exp()
        probs = counts / counts.sum(1, keepdims=True)
        lossi.append(-probs[0, ix3].log())
    print(f"Smoothness: {smoothness} => loss dev set: {torch.tensor(lossi).mean()}")
    print()

Smoothness: 0 => loss train set: 2.28078556060791
Smoothness: 0 => loss dev set: 2.2889039516448975

Smoothness: 0.01 => loss train set: 2.290830135345459
Smoothness: 0.01 => loss dev set: 2.292793035507202

Smoothness: 0.02 => loss train set: 2.296457052230835
Smoothness: 0.02 => loss dev set: 2.2923743724823

Smoothness: 0.05 => loss train set: 2.316439151763916
Smoothness: 0.05 => loss dev set: 2.291156530380249

Smoothness: 0.1 => loss train set: 2.3434064388275146
Smoothness: 0.1 => loss dev set: 2.294713258743286

Smoothness: 0.25 => loss train set: 2.3922369480133057
Smoothness: 0.25 => loss dev set: 2.3107962608337402

Smoothness: 0.5 => loss train set: 2.453662157058716
Smoothness: 0.5 => loss dev set: 2.345001459121704

Smoothness: 1.0 => loss train set: 2.544764518737793
Smoothness: 1.0 => loss dev set: 2.4018919467926025

