In [4]:
import os.path as osp
from math import ceil, floor
import numpy as np

import torch
import torch.nn.functional as F
import torch_geometric.transforms as T
from torch_geometric.data import DenseDataLoader
from torch_geometric.nn import DenseGCNConv as GCNConv, dense_diff_pool
from torch_geometric.nn import DenseSAGEConv
from torch_geometric.datasets import TUDataset
from tqdm import tqdm
from torch.nn import Linear, Parameter
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import add_self_loops, degree, get_laplacian, remove_self_loops, to_dense_adj


## Preprocessing Data

In [7]:
max_nodes = 150


class MyFilter(object):
    def __call__(self, data):
        return data.num_nodes <= max_nodes


dataset = TUDataset('data', name='PROTEINS', transform=T.ToDense(max_nodes),
                    pre_filter=MyFilter())
dataset = dataset.shuffle()
n = (len(dataset) + 9) // 10
test_dataset = dataset[:n]
val_dataset = dataset[n:2 * n]
train_dataset = dataset[2 * n:]
test_loader = DenseDataLoader(test_dataset, batch_size=20)
val_loader = DenseDataLoader(val_dataset, batch_size=20)
train_loader = DenseDataLoader(train_dataset, batch_size=20)

## Implementing GDN in Pytorch

In [112]:
class GDNLayer(torch.nn.Module):
    def __init__(self, in_channels, out_channels, normalization='sym', bias=True):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.normalization = normalization
        
        # in_channels and out_channels should be the same???
        
        # Learnable Parameters
        self.W3 = Parameter(torch.randn(in_channels,in_channels))
        self.W4 = Parameter(torch.randn(in_channels,in_channels))
        self.W5 = Parameter(torch.randn(in_channels,out_channels))
        
    def __norm__(self, edge_index):
        
        edge_index, _ = remove_self_loops(edge_index=edge_index)
        edge_index, _ = get_laplacian(edge_index=edge_index, normalization=self.normalization)
        edge_index, _ = add_self_loops(edge_index=edge_index)
        
        return edge_index

    def dense_to_sparse_with_attr(adj):
        adj2 = adj.abs().sum(dim=-1)  
        index = adj2.nonzero(as_tuple=True)
        edge_attr = adj[index]
        batch = index[0] * adj.size(-1)
        index = (batch + index[1], batch + index[2])
        edge_index = torch.stack(index, dim=0)
        return edge_index, edge_attr
        

    def forward(self, batched_H, batched_A_):

        batched_X_ = []
        for i in range(len(batched_H)):
            H = batched_H[i]
            A_ = batched_A_[i]
            # edge_index, _ = self.dense_to_sparse_with_attr(A_)
            # edge_index, _ = self.__norm__(edge_index=edge_index)
            # L_sym = to_dense_adj(edge_index=edge_index)
            L_sym = A_
            I_n = torch.eye(n=L_sym.shape[0])
            order = 3
            s = 3
            eigendecomp_approx_list = [torch.linalg.matrix_power(L_sym, n) for n in range(1, order+1)] # list of L^n stored for future use

            eigendecomp_L_approx = I_n
            psi = I_n
            psi_inverse = I_n
            for i, L_sym_n in enumerate(eigendecomp_approx_list):
                n = i + 1
                eigendecomp_L_approx += L_sym_n                                             # Equation (9)
                psi_magnitude = s**n / np.math.factorial(n)
                psi += psi_magnitude if n % 2 == 0 else - psi_magnitude                     # Equation (11)
                psi_inverse += psi_magnitude                                                # Equation (12)
                
            M = torch.sigmoid(eigendecomp_L_approx @ H @ self.W3)                           # Equation (10)

            X_ = psi @ torch.relu(psi_inverse @ M @ self.W4) @ self.W5                      # Equation (13)
            batched_X_.append(X_)
            
        batched_X_ = torch.stack(batched_X_)
        print(batched_X_.shape)
        return batched_X_


