In [6]:
import torch
from torch import Tensor
import torch.nn.functional as F
from torch.nn import Parameter
from typing import Optional, Tuple, Union
from torch_geometric.nn.dense.linear import Linear

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj, OptTensor, PairTensor, Size
from torch_geometric.utils import (
    add_self_loops,
    remove_self_loops,
    softmax,
)
from torch_sparse import SparseTensor, set_diag

# PyG initialization helpers
from torch_geometric.nn.inits import glorot, zeros


In [None]:
class GTATConv(MessagePassing):
    _alpha: OptTensor

    def __init__(self, in_channels: int, out_channels: int, heads: int,
                 topology_channels:int = 15,
                 concat: bool = True, negative_slope: float = 0.2,
                 dropout: float = 0., add_self_loops: bool = True,
                 bias: bool = True, share_weights: bool = False, **kwargs):
        super(GTATConv, self).__init__(node_dim=0, **kwargs)

        self.in_channels = in_channels
        self.out_channels = out_channels
        self.topology_channels = topology_channels
        self.heads = heads
        self.concat = concat
        self.negative_slope = negative_slope
        self.dropout = dropout
        self.add_self_loops = add_self_loops
        self.share_weights = share_weights
        self.lin_l = Linear(in_channels, heads * out_channels, bias=bias,
                            weight_initializer='glorot')
        
        if share_weights:
            self.lin_r = self.lin_l
        else:
            self.lin_r = Linear(in_channels, heads * out_channels, bias=bias,
                                weight_initializer='glorot')
        
        

        self.att = Parameter(torch.Tensor(1, heads, out_channels))

        self.att2 = Parameter(torch.Tensor(1, heads, self.topology_channels))

        if bias and concat:
            self.bias = Parameter(torch.Tensor(heads * out_channels))
        elif bias and not concat:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)

        self._alpha1 = None
        self._alpha2 = None

        self.bias2 =  Parameter(torch.Tensor(self.topology_channels))

        self.reset_parameters()

    def reset_parameters(self):
        self.lin_l.reset_parameters()
        self.lin_r.reset_parameters()
        glorot(self.att)
        glorot(self.att2)
        zeros(self.bias)
        zeros(self.bias2)

    def forward(self, x: Union[Tensor, PairTensor], edge_index: Adj,
                topology: Tensor,
                size: Size = None, return_attention_weights: bool = None):
        # type: (Union[Tensor, PairTensor], Tensor , Tensor, Size, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, Size, NoneType) -> Tensor  # noqa
        # type: (Union[Tensor, PairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]]  # noqa
        # type: (Union[Tensor, PairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor]  # noqa
        r"""
        Args:
            return_attention_weights (bool, optional): If set to :obj:`True`,
                will additionally return the tuple
                :obj:`(edge_index, attention_weights)`, holding the computed
                attention weights for each edge. (default: :obj:`None`)
        """
        H, C = self.heads, self.out_channels

        x_l: OptTensor = None
        x_r: OptTensor = None
        if isinstance(x, Tensor):
            assert x.dim() == 2
            x_l = self.lin_l(x).view(-1, H, C)  #(N , heads, features)
            if self.share_weights:
                x_r = x_l
            else:
                x_r = self.lin_r(x).view(-1, H, C)


        assert x_l is not None
        assert x_r is not None
        topology = topology.unsqueeze(dim = 1)
        topology = topology.repeat(1, self.heads, 1)
        x_l = torch.cat((x_l,topology), dim = -1)
        x_r = torch.cat((x_r,topology), dim = -1)

        if self.add_self_loops:
            if isinstance(edge_index, Tensor):
                num_nodes = x_l.size(0)
                if x_r is not None:
                    num_nodes = min(num_nodes, x_r.size(0))
                if size is not None:
                    num_nodes = min(size[0], size[1])
                edge_index, _ = remove_self_loops(edge_index)
                edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
            elif isinstance(edge_index, SparseTensor):
                edge_index = set_diag(edge_index)

        out_all = self.propagate(edge_index, x=(x_l, x_r), size=size)
        out = out_all[ : , : , :self.out_channels ]
        out2 = out_all[ : , : , self.out_channels:]
        alpha1 = self._alpha1
        self._alpha1 = None
        alpha2 = self._alpha2
        self._alpha2 = None

        if self.concat:
            out = out.reshape(-1, self.heads * self.out_channels)
        else:
            out = out.mean(dim=1)

        if self.bias is not None:
            out += self.bias

        out2 = out2.mean(dim=1)
        out2 += self.bias2

        if isinstance(return_attention_weights, bool):
            assert alpha is not None
            if isinstance(edge_index, Tensor):
                return out, (edge_index, alpha)
            elif isinstance(edge_index, SparseTensor):
                return out, edge_index.set_value(alpha, layout='coo')
        else:
            return out , out2
        return out, out2

    def message(self, x_j: Tensor, x_i: Tensor, index: Tensor, ptr: OptTensor,
                size_i: Optional[int]) -> Tensor:
        x = x_i + x_j
        alpha1 = (x[:, :, :self.out_channels] * self.att).sum(dim=-1)
        alpha2 = (x[:, :, self.out_channels:] * self.att2).sum(dim=-1)
        alpha1 = F.leaky_relu(alpha1 ,self.negative_slope )
        alpha2 = F.leaky_relu(alpha2 ,self.negative_slope )
        alpha1 = softmax(alpha1, index, ptr, size_i)
        alpha2 = softmax(alpha2, index, ptr, size_i)
        self._alpha1 = alpha1
        self._alpha2 = alpha2
        alpha1= F.dropout(alpha1, p=self.dropout, training=self.training)
        alpha2= F.dropout(alpha2, p=self.dropout, training=self.training)
        return torch.cat((x_j[:, :, :self.out_channels]* alpha2.unsqueeze(-1), x_j[:, :, self.out_channels: ]* alpha1.unsqueeze(-1)) ,dim = -1)

    def __repr__(self):
        return '{}({}, {}, heads={})'.format(self.__class__.__name__,
                                             self.in_channels,
                                             self.out_channels, self.heads)

