# Bi-directional Attention Flow

This notebook shows how to use [BiDAF](arxiv.org/abs/1611.01603) for Question Answering, exatracting information from a context paragaph.
We will be using the [AllenNLP](https://github.com/allenai/allennlp) library built on top of [PyTorch](https://pytorch.org/)


In [None]:
from allennlp import predictors
from allennlp.predictors import Predictor
from allennlp.models.archival import load_archive

%matplotlib inline

import matplotlib
import numpy as np
import matplotlib.pyplot as plt


class PretrainedModel:
    """
    A pretrained model is determined by both an archive file
    (representing the trained model)
    and a choice of predictor.
    """
    def __init__(self, archive_file: str, predictor_name: str) -> None:
        self.archive_file = archive_file
        self.predictor_name = predictor_name

    def predictor(self) -> Predictor:
        archive = load_archive(self.archive_file)
        return Predictor.from_archive(archive, self.predictor_name)

def bidirectional_attention_flow_seo_2017() -> predictors.BidafPredictor:
    model = PretrainedModel(
        'https://s3-us-west-2.amazonaws.com/allennlp/models/bidaf-model-2017.09.15-charpad.tar.gz',
        'machine-comprehension'
    )
    return model.predictor() # type: ignore



In [None]:
# define some helper function to handle the output and take a sneak peak into the model attention mechanism:

def build_answer(ans):
    return ans['best_span_str']
#     return " ".join(ans['passage_tokens'][start:end])

def get_span_probs(ans):
    start, end = ans['best_span'][0], ans['best_span'][1]
    return ans['span_start_probs'][start], ans['span_end_probs'][end]

def plot_attention(ans):
    # Retrieve answer relevant information
    energies = np.array(ans['passage_question_attention'], dtype=np.float32)
    question_tokens = ans['question_tokens']
    passage_tokens = ans['passage_tokens']
    
    # Create the plot
    fig, ax = plt.subplots(figsize=(15, 100))
    m = ax.imshow(energies, plt.get_cmap("Blues"))
    
    # Use each token...
    ax.set_xticks(np.arange(len(question_tokens)))
    ax.set_yticks(np.arange(len(passage_tokens)))
    
    # ... to label the axis
    ax.xaxis.set_tick_params(labeltop=True)
    ax.set_xticklabels(question_tokens, rotation=75)
    ax.set_yticklabels(passage_tokens)
    
    # Create a colorbar
    cbar = plt.colorbar(m)
    cbar.ax.set_ylabel("attention weight", rotation=-90, va="bottom")
    
    plt.show()
    

In [None]:
# define a passage from where to extract information:
# An example passage could be the following:
"""
The history of the penny of Great Britain and the United Kingdom from 1714 to 1901, 
the period in which the House of Hanover reigned, saw its transformation from a small 
silver coin to a larger bronze piece. All bear the portrait of the monarch on the obverse; 
copper and bronze pennies have a depiction of Britannia on the reverse. During most of the 18th century, 
the penny was a small silver coin rarely seen in circulation. Beginning in 1787, 
the chronic shortage of good money resulted in the wide circulation of private tokens, 
including ones valued at one penny. In 1797 Matthew Boulton gained a government contract 
and struck millions of pennies. The copper penny continued to be issued until 1860, 
when they were replaced by lighter bronze coins; the "Bun penny", 
named for the hairstyle of Queen Victoria on it, was issued from then until 1894. 
The final years of her reign saw the "Old head" pennies, coined from 1895 until her death in 1901
"""
passage = input("Input a text passage you would like to ask questions about\n")


In [None]:
# here we define the question and make a prediction.
# Some questions examples:
#    When was the 'Old Head' penny coined?
#    How was the penny called due to Queen Victoria hairstyle?
#    When did Matthew boulton gain the government contract?
#    What is depicted in the reverse of the penny?
#    How many pennies Matthew Boulton strucked?

model = bidirectional_attention_flow_seo_2017()

while True:
    question = input("Input your question\t")
    
    # 1. make a prediction
    ans = model.predict(question, passage)

    # ensemble the answer
    p1, p2 = get_span_probs(ans)
    print(f"\n{question} --> {build_answer(ans)} | p(start)={p1:.3f}, p(end)={p2:.3f}\n")

    # plot the attention
#     plot_attention(ans)