In [1]:
import os
import time

import numpy as np

In [11]:
class Entity:
    def __init__(self, idx: int, name: str, preprocess_func, is_literal=False, affiliation=None):
        self._is_literal = is_literal

        self.id: int = idx
        self.name: str = name.strip()
        self.value = None

        self.preprocess_func = preprocess_func
        self.affiliation = affiliation

        self.involved_as_tail_dict = dict()
        self.involved_as_head_dict = dict()

        self.embedding = None

        self.__init()

    @staticmethod
    def is_entity():
        return True

    @staticmethod
    def is_relation():
        return False

    def __init(self):
        self.value = self.preprocess_func(self.name)

    def is_literal(self):
        return self._is_literal

    def add_relation_as_head(self, relation, tail):
        if self.involved_as_head_dict.__contains__(relation) is False:
            self.involved_as_head_dict[relation] = set()
        self.involved_as_head_dict[relation].add(tail)

    def add_relation_as_tail(self, relation, head):
        if self.involved_as_tail_dict.__contains__(relation) is False:
            self.involved_as_tail_dict[relation] = set()
        self.involved_as_tail_dict[relation].add(head)

In [12]:

class Relation:
    def __init__(self, idx: int, name: str, preprocess_func, is_attribute=False, affiliation=None):
        self._is_attribute = is_attribute

        self.id: int = idx
        self.name: str = name.strip()
        self.value = None

        self.preprocess_func = preprocess_func
        self.affiliation = affiliation

        self.frequency = 0

        self.head_ent_set = set()
        self.tail_ent_set = set()
        self.tuple_set = set()

        self.functionality = 0.0
        self.functionality_inv = 0.0

        self.embedding = None
        self.__init()

    @staticmethod
    def is_entity():
        return False

    @staticmethod
    def is_relation():
        return True

    def __init(self):
        self.value = self.preprocess_func(self.name)

    def is_attribute(self):
        return self._is_attribute

    def add_relation_tuple(self, head, tail):
        self.head_ent_set.add(head)
        self.tail_ent_set.add(tail)
        self.tuple_set.add((head, tail))
        self.frequency += 1

    def calculate_functionality(self):
        if self.frequency == 0:
            return
        self.functionality = len(self.head_ent_set) / self.frequency
        self.functionality_inv = len(self.tail_ent_set) / self.frequency

In [13]:
import re

