# models

In [None]:
import math
from abc import ABC, abstractmethod
from typing import Tuple, List, Dict

import numpy as np
import torch
from torch import nn


class TKBCModel(nn.Module, ABC):
    @abstractmethod
    def get_rhs(self, chunk_begin: int, chunk_size: int):
        pass

    @abstractmethod
    def get_queries(self, queries: torch.Tensor):
        pass

    @abstractmethod
    def get_rhs_static(self, queries: torch.Tensor):
        pass

    @abstractmethod
    def score(self, x: torch.Tensor):
        pass

    @abstractmethod
    def forward_over_time(self, x: torch.Tensor):
        pass

    def get_ranking(
            self, queries: torch.Tensor,
            filters: Dict[Tuple[int, int, int], List[int]],
            batch_size: int = 1000, chunk_size: int = -1
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: filters[(lhs, rel, ts)] gives the elements to filter from ranking
        :param batch_size: maximum number of queries processed at once
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            while c_begin < self.sizes[2]:
                b_begin = 0
                rhs = self.get_rhs(c_begin, chunk_size)
                while b_begin < len(queries):
                    these_queries = queries[b_begin:b_begin + batch_size]
                    q = self.get_queries(these_queries)

                    # scores = q[0] @ rhs + 0.1 * q[1] @ self.get_rhs_static(c_begin, chunk_size)
                    # targets = self.score(these_queries)
                    scores_tem = q[0] @ rhs
                    scores_cs = q[1] @ self.get_rhs_static(c_begin, chunk_size)
                    targets_tem, targets_cs = self.score(these_queries)
                    # print("scores_cs:\n{}\n\ntargets_cs:\n{}\n".format(scores_cs, targets_cs))

                    # assert not torch.any(torch.isinf(scores)), "inf scores"
                    # assert not torch.any(torch.isnan(scores)), "nan scores"
                    # assert not torch.any(torch.isinf(targets)), "inf targets"
                    # assert not torch.any(torch.isnan(targets)), "nan targets"

                    # set filtered and true scores to -1e6 to be ignored
                    # take care that scores are chunked
                    for i, query in enumerate(these_queries):
                        filter_out = filters[(query[0].item(), query[1].item(), query[3].item())]
                        filter_out += [queries[b_begin + i, 2].item()]
                        if chunk_size < self.sizes[2]:
                            filter_in_chunk = [
                                int(x - c_begin) for x in filter_out
                                if c_begin <= x < c_begin + chunk_size
                            ]
                            scores_tem[i, torch.LongTensor(filter_in_chunk)] = -1e6
                        else:
                            scores_tem[i, torch.LongTensor(filter_out)] = -1e6
                    ranks[b_begin:b_begin + batch_size] += torch.sum(
                        (torch.mul(scores_tem >= targets_tem, scores_cs > targets_cs)).float(), dim=1
                    ).cpu()

                    b_begin += batch_size

                c_begin += chunk_size
        return ranks

    def get_auc(
            self, queries: torch.Tensor, batch_size: int = 1000
    ):
        """
        Returns filtered ranking for each queries.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, begin, end)
        :param batch_size: maximum number of queries processed at once
        :return:
        """
        all_scores, all_truth = [], []
        all_ts_ids = None
        with torch.no_grad():
            b_begin = 0
            while b_begin < len(queries):
                these_queries = queries[b_begin:b_begin + batch_size]
                scores = self.forward_over_time(these_queries)
                all_scores.append(scores.cpu().numpy())
                if all_ts_ids is None:
                    all_ts_ids = torch.arange(0, scores.shape[1]).cuda()[None, :]
                assert not torch.any(torch.isinf(scores) + torch.isnan(scores)), "inf or nan scores"
                truth = (all_ts_ids <= these_queries[:, 4][:, None]) * (all_ts_ids >= these_queries[:, 3][:, None])
                all_truth.append(truth.cpu().numpy())
                b_begin += batch_size

        return np.concatenate(all_truth), np.concatenate(all_scores)

    def get_time_ranking(
            self, queries: torch.Tensor, filters: List[List[int]], chunk_size: int = -1
    ):
        """
        Returns filtered ranking for a batch of queries ordered by timestamp.
        :param queries: a torch.LongTensor of quadruples (lhs, rel, rhs, timestamp)
        :param filters: ordered filters
        :param chunk_size: maximum number of candidates processed at once
        :return:
        """
        if chunk_size < 0:
            chunk_size = self.sizes[2]
        ranks = torch.ones(len(queries))
        with torch.no_grad():
            c_begin = 0
            q = self.get_queries(queries)
            targets = self.score(queries)
            while c_begin < self.sizes[2]:
                rhs = self.get_rhs(c_begin, chunk_size)
                scores = q @ rhs
                # set filtered and true scores to -1e6 to be ignored
                # take care that scores are chunked
                for i, (query, filter) in enumerate(zip(queries, filters)):
                    filter_out = filter + [query[2].item()]
                    if chunk_size < self.sizes[2]:
                        filter_in_chunk = [
                            int(x - c_begin) for x in filter_out
                            if c_begin <= x < c_begin + chunk_size
                        ]
                        max_to_filter = max(filter_in_chunk + [-1])
                        assert max_to_filter < scores.shape[1], f"fuck {scores.shape[1]} {max_to_filter}"
                        scores[i, filter_in_chunk] = -1e6
                    else:
                        scores[i, filter_out] = -1e6
                ranks += torch.sum(
                    (scores >= targets).float(), dim=1
                ).cpu()

                c_begin += chunk_size
        return ranks


class LCGE(TKBCModel):
    def __init__(
            self, sizes: Tuple[int, int, int, int], rank: int, Rules, w_static,
            no_time_emb=False, init_size: float = 1e-3
    ):
        super(LCGE, self).__init__()
        self.sizes = sizes
        self.rank = rank
        self.rank_static = rank // 20
        self.w_static = w_static

        self.embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * rank, sparse=True)
            for s in [sizes[0], sizes[1], sizes[3], sizes[1], 1]  # last embedding modules contains no_time embeddings
        ])
        self.static_embeddings = nn.ModuleList([
            nn.Embedding(s, 2 * self.rank_static, sparse=True)
            for s in [sizes[0], sizes[1]]
        ])
        
        self.self_attention = nn.MultiheadAttention(embed_dim=2 * rank, num_heads=2)  # 多头自注意力机制
        self.self_attention_s = nn.MultiheadAttention(embed_dim=2 * self.rank_static, num_heads=2)  # 多头自注意力机制
        
        self.embeddings[0].weight.data *= init_size
        self.embeddings[1].weight.data *= init_size
        self.embeddings[2].weight.data *= init_size
        self.embeddings[3].weight.data *= init_size
        self.embeddings[4].weight.data *= init_size  # time transition
        self.static_embeddings[0].weight.data *= init_size  # static entity embedding
        self.static_embeddings[1].weight.data *= init_size  # static relation embedding
        

        self.no_time_emb = no_time_emb
        # self.rule1_p1, self.rule1_p2, self.rule2_p1, self.rule2_p2, self.rule2_p3, self.rule2_p4 = Rules
        self.rule1_p1, self.rule1_p2, self.rule2_p1, self.rule2_p2, self.rule2_p3, self.rule2_p4 = Rules

    @staticmethod
    def has_time():
        return True

    def score(self, x):
