##  GIN模型实现
*参考资料：*
* [源码地址](https://github.com/weihua916/powerful-gnns)

### 导入所需python文件和库

In [2]:
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np

from tqdm import tqdm

from util import load_data, separate_data
from models.graphcnn import GraphCNN

### 参数设置

In [3]:
# Training settings
# Note: Hyper-parameters need to be tuned in order to obtain results reported in the paper.
parser = argparse.ArgumentParser(description='PyTorch graph convolutional neural net for whole-graph classification')
parser.add_argument('--dataset', type=str, default="MUTAG",help='name of dataset (default: MUTAG)')
parser.add_argument('--device', type=int, default=0,help='which gpu to use if any (default: 0)')
parser.add_argument('--batch_size', type=int, default=32,help='input batch size for training (default: 32)')
parser.add_argument('--iters_per_epoch', type=int, default=50,help='number of iterations per each epoch (default: 50)')
parser.add_argument('--epochs', type=int, default=50,help='number of epochs to train (default: 350)')
parser.add_argument('--lr', type=float, default=0.01,help='learning rate (default: 0.01)')
parser.add_argument('--seed', type=int, default=0,help='random seed for splitting the dataset into 10 (default: 0)')
parser.add_argument('--fold_idx', type=int, default=0, help='the index of fold in 10-fold validation. Should be less then 10.')
parser.add_argument('--num_layers', type=int, default=5, help='number of layers INCLUDING the input one (default: 5)')
parser.add_argument('--num_mlp_layers', type=int, default=2, help='number of layers for MLP EXCLUDING the input one (default: 2). 1 means linear model.')
parser.add_argument('--hidden_dim', type=int, default=64, help='number of hidden units (default: 64)')
parser.add_argument('--final_dropout', type=float, default=0.5, help='final layer dropout (default: 0.5)')
parser.add_argument('--graph_pooling_type', type=str, default="sum", choices=["sum", "average"],help='Pooling for over nodes in a graph: sum or average')
parser.add_argument('--neighbor_pooling_type', type=str, default="sum", choices=["sum", "average", "max"],help='Pooling for over neighboring nodes: sum, average or max')
parser.add_argument('--learn_eps', action="store_true", help='Whether to learn the epsilon weighting for the center nodes. Does not affect training accuracy though.')
parser.add_argument('--degree_as_tag', action="store_true", help='let the input node features be the degree of nodes (heuristics for unlabeled graph)')
parser.add_argument('--filename', type = str, default = "output file",help='output file')

# parser.add_argument('--config', type=str, default='./experiments.conf')   # 获取的一些配置文件
#args = parser.parse_args()                                               # pychram 中使用
args = parser.parse_args(args=[])
args


Namespace(dataset='MUTAG', device=0, batch_size=32, iters_per_epoch=50, epochs=50, lr=0.01, seed=0, fold_idx=0, num_layers=5, num_mlp_layers=2, hidden_dim=64, final_dropout=0.5, graph_pooling_type='sum', neighbor_pooling_type='sum', learn_eps=False, degree_as_tag=False, filename='output file')

### 定义损失函数，交叉熵损失函数

In [4]:
criterion = nn.CrossEntropyLoss()

### 固定随机数种子

In [5]:
#set up seeds and gpu device
torch.manual_seed(0)
np.random.seed(0)    
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)
device

device(type='cpu')

### 数据集处理

In [6]:
graphs, num_classes = load_data(args.dataset, args.degree_as_tag)


loading data
# classes: 2
# maximum node tag: 7
# data: 188


In [7]:
print(graphs[0].label)
print(graphs[0].g)
print(graphs[0].node_tags)
print(graphs[0].neighbors)
print(graphs[0].node_features.shape)
print(graphs[0].edge_mat.shape)
print(graphs[0].max_neighbor)

0
Graph with 23 nodes and 27 edges
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2]
[[1, 13], [0, 2], [1, 3, 11], [2, 4], [3, 5], [4, 6, 10], [5, 7, 20], [6, 8], [7, 9], [10, 8, 15], [11, 5, 9], [2, 10, 12], [13, 11, 14], [0, 12], [15, 12, 19], [9, 14, 16], [15, 17], [16, 18], [19, 17], [14, 18], [6, 21, 22], [20], [20]]
torch.Size([23, 7])
torch.Size([2, 54])
3


In [8]:
print(graphs[0].g.nodes)
print(graphs[0].g.edges)

