In [1]:
import numpy as np

import torch
from torch.nn import Linear
from torch_scatter import scatter_mean
from torch_geometric.nn import MessagePassing
from torch_geometric.utils import degree
from torch_geometric.utils.num_nodes import maybe_num_nodes
from torch_scatter import scatter_add
import torch_geometric

device = torch.device("cpu")

In [2]:


## Author: Alex Tong
## Reference: Data-Driven Learning of Geometric Scattering Networks, IEEE Machine Learning for Signal Processing Workshop 2021

def scatter_moments(graph, batch_indices, moments_returned=4):
    
    """ Compute specified statistical coefficients for each feature of each graph passed. 
        The graphs expected are disjoint subgraphs within a single graph, whose feature tensor is passed as argument "graph."
        "batch_indices" connects each feature tensor to its home graph.
        "Moments_returned" specifies the number of statistical measurements to compute. 
        If 1, only the mean is returned. If 2, the mean and variance. If 3, the mean, variance, and skew. If 4, the mean, variance, skew, and kurtosis.
        The output is a dictionary. You can obtain the mean by calling output["mean"] or output["skew"], etc.
    """

    # Step 1: Aggregate the features of each mini-batch graph into its own tensor
    graph_features = [torch.zeros(0).to(graph) for i in range(torch.max(batch_indices) + 1)]

    for i, node_features in enumerate(graph):

        # Sort the graph features by graph, according to batch_indices. For each graph, create a tensor whose first row is the first element of each feature, etc.
        # print("node features are", node_features)
        
        if (len(graph_features[batch_indices[i]]) == 0):  
            # If this is the first feature added to this graph, fill it in with the features.
            graph_features[batch_indices[i]] = node_features.view(-1, 1, 1)  # .view(-1,1,1) changes [1,2,3] to [[1],[2],[3]], so that we can add each column to the respective row.
        else:
            graph_features[batch_indices[i]] = torch.cat((graph_features[batch_indices[i]], node_features.view(-1, 1, 1)), dim=1)  # concatenates along columns

    statistical_moments = {"mean": torch.zeros(0).to(graph)}

    if moments_returned >= 2:
        statistical_moments["variance"] = torch.zeros(0).to(graph)
    if moments_returned >= 3:
        statistical_moments["skew"] = torch.zeros(0).to(graph)
    if moments_returned >= 4:
        statistical_moments["kurtosis"] = torch.zeros(0).to(graph)

    for data in graph_features:

        data = data.squeeze()
        
        def m(i):  # ith moment, computed with derivation data
            return torch.mean(deviation_data ** i, axis=1)

        mean = torch.mean(data, dim=1, keepdim=True)
        
        if moments_returned >= 1:
            statistical_moments["mean"] = torch.cat(
                (statistical_moments["mean"], mean.T), dim=0
            )

        # produce matrix whose every row is data row - mean of data row

        #for a in mean:
        #    mean_row = torch.ones(data.shape[1]).to( * a
        #    tuple_collect.append(
        #        mean_row[None, ...]
        #    )  # added dimension to concatenate with differentiation of rows
        # each row contains the deviation of the elements from the mean of the row
        
        deviation_data = data - mean
        
        # variance: difference of u and u mean, squared element wise, summed and divided by n-1
        variance = m(2)
        
        if moments_returned >= 2:
            statistical_moments["variance"] = torch.cat(
                (statistical_moments["variance"], variance[None, ...]), dim=0
            )

        # skew: 3rd moment divided by cubed standard deviation (sd = sqrt variance), with correction for division by zero (inf -> 0)
        skew = m(3) / (variance ** (3 / 2)) 
        skew[
            skew > 1000000000000000
        ] = 0  # multivalued tensor division by zero produces inf
        skew[
            skew != skew
        ] = 0  # single valued division by 0 produces nan. In both cases we replace with 0.
        if moments_returned >= 3:
            statistical_moments["skew"] = torch.cat(
                (statistical_moments["skew"], skew[None, ...]), dim=0
            )

        # kurtosis: fourth moment, divided by variance squared. Using Fischer's definition to subtract 3 (default in scipy)
        kurtosis = m(4) / (variance ** 2) - 3 
        kurtosis[kurtosis > 1000000000000000] = -3
        kurtosis[kurtosis != kurtosis] = -3
        if moments_returned >= 4:
            statistical_moments["kurtosis"] = torch.cat(
                (statistical_moments["kurtosis"], kurtosis[None, ...]), dim=0
            )
    
    # Concatenate into one tensor (alex)
    statistical_moments = torch.cat([v for k,v in statistical_moments.items()], axis=1)
    #statistical_moments = torch.cat([statistical_moments['mean'],statistical_moments['variance']],axis=1)
    
    return statistical_moments


