In [1]:
#Import library

from simpletransformers.question_answering import QuestionAnsweringModel

import pandas as pd
import requests
import html2text
from googlesearch import search
import re
from IPython.display import display
import ipywidgets as widgets
from bs4 import BeautifulSoup
from markdown import markdown

import warnings
warnings.filterwarnings('ignore')



In [2]:
model = QuestionAnsweringModel('distilbert','distilbert-base-uncased-distilled-squad', use_cuda=False)

In [3]:
def predict_answer(model, question, contexts, seq_len=512, debug=False):
    split_context = []
    
    if not isinstance(contexts, list):
        contexts = [contexts]
        
    for context in contexts:
        for i in range(0, len(context), seq_len):
            split_context.append(context[i:i+seq_len])
            
    split_context = contexts
    
    f_data = []
    
    for i, c in enumerate(split_context):
        f_data.append(
        {'qas':
            [{'question': question,
              'id':i,
              'answer':[{'text':' ', 'answer_start':0}],
              'is_impossible':False}],
            'context': c
        })
    prediction = model.predict(f_data)
    
    if debug:
        print(prediction)
        

    t = [s for s in prediction[0]]
    p = [s for s in prediction[1]]
    
    data = pd.DataFrame({'id': [t[i]['id'] for i in range(0,len(t))], 'answer': [t[i]['answer'][0] for i in range(0,len(t))],
       'probability' : [p[i]['probability'][0] for i in range(0,len(p))]})
    preds = data.loc[data['probability'].idxmax()]['answer']

    if preds:
        return preds
    return 'No answer'

In [4]:
def markdown_to_text(markdown_string):
    html = markdown(markdown_string)
    
    html = re.sub(r'<pre>(.*?)</pre>',' ', html)
    html = re.sub(r'<code>(.*?)</code>',' ',html)
    
    soup = BeautifulSoup(html, 'html.parser')
    text = ''.join(soup.findAll(text=True))
    
    return text

def format_text(text):
    text = markdown_to_text(text)
    text = text.replace('\n',' ')
    return text

In [5]:
def query_pages(query, n=2):
    return list(search(query, num=n, stop=10, pause=2))

def query_to_text(query, n=2):
    html_conv = html2text.HTML2Text()
    html_conv.ignore_links = True
    html_conv.escape_all = True

    text = []
    for link in query_pages(query, n=n):
        req = requests.get(link, verify=False)
        if req:
            text.append(html_conv.handle(req.text))
            text[-1] = format_text(text[-1])
    return text

In [6]:
def q_to_a(model, question,n=1, debug=False):
    context = query_to_text(question)
    pred = predict_answer(model, question, context, debug=debug)
    return pred

In [7]:
text = widgets.Text(description='Question:', width=300)
display(text)

button = widgets.Button(description='Get an Answer')
display(button)

def on_button_click(b):
    answer = q_to_a(model, text.value, n=5)
    print('Answer:', answer)
    
button.on_click(on_button_click)


# example -- Question- what is the population of India?
# what is the national animal of India?
# Which scientist proposed the three laws of motion?

Text(value='', description='Question:')

Button(description='Get an Answer', style=ButtonStyle())