In [1]:
# Retrieve additional data from wikidata such as wikidata entity labels, entity descriptions and aliases
# for all enities and relations in used datasets

In [2]:
import json
import pandas as pd
import sys
sys.path.append("../utils")
from tqdm import tqdm
from wikidata import get_wikidata_entity_label, get_wikidata_entity_description, get_wikidata_entity_aliases

In [3]:
# FewRel
with open("FewRel/raw/train_wiki.json") as f:
    train_data = json.load(f)

with open("FewRel/raw/val_wiki.json") as f:
    val_data = json.load(f)
    
with open("FewRel/raw/pid2name.json") as f:
    pid_dict = json.load(f)

# concatenate dicts
data_fr = train_data | val_data

head_entities_fr, tail_entities_fr, relations_fr = [], [], []
for pid in data_fr:
    for x in data_fr[pid]:
        head_entities_fr.append(x["h"][1])
        tail_entities_fr.append(x["t"][1])
        relations_fr.append(pid)

entities_fr = set(head_entities_fr) | set(tail_entities_fr)
relations_fr = set(relations_fr)
print(f"Num of head entities: {len(head_entities_fr)}")
print(f"Num of tail entities: {len(tail_entities_fr)}")
print(f"Num of unique head entities: {len(set(head_entities_fr))}")
print(f"Num of unique tail entities: {len(set(tail_entities_fr))}")
print(f"Num of unique entities: {len(entities_fr)}")
print(f"Num of unique relations: {len(relations_fr)}")

Num of head entities: 56000
Num of tail entities: 56000
Num of unique head entities: 50340
Num of unique tail entities: 27160
Num of unique entities: 72954
Num of unique relations: 80


In [4]:
# Counterfact
with open("Counterfact/raw/counterfact.json", "r") as file:
    data_cf = json.load(file)
    
target_true, target_new, relations_cf = [], [], []
for x in data_cf:
    target_true.append(x["requested_rewrite"]["target_true"]["id"])
    target_new.append(x["requested_rewrite"]["target_new"]["id"])
    relations_cf.append(x["requested_rewrite"]["relation_id"])
entities_cf = set(target_true) | set(target_new)
relations_cf = set(relations_cf)
print(f"Num of tail entities: {len(target_true)}")
print(f"Num of new tail entities: {len(target_new)}")
print(f"Num of unique tail entities: {len(set(target_true))}")
print(f"Num of unique new tail entities: {len(set(target_new))}")
print(f"Num of unique relations: {len(relations_cf)}")

Num of tail entities: 21919
Num of new tail entities: 21919
Num of unique tail entities: 863
Num of unique new tail entities: 781
Num of unique relations: 34


In [5]:
# merge
entities = entities_fr | entities_cf
entities = sorted(entities, key=lambda s: int(s[1:])) # sort by number after 'Q'
relations = relations_fr | relations_cf
relations = sorted(relations, key=lambda s: int(s[1:])) # sort by number after 'P'
print(f"Total num unique entities: {len(entities)}")
print(f"Total num unique relations: {len(relations)}")

Total num unique entities: 73309
Total num unique relations: 90


In [None]:
# retrieve entitiy data
result = []
for n, entity in tqdm(enumerate(entities), total=len(entities)):
    entitiy_data = dict(
        id = entity,
        label = get_wikidata_entity_label(entity),
        description = get_wikidata_entity_description(entity),
        aliases = get_wikidata_entity_aliases(entity)
    )
    result.append(entitiy_data)
    if n % 10 == 0:
        pd.DataFrame(result).to_json("wikidata_entity_data.json", orient="records", lines=True, mode="a")
        result = []
pd.DataFrame(result).to_json("wikidata_entity_data.json", orient="records", lines=True, mode="a")

In [None]:
# retrieve relation data
result = []
for n, relation in tqdm(enumerate(relations), total=len(relations)):
    relation_data = dict(
        id = relation,
        label = get_wikidata_entity_label(relation),
        description = get_wikidata_entity_description(relation),
        aliases = get_wikidata_entity_aliases(relation)
    )
    result.append(relation_data)
    if n % 10 == 0:
        pd.DataFrame(result).to_json("wikidata_relation_data.json", orient="records", lines=True, mode="a")
        result = []
pd.DataFrame(result).to_json("wikidata_relation_data.json", orient="records", lines=True, mode="a")

In [6]:
# validate
# entities
df = pd.read_json("wikidata_entity_data.json", lines=True)

print(f"Num entities: {len(df)}")
print(f"Num unique entities: {df['id'].nunique()}")
missing_entities = list(set(entities) - set(df["id"].to_list()))
print(f"Missing entities: {missing_entities}")
print(f"No description: {df['description'].isna().sum()}")
print(f"No aliases: {df['aliases'].isna().sum()}")

Num entities: 73309
Num unique entities: 73309
Missing entities: []
No description: 1706
No aliases: 35770


In [7]:
# relations
df_relation = pd.read_json("wikidata_relation_data.json", lines=True)

print(f"Num relations: {len(df_relation)}")
print(f"Num unique relations: {df_relation['id'].nunique()}")
missing_relations = list(set(relations) - set(df_relation["id"].to_list()))
print(f"Missing relations: {missing_relations}")
print(f"No description: {df_relation['description'].isna().sum()}")
print(f"No aliases: {df_relation['aliases'].isna().sum()}")

Num relations: 90
Num unique relations: 90
Missing relations: []
No description: 0
No aliases: 1
