In [157]:
import numpy as np

# Read in and prepare text data

In [158]:
with open("goblet_book.txt", "r") as f:
    data = f.read()

book_chars = list(set(data))
K = len(book_chars)

charToInd = {c:i for i,c in enumerate(book_chars)}
indToChar = {i:c for i,c in enumerate(book_chars)}

In [159]:
def toString(encoded_text):
    return ''.join([indToChar[i] for i in encoded_text])

In [160]:
def oneHotEncode(x):
    Y = np.zeros((K, len(x)))
    for i,c in enumerate(x):
        Y[charToInd[c],i] = 1
    return Y

In [161]:
class RNN():
    def __init__(self, K, m=100, seed=123456789):
        np.random.seed(seed)

        self.K = K
        self.m = m
        self.sigma = 0.01

        self.weights = {}
        self.momentum = {}
        
        # Biases
        self.weights["b"] = np.zeros(shape=(self.m,1))
        self.weights["c"] = np.zeros(shape=(K,1))

        # Weights
        self.weights["U"] = np.random.randn(self.m, self.K) * self.sigma
        self.weights["W"] = np.random.randn(self.m, self.m) * self.sigma
        self.weights["V"] = np.random.randn(self.K, self.m) * self.sigma

        # Momentum
        for key, value in self.weights.items():
            self.momentum[key] = np.zeros(value.shape)

        # Set initial hidden state
        self.h0 = np.zeros(shape=(self.m,1))
    
    def synth(self, x0, n):

        h = self.h0
        x = x0

        for i in range(n):
            t1 = self.weights["W"] @ h
            t2 = self.weights["U"] @ x[:,-1].reshape(-1,1)
            a = t1 + t2 + self.weights["b"]
            h = np.tanh(a)
            o = self.weights["V"] @ h + self.weights["c"]
            p = np.exp(o) / np.sum(np.exp(o), axis=0)
            idx = np.random.choice(range(self.K),p=np.squeeze(p))
            newX = np.zeros(shape=(self.K,1))
            newX[idx,0] = 1
            x = np.c_[x,newX]
        
        return [np.argmax(c) for c in x.T]
            
        




In [162]:
m = 100
eta = 0.1
seq_length = 25

In [163]:
x = oneHotEncode(["a"])
model = RNN(K, m)

In [164]:
toString(model.synth(x, 20))

'a \t9xj)cLc0/T2d•./2qJ'