In [None]:
### E01: train a trigram language model, i.e. take two characters as an input to predict the 3rd one. 
### Feel free to use either counting or a neural net. Evaluate the loss; Did it improve over a bigram model?

In [1]:
words = open('names.txt', 'r').read().splitlines()

In [2]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
len(words)

32033

In [4]:
min(len(w) for w in words)

2

In [5]:
max(len(w) for w in words)

15

In [6]:
t = {}
for w in words:
  chs = ['<S>'] + list(w) + ['<E>']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    trigram = (ch1, ch2, ch3)
    t[trigram] = t.get(trigram, 0) + 1

In [7]:
sorted(t.items(), key = lambda kv: -kv[1])

[(('a', 'h', '<E>'), 1714),
 (('n', 'a', '<E>'), 1673),
 (('a', 'n', '<E>'), 1509),
 (('o', 'n', '<E>'), 1503),
 (('<S>', 'm', 'a'), 1453),
 (('<S>', 'j', 'a'), 1255),
 (('<S>', 'k', 'a'), 1254),
 (('e', 'n', '<E>'), 1217),
 (('l', 'y', 'n'), 976),
 (('y', 'n', '<E>'), 953),
 (('a', 'r', 'i'), 950),
 (('i', 'a', '<E>'), 903),
 (('i', 'e', '<E>'), 858),
 (('a', 'n', 'n'), 825),
 (('e', 'l', 'l'), 822),
 (('a', 'n', 'a'), 804),
 (('i', 'a', 'n'), 790),
 (('m', 'a', 'r'), 776),
 (('i', 'n', '<E>'), 766),
 (('e', 'l', '<E>'), 727),
 (('y', 'a', '<E>'), 716),
 (('a', 'n', 'i'), 703),
 (('<S>', 'd', 'a'), 700),
 (('l', 'a', '<E>'), 684),
 (('e', 'r', '<E>'), 683),
 (('i', 'y', 'a'), 669),
 (('l', 'a', 'n'), 647),
 (('<S>', 'b', 'r'), 646),
 (('n', 'n', 'a'), 633),
 (('<S>', 'a', 'l'), 632),
 (('<S>', 'c', 'a'), 628),
 (('r', 'a', '<E>'), 627),
 (('n', 'i', '<E>'), 625),
 (('<S>', 'a', 'n'), 623),
 (('n', 'n', '<E>'), 619),
 (('n', 'e', '<E>'), 607),
 (('e', 'e', '<E>'), 605),
 (('e', 'y', '<

In [27]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# pair -> char
# m = number of pairs; n = number of chars: same as in bigrams
# so our matrix of counts/probs will be m x n

In [20]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [66]:
# We actually need all possible pairs of our chars, as sampling can come up
# with a pair not seen it the actual data. 27*27

In [79]:
# same as with chars, but we need all pairs
pairs = []
for i in range(27):
  for j in range(27):
    pairs.append(itos[i] + itos[j])
pairs.sort()
# need to populate pair to ix and ix to pair dicts
pairtoi = {p:i for i,p in enumerate(pairs)}
itopair = {i:p for p,i in pairtoi.items()}

In [81]:
len(pairtoi), len(stoi)

(729, 27)

In [106]:
# Matrix of counts how often a pair followed by a char
N = torch.zeros((729, 27), dtype=torch.int32)
for w in words:
  # as we now using pairs, we start with ..
  # didn't come up with better solution
  chs = ['.', '.'] + list(w) + ['.']
  # we can use indecies, but for simplicity just 3 iters
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    N[ix1, ix2] += 1
    

In [107]:
N[0]

tensor([   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
        1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
         134,  535,  929], dtype=torch.int32)

In [108]:
p = N[0].float()
p = p / p.sum()
p

tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,
        0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,
        0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])

In [109]:
g = torch.Generator().manual_seed(2147483647)
ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
itos[ix]

'j'

In [116]:
P = (N+1).float() # N+1 is smoothing, so to not have inf loss on zero prob
P /= P.sum(1, keepdims=True)

In [117]:
P[1].sum()

tensor(1.0000)

In [122]:
g = torch.Generator().manual_seed(2147483647)

for _ in range(10):
  
  out = ['.']  # prepopulate with first .
  i = 0 # start sampling from what char follows '..'
  while True:
    p = P[i]
    j = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[j])
    if j == 0: # we've sampled end of word
      break
    # update index i
    pair = ''.join(out[-2:])  # last 2 chars
    i = pairtoi[pair]

  print(''.join(out[1:]))

junide.
jakasid.
prelay.
adin.
kairritoper.
sathen.
sameia.
yanileniassibduinrwin.
lessiyanayla.
te.