In [14]:
class KG:
    def __init__(self, name="KG", ent_pre_func=None, rel_pre_func=None, attr_pre_func=None,
                 lite_pre_func=None):
        self.name = name
        self.ent_pre_func = ent_pre_func
        self.rel_pre_func = rel_pre_func
        self.attr_pre_func = attr_pre_func
        self.lite_pre_func = lite_pre_func

        self.entity_set = set()
        self.relation_set = set()
        self.attribute_set = set()
        self.literal_set = set()

        self.entity_dict_by_name = dict()
        self.relation_dict_by_name = dict()
        self.attribute_dict_by_name = dict()
        self.literal_dict_by_name = dict()

        self.entity_dict_by_value = dict()
        self.relation_dict_by_value = dict()
        self.attribute_dict_by_value = dict()
        self.literal_dict_by_value = dict()

        self.ent_lite_list_by_id = list()
        self.rel_attr_list_by_id = list()

        self.relation_tuple_list = list()
        self.attribute_tuple_list = list()

        self.functionality_dict = dict()
        self.ent_id_list = list()
        self.fact_dict_by_head = dict()
        self.fact_dict_by_tail = dict()
        self.is_literal_list = list()

        self.ent_embeddings = None

        self.__init()
        self._init = False

    def __init(self):
        if self.ent_pre_func is None:
            self.ent_pre_func = self.default_pre_func
        if self.rel_pre_func is None:
            self.rel_pre_func = self.default_pre_func
        if self.attr_pre_func is None:
            self.attr_pre_func = self.default_pre_func
        if self.lite_pre_func is None:
            self.lite_pre_func = self.default_pre_func_for_literal

    @staticmethod
    def default_pre_func(name: str):
        pattern = r'"?<?([^">]*)>?"?.*'
        matchObj = re.match(pattern=pattern, string=name)
        if matchObj is None:
            print("Match Error: " + name)
            return name
        value = matchObj.group(1).strip()
        if "/" in value:
            value = value.split(sep="/")[-1].strip()
        return value

    @staticmethod
    def default_pre_func_for_literal(name: str):
        value = name.split("^")[0].strip()
        start, end = 0, len(value) - 1
        if start < len(value) and value[start] == '<':
            start += 1
        if end > 0 and value[end] == '>':
            end -= 1
        if start < len(value) and value[start] == '"':
            start += 1
        if end > 0 and value[end] == '"':
            end -= 1
        if start > end:
            print("Match Error: " + name)
            return name
        value = value[start: end + 1].strip()
        return value

    @staticmethod
    def __dict_set_insert_helper(dictionary: dict, key, value):
        if dictionary.__contains__(key) is False:
            dictionary[key] = set()
        dictionary[key].add(value)

    def get_entity(self, name: str):
        if self.entity_dict_by_name.__contains__(name):
            return self.entity_dict_by_name.get(name)
        else:
            entity = Entity(idx=len(self.literal_set) + len(self.entity_set), name=name, preprocess_func=self.ent_pre_func, affiliation=self)
            self.entity_set.add(entity)
            self.entity_dict_by_name[entity.name] = entity
            self.entity_dict_by_value[entity.value] = entity
            # self.entity_dict_by_id[entity.id] = entity
            # self.ent_id_list.append(entity.id)
            # self.is_literal_list.append(False)
            return entity

    def get_relation(self, name: str):
        if self.relation_dict_by_name.__contains__(name):
            return self.relation_dict_by_name.get(name)
        else:
            relation = Relation(idx=len(self.attribute_set) + len(self.relation_set), name=name, preprocess_func=self.rel_pre_func,
                                affiliation=self)
            self.relation_set.add(relation)
            self.relation_dict_by_name[relation.name] = relation
            self.relation_dict_by_value[relation.value] = relation
            # self.relation_dict_by_id[relation.id] = relation
            return relation

    def get_attribute(self, name: str):
        if self.attribute_dict_by_name.__contains__(name):
            return self.attribute_dict_by_name.get(name)
        else:
            attribute = Relation(idx=len(self.attribute_set) + len(self.relation_set), name=name, preprocess_func=self.attr_pre_func,
                                 affiliation=self, is_attribute=True)
            self.attribute_set.add(attribute)
            self.attribute_dict_by_name[attribute.name] = attribute
            self.attribute_dict_by_value[attribute.value] = attribute
            # self.relation_dict_by_id[attribute.id] = attribute
            return attribute

    def get_literal(self, name: str):
        if self.literal_dict_by_name.__contains__(name):
            return self.literal_dict_by_name.get(name)
        else:
            literal = Entity(idx=len(self.literal_set) + len(self.entity_set), name=name, preprocess_func=self.lite_pre_func,
                             affiliation=self, is_literal=True)
            self.literal_set.add(literal)
            self.literal_dict_by_name[literal.name] = literal
            self.literal_dict_by_value[literal.value] = literal
            # self.entity_dict_by_id[literal.id] = literal
            # self.is_literal_list.append(True)
            return literal

    def insert_relation_tuple(self, head: str, relation: str, tail: str):
        ent_h, rel, ent_t = self.get_entity(head), self.get_relation(relation), self.get_entity(tail)
        self.__insert_relation_tuple_one_way(ent_h, rel, ent_t)
        relation_inv = relation.strip() + str("-(INV)")
        rel_v = self.get_relation(relation_inv)
        self.__insert_relation_tuple_one_way(ent_t, rel_v, ent_h)

    def insert_attribute_tuple(self, entity: str, attribute: str, literal: str):
        ent, attr, val = self.get_entity(entity), self.get_attribute(attribute), self.get_literal(literal)
        self.__insert_attribute_tuple_one_way(ent, attr, val)
        attribute_inv = attribute.strip() + str("-(INV)")
        attr_v = self.get_attribute(attribute_inv)
        self.__insert_attribute_tuple_one_way(val, attr_v, ent)

    def __insert_relation_tuple_one_way(self, ent_h, rel, ent_t):
        ent_h.add_relation_as_head(relation=rel, tail=ent_t)
        rel.add_relation_tuple(head=ent_h, tail=ent_t)
        ent_t.add_relation_as_tail(relation=rel, head=ent_h)
        self.relation_tuple_list.append((ent_h, rel, ent_t))
        # if not self.fact_dict_by_head.__contains__(ent_h.id):
        #     self.fact_dict_by_head[ent_h.id] = list()
        # if not self.fact_dict_by_tail.__contains__(ent_t.id):
        #     self.fact_dict_by_tail[ent_t.id] = list()
        # self.fact_dict_by_head[ent_h.id].append((rel.id, ent_t.id))
        # self.fact_dict_by_tail[ent_t.id].append((rel.id, ent_h.id))

    def __insert_attribute_tuple_one_way(self, ent, attr, val):
        ent.add_relation_as_head(relation=attr, tail=val)
        attr.add_relation_tuple(head=ent, tail=val)
        val.add_relation_as_tail(relation=attr, head=ent)
        self.attribute_tuple_list.append((ent, attr, val))
        # if not self.fact_dict_by_head.__contains__(ent.id):
        #     self.fact_dict_by_head[ent.id] = list()
        # if not self.fact_dict_by_tail.__contains__(val.id):
        #     self.fact_dict_by_tail[val.id] = list()
        # self.fact_dict_by_head[ent.id].append((attr.id, val.id))
        # self.fact_dict_by_tail[val.id].append((attr.id, ent.id))

    def get_object_by_name(self, name: str):
        name = name.strip()
        if self.attribute_dict_by_name.__contains__(name):
            return self.attribute_dict_by_name[name]
        if self.relation_dict_by_name.__contains__(name):
            return self.relation_dict_by_name[name]
        if self.literal_dict_by_name.__contains__(name):
            return self.literal_dict_by_name[name]
        if self.entity_dict_by_name.__contains__(name):
            return self.entity_dict_by_name[name]

    def __calculate_functionality(self):
        for relation in self.relation_set:
            relation.calculate_functionality()
            self.functionality_dict[relation.id] = relation.functionality
        for attribute in self.attribute_set:
            attribute.calculate_functionality()
            self.functionality_dict[attribute.id] = attribute.functionality

    def init(self):
        def init_index(set_a, set_b):
            index = 0
            for item in set_a:
                item.id = index
                index += 1
            for item in set_b:
                item.id = index
                index += 1

        def init_fact_dict(tuple_list, fact_dict_by_head, fact_dict_by_tail):
            for (h, r, t) in tuple_list:
                if not self.fact_dict_by_head.__contains__(h.id):
                    self.fact_dict_by_head[h.id] = list()
                if not self.fact_dict_by_tail.__contains__(t.id):
                    self.fact_dict_by_tail[t.id] = list()
                fact_dict_by_head[h.id].append((r.id, t.id))
                fact_dict_by_tail[t.id].append((r.id, h.id))

        def init_idx_dict(item_set):
            idx_list = [None for _ in range(len(item_set))]
            for item in item_set:
                idx_list[item.id] = item
            return idx_list

        init_index(self.entity_set, self.literal_set)
        init_index(self.relation_set, self.attribute_set)
        init_fact_dict(self.relation_tuple_list + self.attribute_tuple_list, self.fact_dict_by_head, self.fact_dict_by_tail)
        self.ent_lite_list_by_id = init_idx_dict(self.entity_set | self.literal_set)
        self.rel_attr_list_by_id = init_idx_dict(self.relation_set | self.attribute_set)
        self.is_literal_list = [False for _ in range(len(self.entity_set))] + [True for _ in range(len(self.literal_set))]
        self.ent_id_list = [item.id for item in self.entity_set]
        self.__calculate_functionality()
        self._init = True

    def is_init(self):
        return self._init

    def init_ent_embeddings(self):
        for ent in self.entity_set:
            idx, embedding = ent.id, ent.embedding
            if embedding is None:
                break
            if self.ent_embeddings is None:
                self.ent_embeddings = np.zeros((len(self.entity_set), len(embedding)))
            self.ent_embeddings[idx, :] = embedding

    def set_ent_embedding(self, idx, emb, func=None):
        if self.ent_embeddings is not None:
            if func is None:
                self.ent_embeddings[idx, :] = emb
            else:
                self.ent_embeddings[idx, :] = func(self.ent_lite_list_by_id[idx].embedding, emb)

    def print_kg_info(self, func_num=10):
        print("\nInformation of Knowledge Graph (" + str(self.name) + "):")
        print("- Relation Tuple Number: " + str(int(len(self.relation_tuple_list) / 2)))
        print("- Attribute Tuple Number: " + str(int(len(self.attribute_tuple_list) / 2)))
        print("- Entity Number: " + str(len(self.entity_set)))
        print("- Relation Number: " + str(int(len(self.relation_set) / 2)))
        print("- Attribute Number: " + str(int(len(self.attribute_set) / 2)))
        print("- Literal Number: " + str(len(self.literal_set)))
        print("- Functionality Statistics:")

        def functionality_printer(is_rel: bool, inverse: bool, num: int):
            if is_rel:
                tmp_list = list(self.relation_set.copy())
            else:
                tmp_list = list(self.attribute_set.copy())
            if inverse:
                tmp_list.sort(key=lambda x: x.functionality_inv, reverse=True)
            else:
                tmp_list.sort(key=lambda x: x.functionality, reverse=True)
            title = "--- TOP-{} {} ({}) ---"
            title = title.format(str(num), "Relations" if is_rel else "Attributes", "Func-Inv" if inverse else "Func")
            print(title)
            for i in range(min(num, len(tmp_list))):
                relation = tmp_list[i]
                item = "Name: {}\t{}: {}".format(relation.name, "Func-Inv" if inverse else "Func",
                                                 relation.functionality_inv if inverse else relation.functionality)
                print(item)
            print("......")

        functionality_printer(True, False, func_num)
        functionality_printer(True, True, func_num)
        functionality_printer(False, False, func_num)
        functionality_printer(False, True, func_num)

