## DySAT代码复现

* [参考仓库链接](https://github.com/FeiGSSS/DySAT_pytorch)

### 导入相关库文件

In [1]:
import argparse
import networkx as nx
import numpy as np
import dill
import pickle as pkl
import scipy
from torch.utils.data import DataLoader

from utils.preprocess import load_graphs, get_context_pairs, get_evaluation_data
from utils.minibatch import  MyDataset
from utils.utilities import to_device
from eval.link_prediction import evaluate_classifier
from models.model import DySAT

import torch
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x22ff5736760>

In [2]:
def inductive_graph(graph_former, graph_later):
    """Create the adj_train so that it includes nodes from (t+1) 
       but only edges from t: this is for the purpose of inductive testing.

    Args:
        graph_former ([type]): [description]
        graph_later ([type]): [description]
    """
    newG = nx.MultiGraph()
    newG.add_nodes_from(graph_later.nodes(data=True))  # 最后一张图中的节点
    newG.add_edges_from(graph_former.edges(data=False))  # 前一张图的边; 目的是语料中出现要预测的图中的节点
    return newG

### 设置全局参数

In [3]:
def parameter_parser():
    """
    A method to parse up command line parameters.
    The default hyperparameters give a high performance model without grid search.
    """
    parser = argparse.ArgumentParser()
    parser.add_argument('--time_steps', type=int, nargs='?', default=16,
                        help="total time steps used for train, eval and test")
    # Experimental settings.
    parser.add_argument('--dataset', type=str, nargs='?', default='Enron',
                        help='dataset name')
    parser.add_argument('--GPU_ID', type=int, nargs='?', default=0,
                        help='GPU_ID (0/1 etc.)')
    parser.add_argument('--epochs', type=int, nargs='?', default=30,
                        help='# epochs')
    parser.add_argument('--val_freq', type=int, nargs='?', default=1,
                        help='Validation frequency (in epochs)')
    parser.add_argument('--test_freq', type=int, nargs='?', default=1,
                        help='Testing frequency (in epochs)')
    parser.add_argument('--batch_size', type=int, nargs='?', default=512,
                        help='Batch size (# nodes)')
    parser.add_argument('--featureless', type=bool, nargs='?', default=True,
                    help='True if one-hot encoding.')
    parser.add_argument("--early_stop", type=int, default=10,
                        help="patient")
    # 1-hot encoding is input as a sparse matrix - hence no scalability issue for large datasets.
    # Tunable hyper-params
    # TODO: Implementation has not been verified, performance may not be good.
    parser.add_argument('--residual', type=bool, nargs='?', default=True,
                        help='Use residual')
    # Number of negative samples per positive pair.
    parser.add_argument('--neg_sample_size', type=int, nargs='?', default=10,
                        help='# negative samples per positive')
    # Walk length for random walk sampling.
    parser.add_argument('--walk_len', type=int, nargs='?', default=20,
                        help='Walk length for random walk sampling')
    # Weight for negative samples in the binary cross-entropy loss function.
    parser.add_argument('--neg_weight', type=float, nargs='?', default=1.0,
                        help='Weightage for negative samples')
    parser.add_argument('--learning_rate', type=float, nargs='?', default=0.01,
                        help='Initial learning rate for self-attention model.')
    parser.add_argument('--spatial_drop', type=float, nargs='?', default=0.1,
                        help='Spatial (structural) attention Dropout (1 - keep probability).')
    parser.add_argument('--temporal_drop', type=float, nargs='?', default=0.5,
                        help='Temporal attention Dropout (1 - keep probability).')
    parser.add_argument('--weight_decay', type=float, nargs='?', default=0.0005,
                        help='Initial learning rate for self-attention model.')
    # Architecture params
    parser.add_argument('--structural_head_config', type=str, nargs='?', default='16,8,8',
                        help='Encoder layer config: # attention heads in each GAT layer')
    parser.add_argument('--structural_layer_config', type=str, nargs='?', default='128',
                        help='Encoder layer config: # units in each GAT layer')
    parser.add_argument('--temporal_head_config', type=str, nargs='?', default='16',
                        help='Encoder layer config: # attention heads in each Temporal layer')
    parser.add_argument('--temporal_layer_config', type=str, nargs='?', default='128',
                        help='Encoder layer config: # units in each Temporal layer')
    parser.add_argument('--position_ffn', type=str, nargs='?', default='True',
                        help='Position wise feedforward')
    parser.add_argument('--window', type=int, nargs='?', default=-1,
                        help='Window for temporal attention (default : -1 => full)')

    return parser.parse_args(args=[])

### 可视化参数设置

In [4]:
from texttable import Texttable

def tab_printer(args):
    """
    Function to print the logs in a nice tabular format.
    :param args: Parameters used for the model.
    """
    args = vars(args)
    keys = sorted(args.keys())
    t = Texttable()
    rows = [["Parameter", "Value"]]
    for i in [[k.replace("_", " ").capitalize(), args[k]] for k in keys]:
        rows.append(i)
    # print(rows)
    t.add_rows(rows)
    print(t.draw())

In [5]:
args = parameter_parser()

In [6]:
tab_printer(args)

+-------------------------+--------+
|        Parameter        | Value  |
| Gpu id                  | 0      |
+-------------------------+--------+
| Batch size              | 512    |
+-------------------------+--------+
| Dataset                 | Enron  |
+-------------------------+--------+
| Early stop              | 10     |
+-------------------------+--------+
| Epochs                  | 30     |
+-------------------------+--------+
| Featureless             | 1      |
+-------------------------+--------+
| Learning rate           | 0.010  |
+-------------------------+--------+
| Neg sample size         | 10     |
+-------------------------+--------+
| Neg weight              | 1      |
+-------------------------+--------+
| Position ffn            | True   |
+-------------------------+--------+
| Residual                | 1      |
+-------------------------+--------+
| Spatial drop            | 0.100  |
+-------------------------+--------+
| Structural head config  | 16,8,8 |
+

### 加载数据

In [7]:
import warnings
warnings.filterwarnings("ignore")

In [8]:
#graphs, feats, adjs = load_graphs(args.dataset)
graphs, adjs = load_graphs(args.dataset)  # 导入数据
if args.featureless == True:  # 创建单位阵  # 最后一个时间点包括的节点数量
    feats = [scipy.sparse.identity(adjs[args.time_steps - 1].shape[0]).tocsr()[range(0, x.shape[0]), :] for x in adjs if
            x.shape[0] <= adjs[args.time_steps - 1].shape[0]]  # 选择需要的时间点; 构建ont-hot特征

Loaded 16 graphs 


In [9]:
adjs[args.time_steps - 1].shape

(143, 143)

In [10]:
len(feats)

16

In [11]:
feats[0].shape

(18, 143)

### 对每张图进行随机游走采样

In [13]:
assert args.time_steps <= len(adjs), "Time steps is illegal"
# node2vec的训练语料; 16个garph 和 16个节点特征;
context_pairs_train = get_context_pairs(graphs, adjs)  # 16个图，每个图中进行随机游走采样;

Computing training pairs ...
# nodes with random walk samples: 18
# sampled pairs: 40854
# nodes with random walk samples: 18
# sampled pairs: 39386
# nodes with random walk samples: 14
# sampled pairs: 31918
# nodes with random walk samples: 47
# sampled pairs: 104792
# nodes with random walk samples: 57
# sampled pairs: 129236
# nodes with random walk samples: 65
# sampled pairs: 149300
# nodes with random walk samples: 79
# sampled pairs: 193622
# nodes with random walk samples: 97
# sampled pairs: 239452
# nodes with random walk samples: 101
# sampled pairs: 245690
# nodes with random walk samples: 106
# sampled pairs: 256484
# nodes with random walk samples: 103
# sampled pairs: 254148
# nodes with random walk samples: 113
# sampled pairs: 279936
# nodes with random walk samples: 98
# sampled pairs: 232548
# nodes with random walk samples: 79
# sampled pairs: 181768
# nodes with random walk samples: 94
# sampled pairs: 231172
# nodes with random walk samples: 93
# sampled pairs: 2

In [14]:
len(context_pairs_train)

16

In [17]:
context_pairs_train[0].keys()

dict_keys([9, 8, 0, 10, 5, 11, 4, 7, 15, 17, 3, 1, 2, 6, 16, 14, 13, 12])

In [19]:
# context_pairs_train[0][0] # 上下文节点

In [20]:
 # Load evaluation data for link prediction. 只是对最后一张图中边进行了处理：如果上张图中没有该节点，则不计算这条边的label
train_edges_pos, train_edges_neg, val_edges_pos, val_edges_neg, \
    test_edges_pos, test_edges_neg = get_evaluation_data(graphs)
# 训练集、验证集、测试集
print("No. Train: Pos={}, Neg={} \nNo. Val: Pos={}, Neg={} \nNo. Test: Pos={}, Neg={}".format(
    len(train_edges_pos), len(train_edges_neg), len(val_edges_pos), len(val_edges_neg),
    len(test_edges_pos), len(test_edges_neg)))

Generating eval data ....
No. Train: Pos=46, Neg=46 
No. Val: Pos=46, Neg=46 
No. Test: Pos=140, Neg=140


In [11]:
# Create the adj_train so that it includes nodes from (t+1) but only edges from t: this is for the purpose of; 创建包括下张图的所有节点，但只保留这一层的边
# inductive testing.
new_G = inductive_graph(graphs[args.time_steps-2], graphs[args.time_steps-1])  # 下一层的所有点，都放到数据集中
graphs[args.time_steps-2] = new_G
adjs[args.time_steps-2] = nx.adjacency_matrix(new_G)
feats[args.time_steps-2] = feats[args.time_steps-1]

### 构建dataloader和DySAT模型

In [12]:
# build dataloader and model
device = torch.device('cuda' if torch.cuda.is_available()  else 'cpu')
dataset = MyDataset(args, graphs, feats, adjs, context_pairs_train)

dataloader = DataLoader(dataset,  # 定义dataloader
                        batch_size=args.batch_size, 
                        shuffle=True, 
                        # num_workers=10, 
                        collate_fn=MyDataset.collate_fn)
#dataloader = NodeMinibatchIterator(args, graphs, feats, adjs, context_pairs_train, device) 
model = DySAT(args, feats[0].shape[1], args.time_steps).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.learning_rate, weight_decay=args.weight_decay)

