In [40]:
import torch
from pathlib import Path
import numpy as np
from scipy.sparse import coo_matrix,csr_matrix,diags,eye

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print(device) # 查看计算资源

# 数据路径
path = Path('./data/cora')

cuda


In [41]:
# 读取特征数据
paper_features_label = np.genfromtxt(path/'cora.content',dtype=np.str_) 

# 取出数据集中的第一列:论文ID
papers = paper_features_label[:,0].astype(np.int32)
# 论文重新编号，并将其映射到论文ID中，实现论文的统一管理
paper2idx = {k:v for v,k in enumerate(papers)}

# 将数据中间部分的字标签取出，转化成矩阵
features = csr_matrix(paper_features_label[:,1:-1],dtype=np.float32)

# 将数据的最后一项的文章分类属性取出，转化为分类的索引
labels = paper_features_label[:,-1]
lbl2idx = {k:v for v,k in enumerate(sorted(np.unique(labels)))}
labels = [lbl2idx[e] for e in labels]

In [42]:
# 读取论文关系数据
edges = np.genfromtxt(path/'cora.cites',dtype=np.int32) # 将数据集中论文的引用关系以数据的形式读入

# 转化为新编号节点间的关系：将数据集中论文ID表示的关系转化为重新编号后的关系
edges = np.asarray([paper2idx[e] for e in edges.flatten()],np.int32).reshape(edges.shape)

# 计算邻接矩阵，行与列都是论文个数：由论文引用关系所表示的图结构生成邻接矩阵。
adj = coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),shape=(len(labels), len(labels)), dtype=np.float32)
# 生成无向图对称矩阵：将有向图的邻接矩阵转化为无向图的邻接矩阵。Tip：转化为无向图的原因：主要用于对论文的分类，论文的引用关系主要提供单个特征之间的关联，故更看重是不是有关系，所以无向图即可。
adj_long = adj.multiply(adj.T < adj)
adj = adj_long + adj_long.T

In [43]:
# 归一化处理

# 定义函数，对矩阵的数据进行归一化处理
def normalize(mx): 
    rowsum = np.array(mx.sum(1)) 
    r_inv = (rowsum ** -1).flatten() 
    r_inv[np.isinf(r_inv)] = 0.0 
    r_mat_inv = diags(r_inv)
    mx = r_mat_inv.dot(mx) 
    return mx

# 对features矩阵进行归一化处理（每行总和为1）
features = normalize(features) 
# 对邻接矩阵的对角线添1，再对其进行归一化处理
adj = normalize(adj + eye(adj.shape[0])) 

In [44]:
# 将数据转化为张量，并分配运算资源
adj = torch.FloatTensor(adj.todense()) # 节点间关系 todense()方法将其转换为矩阵。
features = torch.FloatTensor(features.todense()) # 节点自身的特征
labels = torch.LongTensor(labels) # 对每个节点的分类标签
 
# 划分数据集
n_train = 1500 # 训练数据集大小
n_val = 500 # 验证数据集大小
n_test = len(features) - n_train - n_val # 测试数据集大小
np.random.seed(34)
idxs = np.random.permutation(len(features)) # 将原有的索引打乱顺序
 
# 计算每个数据集的索引
idx_train = torch.LongTensor(idxs[:n_train]) # 根据指定训练数据集的大小并划分出其对应的训练数据集索引
idx_val = torch.LongTensor(idxs[n_train:n_train+n_val])# 根据指定验证数据集的大小并划分出其对应的验证数据集索引
idx_test = torch.LongTensor(idxs[n_train+n_val:])# 根据指定测试数据集的大小并划分出其对应的测试数据集索引
 
# 分配运算资源
adj = adj.to(device)
features = features.to(device)
labels = labels.to(device)
idx_train = idx_train.to(device)
idx_val = idx_val.to(device)
idx_test = idx_test.to(device)

In [None]:
# 定义函数计算准确率
def accuracy(output,y): 
    return (output.argmax(1) == y).type(torch.float32).mean().item()
 
# 定义函数来实现模型的训练过程。与深度学习任务不同，图卷积在训练时需要传入样本间的关系数据。
# 因为该关系数据是与节点数相等的方阵，所以传入的样本数也要与节点数相同，在计算loss值时，可以通过索引从总的运算结果中取出训练集的结果。
# 定义函数来训练模型 Tip：在图卷积任务中，无论是用模型进行预测还是训练，都需要将全部的图结构方阵输入
def step():
    model.train()
    optimizer.zero_grad()
    output = model(features,adj) # 将全部数据载入模型，只用训练数据计算损失
    loss = F.cross_entropy(output[idx_train],labels[idx_train])
    acc = accuracy(output[idx_train],labels[idx_train]) # 计算准确率
    loss.backward()
    optimizer.step()
    return loss.item(),acc
 
 # 定义函数来评估模型
def evaluate(idx): 
    model.eval()
    output = model(features, adj) # 将全部数据载入模型，用指定索引评估模型结果
    loss = F.cross_entropy(output[idx], labels[idx]).item()
    return loss, accuracy(output[idx], labels[idx])

