In [9]:
import matplotlib.pyplot as plt
import torch
%matplotlib inline

In [5]:
words = open('names.txt', 'r').read().splitlines()

In [None]:
words[:10]

In [None]:
len(words)

In [None]:
min(len(word) for word in words), max(len(word) for word in words)

In [None]:
b = {}
for w in words:
    chs = ['<S>'] + list(w) + ['<E>']
    for char1, char2 in zip(chs, chs[1:]):
        bigram = (char1, char2)
        b[bigram] = b.get(bigram, 0) + 1
        #print(char1,char2)

In [None]:
sorted(b.items(), key = lambda kv : kv[1], reverse=True)

In [None]:
N = torch.zeros((27,27), dtype = torch.int32)

In [7]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [None]:
itos

In [None]:
for w in words:
    chs = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chs, chs[1:]):
        ix1 = stoi[char1]
        ix2 = stoi[char2]
        N[ix1,ix2] +=1 #gives number of occurences of each pair of ix1,ix2

In [None]:
plt.figure(figsize=(15,15))
plt.imshow(N, cmap = 'Blues')
for i in range(27):
    for j in range(27):
        chstr = itos[i] +itos[j]
        plt.text(j, i, chstr, ha="center", va="bottom", color = "gray")
        plt.text(j, i, N[i,j].item(), ha="center", va="top", color = "gray")
plt.axis('off');

In [None]:
N[0]

In [None]:
#pd for first row
p = N[0].float()
p /= p.sum()
p

In [None]:
# sum of pd will always be 1
sum(p)

In [None]:
# to create a deterministic random number generator
g = torch.Generator().manual_seed(2147483647)
p = torch.rand(3, generator = g)
p /= p.sum()
p

In [None]:
g = torch.Generator().manual_seed(2147483647)
ix = torch.multinomial(p, num_samples = 1, replacement=True, generator = g).item()
itos[ix]

In [None]:
# to sample data using a RNG
torch.multinomial(p, num_samples = 100, replacement = True, generator=g)

In [None]:
P = (N+1).float(), #probability matrixt
P = P[0]

In [None]:
P /= P.sum(1, keepdim=True) 
P[0]

In [None]:
g = torch.Generator().manual_seed(2147483647)
for i in range(4):
    out =[]
    index = 0
    while True:
        p = P[ix]
        #p /= p.sum()
        ix = torch.multinomial(p, num_samples = 1, replacement = True, generator = g).item()
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))

In [None]:
log_likelihood = 0.0
n = 0
for w in words:
#for w in ["john"]:
    chs = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chs, chs[1:]):
        ix1 = stoi[char1]
        ix2 = stoi[char2]
        prob = P[ix1,ix2]
        logprob = torch.log(prob)
        log_likelihood += logprob
        n += 1
       # print(f'{char1}{char2}: {prob:.4f}')
print(f'{log_likelihood=}')
nll = -log_likelihood
print(f'{nll=}')
avg_nll = nll/n
print(f'{avg_nll=:.4f}')
# maximizing likelihood -> maximizing log -> minimizing nll -> minimizing avg

In [None]:
#NN approach
xs,ys = [], [] #input and expected output
for w in words[:1]:
    chs = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chs, chs[1:]):
        ix1 = stoi[char1]
        ix2 = stoi[char2]
        xs.append(ix1)
        ys.append(ix2)
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [None]:
print(xs, ys)

In [10]:
import torch.nn.functional as F

In [None]:
xenc = F.one_hot(xs, num_classes = 27).float()
xenc

In [None]:
xenc.shape, xenc.dtype

In [None]:
W = torch.randn((27,1)) # one neuron 27 inputs
xenc @ W
W = torch.randn((27,27), generator = g, requires_grad = True) # 27 neurons with 27 inputs each
xenc @ W #((5,27) @ (27,27) =(5,27))
#These are outputs for 5 inputs. For each input we get the amount by which each of the 27 neurons 
# will be activated. since by hot encoding only desired char has value one, rest of products will be zero
# so sum(w*x) will give value for hot encoded char only as rest of x values are zero

In [None]:
#forward pass
xenc = F.one_hot(xs, num_classes = 27).float()

logits = xenc @ W
counts = logits.exp()
probs = counts / counts.sum(1,keepdim=True)
#above two lines are called softmax
#probs
loss = -probs[torch.arange(5), ys].log().mean() #probs of output, then taking log and average to get nll

In [None]:
loss.item()

In [None]:
probs[0].sum(), probs.shape

In [None]:
#for each index we want probability of *output* for hot encoded char only, like 
probs[0,5], probs[1,13], probs[2,13], probs[3,1], probs[4,0]
# we can do that using arange in torch # revise this cell

In [None]:
#backward pass
W.grad = None
loss.backward()

In [None]:
#update
W = -0.01 * W.grad

In [None]:
# Entire optimization part

In [11]:
#create the dataset
xs,ys = [], [] #input and expected output
for w in words:
    chs = ['.'] + list(w) + ['.']
    for char1, char2 in zip(chs, chs[1:]):
        ix1 = stoi[char1]
        ix2 = stoi[char2]
        xs.append(ix1)
        ys.append(ix2)
num = len(xs)
print("number of examples: ", num)
xs = torch.tensor(xs)
ys = torch.tensor(ys)

#initializing the network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27,27), generator = g, requires_grad = True)

number of examples:  228146


In [72]:
#grad des
for k in range(10):
    #forward pass
    xenc = F.one_hot(xs, num_classes = 27).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(1,keepdim=True)
    loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
    print(loss.item())
   # print(logits[0])
    
    #backward pass
    W.grad = None
    loss.backward()
    
    #update
    W.data += -((500-k)/100) * W.grad

2.5060033798217773
2.5059330463409424
2.5058624744415283
2.5057926177978516
2.505722999572754
2.5056540966033936
2.5055856704711914
2.5055174827575684
2.5054502487182617
2.505382537841797


In [73]:
#sampling from NN
g = torch.Generator().manual_seed(2147483647)

for i in range(20):
    out = []
    ix = 0
    while True:
        xenc = F.one_hot(torch.tensor([ix]), num_classes = 27).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1,keepdim=True)
        ix = torch.multinomial(p, num_samples = 1, replacement = True, generator = g).item()
        out.append(itos[ix])
        if ix == 0:
            break
    print(''.join(out))

junide.
janasah.
pxzfay.
a.
nn.
kohin.
tolian.
jgee.
ksaheiauyanilevias.
dbdainrwieta.
sejaielylarte.
faveumerifontume.
phynslenaruani.
core.
yaenon.
ka.
jabrinerimikimwynin.
anaasn.
ssorionszxh.
dgosfbrian.


In [None]:
logits[0]