<a href="https://colab.research.google.com/github/Nanangk/Simple_Recurrent_Neural_Network_with_Keras/blob/master/SimpleRNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
#Import Library

import numpy as np

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import TimeDistributed, Dense, Dropout, SimpleRNN, RepeatVector
from tensorflow.keras.callbacks import EarlyStopping, LambdaCallback

from termcolor import colored

In [0]:
all_chars = '0123456789+'

In [0]:
num_features = len(all_chars)
char_to_index = dict((c, i) for i, c in enumerate(all_chars))
index_to_char = dict((i, c) for i, c in enumerate(all_chars))

In [51]:
def generate_data():
  first = np.random.randint(0,100)
  second = np.random.randint(0, 100)
  example = str(first) + '+' + str(second)
  label = str(first+second)
  return example, label

generate_data()

('60+87', '147')

In [52]:
#Create Model

hidden_unit = 128
max_time_steps = 5

model =  Sequential([
        SimpleRNN(hidden_unit, input_shape=(None, num_features)),
        RepeatVector(max_time_steps),
        SimpleRNN(hidden_unit, return_sequences=True),
        TimeDistributed(Dense(num_features, activation='softmax'))
])

model.compile(
    loss='categorical_crossentropy',
    optimizer = 'adam',
    metrics=['accuracy']
)

model.summary()

Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
simple_rnn_6 (SimpleRNN)     (None, 128)               17920     
_________________________________________________________________
repeat_vector_3 (RepeatVecto (None, 5, 128)            0         
_________________________________________________________________
simple_rnn_7 (SimpleRNN)     (None, 5, 128)            32896     
_________________________________________________________________
time_distributed_3 (TimeDist (None, 5, 11)             1419      
Total params: 52,235
Trainable params: 52,235
Non-trainable params: 0
_________________________________________________________________


In [53]:
#Vectorize and De-Vectorize Data

def vectorize_example(example, label):
  x = np.zeros((max_time_steps, num_features))
  y = np.zeros((max_time_steps, num_features))

  diff_x = max_time_steps - len(example)
  diff_y = max_time_steps - len(label)

  for i,c in enumerate(example):
    x[i+diff_x, char_to_index[c]] = 1

  for i in range(diff_x):
    x[i, char_to_index['0']] = 1

  for i,c in enumerate(label):
    y[i+diff_y, char_to_index[c]] = 1

  for i in range(diff_y):
    y[i, char_to_index['0']] = 1

  return x, y


e, l = generate_data()
print(e, l)
x, y = vectorize_example(e,l)
print(x.shape, y.shape)

26+73 99
(5, 11) (5, 11)


In [54]:
def devectorize_example(example):
  result = [index_to_char[np.argmax(vec)] for i, vec in enumerate(example)]
  return ''.join(result)


devectorize_example(x)

'26+73'

In [55]:
devectorize_example(y)

'00099'

In [56]:
#Creating Dataset

def create_dataset(num_examples=2000):
  x = np.zeros((num_examples, max_time_steps, num_features))
  y = np.zeros((num_examples, max_time_steps, num_features))
  for i in range(num_examples):
    e, l = generate_data()
    e_v, l_v = vectorize_example(e, l)
    x[i] = e_v
    y[i] = l_v

  return x,y
x,y = create_dataset()
print(x.shape, y.shape)

(2000, 5, 11) (2000, 5, 11)


In [57]:
devectorize_example(x[0])

'52+34'

In [58]:
devectorize_example(y[0])

'00086'

In [0]:
#Training the Model

model.fit(x,y, epochs=50, batch_size=100, validation_split=0.2,
          verbose=1,
          )

In [62]:
#Predict

x_test, y_test = create_dataset(20)
preds = model.predict(x_test)

for i,pred in enumerate(preds):
  y = devectorize_example(y_test[i])
  y_hat = devectorize_example(pred)
  col = 'green'
  if y != y_hat:
    col = 'red'
  
  out = 'Input : '+devectorize_example(x_test[i])+ ' Out : '+y+' Pred : '+y_hat
  print(colored(out, col))

[31mInput : 53+26 Out : 00079 Pred : 00089[0m
[32mInput : 061+2 Out : 00063 Pred : 00063[0m
[32mInput : 52+60 Out : 00112 Pred : 00112[0m
[32mInput : 36+72 Out : 00108 Pred : 00108[0m
[32mInput : 69+39 Out : 00108 Pred : 00108[0m
[32mInput : 76+76 Out : 00152 Pred : 00152[0m
[32mInput : 96+92 Out : 00188 Pred : 00188[0m
[32mInput : 82+12 Out : 00094 Pred : 00094[0m
[32mInput : 46+80 Out : 00126 Pred : 00126[0m
[32mInput : 87+25 Out : 00112 Pred : 00112[0m
[32mInput : 84+69 Out : 00153 Pred : 00153[0m
[32mInput : 48+85 Out : 00133 Pred : 00133[0m
[32mInput : 22+30 Out : 00052 Pred : 00052[0m
[32mInput : 050+9 Out : 00059 Pred : 00059[0m
[31mInput : 049+2 Out : 00051 Pred : 00050[0m
[31mInput : 96+93 Out : 00189 Pred : 00180[0m
[32mInput : 57+29 Out : 00086 Pred : 00086[0m
[32mInput : 58+40 Out : 00098 Pred : 00098[0m
[31mInput : 023+2 Out : 00025 Pred : 00023[0m
[31mInput : 78+96 Out : 00174 Pred : 00183[0m
