# Question Answering Engine

## 04: Question Answering Modules

For the modules

 to download the necessary files and initialise the span entity and relation prediction models, as


- I added colour in the messages and progress bars to make the interface a bit more appealing. 
- The preprocessing of the question is similar to the one I used for the data ingestion, for example removing accents, from Adèle to adele. 
- For the models the forward method are used to get the logits and then the highest probabilites are the predictions.
- For the SPARQL builder and executor modules I used the suggestions provided by the SPARQL documentation.

In [9]:
# ANSI colour codes
RED = '\033[91m'
GREEN = '\033[92m'
YELLOW = '\033[93m'
BLUE = '\033[94m'
MAGENTA = '\033[95m'
CYAN = '\033[96m'
RESET = '\033[0m'

# Message to the user
print(GREEN + "Please wait..." + RESET)

# Importing libraries
import sys
from SPARQLWrapper import SPARQLWrapper, JSON
import torch
import torch.nn as nn
import pandas as pd
from unidecode import unidecode
import difflib
from tqdm import tqdm
import spacy
nlp = spacy.load('en_core_web_lg')
from transformers import BertTokenizer, BertModel, logging
logging.set_verbosity_error()

# Initialise variables and GPU
device = 'mps' if (torch.backends.mps.is_available()) else 'cuda' if ( torch.cuda.is_available()) else 'cpu'
BERT_MODEL = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(BERT_MODEL, do_lower_case=True)

# Firstly read the dictionary I created
df = pd.read_csv("dataset/entity_dict.csv", sep = ',')
Entities = df['Entity']
Entity_ids = df['Id']

# Remove accents from the 'Entity' column
df['Entity'] = df['Entity'].apply(lambda x: unidecode(x))

# Create a dictionary with entities as keys and entity_ids as values
entity_to_id = {entity: entity_id for entity, entity_id in zip(df['Entity'], df['Id'])}
entity_list = list(entity_to_id.keys())

# This dictionary for similar words for extreme cases
entity_docs = {}

# Create a dictionary with docs as keys and entity_ids as values
for entity_name, id in tqdm(entity_to_id.items(), desc='Building Knowledge Base'):
    doc = nlp(entity_name)
    if doc.has_vector:
        entity_docs[doc] = id

# Then read the relations
df = pd.read_csv("dataset/relation_vocab.csv")

# Create a list with the relation vocabulary
relation_vocab = df['Relation'].to_list()

del df

class BERT_SPAN(torch.nn.Module):
    def __init__(self, bert_model, vocab_size):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)

        self.start_head = nn.Sequential(
            nn.Dropout(p=0.15),
            nn.Linear(self.bert.config.hidden_size, 1),
            nn.Flatten(),
            nn.Softmax(dim=1)
        )

        self.end_head = nn.Sequential(
            nn.Dropout(p=0.15),
            nn.Linear(self.bert.config.hidden_size, 1),
            nn.Flatten(),
            nn.Softmax(dim=1)
        )
        
        self.vocab_size = vocab_size

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        
        start_ent = self.start_head(sequence_output)
        end_ent = self.end_head(sequence_output)
        
        return start_ent * attention_mask, end_ent * attention_mask

class BERT_REL(torch.nn.Module):
    def __init__(self, bert_model, vocab_size):
        super().__init__()
        self.bert = BertModel.from_pretrained(bert_model)

        self.relation_head = nn.Sequential(
            nn.Dropout(0.15),
            nn.Linear(self.bert.config.hidden_size, vocab_size),
            nn.Softmax(dim=1)
        ) 
        self.vocab_size = vocab_size

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids, attention_mask=attention_mask)        
        return  self.relation_head(outputs[1])

def preprocess(question):
    question = unidecode(question.replace("?", "").replace("'s", ""))
    ids = []
    mask = []
    
    # Encode the question
    encoding = tokenizer.encode_plus(
        text = question,
        return_attention_mask=True,
        add_special_tokens=False
    )

    ids.append(encoding['input_ids'])
    mask.append(encoding['attention_mask'])

    return torch.tensor(ids), torch.tensor(mask)

def relation_prediction(model, ids, mask, relation_vocab):
        
    # Predict 
    relation_logits = model(input_ids=ids, attention_mask=mask)
    _ , prediction = torch.max(relation_logits, dim=1)
    
    return relation_vocab[prediction]

# In case the entity is slightly off by a word or summat
def find_closest_match(entity, entity_docs, entity_to_id, entity_list):
        
    doc = nlp(entity)
    if doc.has_vector:

        scores = {}
        # Find similarity scores between entity and each entity in dictionary
        for doc_ent, id in entity_docs.items():
            scores[id] = doc.similarity(doc_ent)

        best_score = max(scores.values())
        return [id for id, score in scores.items() if score == best_score][0].split()[0]
    
    # Else simply find the best match from the rest
    best_match = difflib.get_close_matches(entity, entity_list, n=1, cutoff=0.12)
    if best_match != []:
        return entity_to_id[best_match[0]][0]
    return None

endpoint_url = "https://query.wikidata.org/sparql"

def entity_prediction(model, ids, mask, entity_docs, entity_to_id, entity_list):

    # Probabilities from the model
    start_logits, end_logits = model(input_ids=ids, attention_mask=mask)

    # Find the start index with the highest probability
    start_pred = torch.argmax(start_logits, dim=1)

    # Create a mask with the same shape as the matrix
    start_mask = torch.zeros(len(mask[0])).to(device)
    start_mask[start_pred.item():] = 1

    # Find the end index with the highest probability
    masked_end_logits = end_logits * start_mask
    end_pred = torch.argmax(masked_end_logits, dim=1)

    # Get the entity from the tokens
    tokens = ids[0][start_pred.item():end_pred.item()+1]
    entity = tokenizer.decode(tokens)

    # Initially check if it's in its current form (without accents etc)
    if entity in entity_to_id:
        return entity_to_id[entity]
    else:
        entityid = get_entityid(entity)

        if entityid != []:
            return entityid[0]

    # Fallback to the most similar from the dictionary
    return find_closest_match(entity, entity_docs, entity_to_id, entity_list)