In [15]:
def get_counterpart_id_and_prob(ent_match, ent_prob, ent_id):
    counterpart = ent_match[ent_id]
    if counterpart is None:
        return None, 0.0
    else:
        return counterpart, ent_prob[ent_id]


def set_counterpart_id_and_prob(ent_match, ent_prob, ent_l_id, ent_r_id, prob):
    curr_prob = ent_prob[ent_l_id]
    if prob < curr_prob:
        return
    ent_match[ent_l_id], ent_prob[ent_l_id] = ent_r_id, prob


def register_rel_align_prob_norm(dictionary, rel, prob):
    if not dictionary.__contains__(rel):
        dictionary[rel] = 0.0
    dictionary[rel] += prob


def register_ongoing_prob_product(dictionary, key1, key2, prob):
    if not dictionary.__contains__(key1):
        dictionary[key1] = dict()
    if not dictionary[key1].__contains__(key2):
        dictionary[key1][key2] = 0.0
    dictionary[key1][key2] += prob


def get_rel_align_prob(dictionary, rel_l, rel_r):
    if not dictionary.__contains__(rel_l):
        return 0.0
    if not dictionary[rel_l].__contains__(rel_r):
        return 0.0
    prob = dictionary[rel_l][rel_r]
    prob = 1.0 if prob > 1.0 else prob
    prob = 0.0 if prob < 0.0 else prob
    return prob


def update_ent_align_prob(ent_align_ongoing_dict, ent_match, ent_prob, kg_l_ent_embeds, kg_r_ent_embeds, ent, fusion_func, init):
    counterpart, value = None, 0.0
    for (candidate, prob) in ent_align_ongoing_dict.items():
        val = 1.0 - prob
        if not init and kg_l_ent_embeds is not None and kg_r_ent_embeds is not None and fusion_func is not None:
            ent_emb = kg_l_ent_embeds[ent, :]
            candidate_emb = kg_r_ent_embeds[candidate, :]
            val = fusion_func(val, ent_emb, candidate_emb)
        if val >= value:
            value, counterpart = val, candidate
    value = 1.0 if value > 1.0 else value
    value = 0.0 if value < 0.0 else value
    set_counterpart_id_and_prob(ent_match, ent_prob, ent, counterpart, value)


def register_ent_equality(ent_align_ongoing_dict, rel_align_dict_l, rel_align_dict_r,
                          kg_l_func, kg_r_func,
                          rel, rel_counterpart, tail_counterpart,
                          head_eqv_prob, theta, epsilon, delta, init):
    prob_sub = get_rel_align_prob(rel_align_dict_l, rel, rel_counterpart) / epsilon
    prob_sup = get_rel_align_prob(rel_align_dict_r, rel_counterpart, rel) / epsilon
    if prob_sub < theta and prob_sup < theta:
        if init:
            prob_sub, prob_sup = theta, theta
        else:
            return
    func_l, func_r = kg_l_func.get(rel, 0.0) / epsilon, kg_r_func.get(rel_counterpart, 0.0) / epsilon
    factor = 1.0
    factor_l = 1.0 - head_eqv_prob * prob_sup * func_r
    factor_r = 1.0 - head_eqv_prob * prob_sub * func_l
    if prob_sub >= 0.0 and func_l >= 0.0:
        factor *= factor_l
    if prob_sup >= 0.0 and func_r >= 0.0:
        factor *= factor_r
    if 1.0 - factor > delta:
        if not ent_align_ongoing_dict.__contains__(tail_counterpart):
            ent_align_ongoing_dict[tail_counterpart] = 1.0
        ent_align_ongoing_dict[tail_counterpart] *= factor


def one_iteration_one_way(queue, kg_r_fact_dict_by_head,
                          kg_l_fact_dict_by_tail,
                          kg_l_func, kg_r_func,
                          sub_ent_match, sub_ent_prob,
                          is_literal_list_r,
                          rel_align_dict_l, rel_align_dict_r,
                          rel_ongoing_dict_queue, rel_norm_dict_queue,
                          ent_match_tuple_queue,
                          kg_l_ent_embeds, kg_r_ent_embeds,
                          fusion_func,
                          theta, epsilon, delta, init=False, ent_align=True):
    rel_ongoing_dict, rel_norm_dict = dict(), dict()
    while not queue.empty():
        # noinspection PyBroadException
        try:
            ent_id = queue.get_nowait()
        except Exception:
            break
        ent_align_ongoing_dict = dict()
        ent_fact_list = kg_l_fact_dict_by_tail.get(ent_id, list())
        for (rel_id, head_id) in ent_fact_list:
            head_counterpart, head_eqv_prob = get_counterpart_id_and_prob(sub_ent_match, sub_ent_prob, head_id)
            if head_counterpart is None or head_eqv_prob < theta:
                continue
            ent_counterpart, tail_eqv_prob = get_counterpart_id_and_prob(sub_ent_match, sub_ent_prob, ent_id)
            if ent_counterpart is not None:
                register_rel_align_prob_norm(rel_norm_dict, rel_id, head_eqv_prob * tail_eqv_prob)
            head_counterpart_fact_list = kg_r_fact_dict_by_head.get(head_counterpart, list())
            for (rel_counterpart_id, tail_counterpart_id) in head_counterpart_fact_list:
                if is_literal_list_r[tail_counterpart_id]:
                    continue
                eqv_prob = tail_eqv_prob if tail_counterpart_id == ent_counterpart else 0.0
                if eqv_prob > 0.0:
                    register_ongoing_prob_product(rel_ongoing_dict, rel_id, rel_counterpart_id,
                                                  head_eqv_prob * eqv_prob)
                if ent_align:
                    register_ent_equality(ent_align_ongoing_dict, rel_align_dict_l, rel_align_dict_r,
                                          kg_l_func, kg_r_func,
                                          rel_id, rel_counterpart_id, tail_counterpart_id,
                                          head_eqv_prob, theta, epsilon, delta, init)
        if ent_align:
            update_ent_align_prob(ent_align_ongoing_dict, sub_ent_match, sub_ent_prob, kg_l_ent_embeds, kg_r_ent_embeds, ent_id, fusion_func, init)
    rel_ongoing_dict_queue.put(rel_ongoing_dict), rel_norm_dict_queue.put(rel_norm_dict)
    ent_match_tuple_queue.put((sub_ent_match, sub_ent_prob))
    exit(1)

