# Made in collaboration with Chu Fang, Xiangdi Lin, Rachel Finley, and Dawid Cichoki

# Class for NER outputs

In [1]:
'''imports and their dependencies'''

#!pip install transformers==3
#!pip install torch
# !pip install tqdm


import torch
from tqdm import tqdm
from tokenizers import Tokenizer
from torch.optim.lr_scheduler import LinearLR
import numpy as np

In [2]:
from transformers import BertConfig, BertTokenizerFast, BertForTokenClassification, AutoTokenizer
import re

In [3]:
class BERT_NER():
    '''Class to isolate entities from question with NER'''
    
    def __init__(self):
        '''Initialize models and variables'''
        
        ### Tokenizer settings
        self.tokenizer =  BertTokenizerFast.from_pretrained("bert-base-uncased")
        

        ### Model Settings
        self.lr = 1e-6
        self.batch_size = 32
        self.num_epochs = 1
        self.max_length = 100
        
        ### Load/config model &  optimizer
        self.device = torch.device("cpu" if not torch.cuda.is_available() else "cuda")
        self.config = BertConfig(max_position_embeddings = self.max_length, num_labels = 3)
        
        self.model = BertForTokenClassification.from_pretrained("Rachel-Finley/BERT_NER_for_QA_Queries")
        self.optimizer = torch.optim.AdamW(self.model.parameters(), lr = self.lr, weight_decay=0.01)
        
        ### Attributes initialization
        self.tokenized = {}
        self.input_ids = {}
        self.tokens = {}
        self.new_tokens = []
        self.new_labels = []
        
        self.results = {}
        self.named_entity_found = []
        self.results_final = []
        
        self.pattern = r'\[\w+\]'
    
    def run_model(self, question):
        '''Tokenizes question, predicts entities, saves results to class attribute'''
        
        ### tokenize
        split_questions = question.split()
        self.tokenized = self.tokenizer(split_questions, is_split_into_words=True, return_tensors='pt')
        
        ### run model
        with torch.no_grad():
            logits = self.model(**self.tokenized).logits
        
        ### save logits, decode predictions, and clean the result
        self.results = logits.argmax(-1).tolist()[0]
        encoded_input_ids = self.tokenized["input_ids"].tolist()[0]
        self.named_entity_found = [encoded_input_ids[i] for i,x in enumerate(self.results) if x == 1 or x == 2]
        self.results_final = self.tokenizer.decode(self.named_entity_found)
        self.results_final = re.sub(self.pattern, '', self.results_final)
        


In [4]:
question = "when was the treaty of versiallies signed in world war two?"

ner_model = BERT_NER()

In [5]:
ner_model.run_model(question)


In [6]:
ner_model.results_final

'when was the of versiallies signed in war two '

# Class for Querying Wikipedia

In [7]:
'''imports and their dependencies'''

# !pip install wikipedia 

import wikipedia

In [8]:
class Wiki():
    '''class to query wikipedia and save the contents'''
    
    def __init__(self):
        '''initializing dictionaries'''
        self.wiki_results = {}
        self.wiki_summaries = {}
    
    def query_wiki(self, output: str,  num_results: int):
        '''Parameters: NER output, n articles to retrieve'''
    
        # query wikipedia using the output of the NER model
        # get top n wiki results in the form of article names
        for idx, result in enumerate(wikipedia.search(output, num_results, suggestion = True)):
            self.wiki_results[idx] = result
        # use them to query wikipedia for content of summaries
        for idx, article in enumerate(self.wiki_results[0]):
            self.wiki_summaries[idx] = wikipedia.page(article).summary
        list(self.wiki_summaries)

In [9]:
wiki = Wiki()
wiki.query_wiki("the treaty of versailles world war two", 4)
print(wiki.wiki_results[0])
print(wiki.wiki_summaries)

