## 说明

&emsp;&emsp;源自书籍《深入浅出图神经网络》第五章，使用GCN对Cora数据集进行节点分类，实战代码。

**Step 1:** 导入必要的库文件

In [1]:
import itertools
import os
import os.path as osp
import pickle
import urllib
from collections import namedtuple

import numpy as np
import scipy.sparse as sp
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
import torch.optim as optim

**Step 2:** 保存处理好的数据集

In [2]:
Data = namedtuple('Data', ['x','y','adjacency','train_mask','val_mask','test_mask'])

class CoraData(object):
    download_url = "https://github.com/kimiyoung/planetoid/raw/master/data"
    filename = ["ind.cora.{}".format(name) for name in 
               ['x','tx','allx','y','ty','ally','graph','test.index']]
    
    def __init__(self, data_root = "cora", rebuild=False):
        self.data_root = data_root
        save_file = osp.join(self.data_root,"processed_cora.pkl")
        
        if osp.exists(save_file) and not rebuild:
            print("Using Cached file: {}".format(save_file))
            self._data = pickle.load(open(save_file,"rb"))
        else:
            self.maybe_download()
            self._data = self.process_data()
            with open(save_file,"wb") as f:
                pickle.dump(self.data, f)
            
            print("Cached file: {}".format(save_file))
        
    
    @property
    def data(self):
        return self._data
    
    def maybe_download(self):
        save_path = osp.join(self.data_root, "raw")
        for name in self.filename:
            if not osp.exists(osp.join(save_path, name)):
                self.download_data("{}/{}".format(self.download_url, name), save_path)
    
    @staticmethod
    def download_data(url, save_path):
        if not osp.exists(save_path):
            os.makedirs(save_path)
        data = urllib.request.urlopen(url)
        filename = osp.basename(url)
        
        with open(osp.join(save_path, filename), "wb") as f:
            f.write(data.read())
        
        return True
    
    def process_data(self):
        print("Processing data...")
        _, tx, allx, y, ty, ally, graph, test_index = [self.read_data(
        osp.join(self.data_root, "raw", name)) for name in self.filename]
        
        train_index = np.arange(y.shape[0])
        val_index = np.arange(y.shape[0], y.shape[0] + 500)
        sorted_test_index = sorted(test_index)
        
        x = np.concatenate((allx, tx), axis=0)
        y = np.concatenate((ally, ty), axis=0).argmax(axis=1)
        
        x[test_index] = x[sorted_test_index]
        y[test_index] = y[sorted_test_index]
        
        num_nodes = x.shape[0]
        
        train_mask = np.zeros(num_nodes, dtype=np.bool)
        val_mask = np.zeros(num_nodes, dtype=np.bool)
        test_mask = np.zeros(num_nodes, dtype=np.bool)
        
        train_mask[train_index] = True
        val_mask[val_index] = True
        test_mask[test_index] = True
        
        adjacency = self.build_adjacency(graph)
        
        print("Node's feature shape: ", x.shape)
        print("Node's label shape: ", y.shape)
        print("Adjacency's shape: ", adjacency.shape)
        print("Number of training nodes: ", train_mask.sum())
        
        return Data(x=x, y=y, adjacency=adjacency, train_mask=train_mask, val_mask=val_mask
                   ,test_mask=test_mask)
    
    @staticmethod
    def build_adjacency(adj_dict):
        edge_index = []
        num_nodes = len(adj_dict)
        
        for src, dst in adj_dict.items():
            edge_index.extend([src, v] for v in dst)
            edge_index.extend([v, src] for v in dst)
        
        edge_index = list(k for k,_ in itertools.groupby(sorted(edge_index)))
        edge_index = np.array(edge_index)
        adjacency = sp.coo_matrix((np.ones(len(edge_index)),
                                  (edge_index[:, 0], edge_index[:, 1])),
                                 shape=(num_nodes, num_nodes), dtype="float32")
        
        return adjacency
    
    @staticmethod
    def read_data(path):
        name = osp.basename(path)
        if name=="ind.cora.test.index":
            out = np.genfromtxt(path, dtype='int64')
            return out
        else:
            out = pickle.load(open(path, "rb"), encoding="latin1")
            out = out.toarray() if hasattr(out, "toarray") else out
            return out