In [16]:
import gc
import sys
import random
import multiprocessing as mp

In [17]:
sys.setrecursionlimit(1000000)

In [18]:
class KGs:
    def __init__(self, kg1: KG, kg2: KG, theta=0.1, iteration=3, workers=4, fusion_func=None):
        self.kg_l = kg1
        self.kg_r = kg2
        self.theta = theta
        self.iteration = iteration
        self.delta = 0.01
        self.epsilon = 1.01
        self.const = 10.0
        self.workers = workers
        self.fusion_func = fusion_func

        self.rel_ongoing_dict_l, self.rel_ongoing_dict_r = dict(), dict()
        self.rel_norm_dict_l, self.rel_norm_dict_r = dict(), dict()
        self.rel_align_dict_l, self.rel_align_dict_r = dict(), dict()

        self.sub_ent_match = None
        self.sup_ent_match = None
        self.sub_ent_prob = None
        self.sup_ent_prob = None

        self._iter_num = 0
        self.has_load = False
        self.util = KGsUtil(self, self.__get_counterpart_and_prob, self.__set_counterpart_and_prob)
        self.__init()

    def __init(self):
        if not self.kg_l.is_init():
            self.kg_l.init()
        if not self.kg_r.is_init():
            self.kg_r.init()

        kg_l_ent_num = len(self.kg_l.entity_set) + len(self.kg_l.literal_set)
        kg_r_ent_num = len(self.kg_r.entity_set) + len(self.kg_r.literal_set)
        self.sub_ent_match = [None for _ in range(kg_l_ent_num)]
        self.sub_ent_prob = [0.0 for _ in range(kg_l_ent_num)]
        self.sup_ent_match = [None for _ in range(kg_r_ent_num)]
        self.sup_ent_prob = [0.0 for _ in range(kg_r_ent_num)]

        for lite_l in self.kg_l.literal_set:
            if self.kg_r.literal_dict_by_value.__contains__(lite_l.value):
                lite_r = self.kg_r.literal_dict_by_value[lite_l.value]
                l_id, r_id = lite_l.id, lite_r.id
                self.sub_ent_match[l_id], self.sup_ent_match[r_id] = lite_r.id, lite_l.id
                self.sub_ent_prob[l_id], self.sup_ent_prob[r_id] = 1.0, 1.0

    def __get_counterpart_and_prob(self, ent):
        source = ent.affiliation is self.kg_l
        counterpart_id = self.sub_ent_match[ent.id] if source else self.sup_ent_match[ent.id]
        if counterpart_id is None:
            return None, 0.0
        else:
            counterpart = self.kg_r.ent_lite_list_by_id[counterpart_id] if source \
                else self.kg_l.ent_lite_list_by_id[counterpart_id]
            return counterpart, self.sub_ent_prob[ent.id] if source else self.sup_ent_prob[ent.id]

    def __set_counterpart_and_prob(self, ent_l, ent_r, prob, force=False):
        source = ent_l.affiliation is self.kg_l
        l_id, r_id = ent_l.id, ent_r.id
        curr_prob = self.sub_ent_prob[l_id] if source else self.sup_ent_prob[l_id]
        if not force and prob < curr_prob:
            return False
        if source:
            self.sub_ent_match[l_id], self.sub_ent_prob[l_id] = r_id, prob
        else:
            self.sup_ent_match[l_id], self.sup_ent_prob[l_id] = r_id, prob
        return True

    def set_fusion_func(self, func):
        self.fusion_func = func

    def set_iteration(self, iteration):
        self.iteration = iteration

    def set_worker_num(self, worker_num):
        self.workers = worker_num

    def run(self, test_path=None):
        start_time = time.time()
        print("Start...")
        for i in range(self.iteration):
            self._iter_num = i
            print(str(i + 1) + "-th iteration......")
            self.__run_per_iteration()
            self.util.test(path=test_path, threshold=[0.1 * i for i in range(10)])
            gc.collect()
        print("PARIS Completed!")
        end_time = time.time()
        print("Total time: " + str(end_time - start_time))

    def __run_per_iteration(self):
        self.__run_per_iteration_one_way(self.kg_l)
        self.__ent_bipartite_matching()
        self.__run_per_iteration_one_way(self.kg_r, ent_align=False)
        return

    def __run_per_iteration_one_way(self, kg: KG, ent_align=True):
        kg_other = self.kg_l if kg is self.kg_r else self.kg_r
        ent_list = self.__generate_list(kg)
        mgr = mp.Manager()
        ent_queue = mgr.Queue(len(ent_list))
        for ent_id in ent_list:
            ent_queue.put(ent_id)

        rel_ongoing_dict_queue = mgr.Queue()
        rel_norm_dict_queue = mgr.Queue()
        ent_match_tuple_queue = mgr.Queue()

        kg_r_fact_dict_by_head = kg_other.fact_dict_by_head
        kg_l_fact_dict_by_tail = kg.fact_dict_by_tail
        kg_l_func, kg_r_func = kg.functionality_dict, kg_other.functionality_dict

        rel_align_dict_l, rel_align_dict_r = self.rel_align_dict_l, self.rel_align_dict_r

        if kg is self.kg_l:
            ent_match, ent_prob = self.sub_ent_match, self.sub_ent_prob
            is_literal_list_r = self.kg_r.is_literal_list
        else:
            ent_match, ent_prob = self.sup_ent_match, self.sup_ent_prob
            rel_align_dict_l, rel_align_dict_r = rel_align_dict_r, rel_align_dict_l
            is_literal_list_r = self.kg_l.is_literal_list

        init = not self.has_load and self._iter_num <= 1
        tasks = []
        kg_l_ent_embeds, kg_r_ent_embeds = kg.ent_embeddings, kg_other.ent_embeddings
        for _ in range(self.workers):
            task = mp.Process(target=one_iteration_one_way, args=(ent_queue, kg_r_fact_dict_by_head,
                                                                  kg_l_fact_dict_by_tail,
                                                                  kg_l_func, kg_r_func,
                                                                  ent_match, ent_prob,
                                                                  is_literal_list_r,
                                                                  rel_align_dict_l, rel_align_dict_r,
                                                                  rel_ongoing_dict_queue, rel_norm_dict_queue,
                                                                  ent_match_tuple_queue,
                                                                  kg_l_ent_embeds, kg_r_ent_embeds,
                                                                  self.fusion_func,
                                                                  self.theta, self.epsilon, self.delta, init,
                                                                  ent_align))
            task.start()
            tasks.append(task)

        for task in tasks:
            task.join()

        self.__clear_ent_match_and_prob(ent_match, ent_prob)
        while not ent_match_tuple_queue.empty():
            ent_match_tuple = ent_match_tuple_queue.get()
            self.__merge_ent_align_result(ent_match, ent_prob, ent_match_tuple[0], ent_match_tuple[1])

        rel_ongoing_dict = self.rel_ongoing_dict_l if kg is self.kg_l else self.rel_ongoing_dict_r
        rel_norm_dict = self.rel_norm_dict_l if kg is self.kg_l else self.rel_norm_dict_r
        rel_align_dict = self.rel_align_dict_l if kg is self.kg_l else self.rel_align_dict_r

        rel_ongoing_dict.clear(), rel_norm_dict.clear(), rel_align_dict.clear()
        while not rel_ongoing_dict_queue.empty():
            self.__merge_rel_ongoing_dict(rel_ongoing_dict, rel_ongoing_dict_queue.get())

        while not rel_norm_dict_queue.empty():
            self.__merge_rel_norm_dict(rel_norm_dict, rel_norm_dict_queue.get())

        self.__update_rel_align_dict(rel_align_dict, rel_ongoing_dict, rel_norm_dict)

    @staticmethod
    def update_ent_embeds(kg, new_ent_emb_dict, alpha=0.5):
        def update_function(emb_origin, emb_new):
            emb_pool = alpha * emb_origin + (1.0 - alpha) * emb_new
            return emb_pool / np.linalg.norm(emb_pool)

        for (idx, emb) in new_ent_emb_dict.items():
            kg.set_ent_embedding(idx, emb, update_function)

    @staticmethod
    def __generate_list(kg: KG):
        ent_list = kg.ent_id_list
        random.shuffle(ent_list)
        return ent_list

    @staticmethod
    def __merge_rel_ongoing_dict(rel_dict_l, rel_dict_r):
        for (rel, rel_counterpart_dict) in rel_dict_r.items():
            if not rel_dict_l.__contains__(rel):
                rel_dict_l[rel] = rel_counterpart_dict
            else:
                for (rel_counterpart, prob) in rel_counterpart_dict.items():
                    if not rel_dict_l[rel].__contains__(rel_counterpart):
                        rel_dict_l[rel][rel_counterpart] = prob
                    else:
                        rel_dict_l[rel][rel_counterpart] += prob

    @staticmethod
    def __merge_rel_norm_dict(norm_dict_l, norm_dict_r):
        for (rel, norm) in norm_dict_r.items():
            if not norm_dict_l.__contains__(rel):
                norm_dict_l[rel] = norm
            else:
                norm_dict_l[rel] += norm

    @staticmethod
    def __update_rel_align_dict(rel_align_dict, rel_ongoing_dict, rel_norm_dict, const=10.0):
        for (rel, counterpart_dict) in rel_ongoing_dict.items():
            norm = rel_norm_dict.get(rel, 1.0)
            if not rel_align_dict.__contains__(rel):
                rel_align_dict[rel] = dict()
            rel_align_dict[rel].clear()
            for (counterpart, score) in counterpart_dict.items():
                prob = score / (const + norm)
                rel_align_dict[rel][counterpart] = prob

    def __ent_bipartite_matching(self):
        for ent_l in self.kg_l.entity_set:
            ent_id = ent_l.id
            counterpart_id, prob = self.sub_ent_match[ent_id], self.sub_ent_prob[ent_id]
            if counterpart_id is None:
                continue
            counterpart_prob = self.sup_ent_prob[counterpart_id]
            if counterpart_prob < prob:
                self.sup_ent_match[counterpart_id] = ent_id
                self.sup_ent_prob[counterpart_id] = prob
        for ent_l in self.kg_l.entity_set:
            ent_id = ent_l.id
            sub_counterpart_id = self.sub_ent_match[ent_id]
            if sub_counterpart_id is None:
                continue
            sup_counterpart_id = self.sup_ent_match[sub_counterpart_id]
            if sup_counterpart_id is None:
                continue
            if sup_counterpart_id != ent_id:
                self.sub_ent_match[ent_id], self.sub_ent_prob[ent_id] = None, 0.0

    @staticmethod
    def __merge_ent_align_result(ent_match_l, ent_prob_l, ent_match_r, ent_prob_r):
        assert len(ent_match_l) == len(ent_match_r)
        for i in range(len(ent_prob_l)):
            if ent_prob_l[i] < ent_prob_r[i]:
                ent_prob_l[i] = ent_prob_r[i]
                ent_match_l[i] = ent_match_r[i]

    @staticmethod
    def __clear_ent_match_and_prob(ent_match, ent_prob):
        for i in range(len(ent_match)):
            ent_match[i] = None
            ent_prob[i] = 0.0

