## Load Packages

In [1]:

# import torchvision
# from torchvision.datasets import CIFAR10
# from torchvision import transforms
# PyTorch Lightning
# Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
# !pip install --quiet pytorch-lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint


In [2]:
## Standard libraries
import os
import json
import math
import numpy as np 
import time

## Imports for plotting
import matplotlib.pyplot as plt
%matplotlib inline 
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgb
import matplotlib
matplotlib.rcParams['lines.linewidth'] = 2.0
import seaborn as sns
sns.reset_orig()
sns.set()

## Progress bar
from tqdm.notebook import tqdm

## PyTorch
import torch
import torch.nn as nn
# import torch.nn.functional as F
import torch.utils.data as data
import torch.optim as optim
# Torchvision
import torchvision
from torchvision.datasets import CIFAR10
from torchvision import transforms
# PyTorch Lightning
try:
    import pytorch_lightning as pl
except ModuleNotFoundError: # Google Colab does not have PyTorch Lightning installed by default. Hence, we do it here if necessary
    !pip install --quiet pytorch-lightning>=1.4
    import pytorch_lightning as pl
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint

# Path to the folder where the datasets are/should be downloaded (e.g. CIFAR10)
DATASET_PATH = "../data"
# Path to the folder where the pretrained models are saved
CHECKPOINT_PATH = "../saved_models/tutorial7"

# Setting the seed
pl.seed_everything(42)

# Ensure that all operations are deterministic on GPU (if used) for reproducibility
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("mps")
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
print(device)

  set_matplotlib_formats('svg', 'pdf') # For export
Seed set to 42


cpu


In [3]:
import torch.nn.functional as F

In [4]:
import torch
x = torch.rand(5, 3)
print(x)

tensor([[0.8823, 0.9150, 0.3829],
        [0.9593, 0.3904, 0.6009],
        [0.2566, 0.7936, 0.9408],
        [0.1332, 0.9346, 0.5936],
        [0.8694, 0.5677, 0.7411]])


In [5]:
import torch
if torch.backends.mps.is_available():
    mps_device = torch.device("mps")
    x = torch.ones(1, device=mps_device)
    print (x)
else:
    print ("MPS device not found.")

tensor([1.], device='mps:0')


In [6]:
# Check that MPS is available
#  not in the original code 
if not torch.backends.mps.is_available():
    if not torch.backends.mps.is_built():
        print("MPS not available because the current PyTorch install was not "
              "built with MPS enabled.")
    else:
        print("MPS not available because the current MacOS version is not 12.3+ "
              "and/or you do not have an MPS-enabled device on this machine.")
else:
    mps_device = torch.device("mps")

In [7]:
from torch_geometric.utils import sort_edge_index

We also have a few pre-trained models we download below.

In [8]:
# import urllib.request
# from urllib.error import HTTPError
# # Github URL where saved models are stored for this tutorial
# base_url = "https://raw.githubusercontent.com/phlippe/saved_models/main/tutorial7/"
# # Files to download
# pretrained_files = ["NodeLevelMLP.ckpt", "NodeLevelGNN.ckpt", "GraphLevelGraphConv.ckpt"]

# # Create checkpoint path if it doesn't exist yet
# os.makedirs(CHECKPOINT_PATH, exist_ok=True)

# # For each file, check whether it already exists. If not, try downloading it.
# for file_name in pretrained_files:
#     file_path = os.path.join(CHECKPOINT_PATH, file_name)
#     if "/" in file_name:
#         os.makedirs(file_path.rsplit("/",1)[0], exist_ok=True)
#     if not os.path.isfile(file_path):
#         file_url = base_url + file_name
#         print(f"Downloading {file_url}...")
#         try:
#             urllib.request.urlretrieve(file_url, file_path)
#         except HTTPError as e:
#             print("Something went wrong. Please try to download the file from the GDrive folder, or contact the author with the full output including the following error:\n", e)

## Graph Neural Networks

### Classes

In [9]:
class GCNLayer(nn.Module): # neural network base 
    
    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out) # linear layer 

    def forward(self, node_feats, adj_matrix):
        """
        Inputs:
            node_feats - Tensor with node features of shape [batch_size, num_nodes, c_in]
            adj_matrix - Batch of adjacency matrices of the graph. If there is an edge from i to j, adj_matrix[b,i,j]=1 else 0.
                         Supports directed edges by non-symmetric matrices. Assumes to already have added the identity connections. 
                         Shape: [batch_size, num_nodes, num_nodes]
        """
        # Num neighbours = number of incoming edges
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats) # nn.Linear 
        node_feats = torch.bmm(adj_matrix, node_feats) # matrix product 
        node_feats = node_feats / num_neighbours
        return node_feats

In [10]:
class GCNLayer_toy(nn.Module): # neural network base 
    
    def __init__(self, c_in, c_out):
        super().__init__()
        self.projection = nn.Linear(c_in, c_out) # linear layer 

    def forward(self, node_feats, adj_matrix):
        """
        Inputs:
            node_feats - Tensor with node features of shape [batch_size, num_nodes, c_in]
            adj_matrix - Batch of adjacency matrices of the graph. If there is an edge from i to j, adj_matrix[b,i,j]=1 else 0.
                         Supports directed edges by non-symmetric matrices. Assumes to already have added the identity connections. 
                         Shape: [batch_size, num_nodes, num_nodes]
        """
        # Num neighbours = number of incoming edges
        num_neighbours = adj_matrix.sum(dim=-1, keepdims=True)
        node_feats = self.projection(node_feats) # nn.Linear 
        node_feats = torch.bmm(adj_matrix, node_feats) # matrix product 
        node_feats = node_feats / num_neighbours
        return node_feats

In [11]:
class GATLayer(nn.Module):
    
    def __init__(self, c_in, c_out, num_heads=1, concat_heads=True, alpha=0.2):
        """
        Inputs:
            c_in - Dimensionality of input features
            c_out - Dimensionality of output features
            num_heads - Number of heads, i.e. attention mechanisms to apply in parallel. The 
                        output features are equally split up over the heads if concat_heads=True.
            concat_heads - If True, the output of the different heads is concatenated instead of averaged.
            alpha - Negative slope of the LeakyReLU activation.
        """
        super().__init__()
        self.num_heads = num_heads
        self.concat_heads = concat_heads
        if self.concat_heads:
            assert c_out % num_heads == 0, "Number of output features must be a multiple of the count of heads."
            c_out = c_out // num_heads
        
        # Sub-modules and parameters needed in the layer
        self.projection = nn.Linear(c_in, c_out * num_heads)
        self.a = nn.Parameter(torch.Tensor(num_heads, 2 * c_out)) # One per head
        self.leakyrelu = nn.LeakyReLU(alpha)
        
        # Initialization from the original implementation
        nn.init.xavier_uniform_(self.projection.weight.data, gain=1.414)
        nn.init.xavier_uniform_(self.a.data, gain=1.414)
        
    def forward(self, node_feats, adj_matrix, print_attn_probs=False):
        """
        Inputs:
            node_feats - Input features of the node. Shape: [batch_size, c_in]
            adj_matrix - Adjacency matrix including self-connections. Shape: [batch_size, num_nodes, num_nodes]
            print_attn_probs - If True, the attention weights are printed during the forward pass (for debugging purposes)
        """
        batch_size, num_nodes = node_feats.size(0), node_feats.size(1)
        
        # Apply linear layer and sort nodes by head
        node_feats = self.projection(node_feats)
        node_feats = node_feats.view(batch_size, num_nodes, self.num_heads, -1) # reshape 
        
        # We need to calculate the attention logits for every edge in the adjacency matrix 
        # Doing this on all possible combinations of nodes is very expensive
        # => Create a tensor of [W*h_i||W*h_j] with i and j being the indices of all edges
        edges = adj_matrix.nonzero(as_tuple=False) # Returns indices where the adjacency matrix is not 0 => edges
        node_feats_flat = node_feats.view(batch_size * num_nodes, self.num_heads, -1)
        edge_indices_row = edges[:,0] * num_nodes + edges[:,1]
        edge_indices_col = edges[:,0] * num_nodes + edges[:,2]
        a_input = torch.cat([
            torch.index_select(input=node_feats_flat, index=edge_indices_row, dim=0),
            torch.index_select(input=node_feats_flat, index=edge_indices_col, dim=0) 
        ], dim=-1) # Index select returns a tensor with node_feats_flat being indexed at the desired positions along dim=0
        
        # Calculate attention MLP output (independent for each head)
        attn_logits = torch.einsum('bhc,hc->bh', a_input, self.a) # sum product of matrices 
        attn_logits = self.leakyrelu(attn_logits)
        
        # Map list of attention values back into a matrix
        attn_matrix = attn_logits.new_zeros(adj_matrix.shape+(self.num_heads,)).fill_(-9e15) # why is there a negative number to fill 
        attn_matrix[adj_matrix[...,None].repeat(1,1,1,self.num_heads) == 1] = attn_logits.reshape(-1)
        
        # Weighted average of attention
        attn_probs = F.softmax(attn_matrix, dim=2)
        if print_attn_probs:
            print("Attention probs\n", attn_probs.permute(0, 3, 1, 2))
        node_feats = torch.einsum('bijh,bjhc->bihc', attn_probs, node_feats)
        
        # If heads should be concatenated, we can do this by reshaping. Otherwise, take mean
        if self.concat_heads:
            node_feats = node_feats.reshape(batch_size, num_nodes, -1)
        else:
            node_feats = node_feats.mean(dim=2)
        
        return node_feats 

