In [1]:
import torch
import matplotlib.pyplot as plt

In [2]:
words = open('names.txt','r').read().splitlines()

In [3]:
len(words)

32033

In [4]:
words[0]

'emma'

In [5]:
alphabets = sorted(list(set(''.join(words))))
stoi ={s:i+1 for i,s in enumerate(alphabets)}
stoi["."] = 0
itos = {i:s for s,i in stoi.items()}



In [6]:
N = torch.zeros((27,27))

In [7]:
#Counting

for w in words:
    ch = "." + w + "."
    for chr1,chr2 in zip(ch,ch[1:]):
        ix1 = stoi[chr1]
        ix2 = stoi[chr2]
        N[ix1,ix2] += 1

In [8]:
#Create probability distribution
P = N / N.sum(1,keepdim=True)
P.shape

torch.Size([27, 27])

In [9]:
# Sample

word = ""
ch = "."
while True:
    
    ix = stoi[ch]
    word += itos[torch.multinomial(P[ix],num_samples=1).item()]
    ch = word[-1]
    if ch == ".":
        break    
print(word)

zasesabll.


Tri-gram

In [10]:
string_pairs = [a+b for a in stoi.keys() for b in stoi.keys()]
print(f"{len(string_pairs) = }")
print(f"{string_pairs[0] = }")
print(f"{string_pairs[1] = }")

len(string_pairs) = 729
string_pairs[0] = 'aa'
string_pairs[1] = 'ab'


In [11]:
sptoi = {s:i for i,s in enumerate(string_pairs)}
itosp = {i:s for s,i in sptoi.items()}

In [12]:
tri_N = torch.zeros((729,27))

In [13]:
#Counting trigrams

for w in words:
    ch = ".." + w + "."
    for chr1,chr2,chr3 in zip(ch,ch[1:],ch[2:]):
        ix1 = sptoi[chr1+chr2]
        ix2 = stoi[chr3]
        tri_N[ix1,ix2] += 1

In [14]:
tri_N[sptoi["aa"],stoi["."]]

tensor(40.)

In [15]:
#Create probability distribution
tri_P = tri_N / tri_N.sum(1,keepdim=True)

In [16]:
# Sample

word = ".."
ch = ".."
while True:
    ix = sptoi[ch]
    word += itos[torch.multinomial(tri_P[ix],num_samples=1).item()]
    ch = word[-2:]
    if word[-1] == ".":
        break    
print(word)

..re.


Calculating loss

In [17]:

likelihood = 0
n =0 
for w in words:
    ch = "." + w +"."
    for chr1,chr2 in zip(ch,ch[1:]):
        ix1 = stoi[chr1]
        ix2 = stoi[chr2]
        likelihood += torch.log(P[ix1,ix2])
        n += 1
        nll = -likelihood

print(f"{nll=}")
print(f"{nll/n=}")

nll=tensor(559891.7500)
nll/n=tensor(2.4541)


In [18]:
trigram_likelihood = 0
n = 0

for w in words:
    ch = ".." + w +"."
    for chr1,chr2,chr3 in zip(ch,ch[1:],ch[2:]):
        ix1 = sptoi[chr1 + chr2]
        ix2 = stoi[chr3]
        trigram_likelihood += torch.log(tri_P[ix1,ix2])
        n += 1
        tri_nll = -trigram_likelihood

print(f"{tri_nll=}")
print(f"{tri_nll/n=}")


tri_nll=tensor(498647.7812)
tri_nll/n=tensor(2.1857)


In [19]:
#Generating the most likely name according to bigram and trigram models

for al in alphabets:
    ch = al
    word = ch
    while True:
        
        ix = stoi[ch]
        word += itos[torch.argmax(P[ix]).item()]
        ch = word[-1]
        if ch == ".":
            break    

    print(al,"->",word)


a -> a.
b -> bri.
c -> ca.
d -> da.
e -> e.
f -> fa.
g -> gh.
h -> h.
i -> i.
j -> ja.
k -> ka.
l -> le.
m -> ma.
n -> n.
o -> on.
p -> pa.
q -> qush.
r -> ri.
s -> sh.
t -> ta.
u -> ush.
v -> vi.
w -> wa.
x -> x.
y -> ya.
z -> za.


