<a href="https://colab.research.google.com/github/abhishekvaid/makemore/blob/master/2_makemore_collab.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import matplotlib.pyplot as plt
%matplotlib inline
import torch.nn.functional as F
import random

In [3]:
# download the names.txt file from github
!wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt

--2024-08-30 19:00:22--  https://raw.githubusercontent.com/karpathy/makemore/master/names.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.111.133, 185.199.109.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 228145 (223K) [text/plain]
Saving to: ‘names.txt’


2024-08-30 19:00:22 (6.25 MB/s) - ‘names.txt’ saved [228145/228145]



In [4]:
words = names = open("names.txt").read().splitlines()
chars = sorted(list(set(''.join(names))))
stoi = { ch: i+1 for i, ch in enumerate(chars)}
stoi["."] = 0
itos = { v:k for k, v in stoi.items()}

In [5]:
WINDOW_SIZE = 3

def create_dataset(words):
    X, Y = [], []
    for w in words[:]:
        window = [0] * WINDOW_SIZE
        # print(w)
        for ch in w + ".":
            ix = stoi[ch]
            X.append(window)
            Y.append(ix)
            window = window[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

# Create dataset
random.shuffle(words)
i, j = int(0.8*len(words)), int(0.9*len(words))

Xtr, Ytr = create_dataset(words[:i])
Xval, Yval = create_dataset(words[i:j])
Xtest, Ytest = create_dataset(words[j:])

In [7]:
def find_loss(X, Y):
    emb = Emb[X].view(-1, EMB_SIZE*WINDOW_SIZE)
    a1 = torch.tanh(emb @ W1 + b1)
    logits = a1 @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    return loss.item()

In [8]:
# Setup Architecture of NN
LEARNING_RATE = 0.1

EMB_SIZE = 10
g = torch.Generator().manual_seed(2147483647)
Emb = torch.randn((len(stoi), EMB_SIZE))
W1 = torch.randn((EMB_SIZE*WINDOW_SIZE, 200), generator=g)
b1 = torch.randn(200, generator=g)
W2 = torch.randn( (200, len(stoi)), generator=g)
b2 = torch.randn(len(stoi), generator=g)
parameters = [Emb, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True


In [9]:
sum(p.nelement() for p in parameters) # number of parameters in total

11897

In [15]:
def train(X, Y, init_network=False):
    if init_network:
        init_network()
    losses = []
    for i in range(200_000):
        idxs = torch.randint(0, X.shape[0], (32,))
        emb = Emb[X[idxs]].view(-1, EMB_SIZE*WINDOW_SIZE)
        a1 = torch.tanh(emb @ W1 + b1)
        logits = a1 @ W2 + b2

        loss = F.cross_entropy(logits, Y[idxs])
        losses.append(loss.item())

        # Backward Pass
        for p in parameters:
            p.grad = None
        loss.backward()
        lr = LEARNING_RATE if i < 100_000 else (LEARNING_RATE / 10)
        for p in parameters:
            p.data += -lr * p.grad
    return losses
Losses = train(Xtr, Ytr)

In [13]:
print(find_loss(Xtr, Ytr))
print(find_loss(Xval, Yval))
print(find_loss(Xtest, Ytest))

2.0744848251342773
2.136138677597046
2.1370372772216797


In [14]:
# g = torch.Generator().manual_seed(2147483647)
for i in range(100):
    xidxs = [0]*WINDOW_SIZE
    chars = []
    while True:
        I = Emb[torch.tensor(xidxs)].view(-1, WINDOW_SIZE*EMB_SIZE)
        # print(X.shape)
        a1 = torch.tanh(I @ W1 + b1)
        logits = (a1 @ W2 + b2)
        probs = F.softmax(logits, dim=1)
        yidx = torch.multinomial(probs, num_samples=1, replacement=False, generator=g).item()
        xidxs = (xidxs + [yidx])[-3:]
        chars.append(itos[yidx])
        if yidx == 0:
            break
    print("".join(chars))

ame.
rydonta.
alin.
javerly.
eara.
tayvea.
marmigaaria.
keslee.
mevin.
kayden.
giav.
lura.
adrie.
meh.
elua.
lan.
makehaie.
griffer.
genna.
brocka.
der.
alec.
tarril.
jakarretta.
andian.
amano.
devalishwassel.
rote.
samansa.
drias.
ruxte.
arion.
kamelyna.
kalinnany.
brayntlea.
sarryn.
maiba.
trija.
yairicholstonalee.
edri.
royk.
jasiya.
uct.
avaigu.
ela.
greyanna.
tahman.
safa.
gera.
arott.
rissa.
artandra.
nikhritalvyn.
ammarson.
kari.
isse.
eve.
lawe.
avarthi.
achoy.
zavishaashawnonalon.
jaylonsa.
lociuzahir.
bhait.
jayri.
ulia.
asze.
nesush.
jayrose.
taliel.
wrajon.
tester.
nicle.
mira.
arsheliya.
cocklynn.
asimarya.
way.
jereyla.
zire.
keyana.
bryni.
evelle.
avea.
dee.
jak.
delanes.
tobi.
broorsekeya.
ren.
andraiyet.
cyphnoah.
emetre.
luc.
hazhna.
mabianly.
samarden.
jaizurie.
desabel.
trand.
