In [5]:
# https://medium.com/@erikhallstrm/hello-world-rnn-83cd7105b767
DEBUG = True

import re, random, math, csv, io, string, itertools
import numpy as np
# import pandas as pd
import tensorflow as tf

In [6]:
hp = dict(
    n_layers = 2,
    hidden_size = 512,
    fc_size = 512,
    dropout = 0.9,
    batch_size = 20,
    lr = 0.001,
    lr_decay = 0.9999,
    min_lr = 0.00001,
    grad_clip = 5.,
    cuda = False,
    num_epoch = 5,
    max_length = 10
)

In [7]:
class Voc:
    SOS = "!"
    EOS = "#"
    SOS_ID = 0
    EOS_ID = 1
    def __init__(self):
        self.word2index = {self.SOS:0, self.EOS:1}
        self.word2count = {}
        self.index2word = {0:self.SOS, 1:self.EOS}
        self.n_words = 2 # Count SOS and EOS

    def index_words(self, sentence):
        for word in sentence.split(' '):
            self.index_word(word)

    def index_word(self, word):
        if word not in self.word2index:
            self.word2index[word] = self.n_words
            self.word2count[word] = 1
            self.index2word[self.n_words] = word
            self.n_words += 1
        else:
            self.word2count[word] += 1
    
def string2indicies(voc, text):
    return [voc.word2index[c] for c in text]
    
def indicies2string(voc, indicies):
    return "".join([voc.index2word[i] for i in indicies])

voc = Voc()

for c in itertools.chain(range(ord('a'), ord('z')+1),range(ord('A'),ord('Z')+1),(ord(" "),)):
    voc.index_word(chr(c))
# print(f'vocabulary size: {voc.n_words}')

In [24]:
class PairGenerator:
#     vocabulary = [chr(i) for i in itertools.chain(range(ord('a'), ord('z')+1),range(ord('A'),ord('Z')+1))]
    word_len_interval = {"a":2,"b":7}
    sent_len_interval = {"a":1,"b":10}
    
    def __init__(self,voc):
        self.voc = voc
        self.vocabulary = [c for c in voc.word2index.keys() if c not in {voc.SOS, voc.EOS}]
        
    def gen_word_pair(self):
        word_len = int(random.uniform(**self.word_len_interval))
#         word = random.choices(self.vocabulary,k=word_len)
        word = np.random.choice(self.vocabulary,word_len)
        return "".join(word), "".join(list(reversed(word)))
    
    def gen_pair(self):
        num_words = int(random.uniform(**self.sent_len_interval))
        inp, out = zip(*[self.gen_word_pair() for _ in range(num_words)])
        return self.voc.SOS+" ".join(inp)+self.voc.EOS, self.voc.SOS+" ".join(out)+self.voc.EOS
    
    def gen_batch(self, n_unrollings):
        inp, out = zip(*[self.gen_pair() for _ in range(n_unrollings)])
        return inp, out
    
    def gen_int_batch(self, n_unrollings):
        inp, out = zip(*[self.gen_pair() for _ in range(n_unrollings)])
        
        return \
            [string2indicies(self.voc, x) for x in inp], \
            [string2indicies(self.voc, x) for x in out]
        
pg = PairGenerator(voc)

In [48]:
def generate_data(n_unrollings=100, echo_step=2, batch_size=5):
    x = np.array(np.random.choice(2, n_unrollings, p=[0.5, 0.5]))
    y = np.roll(x, echo_step)
    y[0:echo_step] = 0

    x = x.reshape((batch_size, -1))  # The first index changing slowest, subseries as rows
    y = y.reshape((batch_size, -1))

    return (x, y)

In [45]:
x,y=generate_data()

In [46]:
x

array([[1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0, 0, 0],
       [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0, 1, 0],
       [0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0],
       [0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0],
       [1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 1]])

In [47]:
y

array([[0, 0, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 1, 1, 0],
       [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1, 0, 0, 0],
       [1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 0, 1],
       [0, 0, 0, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1],
       [1, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0]])

In [26]:
np.random.choice(pg.vocabulary,3)

array(['C', 'Q', 'F'], dtype='<U8')

In [27]:
X,Y = pg.gen_batch(n_unrollings=10)