#         print(x.shape)
#         print(x[:5])
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rel_no_time = self.embeddings[3](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])
        input_sequence = torch.cat([lhs, rel, rhs,time], dim=0)  
        # 将左实体、关系、右实体拼接成一个序列
        attended_output, _ = self.self_attention(input_sequence, input_sequence, input_sequence)
        
        lhs, rel, rhs,time = torch.split(attended_output, [lhs.size(0), rel.size(0), rhs.size(0),time.size(0)], dim=0)
        

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        full_rel = (rt[0] - rt[3]) + rnt[0], (rt[1] + rt[2]) + rnt[1]

        h_static = self.static_embeddings[0](x[:, 0])
        r_static = self.static_embeddings[1](x[:, 1])
        t_static = self.static_embeddings[0](x[:, 2])
        
        input_sequence_static = torch.cat([h_static, r_static, t_static], dim=0)  
        # 将左实体、关系、右实体拼接成一个序列
        attended_output_static, _ = self.self_attention_s(input_sequence_static, input_sequence_static, input_sequence_static)
        
        h_static, r_static, t_static = torch.split(attended_output_static, [h_static.size(0), r_static.size(0), t_static.size(0)], dim=0)


        h_static = h_static[:, :self.rank_static], h_static[:, self.rank_static:]
        r_static = r_static[:, :self.rank_static], r_static[:, self.rank_static:]
        t_static = t_static[:, :self.rank_static], t_static[:, self.rank_static:]
        # print("h size:{}\tr size:{}\ttsize:{}".format(h_static[0].shape, r_static[0].shape, t_static[0].shape))

        return torch.sum(
            (lhs[0] * full_rel[0] - lhs[1] * full_rel[1]) * rhs[0] +
            (lhs[1] * full_rel[0] + lhs[0] * full_rel[1]) * rhs[1],
            1, keepdim=True
        ), torch.sum(
            (h_static[0] * r_static[0] - h_static[1] * r_static[1]) * t_static[0] +
            (h_static[1] * r_static[0] + h_static[0] * r_static[1]) * t_static[1],
            1, keepdim=True
        )

    def forward(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rel_no_time = self.embeddings[3](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2](x[:, 3])
        transt = self.embeddings[4](torch.LongTensor([0]).cuda())
        # 提取输入数据的embedding：
        # lhs、rel、rhs、time和transt分别提取了左实体、关系、右实体、时间和时间转移的embedding。
        # 这些embedding是通过调用nn.Embedding层得到的。
        # print('lhs:', lhs.shape, '\n', 'rel:', rel.shape, '\n', 'rhs:', rhs.shape, '\n', 'tiem:', time.shape)
        # embedding后是(1000,4000)   1000为epoch的限制，每个单词转化为4000维的向量

        input_sequence = torch.cat([lhs, rel, rhs,time], dim=0)  
        # 将左实体、关系、右实体拼接成一个序列
        attended_output, _ = self.self_attention(input_sequence, input_sequence, input_sequence)
        
        lhs, rel, rhs,time = torch.split(attended_output, [lhs.size(0), rel.size(0), rhs.size(0),time.size(0)], dim=0)
        
        
        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        transt = transt[:, :self.rank], transt[:, self.rank:]
        # print(transt[1])
        # print('lhs[0]:', lhs[0].shape, '\n', 'rel[0]:', rel[0].shape, '\n', 'rhs[0]:', rhs[0].shape, '\n', 'tiem[0]:', time[0].shape)

        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        right = self.embeddings[0].weight
        right = right[:, :self.rank], right[:, self.rank:]
        # 对提取的embedding进行分割，将每个embedding向量按rank分成两部分，前后各一半
        # 分割为两部分的原因是因为要将向量映射为实部和虚部，为了使用复平面进行运算

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        rrt = rt[0] - rt[3], rt[1] + rt[2]
        full_rel = rrt[0] + rnt[0], rrt[1] + rnt[1]


        h_static = self.static_embeddings[0](x[:, 0])
        r_static = self.static_embeddings[1](x[:, 1])
        t_static = self.static_embeddings[0](x[:, 2])
        
        input_sequence_static = torch.cat([h_static, r_static, t_static], dim=0)  
        # 将左实体、关系、右实体拼接成一个序列
        attended_output_static, _ = self.self_attention_s(input_sequence_static, input_sequence_static, input_sequence_static)
        
        h_static, r_static, t_static = torch.split(attended_output_static, [h_static.size(0), r_static.size(0), t_static.size(0)], dim=0)

        h_static = h_static[:, :self.rank_static], h_static[:, self.rank_static:]
        r_static = r_static[:, :self.rank_static], r_static[:, self.rank_static:]
        t_static = t_static[:, :self.rank_static], t_static[:, self.rank_static:]
        # 静态，无time值的情况下的embedding处理
        right_static = self.static_embeddings[0].weight
        right_static = right_static[:, :self.rank_static], right_static[:, self.rank_static:]

        regularizer = (
            math.pow(2, 1 / 3) * torch.sqrt(lhs[0] ** 2 + lhs[1] ** 2),
            torch.sqrt(rrt[0] ** 2 + rrt[1] ** 2),
            torch.sqrt(rnt[0] ** 2 + rnt[1] ** 2),
            math.pow(2, 1 / 3) * torch.sqrt(rhs[0] ** 2 + rhs[1] ** 2),
            torch.sqrt(h_static[0] ** 2 + h_static[1] ** 2),
            torch.sqrt(r_static[0] ** 2 + r_static[1] ** 2),
            torch.sqrt(t_static[0] ** 2 + t_static[1] ** 2)
        )

        rule = 0.
        rule_num = 0
        for rel_1 in x[:, 1]:
            rel_1_str = str(rel_1.item())
            if rel_1_str in self.rule1_p2:
                rel1_emb = self.embeddings[3](rel_1)
                for rel_2 in self.rule1_p2[rel_1_str]:
                    weight_r = self.rule1_p2[rel_1_str][rel_2]
                    rel2_emb = self.embeddings[3](torch.LongTensor([int(rel_2)]).cuda())[0]
                    rule += weight_r * torch.sum(torch.abs(rel1_emb - rel2_emb) ** 3)
                    rule_num += 1

        for rel_1 in x[:, 1]:
            rel_1_str = str(rel_1.item())
            if rel_1_str in self.rule1_p2:
                rel1_emb = self.embeddings[3](rel_1)
                rel1_split = rel1_emb[:self.rank], rel1_emb[self.rank:]
                for rel_2 in self.rule1_p2[rel_1_str]:
                    weight_r = self.rule1_p2[rel_1_str][rel_2]
                    rel2_emb = self.embeddings[3](torch.LongTensor([int(rel_2)]).cuda())[0]
                    rel2_split = rel2_emb[:self.rank], rel2_emb[self.rank:]
                    tt = rel2_split[0] * transt[0][0], rel2_split[1] * transt[0][0], rel2_split[0] * transt[1][0], \
                         rel2_split[1] * transt[1][0]
                    rtt = tt[0] - tt[3], tt[1] + tt[2]
                    # print("rel1_split:\t", rel1_split[0])
                    rule += weight_r * (torch.sum(torch.abs(rel1_split[0] - rtt[0]) ** 3) + torch.sum(
                        torch.abs(rel1_split[1] - rtt[1]) ** 3))
                    rule_num += 1

        for rel_1 in x[:, 1]:
            if rel_1 in self.rule2_p1:
                rel1_emb = self.embeddings[3](rel_1)
                rel1_split = rel1_emb[:self.rank], rel1_emb[self.rank:]
                for body in self.rule2_p1[rel_1]:
                    rel_2, rel_3 = body
                    weight_r = self.rule2_p1[rel_1][body]
                    rel2_emb = self.embeddings[3](torch.LongTensor([rel_2]).cuda())[0]
                    rel3_emb = self.embeddings[3](torch.LongTensor([rel_3]).cuda())[0]
                    rel2_split = rel2_emb[:self.rank], rel2_emb[self.rank:]
                    rel3_split = rel3_emb[:self.rank], rel3_emb[self.rank:]
                    tt2 = rel2_split[0] * transt[0][0], rel2_split[1] * transt[0][0], rel2_split[0] * transt[1][0], \
                          rel2_split[1] * transt[1][0]
                    rtt2 = tt2[0] - tt2[3], tt2[1] + tt2[2]
                    ttt2 = rtt2[0] * transt[0][0], rtt2[1] * transt[0][0], rtt2[0] * transt[1][0], rtt2[1] * transt[1][
                        0]
                    rttt2 = ttt2[0] - ttt2[3], ttt2[1] + ttt2[2]
                    tt3 = rel3_split[0] * transt[0][0], rel3_split[1] * transt[0][0], rel3_split[0] * transt[1][0], \
                          rel3_split[1] * transt[1][0]
                    rtt3 = tt3[0] - tt3[3], tt3[1] + tt3[2]
                    tt = rtt3[0] * rttt2[0], rtt3[1] * rttt2[0], rtt3[0] * rttt2[1], rtt3[1] * rttt2[1]
                    rtt = tt[0] - tt[3], tt[1] + tt[2]
                    # print("rel1_split:\t", rel1_split[0])
                    rule += weight_r * (torch.sum(torch.abs(rel1_split[0] - rtt[0]) ** 3) + torch.sum(
                        torch.abs(rel1_split[1] - rtt[1]) ** 3))
                    rule_num += 1

        for rel_1 in x[:, 1]:
            if rel_1 in self.rule2_p2:
                rel1_emb = self.embeddings[3](rel_1)
                rel1_split = rel1_emb[:self.rank], rel1_emb[self.rank:]
                for body in self.rule2_p2[rel_1]:
                    rel_2, rel_3 = body
                    weight_r = self.rule2_p2[rel_1][body]
                    rel2_emb = self.embeddings[3](torch.LongTensor([rel_2]).cuda())[0]
                    rel3_emb = self.embeddings[3](torch.LongTensor([rel_3]).cuda())[0]
                    rel2_split = rel2_emb[:self.rank], rel2_emb[self.rank:]
                    rel3_split = rel3_emb[:self.rank], rel3_emb[self.rank:]
                    tt2 = rel2_split[0] * transt[0][0], rel2_split[1] * transt[0][0], rel2_split[0] * transt[1][0], \
                          rel2_split[1] * transt[1][0]
                    rtt2 = tt2[0] - tt2[3], tt2[1] + tt2[2]
                    tt3 = rel3_split[0] * transt[0][0], rel3_split[1] * transt[0][0], rel3_split[0] * transt[1][0], \
                          rel3_split[1] * transt[1][0]
                    rtt3 = tt3[0] - tt3[3], tt3[1] + tt3[2]
                    tt = rtt3[0] * rtt2[0], rtt3[1] * rtt2[0], rtt3[0] * rtt2[1], rtt3[1] * rtt2[1]
                    rtt = tt[0] - tt[3], tt[1] + tt[2]
                    # print("rel1_split:\t", rel1_split[0])
                    rule += weight_r * (torch.sum(torch.abs(rel1_split[0] - rtt[0]) ** 3) + torch.sum(
                        torch.abs(rel1_split[1] - rtt[1]) ** 3))
                    rule_num += 1

        for rel_1 in x[:, 1]:
            if rel_1 in self.rule2_p3:
                rel1_emb = self.embeddings[3](rel_1)
                rel1_split = rel1_emb[:self.rank], rel1_emb[self.rank:]
                for body in self.rule2_p3[rel_1]:
                    rel_2, rel_3 = body
                    weight_r = self.rule2_p3[rel_1][body]
                    rel2_emb = self.embeddings[3](torch.LongTensor([rel_2]).cuda())[0]
                    rel3_emb = self.embeddings[3](torch.LongTensor([rel_3]).cuda())[0]
                    rel2_split = rel2_emb[:self.rank], rel2_emb[self.rank:]
                    rtt3 = rel3_emb[:self.rank], rel3_emb[self.rank:]
                    tt2 = rel2_split[0] * transt[0][0], rel2_split[1] * transt[0][0], rel2_split[0] * transt[1][0], \
                          rel2_split[1] * transt[1][0]
                    rtt2 = tt2[0] - tt2[3], tt2[1] + tt2[2]
                    tt = rtt3[0] * rtt2[0], rtt3[1] * rtt2[0], rtt3[0] * rtt2[1], rtt3[1] * rtt2[1]
                    rtt = tt[0] - tt[3], tt[1] + tt[2]
                    # print("rel1_split:\t", rel1_split[0])
                    rule += weight_r * (torch.sum(torch.abs(rel1_split[0] - rtt[0]) ** 3) + torch.sum(
                        torch.abs(rel1_split[1] - rtt[1]) ** 3))
                    rule_num += 1

        for rel_1 in x[:, 1]:
            if rel_1 in self.rule2_p4:
                rel1_emb = self.embeddings[3](rel_1)
                rel1_split = rel1_emb[:self.rank], rel1_emb[self.rank:]
                for body in self.rule2_p4[rel_1]:
                    rel_2, rel_3 = body
                    weight_r = self.rule2_p4[rel_1][body]
                    rel2_emb = self.embeddings[3](torch.LongTensor([rel_2]).cuda())[0]
                    rel3_emb = self.embeddings[3](torch.LongTensor([rel_3]).cuda())[0]
                    rtt2 = rel2_emb[:self.rank], rel2_emb[self.rank:]
                    rtt3 = rel3_emb[:self.rank], rel3_emb[self.rank:]
                    tt = rtt3[0] * rtt2[0], rtt3[1] * rtt2[0], rtt3[0] * rtt2[1], rtt3[1] * rtt2[1]
                    rtt = tt[0] - tt[3], tt[1] + tt[2]
                    # print("rel1_split:\t", rel1_split[0])
                    rule += weight_r * (torch.sum(torch.abs(rel1_split[0] - rtt[0]) ** 3) + torch.sum(
                        torch.abs(rel1_split[1] - rtt[1]) ** 3))
                    rule_num += 1

        rule = rule / rule_num
        return (
            (lhs[0] * full_rel[0] - lhs[1] * full_rel[1]) @ right[0].t() +
            (lhs[1] * full_rel[0] + lhs[0] * full_rel[1]) @ right[1].t(),
            (h_static[0] * r_static[0] - h_static[1] * r_static[1]) @ right_static[0].t() +
            (h_static[1] * r_static[0] + h_static[0] * r_static[1]) @ right_static[1].t(),
            regularizer,
            self.embeddings[2].weight[:-1] if self.no_time_emb else self.embeddings[2].weight,
            rule
        )

    def forward_over_time(self, x):
        lhs = self.embeddings[0](x[:, 0])
        rel = self.embeddings[1](x[:, 1])
        rhs = self.embeddings[0](x[:, 2])
        time = self.embeddings[2].weight

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        rhs = rhs[:, :self.rank], rhs[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]

        rel_no_time = self.embeddings[3](x[:, 1])
        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        score_time = (
                (lhs[0] * rel[0] * rhs[0] - lhs[1] * rel[1] * rhs[0] -
                 lhs[1] * rel[0] * rhs[1] + lhs[0] * rel[1] * rhs[1]) @ time[0].t() +
                (lhs[1] * rel[0] * rhs[0] - lhs[0] * rel[1] * rhs[0] +
                 lhs[0] * rel[0] * rhs[1] - lhs[1] * rel[1] * rhs[1]) @ time[1].t()
        )
        base = torch.sum(
            (lhs[0] * rnt[0] * rhs[0] - lhs[1] * rnt[1] * rhs[0] -
             lhs[1] * rnt[0] * rhs[1] + lhs[0] * rnt[1] * rhs[1]) +
            (lhs[1] * rnt[1] * rhs[0] - lhs[0] * rnt[0] * rhs[0] +
             lhs[0] * rnt[1] * rhs[1] - lhs[1] * rnt[0] * rhs[1]),
            dim=1, keepdim=True
        )
        return score_time + base

    def get_rhs(self, chunk_begin: int, chunk_size: int):
        return self.embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_rhs_static(self, chunk_begin: int, chunk_size: int):
        return self.static_embeddings[0].weight.data[
               chunk_begin:chunk_begin + chunk_size
               ].transpose(0, 1)

    def get_queries(self, queries: torch.Tensor):
        lhs = self.embeddings[0](queries[:, 0])
        rel = self.embeddings[1](queries[:, 1])
        rel_no_time = self.embeddings[3](queries[:, 1])
        rhs = self.embeddings[0](queries[:, 2])
        time = self.embeddings[2](queries[:, 3])
        
        input_sequence = torch.cat([lhs, rel, rhs,time], dim=0)  
        # 将左实体、关系、右实体拼接成一个序列
        attended_output, _ = self.self_attention(input_sequence, input_sequence, input_sequence)
        
        lhs, rel, rhs,time = torch.split(attended_output, [lhs.size(0), rel.size(0), rhs.size(0),time.size(0)], dim=0)
        
        

        lhs = lhs[:, :self.rank], lhs[:, self.rank:]
        rel = rel[:, :self.rank], rel[:, self.rank:]
        time = time[:, :self.rank], time[:, self.rank:]
        rnt = rel_no_time[:, :self.rank], rel_no_time[:, self.rank:]

        rt = rel[0] * time[0], rel[1] * time[0], rel[0] * time[1], rel[1] * time[1]
        full_rel = (rt[0] - rt[3]) + rnt[0], (rt[1] + rt[2]) + rnt[1]

        h_static = self.static_embeddings[0](queries[:, 0])
        r_static = self.static_embeddings[1](queries[:, 1])
        t_static = self.static_embeddings[0](queries[:, 2])
        
        input_sequence_static = torch.cat([h_static, r_static, t_static], dim=0)  
        # 将左实体、关系、右实体拼接成一个序列
        attended_output_static, _ = self.self_attention_s(input_sequence_static, input_sequence_static, input_sequence_static)
        
        h_static, r_static, t_static = torch.split(attended_output_static, [h_static.size(0), r_static.size(0), t_static.size(0)], dim=0)

        
        h_static = h_static[:, :self.rank_static], h_static[:, self.rank_static:]
        r_static = r_static[:, :self.rank_static], r_static[:, self.rank_static:]

        return torch.cat([
            lhs[0] * full_rel[0] - lhs[1] * full_rel[1],
            lhs[1] * full_rel[0] + lhs[0] * full_rel[1]
        ], 1), torch.cat([
            h_static[0] * r_static[0] - h_static[1] * r_static[1],
            h_static[1] * r_static[0] + h_static[0] * r_static[1]
        ], 1)


# Regularizers_rule

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates.

from abc import ABC, abstractmethod
from typing import Tuple, Optional

import torch
from torch import nn


class Regularizer(nn.Module, ABC):
    @abstractmethod
    def forward(self, factors: Tuple[torch.Tensor]):
        pass

class N3(Regularizer):
    def __init__(self, weight: float):
        super(N3, self).__init__()
        self.weight = weight

    def forward(self, factors):
        norm = 0
        for f in factors:
            norm += self.weight * torch.sum(torch.abs(f)**3)
        return norm / factors[0].shape[0]


class Lambda3(Regularizer):
    def __init__(self, weight: float):
        super(Lambda3, self).__init__()
        self.weight = weight

    def forward(self, factor):
        ddiff = factor[1:] - factor[:-1]
        rank = int(ddiff.shape[1] / 2)
        diff = torch.sqrt(ddiff[:, :rank]**2 + ddiff[:, rank:]**2)**3
        return self.weight * torch.sum(diff) / (factor.shape[0] - 1)


class RuleSim(Regularizer):
    def __init__(self, weight: float):
        super(RuleSim, self).__init__()
        self.weight = weight

    def forward(self, factors):
        norm = self.weight * factors
        return norm

In [None]:
import os
current_file_path = os.getcwd()
current_file_path

# Datasets_lcge

In [None]:
from pathlib import Path
import pickle
from collections import defaultdict
from typing import Dict, Tuple, List

from sklearn.metrics import average_precision_score

import numpy as np
import torch
import os
# 获取当前文件的路径
current_file_path = os.getcwd()

# 构建数据文件夹路径
DATA_PATH = os.path.join(current_file_path, 'data')


class TemporalDataset(object):
    def __init__(self, name: str):
        self.root = Path(DATA_PATH) / name
        # 数据，存到字典，值为四元组的id
        self.data = {}
        for f in ['train', 'test', 'valid']:
            in_file = open(str(self.root / (f + '.pickle')), 'rb')
            self.data[f] = pickle.load(in_file)

        maxis = np.max(self.data['train'], axis=0)
        self.n_entities = int(max(maxis[0], maxis[2]) + 1)  # 实体数为主宾语相加
        self.n_predicates = int(maxis[1] + 1)
        self.n_predicates *= 2  # 谓语动词为双向，有部分关系是可逆的
        if maxis.shape[0] > 4:  # 如果为5维，是有时间头，有时间尾，四维则不是
            self.n_timestamps = max(int(maxis[3] + 1), int(maxis[4] + 1))
        else:
            self.n_timestamps = int(maxis[3] + 1)
            
        try:
            # 检测时间戳是否规则，wikidata 数据集不规则，需要二次处理
            inp_f = open(str(self.root / f'ts_diffs.pickle'), 'rb')
            self.time_diffs = torch.from_numpy(pickle.load(inp_f)).cuda().float()
            # print("Assume all timestamps are regularly spaced")
            # self.time_diffs = None
            inp_f.close()
        except OSError:
            print("Assume all timestamps are regularly spaced")
            self.time_diffs = None

        try:
            # 检测数据集是否含有event—list，如果没有不使用时间间隔和事件评估
            e = open(str(self.root / f'event_list_all.pickle'), 'rb')
            self.events = pickle.load(e)
            e.close()

            f = open(str(self.root / f'ts_id'), 'rb')
            dictionary = pickle.load(f)
            f.close()
            self.timestamps = sorted(dictionary.keys())
        except OSError:
            print("Not using time intervals and events eval")
            self.events = None

        if self.events is None:
            # 如果无evert—list，使用to—skip中存储的边，将主谓语进行链接
            inp_f = open(str(self.root / f'to_skip.pickle'), 'rb')
            self.to_skip: Dict[str, Dict[Tuple[int, int, int], List[int]]] = pickle.load(inp_f)
            inp_f.close()



        # If dataset has events, it's wikidata.
        # For any relation that has no beginning & no end:
        # add special beginning = end = no_timestamp, increase n_timestamps by one.

    def has_intervals(self):
        return self.events is not None

    def get_examples(self, split):
        # split为字符串，test，train，valid
        return self.data[split]

    def get_train(self):
        # 主宾语换位置，边反向，返回值为正向反向堆叠在一块，相当于把图谱所有边拿出来
        copy = np.copy(self.data['train'])
        print("\nexamples:\n", copy.shape)
        tmp = np.copy(copy[:, 0])
        copy[:, 0] = copy[:, 2]
        copy[:, 2] = tmp
        copy[:, 1] += self.n_predicates // 2  # has been multiplied by two.
        return np.vstack((self.data['train'], copy))

    def eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10)
    ):
        if self.events is not None:
            return self.time_eval(model, split, n_queries, 'rhs', at)
        test = self.get_examples(split)
        examples = torch.from_numpy(test.astype('int64')).cuda()
        missing = [missing_eval]
        if missing_eval == 'both':
            missing = ['rhs', 'lhs']

        mean_reciprocal_rank = {}
        hits_at = {}

        for m in missing:
            q = examples.clone()
            if n_queries > 0:
                permutation = torch.randperm(len(examples))[:n_queries]
                q = examples[permutation]
            if m == 'lhs':
                tmp = torch.clone(q[:, 0])
                q[:, 0] = q[:, 2]
                q[:, 2] = tmp
                q[:, 1] += self.n_predicates // 2
            ranks = model.get_ranking(q, self.to_skip[m], batch_size=500)
            mean_reciprocal_rank[m] = torch.mean(1. / ranks).item()
            hits_at[m] = torch.FloatTensor((list(map(
                lambda x: torch.mean((ranks <= x).float()).item(),
                at
            ))))

        return mean_reciprocal_rank, hits_at

    def time_eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'both',
            at: Tuple[int] = (1, 3, 10)
    ):
        assert missing_eval == 'rhs', "other evals not implemented"
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        time_range = test.float()
        sampled_time = (
                torch.rand(time_range.shape[0]) * (time_range[:, 4] - time_range[:, 3]) + time_range[:, 3]
        ).round().long()
        has_end = (time_range[:, 4] != (self.n_timestamps - 1))
        has_start = (time_range[:, 3] > 0)

        masks = {
            'full_time': has_end + has_start,
            'only_begin': has_start * (~has_end),
            'only_end': has_end * (~has_start),
            'no_time': (~has_end) * (~has_start)
        }

        with_time = torch.cat((
            sampled_time.unsqueeze(1),
            time_range[:, 0:3].long(),
            masks['full_time'].long().unsqueeze(1),
            masks['only_begin'].long().unsqueeze(1),
            masks['only_end'].long().unsqueeze(1),
            masks['no_time'].long().unsqueeze(1),
        ), 1)
        # generate events
        eval_events = sorted(with_time.tolist())

        to_filter: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))

        id_event = 0
        id_timeline = 0
        batch_size = 100
        to_filter_batch = []
        cur_batch = []

        ranks = {
            'full_time': [], 'only_begin': [], 'only_end': [], 'no_time': [],
            'all': []
        }
        while id_event < len(eval_events):
            # Follow timeline to add events to filters
            while id_timeline < len(self.events) and self.events[id_timeline][0] <= eval_events[id_event][3]:
                date, event_type, (lhs, rel, rhs) = self.events[id_timeline]
                if event_type < 0:  # begin
                    to_filter[(lhs, rel)][rhs] += 1
                if event_type > 0:  # end
                    to_filter[(lhs, rel)][rhs] -= 1
                    if to_filter[(lhs, rel)][rhs] == 0:
                        del to_filter[(lhs, rel)][rhs]
                id_timeline += 1
            date, lhs, rel, rhs, full_time, only_begin, only_end, no_time = eval_events[id_event]

            to_filter_batch.append(sorted(to_filter[(lhs, rel)].keys()))
            cur_batch.append((lhs, rel, rhs, date, full_time, only_begin, only_end, no_time))
            # once a batch is ready, call get_ranking and reset
            if len(cur_batch) == batch_size or id_event == len(eval_events) - 1:
                cuda_batch = torch.cuda.LongTensor(cur_batch)
                bbatch = torch.LongTensor(cur_batch)
                batch_ranks = model.get_time_ranking(cuda_batch[:, :4], to_filter_batch, 500000)

                ranks['full_time'].append(batch_ranks[bbatch[:, 4] == 1])
                ranks['only_begin'].append(batch_ranks[bbatch[:, 5] == 1])
                ranks['only_end'].append(batch_ranks[bbatch[:, 6] == 1])
                ranks['no_time'].append(batch_ranks[bbatch[:, 7] == 1])

                ranks['all'].append(batch_ranks)
                cur_batch = []
                to_filter_batch = []
            id_event += 1

        ranks = {x: torch.cat(ranks[x]) for x in ranks if len(ranks[x]) > 0}
        mean_reciprocal_rank = {x: torch.mean(1. / ranks[x]).item() for x in ranks if len(ranks[x]) > 0}
        hits_at = {z: torch.FloatTensor((list(map(
            lambda x: torch.mean((ranks[z] <= x).float()).item(),
            at
        )))) for z in ranks if len(ranks[z]) > 0}

        res = {
            ('MRR_'+x): y for x, y in mean_reciprocal_rank.items()
        }
        res.update({('hits@_'+x): y for x, y in hits_at.items()})
        return res

    def breakdown_time_eval(
            self, model: TKBCModel, split: str, n_queries: int = -1, missing_eval: str = 'rhs',
    ):
        # 在ICEWS14中未使用
        assert missing_eval == 'rhs', "other evals not implemented"
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        time_range = test.float()
        sampled_time = (
                torch.rand(time_range.shape[0]) * (time_range[:, 4] - time_range[:, 3]) + time_range[:, 3]
        ).round().long()
        has_end = (time_range[:, 4] != (self.n_timestamps - 1))
        has_start = (time_range[:, 3] > 0)

        masks = {
            'full_time': has_end + has_start,
            'only_begin': has_start * (~has_end),
            'only_end': has_end * (~has_start),
            'no_time': (~has_end) * (~has_start)
        }

        with_time = torch.cat((
            sampled_time.unsqueeze(1),
            time_range[:, 0:3].long(),
            masks['full_time'].long().unsqueeze(1),
            masks['only_begin'].long().unsqueeze(1),
            masks['only_end'].long().unsqueeze(1),
            masks['no_time'].long().unsqueeze(1),
        ), 1)
        # generate events
        eval_events = sorted(with_time.tolist())

        to_filter: Dict[Tuple[int, int], Dict[int, int]] = defaultdict(lambda: defaultdict(int))

        id_event = 0
        id_timeline = 0
        batch_size = 100
        to_filter_batch = []
        cur_batch = []

        ranks = defaultdict(list)
        while id_event < len(eval_events):
            # Follow timeline to add events to filters
            while id_timeline < len(self.events) and self.events[id_timeline][0] <= eval_events[id_event][3]:
                date, event_type, (lhs, rel, rhs) = self.events[id_timeline]
                if event_type < 0:  # begin
                    to_filter[(lhs, rel)][rhs] += 1
                if event_type > 0:  # end
                    to_filter[(lhs, rel)][rhs] -= 1
                    if to_filter[(lhs, rel)][rhs] == 0:
                        del to_filter[(lhs, rel)][rhs]
                id_timeline += 1
            date, lhs, rel, rhs, full_time, only_begin, only_end, no_time = eval_events[id_event]

            to_filter_batch.append(sorted(to_filter[(lhs, rel)].keys()))
            cur_batch.append((lhs, rel, rhs, date, full_time, only_begin, only_end, no_time))
            # once a batch is ready, call get_ranking and reset
            if len(cur_batch) == batch_size or id_event == len(eval_events) - 1:
                cuda_batch = torch.cuda.LongTensor(cur_batch)
                bbatch = torch.LongTensor(cur_batch)
                batch_ranks = model.get_time_ranking(cuda_batch[:, :4], to_filter_batch, 500000)
                for rank, predicate in zip(batch_ranks, bbatch[:, 1]):
                    ranks[predicate.item()].append(rank.item())
                cur_batch = []
                to_filter_batch = []
            id_event += 1

        ranks = {x: torch.FloatTensor(ranks[x]) for x in ranks}
        sum_reciprocal_rank = {x: torch.sum(1. / ranks[x]).item() for x in ranks}

        return sum_reciprocal_rank

    def time_AUC(self, model: TKBCModel, split: str, n_queries: int = -1):
        # AUC面积值越接近1，模型正确率越高，0.5等同于随机猜测
        test = torch.from_numpy(
            self.get_examples(split).astype('int64')
        )
        if n_queries > 0:
            permutation = torch.randperm(len(test))[:n_queries]
            test = test[permutation]

        truth, scores = model.get_auc(test.cuda())

        return {
            'micro': average_precision_score(truth, scores, average='micro'),
            'macro': average_precision_score(truth, scores, average='macro')
        }


    def get_shape(self):
        # 获取图谱四元组的尺寸，不是tensor的shape
        return self.n_entities, self.n_predicates, self.n_entities, self.n_timestamps


