In [9]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F

In [5]:
words = open('names.txt', 'r').read().splitlines()
bigram_dict = {}
for word in words:
    chs = ['<S>'] + list(word) + ['<E>']
    for ch1, ch2 in zip(chs, chs[1:]):
        bigram = (ch1, ch2)
        bigram_dict[bigram] = bigram_dict.get(bigram, 0) + 1

In [7]:
# Pre-process the data, creating a map of strings to ints to convert the input
# as well as an int to string mapping to convert the output

chars = sorted(list(set(''.join(words))))
stoi = { s: i+1 for i,s in enumerate(chars) }
stoi['.'] = 0
itos = { i: s for s,i in stoi.items() }

In [8]:
# Bigrams kinda suck
# Shall we use a neural network?
xs, ys = [], []
for w in words[:1]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        rowIdx = stoi[ch1]
        colIdx = stoi[ch2]
        xs.append(rowIdx)
        ys.append(colIdx)
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [42]:
# Forward pass
yenc = F.one_hot(ys, num_classes=27).float()
xenc = F.one_hot(xs, num_classes=27).float()

g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27,27), generator=g, requires_grad=True) # 27 features / 27 neurons
logits = xenc @ W
# Softmax
counts = logits.exp()
probs = counts / counts.sum(1, keepdims=True)
loss = -probs[torch.arange(5), ys].log().mean()

In [43]:
W.grad = None
loss.backward()
loss.item()

3.7693049907684326

In [44]:
# Compute gradient, i.e. the influence of the weight on the loss function
W.grad[:3]

tensor([[ 0.0121,  0.0020,  0.0025,  0.0008,  0.0034, -0.1975,  0.0005,  0.0046,
          0.0027,  0.0063,  0.0016,  0.0056,  0.0018,  0.0016,  0.0100,  0.0476,
          0.0121,  0.0005,  0.0050,  0.0011,  0.0068,  0.0022,  0.0006,  0.0040,
          0.0024,  0.0307,  0.0292],
        [-0.1970,  0.0017,  0.0079,  0.0020,  0.0121,  0.0062,  0.0217,  0.0026,
          0.0025,  0.0010,  0.0205,  0.0017,  0.0198,  0.0022,  0.0046,  0.0041,
          0.0082,  0.0016,  0.0180,  0.0106,  0.0093,  0.0062,  0.0010,  0.0066,
          0.0131,  0.0101,  0.0018],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000,  0.0000]])

In [45]:
# Update the Weight tensor
W.data += -0.1 * W.grad

In [65]:
# Now it's time to refactor the code to make it concise 
# and actually train the network on the entire training set

In [79]:
# Model parameters
learning_rate = 50

xs, ys = [], []
for w in words:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        rowIdx = stoi[ch1]
        colIdx = stoi[ch2]
        xs.append(rowIdx)
        ys.append(colIdx)
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('Number of examples: ', num)

g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27,27), generator=g, requires_grad=True) # 27 features / 27 neurons

Number of examples:  228146


In [80]:
for k in range(100):
    # forward pass
    xenc = F.one_hot(xs, num_classes=27).float() # input to the network w/ one hot encoding
    logits = xenc @ W # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
    loss = -probs[torch.arange(num), ys].log().mean()
    print(loss.item())
    
    # backward pass
    W.grad = None # reset the gradient to 0
    loss.backward()
    
    # update
    W.data += -learning_rate * W.grad


3.758953809738159
3.3710994720458984
3.1540415287017822
3.020372152328491
2.9277102947235107
2.8604013919830322
2.809727907180786
2.77010178565979
2.738072395324707
2.711496353149414
2.6890029907226562
2.6696884632110596
2.6529300212860107
2.638277053833008
2.6253881454467773
2.613990545272827
2.60386323928833
2.5948216915130615
2.5867116451263428
2.5794036388397217
2.572789430618286
2.5667762756347656
2.5612878799438477
2.5562586784362793
2.551633596420288
2.547365665435791
2.5434155464172363
2.5397486686706543
2.5363364219665527
2.533154249191284
2.5301806926727295
2.5273966789245605
2.5247862339019775
2.522334575653076
2.520029067993164
2.5178580284118652
2.515810489654541
2.513878345489502
2.512052059173584
2.510324001312256
2.5086867809295654
2.5071346759796143
2.5056614875793457
2.504261016845703
2.5029289722442627
2.5016605854034424
2.5004522800445557
2.4992990493774414
2.498197317123413
2.497144937515259
2.4961376190185547
2.495173692703247
2.4942493438720703
2.493363380432129