In [20]:
import torch.nn.functional as F
xs = []
ys = []
for w in words:
    ch = "." + w + "."
    for chr1,chr2 in zip(ch,ch[1:]):
        ix1 = stoi[chr1]
        ix2 = stoi[chr2]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)



In [21]:
W = torch.randn((27,27),requires_grad=True)
x_enc = F.one_hot(xs,num_classes=27)
x_enc = x_enc.T.float()



In [22]:
counts = (W @ x_enc).exp()
probs = counts/counts.sum(dim=0,keepdim=True)
logits = probs[ys,torch.arange(0,len(ys))]
loss = -logits.log().mean()
loss


tensor(3.8105, grad_fn=<NegBackward0>)

In [23]:
#Training cycle

for i in range(500):
    counts = (W @ x_enc).exp()
    probs = counts/counts.sum(dim=0,keepdim=True)
    logits = probs[ys,torch.arange(0,len(ys))]
    loss = -logits.log().mean()

    W.grad = None

    loss.backward()

    W.data -= 10 * W.grad
    if i%50 == 0:
        print(loss)

    



tensor(3.8105, grad_fn=<NegBackward0>)
tensor(2.6702, grad_fn=<NegBackward0>)
tensor(2.5658, grad_fn=<NegBackward0>)
tensor(2.5278, grad_fn=<NegBackward0>)
tensor(2.5086, grad_fn=<NegBackward0>)
tensor(2.4969, grad_fn=<NegBackward0>)
tensor(2.4891, grad_fn=<NegBackward0>)
tensor(2.4835, grad_fn=<NegBackward0>)
tensor(2.4793, grad_fn=<NegBackward0>)
tensor(2.4760, grad_fn=<NegBackward0>)


In [24]:
#Sampling
ch = "."
word = ""
while True:
    
    x = F.one_hot(torch.tensor(stoi[ch]),num_classes=27).float()
    x = torch.reshape(x,(-1,1))
    N = (W @ x).exp()
    probs = N/N.sum(dim=0,keepdim=True)
    probs = probs.T
    pred = torch.multinomial(probs,num_samples=1)
    next_ch = itos[pred.item()]
    ch = next_ch
   

    if next_ch == ".":
        break
    word += next_ch

print(word)

ry


In [25]:
xs = []
ys = []
for w in words:
    ch = "." + w + "."
    for chr1,chr2 in zip(ch,ch[1:]):
        ix1 = stoi[chr1]
        ix2 = stoi[chr2]
        xs.append(ix1)
        ys.append(ix2)

xs = torch.tensor(xs)
ys = torch.tensor(ys)

In [79]:
W = torch.randn((27,27),requires_grad=True)


In [89]:
for i in range(5):
    counts = (W @ x_enc).exp()
    probs = counts/counts.sum(dim=0,keepdim=True)
    logits = probs[ys,torch.arange(0,len(ys))]
    loss = -logits.log().mean()

    W.grad = None

    loss.backward()

    W.data -= 10 * W.grad


In [90]:
counts = (W @ x_enc).exp()
probs = counts/counts.sum(dim=0,keepdim=True)
true_probs = P[ys].T

mse_loss = ((probs - true_probs)**2).mean()
rmse_loss = torch.sqrt(mse_loss)
rmse_loss

tensor(0.0616, grad_fn=<SqrtBackward0>)

In [91]:
W = torch.randn((27,27),requires_grad=True)

In [121]:
for i in range(10):
    counts = (W @ x_enc).exp()
    probs = counts/counts.sum(dim=0,keepdim=True)
    true_probs = P[ys].T

    mse_loss = ((probs - true_probs)**2).mean()
    rmse_loss = torch.sqrt(mse_loss)
    rmse_loss

    W.grad = None

    rmse_loss.backward()

    W.data -= 1000 * W.grad

In [122]:
counts = (W @ x_enc).exp()
probs = counts/counts.sum(dim=0,keepdim=True)
logits = probs[ys,torch.arange(0,len(ys))]
loss = -logits.log().mean()
loss

tensor(10.2226, grad_fn=<NegBackward0>)