# Optimizers

In [None]:
# Copyright (c) Facebook, Inc. and its affiliates.

import torch
import tqdm
from torch import nn
import torch.nn.functional as F
from torch import optim
from torchviz import make_dot


class TKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer, temporal_regularizer: Regularizer, rule_regularizer: Regularizer,
            optimizer: optim.Optimizer, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.emb_regularizer = emb_regularizer
        self.temporal_regularizer = temporal_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose
        self.rule_regularizer = rule_regularizer
        

        self.run_once = True

    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')
        loss_static = nn.CrossEntropyLoss(reduction='mean')
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                input_batch = actual_examples[
                              b_begin:b_begin + self.batch_size
                              ].cuda()
                predictions, pred_static, factors, time, rule = self.model.forward(input_batch)
                truth = input_batch[:, 2]





                if self.run_once:
                    self.run_once = False
                    print("prediction.shape:",predictions.shape)
                    print("truth.shape:",truth.shape)
                    # 对 logits 应用 softmax 函数
                    probs = F.softmax(predictions, dim=0)
                    print(probs.shape)
                    max_prob, max_index = torch.max(probs, dim=1)
                    # 找到概率最大的值和对应的下标
#                     print(max_prob[:30],"\n", max_index[:30],"\n",truth[:30])
                    # 计算预测的类别与真实标签相同的数量
                    num_correct = (max_index == truth).sum().item()
                    # 计算总样本数量
                    total_samples = len(truth)
                    # 计算正确率
                    accuracy = num_correct / total_samples
                    print(f"Accuracy: {accuracy * 100:.2f}%")
                
                
                
                    
                l_fit = loss(predictions, truth)
                l_static = loss_static(pred_static, truth)
                l_reg = self.emb_regularizer.forward(factors)
                l_time = torch.zeros_like(l_reg)
                l_rule = self.rule_regularizer.forward(rule)
                if time is not None:
                    l_time = self.temporal_regularizer.forward(time)
                l = l_fit + 0.1 * l_static + l_reg + l_time

                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(input_batch.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.5f}',
                    loss_cs=f'{l_static.item():.5f}',
                    reg=f'{l_reg.item():.5f}',
                    cont=f'{l_time.item():.5f}',
                    rule=f'{l_rule.item():.5f}'
                )
                # if b_begin == 0:
                #     output = self.model.forward(input_batch)
                #     make_dot(output[0], params=dict(self.model.named_parameters())).render("predictions_compute_graph",
                #                                                                            format="png")
                #     make_dot(output[1], params=dict(self.model.named_parameters())).render("pred_static_compute_graph",
                #                                                                            format="png")
                #     print('drawing done')