In [113]:
# Trying different seeds, it looks like more generated words became name-like. Tend to generate very long words as well.

In [114]:
# GOAL: maximize likelihood of the data w.r.t. model parameters (statistical modeling)
# equivalent to maximizing the log likelihood (because log is monotonic)
# equivalent to minimizing the negative log likelihood
# equivalent to minimizing the average negative log likelihood

# log(a*b*c) = log(a) + log(b) + log(c)

In [119]:
log_likelihood = 0.0
n = 0

for w in words:
# for w in ["alexey"]:
  chs = ['.', '.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1
    # print(f'{ch1}{ch2}{ch3}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

log_likelihood=tensor(-504653.)
nll=tensor(504653.)
2.2119739055633545


In [120]:
# Increasing context to have a probability of char following a pair improves loss.

In [123]:
# create the training set of trigrams (x,y)
xs, ys = [], []

for w in words[:1]:
  chs = ['.', '.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    print(ch1, ch2, ch3)
    xs.append(ix1)
    ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)

. . e
. e m
e m m
m m a
m a .


In [124]:
xs

tensor([  0,   5, 148, 364, 352])

In [125]:
ys

tensor([ 5, 13, 13,  1,  0])

In [127]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=729).float()
xenc

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [128]:
xenc.shape

torch.Size([5, 729])

In [130]:
xenc.dtype

torch.float32

In [136]:
W = torch.randn((729, 27))
xenc @ W

