In [1]:
# import libraries
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

from utils.gengraph import gen_syn1, gen_syn3, gen_syn4, preprocess_input_graph

import networkx as nx
import matplotlib.pyplot as plt

# Define GCN Classes

In [2]:
class GraphConvolutionLayer(nn.Module):
    """
    A single Graph Convolution Layer
    """
    def __init__(self, dim_in:int, dim_out:int):
        super().__init__()
        
        self.dim_in = dim_in # input number of features
        self.dim_out = dim_out # output size
        # Initialize weights as suggested, using Glorot initialization
        self.W = nn.Parameter(
            nn.init.xavier_uniform_(torch.empty(self.dim_in,
                                                self.dim_out,
                                                dtype=torch.float64)))
    
    def forward(self, H:torch.sparse, A:torch.sparse):# -> torch.Tensor:
        
        # compute degree matrix:
        # 1. Combine sum over rows and sum over columns
        # 2. Remove identity vector from sums to deal with double self-loops
        # 3. Take inverse square root of all elements and create diagonal matrix
        self.D = torch.diag(1/torch.sqrt(
            (-torch.ones(A.shape[0]) + A.sum(axis=0) + A.sum(axis=1))/2))
        
        # Compute Â
        self.A_hat = self.D @ A @ self.D
        
        return (self.A_hat @ H @ self.W) 


class GCN(nn.Module):
    def __init__(self, dim_in, dim_hidden, dim_out):
        super().__init__()

        self.dim_in = dim_in
        self.dim_hidden = dim_hidden
        self.dim_out = dim_out
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(p=0.3)
        
        self.GCN1 = GraphConvolutionLayer(dim_in, dim_hidden)
        self.GCN2 = GraphConvolutionLayer(dim_hidden, dim_out)

    def encode(self, H:torch.sparse, A:torch.sparse) -> torch.Tensor:
        h1 = self.relu(self.GCN1(H, A))
        h1 = self.dropout(h1)
        return self.GCN2(h1, A)
    
    def decode(self, node_embeds, edge_ids):
        """ Compute edge embeddings based on node similarity
        
        Edge_ids stores all pairs of node ids for which we want link predictions
        
        ::param:: node_embeds <- torch.tensor(n_nodes, size_embeds)
        ::param:: edge_ids <- torch.tensor(n_edges, 2)
        
        Return:
            edge_embeds <- torch.tensor(n_edges, 1)
        """
        edge_embeds = torch.sum(node_embeds[edge_ids.T[0]] * node_embeds[edge_ids.T[1]], dim=-1)
        return edge_embeds

In [3]:
# a, b, c = gen_syn1()
# a, b, c = gen_syn3()
a, b, c = gen_syn4()

data = preprocess_input_graph(a, b)


feat_dict[0]["feat"]: float32
G.nodes[0]["feat"]: float32


In [6]:
def gen_data_split(adj, p:float=0.1):
    """ 
    Returns train sub-graph and edge indices and labels for train and test links

    ::param:: adj <- np.array : adjacency matrix
    ::param:: p <- float : proportion of positive links to remove
                           and of pos/neg links to train on
    
    Returns Tuple(torch.sparse_coo, np.array, np.array,
                  np.array, np.array):
        ...
    """
    # Find all pos edge indices in adj
    pos_edges = np.argwhere(adj == np.argmax(adj)).tolist()
    np.random.shuffle(pos_edges)
    
    # pos edges to remove
    len_split = int(len(pos_edges)*p)
    pos_ids = pos_edges[:len_split]
    pos_test = pos_edges[len_split:]
    len_test = len(pos_test)
    
    # select neg edges
    neg_edges = np.argwhere(adj == np.argmin(adj)).tolist()
    np.random.shuffle(neg_edges)
    neg_ids = neg_edges[:len_split] # same number as pos edges
    neg_test = neg_edges[len_split:len_split + len_test] # same number as pos edges
    
    sub_g = adj.copy()
    for id in pos_ids:
        sub_g[id] = 0.0
    
    return sub_g, np.concatenate((pos_ids, neg_ids)),\
        np.concatenate((np.ones(len_split), np.zeros(len_split))),\
        np.concatenate((pos_test, neg_test)),\
        np.concatenate((np.ones(len_test), np.zeros(len_test)))

In [7]:
sub_g, train_edgeids, train_labels, test_edgeids, test_labels = gen_data_split(data['adj'][0], p = 0.3)