class IKBCOptimizer(object):
    def __init__(
            self, model: TKBCModel,
            emb_regularizer: Regularizer, temporal_regularizer: Regularizer,
            optimizer: optim.Optimizer, dataset: TemporalDataset, batch_size: int = 256,
            verbose: bool = True
    ):
        self.model = model
        self.dataset = dataset
        self.emb_regularizer = emb_regularizer
        self.temporal_regularizer = temporal_regularizer
        self.optimizer = optimizer
        self.batch_size = batch_size
        self.verbose = verbose

    def epoch(self, examples: torch.LongTensor):
        actual_examples = examples[torch.randperm(examples.shape[0]), :]
        loss = nn.CrossEntropyLoss(reduction='mean')
        with tqdm.tqdm(total=examples.shape[0], unit='ex', disable=not self.verbose) as bar:
            bar.set_description(f'train loss')
            b_begin = 0
            while b_begin < examples.shape[0]:
                time_range = actual_examples[b_begin:b_begin + self.batch_size].cuda()

                ## RHS Prediction loss
                sampled_time = (
                        torch.rand(time_range.shape[0]).cuda() * (time_range[:, 4] - time_range[:, 3]).float() +
                        time_range[:, 3].float()
                ).round().long()
                with_time = torch.cat((time_range[:, 0:3], sampled_time.unsqueeze(1)), 1)

                predictions, factors, time = self.model.forward(with_time)
                truth = with_time[:, 2]

                l_fit = loss(predictions, truth)

                ## Time prediction loss (ie cross entropy over time)
                time_loss = 0.
                if self.model.has_time():
                    filtering = ~(
                            (time_range[:, 3] == 0) *
                            (time_range[:, 4] == (self.dataset.n_timestamps - 1))
                    )  # NOT no begin and no end
                    these_examples = time_range[filtering, :]
                    truth = (
                            torch.rand(these_examples.shape[0]).cuda() * (
                            these_examples[:, 4] - these_examples[:, 3]).float() +
                            these_examples[:, 3].float()
                    ).round().long()
                    time_predictions = self.model.forward_over_time(these_examples[:, :3].cuda().long())
                    time_loss = loss(time_predictions, truth.cuda())

                l_reg = self.emb_regularizer.forward(factors)
                l_time = torch.zeros_like(l_reg)
                if time is not None:
                    l_time = self.temporal_regularizer.forward(time)
                l = l_fit + l_reg + l_time + time_loss

                self.optimizer.zero_grad()
                l.backward()
                self.optimizer.step()
                b_begin += self.batch_size
                bar.update(with_time.shape[0])
                bar.set_postfix(
                    loss=f'{l_fit.item():.0f}',
                    loss_time=f'{time_loss if type(time_loss) == float else time_loss.item() :.0f}',
                    reg=f'{l_reg.item():.0f}',
                    cont=f'{l_time.item():.4f}'
                )