[0, 1, 13, 2, 3, 11, 4, 5, 6, 10, 7, 20, 8, 9, 15, 12, 14, 19, 16, 17, 18, 21, 22]
[(0, 1), (0, 13), (1, 2), (13, 12), (2, 3), (2, 11), (3, 4), (11, 10), (11, 12), (4, 5), (5, 6), (5, 10), (6, 7), (6, 20), (10, 9), (7, 8), (20, 21), (20, 22), (8, 9), (9, 15), (15, 14), (15, 16), (12, 14), (14, 19), (19, 18), (16, 17), (17, 18)]


### 拆分数据集，训练集和测试集

In [9]:
##10-fold cross validation. Conduct an experiment on the fold specified by args.fold_idx.
train_graphs, test_graphs = separate_data(graphs, args.seed, args.fold_idx)
len(train_graphs)

169

In [10]:
len(test_graphs)

19

### GIN模型中的 maxpooling 操作测试(可忽略)

In [11]:
max_deg = max([graph.max_neighbor for graph in test_graphs]) # 计算批次数据中的最大邻居数量
max_deg

3

In [12]:
#create a list of padded neighbor lists
padded_neighbor_list = [] # 用于存储填充后的邻居列表。
start_idx = [0]  # 用于存储每个图的起始索引。
flag = 0
#iterate through the graphs in the current minibatch
for i, graph in enumerate(test_graphs):
    #increment the start index of the current graph
    # print(start_idx[i],len(graph.g))
    start_idx.append(start_idx[i] + len(graph.g))
    # print(start_idx)
    #create a list of padded neighbors
    padded_neighbors = []
    #iterate through the neighbors of the current graph
    for j in range(len(graph.neighbors)):
        #add off-set values to the neighbor indices
        pad = [n + start_idx[i] for n in graph.neighbors[j]]
        #padding, dummy data is assumed to be stored in -1
        pad.extend([-1]*(max_deg - len(pad)))
        # print(pad)
        #append the padded neighbor to the list of padded neighbors
        padded_neighbors.append(pad)
        # print(padded_neighbors)
    #append the list of padded neighbors to the list of padded neighbor lists
    padded_neighbor_list.extend(padded_neighbors)
    # print(padded_neighbor_list)
    flag += 1
    # if flag == 2:
    #     break
print(padded_neighbor_list)

[[1, 13, -1], [0, 2, -1], [1, 3, 11], [2, 4, -1], [3, 5, -1], [4, 6, 10], [5, 7, 20], [6, 8, -1], [7, 9, -1], [10, 8, 15], [11, 5, 9], [2, 10, 12], [13, 11, 14], [0, 12, -1], [15, 12, 19], [9, 14, 16], [15, 17, -1], [16, 18, -1], [19, 17, -1], [14, 18, -1], [6, 21, 22], [20, -1, -1], [20, -1, -1], [24, 36, -1], [23, 25, -1], [24, 26, 34], [25, 27, -1], [26, 28, -1], [27, 29, 33], [28, 30, 48], [29, 31, -1], [30, 32, -1], [33, 31, 41], [34, 28, 32], [25, 33, 35], [36, 34, 40], [23, 35, 37], [36, 38, -1], [37, 39, -1], [40, 38, 44], [41, 35, 39], [32, 40, 42], [41, 43, -1], [44, 42, -1], [39, 43, 45], [44, 46, 47], [45, -1, -1], [45, -1, -1], [29, 49, 50], [48, -1, -1], [48, -1, -1], [52, 60, -1], [51, 53, -1], [52, 54, 58], [53, 55, -1], [54, 56, -1], [55, 57, 63], [58, 56, 62], [53, 57, 59], [60, 58, 61], [51, 59, -1], [62, 59, -1], [57, 61, -1], [56, 64, 65], [63, -1, -1], [63, -1, -1], [67, 71, -1], [66, 68, -1], [67, 69, -1], [68, 70, 78], [71, 69, 72], [66, 70, -1], [70, 73, 77], [

In [13]:
X_concat = torch.cat([graph.node_features for graph in test_graphs], 0)
X_concat.shape

torch.Size([372, 7])

In [14]:
h = X_concat

dummy = torch.min(h, dim = 0)[0]
print(dummy.shape)

h_with_dummy = torch.cat([h, dummy.reshape((1, -1))])
print(h_with_dummy.shape)
print(h_with_dummy[2])

print(h_with_dummy[padded_neighbor_list].shape)
print(h_with_dummy[padded_neighbor_list][2])

pooled_rep = torch.max(h_with_dummy[padded_neighbor_list], dim = 1)[0]
print(pooled_rep.shape)
print(pooled_rep)

torch.Size([7])
torch.Size([373, 7])
tensor([1., 0., 0., 0., 0., 0., 0.])
torch.Size([372, 3, 7])
tensor([[1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 0., 0.]])
torch.Size([372, 7])
tensor([[1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.],
        [1., 1., 0.,  ..., 0., 0., 0.]])


### GIN模型训练

In [15]:
model = GraphCNN(args.num_layers, args.num_mlp_layers, train_graphs[0].node_features.shape[1], args.hidden_dim, num_classes, args.final_dropout, args.learn_eps, args.graph_pooling_type, args.neighbor_pooling_type, device).to(device)

optimizer = optim.Adam(model.parameters(), lr=args.lr)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.5)

In [16]:
def train(args, model, device, train_graphs, optimizer, epoch):
    model.train()

    total_iters = args.iters_per_epoch
    pbar = tqdm(range(total_iters), unit='batch')

    loss_accum = 0
    for pos in pbar:
        # 从训练数据中随机选择一个子集，用于训练神经网络
        # 首先使用np.random.permutation函数对训练数据索引进行随机排序
        # 然后使用[:args.batch_size]选择前batch_size个随机索引作为子集
        selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]

        batch_graph = [train_graphs[idx] for idx in selected_idx]
        output = model(batch_graph)

        labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)

        #compute loss
        loss = criterion(output, labels)

        #backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()         
            optimizer.step()
        

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

        #report
        pbar.set_description('epoch: %d' % (epoch))

    average_loss = loss_accum/total_iters
    print("loss training: %f" % (average_loss))
    
    return average_loss


