In [4]:
import torch
import torch.nn.functional as F

In [5]:
words = open('names.txt', 'r').read().splitlines()
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

train_end = round(len(words) * 0.8)
dev_end = round((len(words) - train_end)/2) + train_end

In [6]:
# Create training set for bigrams
bxs, bys = [], []
for w in words[:train_end]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]

        bxs.append(ix1)
        bys.append(ix2)

bxs = torch.tensor(bxs)
bys = torch.tensor(bys)
bnum = bxs.nelement()

# Initalize network
bg = torch.Generator().manual_seed(2147483647)
bW = torch.randn((27, 27), generator=bg, requires_grad=True)

In [7]:
# Create training set for trigrams
tx1 = []
tx2 = []
tys = []
tnum = 0

for w in words[:train_end]:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        id1 = stoi[ch1]
        id2 = stoi[ch2]
        id3 = stoi[ch3]

        tx1.append(id1)
        tx2.append(id2 + 27)
        tys.append(id3)

tx1 = torch.tensor(tx1)
tx2 = torch.tensor(tx2)
tys = torch.tensor(tys)
tnum = tys.nelement()

# Initalize Network
tg = torch.Generator().manual_seed(2147483647)
tW = torch.randn((54, 27), generator=tg, requires_grad=True)

In [8]:
# BIGRAM MODEL

In [9]:
# Gradient descent
for k in range(1000):
    # Forward Pass
    xenc = F.one_hot(bxs, num_classes=27).float() # Input to the network: one-hot encoding
    logits = xenc @ bW # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdim=True) # probabilities for next character
    loss = -probs[torch.arange(bnum), bys].log().mean()

    # Backward pass
    bW.grad = None # set to zero the gradient
    loss.backward()

    bW.data += -50 * bW.grad
print(loss.item())

2.425746202468872


In [10]:
# TRIGRAM MODEL

In [11]:
xenc = F.one_hot(tx1, num_classes=54).float() + F.one_hot(tx2, num_classes=54).float()

for k in range(1000):
    logits = xenc @ tW
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(tnum), tys].log().mean()

    tW.grad = None
    loss.backward()

    tW.data += -50 * tW.grad

print(loss.item())

2.342146873474121


In [12]:
# Create dev set for bigrams
bxs, bys = [], []
for w in words[train_end:dev_end]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]

        bxs.append(ix1)
        bys.append(ix2)

bxs = torch.tensor(bxs)
bys = torch.tensor(bys)
bnum = bxs.nelement()

# Initalize network
bg = torch.Generator().manual_seed(2147483647)
bW = torch.randn((27, 27), generator=bg, requires_grad=True)

In [13]:
# Create dev set for trigrams
tx1 = []
tx2 = []
tys = []
tnum = 0

for w in words[train_end:dev_end]:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        id1 = stoi[ch1]
        id2 = stoi[ch2]
        id3 = stoi[ch3]

        tx1.append(id1)
        tx2.append(id2 + 27)
        tys.append(id3)

tx1 = torch.tensor(tx1)
tx2 = torch.tensor(tx2)
tys = torch.tensor(tys)
tnum = tys.nelement()

In [14]:
# Gradient descent
for k in range(100):
    # Forward Pass
    xenc = F.one_hot(bxs, num_classes=27).float() # Input to the network: one-hot encoding
    logits = xenc @ bW # predict log-counts
    counts = logits.exp() # counts, equivalent to N
    probs = counts / counts.sum(1, keepdim=True) # probabilities for next character
    loss = -probs[torch.arange(bnum), bys].log().mean() + 0.01 * (bW**2).mean()

    # Backward pass
    bW.grad = None # set to zero the gradient
    loss.backward()

    bW.data += -30 * bW.grad
print(loss.item())

2.533405303955078


In [15]:
xenc = F.one_hot(tx1, num_classes=54).float() + F.one_hot(tx2, num_classes=54).float()

for k in range(100):
    logits = xenc @ tW
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdim=True)
    loss = -probs[torch.arange(tnum), tys].log().mean() + 0.01 * (tW**2).mean()

    tW.grad = None
    loss.backward()

    tW.data += -30 * tW.grad

print(loss.item())

2.386350631713867


In [16]:
# Create test set for bigrams
bxs, bys = [], []
for w in words[dev_end:]:
    chs = ['.'] + list(w) + ['.']
    for ch1, ch2 in zip(chs, chs[1:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]

        bxs.append(ix1)
        bys.append(ix2)

bxs = torch.tensor(bxs)
bys = torch.tensor(bys)
bnum = bxs.nelement()

In [17]:
# Create test set for trigrams
tx1 = []
tx2 = []
tys = []
tnum = 0

for w in words[dev_end:]:
    chs = ['.', '.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        id1 = stoi[ch1]
        id2 = stoi[ch2]
        id3 = stoi[ch3]

        tx1.append(id1)
        tx2.append(id2 + 27)
        tys.append(id3)

tx1 = torch.tensor(tx1)
tx2 = torch.tensor(tx2)
tys = torch.tensor(tys)
tnum = tys.nelement()

In [18]:
# Bigram Forward Pass
xenc = F.one_hot(bxs, num_classes=27).float() # Input to the network: one-hot encoding
logits = xenc @ bW # predict log-counts
counts = logits.exp() # counts, equivalent to N
probs = counts / counts.sum(1, keepdim=True) # probabilities for next character
loss = -probs[torch.arange(bnum), bys].log().mean()
print(loss.item())

2.5418505668640137


In [19]:
# Trigram Forward Pass
xenc = F.one_hot(tx1, num_classes=54).float() + F.one_hot(tx2, num_classes=54).float()    
logits = xenc @ tW
counts = logits.exp()
probs = counts / counts.sum(1, keepdim=True)
loss = -probs[torch.arange(tnum), tys].log().mean()
print(loss.item())


2.41545033454895


In [20]:
# Bigram Output Testing

for x in range(10):
    cur = torch.zeros(27)
    cur[0] = 1

    chars = []

    while True:

        logits = cur @ bW # predict log-counts
        counts = logits.exp() # counts, equivalent to N
        probs = counts / counts.sum() # probabilities for next character
        ix = torch.multinomial(probs, num_samples=1, replacement=True, generator=bg).item()
        if ix == 0:
            break
        chars.append(itos[ix])
        cur = torch.zeros(27)
        cur[ix] = 1


    print(''.join(chars))

jigla
sadrqr
brixzydavesolen
pshabasha
n
r
mas
tharze
brylon
kym


In [21]:
# Trigram Output Testing

for k in range(10):
    input = torch.zeros(54)
    f = 0
    s = 0

    out = []
    while True:
        logits = input @ tW
        counts = logits.exp()
        probs = counts / counts.sum()

        idx = torch.multinomial(probs, num_samples=1, replacement=True, generator=tg).item()

        if idx == 0:
            break
        out.append(itos[idx])

        f = s
        s = idx

        input = torch.zeros(54)
        input[f] = 1
        input[s + 27] = 1
    print(''.join(out))

aellen
fedyrocre
was
kon
num
blu
jeaglon
asiani
lamaron
wan
