In [17]:
import torch
import torch_geometric as tg
import networkx as nx
import numpy as np
import pandas as pd

from graphgps.loader.dataset.voc_superpixels import VOCSuperpixels
from functools import cached_property

In [18]:
# Load up the dataset

dataset = VOCSuperpixels(root='../../../datasets/VOCSuperpixels',
                         slic_compactness=30,
                         name='edge_wt_only_coord',
                         split='test')

In [19]:
data = dataset[0].clone()
graph = tg.utils.convert.to_networkx(data)


# Only need to do this once, cache for later
all_shortest_paths = nx.algorithms.shortest_paths.dense.floyd_warshall_numpy(graph).astype(int)

# Identify a target node
target_node_idx = 0

# get the shortest paths to target node
shortest_paths_to_target = all_shortest_paths[:, target_node_idx] # [i, j] is the shortest path from i to j

In [20]:
shortest_paths_df = pd.DataFrame(shortest_paths_to_target)
shortest_paths_df = shortest_paths_df.reset_index().rename(columns = {'index': 'node_id',
                                                         0: 'path_length'})

In [21]:
# Obtain the buckets corresponding to each path length. A map path_length -> node_ids with that path length
path_length_buckets = shortest_paths_df.groupby('path_length')['node_id'].groups

In [22]:
path_length_buckets

{0: [0], 1: [2, 25, 29, 41], 2: [1, 3, 19, 24, 28, 34, 45, 47, 53, 61, 65, 84], 3: [4, 16, 37, 42, 57, 66, 67, 80, 81, 96, 97], 4: [5, 6, 22, 30, 48, 59, 68, 76, 77, 85, 98, 106, 119, 120, 127], 5: [7, 8, 20, 35, 38, 54, 58, 75, 82, 99, 100, 101, 115, 117, 134, 137, 143, 152, 162, 173], 6: [9, 10, 23, 31, 43, 49, 72, 78, 93, 95, 111, 116, 121, 122, 138, 140, 153, 157, 163, 175, 191, 192, 200, 210], 7: [11, 18, 21, 32, 44, 50, 56, 69, 79, 86, 102, 113, 132, 133, 144, 154, 158, 169, 170, 179, 183, 201, 213, 216, 232, 233, 234, 252], 8: [12, 17, 26, 39, 51, 55, 62, 70, 71, 87, 107, 112, 118, 123, 135, 145, 151, 166, 177, 187, 190, 202, 208, 230, 245, 256, 263, 268, 292], 9: [13, 14, 15, 27, 33, 40, 46, 60, 73, 88, 89, 94, 108, 129, 131, 139, 141, 156, 167, 176, 188, 189, 204, 207, 217, 220, 223, 238, 253, 269, 272, 276, 283, 293, 312, 321, 336], 10: [36, 52, 63, 64, 74, 83, 90, 103, 104, 114, 130, 146, 149, 159, 168, 171, 185, 186, 194, 205, 212, 227, 229, 240, 246, 255, 260, 277, 284, 29

In [23]:
all_shortest_paths

array([[ 0,  1,  1, ..., 18, 15, 19],
       [ 2,  0,  1, ..., 19, 16, 20],
       [ 1,  1,  0, ..., 19, 16, 19],
       ...,
       [18, 18, 18, ...,  0,  4,  1],
       [16, 16, 16, ...,  4,  0,  5],
       [19, 19, 19, ...,  1,  5,  0]])

In [15]:
### First we need to decide what mask we are going to use - here are three options

# Just fudge based on the standard deviation
standard_deviation_of_input = data.x.std(dim=0)

# Replace with the mean input values from this graph
mean_of_input = data.x.mean(dim=0)

# Replace with the mean inputs from the entire graph
mean_of_means = []
for d in dataset:
    graph_mean = d.x.mean(dim=0)
    mean_of_means.append(graph_mean)
mean_of_means = torch.row_stack(mean_of_means)
mean_of_means = mean_of_means.mean(0)


In [16]:
# Now fix some path_length and generate a mask for each path length that replaces those node values with a fudged value.

path_length = 5

new_data = data.clone()
for index in path_length_buckets[path_length]:
    new_data.x[index, :] = mean_of_means # can swap this bit out based on what we want to do