# Learner_lcge

In [None]:
import argparse
import json
from typing import Dict

import torch
from torch import optim
from torch.utils.tensorboard import SummaryWriter

# 命令行输入参数
# python learner_lcge.py --dataset ICEWS14 --model LCGE --rank(分解等级) 2000
# --emb_reg(嵌入正则化强度) 0.005 --time_reg(时间戳正则化强度) 0.01 --rule_reg(规则正则化强度) 0.01 --max_epoch 1000
# --weight_static(静态分数的权重) 0.1 --learning_rate(学习率) 0.1

parser = argparse.ArgumentParser(
    description="Logic and Commonsense-Guided Temporal KGE"
)
parser.add_argument(
    '--dataset', type=str, default="ICEWS14",
    help="Dataset name"
)
models = [
    'LCGE'
]
parser.add_argument(
    '--model', choices=models,default="LCGE",
    help="Model in {}".format(models)
)
parser.add_argument(
    '--max_epochs', default=200, type=int,
    help="Number of epochs."
)
parser.add_argument(
    '--valid_freq', default=2, type=int,
    help="Number of epochs between each valid."
)

parser.add_argument(
    '--rank', default=2000, type=int,
    help="Factorization rank."
)
parser.add_argument(
    '--batch_size', default=1000, type=int,
    help="Batch size."
)
parser.add_argument(
    '--learning_rate', default=0.1, type=float,
    help="Learning rate"
)
parser.add_argument(
    '--emb_reg', default=0.000005, type=float,
    help="Embedding regularizer strength"
)
parser.add_argument(
    '--time_reg', default=0.01, type=float,
    help="Timestamp regularizer strength"
)
parser.add_argument(
    '--no_time_emb', default=False, action="store_true",
    help="Use a specific embedding for non temporal relations"
)
parser.add_argument(
    '--rule_reg', default=0.01, type=float,
    help="Rule regularizer strength"
)
parser.add_argument(
    '--weight_static', default=0.1, type=float,
    help="Weight of static score"
)

