In [1]:
from baseline import inference_wrapper
import torch
import torch_geometric.utils as pyg_utils
from torch_geometric.data import Data

  _torch_pytree._register_pytree_node(


In [6]:
def edge_level_augmentation(graph, pe, pt, centrality_measure='degree'):
    pt = torch.tensor(pt)
    
    # Calculate node centrality measures
    if centrality_measure == 'degree':
        node_centrality = pyg_utils.degree(graph.edge_index[0], graph.num_nodes)
    elif centrality_measure == 'eigenvector':
        node_centrality = pyg_utils.eigenvector_centrality(graph.edge_index[0], graph.num_nodes)
    else:
        node_centrality = pyg_utils.pagerank(graph.edge_index[0], graph.num_nodes)
    
    # Calculate edge centrality measures
    edge_weights = (node_centrality[graph.edge_index[0]] + node_centrality[graph.edge_index[1]]) / 2.0
    
    # Take the logarithm of the edge weights
    edge_weights = torch.log(edge_weights)
    
    # Calculate the maximum and average edge weights
    max_weight = edge_weights.max().item()
    avg_weight = edge_weights.mean().item()
    
    # Calculate the probability of removing each edge
    edge_probabilities = torch.min(((max_weight - torch.log(edge_weights)) /
                                    (max_weight - avg_weight)) * pe, pt)
    
    # Remove edges based on their probabilities
    mask = torch.rand(edge_probabilities.size()) < edge_probabilities
    removed_edge_indices = graph.edge_index[:, mask]
    removed_edge_count = removed_edge_indices.size(1)
    graph.edge_index = graph.edge_index[:, ~mask]
    
    # Print the number of removed edges
    print(f"Number of removed edges: {removed_edge_count}")
    
    return graph