def get_results(endpoint_url, query):
    user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1])
    sparql = SPARQLWrapper(endpoint_url, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    return sparql.query().convert()

def get_entityid(label):
    query = """SELECT distinct ?item ?itemLabel ?itemDescription WHERE{  
    ?item ?label "%s"@en.
    ?article schema:about ?item .
    ?article schema:inLanguage "en" .
    ?article schema:isPartOf <https://en.wikipedia.org/>.	
    SERVICE wikibase:label { bd:serviceParam wikibase:language "en". }    
    }""" % (label.title())

    results = get_results(endpoint_url, query)

    entityids = []
    for result in results["results"]["bindings"]:
        entityids.append(result["item"]["value"].split(sep='/')[-1])
    return entityids

def query_builder(entityid, relation):

    # Check whether the question relation is inverse
    inverse = False
    if (relation[0] == 'R'):
        relation = relation.replace('R', 'P')
        inverse = True

    if inverse:
        query = """
        SELECT ?item ?itemLabel 
        WHERE 
        {
        ?item wdt:%s wd:%s.
        SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
        }""" % ( relation, entityid)
    else:
        query = """
        SELECT ?item ?itemLabel 
        WHERE 
        {
        wd:%s wdt:%s ?item.
        SERVICE wikibase:label { bd:serviceParam wikibase:language "[AUTO_LANGUAGE],en". }
        }"""  % (entityid, relation)

    return query

def query_executor(query):
    
    endpoint = "https://query.wikidata.org/sparql"
    user_agent = "WDQS-example Python/%s.%s" % (sys.version_info[0], sys.version_info[1])
    sparql = SPARQLWrapper(endpoint, agent=user_agent)
    sparql.setQuery(query)
    sparql.setReturnFormat(JSON)
    results = sparql.query().convert()
    return results["results"]["bindings"]


def answer(question):
    print(CYAN + question + RESET)
    ids, mask = preprocess(question)
    ids = ids.to(device)
    mask = mask.to(device)
    relation = relation_prediction(relation_model, ids, mask, relation_vocab)
    entity = entity_prediction(entity_model, ids, mask, entity_docs, entity_to_id, entity_list)
    if entity == None:
        print(MAGENTA + "Nothing found, I'm sorry..." + RESET)
    else:
        query = query_builder(entityid=entity, relation=relation)
        results = query_executor(query)
        if results== []:
            print(MAGENTA + "Nothing found, I'm sorry..." + RESET)
        for iter, result in enumerate(results):

            # To only print top 5 results
            if iter == 5:
                break
            # Print the answer
            print(YELLOW + result["itemLabel"]["value"] + RESET)

models = []
for i in tqdm(range(2) , desc='Loading Entity & Relation Model'):

    if i ==0:
        entity_model = BERT_SPAN(bert_model=BERT_MODEL, vocab_size=len(relation_vocab)).to(device)
        entity_model.load_state_dict(torch.load('./best_span_model.pt'))
    else:
        relation_model = BERT_REL(bert_model=BERT_MODEL, vocab_size=len(relation_vocab)).to(device)
        relation_model.load_state_dict(torch.load('./best_relation_model.pt'))

print(GREEN + "Everything is up and running!" + RESET)

[92mPlease wait...[0m


Building Knowledge Base: 100%|██████████| 24103/24103 [01:22<00:00, 292.17it/s]
Loading Entity & Relation Model: 100%|██████████| 2/2 [00:04<00:00,  2.25s/it]

[92mEverything is up and running![0m





### Examples

When asking questions from the train or validation set the model performed brilliantly. From the test set it struggled in only a few questions, but overall it had no problem like for example in the question below from the test set. This hasn't been used for training so it should give an accurate indication that the model is working correctly on unseen data.

In [3]:
question = "what position does jose francisco torres play"
answer(question)

[96mwhat position does jose francisco torres play[0m
[93mmidfielder[0m


The model is predicting accurately the entity and relation, and the answer is correct. Then below I try with a question that isn't in the test set, by modifying the above question's context to see if it can identify that the relation is different.

In [6]:
question = "which national team does francisco torres play for?"
answer(question)

[96mwhich national team does francisco torres play for?[0m
[93mUnited States of America[0m


The relation model correctly identifies that it's a different relation and answers the question correctly. The interesting thing is that even by altering the name and removing the first name (ie Jose) of the footballer the model is able to predict really well and identify the correct entity.

In [10]:
question = "What does vettel do?"
answer(question)

[96mWhat does vettel do?[0m
[93mracing automobile driver[0m
[93mFormula One driver[0m
[93mmotorsports competitor[0m
[93minternational forum participant[0m


In [6]:
question = "Where is max verstappen from?"
answer(question)

[96mWhere is Max Verstappen from?[0m
[93mBelgium[0m
[93mKingdom of the Netherlands[0m


In the two examples above the entities Vettel or Verstappen aren't present in the train set and subsequently in the entity dictionary, but the model is still able to predict the correct span entity and provide the right answer to the question. This is achieved by locating the label and using it to identify the relevant entity id with a query. 

This makes the model able to answer questions on entities that weren't used for training. Below you can experiment by entering a question and running the cell to get the answer.

In [None]:
question = ""
answer(question)