In [19]:
class KGsUtil:
    def __init__(self, kgs, get_counterpart_and_prob, set_counterpart_and_prob):
        self.kgs = kgs
        self.__get_counterpart_and_prob = get_counterpart_and_prob
        self.__set_counterpart_and_prob = set_counterpart_and_prob
        self.ent_links_candidate = list()

    def reset_ent_align_result(self):
        for ent in self.kgs.kg_l.entity_set:
            idx = ent.id
            self.kgs.sub_ent_match[idx], self.kgs.sub_ent_prob[idx] = None, 0.0
        for ent in self.kgs.kg_r.entity_set:
            idx = ent.id
            self.kgs.sup_ent_match[idx], self.kgs.sup_ent_prob[idx] = None, 0.0
        emb_l, emb_r = self.kgs.kg_l.ent_embeddings, self.kgs.kg_r.ent_embeddings
        matrix = np.matmul(emb_l, emb_r.T)
        max_indices = np.argmax(matrix, axis=1)
        print(max_indices)
        for i in range(len(max_indices)):
            counterpart_id = max_indices[i]
            self.kgs.sub_ent_match[i], self.kgs.sub_ent_prob[i] = counterpart_id, 0.2
            self.kgs.sup_ent_match[counterpart_id], self.kgs.sup_ent_prob[counterpart_id] = i, 0.2

    def test(self, path, threshold):
        gold_result = set()
        with open(path, "r", encoding="utf8") as f:
            for line in f.readlines():
                params = str.strip(line).split("\t")
                ent_l, ent_r = params[0].strip(), params[1].strip()
                obj_l, obj_r = self.kgs.kg_l.entity_dict_by_name.get(ent_l), self.kgs.kg_r.entity_dict_by_name.get(
                    ent_r)
                if obj_l is None:
                    print("Exception: fail to load Entity (" + ent_l + ")")
                if obj_r is None:
                    print("Exception: fail to load Entity (" + ent_r + ")")
                if obj_l is None or obj_r is None:
                    continue
                gold_result.add((obj_l.id, obj_r.id))

        threshold_list = []
        if isinstance(threshold, float) or isinstance(threshold, int):
            threshold_list.append(float(threshold))
        else:
            threshold_list = threshold

        for threshold_item in threshold_list:
            ent_align_result = set()
            for ent_id in self.kgs.kg_l.ent_id_list:
                counterpart_id = self.kgs.sub_ent_match[ent_id]
                if counterpart_id is not None:
                    prob = self.kgs.sub_ent_prob[ent_id]
                    if prob < threshold_item:
                        continue
                    ent_align_result.add((ent_id, counterpart_id))

            correct_num = len(gold_result & ent_align_result)
            predict_num = len(ent_align_result)
            total_num = len(gold_result)

            if predict_num == 0:
                print("Threshold: " + format(threshold_item, ".3f") + "\tException: no satisfied alignment result")
                continue

            if total_num == 0:
                print("Threshold: " + format(threshold_item, ".3f") + "\tException: no satisfied instance for testing")
            else:
                precision, recall = correct_num / predict_num, correct_num / total_num
                if precision <= 0.0 or recall <= 0.0:
                    print("Threshold: " + format(threshold_item, ".3f") + "\tPrecision: " + format(precision, ".6f") +
                          "\tRecall: " + format(recall, ".6f") + "\tF1-Score: Nan")
                else:
                    f1_score = 2.0 * precision * recall / (precision + recall)
                    print("Threshold: " + format(threshold_item, ".3f") + "\tPrecision: " + format(precision, ".6f") +
                          "\tRecall: " + format(recall, ".6f") + "\tF1-Score: " + format(f1_score, ".6f"))

    def generate_input_for_embed_align(self, link_path, save_dir="output", threshold=0.0):
        ent_align_predict, visited = set(), set()
        for ent in self.kgs.kg_l.entity_set:
            counterpart, prob = self.__get_counterpart_and_prob(ent)
            if prob < threshold or counterpart is None:
                continue
            ent_align_predict.add((ent, counterpart))
            visited.add(ent)

        ent_align_test = set()
        with open(link_path, "r", encoding="utf8") as f:
            for line in f.readlines():
                params = str.strip(line).split("\t")
                ent_l, ent_r = params[0].strip(), params[1].strip()
                obj_l, obj_r = self.kgs.kg_l.entity_dict_by_name.get(ent_l), self.kgs.kg_r.entity_dict_by_name.get(
                    ent_r)
                if obj_l is None or obj_r is None:
                    continue
                if obj_l not in visited:
                    ent_align_test.add((obj_l, obj_r))
        if not os.path.exists(save_dir):
            os.makedirs(save_dir)
        train_path = os.path.join(save_dir, "train_links")
        test_path = os.path.join(save_dir, "test_links")
        valid_path = os.path.join(save_dir, "valid_links")

        def writer(path, result_set):
            with open(path, "w", encoding="utf8") as file:
                num, length = 0, len(result_set)
                for (l, r) in result_set:
                    file.write("\t".join([l.name, r.name]))
                    num += 1
                    if num < length:
                        file.write("\n")

        writer(train_path, ent_align_predict)
        writer(test_path, ent_align_test)
        writer(valid_path, ent_align_test)
        print("training size: " + str(len(ent_align_predict)) + "\ttest size: " + str(len(ent_align_test)))

    def save_results(self, path="output/EA_Result.txt"):
        ent_dict, lite_dict, attr_dict, rel_dict = dict(), dict(), dict(), dict()
        for obj in (self.kgs.kg_l.entity_set | self.kgs.kg_l.literal_set):
            counterpart, prob = self.__get_counterpart_and_prob(obj)
            if counterpart is not None:
                if obj.is_literal():
                    lite_dict[(obj, counterpart)] = [prob]
                else:
                    ent_dict[(obj, counterpart)] = [prob]

        for (rel_id, rel_counterpart_id_dict) in self.kgs.rel_align_dict_l.items():
            rel = self.kgs.kg_l.rel_attr_list_by_id[rel_id]
            dictionary = attr_dict if rel.is_attribute() else rel_dict
            for (rel_counterpart_id, prob) in rel_counterpart_id_dict.items():
                if prob > self.kgs.theta:
                    rel_counterpart = self.kgs.kg_r.rel_attr_list_by_id[rel_counterpart_id]
                    dictionary[(rel, rel_counterpart)] = [prob, 0.0]

        for (rel_id, rel_counterpart_id_dict) in self.kgs.rel_align_dict_r.items():
            rel = self.kgs.kg_r.rel_attr_list_by_id[rel_id]
            dictionary = attr_dict if rel.is_attribute() else rel_dict
            for (rel_counterpart_id, prob) in rel_counterpart_id_dict.items():
                if prob > self.kgs.theta:
                    rel_counterpart = self.kgs.kg_l.rel_attr_list_by_id[rel_counterpart_id]
                    if not dictionary.__contains__((rel_counterpart, rel)):
                        dictionary[(rel_counterpart, rel)] = [0.0, 0.0]
                    dictionary[(rel_counterpart, rel)][-1] = prob
        base, _ = os.path.split(path)
        if not os.path.exists(base):
            os.makedirs(base)
        if os.path.exists(path):
            os.remove(path)
        self.__result_writer(path, attr_dict, "Attribute Alignment")
        self.__result_writer(path, rel_dict, "Relation Alignment")
        self.__result_writer(path, lite_dict, "Literal Alignment")
        self.__result_writer(path, ent_dict, "Entity Alignment")
        return

    def save_params(self, path="output/EA_Params"):
        base, _ = os.path.split(path)
        if not os.path.exists(base):
            os.makedirs(base)
        with open(path, "w", encoding="utf8") as f:
            for obj in (self.kgs.kg_l.entity_set | self.kgs.kg_l.literal_set):
                counterpart, prob = self.__get_counterpart_and_prob(obj)
                if counterpart is not None:
                    f.write("\t".join(["L", obj.name, counterpart.name, str(prob)]) + "\n")
            for obj in (self.kgs.kg_r.entity_set | self.kgs.kg_r.literal_set):
                counterpart, prob = self.__get_counterpart_and_prob(obj)
                if counterpart is not None:
                    f.write("\t".join(["R", obj.name, counterpart.name, str(prob)]) + "\n")
            for (rel_id, rel_counterpart_id_dict) in self.kgs.rel_align_dict_l.items():
                rel = self.kgs.kg_l.rel_attr_list_by_id[rel_id]
                for (rel_counterpart_id, prob) in rel_counterpart_id_dict.items():
                    if prob > 0.0:
                        rel_counterpart = self.kgs.kg_r.rel_attr_list_by_id[rel_counterpart_id]
                        prefix = "L"
                        f.write("\t".join([prefix, rel.name, rel_counterpart.name, str(prob)]) + "\n")
            for (rel_id, rel_counterpart_id_dict) in self.kgs.rel_align_dict_r.items():
                rel = self.kgs.kg_r.rel_attr_list_by_id[rel_id]
                for (rel_counterpart_id, prob) in rel_counterpart_id_dict.items():
                    if prob > 0.0:
                        rel_counterpart = self.kgs.kg_l.rel_attr_list_by_id[rel_counterpart_id]
                        prefix = "R"
                        f.write("\t".join([prefix, rel.name, rel_counterpart.name, str(prob)]) + "\n")
        return

    def load_params(self, path="output/EA_Params", init=True):
        self.kgs.has_load = init

        def get_obj_by_name(kg_l, kg_r, name1, name2):
            obj1, obj2 = kg_l.literal_dict_by_name.get(name1), kg_r.literal_dict_by_name.get(name2)
            if obj1 is None or obj2 is None:
                obj1, obj2 = kg_l.entity_dict_by_name.get(name1), kg_r.entity_dict_by_name.get(name2)
            if obj1 is None or obj2 is None:
                obj1, obj2 = kg_l.entity_dict_by_name.get(name1), kg_r.entity_dict_by_name.get(name2)
            if obj1 is None or obj2 is None:
                obj1, obj2 = kg_l.relation_dict_by_name.get(name1), kg_r.relation_dict_by_name.get(name2)
            if obj1 is None or obj2 is None:
                obj1, obj2 = kg_l.attribute_dict_by_name.get(name1), kg_r.attribute_dict_by_name.get(name2)
            return obj1, obj2

        with open(path, "r", encoding="utf8") as f:
            for line in f.readlines():
                if len(line.strip()) == 0:
                    continue
                params = line.strip().split("\t")
                assert len(params) == 4
                prefix, name_l, name_r, prob = params[0].strip(), params[1].strip(), params[2].strip(), float(
                    params[3].strip())
                if prefix == "L":
                    obj_l, obj_r = get_obj_by_name(self.kgs.kg_l, self.kgs.kg_r, name_l, name_r)
                else:
                    obj_l, obj_r = get_obj_by_name(self.kgs.kg_r, self.kgs.kg_l, name_l, name_r)
                assert (obj_l is not None and obj_r is not None)
                if obj_l.is_entity():
                    idx_l = obj_l.id
                    if prefix == "L":
                        self.kgs.sub_ent_match[idx_l], self.kgs.sub_ent_prob[idx_l] = obj_r.id, prob
                    else:
                        self.kgs.sup_ent_match[idx_l], self.kgs.sup_ent_prob[idx_l] = obj_r.id, prob
                else:
                    if prefix == "L":
                        self.__params_loader_helper(self.kgs.rel_align_dict_l, obj_l.id, obj_r.id, prob)
                    else:
                        self.__params_loader_helper(self.kgs.rel_align_dict_r, obj_l.id, obj_r.id, prob)
        return

    def load_ent_links(self, path, func=None, num=None, init_value=None, threshold_min=0.0, threshold_max=1.0,
                       force=False):
        ent_link_list = list()
        with open(path, "r", encoding="utf8") as f:
            for line in f.readlines():
                line = line.strip()
                if len(line) == 0:
                    continue
                params = line.split(sep="\t")
                name_l, name_r = params[0].strip(), params[1].strip()
                obj_l, obj_r = self.kgs.kg_l.get_object_by_name(name_l), self.kgs.kg_r.get_object_by_name(name_r)
                if obj_l is None or obj_r is None:
                    continue
                if init_value is None:
                    if len(params) == 3:
                        prob = float(params[2].strip())
                    else:
                        prob = 1.0
                else:
                    prob = init_value
                if prob < threshold_min or prob > threshold_max:
                    continue
                if func is not None:
                    prob = func(prob)
                ent_link_list.append((obj_l, obj_r, prob))
        random_list = random.choices(ent_link_list, k=num) if num is not None else ent_link_list
        change_num = 0
        for (obj_l, obj_r, prob) in random_list:
            success = self.__set_counterpart_and_prob(obj_l, obj_r, prob, force)
            success &= self.__set_counterpart_and_prob(obj_r, obj_l, prob, force)
            change_num += 1 if success else 0
        print("load num: " + str(len(random_list)) + "\t change num: " + str(change_num))

    def reset_ent_align_prob(self, func):
        for ent in self.kgs.kg_l.entity_set:
            idx = ent.id
            self.kgs.sub_ent_prob[idx] = func(self.kgs.sub_ent_prob[idx])
        for ent in self.kgs.kg_r.entity_set:
            idx = ent.id
            self.kgs.sup_ent_prob[idx] = func(self.kgs.sup_ent_prob[idx])

    def load_embedding(self, ent_emb_path, kg_l_mapping, kg_r_mapping):
        ent_emb = np.load(ent_emb_path)

        def load_emb_helper(kg, mapping_path):
            with open(mapping_path, "r", encoding="utf8") as f:
                for line in f.readlines():
                    if len(line.strip()) == 0:
                        continue
                    params = line.strip().split("\t")
                    ent_name, idx = params[0].strip(), int(params[1].strip())
                    ent = kg.entity_dict_by_name.get(ent_name)
                    if ent is not None:
                        ent.embedding = ent_emb[idx, :]

        load_emb_helper(self.kgs.kg_l, kg_l_mapping)
        load_emb_helper(self.kgs.kg_r, kg_r_mapping)
        self.kgs.kg_l.init_ent_embeddings()
        self.kgs.kg_r.init_ent_embeddings()

    @staticmethod
    def __result_writer(path, result_dict, title):
        with open(path, "a+", encoding="utf-8") as f:
            f.write("--- " + title + " ---\n\n")
            for ((obj_l, obj_r), prob_set) in result_dict.items():
                f.write(obj_l.name + "\t" + obj_r.name + "\t" + "\t".join(format(s, ".6f") for s in prob_set) + "\n")
            f.write("\n")

    @staticmethod
    def __params_loader_helper(dict_by_key: dict, key1, key2, value):
        if not dict_by_key.__contains__(key1):
            dict_by_key[key1] = dict()
        dict_by_key[key1][key2] = value

