# Implementing a new model with Jack 

In this tutorial, we focus on the minimal steps required to implement a new model from scratch using Jack.

We will implement a simple Bi-LSTM baseline for extractive question answering.
The architecture is as follows:
- Words of question and support are embedded using random embeddings (not trained)
- Both word and question are encoded using a bi-directional LSTM
- The question is summarized by averaging its token representations
- A feedforward NN scores each of the support tokens to be the start of the answer
- A feedforward NN scores each of the support tokens to be the end of the answer

In order to implement a Jack reader, we define three modules:
- **Input Module**: Responsible for mapping `QASetting`s to numpy array assoicated with `TensorPort`s
- **Model Module**: Defines the TensorFlow graph
- **Output Module**: Converting the network output to the output of the system. In our case, this involves extracting the answer string from the context. We will use the existing `XQAOutputModule`.

In [1]:
# First change dir to jack parent
import os
os.chdir('..')

import re

from jack.core import *
from jack.io.embeddings import Embeddings
from jack.util.vocab import *
from jack.readers.extractive_qa.shared import XQAPorts, XQAOutputModule
from jack.readers.extractive_qa.util import prepare_data, stack_and_pad
from jack.readers.extractive_qa.util import tokenize
from jack.tf_fun.rnn import birnn_with_projection
from jack.util import tfutil
from jack.util.map import numpify

_tokenize_pattern = re.compile('\w+|[^\w\s]')

## Ports

All communication between input, model and output modules happens via `TensorPort`s (see `jack/core/tensorport.py`).
Normally, you should try to reuse ports wherever possible to be able to reuse modules as well.
If you need a new port, however, it is also straight-forward to define one.
For this tutorial, we will define most ports here.

In [2]:
embedded_question = TensorPort(tf.float32, [None, None, None], "embedded_question_flat",
                               "Represents the embedded question",
                               "[Q, max_num_question_tokens, N]")
# or reuse FlatPorts.Misc.embedded_question

question_length = TensorPort(tf.int32, [None], "question_length_flat",
                             "Represents length of questions in batch",
                             "[Q]")
# or reuse FlatPorts.Input.question_length

embedded_support = TensorPort(tf.float32, [None, None, None], "embedded_support_flat",
                              "Represents the embedded support",
                              "[S, max_num_tokens, N]")
# or reuse FlatPorts.Misc.embedded_support

support_length = TensorPort(tf.int32, [None], "support_length_flat",
                            "Represents length of support in batch",
                            "[S]")
# or reuse FlatPorts.Input.support_length

answer_span = TensorPort(tf.int32, [None, 2], "answer_span_target_flat",
                         "Represents answer as a (start, end) span", "[A, 2]")
# or reuse FlatPorts.Prediction.answer_span

In order to reuse the `XQAOutputModule`, we'll use existing ports defined in `XQAPorts` for the `char_token_offset` and the predictions.
We'll also use the `Ports.loss` port, because the the JTR training code expects this port as output of the model module.

In [3]:
print(XQAPorts.token_char_offsets.get_description())
print(XQAPorts.start_scores.get_description())
print(XQAPorts.end_scores.get_description())
print(XQAPorts.span_prediction.get_description())
print(Ports.loss.get_description())

Tensorport 'token_char_offsets'
  dtype: <dtype: 'int32'>
  shape: [None, None]
  doc_string: Character offsets of tokens in support.
  shape_string: [S, support_length]
Tensorport 'start_scores_flat'
  dtype: <dtype: 'float32'>
  shape: [None, None]
  doc_string: Represents start scores for each support sequence
  shape_string: [S, max_num_tokens]
Tensorport 'end_scores_flat'
  dtype: <dtype: 'float32'>
  shape: [None, None]
  doc_string: Represents end scores for each support sequence
  shape_string: [S, max_num_tokens]
Tensorport 'answer_span_prediction_flat'
  dtype: <dtype: 'int32'>
  shape: [None, 2]
  doc_string: Represents answer as a (start, end) span
  shape_string: [A, 2]