In [12]:
# from typing import Optional

# import torch
# from torch import Tensor

# @torch.jit.script
# def softmax(src: Tensor, index: Optional[Tensor] = None,
#             ptr: Optional[Tensor] = None, num_nodes: Optional[int] = None,
#             dim: int = 0) -> Tensor:
#     return src
# import torch_geometric

In [13]:
# torch geometric
try: 
    import torch_geometric
except ModuleNotFoundError:
    # Installing torch geometric packages with specific CUDA+PyTorch version. 
    # See https://pytorch-geometric.readthedocs.io/en/latest/notes/installation.html for details 
    TORCH = torch.__version__.split('+')[0]
    CUDA = 'cu' + torch.version.cuda.replace('.','')

    !pip install torch-scatter     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-sparse      -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-cluster     -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-spline-conv -f https://pytorch-geometric.com/whl/torch-{TORCH}+{CUDA}.html
    !pip install torch-geometric 
    import torch_geometric
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data

In [14]:
import torch_geometric.nn as geom_nn
import torch_geometric.data as geom_data
gnn_layer_by_name = {
    "GCN": geom_nn.GCNConv,
    "GAT": geom_nn.GATConv,
    "GIN": geom_nn.GINConv,
    "GraphConv": geom_nn.GraphConv
}

In [15]:
from torch_geometric.utils import to_networkx, from_networkx

G = to_networkx(cora_dataset[0])
degrees = [val for (node, val) in G.degree()]

In [16]:
class GNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, layer_name="GraphConv", dp_rate=0.1, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of "hidden" graph layers
            layer_name - String of the graph layer to use
            dp_rate - Dropout rate to apply throughout the network
            kwargs - Additional arguments for the graph layer (e.g. number of heads for GAT)
        """
        super().__init__()
        gnn_layer = gnn_layer_by_name[layer_name]
        
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                gnn_layer(in_channels=in_channels, 
                          out_channels=out_channels,
                          **kwargs),
                nn.ReLU(inplace=True)
                # ,nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [gnn_layer(in_channels=in_channels, 
                             out_channels=c_out,
                             **kwargs)]
        self.layers = nn.ModuleList(layers)
    
    def forward(self, x, edge_index):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
        """
        for l in self.layers:
            # For graph layers, we need to add the "edge_index" tensor as additional input
            # All PyTorch Geometric graph layer inherit the class "MessagePassing", hence
            # we can simply check the class type.
            if isinstance(l, geom_nn.MessagePassing):
                x = l(x, edge_index)
            else:
                x = l(x)
        return x

