In [3]:
import torch
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
import random
import copy

dataset = Planetoid(root='../data/Planetoid', name='Cora', transform=NormalizeFeatures())
data = dataset[0]

In [4]:
import csv
freq = {}

with open("results.csv", "r", newline="") as file:
    reader = csv.reader(file)

    # Skip the header row
    next(reader)

    # Iterate over each row in the CSV file
    for row in reader:
        # Extract edge and frequency from the row
        edge = tuple(map(int, row[0].split('-')))  # Assuming edges are integers
        frequency = int(row[1])

        # Update the freq dictionary
        freq[edge] = frequency

In [5]:
def get_top_edges(edge_frequency, top_n):
    sorted_edges = sorted(edge_frequency.items(), key=lambda item: item[1], reverse=True)
    edges = sorted_edges[:top_n]
    return edges


def get_random_edges(edge_frequency, num_edges):
    edges_list = list(edge_frequency.items())
    random_edges = random.sample(edges_list, num_edges)
    return random_edges


def get_bottom_edges(edge_frequency, bottom_n):
    sorted_edges = sorted(edge_frequency.items(), key=lambda item: item[1])
    edges = sorted_edges[:bottom_n]
    return edges


num_edges = int(data.num_edges * 0.05)
print(get_top_edges(freq, num_edges))
print(get_bottom_edges(freq, num_edges))
print(get_random_edges(freq, num_edges))

[((143, 1701), 31), ((1241, 1701), 31), ((306, 1367), 31), ((1042, 1628), 29), ((1362, 1914), 27), ((1441, 1958), 27), ((1986, 1993), 27), ((1703, 2238), 26), ((143, 598), 26), ((525, 2182), 26), ((1040, 1169), 26), ((1169, 1719), 26), ((1169, 1720), 26), ((415, 2182), 26), ((1257, 2248), 26), ((95, 2182), 25), ((111, 1169), 25), ((1169, 1358), 25), ((1169, 1734), 25), ((415, 525), 25), ((485, 1042), 24), ((598, 1636), 24), ((429, 2001), 24), ((2133, 2182), 24), ((525, 1628), 24), ((316, 598), 24), ((310, 1950), 24), ((74, 1042), 23), ((1042, 2054), 23), ((1042, 2055), 23), ((1702, 1703), 23), ((963, 1703), 23), ((1703, 1966), 23), ((1464, 1914), 23), ((687, 1725), 23), ((919, 1358), 23), ((1986, 1990), 23), ((562, 1483), 23), ((1416, 1602), 22), ((1702, 1966), 22), ((175, 1914), 22), ((52, 2182), 22), ((310, 1944), 22), ((371, 1441), 22), ((706, 963), 21), ((1505, 1624), 21), ((1849, 1914), 21), ((95, 456), 21), ((456, 525), 21), ((310, 352), 21), ((310, 1947), 21), ((756, 1692), 21),

In [6]:
def find_edge_index(edge_index, query_edge):
    # Ensure the query edge is a tensor with the shape (2,)
    query_edge = torch.tensor(query_edge, dtype=torch.long)
    
    # Compare all edges to the query_edge, reshaping query_edge for broadcasting
    query_edge = query_edge.view(2, 1)
    matches = (edge_index == query_edge).all(0)  # Check each column (edge) for a match

    # Find indices where match is True
    indices = matches.nonzero(as_tuple=True)[0]
    
    if len(indices) > 0:
        return indices.item()  # Return the first index where the edge is found
    else:
        return -1


query_edge = torch.tensor([1140, 476])
find_edge_index(data.edge_index, query_edge)

  query_edge = torch.tensor(query_edge, dtype=torch.long)


1956

In [7]:
def generate_perturbations(data, edges_to_remove):
    data = copy.deepcopy(data)
    edges = data.edge_index.t()
    
    edges_set = set(map(tuple, edges.numpy()))
    remove_set = set(map(tuple, edges_to_remove.t().numpy()))

    keep_edges = edges_set - remove_set
    keep_edges = torch.tensor(list(keep_edges)).t()
    data.edge_index = keep_edges
    
    return data

In [8]:
def make_undirected(edges):
    res = []
    for e in edges:
        res.append(e[0])
    for e in edges:
        res.append((e[0][1], e[0][0]))
    return torch.tensor(res).t()

In [9]:
top_edges = get_top_edges(freq, num_edges)
bottom_edges = get_bottom_edges(freq, num_edges)
random_edges = get_random_edges(freq, num_edges)

top_edges = make_undirected(top_edges)
bottom_edges = make_undirected(bottom_edges)
random_edges = make_undirected(random_edges)

print(top_edges.shape)
print(data.edge_index.shape)

torch.Size([2, 1054])
torch.Size([2, 10556])


In [10]:
perturbed_graph_top = generate_perturbations(data, top_edges)
perturbed_graph_bottom = generate_perturbations(data, bottom_edges)
perturbed_graph_random = generate_perturbations(data, random_edges)

print(perturbed_graph_top.edge_index.shape)
query_edge = torch.tensor([143, 1701])
assert find_edge_index(perturbed_graph_top.edge_index, query_edge) == -1
assert find_edge_index(perturbed_graph_bottom.edge_index, query_edge) != -1
query_edge = torch.tensor([7, 208])
assert find_edge_index(perturbed_graph_top.edge_index, query_edge) != -1
assert find_edge_index(perturbed_graph_bottom.edge_index, query_edge) == -1

torch.Size([2, 9502])


  query_edge = torch.tensor(query_edge, dtype=torch.long)


In [11]:
from torch_geometric.nn import GATConv
import torch.nn.functional as F


class GAT(torch.nn.Module):
    def __init__(self, hidden_channels, heads):
        super().__init__()
        torch.manual_seed(1234567)
        self.conv1 = GATConv(dataset.num_features, hidden_channels, heads=heads)
        self.conv2 = GATConv(hidden_channels * heads, dataset.num_classes)

    def forward(self, x, edge_index):
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv1(x, edge_index)
        x = F.elu(x)
        x = F.dropout(x, p=0.6, training=self.training)
        x = self.conv2(x, edge_index)
        return x

model = GAT(hidden_channels=8, heads=8)
print(model)

GAT(
  (conv1): GATConv(1433, 8, heads=8)
  (conv2): GATConv(64, 7, heads=1)
)


In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=5e-4)
loss_fn = torch.nn.CrossEntropyLoss()

