In [1]:
import sys
sys.path.append('..')
import os
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import joblib
from matplotlib import rc

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, dense_diff_pool
import torch_explain as te
from torch_explain.logic.nn import entropy
from torch_explain.logic.metrics import test_explanation, complexity

import numpy as np
import pandas as pd
from pytorch_lightning.utilities.seed import seed_everything
from scipy.spatial.distance import cdist
from sympy import to_dnf, lambdify
from sklearn.metrics.cluster import homogeneity_score, completeness_score
from sklearn.tree import DecisionTreeClassifier

import clustering_utils
import data_utils
import lens_utils
import model_utils
import persistence_utils
import visualisation_utils
print(os.getcwd())

/Users/luciecharlottemagister/Documents/Cambridge/PhD/Projects/NEWRON_LEN/src/postprocessing


In [2]:
experiment_names = glob.glob('../../output2/Standard**')
experiment_names = [experiment_name.split('Standard_')[1] for experiment_name in experiment_names]
experiment_names = [x for x in experiment_names if '.zip' not in x]
experiment_names += ['Sigmoid_BA_Shapes', 'Sigmoid_BA_Community', 'BA_Shapes', 'BA_Community']
experiment_names

['GINConv_BA_Shapes',
 'SAGEConv_BA_Community',
 'ChebConv_BA_Community',
 'SAGEConv_BA_Shapes',
 'ChebConv_BA_Shapes',
 'GINConv_BA_Community',
 'Sigmoid_BA_Shapes',
 'Sigmoid_BA_Community',
 'BA_Shapes',
 'BA_Community']

In [3]:
standard_experiment_names = []
for exp in experiment_names:
    standard_experiment_names.append(f"Standard_{exp}")
    
experiment_names = experiment_names + standard_experiment_names

experiment_names

['GINConv_BA_Shapes',
 'SAGEConv_BA_Community',
 'ChebConv_BA_Community',
 'SAGEConv_BA_Shapes',
 'ChebConv_BA_Shapes',
 'GINConv_BA_Community',
 'Sigmoid_BA_Shapes',
 'Sigmoid_BA_Community',
 'BA_Shapes',
 'BA_Community',
 'Standard_GINConv_BA_Shapes',
 'Standard_SAGEConv_BA_Community',
 'Standard_ChebConv_BA_Community',
 'Standard_SAGEConv_BA_Shapes',
 'Standard_ChebConv_BA_Shapes',
 'Standard_GINConv_BA_Community',
 'Standard_Sigmoid_BA_Shapes',
 'Standard_Sigmoid_BA_Community',
 'Standard_BA_Shapes',
 'Standard_BA_Community']

In [4]:
base_path = os.path.join("..", "..", "output2")

In [5]:
random_seeds = [42, 19, 76, 58, 92]

accuracy_table = np.zeros((len(random_seeds), len(experiment_names)))
node_count_table = np.zeros((len(random_seeds), len(experiment_names)))

graph_accuracy_table = np.zeros((len(random_seeds), 2))
graph_node_count_table = np.zeros((len(random_seeds), 2))

