### Import

In [5]:
import torch
import transformers
from transformers import BertForQuestionAnswering
from transformers import BertTokenizer
import wikipedia
import nltk
#utile pour le pos tagging
#nltk.download('punkt')
#nltk.download('averaged_perceptron_tagger')

### Load Model

In [6]:
tokenizer = BertTokenizer.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')
model = BertForQuestionAnswering.from_pretrained('bert-large-uncased-whole-word-masking-finetuned-squad')

***
### Model function
Use loaded BERT model to return answer<br>
**Input** : Question, Context text <br>
**return** : Answer

In [7]:
def generate_answer(question, answer_text):
    print("I'm looking for an answer, please wait ...")
    # == Tokenize ==
    # use a python dictonary so run on CPU
    # Apply the tokenizer to the input text, treating them as a text-pair.
    print("-Tokenization")
    input_ids = tokenizer.encode(question, answer_text)
    print("input ids : ", input_ids)
    #print('The input has a total of {:} tokens.'.format(len(input_ids)))

    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    # == Set Segment IDs ==
    # Search the input_ids for the first instance of the `[SEP]` token.
    sep_index = input_ids.index(tokenizer.sep_token_id)

    # The number of segment A tokens including the [SEP] token istelf.
    num_seg_a = sep_index + 1

    # The remainder are segment B.
    num_seg_b = len(input_ids) - num_seg_a

    # Construct the list of 0s and 1s.
    segment_ids = [0]*num_seg_a + [1]*num_seg_b

    # There should be a segment_id for every input token.
    assert len(segment_ids) == len(input_ids)

    # == Run Model ==
    # Run our example through the model.
    # by default on CPU, use model.to(device) to select GPU !?
    print("-Forward pass on the model")
    start_scores, end_scores = model(torch.tensor([input_ids]), # The tokens representing our input text.
                                 token_type_ids=torch.tensor([segment_ids])) # The segment IDs to differentiate question from 

    
    # donc on applique un argmax pour trouver le plus probable
    # Find the tokens with the highest `start` and `end` scores.
    answer_start = torch.argmax(start_scores)
    answer_end = torch.argmax(end_scores)
    
    print(type(start_scores))
    print(start_scores.size())
    print(start_scores[0,answer_start])
    print(end_scores[0,answer_end])
    
    # == Print Answer without ## ==
    # Start with the first token.
    answer = tokens[answer_start]

    # Select the remaining answer tokens and join them with whitespace.
    for i in range(answer_start + 1, answer_end + 1):
    
        # If it's a subword token, then recombine it with the previous token.
        if tokens[i][0:2] == '##':
            answer += tokens[i][2:]
    
        # Otherwise, add a space then the token.
        else:
            answer += ' ' + tokens[i]

    return answer

    

***
### Question Processing
Extract subjet _(and more?)_ from the question

In [9]:
def extract_subject(question):
    subject = None
    token = nltk.word_tokenize(question)
    #print(token)
    pos_token = nltk.pos_tag(token)
    for item in pos_token:
        if item[1] == 'NN':
            subject = item[0]
    if subject is not None:
        print("Subject found: " + subject)
    else:
        print("Subject not found 😔\n Rephrase the question or try another one")
    return  subject

extract_subject("What is the color of the sky ?")



