In [1]:
import torch
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.inits import ones
from torch_geometric.utils import to_dense_adj, degree

In [None]:

class CayleyConv(MessagePassing):
    def __init__(
        self,
        r: int,
        K: int,
        h: float,
        **kwargs,
    ):
        super().__init__(**kwargs)

        assert r > 0
        assert K > 0

        self.r = r
        self.K = K
        self.h = h
        self.c_0 = Parameter(torch.ones(1))
        self.c_r = Parameter(torch.ones(r, dtype=torch.complex64))

        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        ones(self.c_0)
        ones(self.c_r)
        

    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
    ) -> Tensor:


        y_j = x
        out = self.c_0 * y_j
        row, col = edge_index
        n_nodes = x.size(self.node_dim)
        W = to_dense_adj(edge_index)
        D = torch.diag(degree(row, n_nodes))
        lap = D - W
        i = torch.complex(torch.tensor(0.0), torch.tensor(1.0))
        jacobi = 1 / (self.h * D + i * torch.eye(n_nodes))
        jacobi = torch.mm(jacobi, W)

        
        # calcualte r polynomials 
        for j in range(self.r):
            b_j = (self.h * lap + i * torch.eye(n_nodes)) ** -1 @ (self.h * lap - i * torch.eye(n_nodes)) @ y_j
            y_j_k = b_j

            # K jacobi iteration
            for _ in range(self.K):
                y_j_k = self.propagate(edge_index, x=y_j_k, jacobi=jacobi, bias=b_j)
            y_j = y_j_k
            out += self.c_r[j] * y_j

        return out

    def message(self, x_j: Tensor, jacobi: Tensor, bias: Tensor) -> Tensor:
        return torch.matmul(jacobi, x_j) + bias
    