In [16]:
from net import MLP, grad_descent, parse_txt, char_tokenize, itos, stoi
from net.util import tanh, dtanh, cross_entropy, SEED

import numpy as np

In [17]:
# Hyper Parameters
BLOCK_SIZE = 4
FEATURES = 8
VOCAB_SIZE = 27
size = ((BLOCK_SIZE - 1) * FEATURES, 48, VOCAB_SIZE)

In [18]:
# Name data
names = parse_txt("../data/names.txt")
tokens, vocab = char_tokenize(names, BLOCK_SIZE)
str_to_int = stoi(vocab)
int_to_str = itos(vocab)

xs = tokens[:, :(BLOCK_SIZE - 1)]
ys = tokens[:, -1]

# Splits
x_train = xs[:259395]
y_train = ys[:259395]
x_test = xs[259395:291819]
y_test = ys[259395:291819]
x_dev = xs[291819:]
y_dev = ys[291819:]

In [19]:
# Initialize emb matrix
emb = np.random.randn(len(vocab), FEATURES)

# Initialize model
name_net = MLP(size, tanh, dtanh, emb)

In [20]:
# Train
grad_descent(name_net, x_train, y_train, 20, 10, 10000, 0.1)

epoch: 1/10 done.
epoch: 2/10 done.
epoch: 3/10 done.
epoch: 4/10 done.
epoch: 5/10 done.
epoch: 6/10 done.
epoch: 7/10 done.
epoch: 8/10 done.
epoch: 9/10 done.
epoch: 10/10 done.


In [32]:
# Check entropy
name_net.forward(x_train)
print(np.linalg.norm(name_net.weights[2].value))
preds = np.max(name_net.layers[-1].value, axis=1)
cross_entropy(preds, y_train)

7245.4836939795905


np.float64(0.010489119185961812)

In [33]:
# Generate
for _ in range(10):
    input = [str_to_int['.']] * (BLOCK_SIZE - 1) # SEED
    out = ""
    while '.' not in out:
        name_net.forward(np.array(input))
        probs = name_net.layers[-1].value
        print(np.reshape(probs, (probs.size)))
        i = np.random.choice(VOCAB_SIZE, p=np.reshape(probs, (probs.size)))
        out += int_to_str[i]
        input = input[1:] + [i]
    print(out)

[7.42017028e-275 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 7.33847811e-168 0.00000000e+000
 0.00000000e+000 1.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000]
[0.00000000e+000 8.33861232e-081 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 4.56406615e-082 0.00000000e+000
 0.00000000e+000 0.00000000e+000 1.00000000e+000 1.48145647e-155
 0.00000000e+000 0.00000000e+000 0.00000000e+000]
[0.00000000e+000 0.00000000e+000 0.00000000e+000 0.00000000e+000
 0.00000000e+000 0.00000000e+000 0.00000000e+000 0.0000