In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn import MessagePassing, TopKPooling, global_mean_pool
from torch_geometric.utils import degree, get_laplacian
from tqdm import trange
from torch_geometric.datasets import TUDataset
from torch_sparse import SparseTensor

In [2]:
class ComplexLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(ComplexLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty((out_features, in_features), dtype=torch.cfloat))
        self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.cfloat)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight.real, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight.imag, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)




class CayleyConv(MessagePassing):
    def __init__(
        self,
        r: int,
        K: int,
        in_channels: int,
        out_channels: int,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        assert r > 0
        assert K > 0
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.r = r
        self.K = K
        self.h = Parameter(torch.empty(1))
        self.i = Parameter(torch.tensor(0.+1.j), requires_grad=False)
        self.alpha = Parameter(torch.empty(1))
        self.real_linear = nn.Linear(in_channels, out_channels, bias=False)
        self.complex_linear = ComplexLinear(in_channels * r, out_channels, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        stdv = 1. / math.sqrt(self.in_channels)
        self.h.data.uniform_(-stdv, stdv)
        self.alpha.data.uniform_(-stdv, stdv)


    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
    ) -> Tensor:


        y_j = x
        out_0 = self.real_linear(y_j)
        out_r = []

        n_nodes = x.size(self.node_dim)
        device = edge_index.device
        edge_weight = torch.ones(edge_index.size(1), device=device)
        row, _ = edge_index
        deg = degree(row, n_nodes).to(device)
        # Laplacian
        l_index, l_weight = get_laplacian(edge_index, edge_weight, normalization=None, num_nodes=n_nodes)
        l_weight[l_index[0] == l_index[1]] -= self.alpha

        # Jacobi
        jacobi = self.calcualte_jacobi(l_index, l_weight)
        
        # calcualte r polynomials 
        for j in range(self.r):
            
            b_j = self.calcualte_b(l_index, l_weight, y_j)
            y_j_k = b_j

            # K jacobi iteration
            for k in range(self.K):
                # y_j ^ k+1 = J @ y_j ^ k + b_j
                y_j_k = self.propagate(l_index, x=y_j_k, jacobi=jacobi) + b_j
            y_j = y_j_k
            out_r.append(y_j)
        out_r = self.complex_linear(torch.concat(out_r, -1))
        out = out_0 + 2 * out_r.real
        return out

    def message(self, x_j: Tensor, jacobi: Tensor) -> Tensor:
        # J Y
        return jacobi.view(-1, 1) * x_j
    
    def calcualte_jacobi(self, l_index, l_weight):
        # device
        l_row, l_col = l_index

        # (hD + iI)^-1 * h
        l_dia = l_weight[l_row == l_col]
        tmp_left = 1 / (self.h * l_dia + self.i)
        tmp_left.masked_fill_(tmp_left == float('inf'), 0.+0.j)
        
        tmp_right = self.h * l_weight.type(torch.cfloat)
        tmp_right[l_row == l_col] = 0.+0.j

        jacobi = tmp_left[l_row] * tmp_right
        return jacobi

    def calcualte_b(self, l_index, l_weight, y_j):
        l_row, l_col = l_index

        # hL - iI
        tmp_right = (l_weight * self.h).type(self.i.dtype)
        tmp_right[l_row == l_col] -= self.i

        l_dia = l_weight[l_row == l_col]
        # (Diag(hL + iI))^-1
        tmp_left = 1 / (self.h * l_dia + self.i)
        tmp_left.masked_fill_(tmp_left == float('inf'), 0.+0.j)
        # (Diag(hL + iI))^-1 (hL - iI)
        tmp = tmp_left[l_row] * tmp_right
        tmp = torch.sparse_coo_tensor(indices=l_index, values=tmp, device=y_j.device)
        return torch.matmul(tmp, y_j.type(torch.cfloat))
    
class CayleyNet(nn.Module):
    def __init__(self, n_conv, r, K, feature_dim, hidden_dim, output_dim):
        super(CayleyNet, self).__init__()
        convs = []
        for i in range(n_conv):
            convs.append(CayleyConv(r, K, feature_dim if i == 0 else hidden_dim, hidden_dim))
            convs.append(nn.ReLU())
        self.convs = nn.ModuleList(convs)
        # self.caley_conv = CayleyConv(r, K, feature_dim, hidden_dim)
        self.pool = TopKPooling(hidden_dim, ratio=0.9)
        self.lin = nn.Linear(hidden_dim, output_dim)


    def forward(self, x, edge_index):
        for i in range(0, len(self.convs), 2):
            conv = self.convs[i]
            relu = self.convs[i+1]
            x = conv(x, edge_index)
            x = relu(x)

        x, edge_index, _, batch, _, _ =  self.pool(x, edge_index)

        x = global_mean_pool(x, batch) 

        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin(x)
        
        return x



In [3]:
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
DEVICE = 'cpu'

model = CayleyNet(n_conv=3, r=5, K=10, feature_dim=7, hidden_dim=64, output_dim=dataset.num_classes).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
model.train()
for epoch in trange(10):
    for data in train_dataset:
        data = data.to(DEVICE)
        out = model(data.x, data.edge_index) 
        loss = criterion(out, data.y) 
        loss.backward() 
        optimizer.step() 
        optimizer.zero_grad() 


model.eval()
correct = 0
for data in test_dataset:  
    data = data.to(DEVICE)
    out = model(data.x, data.edge_index)  
    pred = out.argmax(dim=1) 
    correct += int((pred == data.y).sum())  
print("{} test accuracy".format(correct / len(test_dataset))) 

100%|██████████████████████████████████████████████████████████████████████████████████| 10/10 [00:43<00:00,  4.31s/it]


0.631578947368421 test accuracy