In [56]:
from GCN import *
from tqdm import tqdm
from Ranger import * 
import matplotlib.pyplot as plt

# 分类和关键词数量
n_labels = 7
n_features = features.shape[1] 

model = GCN(n_features, n_labels, hidden=[16, 32, 16]).to(device)
optimizer = Ranger(model.parameters())

# 训练模型
epochs = 2000
print_steps = 50
train_loss, train_acc = [], []
val_loss, val_acc = [], []
for i in tqdm(range(epochs)):
    tl,ta = step()
    train_loss = train_loss + [tl]
    train_acc = train_acc + [ta]
    if (i+1) % print_steps == 0 or i == 0:
        tl,ta = evaluate(idx_train)
        vl,va = evaluate(idx_val)
        val_loss = val_loss + [vl]
        val_acc = val_acc + [va]
        print(f'{i + 1:6d}/{epochs}: train_loss={tl:.4f}, train_acc={ta:.4f}' + f', val_loss={vl:.4f}, val_acc={va:.4f}')
 
# 输出最终结果
final_train, final_val, final_test = evaluate(idx_train), evaluate(idx_val), evaluate(idx_test)
print(f'Train     : loss={final_train[0]:.4f}, accuracy={final_train[1]:.4f}')
print(f'Validation: loss={final_val[0]:.4f}, accuracy={final_val[1]:.4f}')
print(f'Test      : loss={final_test[0]:.4f}, accuracy={final_test[1]:.4f}')
 