LookupError: 
**********************************************************************
  Resource [93mpunkt[0m not found.
  Please use the NLTK Downloader to obtain the resource:

  [31m>>> import nltk
  >>> nltk.download('punkt')
  [0m
  For more information see: https://www.nltk.org/data.html

  Attempted to load [93mtokenizers/punkt/english.pickle[0m

  Searched in:
    - 'C:\\Users\\Utilisateur/nltk_data'
    - 'C:\\Users\\Utilisateur\\anaconda3\\envs\\askme\\nltk_data'
    - 'C:\\Users\\Utilisateur\\anaconda3\\envs\\askme\\share\\nltk_data'
    - 'C:\\Users\\Utilisateur\\anaconda3\\envs\\askme\\lib\\nltk_data'
    - 'C:\\Users\\Utilisateur\\AppData\\Roaming\\nltk_data'
    - 'C:\\nltk_data'
    - 'D:\\nltk_data'
    - 'E:\\nltk_data'
    - ''
**********************************************************************


In [None]:
import spacy
import en_core_web_sm
nlp = en_core_web_sm.load()
#!python -m spacy download en_core_web_sm
#nlp = spacy.load("en_core_web_sm")

def extract_subject_with_spacy(question):
    #noun to not take in count
    osef_list = ['who','why','what','when','which','how', 'Who','Why','What','When','Which']
    doc = nlp(question)
    nouns = doc.noun_chunks
    nouns_list = list(nouns)
    
    #on enlève de la liste des potentiels sujets les mots interrogatifs venant de osef_list
    for noun in nouns_list : 
        if str(noun) in osef_list : 
            nouns_list.remove(noun)
    
    #si notre liste de potentiels sujets est vide : pas de sujet
    #si elle est égal à 1 : pas de doute, le sujet est cet élément
    #si elle est plus grande que 1, le sujet est le deuxième élément
    #règle simpliste mais qui semble suivre la logique de la formulation d'une question : c'est souvent le second nom qui est le sujet dans les questions qui en comportent deux, j'ai l'impression 
    
    if(len(nouns_list)) == 0 :
        print("subject not found, please try another formulation")
    elif(len(nouns_list)) == 1 :
        print("subject found : " + str(nouns_list[0]))
        return(str(nouns_list[0]))
    else :
        print("subject found : " + str(nouns_list[1]))
        return(str(nouns_list[1])) #dans ces cas de figure avec + d'un nom, il faudra quand même récupérer le nom qui n'est pas le sujet, pour aller l'utiliser en scrappant la page wiki du sujet
    
        
    
    
        
    #for item in nouns:
        #if str(item) not in osef_list:
            #print("Subject found: " + str(item))
            #return str(item)
    #print("Subject not found 😔\n Rephrase the question or try another one")
        
#test
extract_subject_with_spacy('What is the ?')
extract_subject_with_spacy('What is the sky?')
extract_subject_with_spacy('What is the color of the sky ?')

extract_subject_with_spacy('What is the meaning of "bread" ?') #pas de pb
extract_subject_with_spacy('What is the meaning of "omg" ?') #les sigles ne sont pas reconnus
extract_subject_with_spacy('What is the meaning of "OMG" ?') #ah bah en majuscule si
extract_subject_with_spacy('What is the meaning of "Why" ?') #cas (très) rare et spécial à gérer, quand une question porte sur un mot interrogatif qui est dans osef_list (faudra mettre une condition genre si y'a + d'un mot interrogatif, ne supprimer que le premier de noun_list)
extract_subject_with_spacy('What is the third color of the french flag ?') #propre
extract_subject_with_spacy("What is the color of Nirvana's second album ?") #stylé que ça trouve le sujet dans ce genre de cas. Une future grosse tâche : trouver les articles wikipédia à partir de ce genre de paraphrase

### Wikipedia API
Try to found the most relevant context text to give as BERT input. <br>
Get a wikipedia article, and scrap it

In [None]:
# == WIP ==
# découpé en paragraphe (le model a une limite de 512 token pour le text en entré)
def get_wiki_and_split(subject):
    text = wikipedia.summary(subject)
    print(len(text))
    return text


In [None]:
wikipedia.search("salsa")

In [None]:
text = wikipedia.summary('salsa')
list = text.split('.')
print(list)

***
### Visualization

In [None]:
from IPython.display import display
from ipywidgets import widgets

#Widgets layout difinition
layout = widgets.Layout(width='400px', height='0px', margin='100px 0 0 100px')
bLayout = widgets.Layout(width='50px', height='28px', margin='100px 0 0 0px')
outLayoutPropre = widgets.Layout(width='480px', height='100px', margin='50px 0 100px 100px')
outLayoutTest = widgets.Layout(width='450px', height='auto', margin='50px 0 100px 100px')
#titleLayout = widgets.Layout(width='450px', height='auto', margin='0px 0 0px 100px')
 
#Widgets object definition
text = widgets.Text(layout=layout)
button = widgets.Button(description = 'Ask', layout = bLayout)
out = widgets.Output(layout=outLayoutTest)#layout=outLayout
#out = widgets.HTML(layout = outLayout, value= '<style>.text {width: 480px; heigh: 100px;}</style> <p class="text">'+ out_value +' </p>')
 
def button_on_click(self):
    with out:
        out.clear_output()
        subject = extract_subject_with_spacy(question=text.value)
        if subject is not None:
            context = wikipedia.summary(subject)
            
            answer = generate_answer(text.value, context[:2000])
            #out.clear_output()
            print("Here is what i found: \n"+ answer)
        else:
            pass
        
button.on_click(button_on_click)

display(widgets.HBox((text, button,)))
display(out)