In [50]:
%load_ext autoreload
%autoreload 2

import torch
from src.models.gps import GPS
from src.models.utils.hooks import GPSHook
from src.models.explainer.explainer_pipeline import ExplainerPipeline
from src.data import loader
from src.models.model import train, test

from src.models.gps import GPS
from src.models.gcn import GCN
from src.models.explainer.explainer_pipeline import ExplainerPipeline
from src.models.explainer.gnn_explainer import GNNExplainer
from src.models.explainer.attention_explainer import AttentionExplainer
from src.data import loader

import networkx as nx

from torch_geometric.explain import ModelConfig

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [51]:
data, num_classes, data_networkx = loader.load_clean_bashapes(num_nodes=25, num_edges=5, num_motifs=10, laplacian_eigenvector_dimensions=2)

In [69]:
gps_params = {
    'pe_channels': 2,
    'num_layers': 4,
    'hidden_channels': 4,
    'num_attention_heads': 1,
    'observe_attention': True
}

gcn_params = {
    "hidden_channels": 20,
    "num_layers": 3
}

explainer_params = {
    'explanation_type': 'model',
    'node_mask_type': 'attributes',
    'edge_mask_type': 'object',
    'model_config': ModelConfig(
        mode='multiclass_classification',
        task_level='node',
        return_type='raw',
    )
}

explainer_pipeline = ExplainerPipeline(data, num_classes, GPS, explainer=AttentionExplainer, Hook=GPSHook, model_params=gps_params, explainer_params=explainer_params, epochs=2)
explainer_pipeline.get_accuracies()

100%|██████████| 2/2 [00:00<00:00, 66.00it/s]

Train accuracy: 0.26666666666666666
Test accuracy: 0.26666666666666666





In [71]:
explainer_pipeline.explain(26, laplacian_eigenvector_pe = data.laplacian_eigenvector_pe)
explainer_pipeline.explanations[26]

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,  1,  1,  1,
          1,  1,  1,  1,  1,  1,  2,  3,  3,  3,  3,  4,  4,  4,  4,  4,  4,  4,
          4,  4,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  5,  6,  6,  6,  6,  6,
          6,  6,  6,  6,  7,  7,  7,  7,  7,  7,  7,  7,  7,  8,  8,  8,  8,  9,
          9,  9,  9,  9,  9,  9,  9,  9,  9,  9, 10, 10, 10, 10, 10, 10, 10, 10,
         10, 10, 11, 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, 13,
         13, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, 16, 16, 16, 16,
         17, 17, 17, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, 19, 19, 20, 20, 20,
         20, 20, 21, 21, 21, 21, 21, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 24,
         24, 24, 24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 28, 28, 29, 29,  6, 26,
         30, 30, 30, 31, 31, 31, 32, 32, 33, 33, 34, 34,  2, 33, 35, 35, 35, 36,
         36, 36, 37, 37, 38, 38, 39, 39,  1, 36, 40, 40, 40, 41, 41, 41, 42, 42,
         43, 43, 44, 44,  3,

AttributeError: 'dict' object has no attribute 'dim'

In [None]:
def max_reduce(matrix):

    #apply max reduction along columns (dim = 1)
    max_values, _ = matrix.max(dim=1)
    return max_values

#finds the weighted_average for one singular attention matrix

#when summing across dim=0, output shows how much attention each node receives, when dim=1, it shows how much attention each node is giving

def weighted_average_received(attention_matrix):
    
    # Apply softmax across rows (dim=1) to normalize each row of the attention matrix
    softmax_attention = F.softmax(attention_matrix, dim=1)
    
    # Compute the weighted average across each row (dim=1) by summing
    weighted_avg = softmax_attention.sum(dim=0)  # Sum along the columns
    
    return weighted_avg

def weighted_average_given(attention_matrix):

    """Computes the weighted average of an attention matrix to show how much attention each node is giving to others"""
    
    # Step 1: Apply softmax across rows (dim=1) to normalize attention
    softmax_attention = F.softmax(attention_matrix, dim=1)
    
    # Step 2: Compute the weighted average across rows (dim=1)
    # Multiply each value by its respective column index (weighted sum)
    weighted_avg = torch.matmul(softmax_attention, torch.arange(attention_matrix.size(1), dtype=torch.float32))

    return weighted_avg

def weighted_average_all_layers(function, matrices):

    #store all the weighted averages per matrix (from each layer)
    weighted_averages = []

    #use the weighted_average function (single matrix use case) in a loop to collect all the weighted averages,
    #and append to list

    for matrix in matrices:
        current_weighted_avg = function(matrix)
        weighted_averages.append(current_weighted_avg)

    # Compute the average of the weighted averages across all layers
    avg_all_matrices= torch.stack(weighted_averages).mean(dim=0)

    return weighted_averages, avg_all_matrices

def top_k_nodes(matrix, top_k=0.1):

    # Rank nodes by importance (highest first)
    sorted_indices = torch.argsort(matrix, descending=True)

    # Select the top K nodes (either as a percentage or fixed number)
    if isinstance(top_k, float):
        top_k = int(len(sorted_indices) * top_k)  # Percentage to number of nodes
    top_nodes = sorted_indices[:top_k]
    return top_nodes