In [1]:
# https://zhuanlan.zhihu.com/p/344067462
# https://zhuanlan.zhihu.com/p/142205899
import torch as torch
import dgl
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
# 定义图神经网络GraphSAGE
from dgl.nn.pytorch import SAGEConv

In [2]:
class GraphSAGE(nn.Module):
    def __init__(self, 
                 in_feats,
                 n_hidden, # hidden size也可以是一个list
                 n_classes,
                 n_layers,
                 activation,
                 dropout,
                 aggregator):
        
        super(GraphSAGE, self).__init__()
        self.n_layers = n_layers
        self.n_hidden = n_hidden
        self.n_classes = n_classes
        self.layer = nn.ModuleList()
        self.layer.append(SAGEConv(in_feats, n_hidden, aggregator))
        for i in range(1, n_layers - 1):
            self.layer.append(SAGEConv(n_hidden, n_hidden, aggregator))
        self.layer.append(SAGEConv(n_hidden, n_classes, aggregator))
        self.dropout = nn.Dropout(dropout)
        self.activation = activation
    
    def forward(self, blocks, feas):
        h = feas
        for i, (layer, block) in enumerate(zip(self.layer, blocks)):
            h = layer(block, h)
            if i != self.n_layers - 1:
                h = self.activation(h)
                h = self.dropout(h)
        return h
    
    def inference(self, my_net, val_nid, batch_s, num_worker, device):
        sampler = dgl.dataloading.MultiLayerFullNeighborSampler(self.n_layers)
        dataloader = dgl.dataloading.DataLoader(
                    my_net,
                    val_nid,
                    sampler,
                    batch_size = batch_s,
                    shuffle=True,
                    drop_last=False,
                    num_workers=num_worker
                )
        
        ret = torch.zeros(my_net.num_nodes(), self.n_classes)
        
        for input_nodes, output_nodes, blocks in dataloader:
            h = blocks[0].srcdata['features'].to(device)
            for i, (layer, block) in enumerate(zip(self.layer, blocks)):
                block = block.int().to(device)
                h = layer(block, h)
                if i != self.n_layers - 1 :
                    h = self.activation(h)
                    h = self.dropout(h)
            ret[output_nodes] = h.cpu()
        return ret

In [3]:
def load_subtensor(nfeat, labels, seeds, input_nodes, device):
    """
    Extracts features and labels for a subset of nodes.
    """
    batch_inputs = nfeat[input_nodes].to(device)
    batch_labels = labels[seeds].to(device)
    return batch_inputs, batch_labels

def evaluate(model, my_net, labels, val_nid, val_mask, batch_s, num_worker, device):
    
    model.eval()
    with torch.no_grad():
        label_pred = model.inference(my_net, val_nid,  batch_s, num_worker, device)
    model.train()
    return (torch.argmax(label_pred[val_mask], dim=1) == labels[val_mask]).float().sum() / len(label_pred[val_mask])

In [4]:
# 写一个分batch训练的过程
import itertools

def run(data, train_val_data, args, sample_size, learning_rate, device):
    in_feats, n_classes, my_net, fea_para = data
    hidden_size, n_layers, activation, dropout, aggregator, batch_s, num_worker  = args
    
#     my_net = my_net.to(device)
    
    # 设置一下训练集和测试集，val_mask
    train_mask, test_mask, val_mask, train_nid, test_nid, val_nid = train_val_data
    
    # 训练模型的过程
    nfeat = my_net.ndata['features']
    labels = my_net.ndata['label']
    sampler = dgl.dataloading.MultiLayerNeighborSampler(sample_size)
    
    dataloader = dgl.dataloading.DataLoader(
        my_net,
        train_nid,
        sampler,
        batch_size = batch_s,
        shuffle=True,
        drop_last=False,
        num_workers=num_worker
    )
    
    model = GraphSAGE(in_feats, hidden_size, n_classes, n_layers, activation, dropout, aggregator)
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
#     optimizer = torch.optim.Adam(itertools.chain(model.parameters(), fea_para.parameters()), lr=0.01)
    model.train()
    
    loss_fun = nn.CrossEntropyLoss()
    loss_fun.to(device)
    
    for epoch in range(100):
        print("************************************************************")
        for batch, (input_nodes, output_nodes, block) in enumerate(dataloader):
        
#             batch_feature = block[0].srcdata['features'] 这里也可以图省事不写load函数，感觉没啥问题
#             batch_label = block[-1].dstdata['label']
            batch_feature, batch_label = load_subtensor(nfeat, labels, output_nodes, input_nodes, device)
            block = [block_.int().to(device) for block_ in block]
            # block = [block_.to(device) for block_ in block]
            model_pred = model(block, batch_feature)
            loss = loss_fun(model_pred, batch_label)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            if batch % 1 == 0:
                print('Batch %d | Loss: %.4f' % (batch, loss.item()))
        
        # 验证一下模型的准确率
        if epoch % 10 == 0:
            print("____________________________________________________________")
            val_acc = evaluate(model, my_net, labels, val_nid, val_mask, batch_s, num_worker, device)
            train_acc = evaluate(model, my_net, labels, train_nid, train_mask, batch_s, num_worker, device)
            print('Epoch %d | Val ACC: %.4f | Train ACC: %.4f' % (epoch, val_acc.item(), train_acc.item()))
    
    # 模型训练完毕，检查test集合的acc
    acc_test = evaluate(model, my_net, labels, test_nid, test_mask, batch_s, num_worker, device)
    print('Test ACC: %.4f' % (acc_test.item()))
    return model