In [None]:
def train(args, model, device, train_graphs, optimizer, epoch):
    model.train()

    total_iters = args.iters_per_epoch
    pbar = tqdm(range(total_iters), unit='batch')

    loss_accum = 0
    for pos in pbar:
        # 
        selected_idx = np.random.permutation(len(train_graphs))[:args.batch_size]

        batch_graph = [train_graphs[idx] for idx in selected_idx]
        output = model(batch_graph)

        labels = torch.LongTensor([graph.label for graph in batch_graph]).to(device)

        #compute loss
        loss = criterion(output, labels)

        #backprop
        if optimizer is not None:
            optimizer.zero_grad()
            loss.backward()         
            optimizer.step()
        

        loss = loss.detach().cpu().numpy()
        loss_accum += loss

        #report
        pbar.set_description('epoch: %d' % (epoch))

    average_loss = loss_accum/total_iters
    print("loss training: %f" % (average_loss))
    
    return average_loss


In [17]:
###pass data to model with minibatch during testing to avoid memory overflow (does not perform backpropagation)
def pass_data_iteratively(model, graphs, minibatch_size = 64):
    model.eval()
    output = []
    idx = np.arange(len(graphs))
    for i in range(0, len(graphs), minibatch_size):
        sampled_idx = idx[i:i+minibatch_size]
        if len(sampled_idx) == 0:
            continue
        output.append(model([graphs[j] for j in sampled_idx]).detach())
    return torch.cat(output, 0)

