# Prepare dependencies and utilities

In [None]:
from torch.autograd import Variable
from torch_geometric.datasets import TUDataset
import torch
import torch.nn as nn
import torch_geometric.nn as gnn
import torch_geometric.utils as U
import math
import matplotlib.pyplot as plt
import networkx as nx
import scipy.stats

import importlib
import infra  # local source file with training & plotting infrastructure
importlib.reload(infra);

# Define hyperparameters

In [None]:
num_epochs = 64

# Test framework on TUDataset

In [None]:
tu_dataset = TUDataset(root='data/TUDataset', name='MUTAG')
tu_labels = ['C', 'N', 'O', 'F', 'I', 'Cl', 'Br']
tu_example_idx = 24

In [None]:
len(tu_dataset)

In [None]:
tu_model = gnn.Sequential('x, edge_index', [
    (gnn.GCN(
        in_channels=7,
        hidden_channels=50,
        num_layers=5,
        out_channels=50,
        dropout=0
    ), 'x, edge_index -> x'),
    (lambda x: torch.sum(x, dim=0), 'x -> x'),
    (nn.Linear(50, 2), 'x -> x')
])

In [None]:
tu_accuracies = infra.train(tu_model, tu_dataset, num_epochs=num_epochs)

In [None]:
data = tu_dataset[tu_example_idx]
pos = infra.plot_graph(data, infra.get_grad(data, tu_model), rows=2, labels=tu_labels)

# Now do the same with GANs

In [None]:
tu_gan_model = gnn.Sequential('x, edge_index', [
    (gnn.GAT(
        in_channels=7,
        hidden_channels=50,
        num_layers=5,
        out_channels=50,
        dropout=0
    ), 'x, edge_index -> x'),
    (lambda x: torch.sum(x, dim=0), 'x -> x'),
    (nn.Linear(50, 2), 'x -> x')
])

In [None]:
tu_gan_accuracies = infra.train(tu_gan_model, tu_dataset, num_epochs=num_epochs)

In [None]:
data = tu_dataset[tu_example_idx]
infra.plot_graph(data, infra.get_grad(data, tu_gan_model), rows=2, pos=pos, labels=tu_labels);

---

# Calculate saliency map statistics

In [None]:
gcn_grads = []
for data in tu_dataset:
    grad = infra.get_grad(data, tu_model)
    gcn_grads += grad.flatten().tolist()
print(f"std gcn: {torch.tensor(gcn_grads).std()}")
gan_grads = []
for data in tu_dataset:
    grad = infra.get_grad(data, tu_gan_model)
    gan_grads += grad.flatten().tolist()
print(f"std gan: {torch.tensor(gan_grads).std()}")

In [None]:
hist_kwargs = {'histtype': 'step', 'range': (-1, 1), 'bins': 64, 'density': False, 'stacked': False}

In [None]:
plt.figure(figsize=(4,3))
plt.hist([gcn_grads, gan_grads], **hist_kwargs)
plt.legend(['GCN', 'GAN'])
plt.savefig('./3-hist.pdf')

In [None]:
scipy.stats.kurtosis(gcn_grads)

In [None]:
scipy.stats.kurtosis(gan_grads)