class LazyLayer(torch.nn.Module):
    
    """ Currently a single elementwise multiplication with one laziness parameter per
    channel. this is run through a softmax so that this is a real laziness parameter
    """

    def __init__(self, n):
        super().__init__()
        self.weights = torch.nn.Parameter(torch.Tensor(2, n))

    def forward(self, x, propogated):
        inp = torch.stack((x, propogated), dim=1)
        s_weights = torch.nn.functional.softmax(self.weights, dim=0)
        return torch.sum(inp * s_weights, dim=-2)

    def reset_parameters(self):
        torch.nn.init.ones_(self.weights)
    

def gcn_norm(edge_index, edge_weight=None, num_nodes=None, add_self_loops=False, dtype=None):

    num_nodes = maybe_num_nodes(edge_index, num_nodes)

    if edge_weight is None:
        edge_weight = torch.ones((edge_index.size(1), ), dtype=dtype,
                                 device=edge_index.device)

    if add_self_loops:
        edge_index, tmp_edge_weight = add_remaining_self_loops(
            edge_index, edge_weight, 1, num_nodes)
        assert tmp_edge_weight is not None
        edge_weight = tmp_edge_weight

    row, col = edge_index[0], edge_index[1]
    deg = scatter_add(edge_weight, col, dim=0, dim_size=num_nodes)
    deg_inv_sqrt = deg.pow_(-1)
    deg_inv_sqrt.masked_fill_(deg_inv_sqrt == float('inf'), 0)
    
    return edge_index, deg_inv_sqrt[col] * edge_weight


class Diffuse(MessagePassing):

    """ Implements low pass walk with optional weights
    """

    def __init__(self, in_channels, out_channels, trainable_laziness=False, fixed_weights=True):

        super().__init__(aggr="add",  flow = "target_to_source", node_dim=-3)  # "Add" aggregation.
        assert in_channels == out_channels
        self.trainable_laziness = trainable_laziness
        self.fixed_weights = fixed_weights
        if trainable_laziness:
            self.lazy_layer = LazyLayer(in_channels)
        if not self.fixed_weights:
            self.lin = torch.nn.Linear(in_channels, out_channels)


    def forward(self, x, edge_index, edge_weight=None):

        # x has shape [N, in_channels]
        # edge_index has shape [2, E]

        # Step 2: Linearly transform node feature matrix.
        # turn off this step for simplicity
        if not self.fixed_weights:
            x = self.lin(x)

        # Step 3: Compute normalization
        edge_index, edge_weight = gcn_norm(edge_index, edge_weight, x.size(self.node_dim), dtype=x.dtype)

        # Step 4-6: Start propagating messages.
        propogated = self.propagate(
            edge_index, edge_weight=edge_weight, size=None, x=x,
        )
        if not self.trainable_laziness:
            return 0.5 * (x + propogated), edge_index, edge_weight

        return self.lazy_layer(x, propogated), edge_index, edge_weight


    def message(self, x_j, edge_weight):
        
        # x_j has shape [E, out_channels]
        # Step 4: Normalize node features.
        return edge_weight.view(-1, 1, 1) * x_j


    #def message_and_aggregate(self, adj_t, x):
    #
    #    return matmul(adj_t, x, reduce=self.aggr)


    def update(self, aggr_out):

        # aggr_out has shape [N, out_channels]
        # Step 6: Return new node embeddings.
        return aggr_out