for row, seed in enumerate(random_seeds):
    for col, exp_name in enumerate(experiment_names):
        print(exp_name)
        path = os.path.join(base_path, exp_name, 'seed_' + str(seed))
        
        if "Mutagenicity" not in exp_name and "Reddit_Binary" not in exp_name:
            if "Standard_Sigmoid" in exp_name or "Standard_BA" in exp_name:
                continue
                
            data = persistence_utils.load_experiment(path, 'data.z')
            train_mask = data['train_mask']
            test_mask = data['test_mask']
            y = data['y']

            used_centroid_labels = persistence_utils.load_experiment(path, 'used_centroid_labels.z').reshape(-1, 1)        

            clf = DecisionTreeClassifier(random_state=seed)
            clf = clf.fit(used_centroid_labels[train_mask], y[train_mask])
            accuracy = clf.score(used_centroid_labels[test_mask], y[test_mask])
            num_nodes = clf.tree_.node_count
        else:
            DATASET = exp_name.replace("Standard_", "")
            print(DATASET)
        
            # load data
            graphs = data_utils.load_real_data(DATASET)
            train_loader, test_loader, full_train_loader, full_test_loader, full_loader, small_loader = data_utils.prepare_real_data(graphs, 0.8, 16, DATASET)

            feat = 40
            
            if "Standard_Sigmoid" in exp_name:
                continue
            
            if "Standard_" in exp_name:                
                train_data = next(iter(full_train_loader))
                test_data = next(iter(full_test_loader))
                y = torch.cat((train_data.y, test_data.y))
        
                expanded_train_y = data_utils.reshape_graph_to_node_data(train_data.y, train_data.batch)
                expanded_test_y = data_utils.reshape_graph_to_node_data(test_data.y, test_data.batch)
                expanded_y = torch.cat((expanded_train_y, expanded_test_y))
                
                class GCN(nn.Module):
                    def __init__(self, num_in_features, num_hidden_features, num_classes):
                        super(GCN, self).__init__()

                        self.conv0 = GCNConv(num_in_features, num_hidden_features)
                        self.conv1 = GCNConv(num_hidden_features, num_hidden_features)
                        self.conv2 = GCNConv(num_hidden_features, num_hidden_features)
                        self.conv3 = GCNConv(num_hidden_features, 10)

                        self.pool = model_utils.Pool()

                        # linear layers
                        self.linear = nn.Linear(10, num_classes)

                    def forward(self, x, edge_index, batch):
                        x = self.conv0(x, edge_index)
                        x = F.leaky_relu(x)

                        x = self.conv1(x, edge_index)
                        x = F.leaky_relu(x)

                        x = self.conv2(x, edge_index)
                        x = F.leaky_relu(x)

                        x = self.conv3(x, edge_index)
                        x = F.leaky_relu(x)

                        self.gnn_node_embedding = x

                        x = self.pool(x, batch)

                        self.gnn_graph_embedding = x

                        x = self.linear(x)

                        return F.log_softmax(x, dim=-1)                    

                model = GCN(graphs.num_node_features, feat, graphs.num_classes)
                model = persistence_utils.load_model(model, path, "model.z")
                _ = model(train_data.x, train_data.edge_index, train_data.batch)
                train_activation = model.gnn_node_embedding    
                _ = model(test_data.x, test_data.edge_index, test_data.batch)
                test_activation = model.gnn_node_embedding
                activation = torch.vstack((train_activation, test_activation)).detach().numpy()

                train_mask = np.zeros(activation.shape[0], dtype=bool)
                train_mask[:train_activation.shape[0]] = True
                test_mask = ~train_mask
                
                kmeans_model = persistence_utils.load_experiment(path, 'kmeans_model.z')
                used_centroid_labels = kmeans_model.predict(activation)

                clf = DecisionTreeClassifier(random_state=42)
                clf = clf.fit(used_centroid_labels[train_mask].reshape(-1, 1), expanded_y[train_mask])
                accuracy = clf.score(used_centroid_labels[test_mask].reshape(-1, 1), expanded_y[test_mask])
                num_nodes = clf.tree_.node_count
            else:
            
                # get model activations for complete dataset
                train_data = next(iter(full_train_loader))
                test_data = next(iter(full_test_loader))

                y = torch.cat((train_data.y, test_data.y))
                expanded_train_y = data_utils.reshape_graph_to_node_data(train_data.y, train_data.batch)
                expanded_test_y = data_utils.reshape_graph_to_node_data(test_data.y, test_data.batch)
                expanded_y = torch.cat((expanded_train_y, expanded_test_y))

                train_mask = np.zeros(y.shape[0], dtype=bool)
                train_mask[:train_data.y.shape[0]] = True
                test_mask = ~train_mask

                offset = train_data.batch[-1] + 1
                batch = torch.cat((train_data.batch, test_data.batch + offset))
                
                class GCN(nn.Module):
                    def __init__(self, num_in_features, num_hidden_features, num_classes):
                        super(GCN, self).__init__()

                        self.conv0 = GCNConv(num_in_features, num_hidden_features)
                        self.conv1 = GCNConv(num_hidden_features, num_hidden_features)
                        self.conv2 = GCNConv(num_hidden_features, num_hidden_features)
                        self.conv3 = GCNConv(num_hidden_features, 10)

                        self.pool = model_utils.Pool()

                        # linear layers
                        self.lens = torch.nn.Sequential(te.nn.EntropyLinear(10, 1, n_classes=num_classes))

                    def forward(self, x, edge_index, batch):
                        x = self.conv0(x, edge_index)
                        x = F.leaky_relu(x)

                        x = self.conv1(x, edge_index)
                        x = F.leaky_relu(x)

                        x = self.conv2(x, edge_index)
                        x = F.leaky_relu(x)

                        x = self.conv3(x, edge_index)
                        x = F.leaky_relu(x)

                        x = x.squeeze()

                        self.gnn_node_embedding = x

                        x = F.softmax(x, dim=-1)
                        x = torch.div(x, torch.max(x, dim=-1)[0].unsqueeze(1))
                        self.gnn_node_concepts = x

                        x = self.pool(x, batch)
                        self.gnn_graph_concepts = x

                        x = self.lens(x)

                        return self.gnn_node_concepts, x.squeeze(-1)

                print("None standard")
                model = GCN(graphs.num_node_features, feat, graphs.num_classes)
                model = persistence_utils.load_model(model, path, "model.z")
                train_node_concepts, _ = model(train_data.x, train_data.edge_index, train_data.batch)
                train_node_activation = model.gnn_node_embedding
                test_node_concepts, _ = model(test_data.x, test_data.edge_index, test_data.batch)
                test_node_activation = model.gnn_node_embedding
                node_activation = torch.vstack((train_node_activation, test_node_activation)).detach().numpy()
                node_concepts = torch.vstack((train_node_concepts, test_node_concepts))
                
                centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(node_activation, node_concepts, expanded_y)

                
                expanded_train_mask = data_utils.reshape_graph_to_node_data(train_mask, batch)
                expanded_test_mask = data_utils.reshape_graph_to_node_data(test_mask, batch)
                clf = DecisionTreeClassifier(random_state=42)
                clf = clf.fit(used_centroid_labels[expanded_train_mask].reshape(-1, 1), expanded_y[expanded_train_mask])
                accuracy = clf.score(used_centroid_labels[expanded_test_mask].reshape(-1, 1), expanded_y[expanded_test_mask])
                num_nodes = clf.tree_.node_count
                
                train_graph_concepts = model.gnn_graph_concepts
                train_graph_activation = model.gnn_graph_concepts
                test_graph_concepts = model.gnn_graph_concepts
                test_graph_activation = model.gnn_graph_concepts
                graph_concepts = torch.vstack([train_graph_concepts, test_graph_concepts])
                graph_activation = torch.vstack((train_graph_activation, test_graph_activation)).detach().numpy()
                centroids, centroid_labels, used_centroid_labels = clustering_utils.find_centroids(graph_activation, graph_concepts, y)

                clf = DecisionTreeClassifier(random_state=42)
                clf = clf.fit(used_centroid_labels[train_mask].reshape(-1, 1), y[train_mask])
                graph_accuracy = clf.score(used_centroid_labels[test_mask].reshape(-1, 1), y[test_mask])
                graph_num_nodes = clf.tree_.node_count

                print(f"{exp_name} {seed} {graph_accuracy} {graph_num_nodes}")
                
                if exp_name == "Mutagenicity":
                    graph_accuracy_table[row, 0] = graph_accuracy
                    graph_node_count_table[row, 0] = graph_num_nodes
                else:
                    graph_accuracy_table[row, 1] = graph_accuracy
                    graph_node_count_table[row, 1] = graph_num_nodes
        
        accuracy_table[row, col] = accuracy
        node_count_table[row, col] = num_nodes
        

