# Graph convolutional network

classification on graphs is achieved by first embedding node features into a low dimensional space, then grouping nodes and summarizing them.

In [None]:
import torch
import torch.nn.functional as F
import torch.nn as nn
from torch_geometric.nn import TopKPooling
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp
from torch_geometric.utils import (add_self_loops, sort_edge_index,
                                   remove_self_loops)
from torch_sparse import spspmm

from net.braingraphconv import MyNNConv

Network  is formed by three different types of layers: graph convolutional layers, node pooling layers and a readout layer.
 graph convolutional layer is used to probe the graph structure by using edge features, which contain important information about graphs. For example,the weights of the edges in brain fMRI graphs can represent the relationship between different ROIs.
 we define h(l)i ∈ R d(l) as the features for the ith node in the lth layer, where d(l) is the dimension of the lth layer features. The propagation
 model for the forward-pass update of node representation is calculated as:
 ![figure1](imag\img1.png)
 Given an ROI ordering for all the graphs, we use one-hot encoding to represent the ROI’s location information, instead of using coordinates,because the nodes in the brain are aligned well. Specifically, for node vi, its ROI representation ri is a N−dimensional vector with 1 in the ith entry and 0 for the
other entries.

 A node pooling layer is used to reduce the size of the graph, either by grouping the nodes together or pruning the original graph G to a subgraph Gsby keeping some important nodes only. We will focus on the pruning method,as it is more interpretable and can help detect biomarkers.
A readout layer is used to summarize the node feature vectors {h(l)i} into a single vector z which is finally fed into a classifier for graph classification.

 Given node i’s regional information ri, such as the node’s coordinates in a mesh graph, we propose to learn the vectorized embedding kernel vec(Wi) based on ri on the lth Ra-GNN:
 ![figure1](imag\img2.png)
 The ROIs in the same community will be embedded by the similar kernel so that nodes in different communities are embedded in different ways to reduce the number of learnable params.
  ![figure1](imag\img3.png)

Readout Layer
Lastly, we seek a “flattening” operation to preserve information about the input graph in a fixed-size representation. Concretely, to summarize the output graph of the lth conv-pool block, (V(l), E(l)), we use:
![figure1](imag\imag4.png)

In [None]:
class Network(torch.nn.Module):
    def __init__(self, indim, ratio, nclass, k=8, R=200):
        '''

        :param indim: (int) node feature dimension
        :param ratio: (float) pooling ratio in (0,1)
        :param nclass: (int)  number of classes
        :param k: (int) number of communities
        :param R: (int) number of ROIs
        '''
        super(Network, self).__init__()

        self.indim = indim
        self.dim1 = 32
        self.dim2 = 32
        self.dim3 = 512
        self.dim4 = 256
        self.dim5 = 8
        self.k = k
        self.R = R

        self.n1 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim1 * self.indim))
        self.conv1 = MyNNConv(self.indim, self.dim1, self.n1, normalize=False)
        self.pool1 = TopKPooling(self.dim1, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)
        self.n2 = nn.Sequential(nn.Linear(self.R, self.k, bias=False), nn.ReLU(), nn.Linear(self.k, self.dim2 * self.dim1))
        self.conv2 = MyNNConv(self.dim1, self.dim2, self.n2, normalize=False)
        self.pool2 = TopKPooling(self.dim2, ratio=ratio, multiplier=1, nonlinearity=torch.sigmoid)


        self.fc1 = torch.nn.Linear((self.dim1+self.dim2)*2, self.dim2)
        self.bn1 = torch.nn.BatchNorm1d(self.dim2)
        self.fc2 = torch.nn.Linear(self.dim2, self.dim3)
        self.bn2 = torch.nn.BatchNorm1d(self.dim3)
        self.fc3 = torch.nn.Linear(self.dim3, nclass)




    def forward(self, x, edge_index, batch, edge_attr, pos):

        x = self.conv1(x, edge_index, edge_attr, pos)
        x, edge_index, edge_attr, batch, perm, score1 = self.pool1(x, edge_index, edge_attr, batch)

        pos = pos[perm]
        x1 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        edge_attr = edge_attr.squeeze()
        edge_index, edge_attr = self.augment_adj(edge_index, edge_attr, x.size(0))

        x = self.conv2(x, edge_index, edge_attr, pos)
        x, edge_index, edge_attr, batch, perm, score2 = self.pool2(x, edge_index,edge_attr, batch)

        x2 = torch.cat([gmp(x, batch), gap(x, batch)], dim=1)

        x = torch.cat([x1,x2], dim=1)
        x = self.bn1(F.relu(self.fc1(x)))
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.bn2(F.relu(self.fc2(x)))
        x= F.dropout(x, p=0.5, training=self.training)
        x = F.log_softmax(self.fc3(x), dim=-1)

        return x,self.pool1.weight,self.pool2.weight, torch.sigmoid(score1).view(x.size(0),-1), torch.sigmoid(score2).view(x.size(0),-1)

    def augment_adj(self, edge_index, edge_weight, num_nodes):
        edge_index, edge_weight = add_self_loops(edge_index, edge_weight,
                                                 num_nodes=num_nodes)
        edge_index, edge_weight = sort_edge_index(edge_index, edge_weight,
                                                  num_nodes)
        edge_index, edge_weight = spspmm(edge_index, edge_weight, edge_index,
                                         edge_weight, num_nodes, num_nodes,
                                         num_nodes)
        edge_index, edge_weight = remove_self_loops(edge_index, edge_weight)
        return edge_index, edge_weight
