<a href="https://colab.research.google.com/github/AchrafAsh/gnn-receptive-fields/blob/main/test_correlation_implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os, sys
import os.path as osp
from google.colab import drive, files
drive.mount('/content/mnt')
nb_path = '/content/notebooks'
try:
    os.symlink('/content/mnt/My Drive/Colab Notebooks', nb_path)
except:
    pass
sys.path.insert(0, nb_path)  # or append(nb_path)

In [None]:
import time
import concurrent.futures
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import torch

from collections import Counter
from torch_geometric.utils import to_dense_adj
from tqdm.notebook import tqdm

%matplotlib inline
sns.set_theme(font_scale=1.8)
sns.set_style("white")

device = torch.device("cuda" if torch.cuda.is_available() else 'cpu')

In [None]:
def clamp(x: torch.Tensor):
    if not x.is_coalesced(): x = x.coalesce()

    mask = (x._values() > 0).nonzero().view(-1)
    values = x._values().index_select(0, mask).clamp(0, 1)
    indices = x._indices().index_select(1, mask)

    return torch.sparse_coo_tensor(indices, values, x.shape).coalesce()

In [None]:
def sparse_hop_neighbors(k:int, edge_index: torch.Tensor, num_nodes:int):
    # transform edge_index into a sparse tensor
    yield edge_index, edge_index

    if k > 1:
        sparse_edge_index = torch.sparse_coo_tensor(edge_index, torch.ones(edge_index.size(1)), (num_nodes, num_nodes))
        cum_neighbors = neighbors = pow_A = sparse_edge_index.clone()

    for _ in range(1, k):
        pow_A = clamp(torch.sparse.mm(sparse_edge_index, pow_A))
        neighbors = clamp(pow_A - cum_neighbors)
        cum_neighbors = (cum_neighbors + neighbors).coalesce()
        
        yield neighbors.indices(), cum_neighbors.indices()

In [None]:
def get_neighbors(edge_index:torch.Tensor, i:int):
    first_indices = (edge_index[0] == i).nonzero().view(-1)
    second_indices = (edge_index[1] == i).nonzero().view(-1)

    first_neb_indices = edge_index[1][first_indices]
    second_neb_indices = edge_index[1][second_indices]
    
    neb_indices = torch.cat((
        first_neb_indices,
        second_neb_indices),
        0
    )

    indices = torch.unique(neb_indices, sorted=True)
    return indices

In [None]:
def scale(X:torch.Tensor):   
    m = X.mean(0)
    s = X.std(0)
    ones = torch.ones(s.shape).to(device)
    s = torch.where(s == 0, ones, s)
    return (X - m)/ s

In [None]:
def centroids(X:torch.Tensor, y:torch.Tensor):
    num_classes = y.max().item() + 1
    
    # group nodes by label
    obs = {}
    for i in range(X.size(0)):
        if obs.get(y[i].item()):
            obs[y[i].item()] += [X[i]]
        else:
            obs[y[i].item()] = [X[i]]

    return torch.stack([sum(obs[c]) / len(obs[c]) for c in range(num_classes)], 0)

In [None]:
def corr(x, y, i):
    if x.size(0) == 0: return torch.tensor(0., device=device)
    cov = torch.einsum('ij, j -> i', x, y)
    norm = torch.mm(x, x.t()).diag().sqrt() * torch.matmul(y, y).sqrt()
    return cov / norm

def graph_correlation(edge_index:torch.Tensor, x:torch.Tensor, y:torch.Tensor, y_mean:torch.Tensor):
    """Returns the list of correlations between the barycenter representation of
    labels and the neighbor features.

    Args:
        - edge_index - sparse adjacency matrix
        - x [num_nodes, num_features]: node features
        - y [num_nodes, num_features]: label representation associated with the target node
        :rtype: list [num_nodes]: correlation (scalar) for every node
    """

    num_nodes = x.size(0)
    y_scaled = y.sub(y_mean)

    return torch.stack(
        [corr(x=x[get_neighbors(edge_index, i)], y=y_scaled[i], i=i).abs().mean()
        for i in range(num_nodes)]
        , 0)

