In [1]:
import os
import torch
os.environ['TORCH'] = torch.__version__
print(torch.__version__)

!pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
!pip install -q git+https://github.com/pyg-team/pytorch_geometric.git

2.1.0+cu118


In [2]:
from torch_geometric.datasets import Planetoid

dataset = Planetoid(root='tmp', name='Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [3]:

import torch.nn.functional as F
from torch.nn import Sequential, Linear, ReLU
import torch_geometric
from torch_geometric.nn import GCNConv

In [4]:
class GCNNet(torch.nn.Module):
    def __init__(self, dataset):
        super(GCNNet, self).__init__()
        self.conv1 = GCNConv(dataset.num_node_features, 16)
        self.conv2 = GCNConv(16, dataset.num_classes)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)

        return F.log_softmax(x, dim=1)



In [5]:

def retrieve_accuracy(model, data, test_mask=None, value=False):
    _, pred = model(data.x, data.edge_index).max(dim=1)
    if test_mask is None:
        test_mask = data.test_mask
    correct = float(pred[test_mask].eq(data.y[test_mask]).sum().item())
    acc = correct / test_mask.sum().item()
    if value:
        return acc
    else:
        return 'Accuracy: {:.4f}'.format(acc)

In [6]:
def save_model(model, path):
    torch.save(model.state_dict(), path)

In [7]:
def train_model(model, data, epochs=200, lr=0.01, weight_decay=5e-4, clip=None, loss_function="nll_loss",
                epoch_save_path=None, no_output=False):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)

    accuracies = []

    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        out = model(data.x, data.edge_index)
        if loss_function == "nll_loss":
            loss = F.nll_loss(out[data.train_mask], data.y[data.train_mask])
        elif loss_function == "cross_entropy":
            loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask], size_average=True)
        else:
            raise Exception()
        if clip is not None:
            torch.nn.utils.clip_grad_norm(model.parameters(), clip)
        loss.backward()
        optimizer.step()

        if epoch_save_path is not None:
            # circumvent .pt ending
            save_model(model, epoch_save_path[:-3] + "_epoch_" + str(epoch) + epoch_save_path[-3:])
            accuracies.append(retrieve_accuracy(model, data, value=True))
            print('Accuracy: {:.4f}'.format(accuracies[-1]), "Epoch", epoch)
        else:
            if epoch % 25 == 0 and not no_output:
                print(retrieve_accuracy(model, data))

    model.eval()

    return accuracies

In [8]:
model = GCNNet(dataset)
data = dataset[0]
acc = train_model(model, data, epochs=200, lr=0.01, weight_decay=5e-4, clip=None, loss_function="nll_loss",
                epoch_save_path=None, no_output=False)
test_acc =  retrieve_accuracy(model, data, test_mask=None, value=True)
print("Test Accuracy:",test_acc)

Accuracy: 0.4020
Accuracy: 0.7250
Accuracy: 0.7260
Accuracy: 0.7410
Accuracy: 0.7410
Accuracy: 0.7600
Accuracy: 0.7660
Accuracy: 0.7710
Test Accuracy: 0.808


In [9]:
def execute_model_with_gradient(model, node, x, edge_index):
    ypred = model(x, edge_index)

    predicted_labels = ypred.argmax(dim=-1)
    predicted_label = predicted_labels[node]
    logit = torch.nn.functional.softmax((ypred[node, :]).squeeze(), dim=0)

    logit = logit[predicted_label]
    loss = -torch.log(logit)
    loss.backward()

In [10]:
def grad_node_explanation(model, node, x, edge_index):
    model.zero_grad()

    num_nodes, num_features = x.size()

    node_grad = torch.nn.Parameter(torch.ones(num_nodes))
    feature_grad = torch.nn.Parameter(torch.ones(num_features))

    node_grad.requires_grad = True
    feature_grad.requires_grad = True

    mask = node_grad.unsqueeze(0).T.matmul(feature_grad.unsqueeze(0)).to(device)

    execute_model_with_gradient(model, node, mask*x, edge_index)

    node_mask = torch.abs(node_grad.grad).cpu().detach().numpy()
    feature_mask = torch.abs(feature_grad.grad).cpu().detach().numpy()

    return feature_mask, node_mask


