In [53]:
import json
import requests
from tqdm import tqdm_notebook as tqdm
from scipy.special import softmax

import time
import sys
sys.path.append("..")
from scripts import evaluate_answer

In [45]:
headers = {
    "Content-Type": "application/json",
    'accept': 'application/json'
}

In [6]:
def json_load(name):
    with open(f'{name}', 'r', encoding = 'utf-8') as f:
        return json.load(f)
    
def json_save(name, item):
    with open(f'{name}', 'w', encoding = 'utf-8') as f:
        json.dump(item, f, ensure_ascii = False, indent = 4)
        
def read_vanilla(name):
    with open(name) as f:
        data_tmp = f.readlines()

    data = list()
    for q in data_tmp:
        data.append(json.loads(q))
        
    return data

In [7]:
# vanilla test dataset
test = read_vanilla("../data/VANILLA/Extended_Dataset_Test.json")
# responses for the test dataset from QAnswer (the order of SPARQLs is not changed)
responses = json_load("../processed_data/VANILLA/qanswer_test_responses_extended-0-1000.json") # first thousand
# labels, obtained for each response of QAnswer (see file above)
labels = json_load("../processed_data/VANILLA/qanswer_test_responses_labels.json")
# contains if a particular SPARQL candidate is True (correct) or not
is_true = json_load("../processed_data/VANILLA/is_true.json")

The **files above** were prepared as follows:

* `test` -- default test split of VANILLA
* `responses` -- at first run `scripts/vanilla_qanswer.py`, then run `scripts/vanilla_run_sparql_candidates_on_wikidata.py`
* `labels` -- run `scripts/get_vanilla_labels_wikidata.py`
* `is_true` -- run `scripts/vanilla_get_right_answers.py`

In [35]:
def precision_at_k(data, k=1):
    """
    How many relevant items are present in the top-k recommendations of the system
    """
    assert k > 0
    prec = list()
    for q in data:
        cnt = 0
        for i in range(len(q['response'])):
            if i + 1 <= k: # take first k responses
                if q['response'][i]['is_true']:
                    cnt +=1
        prec.append(cnt/k)
        
    return sum(prec)/len(prec)

In [36]:
responses[0]['response'][0]

{'query': 'SELECT DISTINCT ?o1 WHERE { \t <http://www.wikidata.org/entity/Q8316084>  <http://www.wikidata.org/prop/direct/P21>  ?o1 .  }  LIMIT 1000',
 'confidence': 0.49,
 'result': [{'o1': {'type': 'uri',
    'value': 'http://www.wikidata.org/entity/Q6581097'}}],
 'is_true': True}

In [39]:
# unify data
for i in range(len(responses[:250])):
    for j in range(len(responses[i]['response'])):
        responses[i]['response'][j]['is_true'] = is_true[i]['answer_list'][j]

In [51]:
precision_at_k(responses[:250], 1), precision_at_k(responses[:250], 5)

(0.404, 0.18160000000000004)

In [83]:
# filter answer candidates and create new dataset for evaluation
qanswer_results_filtered = list()

for i in tqdm(range(len(responses[:250]))): # iterate over questions
    question_text = test[i]['question']
    # START: here goes the code used to predict is answer valid or not
    batch = list()
    for j in range(len(labels[i]['responses'])):
        answer_text = ' '.join(t for t in labels[i]['responses'][j])
        batch.append([question_text, answer_text])
    
    data = json.dumps(batch, ensure_ascii=False)
    data = data.encode('ascii', 'ignore').strip()

    json_response = requests.post('http://webengineering.ins.hs-anhalt.de:41003/predict',
                                  data=data,
                                  headers=headers)

    preds = json_response.json()['predictions'] # get predictions for the q-a tuples set
    outputs = json_response.json()['outputs']
    # END: here goes the code used to predict is answer valid or not
    
    answers = list()
    for j in range(len(preds)): # iterate over predictions for each candidate
        if not preds[j] and max(softmax(outputs[j])) > 0.99: # if model says NOT TRUE we check how confident it is
            continue
        else:
            answers.append({'is_true': responses[i]['response'][j]['is_true']})
    qanswer_results_filtered.append({'response': answers})

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  after removing the cwd from sys.path.


HBox(children=(FloatProgress(value=0.0, max=250.0), HTML(value='')))




In [85]:
precision_at_k(qanswer_results_filtered, 1), precision_at_k(qanswer_results_filtered, 5)

(0.064, 0.020800000000000006)