In [2]:
import os
import json
from utils.parse_sparql import *

parser = SparqlParser()

def isdirect(question):
    if question["question-type"].startswith("Simple Question"):
        return question["question-type"][17:23] == "Direct"
    if "description" in question:
        if "ndirect" in question["description"]:
            return False
        if "ncomplete" in question["description"]:
            return False
    return True

def generate_train_data(file_path):
    data = json.load(open(file_path, 'r'), strict=False)
    processed_data = []
    for i in range(0, len(data), 2):
        question = data[i]
        answer = data[i + 1]
        qa = {}
        if not isdirect(question) and i > 0:
            context = context + data[i - 2]["utterance"] + " [SEP] " + data[i - 1]["utterance"] + " [SEP] "
        else:
            context = ""
        if "sparql" not in answer:
            continue

        dir_name, file_name = file_path.split('/')[-2:]
        qa["qid"] = dir_name[3:] + '.' + file_name[3:-5] + '.' + str(i // 2)
        qa["question"] = context + question["utterance"] + " [CTX]"
        qa["sparql_query"] = answer["sparql"]
        qa["question_type"] = f"{question['question-type']} [{question['description']}]" if 'description' in question else question['question-type']
        try:
            qa["s_expression"] = parser.parse_sparql(answer["sparql"])
        except:
            print(qa["question"])
            print(qa["sparql_query"])
            print()
            continue
        entity_names = answer["utterance"].split(", ")
        qa["answer"] = [{"entity_name": entity_name} for entity_name in entity_names]
        processed_data.append(qa)
    return processed_data

In [10]:
root = "data/SPICE/valid/"
subtype_count = {}
for dir in os.listdir(root):
    for file_name in os.listdir(root + dir):
        new_data =  generate_train_data(root + dir + '/' + file_name)
        for example in new_data:
            question_type = example['question_type']
            if question_type not in subtype_count:
                subtype_count[question_type] = 0
            subtype_count[question_type] += 1
question_type_nums = dict(sorted(subtype_count.items(), key=lambda item: item[1], reverse=True))
total_num = sum(question_type_nums.values())
question_type_percentages = {question_type: question_type_nums[question_type] / total_num  * 100 for question_type in question_type_nums}
for question_type in question_type_percentages:
    print(f"{question_type} {question_type_percentages[question_type]:.2f}%")

Simple Question (Direct) [Simple Question|Single Entity] 19.93%
Simple Question (Coreferenced) [Simple Question|Single Entity|Indirect] 15.52%
Simple Question (Direct) [Simple Question] 10.80%
Simple Question (Ellipsis) [only subject is changed, parent and predicate remains same] 3.74%
Simple Question (Coreferenced) 3.30%
Logical Reasoning (All) [Logical|Union|Single_Relation] 2.75%
Verification (Boolean) (All) [Verification|2 entities, both direct] 1.93%
Quantitative Reasoning (Count) (All) [Quantitative|Count|Single entity type] 1.91%
Simple Question (Direct) [Simple Question|Mult. Entity|Indirect] 1.77%
Verification (Boolean) (All) [Verification|2 entities, one direct and one indirect, object is indirect] 1.73%
Verification (Boolean) (All) [Verification|one entity, multiple entities (as object) referred indirectly] 1.73%
Simple Question (Coreferenced) [Simple Question|Mult. Entity] 1.70%
Logical Reasoning (All) [Logical|Union|Multiple_Relation] 1.68%
Comparative Reasoning (All) [Com

In [13]:
def get_required_samples_for_each_question_type(sample_size: int, question_type_percentages, sort_key=lambda item: item[1], reverse=True):
    ''' Returns a dictionary with all question type ids and the required number of samples for each question type id '''
    required_samples = {question_type: sample_size * question_type_percentages[question_type] / 100 for question_type in question_type_percentages}

    num_values = sum([int(n) for n in required_samples.values()])

    missing_values = sample_size - num_values
    add_one = dict(sorted(required_samples.items(), key=lambda item: item[1] % 1, reverse=True)[:missing_values])

    # Add 1 to each of these rounded down values
    for key in required_samples:
        if key in add_one:
            required_samples[key] = int(required_samples[key]) + 1
        else:
            required_samples[key] = int(required_samples[key])

    # Return the dictionary sorted dictionary
    return dict(sorted(required_samples.items(), key=sort_key, reverse=reverse))

required_samples = get_required_samples_for_each_question_type(1500, question_type_percentages)
required_samples

{'Simple Question (Direct) [Simple Question|Single Entity]': 299,
 'Simple Question (Coreferenced) [Simple Question|Single Entity|Indirect]': 233,
 'Simple Question (Direct) [Simple Question]': 162,
 'Simple Question (Ellipsis) [only subject is changed, parent and predicate remains same]': 56,
 'Simple Question (Coreferenced)': 49,
 'Logical Reasoning (All) [Logical|Union|Single_Relation]': 41,
 'Verification (Boolean) (All) [Verification|2 entities, both direct]': 29,
 'Quantitative Reasoning (Count) (All) [Quantitative|Count|Single entity type]': 29,
 'Simple Question (Direct) [Simple Question|Mult. Entity|Indirect]': 27,
 'Verification (Boolean) (All) [Verification|2 entities, one direct and one indirect, object is indirect]': 26,
 'Verification (Boolean) (All) [Verification|one entity, multiple entities (as object) referred indirectly]': 26,
 'Simple Question (Coreferenced) [Simple Question|Mult. Entity]': 25,
 'Logical Reasoning (All) [Logical|Union|Multiple_Relation]': 25,
 'Comp

In [16]:
root = "data/SPICE/valid/"
data = {}
for dir in os.listdir(root):
    for file_name in os.listdir(root + dir):
        new_data =  generate_train_data(root + dir + '/' + file_name)
        for example in new_data:
            question_type = example['question_type']
            # print(question_type)
            if question_type not in data:
                data[question_type] = []
            if len(data[question_type]) < required_samples[question_type]:
                data[question_type].append(example)
json.dump([example for question_type in data for example in data[question_type]], open("data/processed_spice_data/dev_1500.json", 'w'), indent=2)