## Memory Network
* reference : https://github.com/carpedm20/MemN2N-tensorflow

In [1]:
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import sys

print (sys.version)
print (tf.__version__)

3.6.2 |Anaconda custom (64-bit)| (default, Jul 20 2017, 13:51:32) 
[GCC 4.4.7 20120313 (Red Hat 4.4.7-1)]
1.3.0


## Strory Question and Answer (End to End Data)

In [2]:
lines= [
'1 나 는 지금 배가 고프다.',
'2 나 는 집에 있다가 왔다.',
'3 나 는 지금 판교에 있다.',  
'4 너 는 지금 어디 있니?  판교',
'5 나 의 회사는 포스코ICT다.',
'6 나 는 피자 주문 하고 싶다.',
'7 너 는 무엇 을 주문 할려구?  피자'
]


### 입력값의 명사를 통해 완전한 문장 생성

In [3]:
import os
from collections import Counter


def read_data(fname, word2idx, max_words, max_sentences):
    # stories[story_ind] = [[sentence1], [sentence2], ..., [sentenceN]]
    # questions[question_ind] = {'question': [question], 'answer': [answer], 'story_index': #, 'sentence_index': #}
    stories = dict()
    questions = dict()
    
    
    if len(word2idx) == 0:
        word2idx['<null>'] = 0

    for line in lines:
        words = line.split()
        max_words = max(max_words, len(words))
        
        # Determine whether the line indicates the start of a new story
        if words[0] == '1':
            story_ind = len(stories)
            sentence_ind = 0
            stories[story_ind] = []
        
        # Determine whether the line is a question or not
        if '?' in line:
            is_question = True
            question_ind = len(questions)
            questions[question_ind] = {'question': [], 'answer': [], 'story_index': story_ind, 'sentence_index': sentence_ind}
        else:
            is_question = False
            sentence_ind = len(stories[story_ind])
        
        # Parse and append the words to appropriate dictionary / Expand word2idx dictionary
        sentence_list = []
        for k in range(1, len(words)):
            w = words[k].lower()
            
            # Remove punctuation
            if ('.' in w) or ('?' in w):
                w = w[:-1]
            
            # Add new word to dictionary
            if w not in word2idx:
                word2idx[w] = len(word2idx)
            
            # Append sentence to story dict if not question
            if not is_question:
                sentence_list.append(w)
                
                if '.' in words[k]:
                    stories[story_ind].append(sentence_list)
                    break
            
            # Append sentence and answer to question dict if question
            else:
                sentence_list.append(w)
                
                if '?' in words[k]:
                    answer = words[k + 1].lower()
                    
                    if answer not in word2idx:
                        word2idx[answer] = len(word2idx)
                    
                    questions[question_ind]['question'].extend(sentence_list)
                    questions[question_ind]['answer'].append(answer)
                    break
        
        # Update max_sentences
        max_sentences = max(max_sentences, sentence_ind+1)
    
    
    
    # Convert the words into indices
    for idx, context in stories.items():
        for i in range(len(context)):
            temp = list(map(word2idx.get, context[i]))
            context[i] = temp
    
    for idx, value in questions.items():
        temp1 = list(map(word2idx.get, value['question']))
        temp2 = list(map(word2idx.get, value['answer']))
        
        value['question'] = temp1
        value['answer'] = temp2
    
    return stories, questions, max_words, max_sentences


def pad_data(stories, questions, max_words, max_sentences):

    # Pad the context into same size with '<null>'
    for idx, context in stories.items():
        for sentence in context:           
            while len(sentence) < max_words:
                sentence.append(0)
        while len(context) < max_sentences:
            context.append([0] * max_words)
    
    # Pad the question into same size with '<null>'
    for idx, value in questions.items():
        while len(value['question']) < max_words:
            value['question'].append(0)


def depad_data(stories, questions):

    for idx, context in stories.items():
        for i in range(len(context)):
            if 0 in context[i]:
                if context[i][0] == 0:
                    temp = context[:i]
                    context = temp
                    break
                else:
                    index = context[i].index(0)
                    context[i] = context[i][:index]

    for idx, value in questions.items():
        if 0 in value['question']:
            index = value['question'].index(0)
            value['question'] = value['question'][:index]