In [None]:
def pre_processing(graph):
    graph = graph.to(device)
    
    x_scaled = scale(graph.x)
    scaled_centroids = centroids(x, graph.y)
    y = torch.stack([scaled_centroids[graph.y[i]] for i in range(graph.num_nodes)]).to(device)
    y_mean = scaled_centroids.mean(0)

    return x_scaled, y, y_mean

In [None]:
def confidence(values: torch.Tensor):
    """Returns the 95% confidence interval of the array of values
    """
    q = 1.96
    m = values.mean()
    s = values.std()
    
    return m - q * s/np.sqrt(len(values)), m + q * s/np.sqrt(len(values))


def graph_summary(graph, K=10):
    graph = graph.to(device)
    x, y, y_mean = pre_processing(graph)
    data = pd.DataFrame({'k': [],
                         'homophily_neighbors':[],
                         'homophily_neighborhood':[],
                         'correlation_neighbors':[],
                         'correlation_neighborhood':[],
                         'neighbors_count':[],
                         'neighborhood_count':[]})

    idx, k = 0, 0
    for neighbors, cum_neighbors in tqdm(
        sparse_hop_neighbors(K, graph.edge_index, graph.num_nodes),
        total=K):

        k += 1
        # measure graph properties
        corr_neighbors_conf = confidence(graph_correlation(neighbors, x=x, y=y, y_mean=y_mean))
        corr_neighborhood_conf = confidence(graph_correlation(cum_neighbors, x=x, y=y, y_mean=y_mean))

        data.loc[idx] = {'k':k,
                         'correlation_neighbors':corr_neighbors_conf[0].item(),
                         'correlation_neighborhood':corr_neighborhood_conf[0].item()}
        idx += 1
        data.loc[idx] = {'k':k,
                         'correlation_neighbors':corr_neighbors_conf[1].item(),
                         'correlation_neighborhood':corr_neighborhood_conf[1].item()}
        idx += 1

    return data

In [None]:
def plot_summary(data):
    _, ax = plt.subplots(1, 3, figsize=(32,8))
    lineplot1 = sns.lineplot(ax=ax[0], x='k', y='value', 
                             hue='variable',
                             style='variable',
                             markers=True,
                             data=pd.melt(data[['k', 'homophily_neighbors', 'homophily_neighborhood']], ['k']))
    lineplot1.set(xlabel="depth", ylabel="index", title="Homophily")
    lineplot1.legend(('neighbors', 'neighborhood'), frameon=False).set_title(None)

    lineplot2 = sns.lineplot(ax=ax[1], x='k', y='value',
                             hue='variable',
                             style='variable',
                             markers=True,
                             data=pd.melt(data[['k', 'correlation_neighbors', 'correlation_neighborhood']], ['k']))
    lineplot2.set(xlabel="depth", ylabel="correlation", title="Correlation")
    lineplot2.legend(('neighbors', 'neighborhood'), frameon=False).set_title(None)

    lineplot3 = sns.lineplot(ax=ax[2], x='k', y='value',
                             hue='variable',
                             style='variable',
                             markers=True,
                             data=pd.melt(data[['k', 'neighbors_count', 'neighborhood_count']], ['k']))
    lineplot3.set(xlabel="depth", ylabel="count", title="Neighbors count")
    lineplot3.legend(('neighbors', 'neighborhood'), loc="upper left", frameon=False).set_title(None)

In [None]:
%%capture
!wget https://raw.githubusercontent.com/AchrafAsh/gnn-receptive-fields/main/data.py

from data import load_dataset
path = osp.join(os.getcwd(), 'data')
cora = load_dataset(path, 'Cora')

In [None]:
x, y, y_mean = pre_processing(cora[0])

In [None]:
graph_correlation(graph.edge_index, x, y, y_mean)