In [20]:
def construct_kg(path_r, path_a=None, sep='\t', name=None):
    kg = KG(name=name)
    if path_a is not None:
        with open(path_r, "r", encoding="utf-8") as f:
            for line in f.readlines():
                if len(line.strip()) == 0:
                    continue
                params = str.strip(line).split(sep=sep)
                if len(params) != 3:
                    print(line)
                    continue
                h, r, t = params[0].strip(), params[1].strip(), params[2].strip()
                kg.insert_relation_tuple(h, r, t)

        with open(path_a, "r", encoding="utf-8") as f:
            for line in f.readlines():
                if len(line.strip()) == 0:
                    continue
                params = str.strip(line).split(sep=sep)
                if len(params) != 3:
                    print(line)
                    continue
                # assert len(params) == 3
                e, a, v = params[0].strip(), params[1].strip(), params[2].strip()
                kg.insert_attribute_tuple(e, a, v)
    else:
        with open(path_r, "r", encoding="utf-8") as f:
            prev_line = ""
            for line in f.readlines():
                params = line.strip().split(sep)
                if len(params) != 3 or len(prev_line) == 0:
                    prev_line += "\n" if len(line.strip()) == 0 else line.strip()
                    continue
                prev_params = prev_line.strip().split(sep)
                e, a, v = prev_params[0].strip(), prev_params[1].strip(), prev_params[2].strip()
                prev_line = "".join(line)
                if len(e) == 0 or len(a) == 0 or len(v) == 0:
                    print("Exception: " + e)
                    continue
                if v.__contains__("http"):
                    kg.insert_relation_tuple(e, a, v)
                else:
                    kg.insert_attribute_tuple(e, a, v)
    kg.init()
    kg.print_kg_info()
    return kg

