In [None]:
import pickle
import torch
import torch_geometric as tg

import networkx as nx

import pandas as pd
import numpy as np

import matplotlib.pyplot as plt

### Example Script

In [None]:
with open(r'inf_scores_gcn_with_adj.pkl', 'rb') as f:
    all_res = pickle.load(f)
    print(f"{len(all_res)} graphs loaded")

x = all_res[0]

In [None]:
# Normalisation code

inf = x['influence_score'].copy()

def row_normalise(inf):

    row_totals = inf.sum(axis=1)
    norm = (inf.T/row_totals).T

    return norm

assert np.allclose(row_normalise(inf).sum(axis=1), 1.0)

In [None]:
inf = row_normalise(inf)
df = pd.DataFrame(inf)

In [None]:
df.index.name = 'source' # TODO is this right?

melted = df.melt(ignore_index = False)
melted = melted.reset_index() # Move source index to be a column, gives unique index
melted = melted.rename(columns={'variable': 'target',
                                'value': 'influence_score',
                                'index': 'source'})

In [None]:
source_coords = pd.DataFrame(x['xpos'])
target_coords = source_coords.copy()

source_coords.index.name = 'source'
source_coords = source_coords.reset_index()

target_coords.index.name = 'target'
target_coords = target_coords.reset_index()


target_cords = target_coords.rename(columns={0: 'target_x', 1: 'target_y'})
source_coords = source_coords.rename(columns={0: 'source_x', 1: 'source_y'})

In [None]:
print(len(melted))

In [None]:
melted = melted.merge(source_coords, on='source', how='left')
melted = melted.merge(target_cords, on='target', how='left')

In [None]:
print(len(melted))

In [None]:
melted['distance_sq'] = (melted['source_x'] - melted['target_x'])**2 + (melted['source_y'] -
                                                                        melted['target_y'] )**2
melted['distance'] = np.sqrt(melted['distance_sq'])

In [None]:
print(len(melted))

In [None]:
melted.plot.scatter('distance', 'influence_score', logy=False, alpha=0.2)

### Graph Based Distance Measure

In [None]:
edges = x['edges'].copy().astype(int)

coords = x['xpos'].copy()

edges = torch.Tensor(edges)
coords = torch.Tensor(coords)

data = tg.data.Data(edge_index=edges,
                    x=coords)

In [None]:
g = tg.utils.convert.to_networkx(data)

In [None]:
shortest_paths = nx.algorithms.shortest_paths.dense.floyd_warshall_numpy(g, weight=None)
shortest_paths = pd.DataFrame(shortest_paths)

shortest_paths_map = shortest_paths.melt(ignore_index=False)
shortest_paths_map = shortest_paths_map.reset_index()
shortest_paths_map = shortest_paths_map.rename(columns={'index': 'source',
                                                        'variable': 'target',
                                                        'value': 'graph_distance'}
                                               )

shortest_paths_map['graph_distance'] = shortest_paths_map['graph_distance'].astype('int')

In [None]:
melted = melted.merge(shortest_paths_map, on=['source', 'target'])
melted.plot.scatter('graph_distance', 'influence_score', logy=False, alpha=0.2,)

In [None]:
per_distance_per_source = melted.groupby(['graph_distance', 'source'])['influence_score'].sum()
per_distance_per_source.groupby('graph_distance').mean().plot()

In [None]:
per_distance_per_source.groupby('graph_distance').mean().sum()

### Proper pipeline

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from influence import process_all_graphs, plot_mean_influence_by_distance

In [None]:
influence_df_gcn = process_all_graphs('inf_scores_gcn_with_adj.pkl',
                                      normalise=True)

In [None]:
influence_df_transformer = process_all_graphs('inf_scores_GT.pkl',
                                              normalise=True)

In [None]:
influence_df_egnn = process_all_graphs('inf_scores_egnn.pkl', normalise=True)

In [None]:
fig, ax = plt.subplots()
plot_mean_influence_by_distance(influence_df_gcn, ax, 'gcn')
plot_mean_influence_by_distance(influence_df_transformer, ax, 'transformer')
plot_mean_influence_by_distance(influence_df_egnn, ax, 'egnn')

In [None]:
influence_df_gcn[influence_df_gcn.target == 1].count()