def test(args, model, device, train_graphs, test_graphs, epoch):
    model.eval()

    output = pass_data_iteratively(model, train_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in train_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_train = correct / float(len(train_graphs))

    output = pass_data_iteratively(model, test_graphs)
    pred = output.max(1, keepdim=True)[1]
    labels = torch.LongTensor([graph.label for graph in test_graphs]).to(device)
    correct = pred.eq(labels.view_as(pred)).sum().cpu().item()
    acc_test = correct / float(len(test_graphs))

    print("accuracy train: %f test: %f" % (acc_train, acc_test))

    return acc_train, acc_test

## 训练结果

In [18]:
import warnings
warnings.filterwarnings("ignore")

for epoch in range(1, args.epochs + 1):
    scheduler.step()

    avg_loss = train(args, model, device, train_graphs, optimizer, epoch)
    acc_train, acc_test = test(args, model, device, train_graphs, test_graphs, epoch)

    if not args.filename == "":
        with open(args.filename, 'w') as f:
            f.write("%f %f %f" % (avg_loss, acc_train, acc_test))
            f.write("\n")
    print("")

print(model.eps)

epoch: 1: 100%|██████████| 50/50 [00:00<00:00, 76.89batch/s]


loss training: 2.305406
accuracy train: 0.852071 test: 0.789474



epoch: 2: 100%|██████████| 50/50 [00:00<00:00, 82.33batch/s]


loss training: 1.228823
accuracy train: 0.810651 test: 0.894737



epoch: 3: 100%|██████████| 50/50 [00:00<00:00, 76.89batch/s]


loss training: 0.847793
accuracy train: 0.840237 test: 0.842105



epoch: 4: 100%|██████████| 50/50 [00:00<00:00, 80.47batch/s]


loss training: 0.548788
accuracy train: 0.923077 test: 0.842105



epoch: 5: 100%|██████████| 50/50 [00:00<00:00, 83.00batch/s]


loss training: 0.478933
accuracy train: 0.905325 test: 0.842105



epoch: 6: 100%|██████████| 50/50 [00:00<00:00, 78.09batch/s]


loss training: 0.587393
accuracy train: 0.893491 test: 0.842105



epoch: 7: 100%|██████████| 50/50 [00:00<00:00, 79.45batch/s]


loss training: 0.484600
accuracy train: 0.840237 test: 0.736842



epoch: 8: 100%|██████████| 50/50 [00:00<00:00, 79.33batch/s]


loss training: 0.531488
accuracy train: 0.923077 test: 0.894737



epoch: 9: 100%|██████████| 50/50 [00:00<00:00, 78.83batch/s]


loss training: 0.318316
accuracy train: 0.940828 test: 0.842105



epoch: 10: 100%|██████████| 50/50 [00:00<00:00, 72.30batch/s]


loss training: 0.203298
accuracy train: 0.952663 test: 0.842105



epoch: 11: 100%|██████████| 50/50 [00:00<00:00, 65.88batch/s]


loss training: 0.284641
accuracy train: 0.917160 test: 0.842105



epoch: 12: 100%|██████████| 50/50 [00:00<00:00, 80.09batch/s]


loss training: 0.278371
accuracy train: 0.869822 test: 1.000000



epoch: 13: 100%|██████████| 50/50 [00:00<00:00, 81.92batch/s]


loss training: 0.274377
accuracy train: 0.923077 test: 0.842105



epoch: 14: 100%|██████████| 50/50 [00:00<00:00, 81.25batch/s]


loss training: 0.269141
accuracy train: 0.964497 test: 0.894737



epoch: 15: 100%|██████████| 50/50 [00:00<00:00, 80.86batch/s]


loss training: 0.302259
accuracy train: 0.946746 test: 0.789474



epoch: 16: 100%|██████████| 50/50 [00:00<00:00, 80.09batch/s]


loss training: 0.234965
accuracy train: 0.946746 test: 0.842105



epoch: 17: 100%|██████████| 50/50 [00:00<00:00, 79.70batch/s]


loss training: 0.256096
accuracy train: 0.905325 test: 0.894737



epoch: 18: 100%|██████████| 50/50 [00:00<00:00, 79.96batch/s]


loss training: 0.184259
accuracy train: 0.964497 test: 0.842105



epoch: 19: 100%|██████████| 50/50 [00:00<00:00, 81.39batch/s]


loss training: 0.146424
accuracy train: 0.952663 test: 0.842105



epoch: 20: 100%|██████████| 50/50 [00:00<00:00, 78.95batch/s]


loss training: 0.137270
accuracy train: 0.964497 test: 0.842105



epoch: 21: 100%|██████████| 50/50 [00:00<00:00, 80.60batch/s]


loss training: 0.222598
accuracy train: 0.917160 test: 0.842105



epoch: 22: 100%|██████████| 50/50 [00:00<00:00, 80.99batch/s]


loss training: 0.143980
accuracy train: 0.976331 test: 0.842105



epoch: 23: 100%|██████████| 50/50 [00:00<00:00, 77.57batch/s]


loss training: 0.159584
accuracy train: 0.964497 test: 0.894737



epoch: 24: 100%|██████████| 50/50 [00:00<00:00, 80.04batch/s]


loss training: 0.129295
accuracy train: 0.964497 test: 0.842105



epoch: 25: 100%|██████████| 50/50 [00:00<00:00, 79.45batch/s]


loss training: 0.197274
accuracy train: 0.982249 test: 0.842105



epoch: 26: 100%|██████████| 50/50 [00:00<00:00, 81.78batch/s]


loss training: 0.165769
accuracy train: 0.964497 test: 0.894737



epoch: 27: 100%|██████████| 50/50 [00:00<00:00, 80.73batch/s]


loss training: 0.128073
accuracy train: 0.988166 test: 0.894737



epoch: 28: 100%|██████████| 50/50 [00:00<00:00, 80.86batch/s]


loss training: 0.124632
accuracy train: 0.934911 test: 0.894737



epoch: 29: 100%|██████████| 50/50 [00:00<00:00, 80.73batch/s]


loss training: 0.170659
accuracy train: 0.911243 test: 0.789474



epoch: 30: 100%|██████████| 50/50 [00:00<00:00, 83.56batch/s]


loss training: 0.133205
accuracy train: 0.970414 test: 0.842105



epoch: 31: 100%|██████████| 50/50 [00:00<00:00, 87.18batch/s]


loss training: 0.141866
accuracy train: 0.964497 test: 0.842105



epoch: 32: 100%|██████████| 50/50 [00:00<00:00, 82.19batch/s]


loss training: 0.136684
accuracy train: 0.923077 test: 0.789474



epoch: 33: 100%|██████████| 50/50 [00:00<00:00, 83.85batch/s]


loss training: 0.185614
accuracy train: 0.852071 test: 0.894737



epoch: 34: 100%|██████████| 50/50 [00:00<00:00, 86.39batch/s]


loss training: 0.127170
accuracy train: 0.964497 test: 0.842105



epoch: 35: 100%|██████████| 50/50 [00:00<00:00, 88.58batch/s]


loss training: 0.185674
accuracy train: 0.946746 test: 0.842105



epoch: 36: 100%|██████████| 50/50 [00:00<00:00, 83.42batch/s]


loss training: 0.117828
accuracy train: 0.976331 test: 0.842105



epoch: 37: 100%|██████████| 50/50 [00:00<00:00, 74.02batch/s]


loss training: 0.101013
accuracy train: 0.982249 test: 0.842105



epoch: 38: 100%|██████████| 50/50 [00:00<00:00, 83.00batch/s]


loss training: 0.085040
accuracy train: 0.964497 test: 0.842105



epoch: 39: 100%|██████████| 50/50 [00:00<00:00, 84.83batch/s]


loss training: 0.105642
accuracy train: 1.000000 test: 0.842105



epoch: 40: 100%|██████████| 50/50 [00:00<00:00, 83.98batch/s]


loss training: 0.080059
accuracy train: 0.964497 test: 0.842105



epoch: 41: 100%|██████████| 50/50 [00:00<00:00, 83.84batch/s]


loss training: 0.103368
accuracy train: 0.692308 test: 0.684211



epoch: 42: 100%|██████████| 50/50 [00:00<00:00, 84.69batch/s]


loss training: 0.361313
accuracy train: 0.928994 test: 0.842105



epoch: 43: 100%|██████████| 50/50 [00:00<00:00, 81.93batch/s]


loss training: 0.278004
accuracy train: 0.964497 test: 0.842105



epoch: 44: 100%|██████████| 50/50 [00:00<00:00, 86.14batch/s]


loss training: 0.155826
accuracy train: 0.964497 test: 0.842105



epoch: 45: 100%|██████████| 50/50 [00:00<00:00, 84.83batch/s]


loss training: 0.143157
accuracy train: 0.982249 test: 0.842105



epoch: 46: 100%|██████████| 50/50 [00:00<00:00, 85.55batch/s]


loss training: 0.159131
accuracy train: 0.976331 test: 0.894737



epoch: 47: 100%|██████████| 50/50 [00:00<00:00, 84.12batch/s]


loss training: 0.082433
accuracy train: 0.976331 test: 0.894737



epoch: 48: 100%|██████████| 50/50 [00:00<00:00, 87.04batch/s]


loss training: 0.117535
accuracy train: 0.952663 test: 0.842105



epoch: 49: 100%|██████████| 50/50 [00:00<00:00, 85.70batch/s]


loss training: 0.170909
accuracy train: 0.875740 test: 0.842105



epoch: 50: 100%|██████████| 50/50 [00:00<00:00, 84.54batch/s]

loss training: 0.103308
accuracy train: 0.982249 test: 0.842105

Parameter containing:
tensor([0., 0., 0., 0.], requires_grad=True)



