In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import defaultdict
import random
import pandas as pd
import matplotlib.pyplot as plt

In [None]:
bigrams = defaultdict(int)

with open("names.txt", "r") as file:
    for line in file:
        words = line.split()
        for word in words:
            bigram = '^' + word[0]
            bigrams[bigram] += 1
            for i in range(len(word)-1):
                bigram = word[i:i+2]
                bigrams[bigram] += 1
            bigram = word[-1] + '$'
            bigrams[bigram] += 1

In [None]:
total_bigrams = sum(bigrams.values())
probabilities = {bigram: frequency/total_bigrams for bigram, frequency in bigrams.items()}

In [None]:
df = pd.DataFrame.from_dict(probabilities, orient='index', columns=['Probability'])
df.index.name = 'Bigram'
styled_df = df.style.background_gradient()
styled_df

In [None]:
Bigram 
^e 0.006711
em 0.003371
mm 0.000736
ma 0.011352
a$ 0.029104
^o 0.001727
ol 0.002713
li 0.010870
iv 0.001179
vi 0.003993
ia 0.010717
^a 0.019330
av 0.003656
va 0.002814
^i 0.002590
is 0.005768
sa 0.005264
ab 0.002371
be 0.002871
el 0.014236
ll 0.005895
la 0.011497
^s 0.009007
so 0.002327
op 0.000416
ph 0.000894
hi 0.003195
^c 0.006759
ch 0.002910
ha 0.009836
ar 0.014307
rl 0.001810
lo 0.003033
ot 0.000517
tt 0.001639
te 0.003138
e$ 0.017458
^m 0.011124
mi 0.005505
am 0.007162
me 0.003585
^h 0.003831
rp 0.000061
pe 0.000863
er 0.008582
r$ 0.006036
ev 0.002029
ve 0.002490
ly 0.006960
yn 0.008004
n$ 0.029643
bi 0.000951
ig 0.001876
ga 0.001446
ai 0.007232
il 0.005895
l$ 0.005759
y$ 0.008797
iz 0.001214
za 0.003770
et 0.002542
th 0.002836
h$ 0.010559
ry 0.003388
of 0.000149
fi 0.000701
ca 0.003572
ri 0.013294
sc 0.000263
le 0.012803
t$ 0.002117
^v 0.001648
ic 0.002231
ct 0.000153
to 0.002924
or 0.004642
ad 0.004567
di 0.002954
on 0.010568
^l 0.006890
lu 0.001420
un 0.001205
na 0.013049
^g 0.002932
gr 0.000881
ra 0.010327
ac 0.002060
ce 0.002415
hl 0.000811
oe 0.000579
^p 0.002257
en 0.011725
ne 0.005957
ay 0.008985
yl 0.004839
^r 0.007184
ey 0.004690
^z 0.004072
zo 0.000482
^n 0.005023
no 0.002174
ea 0.002976
an 0.023836
nn 0.008354
ah 0.010222
dd 0.000653
au 0.001670
ub 0.000451
br 0.003691
re 0.007438
ie 0.007245
st 0.003353
at 0.003011
ta 0.004502
al 0.011081
az 0.001907
ze 0.001635
io 0.002577
ur 0.001815
ro 0.003809
ud 0.000596
dr 0.001858
^b 0.005724
oo 0.000504
ok 0.000298
kl 0.000609
cl 0.000508
ir 0.003721
sk 0.000359
ky 0.001661
uc 0.000451
cy 0.000456
pa 0.000916
sl 0.001223
in 0.009319
ov 0.000771
ge 0.001464
es 0.003774
si 0.002998
s$ 0.005124
^k 0.012987
ke 0.003923
ed 0.001683
dy 0.001389
nt 0.001942
ya 0.009393
^w 0.001346
wi 0.000649
ow 0.000500
w$ 0.000224
ki 0.002231
ns 0.001219
ao 0.000276
om 0.001144
i$ 0.010910
aa 0.002437
iy 0.003414
de 0.005624
co 0.001666
ru 0.001105
by 0.000364
se 0.003875
ni 0.007561
it 0.002371
ty 0.001495
ut 0.000359
tu 0.000342
um 0.000675
mn 0.000088
gi 0.000833
ti 0.002332
^q 0.000403
qu 0.000903
ui 0.000530
ae 0.003033
eh 0.000666
vy 0.000530
pi 0.000267
ip 0.000232
yd 0.001192
ex 0.000579
xa 0.000451
^j 0.010616
jo 0.002100
os 0.002209
ep 0.000364
ju 0.000885
ul 0.001319
^d 0.007408
ka 0.007587
ee 0.005571
yt 0.000456
dl 0.000263
ck 0.001385
nz 0.000636
zi 0.001595
ag 0.000736
da 0.005711
ja 0.006456
he 0.002954
^x 0.000587
xi 0.000447
im 0.001872
ei 0.003585
^t 0.005733
^f 0.001828
fa 0.001061
nd 0.003086
rg 0.000333
as 0.004900
sh 0.005632
ba 0.001407
kh 0.001346
sm 0.000394
od 0.000833
rs 0.000833
gh 0.001578
sy 0.000942
ys 0.001758
ss 0.002021
ec 0.000671
ci 0.001188
mo 0.001981
rk 0.000394
nl 0.000855
dn 0.000136
rd 0.000820
oi 0.000302
tr 0.001543
mb 0.000491
rm 0.000710
ny 0.002038
do 0.001657
oa 0.000653
oc 0.000500
my 0.001258
su 0.000811
mc 0.000224
pr 0.000662
ou 0.001205
rn 0.000614
wa 0.001227
eb 0.000530
cc 0.000184
aw 0.000706
wy 0.000320
ye 0.001319
eo 0.001179
ak 0.002490
ng 0.001197
ko 0.001508
bl 0.000451
ho 0.001258
eg 0.000548
fr 0.000500
sp 0.000224
ls 0.000412
yz 0.000342
gg 0.000110
zu 0.000320
id 0.001929
m$ 0.002262
og 0.000193
je 0.001929
gn 0.000118
yr 0.001275
c$ 0.000425
cq 0.000048
ue 0.000741
if 0.000443
fe 0.000539
ix 0.000390
x$ 0.000719
oy 0.000451
go 0.000364
gt 0.000136
lt 0.000338
gw 0.000114
we 0.000653
ld 0.000605
ap 0.000359
hn 0.000605
tl 0.000587
mr 0.000425
nc 0.000934
lb 0.000228
ik 0.001951
^y 0.002345
tz 0.000460
hr 0.000894
ji 0.000522
ht 0.000311
rr 0.001863
zl 0.000539
wr 0.000096
bb 0.000167
rt 0.000912
lv 0.000316
ej 0.000241
oh 0.000750
us 0.002078
ib 0.000482
gl 0.000140
hy 0.000934
po 0.000259
pp 0.000171
py 0.000053
nr 0.000193
zm 0.000153
vo 0.000671
lm 0.000263
ox 0.000197
d$ 0.002262
iu 0.000478
v$ 0.000386
ff 0.000193
bo 0.000460
ek 0.000780
cr 0.000333
dg 0.000110
rc 0.000434
rh 0.000530
nk 0.000254
hu 0.000728
ds 0.000127
ax 0.000798
yc 0.000504
ew 0.000219
vk 0.000013
zh 0.000188
wh 0.000101
tn 0.000096
xl 0.000171
gu 0.000373
ua 0.000714
up 0.000070
ug 0.000206
du 0.000403
lc 0.000110
rb 0.000180
aq 0.000263
b$ 0.000500
gy 0.000136
yp 0.000066
pt 0.000075
ez 0.000793
zr 0.000140
fl 0.000088
o$ 0.003748
ob 0.000614
uz 0.000197
z$ 0.000701
iq 0.000228
yv 0.000465
nv 0.000241
dh 0.000517
gd 0.000083
ts 0.000153
nh 0.000114
yj 0.000101
kr 0.000478
zb 0.000018
g$ 0.000473
aj 0.000767
rj 0.000110
mp 0.000167
pb 0.000009
yo 0.001188
zy 0.000644
pl 0.000070
lk 0.000105
ij 0.000333
xe 0.000158
yu 0.000618
ln 0.000061
ux 0.000149
ih 0.000416
ws 0.000088
ks 0.000416
mu 0.000609
yk 0.000377
ef 0.000359
k$ 0.001591
ym 0.000649
zz 0.000197
md 0.000105
sr 0.000241
eu 0.000302
lh 0.000083
af 0.000587
rw 0.000092
nu 0.000421
vr 0.000210
ms 0.000153
^u 0.000342
fs 0.000026
yb 0.000118
xo 0.000180
gs 0.000131
xy 0.000131
wn 0.000254
jh 0.000197
fn 0.000018
nj 0.000193
rv 0.000351
nm 0.000083
tc 0.000075
sw 0.000105
kt 0.000075
ft 0.000079
xt 0.000307
uv 0.000162
kk 0.000088
sn 0.000105
u$ 0.000679
jr 0.000048
yx 0.000123
hm 0.000513
eq 0.000061
uo 0.000044
f$ 0.000351
hz 0.000088
hk 0.000127
yg 0.000131
qr 0.000004
vn 0.000035
sd 0.000039
yi 0.000842
nw 0.000048
dv 0.000075
hv 0.000171
xw 0.000013
oz 0.000237
ku 0.000219
uh 0.000254
kn 0.000114
sb 0.000092
ii 0.000359
yy 0.000101
rz 0.000101
lg 0.000026
lp 0.000066
p$ 0.000145
bu 0.000197
fu 0.000044
bh 0.000180
fy 0.000061
uw 0.000377
xu 0.000022
q$ 0.000123
lr 0.000079
mh 0.000022
lw 0.000070
j$ 0.000311
sv 0.000061
ml 0.000022
nf 0.000048
uj 0.000061
fo 0.000263
jl 0.000039
tg 0.000009
jm 0.000022
vv 0.000031
ps 0.000070
tw 0.000048
xc 0.000018
uk 0.000408
vl 0.000061
hd 0.000105
lz 0.000044
kw 0.000149
nb 0.000035
qs 0.000009
iw 0.000035
cs 0.000022
hs 0.000136
mt 0.000018
hw 0.000044
xx 0.000167
tx 0.000009
dz 0.000004
xz 0.000083
tm 0.000018
tj 0.000013
uq 0.000044
qa 0.000057
fk 0.000009
zn 0.000018
lj 0.000026
jw 0.000026
vu 0.000031
cj 0.000013
hb 0.000035
zt 0.000018
pu 0.000018
mz 0.000048
xs 0.000136
bt 0.000009
uy 0.000057
dj 0.000039
js 0.000031
wu 0.000110
oj 0.000070
bs 0.000035
dw 0.000101
wo 0.000158
jn 0.000009
wt 0.000035
lf 0.000096
dm 0.000131
pj 0.000004
jy 0.000044
yf 0.000053
qi 0.000057
jv 0.000022
ql 0.000004
sz 0.000044
km 0.000039
wl 0.000057
pf 0.000004
qw 0.000013
nx 0.000026
kc 0.000009
tv 0.000066
cu 0.000153
zk 0.000009
cz 0.000018
yq 0.000026
yh 0.000096
rf 0.000039
sj 0.000009
hj 0.000039
gb 0.000013
uf 0.000083
sf 0.000009
qe 0.000004
bc 0.000004
cd 0.000004
zj 0.000009
nq 0.000009
mf 0.000004
pn 0.000004
fz 0.000009
bn 0.000018
wd 0.000035
wb 0.000004
bd 0.000285
zs 0.000018
pc 0.000004
hg 0.000009
mj 0.000031
ww 0.000009
kj 0.000009
hp 0.000004
jk 0.000009
oq 0.000013
fw 0.000018
fh 0.000004
wm 0.000009
bj 0.000004
rq 0.000070
zc 0.000009
zv 0.000009
fg 0.000004
np 0.000022
zg 0.000004
dt 0.000018
wf 0.000009
df 0.000022
wk 0.000026
qm 0.000009
kz 0.000009
jj 0.000009
cp 0.000004
pk 0.000004
pm 0.000004
jd 0.000018
rx 0.000013
xn 0.000004
dc 0.000013
gj 0.000013
xf 0.000013
jc 0.000018
sq 0.000004
kf 0.000004
zp 0.000009
jt 0.000009
kb 0.000009
mk 0.000004
mw 0.000009
xh 0.000004
hf 0.000009
xd 0.000022
yw 0.000018
zw 0.000013
dk 0.000013
cg 0.000009
uu 0.000013
tf 0.000009
gm 0.000026
mv 0.000013
cx 0.000013
hc 0.000009
gf 0.000004
qo 0.000009
lq 0.000013
vb 0.000004
jp 0.000004
kd 0.000009
gz 0.000004
vd 0.000004
db 0.000004
vh 0.000004
kv 0.000009
hh 0.000004
sg 0.000009
gv 0.000004
dq 0.000004
xb 0.000004
wz 0.000004
hq 0.000004
jb 0.000004
zd 0.000009
xm 0.000004
wg 0.000004
tb 0.000004
zx 0.000004

