In [1]:
import os
import re
import json
import aiohttp
import asyncio
import pandas as pd
from tqdm import tqdm
from aiohttp import ClientTimeout
from SPARQLWrapper import SPARQLWrapper, JSON

In [3]:
async def execute_sparql(session, query, timeout=30, max_retries=3):
    SEM = asyncio.Semaphore(20) 
    
    if not query:
        return None

    url = "https://query.wikidata.org/sparql"
    headers = {"Accept": "application/sparql-results+json"}
    data = {"query": query, "format": "json"}

    async with SEM:  # Limit concurrent requests
        for attempt in range(1, max_retries + 1):
            try:
                async with session.post(url, data=data, headers=headers, timeout=ClientTimeout(total=timeout)) as response:
                    if response.status == 200:
                        results = await response.json()
                        return extract_answers_from_response(results)
                    elif response.status == 400:  # Query malformed
                        return None

            except aiohttp.ClientError as e:
                if attempt == max_retries:
                    return []
                await asyncio.sleep(1)
            except asyncio.TimeoutError:
                if attempt == max_retries:
                    return []
                await asyncio.sleep(1)
        return []

def extract_answers_from_response(response):
    answers = []
    if 'results' in response:
        for binding in response['results']['bindings']:
            for key, sub_answer in binding.items():
                value = sub_answer.get('value')
                if re.match(r"^https?://www\.wikidata\.org/entity/Q\d+$", value):
                    answers.append(extract_wikidata_id_from_link(value))
                else:
                    answers.append(value)
    elif 'boolean' in response:
        answers.append(response['boolean'])
    return answers

def extract_wikidata_id_from_link(link):
    match = re.search(r"https?://www\.wikidata\.org/entity/(Q\d+)", link)
    return match.group(1) if match else None

In [4]:
async def get_entity_relations(entity_id: str):
    SPARQL_ENDPOINT = "https://query.wikidata.org/sparql"
    query = f"""
    SELECT ?relation ?direction WHERE {{
      {{ ?subject ?relation wd:{entity_id} . BIND("in" AS ?direction) }}
      UNION
      {{ wd:{entity_id} ?relation ?object . BIND("out" AS ?direction) }}
    }}
    """
    headers = {"Accept": "application/sparql-results+json"}

    try:
        async with aiohttp.ClientSession() as session:
            async with session.get(SPARQL_ENDPOINT, params={"query": query, "format": "json"}, headers=headers) as resp:
                data = await resp.json()
    except:
        return [], []
    
    in_relations, out_relations = set(), set()
    for item in data.get("results", {}).get("bindings", []):
        relation = item["relation"]["value"].split("/")[-1]
        if re.match(r"^P\d+$", relation):
            (in_relations if item["direction"]["value"] == "in" else out_relations).add(relation)
    
    return list(in_relations), list(out_relations)

async def check_entity_mismatch(entities, relations):
    entity_mismatch = True

    for entity in entities:
        if in_out := await get_entity_relations(entity):
            if len(set(in_out[0] + in_out[1]) & relations) > 0:
                entity_mismatch = False  # Found at least one entity that does not mismatch
                break

    return entity_mismatch

In [5]:
def extract_answers_from_response(response):
    answers = []
    
    if 'results' in response:
        for binding in response['results']['bindings']:
            for _, sub_answer in binding.items():
                value = sub_answer.get('value')
                if isinstance(value, str) and re.match(r"^https?://www\.wikidata\.org/entity/Q\d+$", value):
                    answers.append(value.split("/")[-1])
                else:
                    answers.append(value)
    elif 'boolean' in response:
        answers.append(response['boolean'])
    elif 'head' in response and 'vars' in response['head'] and not response.get('results', {}).get('bindings'):
        return []  # Ensures an empty response if no bindings are present
    
    return answers
    
SEM = asyncio.Semaphore(20) 
async def execute_sparql_query(query, session):

    endpoint = "https://query.wikidata.org/sparql"
    params = {"query": query, "format": "json"}

    async with SEM:
        try:
            async with session.get(endpoint, params=params) as response:
                if response.status != 200:
                    return None
                
                data = await response.json()
                extracted_answers = extract_answers_from_response(data)
    
                return extracted_answers
        
        except Exception as e:
            return "error", str(e)