def feng_filters():

    tmp = np.arange(16).reshape(4,4) #tmp doesn't seem to be used!
    results = [4]
    for i in range(2, 4):
        for j in range(0, i):
            results.append(4*i+j)

    return results


class Scatter(torch.nn.Module):

    def __init__(self, in_channels, trainable_laziness=False):

        super().__init__()
        self.in_channels = in_channels
        self.trainable_laziness = trainable_laziness
        self.diffusion_layer1 = Diffuse(in_channels, in_channels, trainable_laziness)
        self.diffusion_layer2 = Diffuse(
            4 * in_channels, 4 * in_channels, trainable_laziness
        )
        self.wavelet_constructor = torch.nn.Parameter(torch.tensor([
            [0, -1.0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1]
        ], requires_grad=True))
        # self.wavelet_constructor = torch.tensor([
        #     [0, -1.0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        #     [0, 0, -1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        #     [0, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        #     [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 1]
        # ], device=device)



    def forward(self, data):

        x, edge_index = data.x, data.edge_index
        s0 = x[:,:,None]
        avgs = [s0]
        for i in range(16):
            avgs.append(self.diffusion_layer1(avgs[-1], edge_index)[0])
        for j in range(len(avgs)):
            avgs[j] = avgs[j][None, :, :, :]  # add an extra dimension to each tensor to avoid data loss while concatenating TODO: is there a faster way to do this?
        
        # Combine the diffusion levels into a single tensor.
        diffusion_levels = torch.cat(avgs)
        
        # Reshape the 3d tensor into a 2d tensor and multiply with the wavelet_constructor matrix
        # This simulates the below subtraction:
        # filter1 = avgs[1] - avgs[2]
        # filter2 = avgs[2] - avgs[4]
        # filter3 = avgs[4] - avgs[8]
        # filter4 = avgs[8] - avgs[16]
        subtracted = torch.matmul(self.wavelet_constructor, diffusion_levels.view(17, -1))
        subtracted = subtracted.view(4, x.shape[0], x.shape[1]) # reshape into given input shape
        s1 = torch.abs(
            torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous

        # perform a second wave of diffusing, on the recently diffused.
        avgs = [s1]
        for i in range(16): # diffuse over diffusions
            avgs.append(self.diffusion_layer2(avgs[-1], edge_index))
        for i in range(len(avgs)): # add an extra dimension to each diffusion level for concatenation
            avgs[i] = avgs[i][None, :, :, :]
        diffusion_levels2 = torch.cat(avgs)
        
        # Having now generated the diffusion levels, we can cmobine them as before
        subtracted2 = torch.matmul(self.wavelet_constructor, diffusion_levels2.view(17, -1))
        subtracted2 = subtracted2.view(4, s1.shape[0], s1.shape[1], s1.shape[2])  # reshape into given input shape
        subtracted2 = torch.transpose(subtracted2, 0, 1)
        subtracted2 = torch.abs(subtracted2.reshape(-1, self.in_channels, 4))
        s2_swapped = torch.reshape(torch.transpose(subtracted2, 1, 2), (-1, 16, self.in_channels))
        s2 = s2_swapped[:, feng_filters()]

        x = torch.cat([s0, s1], dim=2)
        x = torch.transpose(x, 1, 2)
        x = torch.cat([x, s2], dim=1)

        #x = scatter_mean(x, batch, dim=0)
        if hasattr(data, 'batch'):
            x = scatter_moments(x, data.batch, 4)
        else:
            x = scatter_moments(x, torch.zeros(data.x.shape[0], dtype=torch.int32), 4)
            # print('x returned shape', x.shape)
        return x, self.wavelet_constructor


    def out_shape(self):

        # x * 4 moments * in
        return 11 * 4 * self.in_channels


class TSNet(torch.nn.Module):

    def __init__(self, in_channels, out_channels, edge_in_channels = None, trainable_laziness=False, **kwargs):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_in_channels = edge_in_channels
        self.trainable_laziness = trainable_laziness
        self.scatter = Scatter(in_channels, trainable_laziness=trainable_laziness)
        self.lin1 = Linear(self.scatter.out_shape(), out_channels)
        self.lin2 = Linear(out_channels, out_channels)
        self.lin3 = Linear(out_channels, out_channels)
        self.act = torch.nn.LeakyReLU()


    def forward(self, data):

        x, sc = self.scatter(data)
        x = self.act(x)
        x = self.lin1(x)
        x = self.lin2(x)
        x = self.lin3(x)
        return x, sc

# Blis layer and Blis net initial definition

In [97]:
import torch.nn as nn
class Blis(torch.nn.Module):

    def __init__(self, in_channels, trainable_laziness=False, activation = "blis"):

        super().__init__()
        self.in_channels = in_channels
        self.trainable_laziness = trainable_laziness
        self.diffusion_layer1 = Diffuse(in_channels, in_channels, trainable_laziness)
        # self.diffusion_layer2 = Diffuse(
        #     4 * in_channels, 4 * in_channels, trainable_laziness
        # )
        self.wavelet_constructor = torch.nn.Parameter(torch.tensor([
            [1, -1.0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 1, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0],
            [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, -1],
            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]
        ], requires_grad=True))

        if activation == "blis":
            self.activations = [lambda x: torch.relu(x), lambda x: torch.relu(-x)]
        elif activation == None:
            self.activations = [lambda x : x]

    def forward(self, data):

        x, edge_index = data.x, data.edge_index
        s0 = x[:,:,None]
        avgs = [s0]
        for i in range(16):
            avgs.append(self.diffusion_layer1(avgs[-1], edge_index)[0])
        for j in range(len(avgs)):
            avgs[j] = avgs[j][None, :, :, :]  # add an extra dimension to each tensor to avoid data loss while concatenating TODO: is there a faster way to do this?
        
        # Combine the diffusion levels into a single tensor.
        diffusion_levels = torch.cat(avgs)
        
        # Reshape the 3d tensor into a 2d tensor and multiply with the wavelet_constructor matrix
        # This simulates the below subtraction:
        # filter0 = avgs[0] - avgs[1]
        # filter1 = avgs[1] - avgs[2] 
        # filter2 = avgs[2] - avgs[4]
        # filter3 = avgs[4] - avgs[8]
        # filter4 = avgs[8] - avgs[16] 
        # filter5 = avgs[16]
        #subtracted = torch.matmul(self.wavelet_constructor, diffusion_levels.view(17, -1))
        wavelet_coeffs = torch.einsum("ij,jklm->iklm", self.wavelet_constructor, diffusion_levels) # J x num_nodes x num_features x 1
        #subtracted = subtracted.view(6, x.shape[0], x.shape[1]) # reshape into given input shape
        activated = [self.activations[i](wavelet_coeffs) for i in range(len(self.activations))]
        #s1 = self.relu(
        #    torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous
        #s2 = self.relu(-
        #    torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous
        s = torch.cat(activated, axis=-1).transpose(1,0)
        #x = torch.cat((s1,s2), axis = -1)
        #print(x.shape)
        #x = x.reshape(x.shape[0], -1)
        
        return s
        return s.reshape(x.shape[0],-1), diffusion_levels, self.wavelet_constructor

    def out_features(self):

        # x * 4 moments * in
        return 12 * self.in_channels