accuracy_df = pd.DataFrame(accuracy_table, columns=experiment_names)
node_count_df = pd.DataFrame(node_count_table, columns=experiment_names)

output_path = os.path.join(base_path, "model_metrics")
accuracy_df.to_csv(os.path.join(output_path, 'completeness_scores.csv'))
node_count_df.to_csv(os.path.join(output_path, 'node_count.csv'))

GINConv_BA_Shapes
SAGEConv_BA_Community
ChebConv_BA_Community
SAGEConv_BA_Shapes
ChebConv_BA_Shapes
GINConv_BA_Community
Sigmoid_BA_Shapes
Sigmoid_BA_Community
BA_Shapes
BA_Community
Standard_GINConv_BA_Shapes
Standard_SAGEConv_BA_Community
Standard_ChebConv_BA_Community
Standard_SAGEConv_BA_Shapes
Standard_ChebConv_BA_Shapes
Standard_GINConv_BA_Community
Standard_Sigmoid_BA_Shapes
Standard_Sigmoid_BA_Community
Standard_BA_Shapes
Standard_BA_Community
GINConv_BA_Shapes
SAGEConv_BA_Community
ChebConv_BA_Community
SAGEConv_BA_Shapes
ChebConv_BA_Shapes
GINConv_BA_Community
Sigmoid_BA_Shapes
Sigmoid_BA_Community
BA_Shapes
BA_Community
Standard_GINConv_BA_Shapes
Standard_SAGEConv_BA_Community
Standard_ChebConv_BA_Community
Standard_SAGEConv_BA_Shapes
Standard_ChebConv_BA_Shapes
Standard_GINConv_BA_Community
Standard_Sigmoid_BA_Shapes
Standard_Sigmoid_BA_Community
Standard_BA_Shapes
Standard_BA_Community
GINConv_BA_Shapes
SAGEConv_BA_Community
ChebConv_BA_Community
SAGEConv_BA_Shapes
ChebCon