args = vars(parser.parse_args([])) 


print("默认参数：", args)


In [None]:
list_args=[]


In [None]:
list_args.append(args)
print(list_args)

In [None]:
args['dataset']

In [None]:
dataset = TemporalDataset(args["dataset"])

with open(current_file_path+"/src_data/rulelearning/" + args["dataset"] + "/rule1_p1.json", 'r') as load_rule1_p1:
    rule1_p1 = json.load(load_rule1_p1)
with open("./src_data/rulelearning/" + args["dataset"] + "/rule1_p2.json", 'r') as load_rule1_p2:
    rule1_p2 = json.load(load_rule1_p2)

f = open("./src_data/rulelearning/" + args["dataset"] + "/rule2_p1.txt", 'r')
rule2_p1 = {}
for line in f:
    head, body1, body2, confi = line.strip().split("\t")
    head, body1, body2, confi = int(head), int(body1), int(body2), float(confi)
    if head not in rule2_p1:
        rule2_p1[head] = {}
    rule2_p1[head][(body1, body2)] = confi
f.close()

f = open("./src_data/rulelearning/" + args["dataset"] + "/rule2_p2.txt", 'r')
rule2_p2 = {}
for line in f:
    head, body1, body2, confi = line.strip().split("\t")
    head, body1, body2, confi = int(head), int(body1), int(body2), float(confi)
    if head not in rule2_p2:
        rule2_p2[head] = {}
    rule2_p2[head][(body1, body2)] = confi