def calculate_metrics(correct, predicted):
    correct_set = set(correct)
    predicted_set = set(predicted)

    em = correct_set == predicted_set
    true_positives = len(correct_set & predicted_set)  # Intersection

    precision = true_positives / len(predicted_set) if predicted_set else 0
    recall = true_positives / len(correct_set) if correct_set else 0
    f1_score = (2 * precision * recall) / (precision + recall) if (precision + recall) > 0 else 0

    return {'em': em, 'f1': f1_score, 'precision': precision, 'recall': recall}

In [6]:
def extract_sparql(text: str) -> str:
    pattern = r"```(?:[^\n]*\n)?(.*?)```"
    match = re.search(pattern, text, re.DOTALL)
    if match:
        return match.group(1).strip()

    return text.strip()

def get_prediction(id, question, prediction_path):
    predictions = json.load(open(prediction_path))

    for item in predictions:
        if item['id'] == id and item['question'] == question:
            return extract_sparql(item['predicted_query'])

# Validation

In [7]:
dataset_path = 'data/e2e_validation_datasets/lcquad_input_dataset.json'
prediction_path = '/Users/amalekseev/Desktop/text2sparql/data/inference_results/qwen2.5-coder-0.5b/pat_qald_lcquad_2.0_e2e_predictions_gold.json'
save_path = 'data/results/rubq_metrics.csv'

In [8]:
dataset = json.load(open(dataset_path))

table = []
async with aiohttp.ClientSession() as session:
    for item in tqdm(dataset):
        # Check for entity mismatch
        entities = set(item['entities'].keys())
        relations = set(item['relations'].keys())
        entity_mismatch = await check_entity_mismatch(entities, relations)
    
        id = item['id']
        question = item['question']
        gold_query = item['query']
        
        prediction = get_prediction(id, question, prediction_path)
    
        if 'multi-hop' in prediction:
            fisrt_hop, second_hop = prediction.split('<|sep|>')
            fisrt_hop = extract_sparql(fisrt_hop)
            second_hop = extract_sparql(second_hop)
        
            one_hop_entities = await execute_sparql_query(fisrt_hop, session)
            if one_hop_entities is None:
                sparql_query = None
            else:
                for entity in one_hop_entities:
                    if "<|mask|>" in second_hop:
                        second_hop = second_hop.replace("<|mask|>", f'wd:{entity}', 1)
    
                sparql_query = second_hop
        else:
            sparql_query = extract_sparql(prediction)
        
        result = await execute_sparql_query(sparql_query, session)
        sparql_error = result is None
        sparql_empty = len(result) == 0 if result is not None else False
    
        gold_entities = await execute_sparql_query(gold_query, session)
    
        if not gold_entities:
            continue  
    
        if result is not None:
            metric = calculate_metrics(gold_entities, result)
        else:
            metric = {'em': False, 'f1': 0, 'precision': 0.0, 'recall': 0}
        
        sample = {
            'id': id,
            'entity_mismatch': entity_mismatch,
            'sparql_error': sparql_error,
            'sparql_empty': sparql_empty,
        }
        sample.update(metric)
        table.append(sample)

metrics_df = pd.DataFrame(table)
metrics_df['system_rejection'] = (metrics_df['entity_mismatch'] | metrics_df['entity_mismatch'] | metrics_df['sparql_empty'])
metrics_df.to_csv(save_path)

print('Average F1 score:', metrics_df.f1.mean())

100%|█████████████████████████████████████████| 480/480 [15:57<00:00,  1.99s/it]

Average F1 score: 0.5038657913931437





In [18]:
metrics_df

Unnamed: 0,id,entity_mismatch,sparql_error,sparql_empty,em,f1,precision,recall,system_rejection
0,4,False,False,False,True,1.0,1.0,1.0,False
1,7,False,False,True,False,0.0,0.0,0.0,True
2,14,False,False,False,True,1.0,1.0,1.0,False
3,22,False,False,False,True,1.0,1.0,1.0,False
4,25,False,False,True,False,0.0,0.0,0.0,True
...,...,...,...,...,...,...,...,...,...
452,7179,False,False,False,True,1.0,1.0,1.0,False
453,7180,False,False,False,True,1.0,1.0,1.0,False
454,7189,False,False,False,False,0.0,0.0,0.0,False
455,7190,False,False,False,False,0.0,0.0,0.0,False