In [6]:
print(accuracy_df)
print(node_count_df)

   GINConv_BA_Shapes  SAGEConv_BA_Community  ChebConv_BA_Community  \
0           0.965278               0.837370               0.446367   
1           0.976000               0.767176               0.381679   
2           0.992647               0.845588               0.349265   
3           0.970588               0.782609               0.431159   
4           0.993590               0.765957               0.404255   

   SAGEConv_BA_Shapes  ChebConv_BA_Shapes  GINConv_BA_Community  \
0            0.701389            0.701389              0.875433   
1            0.816000            0.704000              0.893130   
2            0.860294            0.669118              0.867647   
3            0.720588            0.705882              0.844203   
4            0.794872            0.762821              0.879433   

   Sigmoid_BA_Shapes  Sigmoid_BA_Community  BA_Shapes  BA_Community  \
0           0.958333              0.747405   0.993056      0.823308   
1           0.952000              

In [7]:
# accuracy_df2 = pd.DataFrame(graph_accuracy_table, columns=["Mutag", "Reddit"])
# node_count_df2 = pd.DataFrame(graph_node_count_table, columns=["Mutag", "Reddit"])

In [8]:
# print(accuracy_df2)
# print(node_count_df2)

In [9]:
# print("BA-Community ", accuracy_df["BA_Community"].mean())

In [10]:
# print("Standard BA-Community ", accuracy_df["Standard_BA_Community"].mean())

In [11]:
# print("BA-Community ", node_count_df["BA_Community"].mean())
# print("Standard BA-Community ", node_count_df["Standard_BA_Community"].mean())