In [14]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

feature_mask, node_mask = grad_node_explanation(model,1,data.x, data.edge_index)
print("feature masks:",feature_mask)
print("node masks:",node_mask)

feature masks: [0. 0. 0. ... 0. 0. 0.]
node masks: [0.0000000e+00 8.1702863e-05 1.8734181e-05 ... 0.0000000e+00 0.0000000e+00
 0.0000000e+00]


In [15]:
import torch
random_seed=1234
num_nodes=10
samples=5
num_nodes_computation_graph = 5
num_features=10
device ='cpu'
rng = torch.Generator(device=device)
rng.manual_seed(random_seed)
random_indices = torch.randint(num_nodes, (samples, num_nodes_computation_graph, num_features),
                                generator=rng,
                                device=device,
                                )
random_indices = random_indices.type(torch.int64)

print(random_indices[0])

full_feature_matrix = torch.rand(10,10)

print(full_feature_matrix)
random_features = torch.gather(full_feature_matrix,
                                dim=0,
                                index=random_indices[0, :, :])



print(random_features)

tensor([[5, 1, 6, 5, 6, 4, 2, 5, 5, 9],
        [3, 1, 4, 2, 3, 2, 6, 8, 2, 2],
        [8, 2, 0, 4, 9, 2, 1, 9, 2, 2],
        [9, 4, 4, 7, 2, 5, 8, 6, 6, 7],
        [1, 5, 2, 1, 7, 4, 3, 0, 7, 6]])
tensor([[0.9468, 0.6453, 0.7974, 0.5538, 0.8636, 0.4947, 0.5937, 0.8221, 0.0784,
         0.4632],
        [0.9343, 0.7641, 0.5657, 0.0612, 0.7930, 0.6508, 0.7735, 0.1647, 0.2516,
         0.1507],
        [0.8039, 0.9856, 0.5960, 0.6068, 0.6896, 0.7500, 0.3031, 0.1733, 0.2129,
         0.9615],
        [0.6978, 0.0645, 0.9033, 0.4567, 0.3887, 0.4557, 0.4947, 0.5219, 0.9497,
         0.1200],
        [0.2773, 0.1493, 0.0853, 0.5599, 0.7849, 0.5501, 0.5467, 0.3459, 0.7624,
         0.5854],
        [0.2095, 0.4840, 0.2078, 0.9691, 0.7514, 0.5012, 0.1100, 0.6811, 0.3762,
         0.7813],
        [0.6866, 0.7659, 0.1498, 0.0464, 0.2301, 0.7191, 0.1077, 0.7131, 0.3464,
         0.5320],
        [0.2788, 0.3700, 0.6600, 0.1648, 0.5865, 0.8568, 0.4191, 0.4717, 0.4991,
         0.1172],
       

In [16]:
import torch
random_seed=1234
num_edges=20
samples=5
num_edges_computation_graph = 5
device ='cpu'
rng = torch.Generator(device=device)

rng.manual_seed(random_seed)

random_indices = torch.randint(num_edges, (samples, num_edges_computation_graph),
                                generator=rng,
                                device=device,
                                )
random_indices = random_indices.type(torch.int64)

print(random_indices[0])

full_edge_mask= torch.rand(20)

print(full_edge_mask)
random_edge_mask = torch.gather(full_edge_mask,
                                dim=0,
                                index=random_indices[0, :])



print(random_edge_mask)

tensor([15, 11,  6,  5, 16])
tensor([0.1014, 0.2145, 0.5899, 0.9781, 0.6532, 0.2897, 0.9356, 0.6222, 0.9475,
        0.2011, 0.0676, 0.9729, 0.2867, 0.0375, 0.1842, 0.4650, 1.0000, 0.1104,
        0.1774, 0.4018])
tensor([0.4650, 0.9729, 0.9356, 0.2897, 1.0000])