In [None]:
df.plot(kind='bar', legend=None)
plt.ylabel('Probability')
plt.title('Bigram Probabilities')
plt.show()

In [None]:
bigram_probabilities = [(bigram, probability) for bigram, probability in probabilities.items()]
#create a mapping between bigrams and indices to facilitate their use in the neural network
bigram_to_idx = {bigram: idx for idx, bigram in enumerate(probabilities.keys())}

In [None]:
class NameGenerator(nn.Module):
    def __init__(self, bigram_probs, input_size, hidden_size, output_size):
        super(NameGenerator, self).__init__()
        self.bigram_probs = bigram_probs
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.lstm = nn.LSTM(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, x, hidden):
        x = x.view(1, 1, -1)
        output, hidden = self.lstm(x, hidden)
        output = self.fc(output)
        output = self.softmax(output)
        return output, hidden

    def init_hidden(self):
        return (torch.zeros(1, 1, self.hidden_size), torch.zeros(1, 1, self.hidden_size))

In [None]:
def generate(model):
    name = random.choice([key for key in bigram_to_idx.keys() if key[0] == '^'])[1]

    while name[-1] != '$':
        last_char = name[-1]
        next_bigrams = {bigram: probability for bigram, probability in probabilities.items() if bigram[0] == last_char}

        if not next_bigrams:
            break

        next_bigram_probs = torch.tensor([probability for bigram, probability in next_bigrams.items()], dtype=torch.float32)
        next_bigram_idx = torch.multinomial(next_bigram_probs, 1).item()
        next_bigram = list(next_bigrams.keys())[next_bigram_idx]
        name += next_bigram[1]

        if next_bigram[1] == '$':
            name = name[:-1]
            break
    return name