In [29]:
import torch
from torch_geometric.data import Data

# ==== your GTATConv must be imported here ====
# from your_file import GTATConv

# ---- Create a tiny 4-node graph ----
# edges: 0->1, 0->2, 1->2, 2->3, 3->0
edge_index = torch.tensor([
    [0, 0, 1, 2, 3],
    [1, 2, 2, 3, 0]
], dtype=torch.long)

num_nodes = 1212
in_channels = 6
out_channels = 6
heads = 2
topology_channels = 15

x = torch.randn(num_nodes, in_channels)

# topology must be per-node, not per-edge
topology = torch.randn(num_nodes, topology_channels)


In [30]:
topology.shape

torch.Size([1212, 15])

In [31]:
# ---- Create the layer ----
conv = GTATConv(
    in_channels=in_channels,
    out_channels=out_channels,
    heads=heads,
    topology_channels=topology_channels,
    concat=True
)

In [34]:
out, out_topo = conv(
    x=x,
    edge_index=edge_index,
    topology=topology
    # drop return_attention_weights for now (the alpha bug)
)

running message


In [33]:
out, out_topo

(tensor([[ 0.0588,  0.1033, -0.4876,  ..., -0.9855,  0.3434, -0.0788],
         [ 1.8255, -1.7424,  0.2190,  ..., -1.0102,  0.7253, -0.2346],
         [ 0.5957, -0.8927, -0.1721,  ..., -1.0335,  0.0633, -0.3791],
         ...,
         [-0.2016, -0.1400, -0.2803,  ..., -0.1449, -0.0055, -0.6065],
         [ 1.0063, -0.7693, -0.1875,  ..., -0.2388,  1.5637, -0.8959],
         [-2.8115,  2.7337, -0.5551,  ...,  2.0506, -2.0814,  1.0832]],
        grad_fn=<AddBackward0>),
 tensor([[ 0.2006, -0.7927,  0.4328,  ...,  0.1841, -0.6073, -0.3252],
         [ 0.1825, -1.2819,  0.0595,  ..., -0.5152,  0.1847, -0.3861],
         [-0.1153, -1.1041,  0.1479,  ..., -0.4474, -0.1341, -0.7020],
         ...,
         [ 0.7222,  1.9238, -0.4392,  ..., -0.2834, -1.7416,  0.7609],
         [ 0.1801, -0.5501,  0.6030,  ...,  0.8499,  0.3245, -1.0969],
         [-0.7432, -0.6839, -0.4942,  ...,  0.4759,  0.3428,  2.1522]],
        grad_fn=<AddBackward0>))