In [3]:
import numpy as np
import matplotlib.pyplot as plt

In [4]:
rng = np.random.default_rng()

In [5]:
def read_data():
    f = open("dinosaur_island.txt", "r")
    file_data = f.read()
    return file_data

def preprocess_data(file_data):
    names = file_data.strip().lower().split('\n')
    name_num = len(names)
    max_len = max([len(name) for name in names])
    charset = 29 # 26=end-of-sequence, 27=beginning-of-sequence, 28=padding
    x_train = np.zeros((name_num, max_len, charset))
    y_train = np.zeros((name_num, max_len, charset))
    a_int = ord('a')
    for n in range(name_num):
        capitals = sum(1 for ch in names[n] if ch.isupper())
        if capitals > 1:
            print(i + 1)
        for i in range(max_len):
            if i == 0:
                x_train[n, i, 27] = 1
            else:
                x_train[n, i] = y_train[n, i - 1]
            if i < len(names[n]):
                y_train[n, i, ord(names[n][i]) - a_int] = 1
            elif i == len(names[n]):
                y_train[n, i, 26] = 1
            else:
                y_train[n, i, 28] = 1
    return x_train, y_train

In [6]:
class RNNSeqSoftmax:
    def __init__(self, input_dim, seq_len, output_dim, hidden_size=128, unroll=5):
        self.input_dim = input_dim
        self.seq_len = seq_len
        self.output_dim = output_dim
        self.hidden_size= hidden_size
        self.unroll = unroll
        weight_range1 = np.sqrt(3)/np.sqrt((input_dim + hidden_size)/2)
        self.xh = np.array(rng.uniform(-weight_range1, weight_range1, size=(input_dim, hidden_size)))
        self.hh = np.array(rng.uniform(-weight_range1, weight_range1, size=(hidden_size, hidden_size)))
        self.bias2 = np.array(rng.uniform(-weight_range1, weight_range1, size=(1, hidden_size)))
        weight_range2 = np.sqrt(3)/np.sqrt((hidden_size + output_dim)/2)
        self.hq = np.array(rng.uniform(-weight_range2, weight_range2, size=(hidden_size, output_dim)))
        self.bias3 = np.array(rng.uniform(-weight_range2, weight_range2, size=(1, output_dim)))
        self.h = 0
        self.dj_dwxh = 0
        self.dj_dwhh = 0
        self.dj_db2 = 0
        self.dj_dwhq = 0
        self.dj_db3 = 0
        self.eta = 0.01

    def softmax(self, x):
        shifted_x = x - np.max(x, axis=-1, keepdims=True)
        exp_x = np.exp(shifted_x)
        return exp_x / np.sum(exp_x, axis=-1, keepdims=True)

    def forward(self, x):
        self.batch_size = x.shape[0]
        self.input = x
        self.output1 = np.dot(x, self.xh)
        self.output2 = np.zeros((self.batch_size, self.seq_len, self.hidden_size))
        for i in range(self.seq_len):
            if i == 0:
                self.output2[:, i] = np.tanh(self.output1[:, i, :] + self.bias2)
            else:
                self.output2[:, i] = np.tanh(self.output1[:, i, :] + np.dot(self.output2[:, i - 1], self.hh) + self.bias2)
        self.output3 = self.softmax(np.dot(self.output2, self.hq) + self.bias3)
        return self.output3

    def backward(self, y):
        error3 = (self.output3 - y)
        error2 = np.zeros((self.batch_size, self.seq_len, self.hidden_size))
        self.dj_dwhq = np.zeros((self.batch_size, self.hidden_size, self.output_dim))
        for i in range(self.batch_size):
            error2[i] = np.dot(error3[i], self.hq.T)
            self.dj_dwhq[i] = np.dot(self.output2[i].T, error3[i])
        self.dj_dwhq = np.mean(self.dj_dwhq, axis=0)
        self.dj_b3 = np.mean(np.sum(error3, axis=0), axis=0, keepdims=True)
        
        self.hq -= np.clip(self.dj_dwhq, a_min=-1, a_max=1) * self.eta
        self.bias3 -= np.clip(self.dj_db3, a_min=-1, a_max=1) * self.eta

        self.dj_dwhh = 0
        self.dj_dwxh = 0
        self.dj_db2 = 0
        for n in reversed(range(self.seq_len)):
            cc = error2[:, n]
            dj_db2_n = np.zeros_like(self.bias2)
            dj_dwxh_n = np.zeros_like(self.xh)
            dj_dwhh_n = np.zeros_like(self.hh)
            for i in reversed(range(max(0, n - self.unroll + 1), n + 1)):
                cc *= 1 - (self.output2[:, i]) ** 2
                if i > 0:
                    dj_dwhh_n += np.dot(self.output2[:, i - 1].T, cc)
                    dj_db2_n += np.sum(cc, axis=0, keepdims=True)
                dj_dwxh_n += np.dot(self.input[:, i].T, cc)
                cc = np.dot(cc, self.hh)
            self.dj_dwhh += dj_dwhh_n
            self.dj_dwxh += dj_dwxh_n
            self.dj_db2 += dj_db2_n

        self.hh -= np.clip(self.dj_dwhh, a_min=-1, a_max=1) * self.eta
        self.xh -= np.clip(self.dj_dwxh, a_min=-1, a_max=1) * self.eta
        self.bias2 -= np.clip(self.dj_db2, a_min=-1, a_max=1) * self.eta

    def loss(self, X, y):
        yhat = self.forward(X)
        total_loss = 0
        for b in range(yhat.shape[0]):
            for i in range(yhat.shape[1]):
                total_loss += -np.dot(np.log(np.clip(yhat[b, i], 10e-10, 1)), y[b, i])
        loss = total_loss / (yhat.shape[0] * yhat.shape[1])
        return loss

    def prepare_minibatches(self, X, y, size=32, shuffle=True):
        if shuffle:
            idx = np.arange(0, X.shape[0])
            idx = np.random.permutation(idx)
            X = X[idx, :]
            y = y[idx]

        batch_num = np.ceil(X.shape[0] / size)
        
        X_batches = np.array_split(X, batch_num)
        y_batches = np.array_split(y, batch_num)

        return (X_batches, y_batches)

    def fit(self, X, y, epochs=3, eta=0.01, batch_size=32, val_X=None, val_y=None):
        self.eta = eta
        train_losses = []
        val_losses = []
        for i in range(epochs):
            self.eta = max(eta * (0.8 ** np.floor(i / 10)), 0.0001) # learning rate decay
            X_batches, y_batches = self.prepare_minibatches(X, y, batch_size)
            for X_sample, y_sample in zip(X_batches, y_batches):
                result = self.forward(X_sample)
                self.backward(y_sample)
            train_losses.append(self.loss(X, y))
            if val_X is not None and val_y is not None:
                val_losses.append(self.loss(val_X, val_y))
                print("Epoch ", i + 1, ", Loss: ", train_losses[-1], " Val loss: ", val_losses[-1])
            else:
                print("Epoch ", i + 1, ", Loss: ", train_losses[-1])
        if val_X is not None and val_y is not None:
            return (train_losses, val_losses)
        return train_losses
        
    def predict(self, x):
        probs = self.forward(x)
        batch = x.shape[0]
        a_ind = ord('a')
        probs_max = np.argmax(probs, axis=-1)
        prediction = []
        for b in range(batch):
            word = ''
            for i in range(self.seq_len):
                char_ind = probs_max[b, i]
                if char_ind < 26:
                    word += chr(a_ind + char_ind)
                elif char_ind == 26:
                    word += '$'
                elif char_ind == 27:
                    word += '*'
                else:
                    word += '#'
            prediction.append(word)
        return prediction

    def predictfrom(self, prefix):
        prefix = prefix.lower()
        charset = 29
        chars = [np.zeros(charset)]
        chars[0][27] = 1
        a_int = ord('a')
        for ch in prefix:
            ch_onehot = np.zeros(charset)
            ch_onehot[ord(ch) - a_int] = 1
            chars.append(ch_onehot)
        result = prefix
        lasthidden = np.zeros(self.hidden_size)
        for ch_onehot in chars:
            lasthidden = np.tanh(np.dot(ch_onehot, self.xh) + np.dot(lasthidden, self.hh) + self.bias2)
        while len(result) < 25:
            rnn_output = np.dot(lasthidden, self.hq) + self.bias3
            ind = np.argmax(rnn_output)
            if ind < 26:
                letter = chr(ind + a_int)
                result += letter
                ch_onehot = np.zeros(charset)
                ch_onehot[ind] = 1
                lasthidden = np.tanh(np.dot(ch_onehot, self.xh) + np.dot(lasthidden, self.hh) + self.bias2)
            else:
                break
        return result