f.close()

f = open("./src_data/rulelearning/" + args["dataset"] + "/rule2_p3.txt", 'r')
rule2_p3 = {}
for line in f:
    head, body1, body2, confi = line.strip().split("\t")
    head, body1, body2, confi = int(head), int(body1), int(body2), float(confi)
    if head not in rule2_p3:
        rule2_p3[head] = {}
    rule2_p3[head][(body1, body2)] = confi
f.close()

f = open("./src_data/rulelearning/" + args["dataset"] + "/rule2_p4.txt", 'r')
rule2_p4 = {}
for line in f:
    head, body1, body2, confi = line.strip().split("\t")
    head, body1, body2, confi = int(head), int(body1), int(body2), float(confi)
    if head not in rule2_p4:
        rule2_p4[head] = {}
    rule2_p4[head][(body1, body2)] = confi
f.close()

rules = (rule1_p1, rule1_p2, rule2_p1, rule2_p2, rule2_p3, rule2_p4)

sizes = dataset.get_shape()
print("sizes of dataset is:\t", sizes)
model = {
    'LCGE': LCGE(sizes, args["rank"], rules, args['weight_static'], no_time_emb=args['no_time_emb']),
}[args['model']]
model = model.cuda()
# if torch.cuda.device_count() > 1:
#     print(f"使用 {torch.cuda.device_count()} 个 GPU.")
#     model = nn.DataParallel(model)