# Train GCN on link prediction task

In [8]:
# Load full graph adjacency matrix as sparse tensor
A = torch.tensor(data['adj'][0])
A = (A + torch.eye(A.shape[0])).to_sparse_coo()

# Load sub-graph as sparse tensor
sub_g = torch.tensor(sub_g)
sub_g = (sub_g + torch.eye(sub_g.shape[0])).to_sparse_coo()

node_features = torch.eye(data['feat'][0].shape[0], dtype=torch.float64)
train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)

# Define the model
input_size = node_features.size(1)
hidden_size = 500
output_size = 100  # Size of node embeddings to decode 
lr = 0.001
model = GCN(input_size, hidden_size, output_size)

# Define loss function and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

# Training loop
num_epochs = 50
for epoch in range(num_epochs):
    optimizer.zero_grad()

    # Forward pass
    outputs = model.encode(node_features, A=sub_g)
    preds = model.decode(outputs, train_edgeids)
    

    # Compute the loss
    loss = criterion(preds, train_labels)
    print(f"Loss: {loss}")
    acc = torch.sum((torch.sigmoid(preds) > 0.5) == train_labels)/train_labels.shape[0]
    print(f"Accuracy: {acc}")

    # Backward pass and optimization
    loss.backward()
    optimizer.step()

# Eval the model on test edges
print(f"----- Evaluate model -----")
model.eval()

test_outputs = model.encode(node_features, A=A)
test_preds = model.decode(test_outputs, test_edgeids)
test_loss = criterion(test_preds, test_labels)
print(f"Test Loss: {test_loss}")
test_acc = torch.sum((torch.sigmoid(test_preds) > 0.5) == test_labels)/test_labels.shape[0]
print(f"Test Accuracy: {test_acc}")


Loss: 0.7007034099539495
Accuracy: 0.4982728958129883
Loss: 0.6844691018922185
Accuracy: 0.5069084763526917
Loss: 0.6649768547591206
Accuracy: 0.5552676916122437
Loss: 0.6530262088429327
Accuracy: 0.6243523359298706
Loss: 0.637425807875001
Accuracy: 0.6571675539016724
Loss: 0.6259037537531857
Accuracy: 0.6623488664627075
Loss: 0.6106863884651867
Accuracy: 0.6442141532897949
Loss: 0.5976391185398781
Accuracy: 0.6433506011962891
Loss: 0.5855957639775985
Accuracy: 0.6390328407287598
Loss: 0.5699500601871834
Accuracy: 0.6943005323410034
Loss: 0.5583254382323078
Accuracy: 0.7366148829460144
Loss: 0.5437051603385071
Accuracy: 0.7754749655723572
Loss: 0.5295992130302981
Accuracy: 0.8324697613716125
Loss: 0.5170864021080103
Accuracy: 0.8747841119766235
Loss: 0.509096340493784
Accuracy: 0.8851467967033386
Loss: 0.4931271124880208
Accuracy: 0.9058721661567688
Loss: 0.47992762395229593
Accuracy: 0.9188255667686462
Loss: 0.46735299808837283
Accuracy: 0.936096727848053
Loss: 0.45490939529962066
Acc

# Define GNNExplainer

In [9]:
class GNNExplainer(nn.Module):
    """
    """
    def __init__(self, model, adj_matrix, labels, size_weight:int=0.005):
        super(GNNExplainer, self).__init__()
        self.model = model
        self.adj_matrix = adj_matrix
        self.labels = labels
        self.size_weight = size_weight
        self.node_mask = torch.empty((adj_matrix.shape[0], adj_matrix.shape[1]), dtype=torch.float64)
        torch.nn.init.xavier_uniform_(self.node_mask)
        # torch.nn.init.constant_(self.node_mask, 1.0)
        self.node_mask = torch.sigmoid(self.node_mask)
        self.node_mask = nn.Parameter(self.node_mask)
        
    def forward(self, edge_idx, x):
        """ 
        Use node mask to mask out unimportant nodes/edges for link prediction(s).
        Returns probability for edge(s) in node_idx and final adjacency matrix mask
        
        ::param:: edge_idx <- Tuple(int, int) : indices of nodes to predict edge between
                        or <- List(Tuple(int, int)) : for multi edge prediction
        ::param:: x <- Tensor(N x k) : node features of dimension (number of nodes x number of features)
        """
        # Use trained node_mask to mask adjacency matrix
        sym_mask = self.node_mask
        sym_mask = (sym_mask + sym_mask.t())/2
        masked_adj = self.adj_matrix * sym_mask
        
        # Compute GCN prediction using masked adjacency matrix
        outputs = model.encode(x, masked_adj)
        preds = model.decode(outputs, edge_idx)
        return preds, masked_adj
    
    def loss(self, pred, explain_pred):
        entropy_loss = -torch.sum(pred * torch.log(explain_pred))
        mask_size_loss = self.size_weight * torch.sum(torch.sigmoid(self.node_mask))
            
        return entropy_loss + mask_size_loss

