### DGNN代码复现

* [源仓库地址](https://github.com/alge24/DyGNN)

### 导入python库和相关自定义文件

In [1]:
import torch
import torch.nn as nn
from torch.nn import init
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SequentialSampler
from model_recurrent import DyGNN
from datasets import Temporal_Dataset
import argparse
from scipy.stats import rankdata
import numpy as np
import random
import os
import time
from param_parser import parameter_parser, tab_printer
import warnings
warnings.filterwarnings("ignore")


### 获取全局参数

In [2]:
args = parameter_parser()
args.batch_size = 5

In [3]:
tab_printer(args)

+----------------+-------+
|   Parameter    | Value |
| Act            | tanh  |
+----------------+-------+
| Batch size     | 5     |
+----------------+-------+
| Dataset        | uci   |
+----------------+-------+
| Decay method   | log   |
+----------------+-------+
| Drop p         | 0     |
+----------------+-------+
| If no time     | 0     |
+----------------+-------+
| If propagation | 1     |
+----------------+-------+
| If updated     | 0     |
+----------------+-------+
| Is att         | 1     |
+----------------+-------+
| Learning rate  | 0.001 |
+----------------+-------+
| Nor            | 0     |
+----------------+-------+
| Num negative   | 5     |
+----------------+-------+
| Reset rep      | 1     |
+----------------+-------+
| Second order   | 0     |
+----------------+-------+
| Seed           | 0     |
+----------------+-------+
| Threhold       | None  |
+----------------+-------+
| Train ratio    | 0.800 |
+----------------+-------+
| Transfer       | 1     |
+

### 模型的存储文件夹

In [4]:
# 模型的存储文件夹
model_save_dir = 'saved_models/'

### 导入数据

In [5]:
# 导入数据
data = Temporal_Dataset('Dataset/UCI_email_1899_59835/opsahl-ucsocial/out.opsahl-ucsocial',1,2)

In [6]:
len(data)

59835

In [7]:
data.node_num()

1899

In [8]:
if args.dataset == 'uci':
    num_nodes = data.node_num() # 节点数量 1899
    model_save_dir = model_save_dir + 'UCI/'  # 模型存储路径
    print('Train on UCI_message dataset')

Train on UCI_message dataset


### 模型参数设置

In [9]:
batch_size = args.batch_size
learning_rate = args.learning_rate
num_negative = args.num_negative
act = args.act
transfer = args.transfer
drop_p = args.drop_p
if_propagation = args.if_propagation
w = args.w
is_att = args.is_att
seed = args.seed
reset_rep = args.reset_rep
decay_method = args.decay_method
nor = args.nor
if_updated = args.if_updated
weight_decay = args.weight_decay
if_no_time = args.if_no_time
threhold = args.threhold
second_order = args.second_order
num_iter = 2

### 随机数设置

In [10]:
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
    torch.cuda.manual_seed(seed)

### 训练集、验证集与测试集划分

In [11]:
train_ratio = args.train_ratio
valid_ratio = args.valid_ratio
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

train_data = data[0:int(len(data)*train_ratio)]  # 选择训练集
validation_data = data[int(len(data)*train_ratio):int(len(data)*(train_ratio+valid_ratio))]  # 验证集
test_data = data[int(len(data)*(train_ratio + valid_ratio)):len(data)]  # 测试集
print('Data length: ', len(data))
print('Train length: ', len(train_data))
print('Valid length: ', len(validation_data))
print('Test length: ', len(test_data))

Data length:  59835
Train length:  47868
Valid length:  598
Test length:  11369


In [12]:
train_data.shape

(47868, 3)

In [13]:
sampler = SequentialSampler(train_data)  # 顺序采样
data_loader = DataLoader(train_data, batch_size, sampler = sampler)  # 定义dataloader

In [14]:
all_nodes = set(range(num_nodes))
print('num_nodes',len(all_nodes))

num_nodes 1899


### src和dst节点对应到的所有节点集合

In [15]:
def get_node2candidate(train_data, all_nodes, pri = False):
    head_node2candidate = dict()
    tail_node2candidate = dict()

    pri = True
    if pri:
        start_time = time.time()
        print('Start to build node2candidate')


    for i in range(len(train_data)):  # 遍历训练数据;

        head, tail, not_in_use = train_data[i]  # src, dst, time
        head = int(head)
        tail = int(tail)
        if head not in head_node2candidate:
            head_node2candidate[head] = all_nodes  # src节点对应到的所有节点

        if tail not in tail_node2candidate:
            tail_node2candidate[tail] = all_nodes  # dst节点



    if pri: 
        end_time = time.time()

        print('node2candidate built in' , str(end_time-start_time))
    return head_node2candidate, tail_node2candidate

In [16]:
head_node2candidate, tail_node2candidate = get_node2candidate(train_data, all_nodes)  # src和dst节点对应到的所有节点集合

Start to build node2candidate
node2candidate built in 0.053854942321777344


In [17]:
print(len(head_node2candidate),len(head_node2candidate[0]))

1217 1899


### 模型存储位置

In [18]:
model_save_dir = model_save_dir  + 'nt_' +str(if_no_time)+ '_wd_' + str(weight_decay) + '_up_' + str(if_updated) +'_w_' + str(w) +'_b_' + str(batch_size) + '_l_' + str(learning_rate) + '_tr_' + str(train_ratio) + '_nn_' +str(num_negative)+'_' + act + '_trans_' +str(transfer) + '_dr_p_' + str(drop_p) + '_prop_' + str(if_propagation) + '_att_' +str(is_att) + '_rp_' + str(reset_rep) + '_dcm_' + decay_method + '_nor_' + str(nor)
if threhold is not None:
    model_save_dir = model_save_dir + '_th_' + str(threhold)
if second_order:
    model_save_dir = model_save_dir + '_2hop'
if not os.path.exists(model_save_dir):
    os.makedirs(model_save_dir)

### 定义模型

In [19]:
dyGnn = DyGNN(num_nodes,64,64,device, w,is_att ,transfer,nor,if_no_time, threhold,second_order, if_updated,drop_p, num_negative, act, if_propagation, decay_method )

Only propagate to relevance nodes below time interval:  None


In [20]:
dyGnn.train()

DyGNN(
  (combiner): Combiner(
    (h2o): Linear(in_features=64, out_features=64, bias=True)
    (l2o): Linear(in_features=64, out_features=64, bias=True)
    (act): Tanh()
  )
  (act): Tanh()
  (decayer): Decayer(
    (linear): Linear(in_features=1, out_features=1, bias=False)
  )
  (edge_updater_head): Edge_updater_nn(
    (h2o): Linear(in_features=64, out_features=64, bias=True)
    (l2o): Linear(in_features=64, out_features=64, bias=True)
    (act): Tanh()
  )
  (edge_updater_tail): Edge_updater_nn(
    (h2o): Linear(in_features=64, out_features=64, bias=True)
    (l2o): Linear(in_features=64, out_features=64, bias=True)
    (act): Tanh()
  )
  (node_updater_head): TLSTM(
    (i2h): Linear(in_features=64, out_features=256, bias=True)
    (h2h): Linear(in_features=64, out_features=256, bias=True)
    (c2s): Sequential(
      (0): Linear(in_features=64, out_features=64, bias=True)
      (1): Tanh()
    )
    (sigmoid): Sigmoid()
    (tanh): Tanh()
  )
  (node_updater_tail): TLSTM(
  

### 优化器设置

In [21]:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,dyGnn.parameters()),lr = learning_rate, weight_decay=weight_decay)
old_head_rank = num_nodes/2 # 949.5
old_tail_rank = num_nodes/2 # 949.5

### 计算验证集损失

In [22]:
def get_loss(data, head_reps, tail_reps,device):

    head_list = list(data[:,0])
    tail_list = list(data[:,1])

    head_tensors = head_reps(torch.LongTensor(head_list).to(device))
    tail_tensors = tail_reps(torch.LongTensor(tail_list).to(device))
    scores = torch.bmm(head_tensors.view(len(head_list),1,head_tensors.size()[1]),tail_tensors.view(len(head_list),head_tensors.size()[1],1)).view(len(head_list))
    labels = torch.FloatTensor([1]*len(head_list)).to(device)
    bce_with_logits_loss = nn.BCEWithLogitsLoss().to(device)
    loss = bce_with_logits_loss(scores,labels)
    return loss

### 得分

In [23]:
def rank(node, true_candidate, node2candidate, node_reps, candidate_reps, device, pri = False):
    node_tensor = node_reps(torch.LongTensor([node]).to(device)).view(-1,1)
    candidates = list(node2candidate[node])

    candidates.append(true_candidate)

    length = len(candidates)

    candidate_tensors = candidate_reps(torch.LongTensor(candidates).to(device))

    scores = torch.mm(candidate_tensors, node_tensor)
    negative_scores_numpy = -scores.view(1,-1).to('cpu').numpy()
    rank = rankdata(negative_scores_numpy)[-1]

    if pri:
        print(node , true_candidate)
        print(scores.view(-1))
        print(rank, 'out of',length)

    return rank, length

In [24]:
def get_ranks(test_data,head_reps, tail_reps, device, head_node2candidate, tail_node2candidate, pri=False, previous_links = None, bo = False):

    head_ranks = []
    tail_ranks = []
    head_lengths = []
    tail_lengths = []

    for interactioin in test_data:
        head_node, tail_node , time = interactioin
        head_node = int(head_node)
        tail_node = int(tail_node)
        if pri:
            print('--------------', head_node, tail_node, '---------------')


        if bo:
            if previous_links is not None: 
                if head_node in head_node2candidate and tail_node in tail_node2candidate and tail_node in head_node2candidate and head_node in tail_node2candidate and (head_node, tail_node) not in previous_links:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)
            else:

                if head_node in head_node2candidate and tail_node in tail_node2candidate and tail_node in head_node2candidate and head_node in tail_node2candidate:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)
        else:

            if previous_links is not None: 
                if head_node in head_node2candidate and tail_node in tail_node2candidate and (head_node, tail_node) not in previous_links:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)
            else:

                if head_node in head_node2candidate and tail_node in tail_node2candidate:
                    head_rank, head_length = rank(head_node, tail_node, head_node2candidate, head_reps, tail_reps, device,pri)
                    head_ranks.append(head_rank)
                    head_lengths.append(head_length)

                    tail_rank, tail_length = rank(tail_node, head_node, tail_node2candidate, tail_reps, head_reps, device)
                    tail_ranks.append(tail_rank)
                    tail_lengths.append(tail_length)

    return head_ranks, tail_ranks, head_lengths, tail_lengths