tensor([[ 0.2249,  0.9268,  0.5115,  0.5983, -1.1390,  1.0081, -0.5331, -1.3072,
         -1.4007, -1.3394,  1.2692, -0.3741,  2.5223, -1.4576, -0.4786, -1.4972,
          0.3316, -0.4075, -2.0935,  0.1361,  0.8890,  0.0319, -0.3023, -0.0750,
          0.5372,  1.2407,  0.2319],
        [-0.0229, -0.0321,  1.6305, -0.1696, -1.5931, -1.3193, -0.6700,  0.5976,
         -0.4984,  1.3065, -1.2362, -0.1958,  1.6151,  1.7665, -0.4873,  0.2932,
          1.2997, -0.2488,  0.3454, -1.3379,  0.4373, -0.0130,  0.7801,  0.0419,
         -0.2599, -0.4158, -1.4023],
        [ 0.0726, -0.0574,  1.0055, -1.3377,  1.5328, -1.1245, -2.0126, -0.1728,
          0.8392,  0.3975,  1.6762,  0.0852, -0.8431,  0.8452, -1.9324, -0.1576,
         -0.5848, -0.5178,  0.0727, -1.9551,  0.5365,  0.2742,  1.9091,  1.2922,
          0.5233, -0.7467,  0.2822],
        [-1.9088, -0.1388,  1.0697, -2.1051,  0.3637,  1.2557, -2.4404,  1.2145,
          0.3150, -0.1444,  0.6528,  0.1451, -0.3456, -1.3960,  0.5895, -0.2948

In [137]:
logits = xenc @ W # log-counts
counts = logits.exp() # equivalent N
probs = counts / counts.sum(1, keepdims=True)
probs

tensor([[0.0282, 0.0570, 0.0376, 0.0410, 0.0072, 0.0618, 0.0132, 0.0061, 0.0056,
         0.0059, 0.0802, 0.0155, 0.2808, 0.0052, 0.0140, 0.0050, 0.0314, 0.0150,
         0.0028, 0.0258, 0.0548, 0.0233, 0.0167, 0.0209, 0.0386, 0.0780, 0.0284],
        [0.0230, 0.0228, 0.1201, 0.0199, 0.0048, 0.0063, 0.0120, 0.0428, 0.0143,
         0.0869, 0.0068, 0.0193, 0.1183, 0.1376, 0.0145, 0.0315, 0.0863, 0.0183,
         0.0332, 0.0062, 0.0364, 0.0232, 0.0513, 0.0245, 0.0181, 0.0155, 0.0058],
        [0.0243, 0.0214, 0.0618, 0.0059, 0.1047, 0.0073, 0.0030, 0.0190, 0.0524,
         0.0337, 0.1209, 0.0246, 0.0097, 0.0527, 0.0033, 0.0193, 0.0126, 0.0135,
         0.0243, 0.0032, 0.0387, 0.0298, 0.1526, 0.0824, 0.0382, 0.0107, 0.0300],
        [0.0044, 0.0258, 0.0862, 0.0036, 0.0426, 0.1038, 0.0026, 0.0997, 0.0405,
         0.0256, 0.0568, 0.0342, 0.0209, 0.0073, 0.0533, 0.0220, 0.0335, 0.1773,
         0.0157, 0.0216, 0.0074, 0.0294, 0.0282, 0.0164, 0.0150, 0.0133, 0.0128],
        [0.0300, 0.0446,

In [138]:
probs[0]

tensor([0.0282, 0.0570, 0.0376, 0.0410, 0.0072, 0.0618, 0.0132, 0.0061, 0.0056,
        0.0059, 0.0802, 0.0155, 0.2808, 0.0052, 0.0140, 0.0050, 0.0314, 0.0150,
        0.0028, 0.0258, 0.0548, 0.0233, 0.0167, 0.0209, 0.0386, 0.0780, 0.0284])

In [139]:
probs[0].shape

torch.Size([27])

In [140]:
probs[0].sum()

tensor(1.0000)

In [None]:
# (5, 27) @ (27, 27) -> (5, 27)

In [None]:
# SUMMARY ------------------------------>>>>

In [528]:
xs

tensor([ 0,  5, 13, 13,  1])

In [529]:
ys

tensor([ 5, 13, 13,  1,  0])

In [141]:
# randomly initialize 27 neurons' weights. each neuron receives 729 inputs (all possible pairs)
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g)

In [142]:
xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
# btw: the last 2 lines here are together called a 'softmax'

In [143]:
probs.shape

torch.Size([5, 27])

In [144]:

nlls = torch.zeros(5)
for i in range(5):
  # i-th bigram:
  x = xs[i].item() # input character index
  y = ys[i].item() # label character index
  print('--------')
  print(f'bigram example {i+1}: {itopair[x]}{itos[y]} (indexes {x},{y})')
  print('input to the neural net:', x)
  print('output probabilities from the neural net:', probs[i])
  print('label (actual next character):', y)
  p = probs[i, y]
  print('probability assigned by the net to the the correct character:', p.item())
  logp = torch.log(p)
  print('log likelihood:', logp.item())
  nll = -logp
  print('negative log likelihood:', nll.item())
  nlls[i] = nll

print('=========')
print('average negative log likelihood, i.e. loss =', nlls.mean().item())

--------
bigram example 1: ..e (indexes 0,5)
input to the neural net: 0
output probabilities from the neural net: tensor([0.0607, 0.0100, 0.0123, 0.0042, 0.0168, 0.0123, 0.0027, 0.0232, 0.0137,
        0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2378, 0.0603, 0.0025,
        0.0249, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1537, 0.1459])
label (actual next character): 5
probability assigned by the net to the the correct character: 0.01228625513613224
log likelihood: -4.399273872375488
negative log likelihood: 4.399273872375488
--------
bigram example 2: .em (indexes 5,13)
input to the neural net: 5
output probabilities from the neural net: tensor([0.0290, 0.0796, 0.0248, 0.0521, 0.1989, 0.0289, 0.0094, 0.0335, 0.0097,
        0.0301, 0.0702, 0.0228, 0.0115, 0.0181, 0.0108, 0.0315, 0.0291, 0.0045,
        0.0916, 0.0215, 0.0486, 0.0300, 0.0501, 0.0027, 0.0118, 0.0022, 0.0472])
label (actual next character): 13
probability assigned by the net to the the correct character: 

In [561]:
# --------- !!! OPTIMIZATION !!! yay --------------

In [145]:
xs

tensor([  0,   5, 148, 364, 352])

In [146]:
ys

tensor([ 5, 13, 13,  1,  0])

In [147]:
# randomly initialize 27 neurons' weights. each neuron receives 27 inputs
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

In [168]:
# forward pass
xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
loss = -probs[torch.arange(5), ys].log().mean()

In [169]:
print(loss.item())

4.128016948699951


In [170]:
# backward pass
W.grad = None # set to zero the gradient
loss.backward()

In [171]:
W.data += -0.1 * W.grad

In [606]:
# --------- !!! OPTIMIZATION !!! yay, but this time actually --------------

In [200]:
# create the dataset
xs, ys = [], []
for w in words:
  chs = ['.', '.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

number of examples:  228146


In [201]:
# gradient descent
for k in range(60):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
  logits = xenc @ W # predict log-counts
  counts = logits.exp() # counts, equivalent to N
  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
  loss = -probs[torch.arange(num), ys].log().mean() #+ 0.01*(W**2).mean()
  print(f'{k=}, {loss.item()}')
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -200 * W.grad

k=0, 3.7927768230438232
k=1, 3.4378983974456787
k=2, 3.321561336517334
k=3, 3.1370038986206055
k=4, 3.0594258308410645
k=5, 2.994174003601074
k=6, 2.91180157661438
k=7, 2.9094626903533936
k=8, 2.8129067420959473
k=9, 2.773369073867798
k=10, 2.7473950386047363
k=11, 2.716261625289917
k=12, 2.705533742904663
k=13, 2.7144110202789307
k=14, 2.6553118228912354
k=15, 2.645838975906372
k=16, 2.651409149169922
k=17, 2.7192015647888184
k=18, 2.627261161804199
k=19, 2.5751070976257324
k=20, 2.5796000957489014
k=21, 2.574028968811035
k=22, 2.6388494968414307
k=23, 2.5547096729278564
k=24, 2.544301748275757
k=25, 2.5965373516082764
k=26, 2.5183115005493164
k=27, 2.5199191570281982
k=28, 2.592808246612549
k=29, 2.5120465755462646
k=30, 2.4963674545288086
k=31, 2.535010576248169
k=32, 2.46429443359375
k=33, 2.4678473472595215
k=34, 2.513728141784668
k=35, 2.4499590396881104
k=36, 2.4547183513641357
k=37, 2.459859848022461
k=38, 2.445643186569214
k=39, 2.5028297901153564
k=40, 2.4349894523620605
k=41

In [202]:
# gradient descent
for k in range(60):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
  logits = xenc @ W # predict log-counts
  counts = logits.exp() # counts, equivalent to N
  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
  loss = -probs[torch.arange(num), ys].log().mean() #+ 0.01*(W**2).mean()
  print(f'{k=}, {loss.item()}')
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -100 * W.grad

k=0, 2.4115524291992188
k=1, 2.3711133003234863
k=2, 2.3531670570373535
k=3, 2.3519272804260254
k=4, 2.3508434295654297
k=5, 2.349788188934326
k=6, 2.348755359649658
k=7, 2.3477394580841064
k=8, 2.346738576889038
k=9, 2.3457517623901367
k=10, 2.344777822494507
k=11, 2.3438167572021484
k=12, 2.342867612838745
k=13, 2.341930627822876
k=14, 2.3410050868988037
k=15, 2.3400909900665283
k=16, 2.3391876220703125
k=17, 2.3382954597473145
k=18, 2.337413787841797
k=19, 2.3365426063537598
k=20, 2.335681676864624
k=21, 2.3348309993743896
k=22, 2.3339900970458984
k=23, 2.3331587314605713
k=24, 2.3323373794555664
k=25, 2.3315250873565674
k=26, 2.3307223320007324
k=27, 2.3299286365509033
k=28, 2.3291432857513428
k=29, 2.328367233276367
k=30, 2.3276000022888184
k=31, 2.326840877532959
k=32, 2.3260903358459473
k=33, 2.325347900390625
k=34, 2.324613332748413
k=35, 2.3238871097564697
k=36, 2.3231685161590576
k=37, 2.3224575519561768
k=38, 2.3217544555664062
k=39, 2.321058750152588
k=40, 2.320369958877563

In [208]:
# finally, sample from the 'neural net' model
g = torch.Generator().manual_seed(2147483647)

for _ in range(5):
  
  out = ['.']
  i = 0
  while True:
    
    # ----------
    # BEFORE:
    #p = P[ix]
    # ----------
    # NOW:
    xenc = F.one_hot(torch.tensor([i]), num_classes=729).float()
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    # ----------
    
    j = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[j])
    if j == 0:
      break
    # update index i
    pair = ''.join(out[-2:])  # last 2 chars
    i = pairtoi[pair]
  print(''.join(out[1:]))

junide.
janasid.
prelay.
adin.
kairritonian.


In [210]:
W[0]

tensor([-3.6937,  2.1662,  0.9490,  1.1152,  1.2068,  1.1080, -0.1936,  0.2796,
         0.5471,  0.1556,  1.5668,  1.7685,  1.1344,  1.6136,  0.8182, -0.2504,
         0.0178, -1.7029,  1.1762,  1.4025,  0.9505, -1.8643, -0.2973, -0.5004,
        -1.3299,  0.0559,  0.6082], grad_fn=<SelectBackward0>)

In [211]:
W[1]

tensor([-2.3681,  0.7741,  0.6876, -1.1817,  1.3476, -0.5805, -1.4108, -1.6406,
        -0.0625,  0.4751, -1.2273, -0.2655,  1.8956,  1.3958,  1.8813, -1.8733,
        -1.5648, -2.0690,  1.6239,  0.7086, -0.3027,  0.4618,  0.9357, -2.0023,
        -1.2318,  0.5929,  0.4618], grad_fn=<SelectBackward0>)