In [13]:
# for idx, feed_dict in enumerate(dataloader):
#     print(idx, feed_dict)
#     break

### 模型训练

In [14]:
from tqdm import tqdm

# in training
best_epoch_val = 0
patient = 0

for epoch in tqdm(range(args.epochs)):
    model.train()
    epoch_loss = []
    # print('ok1')
    for idx, feed_dict in enumerate(dataloader):
        feed_dict = to_device(feed_dict, device)  # 节点信息
        opt.zero_grad()
        loss = model.get_loss(feed_dict)
        loss.backward()
        opt.step()
        epoch_loss.append(loss.item())

    # print('ok2')
    model.eval()
    emb = model(feed_dict["graphs"])[:, -2, :].detach().cpu().numpy()  # 倒数第二层节点的信息当作特征，来训练最后一层的边的连接情况;
    val_results, test_results, _, _ = evaluate_classifier(train_edges_pos,  # 训练时无监督训练(上下文)训练节点embedding; 测试时根据<e0,e1>或者再接一个classifier
                                                        train_edges_neg,
                                                        val_edges_pos, 
                                                        val_edges_neg, 
                                                        test_edges_pos,
                                                        test_edges_neg, 
                                                        emb, 
                                                        emb)
    epoch_auc_val = val_results["HAD"][1]
    epoch_auc_test = test_results["HAD"][1]

    if epoch_auc_val > best_epoch_val:
        best_epoch_val = epoch_auc_val
        torch.save(model.state_dict(), "./model_checkpoints/model.pt")
        patient = 0
    else:
        patient += 1
        if patient > args.early_stop:
            break

    if (epoch+1) % 10 == 0 or epoch==0 or (epoch+1) > 20:
        print("Epoch {:<3},  Loss = {:.3f}, Val AUC {:.3f} Test AUC {:.3f}".format(epoch+1, 
                                                                np.mean(epoch_loss),
                                                                epoch_auc_val, 
                                                                epoch_auc_test))

  3%|▎         | 1/30 [00:02<01:14,  2.57s/it]

