# Technical Details

In [6]:
# Dependencies
from typing import List, Tuple, Dict
import numpy as np 

import torch 
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric as pyg

import pygmtools as pygm
pygm.set_backend('pytorch')

In [24]:
# Meta Data

# Reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True 
np.random.seed(SEED)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Variables. 
max_num_nodes = 30 # 28 in the full dataset.  
batch_size = 64

Using device: cuda


## MUTAG Data  

In [38]:
# Download data. 
from torch_geometric.datasets import TUDataset
data_raw = TUDataset(root='data/TUDataset', name='MUTAG')

# Shuffle.
data_raw = data_raw.shuffle()

# Split.
train_data = data_raw[:150]
test_data = data_raw[150:]

def preprocess_MUTAG(data: TUDataset, max_num_nodes) -> pyg.data.Data:
    
    num_nodes = data.num_nodes
    
    # Pad node features.
    padded_x = torch.zeros((max_num_nodes, data.x.size(1)))
    padded_x[:num_nodes] = data.x

    # Relax edges to weights. 
    padded_adj = torch.zeros((max_num_nodes, max_num_nodes))
    padded_adj[:num_nodes, :num_nodes] = (
        pyg.utils.to_dense_adj(data.edge_index).squeeze(0)
    )
    edge_index, edge_weight = pygm.utils.dense_to_sparse(padded_adj + 1)
    edge_index = edge_index.transpose(0, 1)
    edge_weight = edge_weight.squeeze(0)

    # Wrap in data object.
    preprocessed_data = pyg.data.Data(x=padded_x, 
                                      edge_index=edge_index,
                                      edge_attr=edge_weight - 1,
                                      y=data.y)

    return preprocessed_data 

# Create data lists.
train_data_list = []
train_data_list_0 = []
train_data_list_1 = []
test_data_list = []
test_data_list_0 = []
test_data_list_1 = []

for graph in train_data:
    train_data_list.append(preprocess_MUTAG(graph, max_num_nodes))

    if graph.y.item() == 0: 
        train_data_list_0.append(preprocess_MUTAG(graph, max_num_nodes))

    elif graph.y.item() == 1: 
        train_data_list_1.append(preprocess_MUTAG(graph, max_num_nodes))

for graph in test_data:
    test_data_list.append(preprocess_MUTAG(graph, max_num_nodes))

    if graph.y.item() == 0: 
        test_data_list_0.append(preprocess_MUTAG(graph, max_num_nodes))

    elif graph.y.item() == 1: 
        test_data_list_1.append(preprocess_MUTAG(graph, max_num_nodes))

# Create data loaders.
train_loader = pyg.loader.DataLoader(train_data_list, batch_size=batch_size, 
                                     shuffle=True)
test_loader = pyg.loader.DataLoader(test_data_list, batch_size=batch_size, 
                                    shuffle=True)

## Explainee GCN Model