In [98]:
from blis import DATA_DIR
import os
import blis.models.scattering_transform as st 
import blis.models.wavelets as wav 

dataset = 'traffic'
sub_dataset = "PEMS04"
label= "DAY"
largest_scale = 4
scattering_type = 'blis'
num_layers = 1
highest_moment = 4
wavelet_type = 'W2'

dataset_dir = os.path.join(DATA_DIR, dataset, sub_dataset)
processed_dir =  os.path.join(dataset_dir, 'processed', scattering_type, wavelet_type, f'largest_scale_{largest_scale}')


dataset_dir = os.path.join(DATA_DIR, dataset, sub_dataset)


# load adjacency matrix and signal
A = np.load(os.path.join(dataset_dir, 'adjacency_matrix.npy'), allow_pickle = True)
x = np.load(os.path.join(dataset_dir, 'graph_signals.npy'), allow_pickle = True)
y = np.load(os.path.join(dataset_dir, label, 'label.npy'), allow_pickle = True)
if len(x.shape) == 2:
    x = x[:,:,None]

x = x[:10]

if wavelet_type == 'W2':
    wavelets = wav.get_W_2(A, largest_scale, low_pass_as_wavelet=(scattering_type == 'blis'))
else:
    wavelets = wav.get_W_1(A, largest_scale, low_pass_as_wavelet=(scattering_type == 'blis'))
