import torch
import argparse
import code
import prettytable
import logging
from termcolor import colored
from drqa import pipeline
from drqa.retriever import utils
logger = logging.getLogger()
fmt = logging.Formatter('%(asctime)s: [ %(message)s ]', '%m/%d/%Y %I:%M:%S %p')
console = logging.StreamHandler()
from gutenberg.query import get_etexts
from gutenberg.query import get_metadata
parser = argparse.ArgumentParser()
parser.add_argument('--reader-model', type=str, default=None,
help='Path to trained Document Reader model')
parser.add_argument('--retriever-model', type=str, default=None,
help='Path to Document Retriever model (tfidf)')
parser.add_argument('--doc-db', type=str, default=None,
help='Path to Document DB')
parser.add_argument('--tokenizer', type=str, default=None,
help=("String option specifying tokenizer type to "
"use (e.g. 'corenlp')"))
parser.add_argument('--candidate-file', type=str, default=None,
help=("List of candidates to restrict predictions to, "
"one candidate per line"))
parser.add_argument('--no-cuda', action='store_true',
help="Use CPU only")
parser.add_argument('--gpu', type=int, default=-1,
help="Specify GPU device id to use")
args = parser.parse_args()
args.cuda = not args.no_cuda and torch.cuda.is_available()
if args.cuda:
torch.cuda.set_device(args.gpu)'CUDA enabled (GPU %d)' % args.gpu)
else:'Running on CPU only.')
if args.candidate_file:'Loading candidates from %s' % args.candidate_file)
candidates = set()
with open(args.candidate_file) as f:
for line in f:
line = utils.normalize(line.strip()).lower()
candidates.add(line)'Loaded %d candidates.' % len(candidates))
candidates = None'Initializing pipeline...')
DrQA = pipeline.DrQA(
ranker_config={'options': {'tfidf_path': args.retriever_model}},
db_config={'options': {'db_path': args.doc_db}},
# ------------------------------------------------------------------------------
# Drop in to interactive mode
# ------------------------------------------------------------------------------
def process(question, candidates=None, top_n=1, n_docs=5):
predictions = DrQA.process(
question, candidates, top_n, n_docs, return_context=True
table = prettytable.PrettyTable(
['Rank', 'Answer', 'Doc-ID', 'Doc-Title', 'Doc-Author', 'Doc-Link', 'Answer Score', 'Doc Score']
for i, p in enumerate(predictions, 1):
tittle = list(get_metadata('title', p['doc_id']))[0]
author = list(get_metadata('author', p['doc_id']))[0]
url = list(get_metadata('formaturi', p['doc_id']))[0]
table.add_row([i, p['span'], p['doc_id'], tittle, author, url, '%.5g' % p['span_score'], '%.5g' % p['doc_score']])
print('Top Predictions:')
for p in predictions:
text = p['context']['text']
start = p['context']['start']
end = p['context']['end']
output = (text[:start] +
colored(text[start: end], 'green', attrs=['bold']) +
print('[ Doc = %s ]' % p['doc_id'])
print(output + '\n')
banner = """
Interactive MRC
>> process(question, candidates=None, top_n=1, n_docs=5)
>> usage()
def usage():
code.interact(banner=banner, local=locals())