In [42]:
class GCNWeighted(nn.Module):
    def __init__(self, hidden_channels):
        super(GCNWeighted, self).__init__()
        self.conv1 = pyg.nn.GCNConv(7, hidden_channels) # 7 node features.
        self.conv2 = pyg.nn.GCNConv(hidden_channels, hidden_channels)
        self.conv3 = pyg.nn.GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 2) # 2 classes.
    
    def forward(self, data):
        x, edge_index, batch, edge_weight = (
            data.x, data.edge_index, data.batch, data.edge_weight
        )

        # 1. Node embeddings.
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Pooling.
        x = pyg.nn.global_mean_pool(x, batch)

        # 3. Prediction.
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [44]:
# Training Explainee.
explainee = GCNWeighted(hidden_channels=64)
optimizer = torch.optim.Adam(explainee.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train(data_loader): 
    explainee.train()

    for batch in data_loader: 
        out = explainee(batch)
        loss = criterion(out, batch.y)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

def explainee_accuracy(data_loader):
    explainee.eval()

    correct = 0
    for batch in data_loader: 
        out = explainee(batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())

    return correct / len(data_loader.dataset)

for epoch in range(1, 201): 
    train(train_loader)
    train_accuracy = explainee_accuracy(train_loader)
    test_accuracy = explainee_accuracy(test_loader)

    print(f"Epoch: {epoch} Train Accuracy: {train_accuracy} " + 
          f"Test Accuracy: {test_accuracy}")

Epoch: 1 Train Accuracy: 0.68 Test Accuracy: 0.6052631578947368
Epoch: 2 Train Accuracy: 0.68 Test Accuracy: 0.6052631578947368
Epoch: 3 Train Accuracy: 0.68 Test Accuracy: 0.6052631578947368
Epoch: 4 Train Accuracy: 0.68 Test Accuracy: 0.6052631578947368
Epoch: 5 Train Accuracy: 0.74 Test Accuracy: 0.6842105263157895
Epoch: 6 Train Accuracy: 0.8466666666666667 Test Accuracy: 0.8157894736842105
Epoch: 7 Train Accuracy: 0.8533333333333334 Test Accuracy: 0.8157894736842105
Epoch: 8 Train Accuracy: 0.86 Test Accuracy: 0.8157894736842105
Epoch: 9 Train Accuracy: 0.8733333333333333 Test Accuracy: 0.8157894736842105
Epoch: 10 Train Accuracy: 0.8666666666666667 Test Accuracy: 0.8421052631578947
Epoch: 11 Train Accuracy: 0.8733333333333333 Test Accuracy: 0.8157894736842105
Epoch: 12 Train Accuracy: 0.8266666666666667 Test Accuracy: 0.8421052631578947
Epoch: 13 Train Accuracy: 0.88 Test Accuracy: 0.8421052631578947
Epoch: 14 Train Accuracy: 0.8733333333333333 Test Accuracy: 0.8157894736842105
E

In [45]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix

full_batch = pyg.data.Batch.from_data_list(test_data_list + train_data_list)
explainee.eval()
preds = explainee(full_batch).argmax(dim=1).numpy()
targets = full_batch.y.numpy()

conf_matrix = confusion_matrix(targets, preds)
conf_matrix

array([[ 46,  17],
       [  8, 117]], dtype=int64)

## Diffusion Generator 

betas = []

for batch in loader:
    adj_batch = func(batch)
    t ~ U[1, 50]
    noised_adj = func(adj_batch)

    pred_adj = model(noised_adj, t)

    CrtEnt(adj_batch, pred_adj)



In [58]:
# Hyperparameters.
T = 50
time_embed_dim = 10
num_node_feats = 7

betas = torch.linspace(start=0.001, end=0.1, steps=T)
beta_bars = []
cum_prod = 1

for beta in betas:
    cum_prod *= (1 - 2*beta)
    beta_bars.append(0.5 - 0.5 * cum_prod)

num_epochs = 200

In [59]:
def forward_diffusion_sample(graphs: pyg.data.Batch, 
                             t: int) -> pyg.data.Batch:
    """
    Input: Batch of observed graphs.
    Output: Batch of noised graphs.
    """
    # [b, n, n]
    adj_batch = pyg.utils.to_dense_adj(graphs.edge_index, batch=graphs.batch, 
                                       max_num_nodes=max_num_nodes,
                                       edge_attr = graphs.edge_attr)
    # D4 uses the mask somehow.
    x_batch, node_feat_mask = pyg.utils.to_dense_batch(graphs.x, graphs.batch, 
                                       max_num_nodes=max_num_nodes)    
    
    transition_probs = torch.full_like(adj_batch, beta_bars[t])

    # Symmetrically applies noise - treats edges as undirected.
    noise_upper = torch.bernoulli(transition_probs).triu(diagonal=1)
    noise_lower = noise_upper.transpose(-1, -2)
    noised_adj_batch = torch.abs(adj_batch + noise_upper + noise_lower)
    noised_adj_batch_sparse, _ = pyg.utils.dense_to_sparse(noised_adj_batch)
    noised_graph = pyg.data.Batch(x=graphs.x, 
                                  edge_index=noised_adj_batch_sparse,
                                  edge_attr=graphs.edge_attr, 
                                  y=graphs.y, batch=graphs.batch)

    return noised_graph

In [60]:
class denoising_model(nn.Module):
    def __init__(self, T, time_embed_dim, 
                 num_node_feats, h1=10, h2=50, p_dropout=0.5): 
        super(denoising_model, self).__init__()

        self.time_embedder = nn.Embedding(num_embeddings=T, 
                                          embedding_dim=time_embed_dim) 

        self.conv1 = pyg.nn.GCNConv(num_node_feats, h1)
        self.conv2 = pyg.nn.GCNConv(h1, h2)
        self.conv3 = pyg.nn.GCNConv(h2, max_num_nodes)

        self.p_dropout = p_dropout

    def forward(self, noised_graphs, t):

        time_embedding = self.time_embedder(t)

        x = self.conv1(noised_graphs.x, noised_graphs.edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.p_dropout)
        x = x + time_embedding
        x = self.conv2(x, noised_graphs.edge_index)
        x = F.relu(x)
        x = F.dropout(x, p=self.p_dropout)
        x = self.conv3(x, noised_graphs.edge_index)
        x = x.sigmoid()

        pred_adj, _ = pyg.utils.to_dense_batch(x, noised_graphs.batch, 
                                        max_num_nodes=max_num_nodes)

        pred_adj_sparse = torch.bernoulli(pred_adj)
        pred_adj_sparse, _ = pyg.utils.dense_to_sparse(pred_adj_sparse)
        pred_graph = pyg.data.Batch(x=noised_graphs.x, 
                                    edge_index=pred_adj_sparse,
                                    edge_attr=noised_graphs.edge_attr,
                                    batch=noised_graphs.batch)

        return [pred_adj, pred_graph]

In [61]:
# Training Loop.
model_denoise = denoising_model(T, time_embed_dim=time_embed_dim, 
                                num_node_feats=num_node_feats) 

optimizer = torch.optim.Adam(model_denoise.parameters(), lr=0.01)

for epoch in range(1, num_epochs): 
    running_loss = 0.0
    for graphs in train_loader_padded:
        with torch.no_grad():
            t = torch.randint(low=1, high=T, size=(1,))
            noised_graphs = forward_diffusion_sample(graphs, t)
            adj_batch = pyg.utils.to_dense_adj(
                graphs.edge_index, batch=graphs.batch, 
                max_num_nodes=max_num_nodes
            )

        pred_adj, pred_graph = model_denoise(noised_graphs, t)

        loss_dist = F.binary_cross_entropy(pred_adj, adj_batch) 
        running_loss += loss_dist.item()
        loss_dist.backward()

        optimizer.step()
        optimizer.zero_grad()

    print(
        f"Epoch: {epoch}, Loss: {running_loss / len(train_loader_padded)}"
    )

Epoch: 1, Loss: 0.7093714475631714
Epoch: 2, Loss: 0.622095008691152
Epoch: 3, Loss: 0.5128630797068278
Epoch: 4, Loss: 0.4013783236344655
Epoch: 5, Loss: 0.40246546268463135
Epoch: 6, Loss: 0.32683605949083966
Epoch: 7, Loss: 0.2960011661052704
Epoch: 8, Loss: 0.2102095584074656
Epoch: 9, Loss: 0.20801111062367758
Epoch: 10, Loss: 0.19716976086298624
Epoch: 11, Loss: 0.20385312537352243
Epoch: 12, Loss: 0.21180968483289084
Epoch: 13, Loss: 0.19380254050095877
Epoch: 14, Loss: 0.19296148916085562
Epoch: 15, Loss: 0.20216453075408936
Epoch: 16, Loss: 0.18725470701853433
Epoch: 17, Loss: 0.18771063288052878
Epoch: 18, Loss: 0.1820418337980906
Epoch: 19, Loss: 0.18217353026072183
Epoch: 20, Loss: 0.19016788403193155
Epoch: 21, Loss: 0.1968522916237513
Epoch: 22, Loss: 0.1845134049654007
Epoch: 23, Loss: 0.1775075594584147
Epoch: 24, Loss: 0.18885981539885202
Epoch: 25, Loss: 0.18943706651528677
Epoch: 26, Loss: 0.18136371672153473
Epoch: 27, Loss: 0.18361488481362662
Epoch: 28, Loss: 0.17

In [74]:
# Training Loop.
model_denoise_CF = denoising_model(T, time_embed_dim=time_embed_dim, 
                                num_node_feats=num_node_feats) 

optimizer = torch.optim.Adam(model_denoise_CF.parameters(), lr=0.01)
CF_weight = 1.0

for epoch in range(1, num_epochs): 
    running_loss_dist = 0.0
    running_loss_CF = 0.0
    for graphs in train_loader_padded:
        with torch.no_grad():
            t = torch.randint(low=1, high=T, size=(1,))
            noised_graphs = forward_diffusion_sample(graphs, t)
            adj_batch = pyg.utils.to_dense_adj(
                graphs.edge_index, batch=graphs.batch, 
                max_num_nodes=max_num_nodes
            )

        pred_adj, pred_graph = model_denoise_CF(noised_graphs, t)

        loss_dist = F.binary_cross_entropy(pred_adj, adj_batch) 
        running_loss_dist += loss_dist.item()

        explainee.eval()
        explainee_pred = F.softmax(explainee(noised_graphs), dim=-1)
        loss_CF = -1 * torch.log(
            explainee_pred[torch.arange(explainee_pred.shape[0]), graphs.y]
        ).mean()
        running_loss_CF += loss_CF.item()

        loss = 1.0 * loss_dist + CF_weight * loss_CF
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

    print(f"Epoch: {epoch}")
    print(f" Loss_dist: {running_loss_dist / len(train_loader_padded)}")
    print(f" Loss_CF: {running_loss_CF / len(train_loader_padded)}")

tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0005,  0.0020, -0.0048,  0.0022,  0.0139,  0.0045,  0.0231, -0.0140,
          0.0282,  0.0039],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  

