# Technical Details

In [13]:
# Dependencies
from typing import List, Tuple, Dict
import numpy as np 

import torch 
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric as pyg

import pygmtools as pygm
pygm.set_backend('pytorch')

In [14]:
# Meta Data

# Reproducibility
SEED = 123
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True 
np.random.seed(SEED)

# Check GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)

# Variables. 
max_num_nodes = 30 # 28 in the full dataset.  
batch_size = 64

Using device: cuda


## MUTAG Data  

In [15]:
# Download data. 
from torch_geometric.datasets import TUDataset
data_raw = TUDataset(root='data/TUDataset', name='MUTAG')

# Shuffle.
data_raw = data_raw.shuffle()

# Split.
train_data = data_raw[:150]
test_data = data_raw[150:]

def preprocess_MUTAG(data: TUDataset, max_num_nodes) -> pyg.data.Data:
    
    num_nodes = data.num_nodes
    
    # Pad node features.
    padded_x = torch.zeros((max_num_nodes, data.x.size(1)))
    padded_x[:num_nodes] = data.x

    # Relax edges to weights. 
    padded_adj = torch.zeros((max_num_nodes, max_num_nodes))
    padded_adj[:num_nodes, :num_nodes] = (
        pyg.utils.to_dense_adj(data.edge_index).squeeze(0)
    )
    edge_index, edge_weight = pygm.utils.dense_to_sparse(padded_adj + 1)
    edge_index = edge_index.transpose(0, 1)
    edge_weight = edge_weight.squeeze(0)

    # Wrap in data object.
    preprocessed_data = pyg.data.Data(x=padded_x, 
                                      edge_index=edge_index,
                                      edge_attr=edge_weight - 1,
                                      y=data.y)

    return preprocessed_data 

# Create data lists.
train_data_list = []
train_data_list_0 = []
train_data_list_1 = []
test_data_list = []
test_data_list_0 = []
test_data_list_1 = []

for graph in train_data:
    train_data_list.append(preprocess_MUTAG(graph, max_num_nodes))

    if graph.y.item() == 0: 
        train_data_list_0.append(preprocess_MUTAG(graph, max_num_nodes))

    elif graph.y.item() == 1: 
        train_data_list_1.append(preprocess_MUTAG(graph, max_num_nodes))

for graph in test_data:
    test_data_list.append(preprocess_MUTAG(graph, max_num_nodes))

    if graph.y.item() == 0: 
        test_data_list_0.append(preprocess_MUTAG(graph, max_num_nodes))

    elif graph.y.item() == 1: 
        test_data_list_1.append(preprocess_MUTAG(graph, max_num_nodes))

# Create data loaders.
train_loader = pyg.loader.DataLoader(train_data_list, batch_size=batch_size, 
                                     shuffle=True)
test_loader = pyg.loader.DataLoader(test_data_list, batch_size=batch_size, 
                                    shuffle=True)

## Explainee GCN Model