### 链接预测任务

In [25]:
def link_prediction(data, reps):
    head_list = list(data[:,0])
    tail_list = list(data[:,1])
    head_reps = reps[head_list,:]
    tail_reps = reps[tail_list,:]

def get_previous_links(data):
    previous_links = set()
    for i in range(len(data)):
        head, tail, time = data[i]
        previous_links.add((int(head), int(tail)))
    return previous_links 

###  模型训练

In [26]:
for epoch in range(num_iter):
    print('epoch: ', epoch)
    print('Resetting time...')
    dyGnn.reset_time()  # 定义最近的时间; 交互的时间
    print('Time reset')

    if reset_rep:
        dyGnn.reset_reps()
        print('reps reset')

    x = int(5000/batch_size)
    y = int(10000/batch_size)


    for i, interactions in enumerate(data_loader):
        # interactions.shape = torch.tensor[5,3] => [batch, info]
        # interactions: [src, dst, time_diff] 
        # Compute and print loss.
        loss = dyGnn.loss(interactions)  # 计算loss函数
        if i%x==0:
            #dyGnn.reset_reps()
            print(i,' train_loss: ', loss.item())

            if transfer:
                head_reps = nn.Embedding.from_pretrained(dyGnn.transfer2head(dyGnn.node_representations.weight))  # 获取embedding
                tail_reps = nn.Embedding.from_pretrained(dyGnn.transfer2tail(dyGnn.node_representations.weight))
            else:
                head_reps = dyGnn.node_representations
                tail_reps = dyGnn.node_representations

            # normalize
            head_reps = nn.Embedding.from_pretrained(nn.functional.normalize(head_reps.weight))  # 节点特征归一化
            tail_reps = nn.Embedding.from_pretrained(nn.functional.normalize(tail_reps.weight))


        if i%y==-1:
            if transfer:
                head_reps = nn.Embedding.from_pretrained(dyGnn.transfer2head(dyGnn.node_representations.weight))
                tail_reps = nn.Embedding.from_pretrained(dyGnn.transfer2tail(dyGnn.node_representations.weight))
            else:
                head_reps = dyGnn.node_representations
                tail_reps = dyGnn.node_representations

            head_reps = nn.Embedding.from_pretrained(nn.functional.normalize(head_reps.weight))
            tail_reps = nn.Embedding.from_pretrained(nn.functional.normalize(tail_reps.weight))

            head_ranks, tail_ranks, not_in_use, not_in_use2 = get_ranks(validation_data,head_reps, tail_reps, device, head_node2candidate, tail_node2candidate)  # 评价指标
            head_ranks_numpy = np.asarray(head_ranks)
            tail_ranks_numpy = np.asarray(tail_ranks)
            print('head_rank mean: ', np.mean(head_ranks_numpy),' ; ', 'head_rank var: ', np.var(head_ranks_numpy))
            print('tail_rank mean: ', np.mean(tail_ranks_numpy),' ; ', 'tail_rank var: ', np.var(tail_ranks_numpy))


        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


    if transfer:
        head_reps = nn.Embedding.from_pretrained(dyGnn.transfer2head(dyGnn.node_representations.weight))
        tail_reps = nn.Embedding.from_pretrained(dyGnn.transfer2tail(dyGnn.node_representations.weight))
    else:
        head_reps = dyGnn.node_representations
        tail_reps = dyGnn.node_representations
        
    head_reps = nn.Embedding.from_pretrained(nn.functional.normalize(head_reps.weight))
    tail_reps = nn.Embedding.from_pretrained(nn.functional.normalize(tail_reps.weight))


    valid_loss = get_loss(validation_data, head_reps, tail_reps, device)  # 验证集边分类loss
    head_ranks, tail_ranks, head_lengths, tail_lengths = get_ranks(validation_data, head_reps, tail_reps, device, head_node2candidate, tail_node2candidate)
    head_ranks_numpy = np.asarray(head_ranks)
    tail_ranks_numpy = np.asarray(tail_ranks)
    head_lengths_numpy = np.asarray(head_lengths)
    tail_lengths_numpy = np.asarray(tail_lengths)

    mean_head_rank = np.mean(head_ranks_numpy)
    mean_tail_rank = np.mean(tail_ranks_numpy)


    print('head_length mean: ', np.mean(head_lengths_numpy), ';', 'num_test: ', head_lengths_numpy.shape[0])
    print('tail_lengths mean: ', np.mean(tail_lengths_numpy), ';', 'num_test: ', tail_lengths_numpy.shape[0])
    print('head_rank mean: ', mean_head_rank,' ; ', 'head_rank var: ', np.var(head_ranks_numpy))
    print('tail_rank mean: ', mean_tail_rank,' ; ', 'tail_rank var: ', np.var(tail_ranks_numpy))
    print('reverse head_rank mean: ', np.mean(1/head_ranks_numpy))
    print('reverse tail_rank mean: ', np.mean(1/tail_ranks_numpy))
    print('head_rank HITS 100: ', (head_ranks_numpy<=100).sum())
    print('tail_rank_HITS 100: ', (tail_ranks_numpy<=100).sum())
    print('head_rank HITS 50: ', (head_ranks_numpy<=50).sum())
    print('tail_rank_HITS 50: ', (tail_ranks_numpy<=50).sum())
    print('head_rank HITS 20: ', (head_ranks_numpy<=20).sum())
    print('tail_rank_HITS 20: ', (tail_ranks_numpy<=20).sum())


    if mean_head_rank < old_head_rank or mean_tail_rank < old_tail_rank:
        model_save_path = model_save_dir + '/' + 'model_after_epoch_' + str(epoch) + '.pt'
        torch.save(dyGnn.state_dict(), model_save_path)
        print('model saved in: ', model_save_path)


        with open(model_save_dir + '/' + '0valid_results.txt','a') as f:
            f.write('epoch: ' + str(epoch) + '\n')
            f.write('head_rank mean: ' + str(mean_head_rank) + ' ; ' +  'head_rank var: ' + str(np.var(head_ranks_numpy)) + '\n')
            f.write('tail_rank mean: ' + str(mean_tail_rank) + ' ; ' +  'tail_rank var: ' + str(np.var(tail_ranks_numpy)) + '\n')
            f.write('head_rank HITS 100: ' + str ( (head_ranks_numpy<=100).sum()) + '\n')
            f.write('tail_rank_HITS 100: ' + str ( (tail_ranks_numpy<=100).sum()) + '\n')
            f.write('head_rank HITS 50: ' + str( (head_ranks_numpy<=50).sum()) + '\n')
            f.write('tail_rank_HITS 50: ' + str( (tail_ranks_numpy<=50).sum()) + '\n')
            f.write('head_rank HITS 20: ' + str( (head_ranks_numpy<=20).sum()) + '\n')
            f.write('tail_rank_HITS 20: ' + str( (tail_ranks_numpy<=20).sum()) + '\n')
            f.write('============================================================================\n')
            
        old_head_rank = mean_head_rank + 200
        old_tail_rank = mean_tail_rank + 200