['Treaty of Versailles', 'Article 231 of the Treaty of Versailles', 'World War I reparations', 'U.S.–German Peace Treaty (1921)']
{0: 'The Treaty of Versailles was a peace treaty signed on 28 June 1919. As the most important treaty of World War I, it ended the state of war between Germany and most of the Allied Powers. It was signed in the Palace of Versailles, exactly five years after the assassination of Archduke Franz Ferdinand, which led to the war. The other Central Powers on the German side signed separate treaties.  The United States never ratified the Versailles treaty and made a separate peace treaty with Germany.  Although the armistice of 11 November 1918 ended the actual fighting, it took six months of Allied negotiations at the Paris Peace Conference to conclude the peace treaty. Germany was not allowed to participate in the negotiations—it was forced to sign the final result.  \nThe most critical and controversial provision in the treaty was: "The Allied and Associated Go

# Class for QA model

In [10]:
'''imports and their dependencies'''

# !pip install torch
# !pip install transformers==3
# !pip install tqdm

import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering
from tqdm import tqdm

In [11]:
class BERT_QA():
    '''class for running QA model on the summaries'''
    
    def __init__(self):
        '''initialize models, and variables'''
        
        # models
        self.tokenizer = AutoTokenizer.from_pretrained("csarron/bert-base-uncased-squad-v1")
        self.model = AutoModelForQuestionAnswering.from_pretrained("csarron/bert-base-uncased-squad-v1")
        
        # encoded text data, model inputs, embeddings, text
        self.encoding = {}
        self.inputs = {}
        self.sentence_embedding = {}
        self.tokens = {}
        
        # index of QA model's answer
        self.start_scores = {}
        self.end_scores = {}
        self.start_index = {}
        self.end_index = {}
        self.outputs = {}
        
        # predicted answer's text
        self.init_answer = ""
        self.cleaned_answer = ""
        
        # preprocessing variables to save split text data
        self.split_total = []
        self.split_partial = []
        
    def get_split_tokens(self, wiki_summaries):
        '''splits text from wiki summaries to be right size for BERT model'''

        if len(wiki_summaries.split())//150 >0:
            n = len(wiki_summaries.split())//150
        else: 
            n = 1
        for w in range(n):
            if w == 0:
                self.split_partial = wiki_summaries.split()[:200]
                self.split_total.append(" ".join(self.split_partial))
            else:
                self.split_partial = wiki_summaries.split()[w*150:w*150 + 200]
                self.split_total.append(" ".join(self.split_partial))
        
    def pre_process(self, question, split_wiki_summaries):
        '''tokenizes and encodes question/context for BERT model'''
        
        self.encoding = self.tokenizer.encode_plus(text = question, text_pair = self.split_total[0])
        # token embeddings
        self.inputs = self.encoding['input_ids']
        # segment embeddings
        self.sentence_embedding = self.encoding['token_type_ids']
        # input tokens
        self.tokens = self.tokenizer.convert_ids_to_tokens(self.inputs)
        

    def predict(self):
        '''uses BERT model to predict the span of text with answer'''
        
        ################# NEED SOME KIND OF PADDING HERE FOR MULTIPLE CONTEXTS ###################################
                        #### all vectors need to be the same size
                        #### pad shorter contexts with the token id 0
                
        # run QA model to index predicted answer
        for batch in tqdm(self.tokens):
            self.outputs = self.model(input_ids=torch.tensor([self.inputs]), token_type_ids=torch.tensor([self.sentence_embedding]))
            self.start_index = torch.argmax(self.outputs.start_logits)
            self.end_index = torch.argmax(self.outputs.end_logits)

            # return predicted answer as string
            self.init_answer = ' '.join(self.tokens[self.start_index : self.end_index + 1])
            for word in self.init_answer.split():
                if word[0:2] == '##':
                    self.cleaned_answer += word[2:]
                else:
                    self.cleaned_answer += ' ' + word

In [12]:
qa = BERT_QA()
question = "when was the treaty of versiallies signed in world war two?"
qa.get_split_tokens(wiki.wiki_summaries[0])
qa.pre_process(question, qa.split_total[0])
qa.predict()

100%|██████████| 236/236 [02:50<00:00,  1.38it/s]


In [13]:
print(qa.init_answer)

28 june 1919


# Main Function/ Testing Results

In [26]:
def main():
    ner = BERT_NER()
    wiki = Wiki()
    qa = BERT_QA()
    
    question = input("Search Bar:")
    
    ner.run_model(question)
    wiki.query_wiki(ner.results_final, 2)
    qa.get_split_tokens(wiki.wiki_summaries[0])
    qa.pre_process(question, qa.split_total[0])
    qa.predict()
    
    print(qa.init_answer)

In [27]:
main()

Search Bar: what was the largest battle in the Napoleonic Wars?


100%|██████████| 301/301 [04:36<00:00,  1.09it/s]


the french invasion of russia ( 1812 ) . they were the most widespread and costly wars in european history before world war i
