In [None]:
words = open("names.txt", "r").read().splitlines()
len(min(words, key=len)), len(max(words, key=len))

In [None]:
import torch

N = torch.zeros((28, 28), dtype=torch.int32) # we use int32 because we want to count the number of occurences

b = {}
for w in words:
    # adding special character <S> and <E> to indicate start and end of a word
    chars = ['<S>'] + list(w) + ['<E>']
    for char1, char2 in zip(chars, chars[1:]):
        bigram = char1 , char2
        b[bigram] = b.get(bigram, 0) + 1 # increment for each occurence of bigram

sorted(b.items(), key=lambda x: x[1], reverse=True)
b

In [None]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1  for i, s in enumerate(chars)} # character to index dictionary
stoi['.'] = 0
stoi

In [None]:
import torch

N = torch.zeros((27, 27),dtype=torch.int32) # we use int32 because we want to count the number of occurences

In [None]:
for w in words:
    chars = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chars, chars[1:]):    
        ix1, ix2 = stoi[char1], stoi[char2] # get index of each character tuple
        N[ix1, ix2] += 1 # increment for each occurence of bigram

In [None]:
itos = {i:s for s, i in stoi.items()} # index to character dictionary

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline

plt.figure(figsize=(16, 16))
plt.imshow(N, cmap='Blues')
for i in range(27):
    for j in range(27):
        chstr = itos[i] + itos[j]
        # display bigram bottom and its count top
        plt.text(j, i, chstr, ha='center', va='bottom', color='gray')
        plt.text(j, i, N[i, j].item(), ha='center', va='top', color='gray')
plt.axis('off')

In [None]:
p = N[0].float() 
p = p / p.sum()

In [None]:
g = torch.Generator().manual_seed(2147483647 + 1)
ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
itos[ix]

In [None]:
P = (N + 1).float() # add-one smoothing
print(P.sum(1, keepdim=True))
P /= P.sum(1, keepdim=True)


In [None]:
g = torch.Generator().manual_seed(2147483647)
# the model give us better result that a random model but it's still not good
for i in range(10):
    ix = 0 # P[0] is the distribution of the first bigram
    out = []
    while True:
        # next character is sampled from the distribution of the current bigram
        ix = torch.multinomial(P[ix], num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0: break # if bigram end with the special character then stop
    print(''.join(out))


In [None]:
# log likelihood is the log of the product of the probabilities of each bigram)
# and we know that log(a*b*c) = log(a) + log(b) + log(c)
log_likelihood = 0
n = 0
for w in words:
    chars = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chars, chars[1:]):    
        ix1, ix2 = stoi[char1], stoi[char2] # get index of each character tuple
        prob = P[ix1, ix2]
        logProb = torch.log(prob)
        log_likelihood += logProb
        n += 1
# problem is : we want the loss function to be minimized, but the log likelihood should be maximized
nll = -log_likelihood # negative log likelihood
print(f'{nll=}') 
print(f'{nll/n=}') # normalized negative log likelihood, average instead of sum


In [None]:
# How to cast the bigram model as a neural network?

In [768]:
# create the training set of bigrams

xs, ys = [], [] # input and target

for w in words:
    chars = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chars, chars[1:]):    
        ix1, ix2 = stoi[char1], stoi[char2]
        xs.append(ix1)
        ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
xs.shape

torch.Size([228146])

In [811]:
import torch.nn.functional as F

W = torch.randn((27, 27), requires_grad=True, generator=g) # first dim is the number of classes, second dim is the number of neurons
stepSize = 2000
for i in range(stepSize):
    # forward pass
    xenc = F.one_hot(xs, num_classes=27).float()
    logits = (xenc @ W) # log-counts, we consider that these are counts but in log space(which allow us to use negative values)
    # since we are using log-counts, we need to use exp() to get the counts
    counts = (xenc @ W).exp() # bringing negative values back to positive
    prob = counts / counts.sum(1, keepdim=True) # normalize to get probabilities
    # two line above are equivalent to softmax
    loss = -prob[torch.arange(len(xs)), ys].log().mean() # using torch advanced indexing
    if i % 100 == 0: print(f'Step: {i + 1}, loss: {loss}')
    # backward pass
    W.grad = None
    loss.backward()

    # update
    lr = 2
    W.data -= lr * W.grad.data
print(f'loss: {loss:.2f}')

Step: 1, loss: 3.769151449203491
Step: 101, loss: 2.912829875946045
Step: 201, loss: 2.720900535583496
Step: 301, loss: 2.643017530441284
Step: 401, loss: 2.599700927734375
Step: 501, loss: 2.571789503097534
Step: 601, loss: 2.5523715019226074
Step: 701, loss: 2.5380971431732178
Step: 801, loss: 2.5271337032318115
Step: 901, loss: 2.5184268951416016
Step: 1001, loss: 2.511338233947754
Step: 1101, loss: 2.505459785461426
Step: 1201, loss: 2.500516653060913
Step: 1301, loss: 2.496314764022827
Step: 1401, loss: 2.4927122592926025
Step: 1501, loss: 2.489600896835327
Step: 1601, loss: 2.486896514892578
Step: 1701, loss: 2.4845306873321533
Step: 1801, loss: 2.4824490547180176
Step: 1901, loss: 2.4806056022644043
loss: 2.48


In [810]:
# for step size = 2000
# lr : 0.5 -> loss : 2.57
# lr : 0.75 -> loss : 2.53
# lr : 1 -> loss : 2.51
# lr : 1.25 -> loss : 2.50
# lr : 1.5 -> loss : 2.49 (bellow 2.5 between 1501 and 1601)
# lr : 2.0 -> loss : 2.48 (bellow 2.5 between 1201 and 1301)

In [None]:
# We interpret prob as the following: each row is a distribution of the next character given the current character
# We have a way to mesure the quality of this distribution: the negative log libkelihood
# During the backpropagation, we are fine tuning the weights of the neural network to minimize the negative log likelihood

In [803]:
for i in range(10):
    ix = 0
    out = []
    while True:
        ix = torch.multinomial(prob[ix], num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0: break 
    print(''.join(out))


dytiymsvezturrj.
lt.
mfkeo.
ear.
hoptlai.
in.
j.
lnalaveamn.
tlavea.
md.
