In [6]:
import tensorflow as tf
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import gnn.gnn_utils as gnn_utils

import scipy
import scipy.io as spio

# 稀疏矩阵操作
import scipy.sparse as sp
from collections import namedtuple
# COO (Coordinate Format) COO使用row,col, values 来进行矩阵表示，进行矩阵压缩
from scipy.sparse import coo_matrix
# DIA (Diagonal Storage Format)对角矩阵的存储方式, 通过values, distance形式来进行存储和压缩
from scipy.sparse import dia_matrix

## 读取训练数据
python中读取.mat文件可以使用scipy.io 中的loadmat()读取文件，使用savemat()保存文件   

scipy是一个开源的python算法库和数学工具包  
包含的模块有最优化、线性袋鼠、积分、插值、特殊函数、快速傅立叶变化，信号处理和图像处理、常微分方程求解等科学计算  

类似的软件有matlabm scilab

In [14]:
def loadmat(filename):
    def _check_keys(d):
        """判断是否为嵌套的mat对象
            如果是的话则进行解析
        """
        for key in d:
            if isinstance(d[key], spio.matlab.mio5_params.mat_struct):
                d[key] = _todict(d[key])
        
        return d

    def _todict(matobj):
        """将obj转化为Python的dict对象
        """
        d = {}
        for strg in matobj._fieldnames:
            elem = matobj.__dict__[strg]

            if isinstance(elem, spio.matlab.mio5_params.mat_struct):
                d[strg] = _todict(elem)
            elif isinstance(elem, np.ndarray):
                d[strg] = _tolist(elem)
            else:
                d[strg] = elem
        return d
    
    def _tolist(ndarray):
        """将List中的mat对象转化为python数组
        """
        elem_list = []
        for sub_elem in ndarray:
            if isinstance(sub_elem, spio.matlab.mio5_params.mat_struct):
                elem_list.append(_todict(sub_elem))
            elif isinstance(sub_elem, np.ndarray):
                elem_list.append(_tolist(sub_elem))
            else:
                elem_list.append(sub_elem)
        return elem_list

    data = scipy.io.loadmat(filename, struct_as_record=False, squeeze_me=True)
    return _check_keys(data)


拆分训练集、验证集、测试集

In [22]:
SparseMatrix = namedtuple("SparseMatrix", "indices values dense_shape")

def GetInput(mat,lab, batch=1, grafi=None):
    """grafi是具有相同节点数的 向量，表示节点属于哪张图
    """
    # 如果只有一张图，则batch_number=0
    batch_number = grafi.max() // batch

    # 邻接矩阵的df  
    # 邻接矩阵为图论中的伴生矩阵，在无向图中，一定是对称矩阵
    #TODO  这里为什么只有 两列？见https://zhuanlan.zhihu.com/p/85768094
    dmat = pd.DataFrame(mat, columns=["id_1", "id_2"])

    # 包含每个节点特征的df
    dlab = pd.DataFrame(lab, columns=["lab"+str(i) for i in range(0, lab.shape[1])])

    # 每个节点的贡献图dataFrame
    dgr = pd.DataFrame(grafi, columns=["graph"])

    # 创建输入矩阵  id_p, id_c, label_p, lable_c, graph_belong
    dresult = dmat
    dresult = pd.merge(dresult, dlab, left_on="id_1", right_index=True, how='left')
    dresult = pd.merge(dresult, dlab, left_on="id_2", right_index=True, how='left')
    dresult = pd.merge(dresult, dgr, left_on="id_1", right_index=True, how='left')

    # 
    data_batch = []
    arcnode_batch = []
    nodegraph_batch = []
    node_in = []

    # 生成batch , 每次batch重新指定id 使其从0开始
    for i in range(0, batch_number + 1):
        # 获取当前batch的最小index
        grafo_indexMin = (i*batch)
        grafo_indexMax = (i*batch) + batch

        # 做位运算
        adj = dresult.loc[(dresult["graph"] >= grafo_indexMin) & (dresult["graph"]<grafo_indexMax)]
        min_id = adj[["id_1", "id_2"]].min(axis=0).min()

        # 每个batch的 index从0 开始
        adj["id_1"] = adj["id_1"] - min_id
        adj["id_2"] = adj["id_2"] - min_id

        min_gr = adj["graph"].min()
        adj["graph"] = adj["graph"] - min_gr

        # 给每一个batch添加数据
        # id_2, lab0_1, lab1_1, lab0_2, lab1_2
        data_batch.append(adj.values[:, :-1])

        # 邻接矩阵创建
        max_id = int(adj[["id_1", "id_2"]].max(axis=0).max())

        max_gr = int(adj["graph"].max())

        # 获取节点id
        mt = adj[["id_1", "id_2"]].values
        arcnode = np.zeros((mt.shape[0], max_id+1))

        arcnode = SparseMatrix(indices=np.stack((mt[:, 0], np.arange(len(mt))), axis=1), values=np.ones([len(mt)]), dense_shape=[max_id+1, len(mt)])

        arcnode_batch.append(arcnode)

        nodegraph = SparseMatrix(indices=np.stack((dgr["graph"].values, np.arange(max_id+1)), axis=1), values=np.ones(max_id+1),dense_shape=[max_gr+1, max_id + 1]) 

        nodegraph_batch.append(nodegraph)

        # 每张图中的节点数量
        grbtc = dgr.loc[(dgr["graph"] >= grafo_indexMin) & (dgr["graph"] < grafo_indexMax)]

        node_in.append(grbtc.groupby(["graph"]).size().values)
    
    return data_batch, arcnode_batch, nodegraph_batch, node_in


def set_load_general(data_path, set_type, set_name="sub_30_15"):
    types = ["train", "validation", "test"]
    # 加载训练数据
    train = loadmat(os.path.join(data_path, "{}.mat".format(set_name)))
    train = train["dataSet"]

    try:
        if set_type not in types:
            raise NameError("Wrong Set Name")
        # 加载邻接矩阵
        adj = coo_matrix(train['{}Set'.format(set_type)]["connMatrix"].T)
        adj = np.array([adj.row, adj.col]).T

        # 取对应节点的labels
        lab = np.asarray(train['{}Set'.format(set_type)]["nodeLabels"]).T
        
        # 如果Label只有一维，则进行reshape
        if len(lab.shape)<2:
            lab = lab.reshape(lab.shape[0], 1)
        
        # 取target并转为one-hot编码
        target = np.asarray(train['{}Set'.format(set_type)]['targets']).T
        labels = pd.get_dummies(pd.Series(target))
        # 输出为每个节点的one-hot向量
        labels = labels.values

        # 计算输入和边
        inp, arcnode, nodegraph, nodein = GetInput(adj, lab, 1, np.zeros(len(labels), dtype=int))

        print("我只醒了")

        return inp, arcnode, nodegraph, nodein, labels, lab

    except Exception as e:
        print("Caught exception: ", e)
        exit(1)



In [24]:
data_path = "/Users/yangalan/alan/home/learn/temp/gnn_demo/data"
set_name = "sub_15_7_200"

# 训练集
inp, arcnode, nodegraph, nodein, labels, _ = set_load_general(data_path, "train", set_name=set_name)
inp = [a[:, 1:] for a in inp]

# 测试集

inp_val, arcnode_val, nodegraph_val, nodein_val, labels_val, _ = set_load_general(data_path, "validation", set_name=set_name)
inp_val = [a[:, 1:] for a in inp_val]

我只醒了
我只醒了


## 定义超参

In [26]:
input_dim = len(inp[0][0])
state_dim = 10
output_dim = 2
state_threshold = 0.001
max_iter = 30