# Debugging

In [63]:
T = 10
low_noise = 0.0
high_noise = 0.5
noise_list = list(np.random.uniform(low=low_noise, high=high_noise, size=T))

# Bernoulli distribution for the probability of an edge existing.
bernoulli_adj = torch.full_like(x[1], noise_list[0])

# Symmetrically applies noise - treats edges as undirected.
noise_upper = torch.bernoulli(bernoulli_adj).triu(diagonal=1)
noise_lower = noise_upper.transpose(-1, -2)
train_adj = torch.abs(-x[1] + noise_upper + noise_lower)

noisediff = noise_upper + noise_lower # record true noise. 

print((train_adj - x[1]).abs().sum())
print(noisediff.sum())

NameError: name 'x' is not defined

In [None]:
max_obs_nodes = 0
for graph in train_data_list + test_data_list:
    if graph.x.shape[0] > max_obs_nodes:
        max_obs_nodes = graph.x.shape[0]

print(max_obs_nodes)

28


In [None]:
test_adj_sparse = torch.bernoulli(test_adj)
test_adj_sparse, _ = pyg.utils.dense_to_sparse(test_adj_sparse)
test_adj_sparse

tensor([[ 0,  0,  0,  0,  1,  1,  1,  2,  2,  3,  3,  3,  4,  5,  6,  6,  6,  6,
          7,  7,  8,  9,  9,  9,  9, 11],
        [ 5,  6, 11, 12,  5, 10, 11,  1, 13,  0,  2,  9,  4, 10, 11, 17, 18, 21,
          8, 12, 11,  7, 12, 18, 19,  3]])