In [1]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import random
from collections import defaultdict
import copy
import ast
import csv
import numpy as np

dataset = Planetoid(root='data/Planetoid', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

In [2]:
explanations = {}

with open("explanations.csv", mode='r') as infile:
    reader = csv.reader(infile)
    next(reader)
    for row in reader:
        if row:  # Check if row is not empty
            index = row[0]
            tensor_str = row[1]
            tensor_list = ast.literal_eval(tensor_str)
            explanations[index] = tensor_list

In [6]:
def get_node_frequencies(data, explanations):
    node_frequencies_big = defaultdict(int)
    node_frequencies_mid = defaultdict(int)
    node_frequencies_small = defaultdict(int)

    for i, edge_weight in explanations.items():
        edge_weight = torch.tensor(edge_weight)
        edge_index = data.edge_index

        significant_edge_mask_big = edge_weight > 0.5
        significant_edge_mask_mid = edge_weight > 0.1
        significant_edge_mask_small = edge_weight > 0.01
        significant_edge_index_big = edge_index[:, significant_edge_mask_big]
        significant_edge_index_mid = edge_index[:, significant_edge_mask_mid]
        significant_edge_index_small = edge_index[:, significant_edge_mask_small]

        nodes_big = np.unique(significant_edge_index_big.numpy())
        nodes_mid = np.unique(significant_edge_index_mid.numpy())
        nodes_small = np.unique(significant_edge_index_small.numpy())

        for node in nodes_big:
            node_frequencies_big[node] += 1

        for node in nodes_mid:
            node_frequencies_mid[node] += 1

        for node in nodes_small:
            node_frequencies_small[node] += 1

    return node_frequencies_big, node_frequencies_mid, node_frequencies_small


node_freq_big, node_freq_mid, node_freq_small = get_node_frequencies(data, explanations)

In [8]:
def divide_into_chunks(freq, n):
    # Step 1: Sort data by value
    lst = sorted(freq.items(), key=lambda x: x[1], reverse=True)
    """Divide the list lst into n equally-sized chunks."""
    k, m = divmod(len(lst), n)
    return [lst[i * k + min(i, m):(i + 1) * k + min(i + 1, m)] for i in range(n)]


num_bins = 10
bins_big = divide_into_chunks(node_freq_big, num_bins)
bins_mid = divide_into_chunks(node_freq_mid, num_bins)
bins_small = divide_into_chunks(node_freq_small, num_bins)

In [93]:
from torch_geometric.nn import GATConv
import torch.nn.functional as F


class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATConv(dataset.num_features, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(hidden_channels=8, heads=8)
print(model)

GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=1)
)


In [94]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      out = F.softmax(out, dim=1)
      pred = out.argmax(dim=1)
      correct = pred[mask] == data.y[mask]
      acc = int(correct.sum()) / int(mask.sum())
      return acc

In [95]:
for epoch in range(1, 101):
    loss = train()
    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Loss: 1.9438, Val: 0.3780, Test: 0.4090
Epoch: 002, Loss: 1.9364, Val: 0.5680, Test: 0.5890
Epoch: 003, Loss: 1.9266, Val: 0.6180, Test: 0.6180
Epoch: 004, Loss: 1.9165, Val: 0.6200, Test: 0.6210
Epoch: 005, Loss: 1.9050, Val: 0.6620, Test: 0.6550
Epoch: 006, Loss: 1.8944, Val: 0.7160, Test: 0.7260
Epoch: 007, Loss: 1.8821, Val: 0.7680, Test: 0.7490
Epoch: 008, Loss: 1.8728, Val: 0.7700, Test: 0.7710
Epoch: 009, Loss: 1.8592, Val: 0.7940, Test: 0.7860
Epoch: 010, Loss: 1.8463, Val: 0.7960, Test: 0.7880
Epoch: 011, Loss: 1.8381, Val: 0.7980, Test: 0.7830
Epoch: 012, Loss: 1.8246, Val: 0.7960, Test: 0.7790
Epoch: 013, Loss: 1.8105, Val: 0.7960, Test: 0.7790
Epoch: 014, Loss: 1.7909, Val: 0.7920, Test: 0.7770
Epoch: 015, Loss: 1.7859, Val: 0.7880, Test: 0.7760
Epoch: 016, Loss: 1.7616, Val: 0.7860, Test: 0.7760
Epoch: 017, Loss: 1.7493, Val: 0.7880, Test: 0.7780
Epoch: 018, Loss: 1.7317, Val: 0.7900, Test: 0.7760
Epoch: 019, Loss: 1.7235, Val: 0.7900, Test: 0.7770
Epoch: 020, 

In [96]:
print(data.x.shape)
print(min([key for key, val in node_freq_big]))

torch.Size([2708, 1433])
0


In [105]:
import torch_geometric.utils as pyg_utils


def generate_perturbations(data, nodes_to_remove):
    data = copy.deepcopy(data)
    nodes_to_remove = set(nodes_to_remove)  # Convert tensor to set of node indices
    print(len(nodes_to_remove))

    # # Get the current edge list and convert it to a list of tuples
    # edges = data.edge_index.t()
    # edges_list = edges.tolist()
    
    # # Filter out edges that connect to the nodes to be removed
    # filtered_edges = [edge for edge in edges_list if edge[0] not in nodes_to_remove and edge[1] not in nodes_to_remove]
    
    # # Update the edge indices in the data
    # data.edge_index = torch.tensor(filtered_edges).t()
    
    # Create a mask that is False for nodes to remove and True for others
    mask = torch.ones(data.x.size(0), dtype=torch.bool)  # data.x.size(0) is the number of nodes
    mask[list(nodes_to_remove)] = False
    
    data.x = data.x[mask]

    all_nodes = set(range(2708))
    print(len(all_nodes))
    nodes_to_retain = all_nodes.difference(nodes_to_remove)
    print(len(nodes_to_retain))
    print(data.edge_index.shape)
    data.edge_index = pyg_utils.subgraph(sorted(nodes_to_retain), data.edge_index, relabel_nodes=True)[0]
    print(data.edge_index.shape)

    return data

def get_logit_diff(data, perturbed_data, model):
    model.eval()
    with torch.no_grad():
        out_orig = model(data.x, data.edge_index)
        out_perturb = model(perturbed_data.x, perturbed_data.edge_index)
        
        # Normalize outputs by their corresponding inputs
        norm_out_orig = torch.norm(out_orig, p=2) / torch.norm(data.x, p=2)
        norm_out_perturb = torch.norm(out_perturb, p=2) / torch.norm(perturbed_data.x, p=2)
        
        # Compute the logit difference
        l2_norm = norm_out_orig - norm_out_perturb
        
    return l2_norm
        

In [106]:
num_samples = 10
results = torch.zeros(num_bins, num_samples)

for i, bin in enumerate(bins_small):
    for j in range(num_samples):
        half_size = len(bin) // 2
        sampled_nodes = random.sample(bin, half_size)
        nodes_to_remove = [i[0] for i in sampled_nodes]
        perturbed_data = generate_perturbations(data, nodes_to_remove)
        results[i, j] = get_logit_diff(data, perturbed_data, model)

135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8748])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8956])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8842])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8750])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8810])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8860])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8698])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8686])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8338])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8740])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 9070])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 9078])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 9048])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 8988])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 9096])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 9132])
135
2708
2573
torch.Size([2, 10556])
torch.Size([2, 9216