opt = optim.Adagrad(model.parameters(), lr=args['learning_rate'])

emb_reg = N3(args['emb_reg'])
time_reg = Lambda3(args['time_reg'])
rule_reg = RuleSim(args['rule_reg'])  # relation embedding reglu via rules

best_mrr = 0.
best_hit = 0.
early_stopping = 0
writer = SummaryWriter(log_dir='logs')
for args in list_args:
    for epoch in range(args['max_epochs']):
        examples = torch.from_numpy(
            dataset.get_train().astype('int64')
        )
        print("\nexamples:\n", examples.size())

        model.train()
        if dataset.has_intervals():
            optimizer = IKBCOptimizer(
                model, emb_reg, time_reg, opt, dataset,
                batch_size=args.batch_size
            )
            optimizer.epoch(examples)

        else:
            optimizer = TKBCOptimizer(
                model, emb_reg, time_reg, rule_reg, opt,
                batch_size=args['batch_size']
            )
            optimizer.epoch(examples)


        def avg_both(mrrs: Dict[str, float], hits: Dict[str, torch.FloatTensor]):
            """
            aggregate metrics for missing lhs and rhs
            :param mrrs: d
            :param hits:
            :return:
            """
            m = (mrrs['lhs'] + mrrs['rhs']) / 2.
            h = (hits['lhs'] + hits['rhs']) / 2.
            return {'MRR': m, 'hits@[1,3,10]': h}


        if epoch < 0 or (epoch + 1) % args['valid_freq'] == 0:
            if dataset.has_intervals():
                valid, test, train = [
                    dataset.eval(model, split, -1 if split != 'train' else 50000)
                    for split in ['valid', 'test', 'train']
                ]
                print("valid: ", valid)
                print("test: ", test)
                print("train: ", train)

            else:
                valid, test, train = [
                    avg_both(*dataset.eval(model, split, -1 if split != 'train' else 50000))
                    for split in ['valid', 'test', 'train']
                ]
                print("epoch: ", epoch + 1)
                print("valid: ", valid['MRR'])
                print("test: ", test['MRR'])
                print("train: ", train['MRR'])

                writer.add_scalar('valid', valid['MRR'], epoch)
                writer.add_scalar('test', test['MRR'], epoch)
                writer.add_scalar('train', train['MRR'], epoch)

                writer.add_scalar('hit@1', test['hits@[1,3,10]'][0], epoch)
                writer.add_scalar('hit@3', test['hits@[1,3,10]'][1], epoch)
                writer.add_scalar('hit@10', test['hits@[1,3,10]'][2], epoch)

                print("test hits@n:\t", test['hits@[1,3,10]'])
                if test['MRR'] > best_mrr:
                    best_mrr = test['MRR']
                    best_hit = test['hits@[1,3,10]']
                    early_stopping = 0
                else:
                    early_stopping += 1
                if early_stopping > 5:
                    print("early stopping!")
                    break

    print("The best test mrr is:\t", best_mrr)
    print("The best test hits@1,3,10 are:\t", best_hit)