epoch:  0
Resetting time...
Time reset
reps reset
0  train_loss:  1.1134543418884277


1000  train_loss:  0.2990732789039612
2000  train_loss:  0.3084699213504791
3000  train_loss:  0.2624087631702423
4000  train_loss:  0.3137985169887543
5000  train_loss:  0.3106762170791626
6000  train_loss:  0.264798641204834
7000  train_loss:  0.28286978602409363
8000  train_loss:  0.30891525745391846
9000  train_loss:  0.29906752705574036
head_length mean:  1900.0 ; num_test:  523
tail_lengths mean:  1900.0 ; num_test:  523
head_rank mean:  927.9311663479923  ;  head_rank var:  263297.83082598186
tail_rank mean:  894.8546845124283  ;  tail_rank var:  264886.0046960286
reverse head_rank mean:  0.0014650623107585626
reverse tail_rank mean:  0.0015926185354674747
head_rank HITS 100:  0
tail_rank_HITS 100:  0
head_rank HITS 50:  0
tail_rank_HITS 50:  0
head_rank HITS 20:  0
tail_rank_HITS 20:  0
model saved in:  saved_models/UCI/nt_0_wd_0.001_up_0_w_2_b_5_l_0.001_tr_0.8_nn_5_tanh_trans_1_dr_p_0_prop_1_att_1_rp_1_dcm_log_nor_0/model_after_epoch_0.pt
epoch:  1
Resetting time...
Time reset