Tensorport 'loss'
  dtype: <dtype: 'float32'>
  shape: [None]
  doc_string: Represents loss on each instance in the batch
  shape_string: [batch_size]


## Input Module

The input module is responsible for converting `QASetting` instances to numpy
arrays, which are mapped to `TensorPort`s. Essentially, we are building a
feed dict used for training and inference. Note, there are input modules for
several readers that can easily be reused when your model requires the same
pre-processing and input as another model. Similarly, this is true for the
OutputModule. In case you can reuse those modules it is enough to simply
implement your ModelModule (see below). See `jack/readers/implementations.py`
how different readers re-use the same modules.

You could implement the `InputModule` interface, but in many cases it'll be
easier to inherit from `OnlineInputModule`. Doing this, we need to:
- Define the output `TensorPort`s of our input module
- Implement the preprocessing (e.g. tokenization, mapping to embedding vectors, ...). The result of this step is one *annotation* per instance, e.g. a `dict`.
- Implement batching. Given a list of annotations, you need to define how to build the feed dict.

In [4]:
class MyInputModule(OnlineInputModule):

    def __init__(self, shared_resources):
        """The module is initialized with a `shared_resources`.

        For the purpose of this tutorial, we will only use the `vocab` property
        which provides the embeddings. You could also pass arbitrary
        configuration parameters in the `shared_resources.config` dict.
        """
        self.vocab = shared_resources.vocab
        self.emb_matrix = self.vocab.emb.lookup

    # We will now define the input and output TensorPorts of our model.

    @property
    def output_ports(self):
        return [embedded_question,           # Question embeddings
                question_length,             # Lengths of the questions
                embedded_support,            # Support embeddings
                support_length,              # Lengths of the supports
                XQAPorts.token_char_offsets  # Character offsets of tokens in support, used for in ouput module
               ]

    @property
    def training_ports(self):
        return [answer_span]                 # Answer span, one for each question

    # Now, we implement our preprocessing. This involves tokenization,
    # mapping to token IDs, mapping to to token embeddings,
    # and computing the answer spans.

    def _get_emb(self, idx):
        """Maps a token ID to it's respective embedding vector"""
        if idx < self.emb_matrix.shape[0]:
            return self.vocab.emb.lookup[idx]
        else:
            # <OOV>
            return np.zeros([self.vocab.emb_length])

    def preprocess(self, questions, answers=None, is_eval=False):
        """Maps a list of instances to a list of annotations.

        Since in our case, all instances can be preprocessed independently, we'll
        delegate the preprocessing to a `_preprocess_instance()` method.
        """

        if answers is None:
            answers = [None] * len(questions)

        return [self._preprocess_instance(q, a)
                for q, a in zip(questions, answers)]

    def _preprocess_instance(self, question, answers=None):
        """Maps an instance to an annotation.

        An annotation contains the embeddings and length of question and support,
        token offsets, and optionally answer spans.
        """

        has_answers = answers is not None

        # `prepare_data()` handles most of the computation in our case, but
        # you could implement your own preprocessing here.
        q_tokenized, q_ids, q_length, s_tokenized, s_ids, s_length, \
        word_in_question, token_offsets, answer_spans = \
            prepare_data(question, answers, self.vocab,
                         with_answers=has_answers,
                         max_support_length=100)

        # For both question and support, we'll fill an embedding tensor
        emb_support = np.zeros([s_length, self.emb_matrix.shape[1]])
        emb_question = np.zeros([q_length, self.emb_matrix.shape[1]])
        for k in range(len(s_ids)):
            emb_support[k] = self._get_emb(s_ids[k])
        for k in range(len(q_ids)):
            emb_question[k] = self._get_emb(q_ids[k])

        # Now, we build the annotation for the question instance. We'll use a
        # dict that maps from `TensorPort` to numpy array, but this could be
        # any data type, like a custom class, or a named tuple.

        annotation = {
            question_length: q_length,
            embedded_question: emb_question,
            support_length: s_length,
            embedded_support: emb_support,
            XQAPorts.token_char_offsets: token_offsets
        }

        if has_answers:
            # For the purpose of this tutorial, we'll only use the first answer, such
            # that we will have exactly as many answers as questions.
            annotation[answer_span] = list(answer_spans[0])

        return numpify(annotation, keys=[support_length, question_length,
                                         XQAPorts.token_char_offsets, answer_span])

    def create_batch(self, annotations, is_eval, with_answers):
        """Now, we need to implement the mapping of a list of annotations to a feed dict.
        
        Because our annotations already are dicts mapping TensorPorts to numpy
        arrays, we only need to do padding here.
        """

        return {key: stack_and_pad([a[key] for a in annotations])
                for key in annotations[0].keys()}

