In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from torch import Tensor
from torch.nn import Parameter

from torch_geometric.nn import MessagePassing, TopKPooling, global_mean_pool
from torch_geometric.utils import to_dense_adj, degree
from tqdm import trange
from torch_geometric.datasets import TUDataset


In [2]:
class ComplexLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(ComplexLinear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = nn.Parameter(torch.empty((out_features, in_features), dtype=torch.cfloat))
        self.bias = nn.Parameter(torch.empty(out_features, dtype=torch.cfloat)) if bias else None
        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight.real, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.weight.imag, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def forward(self, input):
        return F.linear(input, self.weight, self.bias)

In [3]:
#define b_j outside the class as a function

def calcualte_b_j(h, lap, n_nodes, y_j):
    i = torch.complex(torch.tensor(0.0), torch.tensor(1.0)).to(y_j.device)
    I = torch.eye(n_nodes).to(y_j.device)
    return torch.diag(torch.diag(h * lap + i * I) ** -1) @ (h * lap - i * I) @ y_j.to(torch.complex64)

def calcualte_jacobi(h, W, D, n_nodes):
    i = torch.complex(torch.tensor(0.0), torch.tensor(1.0)).to(W.device)
    I = torch.eye(n_nodes).to(W.device)
    jacobi = torch.diag(torch.diag(h * D + i * I) ** -1)
    return torch.mm(jacobi, h*W.to(torch.complex64))

class CayleyConv(MessagePassing):
    def __init__(
        self,
        r: int,
        K: int,
        in_channels: int,
        out_channels: int,
        **kwargs,
    ):
        kwargs.setdefault('aggr', 'add')
        super().__init__(**kwargs)

        assert r > 0
        assert K > 0
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.r = r
        self.K = K
        self.h = Parameter(torch.empty(1))
        self.real_linear = nn.Linear(in_channels, out_channels, bias=False)
        self.complex_linear = ComplexLinear(in_channels * r, out_channels, bias=False)
        self.reset_parameters()

    def reset_parameters(self):
        super().reset_parameters()
        stdv = 1. / math.sqrt(self.in_channels)
        self.h.data.uniform_(-stdv, stdv)


    def forward(
        self,
        x: Tensor,
        edge_index: Tensor,
    ) -> Tensor:


        y_j = x
        out_0 = self.real_linear(y_j)
        out_r = []

        row, col = edge_index
        n_nodes = x.size(self.node_dim)
        W = to_dense_adj(edge_index).squeeze(0)
        D = torch.diag(degree(row, n_nodes))
        lap = D - W
        jacobi = calcualte_jacobi(self.h, W, D, n_nodes)
        norm = jacobi.to_sparse().values()
        
        
        # calcualte r polynomials 
        
        for j in range(self.r):
            
            b_j = calcualte_b_j(self.h, lap, n_nodes, y_j)
            y_j_k = b_j

            # K jacobi iteration
            for k in range(self.K):
                y_j_k = self.propagate(edge_index, x=y_j_k, jacobi=norm) + b_j
            y_j = y_j_k
            out_r.append(y_j)
        out_r = self.complex_linear(torch.concat(out_r, -1))
        out = out_0 + 2 * out_r.real
        return out

    def message(self, x_j: Tensor, jacobi: Tensor) -> Tensor:
        return jacobi.view(-1, 1) * x_j
    
class CayleyNet(nn.Module):
    def __init__(self, r, K, feature_dim, hidden_dim, output_dim):
        super(CayleyNet, self).__init__()
        self.caley_conv = CayleyConv(r, K, feature_dim, hidden_dim)
        self.pool = TopKPooling(hidden_dim, ratio=0.9)
        self.lin = nn.Linear(hidden_dim, output_dim)


    def forward(self, x, edge_index):
        x = self.caley_conv(x, edge_index)
        x = x.relu()

        x, edge_index, _, batch, _, _ =  self.pool(x, edge_index)

        x = global_mean_pool(x, batch) 

        x = F.dropout(x, p=0.2, training=self.training)
        x = self.lin(x)
        
        return x


In [4]:
# DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
dataset = TUDataset(root='data/TUDataset', name='MUTAG')
DEVICE = 'cpu'

model = CayleyNet(r=64, K=10, feature_dim=7, hidden_dim=64, output_dim=dataset.num_classes).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
train_dataset = dataset[:150]
test_dataset = dataset[150:]
model.train()
for epoch in trange(10):
    for data in train_dataset:
        data = data.to(DEVICE)
        out = model(data.x, data.edge_index) 
        loss = criterion(out, data.y) 
        loss.backward() 
        optimizer.step() 
        optimizer.zero_grad() 


100%|██████████| 10/10 [04:17<00:00, 25.79s/it]


In [5]:
model.eval()

correct = 0
for data in test_dataset:  
    out = model(data.x, data.edge_index)  
    pred = out.argmax(dim=1) 
    correct += int((pred == data.y).sum())  
print("{} test accuracy".format(correct / len(test_dataset))) 

0.6578947368421053 test accuracy