In [16]:
class GCNWeighted(nn.Module):
    def __init__(self, hidden_channels):
        super(GCNWeighted, self).__init__()
        self.conv1 = pyg.nn.GCNConv(7, hidden_channels) # 7 node features.
        self.conv2 = pyg.nn.GCNConv(hidden_channels, hidden_channels)
        self.conv3 = pyg.nn.GCNConv(hidden_channels, hidden_channels)
        self.lin = nn.Linear(hidden_channels, 2) # 2 classes.
    
    def forward(self, data):
        x, edge_index, batch, edge_weight = (
            data.x, data.edge_index, data.batch, data.edge_attr
        )

        # 1. Node embeddings.
        x = self.conv1(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv2(x, edge_index, edge_weight)
        x = x.relu()
        x = self.conv3(x, edge_index, edge_weight)

        # 2. Pooling.
        x = pyg.nn.global_mean_pool(x, batch)

        # 3. Prediction.
        x = F.dropout(x, p=0.5, training=self.training)
        x = self.lin(x)

        return x

In [17]:
# Training Explainee.
explainee = GCNWeighted(hidden_channels=64)
optimizer = torch.optim.Adam(explainee.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()

def train(data_loader): 
    explainee.train()

    for batch in data_loader: 
        out = explainee(batch)
        loss = criterion(out, batch.y)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

def explainee_accuracy(data_loader):
    explainee.eval()

    correct = 0
    for batch in data_loader: 
        out = explainee(batch)
        pred = out.argmax(dim=1)
        correct += int((pred == batch.y).sum())

    return correct / len(data_loader.dataset)

for epoch in range(1, 201): 
    train(train_loader)
    train_accuracy = explainee_accuracy(train_loader)
    test_accuracy = explainee_accuracy(test_loader)

    print(f"Epoch: {epoch} Train Accuracy: {train_accuracy} " + 
          f"Test Accuracy: {test_accuracy}")

Epoch: 1 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 2 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 3 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 4 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 5 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 6 Train Accuracy: 0.6533333333333333 Test Accuracy: 0.7105263157894737
Epoch: 7 Train Accuracy: 0.68 Test Accuracy: 0.7105263157894737
Epoch: 8 Train Accuracy: 0.7066666666666667 Test Accuracy: 0.7631578947368421
Epoch: 9 Train Accuracy: 0.7133333333333334 Test Accuracy: 0.7631578947368421
Epoch: 10 Train Accuracy: 0.7333333333333333 Test Accuracy: 0.8157894736842105
Epoch: 11 Train Accuracy: 0.82 Test Accuracy: 0.8421052631578947
Epoch: 12 Train Accuracy: 0.82 Test Accuracy: 0.8421052631578947
Epoch: 13 Train Accuracy: 0.78 Test Accuracy: 0.8421052631578947
Epoch: 14 Train Accuracy: 0.8266666666

In [18]:
# Confusion Matrix
from sklearn.metrics import confusion_matrix

full_batch = pyg.data.Batch.from_data_list(test_data_list + train_data_list)
explainee.eval()
preds = explainee(full_batch).argmax(dim=1).numpy()
targets = full_batch.y.numpy()

conf_matrix = confusion_matrix(targets, preds)
conf_matrix

array([[ 46,  17],
       [ 10, 115]], dtype=int64)

## D4Explainer 

In [19]:
# Hyperparameters.
T = 50
time_embed_dim = 10
num_node_feats = 7

betas = torch.linspace(start=0.001, end=0.1, steps=T)
beta_bars = []
cum_prod = 1

for beta in betas:
    cum_prod *= (1 - 2*beta)
    beta_bars.append(0.5 - 0.5 * cum_prod)

num_epochs = 200

In [20]:
def forward_diffusion_sample(graphs: pyg.data.Batch, 
                             t: int) -> pyg.data.Batch:
    """
    Input: Batch of observed graphs.
    Output: Batch of noised graphs.
    """
    edge_weight = graphs.edge_attr

    transition_probs = torch.full_like(edge_weight, beta_bars[t])
    transition_dist = torch.distributions.RelaxedBernoulli(
        temperature=0.15, probs=transition_probs
    )

    noised_edge_weights = torch.abs(
        edge_weight + transition_dist.rsample()
    )
    noised_graph = pyg.data.Batch(x=graphs.x, 
                                  edge_index=graphs.edge_index,
                                  edge_attr=noised_edge_weights,  
                                  y=graphs.y, batch=graphs.batch)

    return noised_graph

In [21]:
class denoising_model(nn.Module):
    def __init__(self, T, time_embed_dim, 
                 num_node_feats, h1=10, h2=50, h3=30, p_dropout=0.5): 
        super(denoising_model, self).__init__()

        self.time_embedder = nn.Embedding(num_embeddings=T, 
                                          embedding_dim=time_embed_dim) 

        self.conv1 = pyg.nn.GCNConv(num_node_feats, h1)
        self.conv2 = pyg.nn.GCNConv(h1, h2)
        self.conv3 = pyg.nn.GCNConv(h2, h3)

        self.p_dropout = p_dropout
        self.lin = nn.Linear(h3, max_num_nodes**2) # predicting weights.

    def forward(self, noised_graphs, t):
        
        x, edge_index, batch, edge_weight = (
            noised_graphs.x, noised_graphs.edge_index, 
            noised_graphs.batch, noised_graphs.edge_weight
        )

        time_embedding = self.time_embedder(t)

        x = self.conv1(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=self.p_dropout)
        
        x = x + time_embedding
        
        x = self.conv2(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=self.p_dropout)

        x = self.conv3(x, edge_index, edge_weight)
        x = F.relu(x)
        x = F.dropout(x, p=self.p_dropout)

        x = pyg.nn.global_mean_pool(x, batch)
        x = F.dropout(x, p=self.p_dropout)
        x = self.lin(x)
        x = F.sigmoid(x) 

        pred_weights = x.reshape(-1)
        pred_graph = pyg.data.Batch(x=noised_graphs.x, 
                                    edge_index=noised_graphs.edge_index,
                                    edge_attr=pred_weights,
                                    batch=noised_graphs.batch)
    
        return [pred_weights, pred_graph]

In [22]:
# Training Loop.
model_denoise = denoising_model(T, time_embed_dim=time_embed_dim, 
                                num_node_feats=num_node_feats) 

optimizer = torch.optim.Adam(model_denoise.parameters(), lr=0.01)

for epoch in range(1, num_epochs): 
    running_loss = 0.0
    for graphs in train_loader:
        with torch.no_grad():
            t = torch.randint(low=1, high=T, size=(1,))
            noised_graphs = forward_diffusion_sample(graphs, t)

        pred_weight, pred_graph = model_denoise(noised_graphs, t)

        loss_dist = F.binary_cross_entropy(
            pred_weight, graphs.edge_attr.squeeze(1)
        ) 
        running_loss += loss_dist.item()

        loss_dist.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(
        f"Epoch: {epoch}, Loss: {running_loss / len(train_loader)}"
    )

Epoch: 1, Loss: 0.6814501682917277
Epoch: 2, Loss: 0.5736742814381918
Epoch: 3, Loss: 0.4547794759273529
Epoch: 4, Loss: 0.3207857708136241
Epoch: 5, Loss: 0.1708637277285258
Epoch: 6, Loss: 0.13724727680285773
Epoch: 7, Loss: 0.11728910605112712
Epoch: 8, Loss: 0.09351835151513417
Epoch: 9, Loss: 0.09692725042502086
Epoch: 10, Loss: 0.14647934089104334
Epoch: 11, Loss: 0.10092594722906749
Epoch: 12, Loss: 0.09141373882691066
Epoch: 13, Loss: 0.11526066809892654
Epoch: 14, Loss: 0.09440551449855168
Epoch: 15, Loss: 0.08638316889603932
Epoch: 16, Loss: 0.08875712255636851
Epoch: 17, Loss: 0.07957890878121059
Epoch: 18, Loss: 0.07754483819007874
Epoch: 19, Loss: 0.07667408386866252
Epoch: 20, Loss: 0.07773071030775706
Epoch: 21, Loss: 0.09795727580785751
Epoch: 22, Loss: 0.07931863764921825
Epoch: 23, Loss: 0.07991917183001836
Epoch: 24, Loss: 0.08723363031943639
Epoch: 25, Loss: 0.0734206885099411
Epoch: 26, Loss: 0.07543578495581944
Epoch: 27, Loss: 0.07695478200912476
Epoch: 28, Loss:

In [23]:
# Training Loop.
model_denoise_CF = denoising_model(T, time_embed_dim=time_embed_dim, 
                                num_node_feats=num_node_feats) 

optimizer = torch.optim.Adam(model_denoise_CF.parameters(), lr=0.01)
CF_weight = 1.0

for epoch in range(1, num_epochs): 
    running_loss_dist = 0.0
    running_loss_CF = 0.0
    for graphs in train_loader:
        with torch.no_grad():
            t = torch.randint(low=1, high=T, size=(1,))
            noised_graphs = forward_diffusion_sample(graphs, t)

        pred_weight, pred_graph = model_denoise_CF(noised_graphs, t)

        loss_dist = F.binary_cross_entropy(
            pred_weight, graphs.edge_attr.squeeze(1)
        ) 
        running_loss_dist += loss_dist.item()

        explainee.eval()
        explainee_pred = F.softmax(explainee(pred_graph), dim=-1)
        class_prob = explainee_pred[torch.arange(explainee_pred.shape[0]), 
                                    graphs.y]
        loss_CF = (-1 * torch.log(1 - class_prob)).mean()
        running_loss_CF += loss_CF.item()

        loss = 1.0 * loss_dist + CF_weight * loss_CF
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

    print(
        f"Epoch: {epoch}", 
        f" Loss_dist: {running_loss_dist / len(train_loader)}", 
        f" Loss_CF: {running_loss_CF / len(train_loader)}"
    )

Epoch: 1  Loss_dist: 0.691668967405955  Loss_CF: 1.8383880058924358
Epoch: 2  Loss_dist: 0.6552395820617676  Loss_CF: 1.652824838956197
Epoch: 3  Loss_dist: 0.6067018906275431  Loss_CF: 1.686516245206197
Epoch: 4  Loss_dist: 0.5793521602948507  Loss_CF: 1.6632184584935505
Epoch: 5  Loss_dist: 0.4610901474952698  Loss_CF: 1.6824331680933635
Epoch: 6  Loss_dist: 0.416742483774821  Loss_CF: 1.4854202270507812
Epoch: 7  Loss_dist: 0.3584097425142924  Loss_CF: 1.390825907389323
Epoch: 8  Loss_dist: 0.34042949477831524  Loss_CF: 1.1268500884373982
Epoch: 9  Loss_dist: 0.3319297830263774  Loss_CF: 0.8661627372105917
Epoch: 10  Loss_dist: 0.3777369161446889  Loss_CF: 0.6504358053207397
Epoch: 11  Loss_dist: 0.5413017968336741  Loss_CF: 0.5196541746457418
Epoch: 12  Loss_dist: 0.38932392994562787  Loss_CF: 0.5397603611151377
Epoch: 13  Loss_dist: 0.3767375449339549  Loss_CF: 0.5104561944802603
Epoch: 14  Loss_dist: 0.2589414417743683  Loss_CF: 0.7227440079053243
Epoch: 15  Loss_dist: 0.24907933

In [51]:
# Inference 
non_mut_example = pyg.data.Batch.from_data_list([test_data_list_1[0]])
non_mut_noised = forward_diffusion_sample(non_mut_example, torch.tensor([T-1])) 
pred_weight, non_mut_denoised = model_denoise_CF(non_mut_noised, torch.tensor([T-1]))
print(test_data_list_1[0].edge_attr.sum(), non_mut_denoised.edge_attr.sum())
print(torch.softmax(explainee(test_data_list_1[0]), dim=-1), 
      torch.softmax(explainee(non_mut_denoised), dim=-1)
)

tensor(44.) tensor(28.6338, grad_fn=<SumBackward0>)
tensor([[0.1354, 0.8646]], grad_fn=<SoftmaxBackward0>) tensor([[0.8495, 0.1505]], grad_fn=<SoftmaxBackward0>)


# Debugging