## Model Module.

The model module defines the TensorFlow computation graph.
It takes input module outputs as inputs and produces outputs such as the loss
and outputs required by hte output module.

In [5]:
class MyModelModule(TFModelModule):

    @property
    def input_ports(self) -> Sequence[TensorPort]:
        return [embedded_question, question_length,
                embedded_support, support_length]

    @property
    def output_ports(self) -> Sequence[TensorPort]:
        return [XQAPorts.start_scores, XQAPorts.end_scores,
                XQAPorts.span_prediction]

    @property
    def training_input_ports(self) -> Sequence[TensorPort]:
        return [XQAPorts.start_scores, XQAPorts.end_scores, answer_span]

    @property
    def training_output_ports(self) -> Sequence[TensorPort]:
        return [Ports.loss]

    def create_output(self, shared_resources, emb_question, question_length,
                      emb_support, support_length):
        """
        Implements the "core" model: The TensorFlow subgraph which computes the
        answer span from the embedded question and support.
        Args:
            emb_question: [Q, L_q, N]
            question_length: [Q]
            emb_support: [Q, L_s, N]
            support_length: [Q]

        Returns:
            start_scores [B, L_s, N], end_scores [B, L_s, N], span_prediction [B, 2]
        """

        with tf.variable_scope("fast_qa", initializer=tf.contrib.layers.xavier_initializer()):
            dim = shared_resources.config['repr_dim']
            # set shapes for inputs
            emb_question.set_shape([None, None, dim])
            emb_support.set_shape([None, None, dim])

            # encode question and support
            rnn = tf.contrib.rnn.LSTMBlockFusedCell
            encoded_question = birnn_with_projection(dim, rnn, emb_question, question_length,
                                                     projection_scope="question_proj")

            encoded_support = birnn_with_projection(dim, rnn, emb_support, support_length,
                                                    share_rnn=True, projection_scope="support_proj")

            start_scores, end_scores, predicted_start_pointer, predicted_end_pointer = \
                self._output_layer(dim, encoded_question, question_length,
                                   encoded_support, support_length)

            span = tf.concat([predicted_start_pointer, predicted_end_pointer], 1)

            return start_scores, end_scores, span

    def _output_layer(self, dim, encoded_question, question_length, encoded_support, support_length):
        """Simple span prediction layer of our network"""
        batch_size = tf.shape(question_length)[0]

        # Computing weighted question state
        attention_scores = tf.contrib.layers.fully_connected(encoded_question, 1,
                                                             scope="question_attention")
        q_mask = tfutil.mask_for_lengths(question_length, batch_size)
        attention_scores = attention_scores + tf.expand_dims(q_mask, 2)
        question_attention_weights = tf.nn.softmax(attention_scores, 1, name="question_attention_weights")
        question_state = tf.reduce_sum(question_attention_weights * encoded_question, [1])

        # Prediction
        support_mask = tfutil.mask_for_lengths(support_length, batch_size)
        def predict():
            interaction = tf.expand_dims(question_state, 1) * encoded_support
            scores = tf.layers.dense(tf.concat([interaction, encoded_support], axis=2), 1)
            scores = tf.squeeze(scores, [2])
            scores = scores + support_mask
            _, predicted = tf.nn.top_k(scores, 1)
            return scores, predicted

        start_scores, predicted_start_pointer = predict()
        end_scores, predicted_end_pointer = predict()

        return start_scores, end_scores, predicted_start_pointer, predicted_end_pointer

    def create_training_output(self, shared_resources, start_scores, end_scores, answer_span) \
            -> Sequence[TensorPort]:
        """Compute loss from start & end scores and the gold-standard `answer_span`."""

        start, end = [tf.squeeze(t, 1) for t in tf.split(answer_span, 2, 1)]

        loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=start_scores,
                                                              labels=start) + \
               tf.nn.sparse_softmax_cross_entropy_with_logits(logits=end_scores, labels=end)
        return [tf.reduce_mean(loss)]

