### Import

In [1]:
import torch
import transformers
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import wikipedia
import nltk
#utile pour le pos tagging
#nltk.download('punkt')
#nltk.download('averaged_perceptron_tagger')

### Load Model

In [2]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

***
### Model function
Use loaded BERT model to return answer<br>
**Input** : Question, Context text <br>
**return** : Answer

In [3]:
def generate_answer(question, answer_text):
    print("I'm looking for an aswer, wait please ...")
    # == Tokenize ==
    # use a python dictonary so run on CPU
    # Apply the tokenizer to the input text, treating them as a text-pair.
    print("-Tokenization")
    input_ids = tokenizer.encode(question, answer_text)
    #print('The input has a total of {:} tokens.'.format(len(input_ids)))

    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # == Set Segment IDs ==
    # Search the input_ids for the first instance of the `[SEP]` token.
    sep_index = input_ids.index(tokenizer.sep_token_id)

    # The number of segment A tokens includes the [SEP] token istelf.
    num_seg_a = sep_index + 1

    # The remainder are segment B.
    num_seg_b = len(input_ids) - num_seg_a

    # Construct the list of 0s and 1s.
    segment_ids = [0]*num_seg_a + [1]*num_seg_b

    # There should be a segment_id for every input token.
    assert len(segment_ids) == len(input_ids)

    # == Run Model ==
    # Run our example through the model.
    # by default on CPU, use model.to(device) to select GPU !?
    print("-Forward pass on the model")
    start_scores, end_scores = model(torch.tensor([input_ids]), # The tokens representing our input text.
                                 token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from 

    
    # donc on applique un argmax pour trouver le plus probable
    # Find the tokens with the highest `start` and `end` scores.
    answer_start = torch.argmax(start_scores)
    answer_end = torch.argmax(end_scores)
    
    print(type(start_scores))
    print(start_scores.size())
    print(start_scores[0,answer_start])
    print(end_scores[0,answer_end])
    
    # == Print Answer without ## ==
    # Start with the first token.
    answer = tokens[answer_start]

    # Select the remaining answer tokens and join them with whitespace.
    for i in range(answer_start + 1, answer_end + 1):
    
        # If it's a subword token, then recombine it with the previous token.
        if tokens[i][0:2] == '##':
            answer += tokens[i][2:]
    
        # Otherwise, add a space then the token.
        else:
            answer += ' ' + tokens[i]

    return answer

    

***
### Question Processing
Extract subjet _(and more?)_ from the question

In [4]:
def extract_subject(question):
    subject = None
    token = nltk.word_tokenize(question)
    #print(token)
    pos_token = nltk.pos_tag(token)
    for item in pos_token:
        if item[1] == 'NN':
            subject = item[0]
    if subject is not None:
        print("Subject found: " + subject)
    else:
        print("Subject not found 😔\n Rephrase the question or try another one")
    return  subject



In [5]:
import spacy
import en_core_web_sm
nlp = en_core_web_sm.load()
#!python -m spacy download en_core_web_sm
#nlp = spacy.load("en_core_web_sm")

def extract_subject_with_spacy(question):
    #noun to not take in count
    osef_list = ['who','why','what','when','which','how']
    doc = nlp(question)
    nouns = doc.noun_chunks
    #print(len(list(nouns)))
    for item in nouns:
        if str(item) not in osef_list:
            print("Subject found: " + str(item))
            return str(item)
    print("Subject not found 😔\n Rephrase the question or try another one")
        
#test
extract_subject_with_spacy('my cat')

Subject found: my cat


'my cat'

### Wikipedia API
Try to found the most relevant context text to give as BERT input. <br>
Get a wikipedia article, and scrap it

In [6]:
# == WIP ==
# découpé en paragraphe (le model a une limite de 512 token pour le text en entré)
def get_wiki_and_split(subject):
    text = wikipedia.summary(subject)
    print(len(text))
    return text


In [7]:
wikipedia.search("salsa")

['Salsa',
 'Salsa (dance)',
 'Salsa music',
 'Salsa (sauce)',
 'Bad Salsa',
 'Pico de gallo',
 'Salsa Lizano',
 'Electrica Salsa',
 'Salsa verde',
 'Salsa roja']

In [8]:
text = wikipedia.summary('salsa')
list = text.split('.')
print(list)

['Samba (Portuguese pronunciation: [ˈsɐ̃bɐ] (listen)), also known as samba urbano carioca (Urban Carioca Samba) or simply samba carioca (Carioca Samba) is a Brazilian music genre that originated in the Afro-Brazilians communities of Rio de Janeiro in the early 20th century', ' Having its roots in the cultural expression of West Africa and in Brazilian folk traditions, especially those linked to the primitive rural samba of the colonial and imperial periods, is considered one of the most important cultural phenomena in Brazil and one of the country symbols Present in the Portuguese language at least since the 19th century, the word “samba” was originally used to designate a “popular dance”', ' Over time, its meaning has been extended to a “batuque-like circle dance”, a dance style and also to a “music genre”', ' This process of establishing itself as a musical genre began in the 1910s and it had its inaugural landmark in the song “Pelo Telefone”, launched in 1917', ' Despite being ident

***
### Visualization

In [9]:
from IPython.display import display
from ipywidgets import widgets

#Widgets layout difinition
layout = widgets.Layout(width='400px', height='0px', margin='100px 0 0 100px')
bLayout = widgets.Layout(width='50px', height='28px', margin='100px 0 0 0px')
outLayoutPropre = widgets.Layout(width='480px', height='100px', margin='50px 0 100px 100px')
outLayoutTest = widgets.Layout(width='450px', height='auto', margin='50px 0 100px 100px')
#titleLayout = widgets.Layout(width='450px', height='auto', margin='0px 0 0px 100px')
 
#Widgets object definition
text = widgets.Text(layout=layout)
button = widgets.Button(description = 'Ask', layout = bLayout)
out = widgets.Output(layout=outLayoutTest)#layout=outLayout
#out = widgets.HTML(layout = outLayout, value= '<style>.text {width: 480px; heigh: 100px;}</style> <p class="text">'+ out_value +' </p>')
 
def button_on_click(self):
    with out:
        out.clear_output()
        subject = extract_subject_with_spacy(question=text.value)
        if subject is not None:
            context = wikipedia.summary(subject)
            
            answer = generate_answer(text.value, context[:2000])
            #out.clear_output()
            print("Here what i found: \n"+ answer)
        else:
            pass
        
button.on_click(button_on_click)

display(widgets.HBox((text, button,)))
display(out)

HBox(children=(Text(value='', layout=Layout(height='0px', margin='100px 0 0 100px', width='400px')), Button(de…

Output(layout=Layout(height='auto', margin='50px 0 100px 100px', width='450px'))