#coeffs = st.scattering_transform(x, scattering_type, wavelets, num_layers, highest_moment, None)


coeffs = np.stack([np.einsum('ik, nkf->nif', wavelets[j], x) for j in range(len(wavelets))],1).transpose(0,2,1,3)

  d_arr_inv = 1/d_arr


In [99]:
coeffs.shape

(10, 307, 6, 3)

In [100]:
coeffs[0].transpose(1,0,2).shape

(6, 307, 3)

In [101]:
from blis.data.load_from_np import create_dataset

data_list = create_dataset(x, y , A)

blis_mod = Blis(in_channels = 3, trainable_laziness=False, activation = None)
out_coeffs  = blis_mod(data_list[0])

Creating dataset....
Done!


In [103]:
out_coeffs.shape

torch.Size([307, 6, 3, 1])

In [104]:
np.abs(coeffs[0] - out_coeffs.detach().numpy()[...,0]).max()

2.070570521439663e-05

In [45]:
np.abs(coeffs[0].transpose(1,0,2).reshape(len(out_coeffs),-1) - out_coeffs.detach().numpy()).max()

285.2721862792969

In [None]:
#subtracted = subtracted.view(6, x.shape[0], x.shape[1]) # reshape into given input shape
        #s1 = self.relu(
        #    torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous
        #s2 = self.relu(-
        #    torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous
        #x = torch.cat((s1,s2), axis = -1)
        #print(x.shape)
        #x = x.reshape(x.shape[0], -1)

In [68]:
out_coeffs.view(6, 307,3).shape

torch.Size([6, 307, 3])

In [69]:
out_coeffs.shape

torch.Size([6, 921])

