In [1]:
#Spatial Convolutional Network

from torch_geometric.nn import GCNConv
from typing import Optional, Tuple

import torch
from torch import Tensor
from torch.nn import Parameter
from torch_scatter import scatter_add
from torch_sparse import SparseTensor, fill_diag, matmul, mul
from torch_sparse import sum as sparsesum

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import Adj, OptTensor, PairTensor
from torch_geometric.utils import add_remaining_self_loops
from torch_geometric.utils.num_nodes import maybe_num_nodes

In [2]:
from torch_geometric.data import Data

In [3]:
def gcn_norm(edge_index, edge_weight=None, num_nodes=None, improved=False, add_self_loops=True, dtype=None):

    fill_value = 2. if improved else 1.

    if isinstance(edge_index, SparseTensor):
        adj_t = edge_index
        if not adj_t.has_value():
            adj_t = adj_t.fill_value(1., dtype=dtype)
        if add_self_loops:
            adj_t = fill_diag(adj_t, fill_value)
        deg = sparsesum(adj_t, dim=1)
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0.)
        adj_t = mul(adj_t, deg_inv_sqrt.view(-1, 1))
        adj_t = mul(adj_t, deg_inv_sqrt.view(1, -1))
        
        I = SparseTensor.eye(num_nodes)
        
#         print(adj_t.to_dense())
#         print(I-adj_t.to_dense())
        
        return I - adj_t

    else:
        num_nodes = maybe_num_nodes(edge_index, num_nodes)

        if edge_weight is None:
            edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                     device=edge_index.device)

        if add_self_loops:
            edge_index, tmp_edge_weight = add_remaining_self_loops(
                edge_index, edge_weight, fill_value, num_nodes)
            assert tmp_edge_weight is not None
            edge_weight = tmp_edge_weight

        row, col = edge_index[0], edge_index[1]
        deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
        
        deg_inv_sqrt = deg.pow_(-0.5)
        deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
        
        Arw = deg_inv_sqrt[row] * edge_weight * deg_inv_sqrt[col]
        
        I = torch.zeros((edge_index.size(1), ), dtype=dtype,device=edge_index.device)
        I[-num_nodes:]=1
        
#         print(edge_index)
#         print(Arw)
#         print(I-Arw)
        
        return edge_index, I-Arw


class SpatialConv(GCNConv):
    
    def forward(self, x: Tensor, edge_index: Adj, edge_weight: OptTensor = None) -> Tensor:

        if self.normalize:
            if isinstance(edge_index, Tensor):
                cache = self._cached_edge_index
                if cache is None:
                    edge_index, edge_weight = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_edge_index = (edge_index, edge_weight)
                else:
                    edge_index, edge_weight = cache[0], cache[1]

            elif isinstance(edge_index, SparseTensor):
                cache = self._cached_adj_t
                if cache is None:
                    edge_index = gcn_norm(  # yapf: disable
                        edge_index, edge_weight, x.size(self.node_dim),
                        self.improved, self.add_self_loops)
                    if self.cached:
                        self._cached_adj_t = edge_index
                else:
                    edge_index = cache

        x = self.lin(x)

        # propagate_type: (x: Tensor, edge_weight: OptTensor)
        out = self.propagate(edge_index, x=x, edge_weight=edge_weight,
                             size=None)

        if self.bias is not None:
            out += self.bias

        return out


In [4]:
# conv1 = SpatialConv(2, 2)
# conv1(x,edge_index)

In [5]:
# conv2 = GCNConv(2,2)
# conv2(x, edge_index)

In [6]:
# x = torch.Tensor([[1,0],[1,0],[1,0],[0,1],[0,1],[0,1],[0,1]])
# y = torch.LongTensor([0,0,0, 1, 1, 1, 1])
# edge_index = torch.LongTensor([[1,2],
#          [1,4],
#          [1,5],
#          [2,1],
#          [3,6],
#          [3,7],
#          [4,5],
#          [4,1],
#          [4,6],
#          [4,7],
#          [5,1],
#          [5,4],
#          [5,6],
#          [6,3],
#          [6,4],
#          [6,5],
#          [6,7],
#          [7,3],
#          [7,4],
#          [7,6]]).T
# edge_index = edge_index-1
# data = Data(x=x, y=y, edge_index = edge_index)

In [7]:
if __name__ == '__main__':    

    
    None