<font color='red'>**Step 3:** 搭建GCN卷积层</font>

In [3]:
class GraphConvolution(nn.Module):
    def __init__(self, input_dim, output_dim, use_bias=True):
        super(GraphConvolution, self).__init__()
        
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.use_bias = use_bias
        self.weights = nn.Parameter(torch.Tensor(input_dim, output_dim))
        
        if self.use_bias:
            self.bias = nn.Parameter(torch.Tensor(output_dim))
        else:
            self.register_parameter('bias', None)
            
        self.reset_parameters()
    
    def reset_parameters(self):
        init.kaiming_normal_(self.weights)
        if self.use_bias:
            init.zeros_(self.bias)
    
    def forward(self, adjacency, input_feature):
        support = torch.mm(input_feature, self.weights)
        output = torch.sparse.mm(adjacency, support)
        if self.use_bias:
            output += self.bias
        
        return output

<font color='red'>**Step 4:** 搭建GCN整体模型</font>

In [4]:
class GCN_net(nn.Module):
    def __init__(self, input_dim=1433):
        super(GCN_net, self).__init__()
        self.gcn1 = GraphConvolution(input_dim, 16)
        self.gcn2 = GraphConvolution(16, 7)
    
    def forward(self, adjacency, feature):
        h = F.relu(self.gcn1(adjacency, feature))
        logits = self.gcn2(adjacency, h)
        return logits
    

**Step 5:** 设置超参数与数据模型结合

In [5]:
def normalizetion(adjacency):
    adjacency += sp.eye(adjacency.shape[0])
    degree = np.array(adjacency.sum(1))
    d_hat = sp.diags(np.power(degree, -0.5).flatten())
    return d_hat.dot(adjacency).dot(d_hat).tocoo()

learning_rate = 0.1
weight_decay = 5e-4
epochs = 200

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GCN_net().to(device)

criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

dataset = CoraData().data
x = dataset.x / dataset.x.sum(1, keepdims=True)

tensor_x = torch.from_numpy(x).to(device)
tensor_y = torch.from_numpy(dataset.y).to(device)

tensor_train_mask = torch.from_numpy(dataset.train_mask).to(device)
tensor_val_mask = torch.from_numpy(dataset.val_mask).to(device)
tensor_test_mask = torch.from_numpy(dataset.test_mask).to(device)

normalize_adjacency = normalizetion(dataset.adjacency)

indices = torch.from_numpy(
    np.asarray([normalize_adjacency.row,
               normalize_adjacency.col]).astype('int64')).long()
values = torch.from_numpy(normalize_adjacency.data.astype(np.float32))

tensor_adjacency = torch.sparse.FloatTensor(indices, values, (2708, 2708)).to(device)

Using Cached file: cora\processed_cora.pkl


In [6]:
def train():
    loss_history = []
    val_acc_history = []
    model.train()
    train_y = tensor_y[tensor_train_mask]
    
    for epoch in range(epochs):
        logits = model(tensor_adjacency, tensor_x)
        train_mask_logits = logits[tensor_train_mask]
        
        loss = criterion(train_mask_logits, train_y)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        train_acc = test(tensor_train_mask)
        val_acc = test(tensor_val_mask)
        
        loss_history.append(loss.item())
        val_acc_history.append(val_acc.item())
        
        print("Epoch {:03d} : Loss {:.4f}, TrainACC {:.4}, ValACC {:.4f}".format(
                epoch, loss.item(), train_acc.item(), val_acc.item()))
        
    return loss_history, val_acc_history

def test(mask):
    model.eval()
    with torch.no_grad():
        logits = model(tensor_adjacency, tensor_x)
        test_mask_logits = logits[mask]
        predict_y = test_mask_logits.max(1)[1]
        accuracy = torch.eq(predict_y, tensor_y[mask]).float().mean()
    
    return accuracy


In [7]:
train()