In [None]:
subtracted = out_coeffs
subtracted = subtracted.view(6, x.shape[0], x.shape[1]) # reshape into given input shape
s1 = self.relu(
    torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous
s2 = self.relu(-
    torch.transpose(torch.transpose(subtracted, 0, 1), 1, 2))  # transpose the dimensions to match previous
x = torch.cat((s1,s2), axis = -1)
print(x.shape)
x = x.reshape(x.shape[0], -1)

In [57]:
np.abs(coeffs[0] - out_coeffs.detach().numpy()).max()

2.070570521439663e-05

In [47]:
x, edge_index = data_list[0].x, data_list[0].edge_index
s0 = x[:,:,None]
avgs = [s0]
diffusion_layer1 = Diffuse( 1, 1, trainable_laziness = False, fixed_weights= True)
diff_out, edge_index_, edge_weights_ = diffusion_layer1(avgs[-1], edge_index)

A_l = torch_geometric.utils.to_dense_adj(edge_index_, edge_attr = edge_weights_).numpy()[0]

In [48]:
import torch_geometric
A_ = torch_geometric.utils.to_dense_adj(edge_index).numpy()[0]
d_arr = np.sum(A_, axis=0)
d_arr_inv = 1/d_arr
d_arr_inv[np.isinf(d_arr_inv)] = 0
D_inv = np.diag(d_arr_inv)
P_ =  A_ @ D_inv

x_ = (P_ @ x.numpy())
x_ = 0.5 * ( x_ + x.numpy() )

  d_arr_inv = 1/d_arr


In [49]:
np.abs(x_ - diff_out.detach().numpy()[...,0]).max()

7.6293945e-06

In [30]:
A_ = torch_geometric.utils.to_dense_adj(edge_index).numpy()[0]
from blis.models.wavelets import get_P
P_b = get_P(A_)
x_ = (P_b @ x.numpy())

  d_arr_inv = 1/d_arr


In [28]:
np.abs(0.5 * (np.identity(len(P_)) + P_) - P_b).max()

0.0

In [32]:
np.abs(x_ - diff_out.detach().numpy()[...,0]).max()

7.62939453125e-06

In [34]:
from torch_geometric.nn.norm import BatchNorm
from torch_geometric.nn.pool import global_mean_pool
from torch_geometric.nn import GCNConv
class BlisNet(torch.nn.Module):

    def __init__(self, in_channels, out_channels, edge_in_channels = None, trainable_laziness=False, **kwargs):

        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.edge_in_channels = edge_in_channels
        self.trainable_laziness = trainable_laziness
        self.conv1 = GCNConv(in_channels, in_channels)
        self.blis1 = Blis(in_channels, trainable_laziness=trainable_laziness)
        self.blis2 = Blis(self.blis1.out_features(), trainable_laziness=trainable_laziness)
        self.batch_norm = BatchNorm(self.blis2.out_features())
        self.lin1 = Linear(self.blis2.out_features(), self.blis2.out_features()//2 )
        self.mean = global_mean_pool
        self.lin2 = Linear(self.blis2.out_features()//2, out_channels)
        self.lin3 = Linear(out_channels, out_channels)

        self.act = torch.nn.ReLU()


    def forward(self, data):
        # 1. Obtain node embeddings 
        data.x = self.conv1(data.x, data.edge_index)
        #x = x.relu()
        x, sc1 = self.blis1(data)
        data.x = x # there's gotta be a better way to do this...
        x, sc2 = self.blis2(data)
        x = self.batch_norm(x)
        x = self.lin1(x)
        x = self.act(x)
        x = self.mean(x,data.batch)
        x = self.lin2(x)
        x = self.act(x)
        x = self.lin3(x)

        return x

# TUDataset and synthetic test

In [None]:
import os
import torch
import torch
from torch_geometric.datasets import TUDataset

dataset = TUDataset(root='data/TUDataset', name='MUTAG')
torch.manual_seed(12345)
dataset = dataset.shuffle()

train_dataset = dataset[:150]
test_dataset = dataset[150:]

print(f'Number of training graphs: {len(train_dataset)}')
print(f'Number of test graphs: {len(test_dataset)}')

from torch_geometric.loader import DataLoader

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [21]:
model = Blis(dataset.num_node_features)
optimizer = torch.optim.Adam(model.parameters(), lr = .01)
criterion = torch.nn.CrossEntropyLoss()

for data in train_loader:
    print(data)
    out = model(data)

DataBatch(edge_index=[2, 2454], x=[1118, 7], edge_attr=[2454, 4], y=[64], batch=[1118], ptr=[65])
torch.Size([1118, 7, 12])
DataBatch(edge_index=[2, 2712], x=[1220, 7], edge_attr=[2712, 4], y=[64], batch=[1220], ptr=[65])
torch.Size([1220, 7, 12])
DataBatch(edge_index=[2, 828], x=[376, 7], edge_attr=[828, 4], y=[22], batch=[376], ptr=[23])
torch.Size([376, 7, 12])


In [22]:
a = torch.zeros(34323)
a.to(device)

tensor([0., 0., 0.,  ..., 0., 0., 0.], device='cuda:0')

In [37]:
model = BlisNet(dataset.num_node_features, dataset.num_classes, trainable_laziness=False)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')


Epoch: 001, Train Acc: 0.7133, Test Acc: 0.7632
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 008, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 009, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 010, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 011, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 012, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 013, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 014, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 015, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 016, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 017, Train Acc: 0.7400, Test Acc: 0.7368
Epoch: 018, Train Acc: 0.7733, Test Acc: 0.7632
Epoch: 019, Train Acc: 0.7667, Test Acc: 0.7368
Epoch: 020, Train Acc: 0.8200, Test Acc: 0.6579
Epoch: 021, Train Acc: 0.8267, Test Acc:

In [11]:
from torch_geometric.data import Data 
from torch_geometric.loader import DataLoader 
import networkx as nx 
import torch 
import numpy as np 
from torch_geometric.utils import from_networkx 

n_vertices = 30
G = nx.erdos_renyi_graph(n=n_vertices, p = .4) 
signal = 100 * np.random.rand(n_vertices, 1).astype(np.float32)
nx.set_node_attributes(G, {node: {'feature': torch.tensor(signal[node])} for node in G.nodes()})

geo_data = from_networkx(G)
geo_data.x = torch.stack([node[1]['feature'] for node in G.nodes(data=True)])
loader = DataLoader([geo_data], batch_size=1)

# Run the model on the graph
for batch in loader:
    output, sc = model(batch)
    # 6. Convert the output to a numpy array
    output_numpy = output.detach().numpy()
    print(output_numpy)

torch.Size([30, 1, 12])
[[9.13871765e+00 9.96826172e-01 0.00000000e+00 0.00000000e+00
  0.00000000e+00 5.16651306e+01 0.00000000e+00 0.00000000e+00
  2.38974380e+00 2.64044571e+00 6.14662170e-01 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 3.87577362e+01 1.67672081e+01 8.36125755e+00
  6.56978416e+00 2.51951599e+00 3.21403503e-01 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 6.02922592e+01 7.72358322e+00 3.39667511e+00
  2.38058472e+00 9.28695679e-01 1.79817200e-01 0.00000000e+00]
 [0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00
  0.00000000e+00 3.87538071e+01 1.84674015e+01 8.42055702e+00
  5.83637238e+00 1.95967865e+00 2.88284302e-01 0.00000000e+00]
 [1.07378693e+01 7.08545685e+00 7.19158173e+00 3.69013214e+00
  6.28135681e-01 4.74000854e+01 0.00000000e+00 0.00000000e+00
  0.00000000e+00 0.00000000e+00 0.00000000e+00 0.00000000e+00]
 [3.32472382e+01 1.23107834e+01 5.7926712

In [12]:
#manual computation of the blis output:
from numpy import linalg as LA
A = nx.adjacency_matrix(G).toarray()
D_inv = np.diag(1/np.sum(A, axis = 1))
P = 1/2 * (np.eye(n_vertices) + D_inv @ A) 
wavelets = list()
wavelets.append(np.eye(n_vertices) - P)
for j in range(1,5):
    wav = LA.matrix_power(P, 2 **(j-1)) - LA.matrix_power(P, 2 ** (j))
    wavelets.append(wav) 
wavelets.append(LA.matrix_power(P, 2 ** 4)) 


In [13]:
results = []
for wavelet in wavelets:
    results.append(torch.tensor(wavelet @ signal))
wavelet_transforms = torch.stack(results, dim = 2)
print(wavelet_transforms.shape)
m = nn.ReLU()
s1 = m(wavelet_transforms)
s2 = m(-wavelet_transforms)
x = torch.cat((s1,s2), axis = -1)
print(x)
print(x.shape)

torch.Size([30, 1, 6])
tensor([[[8.7951e+00, 1.9880e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          4.8821e+01, 0.0000e+00, 0.0000e+00, 9.0223e-01, 1.9698e+00,
          5.7615e-01, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          4.8835e+01, 2.1654e+01, 1.0754e+01, 8.4886e+00, 3.3004e+00,
          4.1994e-01, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          4.8833e+01, 2.2086e+00, 4.9519e-01, 1.1762e-01, 1.8801e-01,
          1.4054e-01, 0.0000e+00]],

        [[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          4.8822e+01, 2.3992e+01, 1.0460e+01, 7.2844e+00, 2.7909e+00,
          5.1302e-01, 0.0000e+00]],

        [[9.1090e+00, 6.5051e+00, 7.2047e+00, 4.2121e+00, 8.3035e-01,
          4.8872e+01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
          0.0000e+00, 0.0000e+00]],

        [[2.9747e+01, 1.0864e+01, 4.9507e+00, 1.4790e-01, 0.0000e+00,
          4.8849e+01, 

In [20]:
np.allclose(output_numpy, np.array(x).astype(np.float32))

False

In [15]:
output_numpy[:,:,1] - np.array(x[:,:,1])

IndexError: too many indices for array: array is 2-dimensional, but 3 were indexed

unfortunately the blis module does not agree with what I expect the output to be. I'd like to investigate the cause of this and verify that there is no bug in the implementation. There are some more serious reasons that could explain this discrepancy, but the least harmful of which would be some kind of permutation of node order. I have not ruled out any possible explanations yet.

# TUDataset definition and baseline GCN

Number of training graphs: 150
Number of test graphs: 38


In [17]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    def __init__(self, hidden_channels):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(dataset.num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)
        self.lin = Linear(hidden_channels, dataset.num_classes)

    def forward(self, x, edge_index, batch):
        # 1. Obtain node embeddings 
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)
        
        return x

model = GCN(hidden_channels=64)
print(model)

GCN(
  (conv1): GCNConv(7, 64)
  (conv2): GCNConv(64, 64)
  (conv3): GCNConv(64, 64)
  (lin): Linear(in_features=64, out_features=2, bias=True)
)


DataBatch(edge_index=[2, 2470], x=[1127, 7], edge_attr=[2470, 4], y=[64], batch=[1127], ptr=[65])
torch.Size([1127, 7, 12])
DataBatch(edge_index=[2, 2610], x=[1177, 7], edge_attr=[2610, 4], y=[64], batch=[1177], ptr=[65])
torch.Size([1177, 7, 12])
DataBatch(edge_index=[2, 914], x=[410, 7], edge_attr=[914, 4], y=[22], batch=[410], ptr=[23])
torch.Size([410, 7, 12])


In [8]:
from IPython.display import Javascript

model = GCN(hidden_channels=64)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train():
    model.train()

    for data in train_loader:  # Iterate in batches over the training dataset.
         out = model(data.x, data.edge_index, data.batch)  # Perform a single forward pass.
         loss = criterion(out, data.y)  # Compute the loss.
         loss.backward()  # Derive gradients.
         optimizer.step()  # Update parameters based on gradients.
         optimizer.zero_grad()  # Clear gradients.

def test(loader):
     model.eval()

     correct = 0
     for data in loader:  # Iterate in batches over the training/test dataset.
         out = model(data.x, data.edge_index, data.batch)  
         pred = out.argmax(dim=1)  # Use the class with highest probability.
         correct += int((pred == data.y).sum())  # Check against ground-truth labels.
     return correct / len(loader.dataset)  # Derive ratio of correct predictions.


for epoch in range(1, 171):
    train()
    train_acc = test(train_loader)
    test_acc = test(test_loader)
    print(f'Epoch: {epoch:03d}, Train Acc: {train_acc:.4f}, Test Acc: {test_acc:.4f}')

Epoch: 001, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 002, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 003, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 004, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 005, Train Acc: 0.6467, Test Acc: 0.7368
Epoch: 006, Train Acc: 0.6533, Test Acc: 0.7368
Epoch: 007, Train Acc: 0.7467, Test Acc: 0.7632
Epoch: 008, Train Acc: 0.7267, Test Acc: 0.7632
Epoch: 009, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 010, Train Acc: 0.7133, Test Acc: 0.7895
Epoch: 011, Train Acc: 0.7200, Test Acc: 0.7632
Epoch: 012, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 013, Train Acc: 0.7200, Test Acc: 0.7895
Epoch: 014, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 015, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 016, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 017, Train Acc: 0.7400, Test Acc: 0.7632
Epoch: 018, Train Acc: 0.7133, Test Acc: 0.8421
Epoch: 019, Train Acc: 0.7400, Test Acc: 0.7895
Epoch: 020, Train Acc: 0.7533, Test Acc: 0.7368
Epoch: 021, Train Acc: 0.7467, Test Acc: