In [87]:
import torch
import torch_geometric as tg
import networkx as nx
import numpy as np
import pandas as pd

from graphgps.loader.dataset.voc_superpixels import VOCSuperpixels
from functools import cached_property

In [None]:
# Load up the dataset

dataset = VOCSuperpixels(root='../../../datasets/VOCSuperpixels',
                         slic_compactness=30,
                         name='edge_wt_only_coord',
                         split='test')

In [31]:
data = dataset[0].clone()
graph = tg.utils.convert.to_networkx(data)


# Only need to do this once, cache for later
all_shortest_paths = nx.algorithms.shortest_paths.dense.floyd_warshall_numpy(graph).astype(int)

# Identify a target node
target_node_idx = 0

# get the shortest paths to target node
shortest_paths_to_target = all_shortest_paths[:, target_node_idx] # [i, j] is the shortest path from i to j

In [38]:
shortest_paths_df = pd.DataFrame(shortest_paths_to_target)
shortest_paths_df = shortest_paths_df.reset_index().rename(columns = {'index': 'node_id',
                                                         0: 'path_length'})

In [42]:
# Obtain the buckets corresponding to each path length. A map path_length -> node_ids with that path length
path_length_buckets = shortest_paths_df.groupby('path_length')['node_id'].groups

In [45]:
# Now that we have this map, we can use it to create a series of masks for the node values


In [78]:
### First we need to decide what mask we are going to use - here are three options

# Just fudge based on the standard deviation
standard_deviation_of_input = data.x.std(dim=0)

# Replace with the mean input values from this graph
mean_of_input = data.x.mean(dim=0)

# Replace with the mean inputs from the entire graph
mean_of_means = []
for d in dataset:
    graph_mean = d.x.mean(dim=0)
    mean_of_means.append(graph_mean)
mean_of_means = torch.row_stack(mean_of_means)
mean_of_means = mean_of_means.mean(0)


In [91]:
# Now fix some path_length and generate a mask for each path length that replaces those node values with a fudged value.

path_length = 5

new_data = data.clone()
for index in path_length_buckets[path_length]:
    new_data.x[index, :] = mean_of_means # can swap this bit out based on what we want to do