class GDN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels):
        super().__init__()
        
        self.GDN1 = GDNLayer(in_channels, hidden_channels)
        self.GDN2 = GDNLayer(hidden_channels, hidden_channels)
        self.GDN3 = GDNLayer(hidden_channels, out_channels)

    def forward(self, H, A_):
        
        X_ = self.GDN1(H, A_)
        X_ = self.GDN2(X_, A_)
        X_ = self.GDN3(X_, A_)

        return X_
    
# class GDNLayer(MessagePassing):
#     def __init__(self, in_channels, out_channels, normalization='sym', bias=True):
#         super().__init__(aggr='add')  # "Add" aggregation (Step 5).
#         self.in_channels = in_channels
#         self.out_channels = out_channels
#         self.normalization = normalization
        
#         # in_channels and out_channels should be the same???
        
#         # Learnable Parameters
#         self.W3 = Parameter(torch.tensor(in_channels,in_channels))
#         self.W4 = Parameter(torch.tensor(in_channels,in_channels))
#         self.W5 = Parameter(torch.tensor(in_channels,out_channels))

        
#     def __norm__(self, edge_index):
        
#         edge_index, _ = remove_self_loops(edge_index=edge_index)
#         edge_index, _ = get_laplacian(edge_index=edge_index, normalization=self.normalization)
#         edge_index, _ = add_self_loops(edge_index=edge_index)
        
#         return edge_index

#     def forward(self, H, edge_index):
#         edge_index, _ = self.__norm__(edge_index=edge_index)
#         normalized_L = to_dense_adj(edge_index=edge_index)
#         I_n = torch.eye(n=normalized_L.shape[0])
#         order = 3
#         s = 3
#         eigendecomp_approx_list = [torch.linalg.matrix_power(normalized_L, n) for n in range(1, order+1)]
#         eigendecomp_approx_list.insert(0, I_n)                                          # list of L^n stored for future use

#         for n, L_n in enumerate(eigendecomp_approx_list):
#             eigendecomp_L_approx += L_n                                                 # Equation (9)
#             psi_magnitude = s**n / np.math.factorial(n)
#             psi += psi_magnitude if n % 2 == 0 else - psi_magnitude                     # Equation (11)
#             psi_inverse += psi_magnitude                                                # Equation (12)
            
#         M = torch.sigmoid(eigendecomp_L_approx @ H @ self.W3)                           # Equation (10)

#         X_ = psi @ torch.relu(psi_inverse @ M @ self.W4) @ self.W5                      # Equation (13)
        
#         return X_

        
#     # def forward(self, x, edge_index):             # Standard GCN forward pass for reference
#     #     # x has shape [N, in_channels]
#     #     # edge_index has shape [2, E]

#     #     # Step 1: Add self-loops to the adjacency matrix.
#     #     edge_index, _ = add_self_loops(edge_index, num_nodes=x.size(0))

#     #     # Step 2: Linearly transform node feature matrix.
#     #     x = self.lin(x)

#     #     # Step 3: Compute normalization.
#     #     row, col = edge_index
#     #     deg = degree(col, x.size(0), dtype=x.dtype)
#     #     deg_inv_sqrt = deg.pow(-0.5)
#     #     deg_inv_sqrt[deg_inv_sqrt == float('inf')] = 0
#     #     norm = deg_inv_sqrt[row] * deg_inv_sqrt[col]

#     #     # Step 4-5: Start propagating messages.
#     #     out = self.propagate(edge_index, x=x, norm=norm)

#     #     # Step 6: Apply a final bias vector.
#     #     out += self.bias

#     #     return out

#     # def message(self, x_j, norm):
#     #     # x_j has shape [E, out_channels]

#     #     # Step 4: Normalize node features.
#     #     return norm.view(-1, 1) * x_j

## Defining Other Modules

