# CNTK 208: ReasoNet for Machine Comprehension

## Introduction and Background

This hands-on tutorial will take you through how to implement [ReasoNet](https://posenhuang.github.io/papers/reasonet_iclr_2017.pdf) in the Microsoft Cognitive Toolkit. Machine comprehension task try to find out the answer for a question given a paragraph of text. 
In this tutorial, we will use [CNN data](https://github.com/deepmind/rc-data) as an example. The data is consist of tuples (q,d,a,A). Here q is the query, d is the document, a is candidate list and A is the true answer. 

### Model Structure

![](ReasoNet/components.png) 
![](ReasoNet/reasonet.png) 

## Data preparing


### Download data
The data can be downloaded via (https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTTljRDVZMFJnVWM) or (https://github.com/deepmind/rc-data)
The downloaded data is packaged as a gz file and to feed to CNTK it needs to be reformated. After unpacking the file, we will get three folders (e.g. training, test, validation), each contains a lot of files where each file is consist of a paragraph of text, a question, the answer to the questions and a list of entities. First we need to merge each folder of files into a single file with following script,

In [None]:
import io
import os
import re
import requests
import sys
import tarfile
import shutil

def merge_files(folder, target):
  if os.path.exists(target):
    return
  count = 0
  all_files = os.listdir(folder)
  print("Start to merge {0} files under folder {1} as {2}".format(len(all_files), folder, target))
  for f in all_files:
    txt=os.path.join(folder, f)
    if os.path.isfile(txt):
      with open(txt) as sample:
        content = sample.readlines()
        context = content[2].strip()
        query = content[4].strip()
        answer = content[6].strip()
        entities = []
        for k in range(8, len(content)):
          entities += [ content[k].strip() ]
        with open(target, 'a') as output:
          output.write("{0}\t{1}\t{2}\t{3}\n".format(query, answer, context, "\t".join(entities)))
    count+=1
    if count%1000==0:
      sys.stdout.write(".")
      sys.stdout.flush()
  print()
  print("Finished to merge {0}".format(target))

def download_cnn(target="."):
  if os.path.exists(os.path.join(target, "cnn")):
    shutil.rmtree(os.path.join(target, "cnn"))
  if not os.path.exists(target):
    os.makedirs(target)
  url="https://drive.google.com/uc?export=download&id=0BwmD_VLjROrfTTljRDVZMFJnVWM"
  print("Start to download CNN data from {0} to {1}".format(url, target))
  pre_request = requests.get(url)
  confirm_match = re.search(r"confirm=(.{4})", pre_request.content.decode("utf-8"))
  confirm_url = url + "&confirm=" + confirm_match.group(1)
  download_request = requests.get(confirm_url, cookies=pre_request.cookies)
  tar = tarfile.open(mode="r:gz", fileobj=io.BytesIO(download_request.content))
  tar.extractall(target)
  print("Finished to download {0} to {1}".format(url, target))

def file_exists(src):
  return (os.path.isfile(src) and os.path.exists(src))

data_path = "../Examples/LanguageUnderstanding/ReasoNet/Data"
raw_train_data=os.path.join(data_path, "training.txt")
raw_test_data=os.path.join(data_path, "test.txt")
raw_validation_data=os.path.join(data_path, "validation.txt")
if not (file_exists(raw_train_data) and file_exists(raw_test_data) and file_exists(raw_validation_data)):
  download_cnn(data_path)

merge_files(os.path.join(data_path, "cnn/questions/training"), raw_train_data)
merge_files(os.path.join(data_path, "cnn/questions/test"), raw_test_data)
merge_files(os.path.join(data_path, "cnn/questions/validation"), raw_validation_data)
print("All necessary data are downloaded to {0}".format(data_path))

### Convert to CNTK Text Format

CNTK consumes a special text format for training data, we need to convert the downloaded data fiels into [CNTK text format](https://github.com/Microsoft/CNTK/wiki/BrainScript-CNTKTextFormat-Reader). Here is the script to do that.

In [None]:
import sys
import os
import math
import functools
import numpy as np

class WordFreq:
  def __init__(self, word, id, freq):
    self.word = word
    self.id = id
    self.freq = freq

class Vocabulary:
  """Build word vocabulary with frequency"""
  def __init__(self, name):
    self.name = name
    self.size = 0
    self.__dict = {}
    self.__has_index = False

  def push(self, word):
    if word in self.__dict:
      self.__dict[word].freq += 1
    else:
      self.__dict[word] = WordFreq(word, len(self.__dict), 1)

  def build_index(self, max_size):
    def word_cmp(x, y):
      if x.freq == y.freq :
        return (x.word > y.word) - (x.word < y.word)
      else:
        return x.freq - y.freq

    items = sorted(self.__dict.values(), key=functools.cmp_to_key(word_cmp), reverse=True)
    if len(items)>max_size:
      del items[max_size:]
    self.size=len(items)
    self.__dict.clear()
    for it in items:
      it.id = len(self.__dict)
      self.__dict[it.word] = it
    self.__has_index = True

  def save(self, dst):
    if not self.__has_index:
      self.build_index(sys.maxsize)
    if self.name != None:
      dst.write("{0}\t{1}\n".format(self.name, self.size))
    for it in sorted(self.__dict.values(), key=lambda it:it.id):
      dst.write("{0}\t{1}\t{2}\n".format(it.word, it.id, it.freq))

  def load(self, src):
    line = src.readline()
    if line == "":
      return
    line = line.rstrip('\n')
    head = line.split()
    max_size = sys.maxsize
    if len(head) == 2:
      self.name = head[0]
      max_size = int(head[1])
    cnt = 0
    while cnt < max_size:
      line = src.readline()
      if line == "":
        break
      line = line.rstrip('\n')
      items = line.split()
      self.__dict[items[0]] = WordFreq(items[0], int(items[1]), int(items[2]))
      cnt += 1
    self.size = len(self.__dict)
    self.__has_index = True

  def __getitem__(self, key):
    if key in self.__dict:
      return self.__dict[key]
    else:
      return None

  def values(self):
    return self.__dict.values()

  def __len__(self):
    return self.size

  def __contains__(self, q):
    return q in self.__dict

  @staticmethod
  def is_cnn_entity(word):
    return word.startswith('@entity') or word.startswith('@placeholder')

  @staticmethod
  def load_vocab(vocab_src):
    """
    Loa vocabulary from file.

    Args:
      vocab_src (`str`): the file stored with the vocabulary data
      
    Returns:
      :class:`Vocabulary`: Vocabulary of the entities
      :class:`Vocabulary`: Vocabulary of the words
    """
    word_vocab = Vocabulary("WordVocab")
    entity_vocab = Vocabulary("EntityVocab")
    with open(vocab_src, 'r', encoding='utf-8') as src:
      entity_vocab.load(src)
      word_vocab.load(src)
    return entity_vocab, word_vocab

  @staticmethod
  def build_vocab(input_src, vocab_dst, max_size=50000):
    """
    Build vocabulary from raw corpus file.

    Args:
      input_src (`str`): the path of the corpus file
      vocab_dst (`str`): the path of the vocabulary file to save the built vocabulary
      max_size (`int`): the maxium size of the word vocabulary
    Returns:
      :class:`Vocabulary`: Vocabulary of the entities
      :class:`Vocabulary`: Vocabulary of the words
    """
    # Leave the first as Unknown
    max_size -= 1
    word_vocab = Vocabulary("WordVocab")
    entity_vocab = Vocabulary("EntityVocab")
    linenum = 0
    print("Start build vocabulary from {0} with maxium words {1}. Saved to {2}".format(input_src, max_size, vocab_dst))
    with open(input_src, 'r', encoding='utf-8') as src:
      all_lines = src.readlines()
      print("Total lines to process: {0}".format(len(all_lines)))
      for line in all_lines:
        line = line.strip('\n')
        ans, query_words, context_words = Vocabulary.parse_corpus_line(line)
        for q in query_words:
          if Vocabulary.is_cnn_entity(q):
          #if q.startswith('@'):
            entity_vocab.push(q)
          else:
            word_vocab.push(q)
        for q in context_words:
          #if q.startswith('@'):
          if Vocabulary.is_cnn_entity(q):
            entity_vocab.push(q)
          else:
            word_vocab.push(q)
        linenum += 1
        if linenum%1000==0:
          sys.stdout.write(".")
          sys.stdout.flush()
    print()
    entity_vocab.build_index(max_size)
    word_vocab.build_index(max_size)
    with open(vocab_dst, 'w', encoding='utf-8') as dst:
      entity_vocab.save(dst)
      word_vocab.save(dst)
    print("Finished to generate vocabulary from: {0}".format(input_src))
    return entity_vocab, word_vocab

  @staticmethod
  def parse_corpus_line(line):
    """
    Parse bing corpus line to answer, query and context.

    Args:
      line (`str`): A line of text of bing corpus
    Returns:
      :`str`: Answer word
      :`str[]`: Array of query words
      :`str[]`: Array of context/passage words

    """
    data = line.split('\t')
    query = data[0]
    answer = data[1]
    context = data[2]
    query_words = query.split()
    context_words = context.split()
    return answer, query_words, context_words

  def build_corpus(entities, words, corpus, output, max_seq_len=100000):
    """
    Build featurized corpus and store it in CNTK Text Format.

    Args:
      entities (class:`Vocabulary`): The entities vocabulary
      words (class:`Vocabulary`): The words vocabulary
      corpus (`str`): The file path of the raw corpus
      output (`str`): The file path to store the featurized corpus data file
    """
    seq_id = 0
    print("Start to build CTF data from: {0}".format(corpus))
    with open(corpus, 'r', encoding = 'utf-8') as corp:
      with open(output, 'w', encoding = 'utf-8') as outf:
        all_lines = corp.readlines()
        print("Total lines to prcess: {0}".format(len(all_lines)))
        for line in all_lines:
          line = line.strip('\n')
          ans, query_words, context_words = Vocabulary.parse_corpus_line(line)
          ans_item = entities[ans]
          query_ids = []
          context_ids = []
          is_entity = []
          entity_ids = []
          labels = []
          pos = 0
          answer_idx = None
          for q in context_words:
            if Vocabulary.is_cnn_entity(q):
              item = entities[q]
              context_ids += [ item.id + 1 ]
              entity_ids += [ item.id + 1 ]
              is_entity += [1]
              if ans_item.id == item.id:
                labels += [1]
                answer_idx = pos
              else:
                labels += [0]
            else:
              item = words[q]
              context_ids += [ (item.id + 1 + entities.size) if item != None else 0 ]
              is_entity += [0]
              labels += [0]
            pos += 1
            if (pos >= max_seq_len):
              break
          if answer_idx is None:
            continue
          for q in query_words:
            if Vocabulary.is_cnn_entity(q):
              item = entities[q]
              query_ids += [ item.id + 1 ]
            else:
              item = words[q]
              query_ids += [ (item.id + 1 + entities.size) if item != None else 0 ]
          #Write featurized ids
          outf.write("{0}".format(seq_id))
          for i in range(max(len(context_ids), len(query_ids))):
            if i < len(query_ids):
              outf.write(" |Q {0}:1".format(query_ids[i]))
            if i < len(context_ids):
              outf.write(" |C {0}:1".format(context_ids[i]))
              outf.write(" |E {0}".format(is_entity[i]))
              outf.write(" |L {0}".format(labels[i]))
            if i < len(entity_ids):
              outf.write(" |EID {0}:1".format(entity_ids[i]))
            outf.write("\n")
          seq_id += 1
          if seq_id%1000 == 0:
            sys.stdout.write(".")
            sys.stdout.flush()
    print()
    print("Finished to build corpus from {0}".format(corpus))
  
vocab_path=os.path.join(data_path, "cnn.vocab")
train_ctf=os.path.join(data_path, "training.ctf")
test_ctf=os.path.join(data_path, "test.ctf")
validation_ctf=os.path.join(data_path, "validation.ctf")
vocab_size=101000
if not (file_exists(train_ctf) and file_exists(test_ctf) and file_exists(validation_ctf)):
  entity_vocab, word_vocab = Vocabulary.build_vocab(raw_train_data, vocab_path, vocab_size)
  Vocabulary.build_corpus(entity_vocab, word_vocab, raw_train_data, train_ctf)
  Vocabulary.build_corpus(entity_vocab, word_vocab, raw_test_data, test_ctf)
  Vocabulary.build_corpus(entity_vocab, word_vocab, raw_validation_data, validation_ctf)
print("Training data conversion finished.")

### Create Reader
The data is stored in CNTK Text Format and we need to create a reader to consume the data. There are 5 columns/streams in the data file, e.g. context, query, entity indication, label, entity ids. Here is an example,

0 |Q 586:1 |C 626:1 |E 0 |L 0 |EID 3:1

 |Q 12:1 |C 3:1 |E 1 |L 0 |EID 5:1
 
 |Q 2758:1 |C 625:1 |E 0 |L 0 |EID 4:1
 
 |Q 603:1 |C 1268:1 |E 0 |L 0 |EID 8:1
 
 |Q 933:1 |C 1516:1 |E 0 |L 0 |EID 10:1
 
 |Q 594:1 |C 757:1 |E 0 |L 0 |EID 13:1
 
 |Q 33:1 |C 586:1 |E 0 |L 0 |EID 14:1
 
 |Q 587:1 |C 4669:1 |E 0 |L 0 |EID 23:1
 
 |Q 10:1 |C 1712:1 |E 0 |L 0 |EID 10:1
 
 |Q 594:1 |C 591:1 |E 0 |L 0 |EID 10:1

Here the first column is the sequence id, 0. The second is the features of Query, the third is the features of Context, the fourth is a boolean to indicate if that word in the Context is an entity, the fifth is the Label which indicate if that word in the context is the answer. The last is the ID of entities in the context.

In [None]:
import sys
from cntk.io import MinibatchSource, CTFDeserializer, StreamDef, StreamDefs, INFINITELY_REPEAT, DEFAULT_RANDOMIZATION_WINDOW
import cntk.ops as ops
from cntk.layers.blocks import _INFERRED, Parameter
from cntk.internal import _as_tuple, sanitize_input
import cntk.learner as learner

def create_reader(path, vocab_dim, entity_dim, randomize, rand_size= DEFAULT_RANDOMIZATION_WINDOW, size=INFINITELY_REPEAT):
  """
  Create data reader for the model
  Args:
    path: The data path
    vocab_dim: The dimention of the vocabulary
    entity_dim: The dimention of entities
    randomize: Where to shuffle the data before feed into the trainer
  """
  return MinibatchSource(CTFDeserializer(path, StreamDefs(
    context  = StreamDef(field='C', shape=vocab_dim, is_sparse=True),
    query    = StreamDef(field='Q', shape=vocab_dim, is_sparse=True),
    entities  = StreamDef(field='E', shape=1, is_sparse=False),
    label   = StreamDef(field='L', shape=1, is_sparse=False),
    entity_ids   = StreamDef(field='EID', shape=entity_dim, is_sparse=True)
    )), randomize=randomize)



### Utils
We need some utils to be used in the model creation and training stage

In [None]:
import os
import numpy as np
from datetime import datetime
import math

class logger:
  __name=''
  __logfile=''

  @staticmethod
  def init(name=''):
    if not os.path.exists("model"):
      os.mkdir("model")
    if not os.path.exists("log"):
      os.mkdir("log")
    if name=='' or name is None:
      logger.__name='train'
    logger.__logfile = 'log/{}_{}.log'.format(logger.__name, datetime.now().strftime("%m-%d_%H.%M.%S"))
    if os.path.exists(logger.__logfile):
      os.remove(logger.__logfile)
    print('Log with log file: {0}'.format(logger.__logfile))

  @staticmethod
  def log(message, toconsole=True):
    if logger.__logfile == '' or logger.__logfile is None:
      logger.init()
    if toconsole:
      print(message)
    with open(logger.__logfile, 'a') as logf:
      logf.write("{}| {}\n".format(datetime.now().strftime("%Y-%m-%d %H:%M:%S"), message))

class uniform_initializer:
  def __init__(self, scale=1, bias=0, seed=0):
    self.seed = seed
    self.scale = scale
    self.bias = bias
    np.random.seed(self.seed)

  def reset(self):
    np.random.seed(self.seed)

  def next(self, size=None):
    return np.random.uniform(0, 1, size)*self.scale + self.bias

def create_random_matrix(rows, columns):
  scale = math.sqrt(6/(rows+columns))*2
  rand = uniform_initializer(scale, -scale/2)
  embedding = [None]*rows
  for i in range(rows):
    embedding[i] = np.array(rand.next(columns), dtype=np.float32)
  return np.ndarray((rows, columns), dtype=np.float32, buffer=np.array(embedding))

def load_embedding(embedding_path, vocab_path, dim, init=None):
  entity_vocab, word_vocab = Vocabulary.load_bingvocab(vocab_path)
  vocab_dim = len(entity_vocab) + len(word_vocab) + 1
  entity_size = len(entity_vocab)
  item_embedding = [None]*vocab_dim
  with open(embedding_path, 'r') as embedding:
    for line in embedding.readlines():
      line = line.strip('\n')
      item = line.split(' ')
      if item[0] in word_vocab:
        item_embedding[word_vocab[item[0]].id + entity_size + 1] = np.array(item[1:], dtype="|S").astype(np.float32)
  if init != None:
    init.reset()

  for i in range(vocab_dim):
    if item_embedding[i] is None:
      if init:
        item_embedding[i] = np.array(init.next(dim), dtype=np.float32)
      else:
        item_embedding[i] = np.array([0]*dim, dtype=np.float32)
  return np.ndarray((vocab_dim, dim), dtype=np.float32, buffer=np.array(item_embedding))

### Basic components
Here we provide some basic components that will be used in the model to simplify the model creation

In [None]:
import sys
import os
from datetime import datetime
import numpy as np
from cntk import Trainer, Axis, device, combine
from cntk.layers.blocks import Stabilizer, _initializer_for,  _INFERRED, Parameter, Placeholder
from cntk.layers import Recurrence, Convolution
from cntk.ops import input_variable, cross_entropy_with_softmax, classification_error, sequence, reduce_sum, \
    parameter, times, element_times, past_value, plus, placeholder_variable, reshape, constant, sigmoid, convolution, tanh, times_transpose, greater, cosine_distance, element_divide, element_select, exp, future_value, past_value
from cntk.internal import _as_tuple, sanitize_input
from cntk.initializer import uniform, glorot_uniform

def gru_cell(shape, init=glorot_uniform(), name=''): # (x, (h,c))
  """ GRU cell function
  """
  shape = _as_tuple(shape)

  if len(shape) != 1 :
    raise ValueError("gru_cell: shape must be vectors (rank-1 tensors)")

  # determine stacking dimensions
  cell_shape_stacked = shape * 2  # patched dims with stack_axis duplicated 2 times

  # parameters
  Wz = Parameter(cell_shape_stacked, init = init, name='Wz')
  Wr = Parameter(cell_shape_stacked, init = init, name='Wr')
  Wh = Parameter(cell_shape_stacked, init = init, name='Wh')
  Uz = Parameter(_INFERRED + shape, init = init, name = 'Uz')
  Ur = Parameter(_INFERRED + shape, init = init, name = 'Ur')
  Uh = Parameter(_INFERRED + shape, init = init, name = 'Uh')

  def create_s_placeholder():
    # we pass the known dimensions here, which makes dimension inference easier
    return Placeholder(shape=shape, name='S') # (h, c)

  # parameters to model function
  x = Placeholder(name='gru_block_arg')
  prev_status = create_s_placeholder()

  # formula of model function
  Sn_1 = prev_status

  z = sigmoid(times(x, Uz, name='x*Uz') + times(Sn_1, Wz, name='Sprev*Wz'), name='z')
  r = sigmoid(times(x, Ur, name='x*Ur') + times(Sn_1, Wr, name='Sprev*Wr'), name='r')
  h = tanh(times(x, Uh, name='x*Uh') + times(element_times(Sn_1, r, name='Sprev*r'), Wh), name='h')
  s = plus(element_times((1-z), h, name='(1-z)*h'), element_times(z, Sn_1, name='z*SPrev'), name=name)
  apply_x_s = combine([s])
  apply_x_s.create_placeholder = create_s_placeholder
  return apply_x_s

def seq_max(x, broadcast=True, name=''):
  """
  Get the max value in the sequence values

  Args:
    x: input sequence
    broadcast: if broadcast is True, the max value will be broadcast along with the input sequence,
    else only a single value will be returned
    name: the name of the operator
  """
  m = placeholder_variable(shape=(1,), dynamic_axes = x.dynamic_axes, name='max')
  o = element_select(greater(x, future_value(m)), x, future_value(m))
  rlt = o.replace_placeholders({m:sanitize_input(o)})
  if broadcast:
    pv = placeholder_variable(shape=(1,), dynamic_axes = x.dynamic_axes, name='max_seq')
    max_seq = element_select(sequence.is_first(x), sanitize_input(rlt), past_value(pv))
    max_out = max_seq.replace_placeholders({pv:sanitize_input(max_seq)})
  else:
    max_out = sequence.first(rlt)
  return sanitize_input(max_out)

def seq_softmax(x, name = ''):
  """
  Compute softmax along with a squence values
  """
  x_exp = exp((x-seq_max(x))*10)
  x_softmax = element_divide(x_exp, sequence.broadcast_as(sequence.reduce_sum(x_exp), x), name = name)
  return x_softmax

def cosine_similarity(src, tgt, name=''):
  """
  Compute the cosine similarity of two squences.
  Src is a sequence of length 1
  Tag is a sequence of lenght >=1
  """
  src_br = sequence.broadcast_as(src, tgt, name='src_broadcast')
  sim = cosine_distance(src_br, tgt, name)
  return sim

def project_cosine_sim(att_dim, init = glorot_uniform(), name=''):
  """
  Compute the project cosine similarity of two input sequences, where each of the input will be projected to a new dimention space (att_dim) via Wi/Wm
  """
  Wi = Parameter(_INFERRED + tuple((att_dim,)), init = init, name='Wi')
  Wm = Parameter(_INFERRED + tuple((att_dim,)), init = init, name='Wm')
  status = placeholder_variable(name='status')
  memory = placeholder_variable(name='memory')
  projected_status = times(status, Wi, name = 'projected_status')
  projected_memory = times(memory, Wm, name = 'projected_memory')
  sim = cosine_similarity(projected_status, projected_memory, name= name+ '_sim')
  return seq_softmax(sim, name = name)

def termination_gate(init = glorot_uniform(), name=''):
  Wt = Parameter( _INFERRED + tuple((1,)), init = init, name='Wt')
  status = placeholder_variable(name='status')
  return sigmoid(times(status, Wt), name=name)


### Create Model

In [None]:
class model_params:
  def __init__(self, vocab_dim, entity_dim, hidden_dim, embedding_dim=100, embedding_init=None, share_rnn_param=False, max_rl_steps=5, dropout_rate=None, init=glorot_uniform(), model_name='rsn'):
    self.vocab_dim = vocab_dim
    self.entity_dim = entity_dim
    self.hidden_dim = hidden_dim
    self.embedding_dim = embedding_dim
    self.embedding_init = embedding_init
    self.max_rl_steps = max_rl_steps
    self.dropout_rate = dropout_rate
    self.init = init
    self.model_name = model_name
    self.share_rnn_param = share_rnn_param
    self.attention_dim = 384

def bind_data(func, data):
  """
  Bind data outputs to cntk function arguments based on the argument name
  """
  bind = {}
  for arg in func.arguments:
    if arg.name == 'query':
      bind[arg] = data.streams.query
    if arg.name == 'context':
      bind[arg] = data.streams.context
    if arg.name == 'entity_ids_mask':
      bind[arg] = data.streams.entities
    if arg.name == 'labels':
      bind[arg] = data.streams.label
    if arg.name == 'entity_ids':
      bind[arg] = data.streams.entity_ids
  return bind

def attention_model(context_memory, query_memory, init_status, hidden_dim, att_dim, max_steps = 5, init = glorot_uniform()):
  """
  Create the attention model for reasonet
  Args:
    context_memory: Context memory
    query_memory: Query memory
    init_status: Intialize status
    hidden_dim: The dimention of hidden state
    att_dim: The dimention of attention
    max_step: Maxuim number of step to revisit the context memory
  """
  gru = gru_cell((hidden_dim*2, ), name='control_status')
  status = init_status
  output = [None]*max_steps*2
  sum_prob = None
  context_cos_sim = project_cosine_sim(att_dim, name='context_attention')
  query_cos_sim = project_cosine_sim(att_dim, name='query_attention')
  ans_cos_sim = project_cosine_sim(att_dim, name='candidate_attention')
  stop_gate = termination_gate(name='terminate_prob')
  prev_stop = 0
  for step in range(max_steps):
    context_attention_weight = context_cos_sim(status, context_memory)
    query_attention_weight = query_cos_sim(status, query_memory)
    context_attention = sequence.reduce_sum(times(context_attention_weight, context_memory), name='C-Att')
    query_attention = sequence.reduce_sum(times(query_attention_weight, query_memory), name='Q-Att')
    attention = ops.splice(query_attention, context_attention, name='att-sp')
    status = gru(attention, status).output
    termination_prob = stop_gate(status)
    ans_attention = ans_cos_sim(status, context_memory)
    output[step*2] = ans_attention
    if step < max_steps -1:
      stop_prob = prev_stop + ops.log(termination_prob, name='log_stop')
    else:
      stop_prob = prev_stop
    output[step*2+1] = sequence.broadcast_as(ops.exp(stop_prob, name='exp_log_stop'), output[step*2], name='Stop_{0}'.format(step))
    prev_stop += ops.log(1-termination_prob, name='log_non_stop')

  final_ans = None
  for step in range(max_steps):
    if final_ans is None:
      final_ans = output[step*2] * output[step*2+1]
    else:
      final_ans += output[step*2] * output[step*2+1]
  combine_func = combine(output + [ final_ans ], name='Attention_func')
  return combine_func

def create_model(params : model_params):
  """
  Create ReasoNet model
  Args:
    params (class:`model_params`): The parameters used to create the model
  """
  logger.log("Create model: dropout_rate: {0}, init:{1}, embedding_init: {2}".format(params.dropout_rate, params.init, params.embedding_init))
  # Query and Doc/Context/Paragraph inputs to the model
  batch_axis = Axis.default_batch_axis()
  query_seq_axis = Axis('sourceAxis')
  context_seq_axis = Axis('contextAxis')
  query_dynamic_axes = [batch_axis, query_seq_axis]
  query_sequence = input_variable(shape=(params.vocab_dim), is_sparse=True, dynamic_axes=query_dynamic_axes, name='query')
  context_dynamic_axes = [batch_axis, context_seq_axis]
  context_sequence = input_variable(shape=(params.vocab_dim), is_sparse=True, dynamic_axes=context_dynamic_axes, name='context')
  entity_ids_mask = input_variable(shape=(1,), is_sparse=False, dynamic_axes=context_dynamic_axes, name='entity_ids_mask')
  # embedding
  if params.embedding_init is None:
    embedding_init = create_random_matrix(params.vocab_dim, params.embedding_dim)
  else:
    embedding_init = params.embedding_init
  embedding = parameter(shape=(params.vocab_dim, params.embedding_dim), init=None)
  embedding.value = embedding_init
  embedding_matrix = constant(embedding_init, shape=(params.vocab_dim, params.embedding_dim))

  if params.dropout_rate is not None:
    query_embedding  = ops.dropout(times(query_sequence , embedding), params.dropout_rate, name='query_embedding')
    context_embedding = ops.dropout(times(context_sequence, embedding), params.dropout_rate, name='context_embedding')
  else:
    query_embedding  = times(query_sequence , embedding, name='query_embedding')
    context_embedding = times(context_sequence, embedding, name='context_embedding')

  contextGruW = Parameter(_INFERRED +  _as_tuple(params.hidden_dim), init=glorot_uniform(), name='gru_params')
  queryGruW = Parameter(_INFERRED +  _as_tuple(params.hidden_dim), init=glorot_uniform(), name='gru_params')

  entity_embedding = ops.times(context_sequence, embedding_matrix, name='constant_entity_embedding')
  # Unlike other words in the context, we keep the entity vectors fixed as a random vector so that each vector just means an identifier of different entities in the context and it has no semantic meaning
  full_context_embedding = ops.element_select(entity_ids_mask, entity_embedding, context_embedding)
  context_memory = ops.optimized_rnnstack(full_context_embedding, contextGruW, params.hidden_dim, 1, True, recurrent_op='gru', name='context_mem')

  query_memory = ops.optimized_rnnstack(query_embedding, queryGruW, params.hidden_dim, 1, True, recurrent_op='gru', name='query_mem')
  qfwd = ops.slice(sequence.last(query_memory), -1, 0, params.hidden_dim, name='fwd')
  qbwd = ops.slice(sequence.first(query_memory), -1, params.hidden_dim, params.hidden_dim*2, name='bwd')
  init_status = ops.splice(qfwd, qbwd, name='Init_Status') # get last fwd status and first bwd status
  return attention_model(context_memory, query_memory, init_status, params.hidden_dim, params.attention_dim, max_steps = params.max_rl_steps)



### Loss fucntion
#### Contractive Reward

In the ReasoNet paper, it gives the fomular of the Reward as
\begin{align}
J(\theta) = \mathbf{E}_{\pi\left(t_{1:T},a_T;\theta\right)}\left[\sum_{t=1}^Tr_t\right]
\end{align}

And it applies REINFORCE algorithm to estimate 
\begin{align} 
\nabla_{\theta}J(\theta) = \mathbf{E}_{\pi\left(t_{1:T},a_T;\theta\right)}\left[\nabla_{\theta}log_{\pi}\left(t_{1:T},a_T;\theta\right)r_T\right]=\sum_{\left(t_{1:T},a_T\right)\in\mathbb{A}^+}\pi\left(t_{1:T},a_T;\theta\right)\left[\nabla_{\theta}log\pi\left(t_{1:T},a_T;\theta\right)\left(r_T-b_T\right)\right]
\end{align}

However, as the baseline $\left\{b_T;T=1...T_{max}\right\}$ are global variables independent of instances, it leads to slow convergence in training ReasoNet. Instead, the paper rewrite the formular as,
$$
\nabla_{\theta}J(\theta) =\sum_{\left(t_{1:T},a_T\right)\in\mathbb{A}^+}\pi\left(t_{1:T},a_T;\theta\right)\left[\nabla_{\theta}log\pi\left(t_{1:T},a_T;\theta\right)\left(r_T-b\right)\right]
$$
,where $b=\sum_{\left(t_{1:T},a_T\right)\in\mathbb{A}^+}\pi\left(t_{1:T},a_T;\theta\right)r_T$ is the average reward on the $\left|\mathbb{A}^+\right|$ episodes.

Since the sum of the rewards over $\left|\mathbb{A}^+\right|$ episodes is zero, $\sum_{\left(t_{1:T},a_T\right)\in\mathbb{A}^+}\pi\left(t_{1:T},a_T;\theta\right)\left(r_T-b\right)=0$, they call it Contractive Reward. Further more, they found using $\left(\frac{r_T}{b}-1\right)$ in replace of $\left(r_T-b\right)$ will lead to a better convergence.

In our implementation, we take the reward in the form,
$$
J(\theta)=\sum_{\left(t_{1:T},a_T\right)\in\mathbb{A}^+}\pi\left(t_{1:T},a_T;\theta\right)\left(\frac{r_T}{b}-1\right) + b
$$
As we only compute gradient on $\pi\left(t_{1:T},a_T;\theta\right)$ and treat other components in the formula as a constant, the derivate is the same as the paper while the output is the average rewards in $\left|\mathbb{A}^+\right|$ episodes.
In CNTK, we use stop_gradient operator over the output of a function to conver it to a constant in the math formula.

In [None]:
def contractive_reward(labels, predictions_and_stop_probabilities):
  """
  Compute the contractive reward loss in paper 'ReasoNet: Learning to Stop Reading in Machine Comprehension'
  Args:
    labels: The lables
    predictions_and_stop_probabilities: A list of tuples, each tuple contains the prediction and stop probability of the coresponding step.
  """
  base = None
  avg_rewards = None
  for step in range(len(predictions_and_stop_probabilities)):
    pred = predictions_and_stop_probabilities[step][0]
    stop = predictions_and_stop_probabilities[step][1]
    if base is None:
      base = ops.element_times(pred, stop)
    else:
      base = ops.plus(ops.element_times(pred, stop), base)
  avg_rewards = ops.stop_gradient(sequence.reduce_sum(base*labels))
  base_reward = sequence.broadcast_as(avg_rewards, base, name = 'base_line')
  # While  the learner will mimize the loss by default, we want it to maxiumize the rewards
  # Maxium rewards => minimal -rewards
  # So we use (1-r/b) as the rewards instead of (r/b-1)
  step_cr = ops.stop_gradient(1- ops.element_divide(labels, base_reward))
  normalized_contractive_rewards = ops.element_times(base, step_cr)
  rewards = sequence.reduce_sum(normalized_contractive_rewards) + avg_rewards
  return rewards

#### Loss and accuracy

In [None]:
def accuracy_func(prediction, label, name='accuracy'):
  """
  Compute the accuracy of the prediction
  """
  pred_max = ops.hardmax(prediction, name='pred_max')
  norm_label = ops.equal(label, [1], name='norm_label')
  acc = ops.times_transpose(pred_max, norm_label, name='accuracy')
  return acc

def loss(model, params:model_params):
  """
  Compute the loss and accuracy of the model output
  """
  model_args = {arg.name:arg for arg in model.arguments}
  context = model_args['context']
  entity_ids_mask = model_args['entity_ids_mask']
  entity_condition = greater(entity_ids_mask, 0, name='condidion')
  entities_all = sequence.gather(entity_condition, entity_condition, name='entities_all')
  entity_ids = input_variable(shape=(params.entity_dim), is_sparse=True, dynamic_axes=entities_all.dynamic_axes, name='entity_ids')
  wordvocab_dim = params.vocab_dim
  labels_raw = input_variable(shape=(1,), is_sparse=False, dynamic_axes=context.dynamic_axes, name='labels')
  answers = sequence.scatter(sequence.gather(model.outputs[-1], entity_condition), entities_all, name='Final_Ans')
  labels = sequence.scatter(sequence.gather(labels_raw, entity_condition), entities_all, name='EntityLabels')
  entity_id_matrix = ops.reshape(entity_ids, params.entity_dim)
  expand_pred = sequence.reduce_sum(element_times(answers, entity_id_matrix))
  expand_label = ops.greater_equal(sequence.reduce_sum(element_times(labels, entity_id_matrix)), 1)
  expand_candidate_mask = ops.greater_equal(sequence.reduce_sum(entity_id_matrix), 1)
  predictions_and_stop_probabilities=[]
  for step in range(int((len(model.outputs)-1)/2)):
    predictions_and_stop_probabilities += [(model.outputs[step*2], model.outputs[step*2+1])]
  loss_value = contractive_reward(labels_raw, predictions_and_stop_probabilities)
  accuracy = accuracy_func(expand_pred, expand_label, name='accuracy')
  apply_loss = combine([loss_value, answers, labels, accuracy], name='Loss')
  return apply_loss


### Create Learner

In [None]:
def create_adam_learner(learn_params, learning_rate = 0.0005, gradient_clipping_threshold_per_sample=0.001):
  """
  Create adam learner
  """
  lr_schedule = learner.learning_rate_schedule(learning_rate, learner.UnitType.sample)
  momentum = learner.momentum_schedule(0.90)
  gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample
  gradient_clipping_with_truncation = True
  momentum_var = learner.momentum_schedule(0.999)
  lr = learner.adam_sgd(learn_params, lr_schedule, momentum, True, momentum_var,
          low_memory = False,
          gradient_clipping_threshold_per_sample = gradient_clipping_threshold_per_sample,
          gradient_clipping_with_truncation = gradient_clipping_with_truncation)
  learner_desc = 'Alg: Adam, learning rage: {0}, momentum: {1}, gradient clip: {2}'.format(learning_rate, momentum[0], gradient_clipping_threshold_per_sample)
  logger.log("Create learner. {0}".format(learner_desc))
  return lr


### Trainer

In [None]:
def __evaluation(trainer, data, bind, minibatch_size, epoch_size):
  """
  Evaluate the loss and accurate of the evaluation data set during training stage
  """
  if epoch_size is None:
    epoch_size = 1
  for key in bind.keys():
    if key.name == 'labels':
      label_arg = key
      break
  eval_acc = 0
  eval_s = 0
  k = 0
  print("Start evaluation with {0} samples ...".format(epoch_size))
  while k < epoch_size:
    mbs = min(epoch_size - k, minibatch_size)
    mb = data.next_minibatch(mbs, input_map=bind)
    k += mb[label_arg].num_samples
    sm = mb[label_arg].num_sequences
    avg_acc = trainer.test_minibatch(mb)
    eval_acc += sm*avg_acc
    eval_s += sm
    sys.stdout.write('.')
    sys.stdout.flush()
  eval_acc /= eval_s
  print("")
  logger.log("Evaluation Acc: {0}, samples: {1}".format(eval_acc, eval_s))
  return eval_acc

def train(model, m_params:model_params, learner, train_data, max_epochs=1, save_model_flag=False, epoch_size=270000, eval_data=None, eval_size=None, check_point_freq=0.1, minibatch_size=50000, model_name='rsn'):
  """
  Train the model
  Args:
    model: The created model
    m_params: Model parameters
    learner: The learner used to train the model
  """
  criterion_loss = loss(model, m_params)
  loss_func = criterion_loss.outputs[0]
  eval_func = criterion_loss.outputs[-1]
  trainer = Trainer(model.outputs[-1], (loss_func, eval_func), learner)
  # Get minibatches of sequences to train with and perform model training
  # bind inputs to data from readers
  train_bind = bind_data(criterion_loss, train_data)
  for k in train_bind.keys():
    if k.name == 'labels':
      label_key = k
      break
  eval_bind = bind_data(criterion_loss, eval_data)

  i = 0
  minibatch_count = 0
  training_progress_output_freq = 500
  check_point_interval = int(epoch_size*check_point_freq)
  check_point_id = 0
  for epoch in range(max_epochs):
    epoch_loss = 0
    epoch_acc = 0
    epoch_samples = 0
    i = 0
    win_loss = 0
    win_acc = 0
    win_samples = 0
    chk_loss = 0
    chk_acc = 0
    chk_samples = 0
    while i < epoch_size:
      # get next minibatch of training data
      mbs = min(minibatch_size, epoch_size - i)
      mb_train = train_data.next_minibatch(minibatch_size, input_map=train_bind)
      i += mb_train[label_key].num_samples
      trainer.train_minibatch(mb_train)
      minibatch_count += 1
      sys.stdout.write('.')
      sys.stdout.flush()
      # collect epoch-wide stats
      samples = trainer.previous_minibatch_sample_count
      ls = trainer.previous_minibatch_loss_average * samples
      acc = trainer.previous_minibatch_evaluation_average * samples
      epoch_loss += ls
      epoch_acc += acc
      win_loss += ls
      win_acc += acc
      chk_loss += ls
      chk_acc += acc
      epoch_samples += samples
      win_samples += samples
      chk_samples += samples
      if int(epoch_samples/training_progress_output_freq) != int((epoch_samples-samples)/training_progress_output_freq):
        print('')
        logger.log("Lastest sample count = {}, Train Loss: {}, Evalualtion ACC: {}".format(win_samples, win_loss/win_samples,
          win_acc/win_samples))
        logger.log("Total sample count = {}, Train Loss: {}, Evalualtion ACC: {}".format(chk_samples, chk_loss/chk_samples,
          chk_acc/chk_samples))
        win_samples = 0
        win_loss = 0
        win_acc = 0
      new_chk_id = int(i/check_point_interval)
      if new_chk_id != check_point_id and i < epoch_size :
        check_point_id = new_chk_id
        print('')
        logger.log("--- CHECKPOINT %d: samples=%d, loss = %.2f, acc = %.2f%% ---" % (check_point_id, chk_samples, chk_loss/chk_samples, 100.0*(chk_acc/chk_samples)))
        if eval_data:
          __evaluation(trainer, eval_data, eval_bind, minibatch_size, eval_size)
        if save_model_flag:
          # save the model every epoch
          model_filename = os.path.join('model', "model_%s_%03d.dnn" % (model_name, check_point_id))
          model.save_model(model_filename)
          logger.log("Saved model to '%s'" % model_filename)
        chk_samples = 0
        chk_loss = 0
        chk_acc = 0

    print('')
    logger.log("--- EPOCH %d: samples=%d, loss = %.2f, acc = %.2f%% ---" % (epoch, epoch_samples, epoch_loss/epoch_samples, 100.0*(epoch_acc/epoch_samples)))
  eval_acc = 0
  if eval_data:
    eval_acc = __evaluation(trainer, eval_data, eval_bind, minibatch_size, eval_size)
  if save_model_flag:
    # save the model every epoch
    model_filename = os.path.join('model', "model_%s_final.dnn" % (model_name))
    model.save_model(model_filename)
    logger.log("Saved model to '%s'" % model_filename)
  return (epoch_loss/epoch_samples, epoch_acc/epoch_samples, eval_acc)


### Train the model

In [None]:
import sys
import os
import cntk.device as device
import numpy as np
from cntk.ops.tests.ops_test_utils import cntk_device
from cntk.ops import input_variable, past_value, future_value
from cntk.io import MinibatchSource
from cntk import Trainer, Axis, device, combine
from cntk.layers import Recurrence, Convolution
import cntk.ops as ops
import cntk
import math

def test_reasonet():
  data_path = train_ctf
  eval_path = validation_ctf
  vocab_dim = 101585
  entity_dim = 586
  epoch_size=289716292
  eval_size=2993016
  hidden_dim=384
  max_rl_iter=5
  max_epochs=5
  embedding_dim=100
  params = model_params(vocab_dim = vocab_dim, entity_dim = entity_dim, hidden_dim = hidden_dim, embedding_dim = embedding_dim, embedding_init = None, dropout_rate = 0.2)

  train_data = create_reader(data_path, vocab_dim, entity_dim, True, rand_size=epoch_size)
  eval_data = create_reader(eval_path, vocab_dim, entity_dim, False, rand_size=eval_size) if eval_path is not None else None
  embedding_init = None

  model = create_model(params)
  learner = create_adam_learner(model.parameters)
  (train_loss, train_acc, eval_acc) = train(model, params, learner, train_data, max_epochs=max_epochs, epoch_size=epoch_size, save_model_flag=False, model_name=os.path.basename(data_path), eval_data=eval_data, eval_size=eval_size, check_point_freq=0.1, minibatch_size = 50000)

test_reasonet()