In [1]:
from elasticsearch import Elasticsearch
import torch
from transformers import AutoTokenizer, AutoModelForQuestionAnswering

In [2]:
model_name = "deepset/roberta-base-squad2"
tokenizer = AutoTokenizer.from_pretrained(model_name, max_length = 512)
model = AutoModelForQuestionAnswering.from_pretrained(model_name)

In [10]:
class Reader():
    def __init__(self, index, host = 'localhost', port=9200):
        self.client = Elasticsearch(hosts=[{"host": host, "port": port}])
        self.index = index
        self.max_length = 0
        model_name = "deepset/roberta-base-squad2"
        self.tokenizer = AutoTokenizer.from_pretrained(model_name, max_length=512)
        self.model = AutoModelForQuestionAnswering.from_pretrained(model_name)
#         self.model.to('cuda')
        
    def query(self, query, fields=['text'], filters=[], top_k=10, excluded_meta_data=[]):
        pass
    
        body = {
            "size": top_k,
            "query": {
                "bool": {
                    "must": {
                        "multi_match": {
                            "query": query, 
                            "type": "most_fields", 
                            "fields": fields
                        }
                    }
                }
            }
        } 
        
        if filters:
            filter_clause = []
            for key, values in filters.items():
                filter_clause.append(
                    {
                        "terms": {key: values}
                    }
                )
            body["query"]["bool"]["filter"] = filter_clause

        if excluded_meta_data:
            body["_source"] = {"excludes": excluded_meta_data}

        result = self.client.search(index=self.index, body=body)["hits"]["hits"]
        self.contexts = [r['_source']['text'] for r in result]
        return self.contexts
    
    def infer(self, query):
        self.query(query)
        
        self.outputs = []
        self.answers = []
        for context in self.contexts:
            inputs = self.tokenizer(query, context, return_tensors="pt", add_special_tokens=True)
            input_ids = inputs["input_ids"].tolist()[0]

            # Get longest question to test out saved model. We can pad shorter answers. Probably better than trimming longer ones
            l = len(input_ids)
            if l >= 512:
                continue
            elif l > self.max_length:
                self.longest = (inputs['input_ids'], inputs['attention_mask'])
                self.max_length = l

            # decoded_inputs = tokenizer.decode(inputs["input_ids"][0])
            # text_tokens = tokenizer.convert_ids_to_tokens(input_ids)
            outputs = model(**inputs)
            self.outputs.append(outputs)
            
            answer = {}
            answer_start_scores = outputs[0]
            answer_end_scores = outputs[1]
            answer_start = torch.argmax(answer_start_scores)  # Get the most likely beginning of answer with the argmax of the score
            answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score
            response = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
            
            answer['score'] = max([answer_start_scores[0][answer_start].item(), answer_end_scores[0][answer_end].item()])
            answer['answer'] = response
            self.answers.append(answer)
            
#             print(f"Question: {query}")
#             print(f"Answer: {answer}")
            
        self.answers.sort(key=lambda x: x['score'], reverse = True)   
        return self.answers
    
    
#     def parse(self, query, outputs=None):
#         if not outputs:
#             self.infer(query)
#             outputs = self.infer(query)
            
#         self.answers = []
#         for output in outputs:
#             answer = {}
#             answer_start_scores = output[0]
#             answer_end_scores = output[1]
#             answer_start = torch.argmax(
#                 answer_start_scores
#             )  # Get the most likely beginning of answer with the argmax of the score
#             answer_end = torch.argmax(answer_end_scores) + 1  # Get the most likely end of answer with the argmax of the score
#             response = tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(input_ids[answer_start:answer_end]))
#             answer['score'] = max([answer_start, answer_end])
#             answer['answer'] = response
            
#             print(f"Question: {query}")
#             print(f"Answer: {answer}")
#         self.answers.sort(key=lambda x: x['score'], reverse = True)
        
reader = Reader('ahrq')

In [13]:
query = "What is ahrq?"
reader.infer(query)

[{'score': 4.527154922485352,
  'answer': ' Agency for Healthcare Research and Quality'},
 {'score': 4.409052848815918, 'answer': ''},
 {'score': 4.211277961730957, 'answer': ''},
 {'score': 2.202496290206909, 'answer': '<s>'},
 {'score': 1.9319677352905273,
  'answer': ' a consolidated set of national standardized databases of reliable social factors that will build on existing databases developed by Federal agencies'},
 {'score': 1.4219703674316406,
  'answer': '<s>What is ahrq?</s></s>Dr. Embi noted the advancing capabilities of technologies, which can lead to better care management. Multidirectional communication is now possible. AHRQ could enable and encourage solutions that allow such communication. As such, AHRQ could advance innovation in patient-centered care'}]

In [14]:
reader.answers[0]['answer']

' Agency for Healthcare Research and Quality'