In [None]:
### E01: train a trigram language model, i.e. take two characters as an input to predict the 3rd one. 
### Feel free to use either counting or a neural net. Evaluate the loss; Did it improve over a bigram model?

In [1]:
words = open('names.txt', 'r').read().splitlines()

In [2]:
words[:10]

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn']

In [3]:
len(words)

32033

In [4]:
min(len(w) for w in words)

2

In [5]:
max(len(w) for w in words)

15

In [6]:
t = {}
for w in words:
  chs = ['<S>'] + list(w) + ['<E>']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    trigram = (ch1, ch2, ch3)
    t[trigram] = t.get(trigram, 0) + 1

In [7]:
sorted(t.items(), key = lambda kv: -kv[1])

[(('a', 'h', '<E>'), 1714),
 (('n', 'a', '<E>'), 1673),
 (('a', 'n', '<E>'), 1509),
 (('o', 'n', '<E>'), 1503),
 (('<S>', 'm', 'a'), 1453),
 (('<S>', 'j', 'a'), 1255),
 (('<S>', 'k', 'a'), 1254),
 (('e', 'n', '<E>'), 1217),
 (('l', 'y', 'n'), 976),
 (('y', 'n', '<E>'), 953),
 (('a', 'r', 'i'), 950),
 (('i', 'a', '<E>'), 903),
 (('i', 'e', '<E>'), 858),
 (('a', 'n', 'n'), 825),
 (('e', 'l', 'l'), 822),
 (('a', 'n', 'a'), 804),
 (('i', 'a', 'n'), 790),
 (('m', 'a', 'r'), 776),
 (('i', 'n', '<E>'), 766),
 (('e', 'l', '<E>'), 727),
 (('y', 'a', '<E>'), 716),
 (('a', 'n', 'i'), 703),
 (('<S>', 'd', 'a'), 700),
 (('l', 'a', '<E>'), 684),
 (('e', 'r', '<E>'), 683),
 (('i', 'y', 'a'), 669),
 (('l', 'a', 'n'), 647),
 (('<S>', 'b', 'r'), 646),
 (('n', 'n', 'a'), 633),
 (('<S>', 'a', 'l'), 632),
 (('<S>', 'c', 'a'), 628),
 (('r', 'a', '<E>'), 627),
 (('n', 'i', '<E>'), 625),
 (('<S>', 'a', 'n'), 623),
 (('n', 'n', '<E>'), 619),
 (('n', 'e', '<E>'), 607),
 (('e', 'e', '<E>'), 605),
 (('e', 'y', '<

In [8]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline

In [9]:
# pair -> char
# m = number of pairs; n = number of chars: same as in bigrams
# so our matrix of counts/probs will be m x n

In [10]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [11]:
# We actually need all possible pairs of our chars, as sampling can come up
# with a pair not seen it the actual data. 27*27

In [12]:
# same as with chars, but we need all pairs
pairs = []
for i in range(27):
  for j in range(27):
    pairs.append(itos[i] + itos[j])
pairs.sort()
# need to populate pair to ix and ix to pair dicts
pairtoi = {p:i for i,p in enumerate(pairs)}
itopair = {i:p for p,i in pairtoi.items()}

In [13]:
len(pairtoi), len(stoi)

(729, 27)

In [14]:
# Matrix of counts how often a pair followed by a char
N = torch.zeros((729, 27), dtype=torch.int32)
for w in words:
  # as we now using pairs, we start with ..
  # didn't come up with better solution
  chs = ['.', '.'] + list(w) + ['.']
  # we can use indecies, but for simplicity just 3 iters
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    N[ix1, ix2] += 1
    

In [15]:
N[0]

tensor([   0, 4410, 1306, 1542, 1690, 1531,  417,  669,  874,  591, 2422, 2963,
        1572, 2538, 1146,  394,  515,   92, 1639, 2055, 1308,   78,  376,  307,
         134,  535,  929], dtype=torch.int32)

In [16]:
p = N[0].float()
p = p / p.sum()
p

tensor([0.0000, 0.1377, 0.0408, 0.0481, 0.0528, 0.0478, 0.0130, 0.0209, 0.0273,
        0.0184, 0.0756, 0.0925, 0.0491, 0.0792, 0.0358, 0.0123, 0.0161, 0.0029,
        0.0512, 0.0642, 0.0408, 0.0024, 0.0117, 0.0096, 0.0042, 0.0167, 0.0290])

In [17]:
g = torch.Generator().manual_seed(2147483647)
ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
itos[ix]

'j'

In [18]:
P = (N+1).float() # N+1 is smoothing, so to not have inf loss on zero prob
P /= P.sum(1, keepdims=True)

In [19]:
P[1].sum()

tensor(1.0000)

In [24]:
g = torch.Generator().manual_seed(2147483647)

for _ in range(10):
  
  out = ['.']  # prepopulate with first .
  i = 0 # start sampling from what char follows '..'
  while True:
    p = P[i]
    j = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[j])
    if j == 0: # we've sampled end of word
      break
    # update index i
    pair = ''.join(out[-2:])  # last 2 chars
    i = pairtoi[pair]

  print(''.join(out[1:]))

junide.
jakasid.
prelay.
adin.
kairritoper.
sathen.
sameia.
yanileniassibduinrwin.
lessiyanayla.
te.


In [113]:
# Trying different seeds, it looks like more generated words became name-like. Tend to generate very long words as well.

In [114]:
# GOAL: maximize likelihood of the data w.r.t. model parameters (statistical modeling)
# equivalent to maximizing the log likelihood (because log is monotonic)
# equivalent to minimizing the negative log likelihood
# equivalent to minimizing the average negative log likelihood

# log(a*b*c) = log(a) + log(b) + log(c)

In [25]:
log_likelihood = 0.0
n = 0

for w in words:
# for w in ["alexey"]:
  chs = ['.', '.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    prob = P[ix1, ix2]
    logprob = torch.log(prob)
    log_likelihood += logprob
    n += 1
    # print(f'{ch1}{ch2}{ch3}: {prob:.4f} {logprob:.4f}')

print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
print(f'{nll/n}')

log_likelihood=tensor(-504653.)
nll=tensor(504653.)
2.2119739055633545


In [120]:
# Increasing context to have a probability of char following a pair improves loss.

In [26]:
# create the training set of trigrams (x,y)
xs, ys = [], []

for w in words[:1]:
  chs = ['.', '.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    print(ch1, ch2, ch3)
    xs.append(ix1)
    ys.append(ix2)
    
xs = torch.tensor(xs)
ys = torch.tensor(ys)

. . e
. e m
e m m
m m a
m a .


In [27]:
xs

tensor([  0,   5, 148, 364, 352])

In [28]:
ys

tensor([ 5, 13, 13,  1,  0])

In [29]:
import torch.nn.functional as F
xenc = F.one_hot(xs, num_classes=729).float()
xenc

tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [30]:
xenc.shape

torch.Size([5, 729])

In [31]:
xenc.dtype

torch.float32

In [32]:
W = torch.randn((729, 27))
xenc @ W

tensor([[ 0.3815,  0.0175,  0.0541, -0.0177, -2.5514,  1.3628,  0.7797,  0.7665,
         -1.1949, -0.6674, -0.7216,  0.0961,  1.6246,  1.2666,  0.7603,  0.6153,
          1.3104,  1.2735,  0.6359,  0.3995,  1.0126,  0.1056, -1.9424, -0.6486,
          1.3768,  0.8932, -1.0436],
        [-0.0555, -0.2170,  0.6552, -0.9323, -0.4071, -1.7409, -0.1845, -2.1906,
         -0.2675, -0.7234, -0.8160, -0.6237, -0.4340,  0.7525,  0.8160,  0.2127,
         -1.0352, -0.9936, -0.0191, -1.0491,  1.4477, -0.0094, -0.5521,  1.4184,
         -0.4808,  0.2031,  1.0628],
        [-1.3213,  0.5866, -0.6162,  1.0366, -1.5004,  0.2157,  0.8076,  0.1225,
         -0.5910,  0.7536,  0.9210, -0.3806, -0.5538,  1.2908, -0.9108, -0.4861,
         -0.1263,  0.0079, -0.3208, -1.1787, -0.5603,  0.9574, -0.5309, -0.4200,
         -0.2792,  1.0428,  2.7781],
        [-0.2687, -0.4145,  0.3681,  0.1470, -0.6927, -0.3118, -1.4349, -0.5806,
         -0.4621, -0.3751, -0.5762, -0.7392,  0.2264,  1.0696,  0.5438,  0.3717

In [33]:
logits = xenc @ W # log-counts
counts = logits.exp() # equivalent N
probs = counts / counts.sum(1, keepdims=True)
probs

tensor([[0.0294, 0.0204, 0.0212, 0.0197, 0.0016, 0.0784, 0.0438, 0.0432, 0.0061,
         0.0103, 0.0098, 0.0221, 0.1019, 0.0713, 0.0429, 0.0372, 0.0744, 0.0717,
         0.0379, 0.0299, 0.0553, 0.0223, 0.0029, 0.0105, 0.0796, 0.0491, 0.0071],
        [0.0304, 0.0259, 0.0619, 0.0127, 0.0214, 0.0056, 0.0267, 0.0036, 0.0246,
         0.0156, 0.0142, 0.0172, 0.0208, 0.0682, 0.0727, 0.0398, 0.0114, 0.0119,
         0.0315, 0.0113, 0.1368, 0.0319, 0.0185, 0.1328, 0.0199, 0.0394, 0.0931],
        [0.0055, 0.0372, 0.0112, 0.0583, 0.0046, 0.0256, 0.0463, 0.0234, 0.0114,
         0.0439, 0.0519, 0.0141, 0.0119, 0.0751, 0.0083, 0.0127, 0.0182, 0.0208,
         0.0150, 0.0064, 0.0118, 0.0538, 0.0122, 0.0136, 0.0156, 0.0586, 0.3325],
        [0.0196, 0.0169, 0.0370, 0.0297, 0.0128, 0.0188, 0.0061, 0.0143, 0.0161,
         0.0176, 0.0144, 0.0122, 0.0321, 0.0746, 0.0441, 0.0371, 0.0048, 0.0519,
         0.0649, 0.0313, 0.0637, 0.2953, 0.0114, 0.0042, 0.0132, 0.0086, 0.0471],
        [0.0098, 0.0320,

In [34]:
probs[0]

tensor([0.0294, 0.0204, 0.0212, 0.0197, 0.0016, 0.0784, 0.0438, 0.0432, 0.0061,
        0.0103, 0.0098, 0.0221, 0.1019, 0.0713, 0.0429, 0.0372, 0.0744, 0.0717,
        0.0379, 0.0299, 0.0553, 0.0223, 0.0029, 0.0105, 0.0796, 0.0491, 0.0071])

In [35]:
probs[0].shape

torch.Size([27])

In [36]:
probs[0].sum()

tensor(1.0000)

In [37]:
# (5, 27) @ (27, 27) -> (5, 27)

In [38]:
# SUMMARY ------------------------------>>>>

In [39]:
xs

tensor([  0,   5, 148, 364, 352])

In [40]:
ys

tensor([ 5, 13, 13,  1,  0])

In [41]:
# randomly initialize 27 neurons' weights. each neuron receives 729 inputs (all possible pairs)
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g)

In [42]:
xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
# btw: the last 2 lines here are together called a 'softmax'

In [43]:
probs.shape

torch.Size([5, 27])

In [44]:

nlls = torch.zeros(5)
for i in range(5):
  # i-th bigram:
  x = xs[i].item() # input character index
  y = ys[i].item() # label character index
  print('--------')
  print(f'bigram example {i+1}: {itopair[x]}{itos[y]} (indexes {x},{y})')
  print('input to the neural net:', x)
  print('output probabilities from the neural net:', probs[i])
  print('label (actual next character):', y)
  p = probs[i, y]
  print('probability assigned by the net to the the correct character:', p.item())
  logp = torch.log(p)
  print('log likelihood:', logp.item())
  nll = -logp
  print('negative log likelihood:', nll.item())
  nlls[i] = nll

print('=========')
print('average negative log likelihood, i.e. loss =', nlls.mean().item())

--------
bigram example 1: ..e (indexes 0,5)
input to the neural net: 0
output probabilities from the neural net: tensor([0.0607, 0.0100, 0.0123, 0.0042, 0.0168, 0.0123, 0.0027, 0.0232, 0.0137,
        0.0313, 0.0079, 0.0278, 0.0091, 0.0082, 0.0500, 0.2378, 0.0603, 0.0025,
        0.0249, 0.0055, 0.0339, 0.0109, 0.0029, 0.0198, 0.0118, 0.1537, 0.1459])
label (actual next character): 5
probability assigned by the net to the the correct character: 0.01228625513613224
log likelihood: -4.399273872375488
negative log likelihood: 4.399273872375488
--------
bigram example 2: .em (indexes 5,13)
input to the neural net: 5
output probabilities from the neural net: tensor([0.0290, 0.0796, 0.0248, 0.0521, 0.1989, 0.0289, 0.0094, 0.0335, 0.0097,
        0.0301, 0.0702, 0.0228, 0.0115, 0.0181, 0.0108, 0.0315, 0.0291, 0.0045,
        0.0916, 0.0215, 0.0486, 0.0300, 0.0501, 0.0027, 0.0118, 0.0022, 0.0472])
label (actual next character): 13
probability assigned by the net to the the correct character: 

In [45]:
# --------- !!! OPTIMIZATION !!! yay --------------

In [46]:
xs

tensor([  0,   5, 148, 364, 352])

In [47]:
ys

tensor([ 5, 13, 13,  1,  0])

In [48]:
# randomly initialize 27 neurons' weights. each neuron receives 27 inputs
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

In [61]:
# forward pass
xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
logits = xenc @ W # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
loss = -probs[torch.arange(5), ys].log().mean()

In [62]:
print(loss.item())

4.170337677001953


In [63]:
# backward pass
W.grad = None # set to zero the gradient
loss.backward()

In [64]:
W.data += -0.1 * W.grad

In [606]:
# --------- !!! OPTIMIZATION !!! yay, but this time actually --------------

In [65]:
# create the dataset
xs, ys = [], []
for w in words:
  chs = ['.', '.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = pairtoi[ch1+ch2]
    ix2 = stoi[ch3]
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)

# initialize the 'network'
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((729, 27), generator=g, requires_grad=True)

number of examples:  228146


In [66]:
# gradient descent
for k in range(120):
  
  # forward pass
  xenc = F.one_hot(xs, num_classes=729).float() # input to the network: one-hot encoding
  logits = xenc @ W # predict log-counts
  counts = logits.exp() # counts, equivalent to N
  probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
  loss = -probs[torch.arange(num), ys].log().mean() #+ 0.01*(W**2).mean()
  print(f'{k=}, {loss.item()}')
  
  # backward pass
  W.grad = None # set to zero the gradient
  loss.backward()
  
  # update
  W.data += -50 * W.grad

k=0, 3.7927768230438232
k=1, 3.6387429237365723
k=2, 3.5469932556152344
k=3, 3.4792749881744385
k=4, 3.4233877658843994
k=5, 3.374289035797119
k=6, 3.330148458480835
k=7, 3.28998064994812
k=8, 3.2531771659851074
k=9, 3.219308614730835
k=10, 3.18802547454834
k=11, 3.159024238586426
k=12, 3.13203501701355
k=13, 3.106820821762085
k=14, 3.0831780433654785
k=15, 3.0609352588653564
k=16, 3.039947509765625
k=17, 3.020094394683838
k=18, 3.0012738704681396
k=19, 2.9833996295928955
k=20, 2.9663970470428467
k=21, 2.9502012729644775
k=22, 2.93475604057312
k=23, 2.920011043548584
k=24, 2.905919313430786
k=25, 2.8924405574798584
k=26, 2.8795382976531982
k=27, 2.8671772480010986
k=28, 2.8553264141082764
k=29, 2.843956470489502
k=30, 2.833040952682495
k=31, 2.8225536346435547
k=32, 2.812471628189087
k=33, 2.80277156829834
k=34, 2.79343318939209
k=35, 2.7844369411468506
k=36, 2.7757630348205566
k=37, 2.7673940658569336
k=38, 2.7593135833740234
k=39, 2.7515058517456055
k=40, 2.743955612182617
k=41, 2.73

In [74]:
# The loss for nn trigram model is about the same as for bigram. Quality of generation is not much better.
# NN also can't achieve 2.21 trigram statistical model result. Counting model has exact answers,
# it counted trigrams. On the other hand with nn we are trying to learn these counts from data using gradient
# descent.

In [71]:
# finally, sample from the 'neural net' model
g = torch.Generator().manual_seed(2147483647 + 191)

for _ in range(5):
  
  out = ['.']
  i = 0
  while True:
    
    # ----------
    # BEFORE:
    #p = P[ix]
    # ----------
    # NOW:
    xenc = F.one_hot(torch.tensor([i]), num_classes=729).float()
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    p = counts / counts.sum(1, keepdims=True) # probabilities for next character
    # ----------
    
    j = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
    out.append(itos[j])
    if j == 0:
      break
    # update index i
    pair = ''.join(out[-2:])  # last 2 chars
    i = pairtoi[pair]
  print(''.join(out[1:]))

jana.
szmjutson.
ar.
xwuxaimmarocu.
ruar.


In [72]:
W[0]

tensor([-2.7120,  2.1271,  0.9092,  1.0755,  1.1673,  1.0684, -0.2357,  0.2389,
         0.5069,  0.1146,  1.5275,  1.7292,  1.0948,  1.5743,  0.7783, -0.2927,
        -0.0236, -1.7571,  1.1366,  1.3630,  0.9107, -1.8276, -0.3398, -0.5439,
        -1.3649,  0.0147,  0.5680], grad_fn=<SelectBackward0>)

In [73]:
W[1]

tensor([-1.6692,  0.5492,  0.4784, -1.2648,  1.1573, -0.6648, -0.9993, -1.4115,
        -0.3538,  0.1504, -0.9233, -0.6285,  1.7133,  1.2062,  1.6988, -1.3918,
        -1.1454, -1.7814,  1.4383,  0.5018, -0.4341,  0.2430,  0.7190, -1.3506,
        -0.9531,  0.3832,  0.1936], grad_fn=<SelectBackward0>)