In [4]:
import pprint

pp = pprint.PrettyPrinter()

flags = tf.app.flags

flags.DEFINE_integer("edim", 20, "internal state dimension [20]")
flags.DEFINE_integer("nhop", 3, "number of hops [3]")
flags.DEFINE_integer("mem_size", 50, "maximum number of sentences that can be encoded into memory [50]")
flags.DEFINE_integer("batch_size", 32, "batch size to use during training [32]")
flags.DEFINE_integer("nepoch", 100, "number of epoch to use during training [100]")
flags.DEFINE_integer("anneal_epoch", 25, "anneal the learning rate every <anneal_epoch> epochs [25]")
flags.DEFINE_integer("babi_task", 1, "index of bAbI task for the network to learn [1]")
flags.DEFINE_float("init_lr", 0.01, "initial learning rate [0.01]")
flags.DEFINE_float("anneal_rate", 0.5, "learning rate annealing rate [0.5]")
flags.DEFINE_float("init_mean", 0., "weight initialization mean [0.]")
flags.DEFINE_float("init_std", 0.1, "weight initialization std [0.1]")
flags.DEFINE_float("max_grad_norm", 40, "clip gradients to this norm [40]")
flags.DEFINE_string("data_dir", "./bAbI/en-valid", "dataset directory [./bAbI/en_valid]")
flags.DEFINE_string("checkpoint_dir", "./checkpoints", "checkpoint directory [./checkpoints]")
flags.DEFINE_boolean("lin_start", False, "True for linear start training, False for otherwise [False]")
flags.DEFINE_boolean("is_test", False, "True for testing, False for training [False]")
flags.DEFINE_boolean("show_progress", False, "print progress [False]")

FLAGS = flags.FLAGS

word2idx = {}
max_words = 0
max_sentences = 0

train_stories, train_questions, max_words, max_sentences = read_data(lines, word2idx, max_words, max_sentences)
valid_stories, valid_questions, max_words, max_sentences = read_data(lines, word2idx, max_words, max_sentences)
test_stories, test_questions, max_words, max_sentences = read_data(lines, word2idx, max_words, max_sentences)

pad_data(train_stories, train_questions, max_words, max_sentences)
pad_data(valid_stories, valid_questions, max_words, max_sentences)
pad_data(test_stories, test_questions, max_words, max_sentences)

idx2word = dict(zip(word2idx.values(), word2idx.keys()))
FLAGS.nwords = len(word2idx)
FLAGS.max_words = max_words
FLAGS.max_sentences = max_sentences

pp.pprint(flags.FLAGS.__flags)

{'anneal_epoch': 25,
 'anneal_rate': 0.5,
 'babi_task': 1,
 'batch_size': 32,
 'checkpoint_dir': './checkpoints',
 'data_dir': './bAbI/en-valid',
 'edim': 20,
 'init_lr': 0.01,
 'init_mean': 0.0,
 'init_std': 0.1,
 'is_test': False,
 'lin_start': False,
 'max_grad_norm': 40,
 'max_sentences': 5,
 'max_words': 8,
 'mem_size': 50,
 'nepoch': 100,
 'nhop': 3,
 'nwords': 25,
 'show_progress': False}


In [5]:
def ProgressBar(Bar):
    message = 'Loading'
    fill = '#'
    suffix = '%(percent).1f%% | ETA: %(eta)ds'


### 학습결과 출력
* Memory Network 학습 결과 출력

In [6]:
import os
import math
import random

import numpy as np
import tensorflow as tf