## Output Module

The output module converts our model predictions to `Answer` instances.
Since our model is a standard extractive QA model and since we used the standard
`TensorPort`s, we can reuse the existing `XQAOutputModule` rather than implementing
our own.

## Training

As a toy example, we'll use da dataset of just one question:

In [6]:
data_set = [
    (QASetting(
        question="Which is it?",
        support=["While b seems plausible, answer a is correct."],
        id="1"),
     [Answer(text="a", span=(32, 33))])
]

The `build_vocab()` function builds a random embedding matrix. Normally,
we could load pre-trained embeddings here, such as GloVe.

In [7]:
embedding_dim = 10

def build_vocab(questions):
    """Build a vocabulary of random vectors."""

    embedding_lookup = dict()
    for question in questions:
        for t in tokenize(question.question):
            if t not in embedding_lookup:
                embedding_lookup[t] = len(embedding_lookup)
    embeddings = Embeddings(embedding_lookup, 
                            np.random.random([len(embedding_lookup),
                                              embedding_dim]))

    vocab = Vocab(emb=embeddings, init_from_embeddings=True)
    return vocab

questions = [q for q, _ in data_set]
shared_resources = SharedResources(build_vocab(questions), config={'repr_dim': 10})

Now, we'll instantiate all modules with the `shared_resources` as parameter.
The `JTReader` needs the three modules and is ready to train!

In [8]:
input_module = MyInputModule(shared_resources)
model_module = MyModelModule(shared_resources)
output_module = XQAOutputModule(shared_resources)
reader = TFReader(shared_resources, input_module, model_module, output_module)

In [9]:
from jack.util.hooks import LossHook
batch_size=1
reader.train(tf.train.AdamOptimizer(learning_rate=0.1), 
             data_set, batch_size, max_epochs=10,
             hooks=[LossHook(reader, iter_interval=1)])

print(questions[0].question, questions[0].support[0])
answers = reader(questions)
print("{}, {}, {}".format(answers[0].score, answers[0].span, answers[0].text))

INFO:jack.core.reader:Start training...
INFO:jack.util.hooks:Epoch 1	Iter 1	train loss 4.6132283210754395
INFO:jack.util.hooks:Epoch 2	Iter 2	train loss 4.096255302429199
INFO:jack.util.hooks:Epoch 3	Iter 3	train loss 1.8390401601791382
INFO:jack.util.hooks:Epoch 4	Iter 4	train loss 0.12551766633987427
INFO:jack.util.hooks:Epoch 5	Iter 5	train loss 0.0002952593786176294
INFO:jack.util.hooks:Epoch 6	Iter 6	train loss 2.3841856489070778e-07
INFO:jack.util.hooks:Epoch 7	Iter 7	train loss 0.0
INFO:jack.util.hooks:Epoch 8	Iter 8	train loss 2.3841855067985307e-07
INFO:jack.util.hooks:Epoch 9	Iter 9	train loss 7.271742560988059e-06
INFO:jack.util.hooks:Epoch 10	Iter 10	train loss 0.0013461976777762175
Which is it? While b seems plausible, answer a is correct.
0.9999980926513672, (32, 33), a
