In [None]:
import numpy as np
import tensorflow as tf
path = tf.keras.utils.get_file('nietzsche.txt', origin='https://s3.amazonaws.com/text-datasets/nietzsche.txt')

Downloading data from https://s3.amazonaws.com/text-datasets/nietzsche.txt
[1m600901/600901[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
with open(path, 'r', encoding='utf-8') as f:
    text_data = f.read().lower()

print(f"Text length: {len(text_data)} characters")

# Build vocabulary
chars = sorted(list(set(text_data)))
vocab_size = len(chars)

char_to_ix = {ch: i for i, ch in enumerate(chars)}
ix_to_char = {i: ch for i, ch in enumerate(chars)}

print(f"Unique characters: {vocab_size}")

Text length: 600893 characters
Unique characters: 57


In [None]:
hidden_size = 128
seq_length = 40
learning_rate = 0.1

Wxh = np.random.randn(hidden_size, vocab_size) * 0.01
Whh = np.random.randn(hidden_size, hidden_size) * 0.01
Why = np.random.randn(vocab_size, hidden_size) * 0.01
bh = np.zeros((hidden_size, 1))
by = np.zeros((vocab_size, 1))

In [None]:
def lossFun(inputs, targets, hprev):
    xs, hs, ys, ps = {}, {}, {}, {}
    hs[-1] = np.copy(hprev)
    loss = 0

    # Forward pass
    for t in range(len(inputs)):
        xs[t] = np.zeros((vocab_size, 1))
        xs[t][inputs[t]] = 1
        hs[t] = np.tanh(np.dot(Wxh, xs[t]) + np.dot(Whh, hs[t-1]) + bh)
        ys[t] = np.dot(Why, hs[t]) + by
        ps[t] = np.exp(ys[t]) / np.sum(np.exp(ys[t]))
        loss += -np.log(ps[t][targets[t], 0])

    # Backward pass
    dWxh, dWhh, dWhy = np.zeros_like(Wxh), np.zeros_like(Whh), np.zeros_like(Why)
    dbh, dby = np.zeros_like(bh), np.zeros_like(by)
    dhnext = np.zeros_like(hs[0])

    for t in reversed(range(len(inputs))):
        dy = np.copy(ps[t])
        dy[targets[t]] -= 1
        dWhy += np.dot(dy, hs[t].T)
        dby += dy

        dh = np.dot(Why.T, dy) + dhnext
        dh_raw = (1 - hs[t] * hs[t]) * dh
        dbh += dh_raw
        dWxh += np.dot(dh_raw, xs[t].T)
        dWhh += np.dot(dh_raw, hs[t-1].T)
        dhnext = np.dot(Whh.T, dh_raw)

    for dparam in [dWxh, dWhh, dWhy, dbh, dby]:
        np.clip(dparam, -5, 5, out=dparam)

    return loss, dWxh, dWhh, dWhy, dbh, dby, hs[len(inputs)-1]

In [None]:
n_iters = 200
pointer = 0
hprev = np.zeros((hidden_size, 1))

for i in range(n_iters):
    if pointer + seq_length + 1 >= len(text_data):
        pointer = 0
        hprev = np.zeros((hidden_size, 1))

    inputs = [char_to_ix[ch] for ch in text_data[pointer:pointer+seq_length]]
    targets = [char_to_ix[ch] for ch in text_data[pointer+1:pointer+seq_length+1]]

    loss, dWxh, dWhh, dWhy, dbh, dby, hprev = lossFun(inputs, targets, hprev)

    # Parameter update
    for param, dparam in zip([Wxh, Whh, Why, bh, by],
                             [dWxh, dWhh, dWhy, dbh, dby]):
        param -= learning_rate * dparam

    if i % 10 == 0:
        print(f"Iteration {i}, Loss: {loss:.4f}")

    pointer += seq_length

Iteration 0, Loss: 161.7165
Iteration 10, Loss: 127.1467
Iteration 20, Loss: 124.5567
Iteration 30, Loss: 1143.0973
Iteration 40, Loss: 1841.8782
Iteration 50, Loss: 1816.4803
Iteration 60, Loss: 1482.9562
Iteration 70, Loss: 1622.8940
Iteration 80, Loss: 2333.0877
Iteration 90, Loss: 1910.2226
Iteration 100, Loss: 1462.7172
Iteration 110, Loss: 2056.8951
Iteration 120, Loss: 1382.9880
Iteration 130, Loss: 1814.2720
Iteration 140, Loss: 2295.7315
Iteration 150, Loss: 1870.7998
Iteration 160, Loss: 2332.6360
Iteration 170, Loss: 2002.2800
Iteration 180, Loss: 2737.9932
Iteration 190, Loss: 1820.1748


In [None]:
def sample(h, seed_ix, n=100):
    x = np.zeros((vocab_size, 1))
    x[seed_ix] = 1
    ixes = []
    for _ in range(n):
        h = np.tanh(np.dot(Wxh, x) + np.dot(Whh, h) + bh)
        y = np.dot(Why, h) + by
        p = np.exp(y) / np.sum(np.exp(y))
        ix = np.random.choice(range(vocab_size), p=p.ravel())
        x = np.zeros((vocab_size, 1))
        x[ix] = 1
        ixes.append(ix)
    return ''.join(ix_to_char[ix] for ix in ixes)

# Generate text
print(sample(hprev, char_to_ix['n'], 300))

amamapamapamamamamamamamamamamamapamamamamaeamapamamamamaeamapamamamapamamamaeapamamamamamamaeamamapapamamamamamamapaeamamamapamamamamamamamapapapawamaeapamamapamamaeamamamamamamamamamamamapamamamamamamaeamamapaeamamamamapamaeamapamawamamamamamapamamapamamamamamapamamamamamamamaeamamamamamamamaeamam