In [102]:
class GNNSage(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super().__init__()

        self.conv1 = DenseSAGEConv(in_channels, hidden_channels, normalize)
        self.bn1 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv2 = DenseSAGEConv(hidden_channels, hidden_channels, normalize)
        self.bn2 = torch.nn.BatchNorm1d(hidden_channels)
        self.conv3 = DenseSAGEConv(hidden_channels, out_channels, normalize)
        self.bn3 = torch.nn.BatchNorm1d(out_channels)

        if lin is True:
            self.lin = torch.nn.Linear(2 * hidden_channels + out_channels,
                                       out_channels)
        else:
            self.lin = None

    def bn(self, i, x):
        batch_size, num_nodes, num_channels = x.size()

        x = x.view(-1, num_channels)
        x = getattr(self, f'bn{i}')(x)
        x = x.view(batch_size, num_nodes, num_channels)
        return x

    def forward(self, x, adj, mask=None):
        # batch_size, num_nodes, in_channels = x.size()

        x0 = x
        x1 = self.bn(1, self.conv1(x0, adj, mask).relu())
        x2 = self.bn(2, self.conv2(x1, adj, mask).relu())
        x3 = self.bn(3, self.conv3(x2, adj, mask).relu())

        x = torch.cat([x1, x2, x3], dim=-1)

        if self.lin is not None:
            x = self.lin(x).relu()

        return x


class GCN(torch.nn.Module):
    def __init__(self, in_channels, hidden_channels, out_channels,
                 normalize=False, lin=True):
        super(GCN, self).__init__()
        
        self.convs = torch.nn.ModuleList()
        # self.bns = torch.nn.ModuleList()
        
        self.convs.append(GCNConv(in_channels, hidden_channels, normalize))
        # self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        
        self.convs.append(GCNConv(hidden_channels, hidden_channels, normalize))
        # self.bns.append(torch.nn.BatchNorm1d(hidden_channels))
        
        self.convs.append(GCNConv(hidden_channels, out_channels, normalize))
        # self.bns.append(torch.nn.BatchNorm1d(out_channels))


    def forward(self, x, adj, mask=None):
        batch_size, num_nodes, in_channels = x.size()
        
        for step in range(len(self.convs)):
            # x = self.bns[step](F.relu(self.convs[step](x, adj, mask)))
            x = F.relu(self.convs[step](x, adj, mask))

        return x

class DiffPool(torch.nn.Module):
    def __init__(self):
        super().__init__()

        num_nodes = ceil(0.25 * max_nodes)
        self.gnn1_pool = GNNSage(dataset.num_features, 64, num_nodes)
        self.gnn1_embed = GNNSage(dataset.num_features, 64, 64, lin=False)

        num_nodes = ceil(0.25 * num_nodes)
        self.gnn2_pool = GNNSage(3 * 64, 64, num_nodes)
        self.gnn2_embed = GNNSage(3 * 64, 64, 64, lin=False)

        self.gnn3_embed = GNNSage(3 * 64, 64, 64, lin=False)

        self.lin1 = torch.nn.Linear(3 * 64, 64)
        self.lin2 = torch.nn.Linear(64, dataset.num_classes)

    def forward(self, x, adj, mask=None):
        s = self.gnn1_pool(x, adj, mask)
        x = self.gnn1_embed(x, adj, mask)

        x, adj, l1, e1 = dense_diff_pool(x, adj, s, mask)

        s = self.gnn2_pool(x, adj)
        x = self.gnn2_embed(x, adj)

        x, adj, l2, e2 = dense_diff_pool(x, adj, s)

        x = self.gnn3_embed(x, adj)

        x = x.mean(dim=1)
        x = self.lin1(x).relu()
        x = self.lin2(x)
        return F.log_softmax(x, dim=-1), l1 + l2, e1 + e2




class HierarchicalGAE(torch.nn.Module):

    def __init__(self):
        super().__init__()
        
        self.s_list = []

        # #----------------- Graph Convolution (Encoding layers) -----------------#
        # num_nodes1 = ceil(0.25 * max_nodes)
        # self.gnn1_pool = GNNSage(dataset.num_features, 64, num_nodes1)
        # self.gnn1_embed = GNNSage(dataset.num_features, 64, 64, lin=False)

        # num_nodes2 = ceil(0.25 * num_nodes1)
        # self.gnn2_pool = GNNSage(3 * 64, 64, num_nodes2)
        # self.gnn2_embed = GNNSage(3 * 64, 64, 64, lin=False)

        # self.gnn3_embed = GNNSage(3 * 64, 64, 64, lin=False)


        # #----------------- Graph Deconvolution (Decoding layers) -----------------#

        # self.gnn1_unpool = GNNSage(3 * 64, 64, num_nodes1)
        # self.gdn1_embed_inv = GDN(num_nodes1, num_nodes1, num_nodes1)
        
        # self.gnn2_unpool = GNNSage(3 * 64, 64, 64, max_nodes)
        # self.gdn2_embed_inv = GDN(max_nodes, max_nodes, max_nodes)
        

        #----------------- Graph Convolution (Encoding layers) -----------------#
        num_nodes1 = ceil(0.25 * max_nodes)
        self.gnn1_pool = GCN(dataset.num_features, 64, num_nodes1)
        self.gnn1_embed = GCN(dataset.num_features, 64, 64, lin=False)
        
        num_nodes2 = ceil(0.25 * num_nodes1)
        self.gnn2_pool = GCN(64, 64, num_nodes2)
        self.gnn2_embed = GCN(64, 64, 64, lin=False)

        self.gnn3_embed = GCN(64, 64, 64, lin=False)


        #----------------- Graph Deconvolution (Decoding layers) -----------------#

        self.gnn1_unpool = GCN(64, 64, 64, num_nodes1)
        self.gdn1_embed_inv = GDN(64, 64, 64)
        
        self.gnn2_unpool = GCN(64, 64, 64, max_nodes)
        self.gdn2_embed_inv = GDN(64, 64, 64)
        


    
    def encode(self, x, adj, mask=None):                # DiffPool Hierarchical Encoding

        s_0 = self.gnn1_pool(x, adj, mask)              # Learn Coarse Grained Mapping S_0 = GNN_pool(X, A)
        z_0 = self.gnn1_embed(x, adj, mask)             # Learn First Layer Embeddings Z_0 = GNN_embed(X, A)
        adj_0 = adj
        self.s_list.append(s_0)       
                
        x_1 = s_0.transpose(1,2) @ z_0                  # Combining features from same communities  X_1 = S_0^T Z_0 
        adj_1 = s_0.transpose(1,2) @ adj_0 @ s_0        # Use S_0 mapping to get coarse grained adj A_1 = S_0^T A_0 S_0 

        s_1 = self.gnn2_pool(x_1, adj_1)                # Learn Coarser Grained Mapping S_1 = GNN_pool(X_1, A_1)
        z_1 = self.gnn2_embed(x_1, adj_1)               # Learn Second Layer Embeddings Z_1 = GNN_embed(X_1, A_1)
                
        self.s_list.append(s_1)               
        
        x_2 = s_1.transpose(1,2) @ z_1                  # Combining features from same communities  X_2 = S_1^T Z_1 
        adj_2 = s_1.transpose(1,2) @ adj_1 @ s_1        # Use S_1 mapping to get coarse grained adj A_2 = S_1^T A_1 S_1
        
        z_2 = self.gnn3_embed(x_2, adj_2)               # Learn Third Layer Embeddings  Z_2 = GNN_embed(X_2, A_2)
        
        # embedding = z_2.mean(dim=1)                     # Average remaining embeddings to generate whole graph embedding
        # embedding = embedding.reshape(shape=(embedding.shape[0], 1, embedding.shape[1]))
        
        print(f"Hierarchical Features Pooling: {z_0.shape} -> {z_1.shape} -> {z_2.shape}")
        print(f"Hierarchical Adjacency Pooling: {adj_0.shape} -> {adj_1.shape} -> {adj_2.shape}")
        return z_2, adj_2
    
    
    def decode(self, H, A, mask=None):
        # s = self.gnn1_unpool(x, adj, mask)
        # X_ = self.gdn1_embed_inv(X_, A)
        
        S = self.s_list[-1]
        X1_ = S @ H                     
        A1_ = S @ A @ S.transpose(1,2)
        X1_ = self.gdn1_embed_inv(X1_, A1_)     #deconvolute these smoothed representations with GDN


        S = self.s_list[-2]
        X0_ = S @ X1_                      
        A0_ = S @ A1_ @ S.transpose(1,2)
        X0_ = self.gdn1_embed_inv(X0_, A0_)     #deconvolute these smoothed representations with GDN

        X_ = X0_
        A_ = A0_
        # s = self.gnn2_unpool(x, adj, mask)
        # x = self.gdn2_embed_inv(x, adj)
        
        # x, adj, _, _ = dense_diff_pool(x, adj, s, mask)
        print(f"Hierarchical Features Unpooling: {H.shape} -> {X1_.shape} -> {X0_.shape}")
        print(f"Hierarchical Adjacency Unpooling: {A.shape} -> {A1_.shape} -> {A0_.shape}")
        return X_, A_

    def forward(self, x, adj, mask=None):
        x, adj = self.encode(x, adj, mask)
        x, adj = self.decode(x, adj)
        return x, adj



In [113]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hgae = HierarchicalGAE().to(device)
optimizer = torch.optim.Adam(hgae.parameters(), lr=0.001)

hgae.train()
for data in train_loader:
    data = data.to(device)
    optimizer.zero_grad()
    output, _ = hgae(data.x, data.adj, data.mask)
    print(output)
    break

Hierarchical Features Pooling: torch.Size([20, 150, 64]) -> torch.Size([20, 38, 64]) -> torch.Size([20, 10, 64])
Hierarchical Adjacency Pooling: torch.Size([20, 150, 150]) -> torch.Size([20, 38, 38]) -> torch.Size([20, 10, 10])
torch.Size([20, 38, 64])
torch.Size([20, 38, 64])
torch.Size([20, 38, 64])
torch.Size([20, 150, 64])
torch.Size([20, 150, 64])
torch.Size([20, 150, 64])
Hierarchical Features Unpooling: torch.Size([20, 10, 64]) -> torch.Size([20, 38, 64]) -> torch.Size([20, 150, 64])
Hierarchical Adjacency Unpooling: torch.Size([20, 10, 10]) -> torch.Size([20, 38, 38]) -> torch.Size([20, 150, 150])
tensor([[[ 9.3804e+06,  5.2027e+07, -2.7934e+07,  ...,  1.2650e+08,
           1.0753e+06,  1.0764e+08],
         [ 9.3804e+06,  5.2027e+07, -2.7934e+07,  ...,  1.2650e+08,
           1.0753e+06,  1.0764e+08],
         [ 9.3804e+06,  5.2027e+07, -2.7934e+07,  ...,  1.2650e+08,
           1.0753e+06,  1.0764e+08],
         ...,
         [ 9.3804e+06,  5.2027e+07, -2.7934e+07,  ...,  1.

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DiffPool().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)


def train(epoch):
    model.train()
    loss_all = 0

    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        output, _, _ = model(data.x, data.adj, data.mask)
        loss = F.nll_loss(output, data.y.view(-1))
        loss.backward()
        loss_all += data.y.size(0) * float(loss)
        optimizer.step()
    return loss_all / len(train_dataset)


@torch.no_grad()
def test(loader):
    model.eval()
    correct = 0

    for data in loader:
        data = data.to(device)
        pred = model(data.x, data.adj, data.mask)[0].max(dim=1)[1]
        correct += int(pred.eq(data.y.view(-1)).sum())
    return correct / len(loader.dataset)

best_val_acc = test_acc = 0
for epoch in range(1, 151):
    train_loss = train(epoch)
    val_acc = test(val_loader)
    if val_acc > best_val_acc:
        test_acc = test(test_loader)
        best_val_acc = val_acc
    print(f'Epoch: {epoch:03d}, Train Loss: {train_loss:.4f}, '
          f'Val Acc: {val_acc:.4f}, Test Acc: {test_acc:.4f}')

In [None]:
import networkx as nx
import torch_geometric

for data in train_loader:
    example = data.get_example(0)
    print(example.edge_index)
    g = torch_geometric.utils.to_networkx(example)
    nx.draw(g)
    break