In [None]:
input_size = len(bigrams)
hidden_size = 128
output_size = len(bigrams)

model = NameGenerator(bigram_probabilities, input_size, hidden_size, output_size)
#the loss function (Negative Log-Likelihood Loss)
criterion = nn.NLLLoss()
#the optimizer (Adam)
optimizer = optim.Adam(model.parameters(), lr=0.001)

epochs = 100
batch_size = 1

In [None]:
for epoch in range(epochs):
    for i in range(batch_size):
        model.zero_grad()
        loss = torch.tensor(0.0, requires_grad=True)
    #a generated name may contain bigrams that are not present in the probabilities dictionary, so
    #we need to update the training loop to ensure that we only generate valid names
    
    valid_name = False
    while not valid_name:
        try:
            name = generate(model)
            input_tensor = torch.tensor([bigram_to_idx[name[i:i+2]] for i in range(len(name) - 1)], dtype=torch.long)
            target_tensor = torch.tensor([bigram_to_idx[name[i+1:i+3]] for i in range(len(name) - 1)], dtype=torch.long)
            valid_name = True
        except KeyError:
            continue
            
        hidden = model.init_hidden()
        for input_value, target_value in zip(input_tensor, target_tensor):
            input_value = input_value.unsqueeze(0)
            output, hidden = model(input_value, hidden)
            loss += criterion(output, target_value.unsqueeze(0))

        loss.backward()
        optimizer.step()

In [None]:
    #print the settings of the neural network model
    if (epoch + 1) % 100 == 0:
        print(f"Epoch {epoch + 1}, Loss: {loss.item() / len(name)}")

In [None]:
#print the name
print("Generated name:")
print(generate(model))