In [142]:
import pandas as pd
import torch
import torch.nn.functional as F
import random

words = []
# with open("./data/western_names.txt") as file:
#     words = file.read().splitlines()

with open("./data/indian_names.csv", encoding="utf-8") as file:
    df = pd.read_csv(file, header=None)
    words = df.iloc[:, 0].dropna().tolist()
    words = [w.lower() for w in words if w.isalpha()]
    words = words[1:]
words[0:10]
min(len(w) for w in words)
max(len(w) for w in words)
len("".join(w for w in words)) + len(words)  # total chars from all words +len(words) for the dots

45594

In [143]:
unique_chars = sorted(list(set("".join(words))))

char_to_idx = {ch: i + 1 for i, ch in enumerate(unique_chars)}
char_to_idx["."] = 0
idx_to_char = {i: ch for ch, i in char_to_idx.items()}
vocab_size = len(char_to_idx)
context_size = 3
embed_size = 2


def build_dataset(words):
    X, Y = [], []
    for w in words:
        context = [0] * context_size
        for ch in w + ".":
            idx = char_to_idx[ch]
            X.append(context)
            Y.append(idx)
            context = context[1:] + [idx]
    X = torch.tensor(X)
    Y = torch.tensor(Y)
    return X, Y


random.seed(42)
random.shuffle(words)
n1 = int(0.8 * len(words))
n2 = int(0.9 * len(words))
Xtr, Ytr = build_dataset(words[:n1])
Xval, Yval = build_dataset(words[n1:n2])
Xtest, Ytest = build_dataset(words[n2:])

In [144]:
g = torch.Generator().manual_seed(2147483647)

E = torch.randn((vocab_size, embed_size), generator=g)
W1 = torch.randn((context_size * embed_size, 100), generator=g)
b1 = torch.randn(100, generator=g)
W2 = torch.randn((100, vocab_size), generator=g)
b2 = torch.randn(vocab_size, generator=g)

parameters = [E, W1, b1, W2, b2]
for p in parameters:
    p.requires_grad = True

In [149]:
X = Xtr[:1000]
Y = Ytr[:1000]

for _ in range(1000):
    # forward pass
    I = E[X].view(-1, context_size * embed_size) # Input layer
    H = torch.tanh(I @ W1 + b1)  # Hidden layer
    L = H @ W2 + b2  # Logit layer
    loss = F.cross_entropy(L, Y)
    print(loss.item())

    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()

    # update parameters
    for p in parameters:
        p.data += 0.1 * -p.grad

1.5356131792068481
1.5354820489883423
1.5353511571884155
1.5352205038070679
1.5350897312164307
1.534958839416504
1.5348284244537354
1.5346975326538086
1.5345669984817505
1.5344364643096924
1.5343059301376343
1.5341752767562866
1.5340449810028076
1.5339144468307495
1.5337841510772705
1.533653736114502
1.5335233211517334
1.533393144607544
1.5332629680633545
1.5331326723098755
1.5330027341842651
1.5328724384307861
1.5327423810958862
1.5326123237609863
1.5324822664260864
1.532352328300476
1.5322225093841553
1.5320926904678345
1.5319627523422241
1.5318330526351929
1.5317033529281616
1.5315734148025513
1.53144371509552
1.5313141345977783
1.531184434890747
1.5310550928115845
1.5309253931045532
1.5307958126068115
1.530666470527649
1.5305370092391968
1.5304077863693237
1.5302784442901611
1.5301493406295776
1.5300196409225464
1.529890537261963
1.529761552810669
1.5296322107315063
1.5295032262802124
1.529374122619629
1.529245138168335
1.529116153717041
1.528987169265747
1.5288583040237427
1.52872

In [150]:
for _ in range(20):
    samples = []
    context = [0] * context_size
    while True:
        I = E[torch.tensor([context])].view(-1, context_size * embed_size)
        H = torch.tanh(I @ W1 + b1)
        logits = H @ W2 + b2
        probs = F.softmax(logits, dim=1)
        idx = torch.multinomial(probs, num_samples=1).item()
        samples.append(idx)
        context = context[1:] + [idx]
        if idx == 0:
            break
    print("Generated samples:", ''.join(idx_to_char[i] for i in samples))

Generated samples: kit.
Generated samples: akik.
Generated samples: shian.
Generated samples: kumadeep.
Generated samples: shind.
Generated samples: rajekk.
Generated samples: ramul.
Generated samples: ramprasameesh.
Generated samples: amin.
Generated samples: miup.
Generated samples: vander.
Generated samples: andunhpal.
Generated samples: salman.
Generated samples: sooh.
Generated samples: shtashishanidejirisi.
Generated samples: abir.
Generated samples: amin.
Generated samples: raj.
Generated samples: salal.
Generated samples: praepna.
