In [4]:
import networkx as nx
from collections import defaultdict
import numpy as np
import torch, numpy, pickle, random, time, argparse
from tqdm import tqdm


class Env(object):
    def __init__(self, examples, config, padding, jump, maxn, transformer_space=None):
        """Temporal Knowledge Graph Environment.
        examples: quadruples (subject, relation, object, timestamps);
        config: config dict;
        state_action_space: Pre-processed action space;
        """
        self.config = config
        self.num_rel = config['num_rel']
        self.label2nodes, self.neighbors = self.prepare_data(examples)
        self.nebor_relation = self.built_nebor_relation(examples)
        # [0, num_rel) -> normal relations; num_rel -> stay in place，(num_rel, num_rel * 2] reversed relations.
        self.NO_OP = self.num_rel  # Stay in place; No Operation
        self.ePAD = config['num_ent']  # Padding entity
        self.rPAD = config['num_rel'] * 2  # Padding relation.
        self.tPAD = 0  # Padding time
        self.confPAD = 1  # Padding time
        self.padding = padding
        self.jump = jump
        self.maxn = maxn
        # self.state_action_space = state_action_space  # Pre-processed action space
        self.transformer_space = transformer_space
        if transformer_space:
            self.transformer_space_key = self.transformer_space.keys()

    def prepare_data(self, examples):
        label2nodes = defaultdict(set)
        neighbors = defaultdict(dict)
        examples.sort(key=lambda x: x[3], reverse=True)  # Reverse chronological order
        for example in tqdm(examples, desc="开始built_graph"):
            src = example[0]
            rel = example[1]
            dst = example[2]
            time = example[3]
            conf = example[4]

            src_node = (src, time)
            dst_node = (dst, time)
            src_node_conf = (src, time, conf)
            dst_node_conf = (dst, time, conf)

            label2nodes[src].add(src_node)
            label2nodes[dst].add(dst_node)

            # 为transformer做准备
            try:
                neighbors[src_node][rel].add(dst_node_conf)
            except KeyError:
                neighbors[src_node][rel] = set([dst_node_conf])
            try:
                neighbors[dst_node][rel + self.num_rel].add(src_node_conf)
            except KeyError:
                neighbors[dst_node][rel + self.num_rel] = set([src_node_conf])

        """需要对neighbors里面的tail根据conf值进行排序，由大到小排序"""
        for h, t in neighbors:
            # neighbors[(h, t)] = {r: list(ts) for r, ts in neighbors[(h, t)].items()}  # 生成邻居信息

            for r, ts_tuples in neighbors[(h, t)].items():
                # 将元组列表转换为列表的列表（每个内部列表包含一个元组转换成的列表）
                ts_lists = [t for t in ts_tuples]

                # 根据每个内部列表的第三个值（索引为2）进行排序（由大到小）
                sorted_ts_lists = sorted(ts_lists, key=lambda x: x[-1], reverse=True)

                # 更新 neighbors[(h, t)] 中的值
                neighbors[(h, t)][r] = sorted_ts_lists
            # print(neighbors[(h, t)])

        return label2nodes, neighbors

    def built_nebor_relation(self, examples):
        """The graph node is represented as (entity, time), and the edges are directed and labeled relation.
        return:
            graph: nx.MultiDiGraph;
            label2nodes: a dict [keys -> entities, value-> nodes in the graph (entity, time)]
        """
        nebor_relation = torch.ones(self.config['num_ent'], 2 * self.num_rel + 1)
        for head, relation, tail, timestamp, conf in tqdm(examples, desc='正在生成nebor_relation'):
            nebor_relation[head][relation] += 1
            nebor_relation[head][relation + self.num_rel] += 1

        first_elemnt = {key[0]: key for key in self.neighbors.keys()}
        for e in range(self.config['num_ent']):  # 由于此处使用的是所有的数据，不存在找不到的情况
            if e not in first_elemnt.keys():  # 不在train训练集中
                nebor_relation[e][2 * self.num_rel] += 1
        nebor_relation = torch.log(nebor_relation)
        nebor_relation /= nebor_relation.sum(1).unsqueeze(1)

        return nebor_relation

In [5]:
import os
from dataset.baseDataset import baseDataset,baseDataset_new
data_dir = "data/ICEWS14"
trainF = os.path.join(data_dir, 'train.txt')
testF = os.path.join(data_dir, 'test.txt')
statF = os.path.join(data_dir, 'stat.txt')
validF = os.path.join(data_dir, 'valid.txt')
if not os.path.exists(validF):
    validF = None
dataset = baseDataset(trainF, testF, statF, validF)

train_new_F = os.path.join(data_dir, 'train_new.txt')
test_new_F = os.path.join(data_dir, 'test_new.txt')
valid_new_F = os.path.join(data_dir, 'valid_new.txt')
dataset_new = baseDataset_new(train_new_F, test_new_F, valid_new_F)

config = {
        'num_rel': dataset.num_r,
        'num_ent': dataset.num_e,
    }
env = Env(dataset_new.allQuadruples, config, 10, 10, 10)

开始built_graph: 100%|██████████| 108462/108462 [00:00<00:00, 231747.00it/s]
正在生成nebor_relation: 100%|██████████| 108462/108462 [00:01<00:00, 68928.54it/s]


In [20]:
count = defaultdict(dict)
for k,v in env.neighbors.items():
    entity,time = k
    temp = 0
    for k1,v1 in v.items():
        temp += len(v1)
    count[time][entity] = temp
    # print(entity)
    # print(v)
    # print(count)
all = 0
for k2,v2 in count[8736].items():
    all = 0
    print(k2,v2)

# print(count[8736])

5 7
18 15
9 2
96 7
753 3
11 2
792 1
8 2
14 2
15 7
169 1
639 1
239 1
322 2
329 4
180 1
29 2
2523 2
23 1
22 1
36 4
1012 1
42 2
109 2
46 13
95 11
84 3
882 1
49 2
188 2
50 2
440 2
52 3
55 1
60 2
3 1
1139 1
4316 1
91 4
122 2
2458 2
93 1
185 2
1920 2
115 1
487 2
117 2
387 2
118 3
2292 1
929 1
127 2
128 13
143 7
535 1
251 1
3963 1
5299 3
132 1
151 2
165 1
657 3
167 1
170 1
171 3
4636 2
174 1
67 1
177 2
467 2
179 1
190 2
3622 2
195 2
1126 1
204 4
470 1
250 1
262 1
263 1
290 1
267 1
302 5
405 2
457 2
304 1
4554 1
12 1
331 2
1760 1
348 1
393 4
509 4
542 3
408 1
389 1
471 1
4783 1
472 1
59 1
586 1
69 1
590 1
591 5
804 1
2500 1
634 1
644 2
863 2
247 1
1589 1
668 2
318 1
675 1
678 1
680 1
730 1
1765 1
772 1
3752 1
791 1
57 1
830 1
831 1
845 1
5489 1
847 2
1438 2
897 1
1079 1
910 1
2405 1
936 1
10 1
1001 1
1028 1
1029 1
1053 1
1076 1
7126 1
1161 2
2457 2
1295 1
1376 1
1382 1
1383 1
1584 1
1708 1
1780 1
1791 1
1128 1
1966 1
2004 1
3588 1
2077 1
917 1
2081 1
1094 1
2099 1
2111 1
2346 1
1510 1
2355 1
7