In [21]:
def construct_kgs(dataset_dir, name="KGs", load_chk=None):
    path_r_1 = os.path.join(dataset_dir, "rel_triples_1")
    path_a_1 = os.path.join(dataset_dir, "attr_triples_1")

    path_r_2 = os.path.join(dataset_dir, "rel_triples_2")
    path_a_2 = os.path.join(dataset_dir, "attr_triples_2")

    kg1 = construct_kg(path_r_1, path_a_1, name=str(name + "-KG1"))
    kg2 = construct_kg(path_r_2, path_a_2, name=str(name + "-KG2"))
    kgs = KGs(kg1=kg1, kg2=kg2)
    # load the previously saved PRASE model
    if load_chk is not None:
        kgs.util.load_params(load_chk)
    return kgs

In [22]:
# the balancing function for PRASE
def fusion_func(prob, x, y):
    return 0.8 * prob + 0.2 * np.dot(x, y) / (np.linalg.norm(x) * np.linalg.norm(y))


def run_init_iteration(kgs, ground_truth_path=None):
    kgs.run(test_path=ground_truth_path)


def run_prase_iteration(kgs, embed_dir, ground_truth_path=None, load_weight=1.0, reset_weight=1.0, load_ent=True,
                        load_emb=True,
                        init_reset=False, prase_func=None):
    if init_reset is True:
        # load_weight: scale the mapping probability predicted by the PARIS module if loading PRASE from check point
        kgs.util.reset_ent_align_prob(lambda x: reset_weight * x)

    # mapping feedback
    if load_ent is True:
        ent_links_path = os.path.join(embed_dir, "alignment_results_12")
        # load_weight: scale the mapping probability predicted by the embedding module
        kgs.util.load_ent_links(func=lambda x: load_weight * x, path=ent_links_path, force=True)

    # embedding feedback
    if load_emb is True:
        mapping_l, mapping_r = os.path.join(embed_dir, "kg1_ent_ids"), os.path.join(embed_dir, "kg2_ent_ids")
        ent_emb_path = os.path.join(embed_dir, "ent_embeds.npy")
        kgs.util.load_embedding(ent_emb_path, mapping_l, mapping_r)

    # set the function balancing the probability (from PARIS) and the embedding similarity
    kgs.set_fusion_func(prase_func)
    kgs.run(test_path=ground_truth_path)

