In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, dense_diff_pool
import torch_explain as te
from torch_explain.logic.nn import entropy
from torch_explain.logic.metrics import test_explanation, complexity

import numpy as np
import pandas as pd
from pytorch_lightning.utilities.seed import seed_everything
from scipy.spatial.distance import cdist
from sympy import to_dnf, lambdify
from sklearn.metrics.cluster import homogeneity_score, completeness_score

import clustering_utils
import data_utils
import lens_utils
import model_utils
import persistence_utils
import visualisation_utils

In [3]:
# constants
DATASET_NAME = "BA_Shapes"
MODEL_NAME = f"GCN for {DATASET_NAME}"
NUM_CLASSES = 4
K = 10

TRAIN_TEST_SPLIT = 0.8

NUM_HIDDEN_UNITS = 10
EPOCHS = 7000
LR = 0.001

RANDOM_STATE = 0

NUM_NODES_VIEW = 5
NUM_EXPANSIONS = 2

LAYER_NUM = 3
LAYER_KEY = "conv3"

visualisation_utils.set_rc_params()

In [4]:
# model definition
class GCN(nn.Module):
    def __init__(self, num_in_features, num_hidden_features, num_classes):
        super(GCN, self).__init__()
        
        self.conv0 = GCNConv(num_in_features, num_hidden_features)
        self.conv1 = GCNConv(num_hidden_features, num_hidden_features)
        self.conv2 = GCNConv(num_hidden_features, num_hidden_features)
        self.conv3 = GCNConv(num_hidden_features, num_hidden_features)
#         self.conv4 = GCNConv(num_hidden_features, num_hidden_features)
        
        # linear layers
        self.lens = torch.nn.Sequential(te.nn.EntropyLinear(num_hidden_features, 1, n_classes=num_classes))

    def forward(self, x, edge_index):
        x = self.conv0(x, edge_index)
        x = F.leaky_relu(x)

        x = self.conv1(x, edge_index)
        x = F.leaky_relu(x)

        x = self.conv2(x, edge_index)
        x = F.leaky_relu(x)
        
        x = self.conv3(x, edge_index)
        x = F.leaky_relu(x)
        
#         x = self.conv4(x, edge_index)
#         x = F.leaky_relu(x)
                
        self.gnn_embedding = x
        
        x = F.softmax(x, dim=-1)
        x = torch.div(x, torch.max(x, dim=-1)[0].unsqueeze(1))
        concepts = x
        
        x = self.lens(x)
                
        return concepts, x.squeeze(-1)

In [5]:
path = os.path.join("..", "output", DATASET_NAME, "seed_42")

In [34]:
# model training
data = persistence_utils.load_experiment(path, "data.z")
x = data['x']
y = data['y']
test_mask = data['test_mask']
edges = data['edges']
model = GCN(data["x"].shape[1], NUM_HIDDEN_UNITS, NUM_CLASSES)
model = persistence_utils.load_model(model, path, 'model.z')
model.eval()

array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20])

In [44]:
concepts, y_pred = model(x, edges)
cluster_general_labels = torch.tensor(np.unique(concepts>0.5, axis=0)).float()
cluster_general_labels.shape

torch.Size([21, 10])

In [60]:
import copy
from sklearn.metrics import accuracy_score
wrong_pred_idx = torch.where(y_pred.argmax(dim=-1) != y)[0]
interventions = copy.deepcopy(y_pred.argmax(dim=-1))
accuracy = accuracy_score(y, interventions)
print(f'Initial accuracy: {accuracy}')
for wrong_id in wrong_pred_idx:
    wrong_concept = concepts[wrong_id]
    correct_y = y[wrong_id]
    max_vals, candidate_labels = model.lens(cluster_general_labels).max(dim=-2)
    correct_mask = torch.where(candidate_labels == correct_y)
    max_vals_filtered, candidate_labels_filtered = max_vals[correct_mask], candidate_labels[correct_mask]
    new_y = candidate_labels_filtered[max_vals_filtered.argmax()]
    interventions[wrong_id] = new_y
    accuracy = accuracy_score(y, interventions)
    print(f'New accuracy: {accuracy}')

Initial accuracy: 0.9871428571428571
New accuracy: 0.9885714285714285
New accuracy: 0.99
New accuracy: 0.9914285714285714
New accuracy: 0.9928571428571429
New accuracy: 0.9942857142857143
New accuracy: 0.9957142857142857
New accuracy: 0.9971428571428571
New accuracy: 0.9985714285714286
New accuracy: 1.0


(tensor([ 2,  3,  4,  5, 12, 13, 18, 20]), tensor([0, 0, 0, 0, 0, 0, 0, 0]))