# GNN的Batch示例

**问题简介**：由于GNN处理的数据通常来说是不规则、格式不统一的图(graph)，因此，如何将数据进行批处理并输入到神经网络中进行训练是一个比较常见的问题，该代码使用`对角邻接矩阵`的方式来实现批处理问题(受到了PyG框架的启发)。该代码的数据集使用人工生成的图分类数据集，并使用Pytorch框架进行实现数据载入、模型构建、训练、评估等流程。

# 任务定义和数据集生成

## 任务定义

子图匹配分类：给定一个子图(subgraph)$g$以及图(graph)的数据集$\mathcal{G}=\{G_1,G_2,...,G_n\}$，对应的标签为$\mathcal{Y}=\{y_1,y_2,...,y_n\}$，对于任意的图(graph)$G_i$及其标签$y_i$，有：
$$
\begin{equation}
y_i=\left\{
\begin{aligned}
1 & \text{ }G_i包含子图g \\
0 & \text{ }G_i不包含子图g \\
\end{aligned}
\right.
\end{equation}
$$

## 数据集生成

图(graph)都可以定义为$m$个节点的集合$\mathcal{N}=\{v_1,v_2,...,v_m\}$和$n$条边的集合$\mathcal{E}=\{e_1,e_2,...,e_n\}$，其中，边的数据结构为两个节点的元组，即$(v_i,v_j)$。现设定数据集： 
+ 有26种节点：A,B,C,...,Z。每一种节点都有特定的特征向量，比如one-hot。
+ 图(graph)由不定数量的上述类型节点和不定数量连接的边构成。

In [98]:
"""
Code for dataset generation.
"""
import string
import numpy as np
import json
import random
import networkx as nx
from matplotlib import pyplot as plt
import time

In [2]:
"""
Generating nodes dict.
"""
node_types = list(string.ascii_uppercase)
nodes_dict = dict([(k, v) for v, k in enumerate(node_types)])
nodes_dict_path = "./data/nodes_dict.json"

print("Saving node dict...")
with open(nodes_dict_path, "w") as fp:
    json.dump({
        "itos" : node_types,
        "stoi" : nodes_dict
    }, fp)
print("Successfully saving dict!")

Saving node dict...
Successfully saving dict!


In [95]:
"""
Show graph.
Input : 
    g : tuple(list, list)
"""
def show_graph(g):
    labels = dict([(k,v) for k, v in enumerate(g[0])])
    nodes = range(len(g[0]))
    
    G = nx.Graph()
    G.add_nodes_from(nodes)
    G.add_edges_from(g[1])
    pos = nx.spring_layout(G)
    nx.draw(G, pos)
    nx.draw_networkx_labels(G, pos, labels)

subgraph = (["A", "A", "B", "C"], 
            [(0, 1),
             (0, 2),
             (1, 2),
             (2, 3)])
min_nodes_num = 5
max_nodes_num = 50
graph_num = 10000

"""
Generating graph dataset. There are three steps:
Step1 : Randomly choose number of nodes(N).
Step2 : Generate random graph with edge number ranging from N-1 to N * (N - 1) / 2.
Step3 : Remove unconnected graph.
Step4 : Add subgraph to some graphs.
"""
N = 0
graphs = []
random.seed(0)
while N < graph_num:
    node_num = random.randint(min_nodes_num, max_nodes_num)
    edge_num = random.randint(node_num-1, node_num * (node_num - 1) / 2)
    G = nx.random_graphs.dense_gnm_random_graph(node_num, edge_num)
    if nx.connected.is_connected(G):
        graphs.append(G)
        N += 1
        if N % 1000 == 0:
            print("{} graphs have been generated!".format(N))

"""
Transform nx.Graph into our graph type.
"""
def transform_nx_graph(g):
    nodes = random.choices(population=node_types, k=len(g.nodes))
    edges = list(g.edges)
    
    return (nodes, edges)

"""
Merge subgraph into graph.
"""
def merge_subgraph_into_graph(g, sg):
    g_node_num = len(g[0])
    sg_node_num = len(sg[0])
    
    merge_edges = [(s+g_node_num, d+g_node_num) for s, d in sg[1]]
    
    g_range = range(g_node_num)
    sg_range = range(g_node_num, g_node_num+sg_node_num)
    new_edges_num = random.randint(1, g_node_num)
    src_list = random.choices(population=g_range, k=new_edges_num)
    dst_list = random.choices(population=sg_range, k=new_edges_num)
    new_edges = [(s, d) for s, d in zip(src_list, dst_list)]
    new_edges = list(set(new_edges))
    
    merge_graph = (g[0]+sg[0],
                   g[1]+merge_edges+new_edges)
    
    return merge_graph

graphs = [transform_nx_graph(g) for g in graphs]
graphs_with_sg = [(merge_subgraph_into_graph(g, subgraph), 1) for g in graphs[:len(graphs)//2]]
graphs_without_sg = [(g, 0) for g in graphs[len(graphs)//2:]]
graphs = graphs_with_sg + graphs_without_sg
random.shuffle(graphs)

print("Saving dataset!")
dataset_path = "./data/dataset.json"
with open(dataset_path, "w") as fp:
    json.dump({
        "time" : time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()),
        "subgraph" : subgraph,
        "graphs" : graphs,
        "min_nodes_num" : min_nodes_num,
        "max_nodes_num" : max_nodes_num,
        "graphs_num" : graph_num
    }, fp)

1000 graphs have been generated!
2000 graphs have been generated!
2000 graphs have been generated!
3000 graphs have been generated!
4000 graphs have been generated!
5000 graphs have been generated!
6000 graphs have been generated!
7000 graphs have been generated!
8000 graphs have been generated!
9000 graphs have been generated!
10000 graphs have been generated!