class MemN2N(object):
    
    def __init__(self, config, sess):
        self.nwords = config.nwords
        self.max_words = config.max_words
        self.max_sentences = config.max_sentences
        self.init_mean = config.init_mean
        self.init_std = config.init_std
        self.batch_size = config.batch_size
        self.nepoch = config.nepoch
        self.anneal_epoch = config.anneal_epoch
        self.nhop = config.nhop
        self.edim = config.edim
        self.mem_size = config.mem_size
        self.max_grad_norm = config.max_grad_norm
        
        self.lin_start = config.lin_start
        self.show_progress = config.show_progress
        self.is_test = config.is_test

        self.checkpoint_dir = config.checkpoint_dir
        
        if not os.path.isdir(self.checkpoint_dir):
            os.makedirs(self.checkpoint_dir)
        
        self.query = tf.placeholder(tf.int32, [None, self.max_words], name='input')
        self.time = tf.placeholder(tf.int32, [None, self.mem_size], name='time')
        self.target = tf.placeholder(tf.float32, [None, self.nwords], name='target')
        self.context = tf.placeholder(tf.int32, [None, self.mem_size, self.max_words], name='context')
        
        self.hid = []
        
        self.lr = None
        
        if self.lin_start:
            self.current_lr = 0.005
        else:
            self.current_lr = config.init_lr

        self.anneal_rate = config.anneal_rate
        self.loss = None
        self.optim = None
        
        self.sess = sess
        self.log_loss = []
        self.log_perp = []
    
    def build_memory(self):
        self.global_step = tf.Variable(0, name='global_step')
        
        zeros = tf.constant(0, tf.float32, [1, self.edim])
        self.A_ = tf.Variable(tf.random_normal([self.nwords - 1, self.edim], mean=self.init_mean, stddev=self.init_std))
        self.B_ = tf.Variable(tf.random_normal([self.nwords - 1, self.edim], mean=self.init_mean, stddev=self.init_std))
        self.C_ = tf.Variable(tf.random_normal([self.nwords - 1, self.edim], mean=self.init_mean, stddev=self.init_std))
        
        A = tf.concat([zeros, self.A_], axis=0)
        B = tf.concat([zeros, self.B_], axis=0)
        C = tf.concat([zeros, self.C_], axis=0)
        
        self.T_A_ = tf.Variable(tf.random_normal([self.mem_size - 1, self.edim], mean=self.init_mean, stddev=self.init_std))
        self.T_C_ = tf.Variable(tf.random_normal([self.mem_size - 1, self.edim], mean=self.init_mean, stddev=self.init_std))
        
        T_A = tf.concat([zeros, self.T_A_], axis=0)
        T_C = tf.concat([zeros, self.T_C_], axis=0)
        
        A_ebd = tf.nn.embedding_lookup(A, self.context)   # [batch_size, mem_size, max_length, edim]
        A_ebd = tf.reduce_sum(A_ebd, axis=2)              # [batch_size, mem_size, edim]
        T_A_ebd = tf.nn.embedding_lookup(T_A, self.time)  # [batch_size, mem_size, edim]
        A_in = tf.add(A_ebd, T_A_ebd)                     # [batch_size, mem_size, edim]
        
        C_ebd = tf.nn.embedding_lookup(C, self.context)   # [batch_size, mem_size, max_length, edim]
        C_ebd = tf.reduce_sum(C_ebd, axis=2)              # [batch_size, mem_size, edim]
        T_C_ebd = tf.nn.embedding_lookup(T_C, self.time)  # [batch_size, mem_size, edim]
        C_in = tf.add(C_ebd, T_C_ebd)                     # [batch_size, mem_size, edim]
        
        query_ebd = tf.nn.embedding_lookup(B, self.query) # [batch_size, max_length, edim]
        query_ebd = tf.reduce_sum(query_ebd, axis=1)      # [batch_size, edim]
        self.hid.append(query_ebd)
        
        for h in range(self.nhop):
            q3dim = tf.reshape(self.hid[-1], [-1, 1, self.edim]) # [batch_size, edim] ==> [batch_size, 1, edim]
            p3dim = tf.matmul(q3dim, A_in, transpose_b=True)     # [batch_size, 1, edim] X [batch_size, edim, mem_size]
            p2dim = tf.reshape(p3dim, [-1, self.mem_size])       # [batch_size, mem_size]
            
            # If linear start, remove softmax layers
            if self.lin_start:
                p = p2dim
            else:
                p = tf.nn.softmax(p2dim)
            
            p3dim = tf.reshape(p, [-1, 1, self.mem_size]) # [batch_size, 1, mem_size]
            o3dim = tf.matmul(p3dim, C_in)                # [batch_size, 1, mem_size] X [batch_size, mem_size, edim]
            o2dim = tf.reshape(o3dim, [-1, self.edim])    # [batch_size, edim]
            
            a = tf.add(o2dim, self.hid[-1]) # [batch_size, edim]
            self.hid.append(a)              # [input, a_1, a_2, ..., a_nhop]
    
    def build_model(self):
        self.build_memory()
        
        self.W = tf.Variable(tf.random_normal([self.edim, self.nwords], mean=self.init_mean, stddev=self.init_std))
        a_hat = tf.matmul(self.hid[-1], self.W)
        
        self.hypothesis = tf.nn.softmax(a_hat)

        self.loss = tf.nn.softmax_cross_entropy_with_logits(logits=a_hat, labels=self.target)
        
        self.lr = tf.Variable(self.current_lr)
        self.opt = tf.train.GradientDescentOptimizer(self.lr)
        
        params = [self.A_, self.B_, self.C_, self.T_A_, self.T_C_, self.W]
        grads_and_vars = self.opt.compute_gradients(self.loss, params)
        clipped_grads_and_vars = [(tf.clip_by_norm(gv[0], self.max_grad_norm), gv[1]) for gv in grads_and_vars]
        
        inc = self.global_step.assign_add(1)
        with tf.control_dependencies([inc]):
            self.optim = self.opt.apply_gradients(clipped_grads_and_vars)
        
        tf.global_variables_initializer().run()
        self.saver = tf.train.Saver()


    def train(self, train_stories, train_questions):
        N = int(math.ceil(len(train_questions) / self.batch_size))
        cost = 0
        
        if self.show_progress:
            bar = ProgressBar('Train', max=N)
        
        for idx in range(N):
            
            if self.show_progress:
                bar.next()
            
            if idx == N - 1:
                iterations = len(train_questions) - (N - 1) * self.batch_size
            else:
                iterations = self.batch_size
            
            query = np.ndarray([iterations, self.max_words], dtype=np.int32)
            time = np.zeros([iterations, self.mem_size], dtype=np.int32)
            target = np.zeros([iterations, self.nwords], dtype=np.float32)
            context = np.ndarray([iterations, self.mem_size, self.max_words], dtype=np.int32)
            
            for b in range(iterations):
                m = idx * self.batch_size + b
                
                curr_q = train_questions[m]
                q_text = curr_q['question']
                story_ind = curr_q['story_index']
                sent_ind = curr_q['sentence_index']
                answer = curr_q['answer'][0]
                
                curr_s = train_stories[story_ind]
                curr_c = curr_s[:sent_ind + 1]

                if len(curr_c) >= self.mem_size:
                    curr_c = curr_c[-self.mem_size:]
                    
                    for t in range(self.mem_size):
                        time[b, t].fill(t)
                else:
                    
                    for t in range(len(curr_c)):
                        time[b, t].fill(t)
                    
                    while len(curr_c) < self.mem_size:
                        curr_c.append([0.] * self.max_words)
                
                query[b, :] = q_text
                target[b, answer] = 1
                context[b, :, :] = curr_c

            _, loss, self.step = self.sess.run([self.optim, self.loss, self.global_step],
                                               feed_dict={self.query: query, self.time: time,
                                                          self.target: target, self.context: context})
            cost += np.sum(loss)
        
        if self.show_progress:
            bar.finish()
        
        return cost / len(train_questions)
    
    
    def test(self, test_stories, test_questions, label='Test'):
        N = int(math.ceil(len(test_questions) / self.batch_size))
        cost = 0
        
        if self.show_progress:
            bar = ProgressBar('Train', max=N)
        
        for idx in range(N):
            
            if self.show_progress:
                bar.next()
            
            if idx == N - 1:
                iterations = len(test_questions) - (N - 1) * self.batch_size
            else:
                iterations = self.batch_size
            
            query = np.ndarray([iterations, self.max_words], dtype=np.int32)
            time = np.zeros([iterations, self.mem_size], dtype=np.int32)
            target = np.zeros([iterations, self.nwords], dtype=np.float32)
            context = np.ndarray([iterations, self.mem_size, self.max_words], dtype=np.int32)
            
            for b in range(iterations):
                m = idx * self.batch_size + b
                
                curr_q = test_questions[m]
                q_text = curr_q['question']
                story_ind = curr_q['story_index']
                sent_ind = curr_q['sentence_index']
                answer = curr_q['answer'][0]
                
                curr_s = test_stories[story_ind]
                curr_c = curr_s[:sent_ind + 1]
                
                if len(curr_c) >= self.mem_size:
                    curr_c = curr_c[-self.mem_size:]
                    
                    for t in range(self.mem_size):
                        time[b, t].fill(t)
                else:
                    
                    for t in range(len(curr_c)):
                        time[b, t].fill(t)
                    
                    while len(curr_c) < self.mem_size:
                        curr_c.append([0.] * self.max_words)
                
                query[b, :] = q_text
                target[b, answer] = 1
                context[b, :, :] = curr_c

            _, loss, self.step = self.sess.run([self.optim, self.loss, self.global_step],
                                               feed_dict={self.query: query, self.time: time,
                                                          self.target: target, self.context: context})
            cost += np.sum(loss)
        
        if self.show_progress:
            bar.finish()
        
        return cost / len(test_questions)
    
    
    def run(self, train_stories, train_questions, test_stories, test_questions):
        if not self.is_test:# add not

            for idx in range(self.nepoch):
                train_loss = np.sum(self.train(train_stories, train_questions))
                test_loss = np.sum(self.test(test_stories, test_questions, label='Validation'))
                
                self.log_loss.append([train_loss, test_loss])
                
                state = {
                    'loss': train_loss,
                    'epoch': idx,
                    'learning_rate': self.current_lr,
                    'validation_loss': test_loss
                }
                
                print(state)
                
                
                # learning rate annealing
                if (not idx == 0) and (idx % self.anneal_epoch == 0):
                    self.current_lr = self.current_lr * self.anneal_rate
                    self.lr.assign(self.current_lr).eval()
            
                # If validation loss stops decreasing, insert softmax layers
                if idx == 0:
                    pass
                else:
                    if self.log_loss[idx][1] > self.log_loss[idx - 1][1]:
                        self.lin_start = False

                if idx % 10 == 0:
                    self.saver.save(self.sess,
                                    os.path.join(self.checkpoint_dir, "MemN2N.model"),
                                    global_step=self.step.astype(int))
        else:
            self.load()
            
            valid_loss = np.sum(self.test(train_stories, train_questions, label='Validation'))
            test_loss = np.sum(self.test(test_stories, test_questions, label='Test'))
            
            state = {
                'validation_loss': valid_loss,
                'test_loss': test_loss
            }
            
            print(state)


    def predict(self, test_stories, test_questions):
        self.load()

        num_instances = len(test_questions)

        query = np.ndarray([num_instances, self.max_words], dtype=np.int32)
        time = np.zeros([num_instances, self.mem_size], dtype=np.int32)
        target = np.zeros([num_instances, self.nwords], dtype=np.float32)
        context = np.ndarray([num_instances, self.mem_size, self.max_words], dtype=np.int32)

        for b in range(num_instances):
            
            curr_q = test_questions[b]
            q_text = curr_q['question']
            story_ind = curr_q['story_index']
            sent_ind = curr_q['sentence_index']
            answer = curr_q['answer'][0]
            
            curr_s = test_stories[story_ind]
            curr_c = curr_s[:sent_ind + 1]
            
            if len(curr_c) >= self.mem_size:
                curr_c = curr_c[-self.mem_size:]
                
                for t in range(self.mem_size):
                    time[b, t].fill(t)
            else:
                
                for t in range(len(curr_c)):
                    time[b, t].fill(t)
                
                while len(curr_c) < self.mem_size:
                    curr_c.append([0.] * self.max_words)
            
            query[b, :] = q_text
            target[b, answer] = 1
            context[b, :, :] = curr_c

        predictions = self.sess.run(self.hypothesis, feed_dict={self.query: query, self.time: time, self.context: context})

        return predictions, target


        
    def load(self):
        print(' [*] Reading checkpoints...')
        ckpt = tf.train.get_checkpoint_state(self.checkpoint_dir)
        if ckpt and ckpt.model_checkpoint_path:
            self.saver.restore(self.sess, ckpt.model_checkpoint_path)
        else:
            raise Exception(" [!] No checkpoint found")