In [56]:
class GRU:
    hp = dict(
        state_sz=5,
        n_classes=2,
        input_dim=2,
        ckpt_path="./checkpoints/"
    )
    def __init__(self, **hyper_parameters):
        if hyper_parameters is not None:
            for k, v in hyper_parameters.items():
                self.hp[k] = v
        self.__graph__()
    
    def _init(self):
        x = tf.placeholder(tf.int32, [None, None], "x") #batch*voc
        y = tf.placeholder(tf.int32, [None, None], "y") #batch*voc
        emb = tf.placeholder(tf.int32, [None, None],)
        
        # batch_sz*seq_len -> batch_sz*seq_ln*voc_sz
        embs = tf.get_variable('emb', [self.hp['n_classes'], self.hp['state_sz']]) 
        rnn_inputs = tf.nn.embedding_lookup(embs, x)
        
        #batch*state_sz
        init_state = tf.placeholder(tf.int32, [None, self.hp['state_sz']], "init_state")
    
    def _weights(self):
        self.wz = tf.get_variable(
            "w[z]", shape=[self.hp['state_sz'], self.hp['state_sz']],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.uz = tf.get_variable(
            "u[z]", shape=[self.hp['state_sz'], self.hp['state_sz']],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.bz = tf.get_variable(
            "b[z]", shape=[self.hp['state_sz']],
            initializer=tf.constant_initializer(0.)
        )
        
        self.wr = tf.get_variable(
            "w[r]",  shape=[self.hp['state_sz'], self.hp['state_sz']],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.ur = tf.get_variable(
            "u[r]",  shape=[self.hp['state_sz'], self.hp['state_sz']],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.br = tf.get_variable(
            "b[r]", shape=[self.hp['state_sz']],
            initializer=tf.constant_initializer(0.)
        )
        
        self.wh = tf.get_variable(
            "w[h]",  shape=[self.hp['state_sz'], self.hp['state_sz']],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.uh = tf.get_variable(
            "w[h]",  shape=[self.hp['state_sz'], self.hp['state_sz']],
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.bh = tf.get_variable(
            "b[h]", shape=[self.hp['state_sz']],
            initializer=tf.constant_initializer(0.)
        )
        
        # layer to decode results of GRU 
        self.wo= tf.get_variable(
            'w[out]', shape=[self.hp['state_sz'], self.hp['n_classes']], 
            initializer=tf.contrib.layers.xavier_initializer()
        )
        self.bo = tf.get_variable(
            'b[out]', shape=[self.hp['num_classes']], 
            initializer=tf.constant_initializer(0.)
        )
    
    def __graph__(self):
        # time cycle step
        def step(prev_state, x):
            z = tf.matmul(x,self.wz) + tf.matmul(prev, self.uz) + self.bz
            r = tf.matmul(x,self.wr) + tf.matmul(prev, self.ur) + self.br
            h = tf.matmul(x,self.wh) + tf.matmul(h*prev, self.uh) + self.bh
            return (1-z)*prev + z*h
        
        tf.reset_default_graph()
        self._init()
        self._weights()
        states = tf.scan(
            step,
            #batch_sz*seq_ln*voc_sz -> seq_len*batch_sz*voc_sz
            tf.transpose(self.rnn_inputs,[1,0,2]),
            initializer=self.init_state
        )
        # seq_len*batch_sz*voc_sz -> batch_sz*seq_ln*voc_sz
        states = tf.transpose(states,[1,0,2])
        
        states_reshaped = tf.reshape(states, [-1, state_size])
        logits = tf.matmul(states_reshaped, self.wo) + self.bo
        
        self.last_state = states[-1]
        self.predictions = tf.nn.softmax(logits)
        self.loss = tf.reduce_mean(
            tf.nn.sparse_softmax_cross_entropy_with_logits(logits,self.y)
        )
        self.train_op = tf.train.AdagradOptimizer(learning_rate=0.2).minimize(self.loss)
        
    def train(self,x_tarin, y_train,n_epochs = 10):
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            train_loss = 0
            try:
                for epoch_num in range(n_epochs):
                    x,y = generate_data()
                    batch_sz = x.shape[0]
                    _, train_loss_ = sess.run([self.train_op, self.loss], feed_dict = {
                        self.x : x,
                        self.y : y.flatten(),
                        self.init_state : np.zeros(
                            [batch_sz, self.hp['state_sz']]
                        )
                    })
            except KeyboardInterrupt as ex:
                print("Interrupted by user at")
            saver = tf.train.Saver()
            saver.save(sess, self.ckpt_path + "nn.mdl", global_step=i)

In [52]:
string2indicies(voc, X[0])

[0, 41, 16, 8, 54, 21, 51, 54, 23, 53, 42, 22, 25, 18, 1]

In [30]:
len_ls = list(map(len,X))

mx = max(len_ls)
diff = list(map(lambda x: mx - x, len_ls))

In [31]:
voc.word2index[' ']*xdiff

NameError: name 'xdiff' is not defined

In [57]:
gru = GRU()

ValueError: 'w[z]' is not a valid scope name

In [33]:
inp, out = zip(
    *[
        string2indicies(voc, s) 
        for s in pg.gen_pair()
    ]
)

ValueError: too many values to unpack (expected 2)

In [34]:
[a for a in pg.gen_pair()]

['!vKNE hc buYyqN#', '!ENKv ch NqyYub#']

In [36]:
x,y=pg.gen_int_batch(10)

In [37]:
x

[[0,
  46,
  10,
  54,
  14,
  35,
  54,
  0,
  35,
  38,
  6,
  17,
  14,
  4,
  1,
  14,
  44,
  53,
  10,
  54,
  54,
  6,
  4,
  11,
  37,
  41,
  11,
  54,
  42,
  18,
  41,
  46,
  51,
  38,
  54,
  15,
  47,
  27,
  54,
  11,
  42,
  36,
  38,
  36,
  33,
  54,
  44,
  41,
  4,
  42,
  32,
  1],
 [0,
  14,
  50,
  32,
  18,
  54,
  37,
  20,
  39,
  54,
  44,
  43,
  54,
  40,
  48,
  15,
  16,
  49,
  31,
  1],
 [0, 16, 24, 54, 35, 31, 1],
 [0,
  24,
  12,
  42,
  21,
  54,
  14,
  6,
  54,
  15,
  2,
  6,
  9,
  54,
  46,
  51,
  33,
  54,
  39,
  46,
  6,
  30,
  54,
  41,
  39,
  54,
  7,
  6,
  39,
  54,
  3,
  41,
  54,
  22,
  19,
  41,
  39,
  38,
  54,
  50,
  35,
  36,
  18,
  16,
  1],
 [0, 40, 35, 5, 19, 54, 6, 19, 7, 54, 53, 54, 6, 53, 2, 54, 48, 13, 33, 1],
 [0,
  0,
  35,
  38,
  6,
  17,
  14,
  4,
  1,
  11,
  40,
  54,
  53,
  22,
  35,
  8,
  12,
  26,
  54,
  26,
  51,
  16,
  54,
  39,
  41,
  53,
  29,
  36,
  54,
  32,
  38,
  54,
  37,
  3,
  21,
  37,
  