Epoch 000 : Loss 1.9410, TrainACC 0.1786, ValACC 0.0920
Epoch 001 : Loss 1.8645, TrainACC 0.5857, ValACC 0.4380
Epoch 002 : Loss 1.7614, TrainACC 0.8429, ValACC 0.5960
Epoch 003 : Loss 1.6250, TrainACC 0.8214, ValACC 0.5820
Epoch 004 : Loss 1.4824, TrainACC 0.8929, ValACC 0.6620
Epoch 005 : Loss 1.3102, TrainACC 0.9143, ValACC 0.7480
Epoch 006 : Loss 1.1420, TrainACC 0.9286, ValACC 0.7360
Epoch 007 : Loss 0.9759, TrainACC 0.9357, ValACC 0.7280
Epoch 008 : Loss 0.8250, TrainACC 0.9357, ValACC 0.7480
Epoch 009 : Loss 0.6890, TrainACC 0.9571, ValACC 0.7760
Epoch 010 : Loss 0.5763, TrainACC 0.95, ValACC 0.7820
Epoch 011 : Loss 0.4823, TrainACC 0.9786, ValACC 0.7760
Epoch 012 : Loss 0.4087, TrainACC 0.9857, ValACC 0.7840
Epoch 013 : Loss 0.3521, TrainACC 0.9857, ValACC 0.7920
Epoch 014 : Loss 0.3076, TrainACC 0.9786, ValACC 0.7880
Epoch 015 : Loss 0.2766, TrainACC 0.9786, ValACC 0.7900
Epoch 016 : Loss 0.2517, TrainACC 0.9857, ValACC 0.7900
Epoch 017 : Loss 0.2339, TrainACC 0.9857, ValACC 0

Epoch 154 : Loss 0.1157, TrainACC 1.0, ValACC 0.7980
Epoch 155 : Loss 0.1158, TrainACC 1.0, ValACC 0.7920
Epoch 156 : Loss 0.1158, TrainACC 1.0, ValACC 0.8000
Epoch 157 : Loss 0.1158, TrainACC 1.0, ValACC 0.7960
Epoch 158 : Loss 0.1157, TrainACC 1.0, ValACC 0.7960
Epoch 159 : Loss 0.1157, TrainACC 1.0, ValACC 0.7980
Epoch 160 : Loss 0.1157, TrainACC 1.0, ValACC 0.7900
Epoch 161 : Loss 0.1157, TrainACC 1.0, ValACC 0.8020
Epoch 162 : Loss 0.1156, TrainACC 1.0, ValACC 0.7960
Epoch 163 : Loss 0.1156, TrainACC 1.0, ValACC 0.7960
Epoch 164 : Loss 0.1157, TrainACC 1.0, ValACC 0.8020
Epoch 165 : Loss 0.1157, TrainACC 1.0, ValACC 0.7980
Epoch 166 : Loss 0.1158, TrainACC 1.0, ValACC 0.8000
Epoch 167 : Loss 0.1156, TrainACC 1.0, ValACC 0.7940
Epoch 168 : Loss 0.1156, TrainACC 1.0, ValACC 0.7980
Epoch 169 : Loss 0.1155, TrainACC 1.0, ValACC 0.7980
Epoch 170 : Loss 0.1156, TrainACC 1.0, ValACC 0.7940
Epoch 171 : Loss 0.1157, TrainACC 1.0, ValACC 0.7940
Epoch 172 : Loss 0.1158, TrainACC 1.0, ValACC 

([1.9409539699554443,
  1.8645037412643433,
  1.761406421661377,
  1.625008463859558,
  1.4823923110961914,
  1.3102192878723145,
  1.141982078552246,
  0.9758604764938354,
  0.8250491619110107,
  0.6890226602554321,
  0.5763481855392456,
  0.4823465049266815,
  0.4086972773075104,
  0.352145791053772,
  0.30763059854507446,
  0.2766127288341522,
  0.2517257034778595,
  0.23393183946609497,
  0.2223677784204483,
  0.2128671407699585,
  0.20768193900585175,
  0.2022581696510315,
  0.19869013130664825,
  0.19375932216644287,
  0.1899913251399994,
  0.18426474928855896,
  0.17939002811908722,
  0.1732550859451294,
  0.16800399124622345,
  0.16197843849658966,
  0.15734148025512695,
  0.15267746150493622,
  0.14915849268436432,
  0.14576983451843262,
  0.1430133879184723,
  0.14083564281463623,
  0.13859152793884277,
  0.13690851628780365,
  0.13495668768882751,
  0.1332520693540573,
  0.1317756623029709,
  0.13026569783687592,
  0.12986071407794952,
  0.12955424189567566,
  0.130538612604