In [17]:
class MLPModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, num_layers=2, dp_rate=0.1):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of the output features. Usually number of classes in classification
            num_layers - Number of hidden layers
            dp_rate - Dropout rate to apply throughout the network
        """
        super().__init__()
        layers = []
        in_channels, out_channels = c_in, c_hidden
        for l_idx in range(num_layers-1):
            layers += [
                nn.Linear(in_channels, out_channels),
                nn.ReLU(inplace=True),
                nn.Dropout(dp_rate)
            ]
            in_channels = c_hidden
        layers += [nn.Linear(in_channels, c_out)]
        self.layers = nn.Sequential(*layers)
    
    def forward(self, x, *args, **kwargs):
        """
        Inputs:
            x - Input features per node
        """
        return self.layers(x)

In [18]:
# Small function for printing the test scores
def print_results(result_dict):
    if "train" in result_dict:
        print(f"Train accuracy: {(100.0*result_dict['train']):4.2f}%")
    if "val" in result_dict:
        print(f"Val accuracy:   {(100.0*result_dict['val']):4.2f}%")
    print(f"Test accuracy:  {(100.0*result_dict['test']):4.2f}%")

In [19]:
class GraphGNNModel(nn.Module):
    
    def __init__(self, c_in, c_hidden, c_out, dp_rate_linear=0.5, **kwargs):
        """
        Inputs:
            c_in - Dimension of input features
            c_hidden - Dimension of hidden features
            c_out - Dimension of output features (usually number of classes)
            dp_rate_linear - Dropout rate before the linear layer (usually much higher than inside the GNN)
            kwargs - Additional arguments for the GNNModel object
        """
        super().__init__()
        self.GNN = GNNModel(c_in=c_in, 
                            c_hidden=c_hidden, 
                            c_out=c_hidden, # Not our prediction output yet!
                            **kwargs)
        self.head = nn.Sequential(
            # nn.Dropout(dp_rate_linear),
            nn.Linear(c_hidden, c_out)
        )

    def forward(self, x, edge_index, batch_idx):
        """
        Inputs:
            x - Input features per node
            edge_index - List of vertex index pairs representing the edges in the graph (PyTorch geometric notation)
            batch_idx - Index of batch element for each node
        """
        x = self.GNN(x, edge_index)
        # x = geom_nn.global_mean_pool(x, batch_idx) # Average pooling
        x = geom_nn.global_max_pool(x, batch_idx) # Average pooling
        # x = geom_nn.global_add_pool(x, batch_idx) # sum pooling
        x = self.head(x)
        return x

In [20]:
class GraphLevelGNN(pl.LightningModule):

    def __init__(self, **model_kwargs):
        super().__init__()
        # Saving hyperparameters
        self.save_hyperparameters()
        
        self.model = GraphGNNModel(**model_kwargs)
        self.loss_module = nn.BCEWithLogitsLoss() if self.hparams.c_out == 1 else nn.CrossEntropyLoss()

    def forward(self, data, mode="train"):
        x, edge_index, batch_idx = data.x, data.edge_index, data.batch
        x = self.model(x, edge_index, batch_idx)
        x = x.squeeze(dim=-1)
        
        if self.hparams.c_out == 1:
            preds = (x > 0).float()
            try: 
                data.y = data.y.float()
            except: pass
        else:
            preds = x.argmax(dim=-1)
        try: 
            loss = self.loss_module(x, data.y)
            acc = (preds == data.y).sum().float() / preds.shape[0]
        except:
            loss = 0
            acc = 0

        return loss,acc,preds


    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=1e-2, weight_decay=0.0) # High lr because of small dataset and small model
        return optimizer

    def training_step(self, batch, batch_idx):
        loss, acc,_ = self.forward(batch, mode="train")
        self.log('train_loss', loss)
        self.log('train_acc', acc)
        return loss

    # def validation_step(self, batch, batch_idx):
    #     _, acc,preds = self.forward(batch, mode="val")
    #     self.log('val_acc', acc)
    #     # self.log('val_pred', preds)

    def test_step(self, batch, batch_idx):
        _, acc,_ = self.forward(batch, mode="test")
        self.log('test_acc', acc)
        # self.log('test_pred', pred_y)

### Graph-level tasks: Graph classification

In [21]:
# setting the max-degree threshold 
thres=6 

In [22]:
tu_dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name="ENZYMES")

In [23]:
tu_dataset[0].edge_index

tensor([[ 0,  0,  0,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  3,  3,  3,  3,  3,
          3,  4,  4,  4,  4,  5,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,  7,  7,
          7,  8,  8,  8,  9,  9,  9,  9,  9, 10, 10, 10, 10, 11, 11, 11, 11, 12,
         12, 12, 12, 12, 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16,
         16, 16, 17, 17, 17, 17, 18, 18, 18, 19, 19, 19, 20, 20, 20, 20, 20, 20,
         21, 21, 21, 21, 21, 22, 22, 22, 22, 23, 23, 23, 23, 24, 24, 24, 24, 25,
         25, 25, 25, 25, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 28, 28, 28, 28,
         28, 28, 29, 29, 29, 29, 29, 29, 29, 30, 30, 30, 30, 30, 31, 31, 31, 32,
         32, 32, 32, 33, 33, 33, 33, 33, 33, 34, 34, 34, 34, 34, 34, 35, 35, 35,
         35, 35, 36, 36, 36, 36],
        [ 1,  2,  3,  0,  2,  3, 24, 27,  0,  1,  3, 27, 28,  0,  1,  2,  4,  5,
         28,  3,  5,  6, 29,  3,  4,  6,  7, 29,  4,  5,  7,  8,  5,  6,  8,  9,
         10,  6,  7,  9,  7,  8, 10, 11, 12,  7,  9, 11, 12,  9, 10, 12, 26

In [24]:
print("Data object:", tu_dataset.data)
print("Length:", len(tu_dataset))
print(f"Average label: {tu_dataset.data.y.float().mean().item():4.2f}")

Data object: Data(x=[19580, 3], edge_index=[2, 74564], y=[600])
Length: 600
Average label: 2.50




In [25]:
tu_dataset.data.x

tensor([[1., 0., 0.],
        [1., 0., 0.],
        [1., 0., 0.],
        ...,
        [0., 1., 0.],
        [0., 1., 0.],
        [0., 1., 0.]])

In [26]:
tu_dataset.data.y

tensor([5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
        5, 5, 5, 5, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4,
        4, 4, 4, 4, 4, 4, 4, 4, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,

In [27]:
tu2=tu_dataset.copy()
type(tu2.data)

torch_geometric.data.data.Data

In [28]:
len(tu2.y)

600

In [29]:
# update label to be 0 or 1 depends on whether it exceeds the threshold 
for i in range(len(tu2.y)): 
    tu2.y[i]=max(torch.bincount(tu2[i].edge_index[0,:]))
    # tu2.y[i]=sum(torch.bincount(tu2[i].edge_index[0,:]))

In [30]:
tu2.y

tensor([7, 6, 6, 5, 6, 6, 7, 8, 5, 5, 3, 6, 5, 5, 5, 6, 7, 7, 1, 7, 7, 6, 6, 6,
        6, 7, 7, 6, 6, 6, 7, 7, 7, 5, 5, 8, 7, 2, 6, 7, 7, 6, 6, 6, 6, 7, 5, 5,
        5, 4, 6, 7, 6, 6, 6, 7, 5, 6, 5, 5, 4, 6, 5, 6, 5, 5, 5, 6, 6, 5, 7, 5,
        6, 6, 6, 6, 5, 5, 5, 5, 6, 5, 5, 5, 5, 6, 6, 7, 8, 5, 7, 7, 5, 5, 5, 5,
        7, 6, 6, 4, 7, 6, 9, 5, 7, 5, 8, 8, 6, 5, 6, 6, 6, 5, 5, 5, 8, 5, 5, 7,
        9, 6, 9, 6, 7, 5, 5, 6, 5, 5, 5, 6, 5, 6, 7, 2, 7, 7, 7, 5, 6, 5, 6, 6,
        5, 7, 6, 6, 6, 5, 6, 7, 6, 4, 7, 5, 5, 5, 7, 7, 6, 5, 5, 7, 5, 5, 7, 6,
        6, 5, 7, 5, 6, 6, 6, 6, 8, 6, 7, 7, 7, 6, 6, 6, 6, 5, 6, 5, 7, 6, 6, 6,
        6, 6, 7, 8, 7, 5, 6, 6, 5, 5, 6, 6, 6, 6, 5, 7, 6, 6, 7, 7, 7, 8, 7, 6,
        6, 8, 7, 6, 6, 5, 6, 8, 5, 7, 7, 7, 7, 7, 6, 6, 8, 6, 6, 6, 5, 7, 7, 7,
        5, 5, 5, 6, 7, 5, 6, 5, 5, 8, 5, 5, 6, 7, 5, 6, 5, 6, 6, 7, 7, 6, 6, 6,
        6, 7, 7, 8, 8, 6, 6, 6, 7, 6, 6, 7, 6, 6, 6, 5, 7, 6, 5, 5, 5, 6, 7, 7,
        6, 6, 6, 6, 6, 6, 5, 5, 7, 6, 5,

In [31]:
tu2.data.y=(tu2.y>thres).long()

In [32]:
tu2.data.y

tensor([1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0,
        0, 1, 1, 0, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 0, 0, 1, 0, 0,
        0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 0, 0, 0,
        1, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1,
        1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0,
        0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0,
        0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0,
        0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 0,
        0, 1, 1, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 1, 1,
        0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0,
        0, 1, 1, 1, 1, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,

In [33]:
tu2.data.y.sum(),tu2.data.y.size()

(tensor(188), torch.Size([600]))

In [34]:
# update node features to be random gaussian 
# we dont want the GNN to pick up features, but rather using the structure of graph to do the prediciton 
tu2.data.x=tu2.x[:,:1]
# tu2.data.x=torch.ones(tu2.x.shape) 
# assigning all node features to be 1 will confuse the model 
tu2.data.x=torch.randn(tu2.x.shape)

In [35]:
tu2.x.shape

torch.Size([19580, 1])

In [36]:
tu2[1]

Data(edge_index=[2, 102], x=[23, 1], y=[1])

In [37]:
# just to stop before training the model for sake of sanity check 
# stop

In [38]:
# torch.manual_seed(42)
# tu_dataset.shuffle()
# train_dataset = tu_dataset[:500]
# test_dataset = tu_dataset[500:]
torch.manual_seed(42)
tu2_shuffle=tu2.shuffle()
train_dataset = tu2_shuffle[:500]
test_dataset = tu2_shuffle[500:]

When using a data loader, we encounter a problem with batching $N$ graphs. Each graph in the batch can have a different number of nodes and edges, and hence we would require a lot of padding to obtain a single tensor. Torch geometric uses a different, more efficient approach: we can view the $N$ graphs in a batch as a single large graph with concatenated node and edge list. As there is no edge between the $N$ graphs, running GNN layers on the large graph gives us the same output as running the GNN on each graph separately. Visually, this batching strategy is visualized below (figure credit - PyTorch Geometric team, [tutorial here](https://colab.research.google.com/drive/1I8a0DfQ3fI7Njc62__mVXUlcAleUclnb?usp=sharing#scrollTo=2owRWKcuoALo)).

<center width="100%"><img src="torch_geometric_stacking_graphs.png" width="600px"></center>

The adjacency matrix is zero for any nodes that come from two different graphs, and otherwise according to the adjacency matrix of the individual graph. Luckily, this strategy is already implemented in torch geometric, and hence we can use the corresponding data loader:

In [39]:
graph_train_loader = geom_data.DataLoader(train_dataset, batch_size=260, shuffle=True)
graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=15) # Additional loader if you want to change to a larger dataset
graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=15)



Let's load a batch below to see the batching in action:

In [40]:
batch = next(iter(graph_test_loader))
print("Batch:", batch)
print("Labels:", batch.y[:10])
print("Batch indices:", batch.batch[:40])

Batch: DataBatch(edge_index=[2, 2192], x=[569, 1], y=[15], batch=[569], ptr=[16])
Labels: tensor([0, 0, 1, 1, 1, 1, 0, 0, 0, 0])
Batch indices: tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])


We have 38 graphs stacked together for the test dataset. The batch indices, stored in `batch`, show that the first 12 nodes belong to the first graph, the next 22 to the second graph, and so on. These indices are important for performing the final prediction. To perform a prediction over a whole graph, we usually perform a pooling operation over all nodes after running the GNN model. In this case, we will use the average pooling. Hence, we need to know which nodes should be included in which average pool. Using this pooling, we can already create our graph network below. Specifically, we re-use our class `GNNModel` from before, and simply add an average pool and single linear layer for the graph prediction task. 

Finally, let's perform the training and testing. Feel free to experiment with different GNN layers, hyperparameters, etc.

In [41]:
def train_graph_classifier(model_name,train_loader=graph_train_loader,test_loader=graph_test_loader, **model_kwargs):
    pl.seed_everything(42)
    
    # Create a PyTorch Lightning trainer with the generation callback
    root_dir = os.path.join(CHECKPOINT_PATH, "GraphLevel" + model_name)
    os.makedirs(root_dir, exist_ok=True)
    trainer = pl.Trainer(default_root_dir=root_dir,
                         # callbacks=[ModelCheckpoint(save_weights_only=True, mode="max", monitor="val_acc")],
                         accelerator="cpu",# if str(device).startswith("cuda") else "cpu",
                         devices=1,
                         max_epochs=500,
                         enable_progress_bar=False)
    trainer.logger._default_hp_metric = None # Optional logging argument that we don't need

    # Check whether pretrained model exists. If yes, load it and skip training
    pretrained_filename = os.path.join(CHECKPOINT_PATH, f"GraphLevel{model_name}.ckpt")
    if os.path.isfile(pretrained_filename):
        # print("Found pretrained model, loading...")
        # model = GraphLevelGNN.load_from_checkpoint(pretrained_filename)
        pl.seed_everything(42)
        model = GraphLevelGNN(c_in=tu2.num_node_features, 
                              c_out=1 if tu2.num_classes==2 else tu2.num_classes, 
                              **model_kwargs)
        trainer.fit(model, graph_train_loader, graph_val_loader)
        # model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    else:
        pl.seed_everything(42)
        model = GraphLevelGNN(c_in=tu2.num_node_features, 
                              c_out=1 if tu2.num_classes==2 else tu2.num_classes, 
                              **model_kwargs)
        trainer.fit(model, graph_train_loader, graph_val_loader)
        # model = GraphLevelGNN.load_from_checkpoint(trainer.checkpoint_callback.best_model_path)
    # Test best model on validation and test set
    train_result = trainer.test(model, train_loader, verbose=False)
    test_result = trainer.test(model, test_loader, verbose=False)
    # test_pred = trainer.predict(model, graph_test_loader, return_predictions=True)
    result = {"test": test_result[0]['test_acc'], "train": train_result[0]['test_acc']
              # ,"pred_y": test_pred
            } 
    return model, result

In [42]:
help(pl.trainer)

Help on package pytorch_lightning.trainer in pytorch_lightning:

NAME
    pytorch_lightning.trainer

DESCRIPTION
    # Copyright The Lightning AI team.
    #
    # Licensed under the Apache License, Version 2.0 (the "License");
    # you may not use this file except in compliance with the License.
    # You may obtain a copy of the License at
    #
    #     http://www.apache.org/licenses/LICENSE-2.0
    #
    # Unless required by applicable law or agreed to in writing, software
    # distributed under the License is distributed on an "AS IS" BASIS,
    # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    # See the License for the specific language governing permissions and
    # limitations under the License.

PACKAGE CONTENTS
    call
    configuration_validator
    connectors (package)
    setup
    states
    trainer

CLASSES
    builtins.object
        pytorch_lightning.trainer.trainer.Trainer
    
    class Trainer(builtins.object)
     |  Trainer(*, acc

In [43]:
model, result = train_graph_classifier(model_name="GraphConv", 
                                       c_hidden=16, 
                                       layer_name="GraphConv", 
                                       num_layers=3, 
                                       dp_rate_linear=0,
                                       dp_rate=0.0)

Seed set to 42
GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/Users/beatrixwen/miniforge3/envs/tensorflow/lib/python3.9/site-packages/pytorch_lightning/trainer/setup.py:187: GPU available but not used. You can set it by doing `Trainer(accelerator='gpu')`.
Seed set to 42
/Users/beatrixwen/miniforge3/envs/tensorflow/lib/python3.9/site-packages/pytorch_lightning/trainer/configuration_validator.py:72: You passed in a `val_dataloader` but have no `validation_step`. Skipping val loop.

  | Name        | Type              | Params
--------------------------------------------------
0 | model       | GraphGNNModel     | 1.1 K 
1 | loss_module | BCEWithLogitsLoss | 0     
--------------------------------------------------
1.1 K     Trainable params
0         Non-trainable params
1.1 K     Total params
0.004     Total estimated model params size (MB)
/Users/beatrixwen/miniforge3/envs/tensorfl

In [44]:
print(f"Train performance: {100.0*result['train']:4.2f}%")
print(f"Test performance:  {100.0*result['test']:4.2f}%")

Train performance: 100.00%
Test performance:  100.00%


stop

In [45]:
proteins_dataset = torch_geometric.datasets.TUDataset(root=DATASET_PATH, name="PROTEINS")
print("Data object:", proteins_dataset.data)
print("Length:", len(proteins_dataset))
print(f"Average label: {proteins_dataset.data.y.float().mean().item():4.2f}")

Data object: Data(x=[43471, 3], edge_index=[2, 162088], y=[1113])
Length: 1113
Average label: 0.40




In [46]:
proteins_dataset.data

Data(x=[43471, 3], edge_index=[2, 162088], y=[1113])

### Test on a different dataset 

In [47]:
tu3=proteins_dataset.copy()
tu3.data.x=tu3.data.x[:,:1]
# tu2.data.x=torch.ones(tu2.x.shape)
tu3.data.x=torch.randn(tu3.x.shape)
for i in range(len(tu3.y)): 
    tu3.y[i]=max(torch.bincount(tu3[i].edge_index[0,:]))

In [48]:
tu3.y[:20]

tensor([8, 5, 5, 7, 5, 7, 7, 8, 5, 5, 4, 8, 5, 6, 6, 5, 5, 7, 7, 5])

In [49]:
model(tu3[0]),model(tu3[1]),model(tu3[2]),model(tu3[3]),model(tu3[4]),model(tu3[5]),model(tu3[6]),model(tu3[7])

((tensor(-120.8686, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([1.])),
 (tensor(151.6385, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([0.])),
 (tensor(134.9324, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([0.])),
 (tensor(-53.9077, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([1.])),
 (tensor(123.2618, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([0.])),
 (tensor(-64.3469, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([1.])),
 (tensor(-60.1594, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([1.])),
 (tensor(-167.3324, grad_fn=<BinaryCrossEntropyWithLogitsBackward0>),
  tensor(0.),
  tensor([1.])))

In [50]:
tu3.data.y=(tu3.y>thres).long()
tu3.y



tensor([1, 0, 0,  ..., 0, 0, 0])

In [51]:
tu3.y[:20]

tensor([1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 1, 0])

In [52]:
tu3.data.y.sum()

tensor(234)

In [53]:
proteins_test_loader = geom_data.DataLoader(tu3[:100], batch_size=1)
# proteins_result = trainer.test(model, proteins_test_loader, verbose=False)



In [54]:
model.eval()

GraphLevelGNN(
  (model): GraphGNNModel(
    (GNN): GNNModel(
      (layers): ModuleList(
        (0): GraphConv(1, 16)
        (1): ReLU(inplace=True)
        (2): GraphConv(16, 16)
        (3): ReLU(inplace=True)
        (4): GraphConv(16, 16)
      )
    )
    (head): Sequential(
      (0): Linear(in_features=16, out_features=1, bias=True)
    )
  )
  (loss_module): BCEWithLogitsLoss()
)

In [55]:
stop

NameError: name 'stop' is not defined

## Run Proposed Algorithms

In [None]:
graph_test_1 = tu2[0]
graph_test_1

In [None]:
import networkx as nx

In [None]:
from torch_geometric.data import Data 

In [None]:
import operator

In [None]:
def excl_low_value_nodes(imp_nodes,excl_round=1,node_threshold=None): 
    for i in range(excl_round): 
        if min(imp_nodes.values())==max(imp_nodes.values()): 
            return imp_nodes
        else: 
            imp_nodes = {key:val for key, val in imp_nodes.items() if val != min(imp_nodes.values()) }
            print (i) 
            if node_threshold!=None: 
                if len(list(imp_nodes.keys()))<=node_threshold: 
                    return imp_nodes 
        # return imp_nodes 

    
                         
    

In [None]:
def shapley_dict_count(shapley_dict_lst,excl_min=0): 
    # if draw_graph==True: 
    #     draw_with_color(graph)
    # i = shapley_dict["rank"+str(level)].copy()
    i = shapley_dict_lst.copy() 
    d = {x:i.count(x) for x in i}
    sorted_dict = dict(sorted(d.items(), key=operator.itemgetter(1)))
    if excl_min==0: 
        return sorted_dict
    elif excl_min>=1: 
        if min(d.values())==max(d.values()): 
            imp_nodes=sorted_dict
        else: 
            imp_nodes = {key:val for key, val in sorted_dict.items() if val != min(d.values()) }
        if excl_min==2: 
            imp_nodes = {key:val for key, val in imp_nodes.items() if val != min(imp_nodes.values()) }
        return imp_nodes 
    # graph_i=new_subgraph(graph,include_node_lst=list(imp_nodes.keys()))
    # if draw_subgraph==True: 
    #     draw_with_color(graph_i)
    # return graph_i,imp_nodes
    
                         
    

In [None]:
def new_subgraph(graph,include_node_lst=None,exclude_node_lst=None,add_node=False,add_sudo_nodes=False,node_lst=None, draw_subgraph=False): 
    if add_node!=False: 
        n=graph.x.shape[0]
        num_edges_to_add=len(add_node)
        # node_lst=[i for i in range(graph.x.shape[0])]
        edge_index_new_0=torch.cat((graph.edge_index[0],int(n)*torch.ones(num_edges_to_add, dtype=torch.int32),torch.as_tensor(add_node, dtype=torch.int32)))
        edge_index_new_1=torch.cat((graph.edge_index[1],torch.as_tensor(add_node, dtype=torch.int32),int(n)*torch.ones(num_edges_to_add, dtype=torch.int32)))
        edge_index_new=torch.stack((edge_index_new_0,edge_index_new_1))
        x_new=torch.cat((graph.x,torch.randn(1,1)))
        graph = Data(x=x_new, edge_index=edge_index_new)
        # graph= torch_geometric.utils.to_networkx(data_new)
        # nx.draw(g_new,with_labels=True)
        graph_i =graph 

        
    if include_node_lst!=None: 
        subset = torch.zeros_like(graph.edge_index[0], dtype = bool)
        subset[include_node_lst] = True
        # print(include_node_lst,th_preserve(graph.x,[include_node_lst]))
        graph_i = torch_geometric.data.Data(x=graph.x, #=th_preserve(graph.x,[include_node_lst]),
                                            edge_index=torch_geometric.utils.subgraph(subset,graph.edge_index,relabel_nodes=False)[0],
                                            y=graph.y)

    if add_sudo_nodes==True: 
        n=graph.x.shape[0]
        num_edges_to_add=len(node_lst)
        # node_lst=[i for i in range(graph.x.shape[0])]
        edge_index_new_0=torch.cat((graph.edge_index[0],int(n)*torch.ones(num_edges_to_add, dtype=torch.int32),torch.as_tensor(node_lst, dtype=torch.int32)))
        edge_index_new_1=torch.cat((graph.edge_index[1],torch.as_tensor(node_lst, dtype=torch.int32),int(n)*torch.ones(num_edges_to_add, dtype=torch.int32)))
        edge_index_new=torch.stack((edge_index_new_0,edge_index_new_1))
        x_new=torch.cat((graph.x,torch.randn(1,1)))
        graph = Data(x=x_new, edge_index=edge_index_new)
        # graph= torch_geometric.utils.to_networkx(data_new)
        # nx.draw(g_new,with_labels=True)
        graph_i =graph 
        
    if exclude_node_lst!=None: 
        subset = torch.ones_like(graph.edge_index[0], dtype = bool)
        subset[exclude_node_lst] = False
        graph_i = torch_geometric.data.Data(x=graph.x, #=th_delete(graph.x,[exclude_node_lst]),
                                            edge_index=torch_geometric.utils.subgraph(subset,graph.edge_index,relabel_nodes=False)[0],
                                            y=graph.y)
    
    if draw_subgraph==True: 
            draw_with_color(graph_i)
    return graph_i


In [None]:
torch.zeros(1,1)

In [None]:
def run_shapley(graph,level=1,shapley_storage=None,skip_to_round=1,previous_shapley=None): 
    graph_label=model(graph)[2]
    node_num=graph.x.size()[0]
    n=node_num
    if shapley_storage==None: 
        shapley_storage = torch.zeros((n,n,n))
    else: 
        shapley_storage=shapley_storage
    shapley_dict={}
    round=skip_to_round
    rank=1

    if round==1:     
        shapley_dict["rank"+str(rank)]=[]
        for i in range(node_num):
            # global round
            # subset = torch.ones_like(graph.edge_index[0], dtype = bool)
            # subset[[i]] = False
            # graph_i = torch_geometric.data.Data(x=graph.x, #=th_delete(graph.x,[i]),
            #                                     edge_index=torch_geometric.utils.subgraph(subset,graph.edge_index,relabel_nodes=True)[0],
            #                                     y=graph.y)
            graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i],add_sudo_nodes=False)  
            shapley_i=graph_label-model(graph_i)[2] 
            # print(shapley_i)
            if shapley_i!=0: 
                shapley_storage[i,0,0]=1
                shapley_dict["rank"+str(rank)].append(i)
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage
            else: 
                rank+=1 
                # round+=1                       
        else: 
            print("No results after the "+str(round)+" round")
        round+=1 
        
    if round==2: 
        shapley_dict["rank"+str(rank)]=[]
        for i in range(node_num):
            # print(i)
            for j in range(i): 
                graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j],add_sudo_nodes=False)    
                shapley_i=shapley_storage[i,0,0]-model(graph_i)[2] 
                shapley_i=graph_label-model(graph_i)[2] 
                
                # if (i==5) and (j==4): 
                #     print(shapley_storage[i,0,0],model(graph_i)[2],shapley_i)
                if shapley_i!=0: 
                    shapley_storage[i,j,0]=1
                    shapley_dict["rank"+str(rank)].append(i)
                    shapley_dict["rank"+str(rank)].append(j)
    
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage 
            else: rank+=1 
        else: 
            print("No results after the "+str(round)+" round") 
        round+=1
            
    if round==3: 
        shapley_dict["rank"+str(rank)]=[]
        for i in range(node_num):
            for j in range(i): 
                for k in range(j): 
                    graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j,k],add_sudo_nodes=False)    
                    shapley_i=shapley_storage[i,j,0]-model(graph_i)[2] 
                    shapley_i=graph_label-model(graph_i)[2] 
                    # if shapley_i>0: 
                    if model(graph_i)[2]<1: 
                        shapley_storage[i,j,k]=1
                        shapley_dict["rank"+str(rank)].append(i)
                        shapley_dict["rank"+str(rank)].append(j)
                        shapley_dict["rank"+str(rank)].append(k)
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage 
            else: rank+=1 
        else: 
            print("No results after the "+str(round)+" round") 
            round+=1
        # return shapley_dict  , shapley_storage

In [None]:
def run_shapley_specified(graph,level=1,shapley_storage=None,skip_to_round=1,previous_shapley=None,node_list=node_lst): 
    graph_label=model(graph)[2]
    node_num=len(node_list)
    n=node_num
    if shapley_storage==None: 
        shapley_storage = torch.zeros((n,n,n))
    else: 
        shapley_storage=shapley_storage
    shapley_dict={}
    round=skip_to_round
    rank=1

    if round==1:     
        shapley_dict["rank"+str(rank)]=[]
        for t in range(node_num):
            # global round
            # subset = torch.ones_like(graph.edge_index[0], dtype = bool)
            # subset[[i]] = False
            # graph_i = torch_geometric.data.Data(x=graph.x, #=th_delete(graph.x,[i]),
            #                                     edge_index=torch_geometric.utils.subgraph(subset,graph.edge_index,relabel_nodes=True)[0],
            #                                     y=graph.y)
            i=node_list[t]
            graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i],add_sudo_nodes=False)  
            shapley_i=graph_label-model(graph_i)[2] 
            # print(shapley_i)
            if shapley_i!=0: 
                shapley_storage[i,0,0]=1
                shapley_dict["rank"+str(rank)].append(i)
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage
            else: 
                rank+=1 
                # round+=1                       
        else: 
            print("No results after the "+str(round)+" round")
        round+=1 
        
    if round==2: 
        shapley_dict["rank"+str(rank)]=[]
        for t in range(node_num):
            # print(i)
            for s in range(t):
                i=node_list[t]
                j=node_list[s]
                graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j],add_sudo_nodes=False)    
                shapley_i=shapley_storage[i,0,0]-model(graph_i)[2] 
                shapley_i=graph_label-model(graph_i)[2] 
                
                # if (i==5) and (j==4): 
                #     print(shapley_storage[i,0,0],model(graph_i)[2],shapley_i)
                if shapley_i!=0: 
                    shapley_storage[i,j,0]=1
                    shapley_dict["rank"+str(rank)].append(i)
                    shapley_dict["rank"+str(rank)].append(j)
    
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage 
            else: rank+=1 
        else: 
            print("No results after the "+str(round)+" round") 
        round+=1
            
    if round==3: 
        shapley_dict["rank"+str(rank)]=[]
        for t in range(node_num):
            for s in range(i): 
                for w in range(j): 
                    i=node_list[t]
                    j=node_list[s]
                    k=node_list[w]
                    graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j,k],add_sudo_nodes=False)    
                    shapley_i=shapley_storage[i,j,0]-model(graph_i)[2] 
                    shapley_i=graph_label-model(graph_i)[2] 
                    # if shapley_i>0: 
                    if model(graph_i)[2]<1: 
                        shapley_storage[i,j,k]=1
                        shapley_dict["rank"+str(rank)].append(i)
                        shapley_dict["rank"+str(rank)].append(j)
                        shapley_dict["rank"+str(rank)].append(k)
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage 
            else: rank+=1 
        else: 
            print("No results after the "+str(round)+" round") 
            round+=1
        # return shapley_dict  , shapley_storage

In [None]:
graph=graph_test_1
i=5
j=4
graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j],add_sudo_nodes=False)    
shapley_i=1-model(graph_i)[2] 
shapley_i

In [None]:
def th_delete(tensor, indices):
    mask = torch.ones(tensor.numel(), dtype=torch.bool)
    mask[indices] = False
    return tensor[mask]

In [None]:
def th_preserve(tensor, indices):
    mask = torch.zeros(tensor.numel(), dtype=torch.bool)
    mask[indices] = True
    return tensor[mask]

In [None]:
def draw_with_color(graph,threshold=thres,isolates=True,exist_sudo=False,extra_node=False): 
    color_index_i=torch.bincount(graph.edge_index[0,:]).tolist()
    if extra_node!=False: 
        color_index_i=[-1 if (i>extra_node) else color_index_i[i] for i in range(len(color_index_i)) ]
    color_index_i=[i for i in color_index_i if i != 0]
    print(color_index_i)
    if extra_node==False: 
        color_index_i=["red" if (i>thres) else "lightblue" for i in color_index_i]
    if extra_node!=False: 
        color_index_i=[ "yellow" if (i==-1) else"red" if (i>thres) else "lightblue" for i in color_index_i]
        # ['yes' if v == 1 else 'no' if v == 2 else 'idle' for v in l]
    
    g_i = torch_geometric.utils.to_networkx(graph)

    if isolates==True: 
        g_i.remove_nodes_from(list(nx.isolates(g_i))) 
    if exist_sudo==True: 
        print("sudo",graph.x.shape[0]-1)
        g_i.remove_nodes_from([graph.x.shape[0]-1])
        nx.draw(g_i,with_labels=True,node_color=color_index_i[:-1])
    else: nx.draw(g_i,with_labels=True,node_color=color_index_i) # ,node_color=color_index_i

In [None]:
color_index_i=torch.bincount(graph_update.edge_index[0,:]).tolist()
color_index_i

In [None]:
extra_node=10
[color_index_i[i] if (i<extra_node) else "yellow" for i in range(len(color_index_i)) ]

In [None]:
color_index_i=[i for i in color_index_i if i != 0]
color_index_i

In [None]:
graph_test_1 = tu2[6]
# graph_test_1 = tu2[7] #0,6,7,16,17,19,20,25,26,30,31,32
# graph_test_1=new_subgraph(graph_test_1,add_node=[0,1,2,4,5,6,7])
draw_with_color(graph_test_1)

In [None]:
graph_test_1 = tu2[7]
draw_with_color(graph_test_1)

In [None]:
shapley_add_node,shapley_add_node_storage=run_shapley(graph_test_1,level=2)
# shapley_add_node

In [None]:
shapley_add_node_storage[5,4,0]

In [None]:
shapley_add_node

In [None]:
set(list(shapley_add_node['rank1'].keys())+list(shapley_add_node['rank2'].keys()))

In [None]:
node_lst=list(set( val for dic in  shapley_add_node.values() for val in dic.keys()))
# seperate_1=new_subgraph(graph_test_1,list(set(list(shapley_add_node['rank1'].keys()))),draw_subgraph=True)
seperate_1=new_subgraph(graph_test_1,node_lst,draw_subgraph=True)

In [None]:
cut_shapley=excl_low_value_nodes(shapley_add_node['rank2'],node_threshold=seperate_1.x.shape[0]*.5)
cut_seperate_1=new_subgraph(graph_test_1,list(cut_shapley.keys()),draw_subgraph=True)

In [None]:
cut_shapley=excl_low_value_nodes(shapley_add_node['rank2'],excl_round=2,node_threshold=seperate_1.x.shape[0]*.3)
cut_seperate_1=new_subgraph(graph_test_1,list(cut_shapley.keys()),draw_subgraph=True)

In [None]:
def generate_graph(node_num,max_degree,gseed,draw_graph=False,star=False):
    np.random.seed(gseed)
    if star==True: 
        degree_lst=np.ones(max_degree+1)
        degree_lst[np.random.randint(max_degree)]=max_degree 
        print(degree_lst)
    else: 
        degree_lst=np.random.randint(max_degree-3, size=node_num)
        degree_lst[np.random.randint(node_num)]=max_degree 
    G = nx.random_degree_sequence_graph(degree_lst,gseed)
    G = from_networkx(G)
    G.x=torch.rand(node_num,1)
    if draw_graph==True: 
        draw_with_color(G)
    return G 

In [None]:
def seperate_graph(graph,node_to_excl,draw_graph=False): 
    edge_index_update=graph.edge_index
    edge_index_update=sort_edge_index(edge_index_update)
    label_max=graph.x.shape[0]
    edge_index_id=(edge_index_update == node_to_excl).nonzero()
    edge_index_copy=edge_index_update.clone()
    node_degree=int((edge_index_id.shape[0])/2)
    for i in range(node_degree): 
        edge_index_copy[0,edge_index_id[i,1]]=graph.x.shape[0]+i
        edge_index_copy[1,edge_index_id[i+node_degree,1]]=graph.x.shape[0]+i
    x_new=torch.cat((graph.x,torch.zeros(node_degree,1)))
    graph_update=Data(x=x_new, edge_index=edge_index_copy)
    if draw_graph==True: 
        draw_with_color(graph_update,extra_node=graph.x.shape[0])
    return graph_update

In [None]:
generate_graph(11,7,203,draw_graph=True,star=True)

In [None]:
gg_star_1=generate_graph(8,7,203,draw_graph=True,star=True)
gg_star_1

In [None]:
gg_star_1.x

In [None]:
to_networkx(gg_star_1).nodes[0]

In [None]:
calculate_shapley(gg_star_1) 

In [None]:
gg_star_2=generate_graph(9,8,203,draw_graph=True,star=True)
gg_star_2

In [None]:
calculate_shapley(gg_star_2)

In [None]:
gg_star_2.edge_index[0]+8

In [None]:
torch.cat((gg_star_1.x,gg_star_2.x)),gg_star_1.x,gg_star_2.x

In [None]:
edge_index_new_0=torch.cat((gg_star_1.edge_index[0],gg_star_2.edge_index[0]+8))
edge_index_new_1=torch.cat((gg_star_1.edge_index[1],gg_star_2.edge_index[1]+8))
edge_index_new=torch.stack((edge_index_new_0,edge_index_new_1))
x_new=torch.cat((gg_star_1.x,gg_star_2.x))
gg_star_3 = Data(x=x_new, edge_index=edge_index_new)
draw_with_color(gg_star_3)

In [None]:
calculate_shapley(gg_star_3)

In [None]:
shapley_star_2,shapley_star_2_storage=run_shapley(gg_star_2,level=2)
shapley_star_2

In [None]:
shapley_star_3,shapley_star_3_storage=run_shapley(gg_star_3,level=2)
shapley_star_3

In [None]:
# shapley_star_3["rank1"]/(gg_star_3.x.shape[0])
a = {k: v / int(gg_star_2.x.shape[0])/(int(gg_star_2.x.shape[0])-1) for k, v in shapley_star_2["rank1"].items()}
a
# b = {k: v *2/ int(gg_star_2.x.shape[0])/(int(gg_star_2.x.shape[0])-1)/(int(gg_star_2.x.shape[0])-2) for k, v in shapley_star_2["rank2"].items()}
# update_value=a.copy()
# update_value=Counter(update_value)+Counter(b)
# update_value

In [None]:
# shapley_star_3["rank1"]/(gg_star_3.x.shape[0])
a = {k: v / int(gg_star_2.x.shape[0]) for k, v in shapley_star_2["rank1"].items()}
b = {k: v *2/ int(gg_star_2.x.shape[0])/(int(gg_star_2.x.shape[0])-1) for k, v in shapley_star_2["rank2"].items()}
update_value=a.copy()
update_value=Counter(update_value)+Counter(b)
update_value

In [None]:
# shapley_star_3["rank1"]/(gg_star_3.x.shape[0])
a = {k: v / int(gg_star_3.x.shape[0])/(int(gg_star_3.x.shape[0])-1) for k, v in shapley_star_3["rank1"].items()}
a

In [None]:
a[12]-a[2]

In [None]:
# shapley_star_3["rank1"]/(gg_star_3.x.shape[0])
b = {k: v *2/ int(gg_star_3.x.shape[0])/(int(gg_star_3.x.shape[0])-1)/(int(gg_star_3.x.shape[0])-2) for k, v in shapley_star_3["rank2"].items()}
b

In [None]:
from collections import Counter

In [None]:
update_value=a.copy()
update_value=Counter(update_value)+Counter(b)
update_value

In [None]:
sum(list(update_value.values()))

In [None]:
shapley_star_3.items()

In [None]:
# random.seed(12)
# generate_graph(11,7,116,draw_graph=True)

In [None]:
# gg=generate_graph(11,7,116,draw_graph=True)
gg=generate_graph(11,7,203,draw_graph=True,star=True)
edge_index_update=gg.edge_index
edge_index_update=sort_edge_index(edge_index_update)
edge_index_update

In [None]:
seperate_graph(gg,4,draw_graph=True)

In [None]:
gg.x.shape[0]

In [None]:
dc={}
dc[1]=0
dc[1]+=.1
dc

In [None]:
def calculate_shapley(graph,model=model): 
    phi={}
    n=graph.x.shape[0] 
    x_lst=[i for i in range(n)]
    w={}
    for i in range(n): 
        w[i]=factorial(i)*factorial(n-i-1)/factorial(n)
    for j in range(n): 
        lst_minus_j=x_lst.copy()
        lst_minus_j.pop(j) 
        phi[j]=0
        for k in range(n): 
            if k==0: 
                graph_C_j=new_subgraph(graph,include_node_lst=[j],exclude_node_lst=None,add_sudo_nodes=False)
                v_C_j=model(graph_C_j)[2]
                phi[j]+=w[k]*int(v_C_j)
            else: 
                set_C_lst=list(itertools.combinations(lst_minus_j, k))
                for l in range(len(set_C_lst)): 
                    set_C=list(set_C_lst[l])
                    set_C_j=set_C+[j]
                    graph_C=new_subgraph(graph,include_node_lst=set_C,exclude_node_lst=None,add_sudo_nodes=False)
                    graph_C_j=new_subgraph(graph,include_node_lst=set_C_j,exclude_node_lst=None,add_sudo_nodes=False)
                    v_C=model(graph_C)[2]
                    v_C_j=model(graph_C_j)[2]
                    phi[j]+=w[k]*int(v_C_j-v_C)
    return phi 
        
    

In [None]:
list(itertools.combinations(lst_minus_j, k))

In [None]:
calculate_shapley(gg)

calculate_shapley(graph_test_1)

In [None]:
n=gg.x.shape[0]
x_lst=[i for i in range(n)]
x_lst

In [None]:
[1,2]+[4]

In [None]:
factorial(5)

In [None]:
import itertools
from math import factorial

In [None]:
print(list(itertools.combinations(x_lst, 2)))

In [None]:
perm_lst=list(itertools.combinations(x_lst, 2))
list(perm_lst[1])

In [None]:
x_j_lst=x_lst.copy()
x_j_lst.pop(2)
x_lst,x_j_lst

In [None]:
graph_i=new_subgraph(gg,include_node_lst=list(perm_lst[1]),exclude_node_lst=None,add_sudo_nodes=False)
int(model(graph_i)[2]-model(gg)[2])

In [None]:
sort_edge_index(edge_index_update)

In [None]:
gg.x.shape[0]

In [None]:
node_val=7
label_max=gg.x.shape[0]
edge_index_id=(edge_index_update == node_val).nonzero()
edge_index_id

In [None]:
edge_index_copy=edge_index_update.clone()
node_degree=int((edge_index_id.shape[0])/2)
print(node_degree)
for i in range(node_degree): 
    # print((0,edge_index_id[i,1]),edge_index_copy[0,edge_index_id[i,1]])
    # print((1,edge_index_id[i+node_degree,1]),edge_index_copy[1,edge_index_id[i+node_degree,1]])
    edge_index_copy[0,edge_index_id[i,1]]=gg.x.shape[0]+i
    edge_index_copy[1,edge_index_id[i+node_degree,1]]=gg.x.shape[0]+i
edge_index_copy

In [None]:
x_new=torch.cat((gg.x,torch.zeros(node_degree,1)))
graph_update=Data(x=x_new, edge_index=edge_index_copy)
draw_with_color(graph_update,extra_node=gg.x.shape[0])

In [None]:
model(graph_update)[2]

In [None]:
def run_shapley_mini(graph,level=1,shapley_storage=None,skip_to_round=1,previous_shapley=None,node_list=node_lst): 
    graph_label=model(graph)[2]
    node_num=len(node_list)
    n=node_num
    if shapley_storage==None: 
        shapley_storage = torch.zeros((n,n,n))
    else: 
        shapley_storage=shapley_storage
    shapley_dict={}
    round=skip_to_round
    rank=1

    if round==1:     
        shapley_dict["rank"+str(rank)]=[]
        for t in range(node_num):
            i=node_list[t]
            edge_index_update=graph.edge_index
            # print(graph)
            graph_i=seperate_graph(graph=graph,node_to_excl=i)
            # print(graph_i)
            # graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i],add_sudo_nodes=False)  
            shapley_i=graph_label-model(graph_i)[2] 
            # print(shapley_i)
            if shapley_i!=0: 
                shapley_storage[i,0,0]=1
                shapley_dict["rank"+str(rank)].append(i)
        if len(shapley_dict["rank"+str(rank)])>0: 
            shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
            if rank==level: 
                return shapley_dict,shapley_storage
            else: 
                rank+=1 
                # round+=1                       
        else: 
            print("No results after the "+str(round)+" round")
        round+=1 
        
    # if round==2: 
    #     shapley_dict["rank"+str(rank)]=[]
    #     for t in range(node_num):
    #         # print(i)
    #         for s in range(t):
    #             i=node_list[t]
    #             j=node_list[s]
    #             graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j],add_sudo_nodes=False)    
    #             shapley_i=shapley_storage[i,0,0]-model(graph_i)[2] 
    #             shapley_i=graph_label-model(graph_i)[2] 
                
    #             # if (i==5) and (j==4): 
    #             #     print(shapley_storage[i,0,0],model(graph_i)[2],shapley_i)
    #             if shapley_i!=0: 
    #                 shapley_storage[i,j,0]=1
    #                 shapley_dict["rank"+str(rank)].append(i)
    #                 shapley_dict["rank"+str(rank)].append(j)
    
    #     if len(shapley_dict["rank"+str(rank)])>0: 
    #         shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
    #         if rank==level: 
    #             return shapley_dict,shapley_storage 
    #         else: rank+=1 
    #     else: 
    #         print("No results after the "+str(round)+" round") 
    #     round+=1
            
    # if round==3: 
    #     shapley_dict["rank"+str(rank)]=[]
    #     for t in range(node_num):
    #         for s in range(i): 
    #             for w in range(j): 
    #                 i=node_list[t]
    #                 j=node_list[s]
    #                 k=node_list[w]
    #                 # graph_i=new_subgraph(graph=graph,include_node_lst=None,exclude_node_lst=[i,j,k],add_sudo_nodes=False)    
    #                 shapley_i=shapley_storage[i,j,0]-model(graph_i)[2] 
    #                 shapley_i=graph_label-model(graph_i)[2] 
    #                 # if shapley_i>0: 
    #                 if model(graph_i)[2]<1: 
    #                     shapley_storage[i,j,k]=1
    #                     shapley_dict["rank"+str(rank)].append(i)
    #                     shapley_dict["rank"+str(rank)].append(j)
    #                     shapley_dict["rank"+str(rank)].append(k)
    #     if len(shapley_dict["rank"+str(rank)])>0: 
    #         shapley_dict["rank"+str(rank)]=shapley_dict_count(shapley_dict["rank"+str(rank)])
    #         if rank==level: 
    #             return shapley_dict,shapley_storage 
    #         else: rank+=1 
    #     else: 
    #         print("No results after the "+str(round)+" round") 
    #         round+=1
    #     # return shapley_dict  , shapley_storage

In [None]:
run_shapley(gg)[0]

In [None]:
run_shapley_mini(gg,level=1)[0]

In [None]:
seperate_1=new_subgraph(gg,list(run_shapley(gg)[0]["rank1"].keys()),draw_subgraph=True)

In [None]:
from torch_geometric.utils.convert import from_networkx
sequence = [1, 2, 2, 3]
G = nx.random_degree_sequence_graph(sequence)
G = from_networkx(G)
G.x=torch.rand(4,1)
G.x

In [None]:
draw_with_color(G)

## Pablo Algo 

In [None]:
fac_w=[]
for k in range(1,n+1): 
    fac_w.append(factorial(n-k)*factorial(k-1)/factorial(n))
fac_w

In [None]:
sum(fac_w)

In [None]:
def partial_shapley(graph,K=None): 
    n=graph.x.shape[0]
    # print(n)
    x_lst=[i for i in range(n)]
    # print(n)
    if K==None: 
        K=n
    else: K=K 
    x_lst
    fac_w=[]
    for k in range(1,n+1): 
        fac_w.append(factorial(n-k)*factorial(k-1)/factorial(n))
    fac_w
    score_dict=dict(zip(x_lst,np.zeros(len(x_lst))))
    score_dict
    pos_G_k={}
    for k in range(K): 
        # if k==1: 
        #     pos_G_k[k-1]=list(itertools.combinations(x_lst, k))
        if k==0: 
            pos_G_k[k-1]=[]
            pos_G_k[k]=[] 
            w_k=factorial(n-k-1)*factorial(k)/factorial(n)
            for i in list(set(x_lst)): 
                # print(i)
                U_i=[i]
                graph_U_i=new_subgraph(graph,include_node_lst=list(set(x_lst)-set(U_i)),exclude_node_lst=None,add_sudo_nodes=False)
                f_U_i=model(graph_U_i)[2] 
                
                if int(f_U_i)==int(0): 
                    # print(f_U_i)
                    # score_dict[i]=score_dict[i]+w_k
                    score_dict[i]=score_dict[i]+w_k
                else: 
                    pos_G_k[k].append(U_i)
                    # print(U_i)
        else: 
            # if k==1: 
            #     print(k)
            #     print(pos_G_k[k-1])
            pos_G_k[k]=[] 
            for U in pos_G_k[k-1]:
                # print(list(U)) 
                U=list(U) 
                w_k=factorial(n-k-1)*factorial(k)/factorial(n)
                for i in list(set(x_lst)-set(U)): 
                    # print(i)
                    U_i=U+[i]
                    # graph_U_i=new_subgraph(graph,include_node_lst=list(set(x_lst)-set(U_i)),exclude_node_lst=None,add_sudo_nodes=False)
                    graph_U_i=new_subgraph(graph,include_node_lst=None,exclude_node_lst=U_i,add_sudo_nodes=False)
                    f_U_i=model(graph_U_i)[2] 
                    
                    if int(f_U_i)==int(0): 
                        # print(f_U_i)
                        # score_dict[i]=score_dict[i]+w_k
                        score_dict[i]=score_dict[i]+w_k
                    else: pos_G_k[k].append(U_i)
                
    return score_dict

In [None]:
draw_with_color(gg_star_2)

In [None]:
partial_shapley(gg_star_2)

In [None]:
partial_shapley(gg_star_2,K=1)

In [None]:
partial_shapley(gg_star_2,K=3)

In [None]:
sum(list(partial_shapley(gg_star_2).values()))

In [None]:
partial_shapley(gg_star_1)

In [None]:
partial_shapley(gg_star_3,K=4)

In [None]:
partial_shapley(gg_star_3,K=3)

In [None]:
sum(list(partial_shapley(gg_star_3,K=3).values()))

In [None]:
partial_shapley(gg_star_3,K=4)

In [None]:
sum(list(partial_shapley(gg_star_3,K=4).values()))

In [None]:
partial_shapley(gg_star_3,K=1)

In [None]:
partial_shapley(gg_star_3,K=2)

In [None]:
partial_shapley(graph_test_1,K=3)

In [None]:
graph_test_2 = tu2[7]
# draw_with_color(graph_test_1)
partial_shapley(graph_test_2,K=3)

In [None]:
sum(list(partial_shapley(gg_star_3,K=6).values()))

In [None]:
graph=graph_test_1

In [None]:
lst1=[1,2,3]
lst2=np.zeros(len(lst1))
dict(zip(lst1,lst2))

In [None]:
n

In [None]:
k=1
n=graph.x.shape[0]
x_lst=[i for i in range(n)]
x_lst

In [None]:
score_dict=dict(zip(x_lst,np.zeros(len(x_lst))))
score_dict

In [None]:
pos_G_k={}
pos_G_k[k-1]=list(itertools.combinations(x_lst, k))
pos_G_k[k-1]

In [None]:
int(f_U_i)

In [None]:
pos_G_k[k]=[]
# k=n
for U in pos_G_k[k-1]:
    # print(list(U)) 
    U=list(U) 
    w_k=factorial(n-k)*factorial(k-1)/factorial(n)
    for i in list(set(x_lst)-set(U)): 
        # print(i)
        U_i=U+[i]
        graph_U_i=new_subgraph(graph,include_node_lst=U_i,exclude_node_lst=None,add_sudo_nodes=False)
        f_U_i=model(graph_U_i)[2] 
        
        if int(f_U_i)==int(1): 
            # print(f_U_i)
            score_dict[i]=score_dict[i]+w_k
        else: pos_G_k[k].append(U_i)
pos_G_k[k]

In [None]:
graph_U_i=new_subgraph(graph,include_node_lst=U_i,exclude_node_lst=None,add_sudo_nodes=False)
draw_with_color(graph_U_i)

In [None]:
score_dict

In [None]:
torch.randint(3, 5, (3,))

In [None]:
graph_test_1

In [None]:
seperate_1

In [None]:
run_shapley_specify(seperate_1[0],spec_node=list(seperate_1[1].keys()))

In [None]:
def backward_rank(graph,levels=2): 
    shapley_dict["rank"+str(levels)]=shapley_dict["rank1"]
    shapley_dict["rank1"]=[]
    # graph_with_sudo = new_subgraph(graph=graph,exclude_node_lst=[0],add_sudo_nodes=True)
    for i in shapley_dict["rank"+str(levels)]:
        graph_i = new_subgraph(graph=graph,exclude_node_lst=[i],add_sudo_nodes=True)
    # global round
        # subset = torch.ones_like(graph_with_sudo.edge_index[0], dtype = bool)
        # subset[[i]] = False
        # # print(subset)
        # # graph_i=torch_geometric.utils.subgraph(subset,graph_test_1.edge_index)
        # graph_i = torch_geometric.data.Data(x=th_delete(graph_with_sudo.x,[i]),
        #                                     edge_index=torch_geometric.utils.subgraph(subset,graph_with_sudo.edge_index,relabel_nodes=True)[0],
        #                                     y=graph_with_sudo.y)
        shapley_i=model(graph_i)[2] 
        # print(shapley_i)
        if shapley_i>0: 
            print(i)
            shapley_storage[i,:,:,:,:]=1
            shapley_dict["rank"+str(levels-1)].append(i)
        if levels<=1:  
            break 
        else: level=levels-1 
shapley_dict

## Conclusion

In this tutorial, we have seen the application of neural networks to graph structures. We looked at how a graph can be represented (adjacency matrix or edge list), and discussed the implementation of common graph layers: GCN and GAT. The implementations showed the practical side of the layers, which is often easier than the theory. Finally, we experimented with different tasks, on node-, edge- and graph-level. Overall, we have seen that including graph information in the predictions can be crucial for achieving high performance. There are a lot of applications that benefit from GNNs, and the importance of these networks will likely increase over the next years.

---

[![Star our repository](https://img.shields.io/static/v1.svg?logo=star&label=⭐&message=Star%20Our%20Repository&color=yellow)](https://github.com/phlippe/uvadlc_notebooks/)  If you found this tutorial helpful, consider ⭐-ing our repository.    
[![Ask questions](https://img.shields.io/static/v1.svg?logo=star&label=❔&message=Ask%20Questions&color=9cf)](https://github.com/phlippe/uvadlc_notebooks/issues)  For any questions, typos, or bugs that you found, please raise an issue on GitHub. 

---