In [26]:
import numpy as np
import requests
import html2text
from googlesearch import search
import json
import re
from simpletransformers.question_answering import QuestionAnsweringModel
from IPython.display import display
from IPython.html import widgets
from bs4 import BeautifulSoup
from markdown import markdown

In [27]:
# Source: https://gist.github.com/lorey/eb15a7f3338f959a78cc3661fbc255fe
def markdown_to_text(markdown_string):
    """ Converts a markdown string to plaintext """

    # md -> html -> text since BeautifulSoup can extract text cleanly
    html = markdown(markdown_string)

    # remove code snippets
    html = re.sub(r'<pre>(.*?)</pre>', ' ', html)
    html = re.sub(r'<code>(.*?)</code >', ' ', html)

    # extract text
    soup = BeautifulSoup(html, "html.parser")
    text = ''.join(soup.findAll(text=True))

    return text

def format_text(text):
    text = markdown_to_text(text)
    text = text.replace('\n', ' ')
#     text = re.sub(r'^https?:\/\/.*[\r\n]*', '', text, flags=re.MULTILINE)
#     text = re.sub(r'^http?:\/\/.*[\r\n]*', '', text, flags=re.MULTILINE)
    return text

In [28]:
def query_pages(query, n=5):
    return list(search(query, num=n, stop=n, pause=2))

def query_to_text(query, n=5):
    html_conv = html2text.HTML2Text()
    html_conv.ignore_links = True
    html_conv.escape_all = True
    
    text = []
    for link in query_pages(query, n):
        req = requests.get(link)
        text.append(html_conv.handle(req.text))
        text[-1] = format_text(text[-1])
        
    return text

In [29]:
def create_model():
     return QuestionAnsweringModel('distilbert', 'distilbert-base-uncased-distilled-squad')
    
def predict_answer(model, question, contexts, seq_len=512, debug=False):
    split_context = []
    
    if not isinstance(contexts, list):
        contexts = [contexts]
    
    for context in contexts:
        for i in range(0, len(context), seq_len):
            split_context.append(context[i:i+seq_len])
            
    split_context = contexts
    
    f_data = []
    
    for i, c in enumerate(split_context):
        f_data.append(
            {'qas': 
              [{'question': question,
               'id': i,
               'answers': [{'text': ' ', 'answer_start': 0}],
               'is_impossible': False}],
              'context': c
            })
        
    prediction = model.predict(f_data)
    if debug:
        print(prediction)
    preds = [x['answer'].lower().strip() for x in prediction if x['answer'].strip() != '']
    
    return max(set(preds), key = preds.count)

In [30]:
def q_to_a(model, question, n=2, debug=False):
    context = query_to_text(question, n=n)
    pred = predict_answer(model, question, context, debug=debug)
    return pred

In [31]:
# Example

# model = create_model()

# print(predict_answer(model, 'what color is the bird?', 'the bird is red.'))

#question = 'What is the bone on the back of your skull called?'
#context = query_to_text(question, n=3)
#pred = predict_answer(model, question, context)
#print(pred)

In [32]:
model = create_model()

In [33]:
text = widgets.Text(description='Question:', width=300)
display(text)

button = widgets.Button(description='Get an Answer')
display(button)

def on_button_click(b):
    answer = q_to_a(model, text.value, n=2)
    print('Answer:', answer)
    
button.on_click(on_button_click)

Text(value='', description='Question:')

Button(description='Get an Answer', style=ButtonStyle())