Epoch 1  ,  Loss = 39.316, Val AUC 0.624 Test AUC 0.694


 33%|███▎      | 10/30 [00:24<00:49,  2.46s/it]

Epoch 10 ,  Loss = 18.979, Val AUC 0.869 Test AUC 0.835


 67%|██████▋   | 20/30 [00:48<00:23,  2.38s/it]

Epoch 20 ,  Loss = 16.605, Val AUC 0.939 Test AUC 0.895


 70%|███████   | 21/30 [00:50<00:21,  2.38s/it]

Epoch 21 ,  Loss = 16.520, Val AUC 0.940 Test AUC 0.896


 73%|███████▎  | 22/30 [00:53<00:19,  2.38s/it]

Epoch 22 ,  Loss = 16.443, Val AUC 0.941 Test AUC 0.896


 77%|███████▋  | 23/30 [00:55<00:16,  2.38s/it]

Epoch 23 ,  Loss = 16.392, Val AUC 0.946 Test AUC 0.897


 80%|████████  | 24/30 [00:57<00:14,  2.36s/it]

Epoch 24 ,  Loss = 16.324, Val AUC 0.948 Test AUC 0.896


 83%|████████▎ | 25/30 [01:00<00:11,  2.37s/it]

Epoch 25 ,  Loss = 16.276, Val AUC 0.949 Test AUC 0.895


 87%|████████▋ | 26/30 [01:02<00:09,  2.36s/it]