# Run Explainer on trained GCN

In [None]:
# Link prediction to explain
# edge = torch.tensor([696, 697])
# edge = torch.tensor([1019, 1018])
edge = torch.tensor([870, 869])


In [11]:
# Define the model
input_size = node_features.size(1)
hidden_size = 100
output_size = 1  # Output size matches the number of nodes for link prediction
lr = 0.001
explainer = GNNExplainer(model=model, adj_matrix=A, labels=...)

# Make sure GCN is set to eval
model.eval()
model.GCN1.W.requires_grad = False
model.GCN2.W.requires_grad = False

# GCN prediction
gcn_pred = torch.sigmoid(model.decode(model.encode(node_features, A), edge))

# Define loss function and optimizer
explain_optimizer = torch.optim.Adam(explainer.parameters(), lr=lr)

# Training loop (replace with your actual training loop)
num_epochs = 100
for epoch in range(num_epochs):
    explain_optimizer.zero_grad()
    

    # Forward pass
    outputs, _ = explainer(edge, node_features)

    # Compute the loss
    explain_loss = explainer.loss(outputs, gcn_pred)
    print(f"Loss: {explain_loss}")

    # Backward pass and optimization
    explain_loss.backward()
    explain_optimizer.step()


  edge_embeds = torch.sum(node_embeds[edge_ids.T[0]] * node_embeds[edge_ids.T[1]], dim=-1)


Loss: 2361.59103288144, Explainer_pred:0.7161332465785349
Loss: 2360.697646032843, Explainer_pred:0.7153812309348347
Loss: 2359.8040532219247, Explainer_pred:0.7146349231078027
Loss: 2358.9102511259866, Explainer_pred:0.7138943886613072
Loss: 2358.0162364396965, Explainer_pred:0.7131596883945188
Loss: 2357.1220058830395, Explainer_pred:0.7124308781793095
Loss: 2356.2275562090767, Explainer_pred:0.7117080088219777
Loss: 2355.3328842114456, Explainer_pred:0.7109911259502694
Loss: 2354.4379867315556, Explainer_pred:0.710280269926336
Loss: 2353.5428606654273, Explainer_pred:0.7095754757859791
Loss: 2352.6475029701296, Explainer_pred:0.7088767732042249
Loss: 2351.751910669782, Explainer_pred:0.7081841864869648
Loss: 2350.8560808610832, Explainer_pred:0.7074977345881256
Loss: 2349.9600107183524, Explainer_pred:0.7068174311515645
Loss: 2349.063697583432, Explainer_pred:0.7061433185397172
Loss: 2348.1671390392557, Explainer_pred:0.7054754943136265
Loss: 2347.270332208637, Explainer_pred:0.7048

# Visualize most important sub-graph

In [19]:
def filter_adj(adj, threshold=0.5):
    """ Filters graph to only include edges above threshold
    """
    filt_adj = adj.to_dense().detach().clone()
    filt_adj[adj.to_dense()<threshold] = 0
    filt_adj = filt_adj.fill_diagonal_(0)
    return filt_adj

f = filter_adj(_)

In [25]:
fig, ax= plt.subplots(1, 1, figsize=(10,10))


# Plot subgraph
ax.set_title("")
G = nx.from_numpy_array(np.array(f))
G.remove_nodes_from(list(nx.isolates(G)))

# color predicted edge
# u, v = (696, 697)
# u, v = (1018, 1019)
u, v = (869, 870)
edge_colors = ['red' if e == (u, v) else 'black' for e in G.edges ]

# draw sub-graph
nx.draw_networkx(G, ax=ax, edge_color=edge_colors)

fig.tight_layout()
fig.savefig(f"{c}_sub_graph.png")
plt.close()