def train():
      model.train()
      optimizer.zero_grad()
      out = model(data.x, data.edge_index)
      loss = loss_fn(out[data.train_mask], data.y[data.train_mask])
      loss.backward()
      optimizer.step()
      return loss

def test(mask):
      model.eval()
      out = model(data.x, data.edge_index)
      out = F.softmax(out, dim=1)
      pred = out.argmax(dim=1)
      correct = pred[mask] == data.y[mask]
      acc = int(correct.sum()) / int(mask.sum())
      return acc

In [13]:
for epoch in range(1, 101):
    loss = train()
    val_acc = test(data.val_mask)
    test_acc = test(data.test_mask)
    print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Val: {val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Loss: 1.9438, Val: 0.3780, Test: 0.4090
Epoch: 002, Loss: 1.9364, Val: 0.5680, Test: 0.5890
Epoch: 003, Loss: 1.9266, Val: 0.6180, Test: 0.6180
Epoch: 004, Loss: 1.9165, Val: 0.6200, Test: 0.6210
Epoch: 005, Loss: 1.9050, Val: 0.6620, Test: 0.6550
Epoch: 006, Loss: 1.8944, Val: 0.7160, Test: 0.7260
Epoch: 007, Loss: 1.8821, Val: 0.7680, Test: 0.7490
Epoch: 008, Loss: 1.8728, Val: 0.7700, Test: 0.7710
Epoch: 009, Loss: 1.8592, Val: 0.7940, Test: 0.7860
Epoch: 010, Loss: 1.8463, Val: 0.7960, Test: 0.7880
Epoch: 011, Loss: 1.8381, Val: 0.7980, Test: 0.7830
Epoch: 012, Loss: 1.8246, Val: 0.7960, Test: 0.7790
Epoch: 013, Loss: 1.8105, Val: 0.7960, Test: 0.7790
Epoch: 014, Loss: 1.7909, Val: 0.7920, Test: 0.7770
Epoch: 015, Loss: 1.7859, Val: 0.7880, Test: 0.7760
Epoch: 016, Loss: 1.7616, Val: 0.7860, Test: 0.7760
Epoch: 017, Loss: 1.7493, Val: 0.7880, Test: 0.7780
Epoch: 018, Loss: 1.7317, Val: 0.7900, Test: 0.7760
Epoch: 019, Loss: 1.7235, Val: 0.7900, Test: 0.7770
Epoch: 020, 

In [14]:
def get_perturb_acc(mask, perturbed_edge_index):
      model.eval()
      out = model(data.x, perturbed_edge_index)
      out = F.softmax(out, dim=1)
      pred = out.argmax(dim=1)
      correct = pred[mask] == data.y[mask]
      acc = int(correct.sum()) / int(mask.sum())
      return acc

print(get_perturb_acc(data.test_mask, data.edge_index))
print(get_perturb_acc(data.test_mask, perturbed_graph_top.edge_index))
print(get_perturb_acc(data.test_mask, perturbed_graph_bottom.edge_index))
print(get_perturb_acc(data.test_mask, perturbed_graph_random.edge_index))

0.81
0.811
0.806
0.803


In [15]:
def get_logit_diff(perturbed_edge_index):
    model.eval()
    out_orig = model(data.x, data.edge_index)
    out_perturb = model(data.x, perturbed_edge_index)
    logit_diff = out_orig - out_perturb
    l2_norm = torch.norm(logit_diff, p=2) / torch.norm(data.x, p=2)
    return l2_norm

print(get_logit_diff(perturbed_graph_top.edge_index))
print(get_logit_diff(perturbed_graph_bottom.edge_index))
print(get_logit_diff(perturbed_graph_random.edge_index))

tensor(1.1771, grad_fn=<DivBackward0>)
tensor(2.1163, grad_fn=<DivBackward0>)
tensor(1.3359, grad_fn=<DivBackward0>)