In [44]:
__file__ = '/kaggle/input/d-w-15k'   # Changed here

In [47]:
if __name__ == '__main__':
    base, _ = os.path.split(os.path.abspath(__file__))
    dataset_name = "d-w-15k"              # Changed here
    # embed_module_name = "MultiKE"
    embed_module_name = "BootEA"

    dataset_path = os.path.join(os.path.join(base), dataset_name)   # changed here
    embed_output_path = os.path.join(dataset_path, embed_module_name)

    print("Construct KGs...")
    # load the KG files from relation and attribute triples to construct the KGs object
    # use load_chk to load the PARIS model from a check point
    # note that, due to the limitation of file size, we do not provide the check point file for performing PRASE
    # surprisingly, it may make the result better than the one reported in the paper
    kgs = construct_kgs(dataset_dir=dataset_path, name=dataset_name, load_chk=None)

    # set the number of processes
    kgs.set_worker_num(6)

    # set the iteration number of PARIS
    kgs.set_iteration(10)

    # ground truth mapping path
    ground_truth_mapping_path = os.path.join(dataset_path, "ent_links")

    # test the model and show the metrics
    # kgs.util.test(path=ground_truth_mapping_path, threshold=0.1)

    # using the following line of code to run the initial iteration of PRASE (i.e., PARIS, without any feedback)
    # the ground truth path is used to show the metrics during the iterations of PARIS
    # run_init_iteration(kgs=kgs, ground_truth_path=ground_truth_mapping_path)

    # run PRASE using both the embedding and mapping feedback
    run_prase_iteration(kgs, embed_dir=embed_output_path, prase_func=fusion_func,
                        ground_truth_path=ground_truth_mapping_path)

    # in the following, we store the mappings and check point files
    save_dir_name = "output" 
    base1 = "/kaggle/working/"           # Changed here
    save_dir_path = os.path.join(os.path.join(base1, save_dir_name), dataset_name)
    if not os.path.exists(save_dir_path):
        os.makedirs(save_dir_path)

    time_stamp = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())

    # save the check point
    check_point_dir = os.path.join(save_dir_path, "chk")
    check_point_name = "PRASE-" + embed_module_name + "@" + time_stamp
    check_point_file = os.path.join(check_point_dir, check_point_name)
    kgs.util.save_params(check_point_file)

    # save the mapping result
    result_dir = os.path.join(save_dir_path, "mapping")
    result_file_name = "PRASE-" + embed_module_name + "@" + time_stamp + ".txt"
    result_file = os.path.join(result_dir, result_file_name)
    kgs.util.save_results(result_file)

    # generate the input files (training data) for embedding module
    input_base = os.path.join(save_dir_path, "embed_input")
    input_dir_name = "PRASE-" + embed_module_name + "@" + time_stamp
    input_dir = os.path.join(input_base, input_dir_name)
    kgs.util.generate_input_for_embed_align(link_path=ground_truth_mapping_path, save_dir=input_dir, threshold=0.1)

Construct KGs...

Information of Knowledge Graph (d-w-15k-KG1):
- Relation Tuple Number: 73983
- Attribute Tuple Number: 66813
- Entity Number: 15000
- Relation Number: 167
- Attribute Number: 175
- Literal Number: 40614
- Functionality Statistics:
--- TOP-10 Relations (Func) ---
Name: http://dbpedia.org/ontology/training	Func: 1.0
Name: http://dbpedia.org/ontology/training-(INV)	Func: 1.0
Name: http://dbpedia.org/ontology/sourceConfluenceRegion	Func: 1.0
Name: http://dbpedia.org/ontology/ethnicity	Func: 1.0
Name: http://dbpedia.org/ontology/sourceConfluenceRegion-(INV)	Func: 1.0
Name: http://dbpedia.org/ontology/league	Func: 1.0
Name: http://dbpedia.org/ontology/photographer	Func: 1.0
Name: http://dbpedia.org/ontology/mainInterest	Func: 1.0
Name: http://dbpedia.org/ontology/photographer-(INV)	Func: 1.0
Name: http://dbpedia.org/ontology/mainInterest-(INV)	Func: 1.0
......
--- TOP-10 Relations (Func-Inv) ---
Name: http://dbpedia.org/ontology/series-(INV)	Func-Inv: 1.0
Name: http://dbped