# 可视化训练过程
plt.switch_backend('agg')
plt.rc('font',family='Times New Roman', size=15)
fig, axes = plt.subplots(1, 2, figsize=(15,5))
ax = axes[0]
axes[0].plot(train_loss[::print_steps] + [train_loss[-1]], label='Train')
axes[0].plot(val_loss, label='Validation')
axes[1].plot(train_acc[::print_steps] + [train_acc[-1]], label='Train')
axes[1].plot(val_acc, label='Validation')
for ax,t in zip(axes, ['Loss', 'Accuracy']): ax.legend(), ax.set_title(t, size=15)
plt.savefig("./result.jpg")
 

  1%|▏         | 25/2000 [00:00<00:08, 246.75it/s]

     1/2000: train_loss=1.9447, train_acc=0.1360, val_loss=1.9441, val_acc=0.1340


  4%|▎         | 73/2000 [00:00<00:09, 213.71it/s]

    50/2000: train_loss=1.9394, train_acc=0.2073, val_loss=1.9382, val_acc=0.2220


  7%|▋         | 136/2000 [00:00<00:09, 197.52it/s]

   100/2000: train_loss=1.9285, train_acc=0.2627, val_loss=1.9259, val_acc=0.3020


  9%|▉         | 178/2000 [00:00<00:09, 200.19it/s]

   150/2000: train_loss=1.9149, train_acc=0.2867, val_loss=1.9104, val_acc=0.3120


 11%|█         | 222/2000 [00:01<00:08, 204.86it/s]

   200/2000: train_loss=1.8943, train_acc=0.2900, val_loss=1.8868, val_acc=0.3140


 14%|█▍        | 287/2000 [00:01<00:08, 211.74it/s]

   250/2000: train_loss=1.8682, train_acc=0.2900, val_loss=1.8561, val_acc=0.3140


 17%|█▋        | 345/2000 [00:01<00:06, 251.40it/s]

   300/2000: train_loss=1.8428, train_acc=0.2900, val_loss=1.8252, val_acc=0.3140
   350/2000: train_loss=1.8160, train_acc=0.2900, val_loss=1.7924, val_acc=0.3140


 22%|██▏       | 439/2000 [00:01<00:05, 287.23it/s]

   400/2000: train_loss=1.7902, train_acc=0.2900, val_loss=1.7631, val_acc=0.3140
   450/2000: train_loss=1.7615, train_acc=0.2900, val_loss=1.7336, val_acc=0.3140


 27%|██▋       | 537/2000 [00:02<00:04, 311.32it/s]

   500/2000: train_loss=1.7139, train_acc=0.2900, val_loss=1.6860, val_acc=0.3140
   550/2000: train_loss=1.6389, train_acc=0.2907, val_loss=1.6126, val_acc=0.3140


 32%|███▏      | 633/2000 [00:02<00:04, 310.22it/s]

   600/2000: train_loss=1.5386, train_acc=0.3047, val_loss=1.5147, val_acc=0.3260
   650/2000: train_loss=1.3823, train_acc=0.5060, val_loss=1.3640, val_acc=0.5360


 36%|███▋      | 728/2000 [00:02<00:04, 297.27it/s]

   700/2000: train_loss=1.1962, train_acc=0.6287, val_loss=1.1860, val_acc=0.6180
   750/2000: train_loss=1.0382, train_acc=0.6600, val_loss=1.0362, val_acc=0.6460


 42%|████▏     | 837/2000 [00:03<00:04, 234.21it/s]

   800/2000: train_loss=0.8937, train_acc=0.6987, val_loss=0.9019, val_acc=0.6700


 44%|████▍     | 886/2000 [00:03<00:04, 229.96it/s]

   850/2000: train_loss=0.7873, train_acc=0.7353, val_loss=0.8074, val_acc=0.7120


 47%|████▋     | 938/2000 [00:03<00:04, 237.96it/s]

   900/2000: train_loss=0.7150, train_acc=0.7687, val_loss=0.7483, val_acc=0.7380


 49%|████▉     | 986/2000 [00:03<00:04, 232.67it/s]

   950/2000: train_loss=0.6475, train_acc=0.8013, val_loss=0.6959, val_acc=0.7820


 52%|█████▏    | 1038/2000 [00:04<00:04, 238.93it/s]

  1000/2000: train_loss=0.5883, train_acc=0.8133, val_loss=0.6549, val_acc=0.8000


 54%|█████▍    | 1086/2000 [00:04<00:04, 228.29it/s]

  1050/2000: train_loss=0.5406, train_acc=0.8207, val_loss=0.6248, val_acc=0.8060


 57%|█████▋    | 1133/2000 [00:04<00:03, 229.12it/s]

  1100/2000: train_loss=0.4919, train_acc=0.8560, val_loss=0.5962, val_acc=0.8260


 59%|█████▉    | 1179/2000 [00:04<00:03, 220.84it/s]

  1150/2000: train_loss=0.4498, train_acc=0.8800, val_loss=0.5746, val_acc=0.8480


 61%|██████▏   | 1225/2000 [00:05<00:03, 219.37it/s]

  1200/2000: train_loss=0.4154, train_acc=0.8853, val_loss=0.5575, val_acc=0.8500


 65%|██████▍   | 1291/2000 [00:05<00:03, 206.74it/s]

  1250/2000: train_loss=0.3817, train_acc=0.8940, val_loss=0.5440, val_acc=0.8580


 67%|██████▋   | 1341/2000 [00:05<00:02, 227.79it/s]

  1300/2000: train_loss=0.3531, train_acc=0.8987, val_loss=0.5338, val_acc=0.8680
  1350/2000: train_loss=0.3311, train_acc=0.9047, val_loss=0.5308, val_acc=0.8680


 72%|███████▏  | 1431/2000 [00:06<00:02, 211.23it/s]

  1400/2000: train_loss=0.3104, train_acc=0.9100, val_loss=0.5274, val_acc=0.8640


 74%|███████▎  | 1474/2000 [00:06<00:02, 203.92it/s]

  1450/2000: train_loss=0.2911, train_acc=0.9153, val_loss=0.5270, val_acc=0.8620


 77%|███████▋  | 1542/2000 [00:06<00:02, 213.53it/s]

  1500/2000: train_loss=0.2768, train_acc=0.9200, val_loss=0.5305, val_acc=0.8600


 79%|███████▉  | 1586/2000 [00:06<00:01, 213.58it/s]

  1550/2000: train_loss=0.2623, train_acc=0.9240, val_loss=0.5335, val_acc=0.8660


 82%|████████▏ | 1636/2000 [00:06<00:01, 230.32it/s]

  1600/2000: train_loss=0.2482, train_acc=0.9267, val_loss=0.5396, val_acc=0.8620
  1650/2000: train_loss=0.2381, train_acc=0.9287, val_loss=0.5457, val_acc=0.8540


 86%|████████▋ | 1725/2000 [00:07<00:01, 271.58it/s]

  1700/2000: train_loss=0.2265, train_acc=0.9307, val_loss=0.5542, val_acc=0.8540


 89%|████████▉ | 1779/2000 [00:07<00:00, 231.91it/s]

  1750/2000: train_loss=0.2165, train_acc=0.9347, val_loss=0.5618, val_acc=0.8580


 91%|█████████▏| 1826/2000 [00:07<00:00, 218.91it/s]

  1800/2000: train_loss=0.2073, train_acc=0.9367, val_loss=0.5684, val_acc=0.8520


 95%|█████████▍| 1894/2000 [00:08<00:00, 217.44it/s]

  1850/2000: train_loss=0.1979, train_acc=0.9400, val_loss=0.5766, val_acc=0.8500


 97%|█████████▋| 1938/2000 [00:08<00:00, 212.11it/s]

  1900/2000: train_loss=0.1893, train_acc=0.9427, val_loss=0.5862, val_acc=0.8500


 99%|█████████▉| 1984/2000 [00:08<00:00, 220.43it/s]

  1950/2000: train_loss=0.1821, train_acc=0.9447, val_loss=0.5942, val_acc=0.8500


100%|██████████| 2000/2000 [00:08<00:00, 233.23it/s]
findfont: Font family ['Times New Roman'] not found. Falling back to DejaVu Sans.


  2000/2000: train_loss=0.1751, train_acc=0.9453, val_loss=0.6021, val_acc=0.8500
Train     : loss=0.1751, accuracy=0.9453
Validation: loss=0.6021, accuracy=0.8500
Test      : loss=0.6329, accuracy=0.8573