In [5]:
def train_val_split(node_fea):
    
    # 简单划分下训练集
    train_node_ids = np.array(node_fea.groupby('label_number').apply(lambda x : x.sort_values('node_id_number')['node_id_number'].values[:20]))
    val_node_ids = np.array(node_fea.groupby('label_number').apply(lambda x : x.sort_values('node_id_number')['node_id_number'].values[21:110]))
    test_node_ids = np.array(node_fea.groupby('label_number').apply(lambda x : x.sort_values('node_id_number')['node_id_number'].values[111:300]))


    train_nid = []
    val_nid = []
    test_nid = []
    for (train_nodes, val_nodes, test_nodes) in zip(train_node_ids, val_node_ids, test_node_ids):
        train_nid.extend(train_nodes)
        val_nid.extend(val_nodes)
        test_nid.extend(test_nodes)
        
    train_mask = node_fea['node_id_number'].apply(lambda x : x in train_nid)
    val_mask = node_fea['node_id_number'].apply(lambda x : x in val_nid)
    test_mask = node_fea['node_id_number'].apply(lambda x : x in test_nid)
    
    return train_mask, test_mask, val_mask, train_nid, test_nid, val_nid

def loaddata():
    node_fea = pd.read_table('/home/ray/code/python/GNN-from-Scratch/GNN/data/cora/cora.content', header=None)
    edges = pd.read_table('/home/ray/code/python/GNN-from-Scratch/GNN/data/cora/cora.cites', header=None)
    # 0是node id， 1434是node label
    node_fea.rename(columns={0:'node_id', 1434:'label'}, inplace=True)

    nodeID_number_dict = dict(zip(node_fea['node_id'].unique(), range(node_fea['node_id'].nunique())))
    node_fea['node_id_number'] = node_fea['node_id'].map(nodeID_number_dict)
    edges['edge1'] = edges[0].map(nodeID_number_dict)
    edges['edge2'] = edges[1].map(nodeID_number_dict)

    label_dict = dict(zip(node_fea['label'].unique(), range(node_fea['label'].nunique())))
    node_fea['label_number'] = node_fea['label'].map(label_dict)
    
    src = np.array(edges['edge1'].values)
    dst = np.array(edges['edge2'].values)

    u = np.concatenate([src, dst])
    v = np.concatenate([dst, src])

    my_net = dgl.DGLGraph((u, v))

    # 尝试第一种，用Embedding方法得到的相关
    fea_id = range(1, 1434)
    tensor_fea = torch.tensor(node_fea[fea_id].values, dtype=torch.float32)

    fea_np = nn.Embedding(2708, 1433)
    fea_np.weight = nn.Parameter(tensor_fea)

#     my_net.ndata['features'] =  torch.tensor(node_fea[fea_id].values, dtype=torch.float32)
    my_net.ndata['features'] = fea_np.weight
    my_net.ndata['label'] = torch.tensor(node_fea['label_number'].values)
    
    in_feats = 1433
    n_classes = node_fea['label'].nunique()
    
    data = in_feats, n_classes, my_net, fea_np
    train_val_data = train_val_split(node_fea)
    
    return data, train_val_data

In [6]:
# 参数设置

data, train_val_data = loaddata()

hidden_size = 16
n_layers = 2
sample_size = [10, 25]
activation = F.relu
dropout = 0.5
aggregator = 'mean'
batch_s = 128
num_worker = 0
learning_rate = 0.003
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
args =  hidden_size, n_layers, activation, dropout, aggregator, batch_s, num_worker  
trained_model = run(data, train_val_data, args, sample_size, learning_rate, device)



************************************************************




DGLError: [07:36:56] /opt/dgl/src/runtime/c_runtime_api.cc:82: Check failed: allow_missing: Device API cuda is not enabled. Please install the cuda version of dgl.
Stack trace:
  [bt] (0) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dmlc::LogMessageFatal::~LogMessageFatal()+0x75) [0x7f063d33e8f5]
  [bt] (1) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPIManager::GetAPI(std::string, bool)+0x202) [0x7f063d6ada92]
  [bt] (2) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::runtime::DeviceAPI::Get(DGLContext, bool)+0x1e1) [0x7f063d6aa071]
  [bt] (3) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::Empty(std::vector<long, std::allocator<long> >, DGLDataType, DGLContext)+0x13b) [0x7f063d6c554b]
  [bt] (4) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::runtime::NDArray::CopyTo(DGLContext const&) const+0xc3) [0x7f063d6ffd53]
  [bt] (5) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::UnitGraph::CSR::CopyTo(DGLContext const&) const+0x1f0) [0x7f063d81db10]
  [bt] (6) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::UnitGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0xd1) [0x7f063d80cf21]
  [bt] (7) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(dgl::HeteroGraph::CopyTo(std::shared_ptr<dgl::BaseHeteroGraph>, DGLContext const&)+0xf6) [0x7f063d70c5d6]
  [bt] (8) /home/ray/code/test/PDE/PeRCNN/percnn/lib/python3.8/site-packages/dgl/libdgl.so(+0x51b396) [0x7f063d71b396]