In [7]:
with tf.Session() as sess:
    model = MemN2N(FLAGS, sess)
    model.build_model()

    if FLAGS.is_test:
        model.run(valid_stories, valid_questions, test_stories, test_questions)
    else:
        model.run(train_stories, train_questions, valid_stories, valid_questions)
        
    predictions, target = model.predict(train_stories, train_questions)


{'loss': 3.1342921257019043, 'epoch': 0, 'learning_rate': 0.01, 'validation_loss': 3.1048181056976318}
{'loss': 3.0751833915710449, 'epoch': 1, 'learning_rate': 0.01, 'validation_loss': 3.0453207492828369}
{'loss': 3.0151634216308594, 'epoch': 2, 'learning_rate': 0.01, 'validation_loss': 2.9846458435058594}
{'loss': 2.9537034034729004, 'epoch': 3, 'learning_rate': 0.01, 'validation_loss': 2.9222722053527832}
{'loss': 2.8902895450592041, 'epoch': 4, 'learning_rate': 0.01, 'validation_loss': 2.8576936721801758}
{'loss': 2.8244237899780273, 'epoch': 5, 'learning_rate': 0.01, 'validation_loss': 2.7904205322265625}
{'loss': 2.7556252479553223, 'epoch': 6, 'learning_rate': 0.01, 'validation_loss': 2.7199821472167969}
{'loss': 2.6834366321563721, 'epoch': 7, 'learning_rate': 0.01, 'validation_loss': 2.6459360122680664}
{'loss': 2.6074314117431641, 'epoch': 8, 'learning_rate': 0.01, 'validation_loss': 2.567875862121582}
{'loss': 2.5272259712219238, 'epoch': 9, 'learning_rate': 0.01, 'validatio

{'loss': 0.10035804659128189, 'epoch': 99, 'learning_rate': 0.00125, 'validation_loss': 0.099972523748874664}
 [*] Reading checkpoints...
INFO:tensorflow:Restoring parameters from ./checkpoints/MemN2N.model-182


In [8]:

index = 0

depad_data(train_stories, train_questions)

question = train_questions[index]['question']
answer = train_questions[index]['answer']
story_index = train_questions[index]['story_index']
sentence_index = train_questions[index]['sentence_index']

story = train_stories[story_index][:sentence_index + 1]

story = [list(map(idx2word.get, sentence)) for sentence in story]
question = list(map(idx2word.get, question))
prediction = [idx2word[np.argmax(predictions[index])]]
answer = list(map(idx2word.get, answer))

print('Story words:')
pp.pprint(story)
print('\nQuestion:')
pp.pprint(question)
print('\nPrediction:')
pp.pprint(prediction)
print('\nAnswer:')
pp.pprint(answer)
print('\nCorrect:')
pp.pprint(prediction == answer)

Story words:
[['나', '는', '지금', '배가', '고프다'],
 ['나', '는', '집에', '있다가', '왔다'],
 ['나', '는', '지금', '판교에', '있다']]

Question:
['너', '는', '지금', '어디', '있니']

Prediction:
['판교']

Answer:
['판교']

Correct:
True


In [9]:
index = 1

depad_data(train_stories, train_questions)

question = train_questions[index]['question']
answer = train_questions[index]['answer']
story_index = train_questions[index]['story_index']
sentence_index = train_questions[index]['sentence_index']

story = train_stories[story_index][:sentence_index + 1]

story = [list(map(idx2word.get, sentence)) for sentence in story]
question = list(map(idx2word.get, question))
prediction = [idx2word[np.argmax(predictions[index])]]
answer = list(map(idx2word.get, answer))

print('\nQuestion:')
pp.pprint(question)
print('\nPrediction:')
pp.pprint(prediction)
print('\nAnswer:')
pp.pprint(answer)
print('\nCorrect:')
pp.pprint(prediction == answer)


Question:
['너', '는', '무엇', '을', '주문', '할려구']

Prediction:
['피자']

Answer:
['피자']

Correct:
True
