In [1]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
words = open("names.txt").read().split("\n")

### Trigram count-based model

In [9]:
trigrams = {}

for w in words:
    w = "." + w + "."
    for ch in zip(w, w[1:], w[2:]):
        trigrams[ch] = trigrams.get(ch, 0) + 1

In [25]:
# 1 for smoothing
p = torch.zeros(27, 27, 27) + 1
p.shape, p.dtype

(torch.Size([27, 27, 27]), torch.float32)

In [30]:
chars = sorted(["."] + list(set("".join(words))))
stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for i, ch in enumerate(chars)}

In [31]:
# populate the trigram
for key in trigrams:
    i0, i1, i2 = stoi[key[0]], stoi[key[1]], stoi[key[2]]
    p[i0, i1, i2] = trigrams[key]
p

tensor([[[  1.,   1.,   1.,  ...,   1.,   1.,   1.],
         [  1., 207., 190.,  ...,  27., 173., 152.],
         [  1., 169.,   1.,  ...,   1.,   4.,   1.],
         ...,
         [  1.,  57.,   1.,  ...,   1.,  17.,  11.],
         [  1., 246.,   1.,  ...,   1.,   1.,   2.],
         [  1., 456.,   1.,  ...,   1.,  91.,   1.]],

        [[  1.,   1.,   1.,  ...,   1.,   1.,   1.],
         [ 40.,   1.,   5.,  ...,   1.,  20.,  11.],
         [ 36.,  28.,  20.,  ...,   1.,  12.,   1.],
         ...,
         [ 11.,   5.,   1.,  ...,  17.,   6.,   3.],
         [163., 389.,  13.,  ...,   1.,  16.,  40.],
         [ 38., 123.,   1.,  ...,   1.,  12.,  22.]],

        [[  1.,   1.,   1.,  ...,   1.,   1.,   1.],
         [ 46.,   5.,   5.,  ...,   4.,  31.,   4.],
         [  1.,   8.,   1.,  ...,   1.,   9.,   1.],
         ...,
         [  1.,   1.,   1.,  ...,   1.,   1.,   1.],
         [ 55.,   4.,   1.,  ...,   1.,   1.,   1.],
         [  1.,   1.,   1.,  ...,   1.,   1.,   1.]],

In [32]:
P = p / p.sum(dim=2, keepdim=True)

In [108]:
probs_given_an = P[stoi["a"], stoi['n']]
probs_given_an # probability distribution given "a" followed by "n"

tensor([2.7749e-01, 1.4785e-01, 3.6778e-04, 1.5263e-02, 7.2085e-02, 4.1192e-02,
        1.1033e-03, 2.0228e-02, 1.6550e-03, 1.2928e-01, 4.0456e-03, 6.4362e-03,
        2.5745e-03, 2.2067e-03, 1.5171e-01, 1.8021e-02, 5.5167e-04, 1.8389e-04,
        1.4711e-03, 1.8205e-02, 3.2549e-02, 9.1946e-03, 9.0107e-03, 7.3556e-04,
        3.6778e-04, 3.2181e-02, 4.0456e-03])

In [117]:
g = torch.Generator().manual_seed(1)
idx = torch.multinomial(probs_given_an, num_samples=1, replacement=True, generator=g)
idx.item(), itos[idx.item()], probs_given_an[idx.item()] # 15% chance that "n" comes after "an", i.e. "ann"

(0, '.', tensor(0.2775))

In [131]:
# sampling from the trigram model
g = torch.Generator().manual_seed(681236)
for i in range(5):
    # start at ".", "." -> this is made possible because we increment the count by 1 of all combination of input tuples
    i0, i1 = 0, 0
    ans = []
    while True:
        probs = P[i0, i1]
        i0 = i1
        i1 = torch.multinomial(probs, num_samples=1, replacement=True, generator=g).item()
        ch = itos[i1]
        if ch == ".":
            break
        ans.append(ch)

    print("".join(ans))

ma
kodan
lisonnaelailcia
niyannaz
velietcgjwpxwbi


In [141]:
log_likelihood = 0
n = 0
for w in words:
    w = "." + w + "."
    for ch in zip(w, w[1:], w[2:]):
        prob = P[stoi[ch[0]], stoi[ch[1]], stoi[ch[2]]]
        log_likelihood += torch.log(prob)
        n += 1

nll = -log_likelihood / n # averaged negative log likehood

In [142]:
nll

tensor(2.0944)

### Trigram NN model

In [157]:
# encodes all possible permutations of two consecutive characters, e.g. [.a .b .c ... aa ab ac ... za zb zc ... zz]
i = 0
enc = {}
for ch1 in chars:
    for ch2 in chars:
        print(ch1, ch2, "|", stoi[ch1], stoi[ch2])
        enc[(stoi[ch1], stoi[ch2])] = i
        i += 1

. . | 0 0
. a | 0 1
. b | 0 2
. c | 0 3
. d | 0 4
. e | 0 5
. f | 0 6
. g | 0 7
. h | 0 8
. i | 0 9
. j | 0 10
. k | 0 11
. l | 0 12
. m | 0 13
. n | 0 14
. o | 0 15
. p | 0 16
. q | 0 17
. r | 0 18
. s | 0 19
. t | 0 20
. u | 0 21
. v | 0 22
. w | 0 23
. x | 0 24
. y | 0 25
. z | 0 26
a . | 1 0
a a | 1 1
a b | 1 2
a c | 1 3
a d | 1 4
a e | 1 5
a f | 1 6
a g | 1 7
a h | 1 8
a i | 1 9
a j | 1 10
a k | 1 11
a l | 1 12
a m | 1 13
a n | 1 14
a o | 1 15
a p | 1 16
a q | 1 17
a r | 1 18
a s | 1 19
a t | 1 20
a u | 1 21
a v | 1 22
a w | 1 23
a x | 1 24
a y | 1 25
a z | 1 26
b . | 2 0
b a | 2 1
b b | 2 2
b c | 2 3
b d | 2 4
b e | 2 5
b f | 2 6
b g | 2 7
b h | 2 8
b i | 2 9
b j | 2 10
b k | 2 11
b l | 2 12
b m | 2 13
b n | 2 14
b o | 2 15
b p | 2 16
b q | 2 17
b r | 2 18
b s | 2 19
b t | 2 20
b u | 2 21
b v | 2 22
b w | 2 23
b x | 2 24
b y | 2 25
b z | 2 26
c . | 3 0
c a | 3 1
c b | 3 2
c c | 3 3
c d | 3 4
c e | 3 5
c f | 3 6
c g | 3 7
c h | 3 8
c i | 3 9
c j | 3 10
c k | 3 11
c l | 3 12
c m | 

In [160]:
# 26, 26 corresponds to zz
enc[(26, 26)], len(enc)

(728, 729)

In [184]:
# prepare training data
xs, ys = [], []
for w in words:
    w = "." + w + "."
    for ch in zip(w, w[1:], w[2:]):
        x = (stoi[ch[0]], stoi[ch[1]])
        y = stoi[ch[2]]

        # use the encoding for x to generate 729 (27^2) unique indexes (each representing a unique permutation of two characters)
        xs.append(enc[x])
        ys.append(y)

In [185]:
xs[:10], ys[:10]

([5, 148, 364, 352, 15, 417, 333, 265, 603, 244],
 [13, 13, 1, 0, 12, 9, 22, 9, 1, 0])

In [186]:
len(xs), len(ys)

(196113, 196113)

In [187]:
xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [188]:
import torch.nn.functional as F

In [176]:
W = torch.randn(27 * 27, 27).float()

one_hot_enc = F.one_hot(xs, num_classes=len(enc)).float() # (196113, 729)
logits = one_hot_enc @ W                          # (196113, 27)
counts = logits.exp() 
probs = counts / counts.sum(1, keepdim=True)
probs


tensor([[0.0318, 0.0544, 0.0488,  ..., 0.0359, 0.0234, 0.0115],
        [0.0958, 0.0978, 0.0305,  ..., 0.0198, 0.0120, 0.0403],
        [0.0126, 0.0046, 0.0515,  ..., 0.0097, 0.0062, 0.0198],
        ...,
        [0.0936, 0.0121, 0.0714,  ..., 0.0120, 0.0382, 0.0205],
        [0.0209, 0.0107, 0.0017,  ..., 0.1180, 0.0145, 0.0256],
        [0.0818, 0.0179, 0.0247,  ..., 0.0314, 0.0452, 0.0202]])

In [180]:
probs[0].shape, len(probs), probs[0].sum()

(torch.Size([27]), 196113, tensor(1.))

In [181]:
nll = -probs[torch.arange(len(xs)), ys].log().mean()
nll

tensor(3.7848)

In [193]:
g = torch.Generator().manual_seed(123)
W = torch.randn(27 * 27, 27, requires_grad=True, generator=g).float()

In [194]:
# gradient descent
for _ in range(100):
    xenc = F.one_hot(xs, num_classes=len(enc)).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(xs.nelement()), ys].log().mean() 
    print(loss.item())

    # backpass
    W.grad = None
    loss.backward()

    # update params
    W.data += -10 * W.grad



3.738034248352051
3.722470998764038
3.7070956230163574
3.691904067993164
3.6768932342529297
3.6620590686798096
3.647399663925171
3.6329126358032227
3.618595838546753
3.6044468879699707
3.5904648303985596
3.576647996902466
3.5629961490631104
3.5495083332061768
3.5361838340759277
3.523022174835205
3.510023593902588
3.497187614440918
3.4845142364501953
3.472003936767578
3.4596571922302246
3.4474737644195557
3.4354536533355713
3.423597574234009
3.411905288696289
3.400377035140991
3.3890130519866943
3.377812147140503
3.3667750358581543
3.3559012413024902
3.3451900482177734
3.3346402645111084
3.324252128601074
3.314023017883301
3.303952932357788
3.2940409183502197
3.2842843532562256
3.274681806564331
3.2652313709259033
3.255932092666626
3.246781826019287
3.2377774715423584
3.2289178371429443
3.220200538635254
3.211623191833496
3.203183174133301
3.1948790550231934
3.1867077350616455
3.178666830062866
3.1707541942596436
3.1629679203033447
3.1553046703338623
3.1477630138397217
3.140340089797973