In [7]:
file_data = read_data()
x_train, y_train = preprocess_data(file_data)
print('x_train shape: ', x_train.shape)
print('y_train shape: ', y_train.shape)

x_train shape:  (1539, 23, 29)
y_train shape:  (1539, 23, 29)


In [30]:
model = RNNSeqSoftmax(input_dim=29, seq_len=23, output_dim=29, hidden_size=64, unroll=5)

history = model.fit(x_train, y_train, epochs=50, eta=0.01, batch_size=32)

Epoch  1 , Loss:  1.3779697324317024
Epoch  2 , Loss:  1.1721163409652904
Epoch  3 , Loss:  1.1792427927538964
Epoch  4 , Loss:  1.092028319340979
Epoch  5 , Loss:  1.0991688261987205
Epoch  6 , Loss:  1.0554846175174728
Epoch  7 , Loss:  1.0379949859072573
Epoch  8 , Loss:  1.0507885856819048
Epoch  9 , Loss:  1.021701649419817
Epoch  10 , Loss:  1.025670865174154
Epoch  11 , Loss:  1.0042827183577423
Epoch  12 , Loss:  1.0153390683822157
Epoch  13 , Loss:  0.9910034100789201
Epoch  14 , Loss:  0.9948117818544917
Epoch  15 , Loss:  0.9839719593022547
Epoch  16 , Loss:  0.9773922972023048
Epoch  17 , Loss:  0.974889902100372
Epoch  18 , Loss:  0.9686588835207638
Epoch  19 , Loss:  0.9656710769367863
Epoch  20 , Loss:  0.9791834490813749
Epoch  21 , Loss:  0.9569731794988008
Epoch  22 , Loss:  0.9618548670385122
Epoch  23 , Loss:  0.9440939872533143
Epoch  24 , Loss:  0.953867740211214
Epoch  25 , Loss:  0.9430630881498937
Epoch  26 , Loss:  0.9633943328508405
Epoch  27 , Loss:  0.94121

In [31]:
out = model.predict(x_train[0].reshape(1, 23, 29))
print(out)
out = model.predict(x_train[200].reshape(1, 23, 29))
print(out)
out = model.predict(x_train[400].reshape(1, 23, 29))
print(out)
out = model.predict(x_train[600].reshape(1, 23, 29))
print(out)

['anchanosaurus$#########']
['aaahroonatas$##########']
['ainoneasachus$#########']
['aanoaaasaurus$#########']


In [32]:
word = model.predictfrom('')
print(word)
word = model.predictfrom('tyran')
print(word)
word = model.predictfrom('velo')
print(word)
word = model.predictfrom('pte')
print(word)
word = model.predictfrom('st')
print(word)

angosaurus
tyranosaurus
velosaurus
pterosaurus
sterosaurus


In [34]:
word = model.predictfrom('stegostegostegostegostegostego')
print(word)

stegostegostegostegostegostego
