In [1]:
import numpy as np
import pickle as pkl
import networkx as nx
import scipy.sparse as sp
import sys
import torch
import pickle

In [2]:
def sample_mask(idx, l):#长L，索引idx mask1
    """Create mask."""
    mask = np.zeros(l)
    mask[idx] = 1
    return np.array(mask, dtype=bool)

def sparse_to_tuple(sparse_mx):
    #coo格式变为元组表示
    """Convert sparse matrix to tuple representation."""
    def to_tuple(mx):
        if not sp.isspmatrix_coo(mx):#不是coo变成coo格式
            mx = mx.tocoo()
        coords = np.vstack((mx.row, mx.col)).transpose()#行堆叠矩阵，再转置，变成节点对（边）
        values = mx.data #节点属性
        shape = mx.shape #结构
        return coords, values, shape

    if isinstance(sparse_mx, list): 
        for i in range(len(sparse_mx)):
            sparse_mx[i] = to_tuple(sparse_mx[i])
    else:
        sparse_mx = to_tuple(sparse_mx)

    return sparse_mx

def preprocess_features(features):
    """Row-normalize feature matrix and convert to tuple representation"""
    rowsum = np.array(features.sum(1))      
    r_inv = np.power(rowsum, -1).flatten()      #归一标准化
    r_inv[np.isinf(r_inv)] = 0.
    r_mat_inv = sp.diags(r_inv)                 #对角化
    features = r_mat_inv.dot(features)          #点乘，标准化
    return features.todense(), sparse_to_tuple(features)#变为元组表示

def normalize_adj(mx):
    """Row-normalize sparse matrix"""
    rowsum = np.array(mx.sum(1))
    r_inv_sqrt = np.power(rowsum, -0.5).flatten()#1/sqrt(d)
    r_inv_sqrt[np.isinf(r_inv_sqrt)] = 0.
    r_mat_inv_sqrt = sp.diags(r_inv_sqrt)
    return mx.dot(r_mat_inv_sqrt).transpose().dot(r_mat_inv_sqrt)#D-1/2 A D-1/2


In [3]:
def parse_index_file(filename):
    """Parse index file."""
    index = []
    for line in open(filename):
        index.append(int(line.strip()))#按行读取 消除空白符
    return index

In [4]:
def accuracy(output, labels):
    preds = output.max(1)[1].type_as(labels)
    correct = preds.eq(labels).double()
    correct = correct.sum()
    return correct / len(labels)

In [5]:
def load_data(dataset_str):
      # """数据处理训练集(x,y); 测试集(tx,ty); 评估集(allx,ally) ; 邻接图graph"""
    names = ['x', 'y', 'tx', 'ty', 'allx', 'ally', 'graph']
    objects = []
    for i in range(len(names)):
        with open("data/{}/ind.{}.{}".format(dataset_str,dataset_str, names[i]), 'rb') as f:#打开文件ind.数据集.?
            if sys.version_info > (3, 0):
                objects.append(pkl.load(f, encoding='latin1'))
            else:
                objects.append(pkl.load(f))              #load反序列化对象，将文件中的数据解析为一个python对象
            #依次将对象读入objects


    x, y, tx, ty, allx, ally, graph = tuple(objects)
    #citeseer: x120(训练) tx1000(测试) allx2312
    #cora: x140(训练) tx1000(测试) allx1708
    #print(graph)
    #print(x.shape[0],tx.shape[0],allx.shape[0])

    test_idx_reorder = parse_index_file("data/{}/ind.{}.test.index".format(dataset_str,dataset_str))#处理index文件并返回index矩阵
    test_idx_range = np.sort(test_idx_reorder)#重排索引
    
    #print(test_idx_range)#1708~2707(最后1000个为测试集)
    
    if dataset_str == 'citeseer':
        
        test_idx_range_full = range(min(test_idx_reorder), max(test_idx_reorder)+1)
        tx_extended = sp.lil_matrix((len(test_idx_range_full), x.shape[1]))
        tx_extended[test_idx_range-min(test_idx_range), :] = tx
        tx = tx_extended
        ty_extended = np.zeros((len(test_idx_range_full), y.shape[1]))
        ty_extended[test_idx_range-min(test_idx_range), :] = ty
        ty = ty_extended

    features = sp.vstack((allx, tx)).tolil()#全部数据
    features[test_idx_reorder, :] = features[test_idx_range, :]#索引重排
    adj = nx.adjacency_matrix(nx.from_dict_of_lists(graph)) 
    #print(adj,np.shape(adj))#3327*3327、
    #test_adj=np.array(adj.todense())   
    #print((test_adj==test_adj.transpose()).all())#无向图
    labels = np.vstack((ally, ty))#所有标签
    labels[test_idx_reorder, :] = labels[test_idx_range, :]

    idx_val = range(len(y), len(y) + 500)
    idx_train = range(len(y))
    idx_test = test_idx_range
    
    '''
    train_mask = sample_mask(idx_train, labels.shape[0])
    val_mask = sample_mask(idx_val, labels.shape[0])
    test_mask = sample_mask(idx_test, labels.shape[0])
    y_train = np.zeros(labels.shape)
    y_val = np.zeros(labels.shape)
    y_test = np.zeros(labels.shape)
    y_train[train_mask, :] = labels[train_mask, :]
    y_val[val_mask, :] = labels[val_mask, :]
    y_test[test_mask, :] = labels[test_mask, :]
    '''
    
    adj = adj.astype(np.float32)
    adj = normalize_adj(adj + sp.eye(adj.shape[0]))#A+I 标准化
    adj = torch.FloatTensor(np.array(adj.todense()))
    
    if dataset_str == 'citeseer':
        for i in range(2312,3327):
            if i not in test_idx_range:
                labels[i][0]=1
        labels = torch.LongTensor(np.where(labels>0)[1])
        for i in range(2312,3327):
            if i not in test_idx_range:
                labels[i]=0
    else:
        labels = torch.LongTensor(np.where(labels>0)[1])
    
    idx_train = torch.LongTensor(idx_train)
    idx_val = torch.LongTensor(idx_val)
    idx_test = torch.LongTensor(idx_test)    
    
    return adj, features,idx_train, idx_val, idx_test, labels
    #return adj, features,idx_train, idx_val, idx_test, train_mask, val_mask, test_mask, labels

In [6]:
#citeseer_adj, features,  idx_train, idx_val, idx_test, labels= load_data("citeseer")
#citeseer_adj, features,  idx_train, idx_val, idx_test, train_mask, val_mask, test_mask, labels= load_data("citeseer")

  objects.append(pkl.load(f, encoding='latin1'))


tensor([120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132, 133,
        134, 135, 136, 137, 138, 139, 140, 141, 142, 143, 144, 145, 146, 147,
        148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159, 160, 161,
        162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
        176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189,
        190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203,
        204, 205, 206, 207, 208, 209, 210, 211, 212, 213, 214, 215, 216, 217,
        218, 219, 220, 221, 222, 223, 224, 225, 226, 227, 228, 229, 230, 231,
        232, 233, 234, 235, 236, 237, 238, 239, 240, 241, 242, 243, 244, 245,
        246, 247, 248, 249, 250, 251, 252, 253, 254, 255, 256, 257, 258, 259,
        260, 261, 262, 263, 264, 265, 266, 267, 268, 269, 270, 271, 272, 273,
        274, 275, 276, 277, 278, 279, 280, 281, 282, 283, 284, 285, 286, 287,
        288, 289, 290, 291, 292, 293, 294, 295, 296, 297, 298, 2