Epoch 26 ,  Loss = 16.177, Val AUC 0.949 Test AUC 0.896


 90%|█████████ | 27/30 [01:05<00:07,  2.36s/it]

Epoch 27 ,  Loss = 16.154, Val AUC 0.948 Test AUC 0.896


 93%|█████████▎| 28/30 [01:07<00:04,  2.36s/it]

Epoch 28 ,  Loss = 16.076, Val AUC 0.943 Test AUC 0.897


 97%|█████████▋| 29/30 [01:09<00:02,  2.35s/it]

Epoch 29 ,  Loss = 16.045, Val AUC 0.939 Test AUC 0.897


100%|██████████| 30/30 [01:12<00:00,  2.40s/it]

Epoch 30 ,  Loss = 16.000, Val AUC 0.936 Test AUC 0.897





### 加载模型,并输出最佳测试结果

In [15]:
# Test Best Model
model.load_state_dict(torch.load("./model_checkpoints/model.pt"))
model.eval()
emb = model(feed_dict["graphs"])[:, -2, :].detach().cpu().numpy()
val_results, test_results, _, _ = evaluate_classifier(train_edges_pos,
                                                    train_edges_neg,
                                                    val_edges_pos, 
                                                    val_edges_neg, 
                                                    test_edges_pos,
                                                    test_edges_neg, 
                                                    emb, 
                                                    emb)
auc_val = val_results["HAD"][1]
auc_test = test_results["HAD"][1]
print("Best Test AUC = {:.3f}".format(auc_test))

Best Test AUC = 0.895
