In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys
sys.path.append('..')
import os
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, dense_diff_pool
import torch_explain as te
from torch_explain.logic.nn import entropy
from torch_explain.logic.metrics import test_explanation, complexity

import numpy as np
import pandas as pd
from pytorch_lightning.utilities.seed import seed_everything
from scipy.spatial.distance import cdist
from sympy import to_dnf, lambdify
from sklearn.metrics.cluster import homogeneity_score, completeness_score

from sklearn.cluster import KMeans

import clustering_utils
import data_utils
import lens_utils
import model_utils
import persistence_utils
import visualisation_utils

In [None]:
# constants
DATASET_NAME = "Reddit_Binary"
MODEL_NAME = f"GCN for {DATASET_NAME}"
NUM_CLASSES = 2
K = 30

TRAIN_TEST_SPLIT = 0.8

NUM_HIDDEN_UNITS = 40
EPOCHS = 1000
LR = 0.001

BATCH_SIZE = 16

NUM_NODES_VIEW = 5
NUM_EXPANSIONS = 4

LAYER_NUM = 3
LAYER_KEY = "conv3"

visualisation_utils.set_rc_params()

In [None]:
# model definition
class GCN(nn.Module):
    def __init__(self, num_in_features, num_hidden_features, num_classes):
        super(GCN, self).__init__()
        
        self.conv0 = GCNConv(num_in_features, num_hidden_features)
        self.conv1 = GCNConv(num_hidden_features, num_hidden_features)
        self.conv2 = GCNConv(num_hidden_features, num_hidden_features)
        self.conv3 = GCNConv(num_hidden_features, 10)
        
        self.pool = model_utils.Pool()

        # linear layers
        self.linear = nn.Linear(10, num_classes)

    def forward(self, x, edge_index, batch):
        x = self.conv0(x, edge_index)
        x = F.leaky_relu(x)

        x = self.conv1(x, edge_index)
        x = F.leaky_relu(x)

        x = self.conv2(x, edge_index)
        x = F.leaky_relu(x)
        
        x = self.conv3(x, edge_index)
        x = F.leaky_relu(x)
        
        self.gnn_node_embedding = x
        
        x = self.pool(x, batch)
        
        self.gnn_graph_embedding = x

        x = self.linear(x)
                
        return F.log_softmax(x, dim=-1)

In [None]:
def run_experiment(seed, path, load_pretrained=False):
    persistence_utils.save_config(seed, DATASET_NAME, MODEL_NAME, NUM_CLASSES, K, TRAIN_TEST_SPLIT, NUM_HIDDEN_UNITS, EPOCHS, LR, NUM_NODES_VIEW, NUM_EXPANSIONS, LAYER_NUM, LAYER_KEY, path)
        
    # load data
    graphs = data_utils.load_real_data(DATASET_NAME)
    data = data_utils.prepare_real_data(graphs, TRAIN_TEST_SPLIT, BATCH_SIZE, DATASET_NAME)
    train_loader, test_loader, full_train_loader, full_test_loader, full_loader, small_loader = data

    # model training
    model = GCN(graphs.num_node_features, NUM_HIDDEN_UNITS, graphs.num_classes)
    
    # register hooks to track activation
    model = model_utils.register_hooks(model)

    # train 
    train_acc, test_acc, train_loss, test_loss = model_utils.train_graph_class(model, train_loader, test_loader, full_loader, EPOCHS, LR, if_interpretable_model=False)
    persistence_utils.persist_model(model, path, 'model.z')

    visualisation_utils.plot_model_accuracy(train_acc, test_acc, MODEL_NAME, path)
    visualisation_utils.plot_model_loss(train_loss, test_loss, MODEL_NAME, path)

    # get model activations for complete dataset
    train_data = next(iter(full_train_loader))
    _ = model(train_data.x, train_data.edge_index, train_data.batch)
    train_activation = model.gnn_node_embedding
    
    test_data = next(iter(full_test_loader))
    _ = model(test_data.x, test_data.edge_index, test_data.batch)
    test_activation = model.gnn_node_embedding
    
    activation = torch.vstack((train_activation, test_activation)).detach().numpy()
    persistence_utils.persist_experiment(activation, path, 'activation.z')
    
    y = torch.cat((train_data.y, test_data.y))
    expanded_train_y = data_utils.reshape_graph_to_node_data(train_data.y, train_data.batch)
    expanded_test_y = data_utils.reshape_graph_to_node_data(test_data.y, test_data.batch)
    expanded_y = torch.cat((expanded_train_y, expanded_test_y))
    
    train_mask = np.zeros(activation.shape[0], dtype=bool)
    train_mask[:train_activation.shape[0]] = True
    test_mask = ~train_mask
    
    offset = train_data.batch[-1] + 1
    batch = torch.cat((train_data.batch, test_data.batch + offset))

    # find centroids
    kmeans_model = KMeans(n_clusters=K, random_state=seed)
    kmeans_model = kmeans_model.fit(activation[train_mask])
    used_centroid_labels = kmeans_model.predict(activation)
    centroid_labels = np.sort(np.unique(used_centroid_labels))
    centroids = kmeans_model.cluster_centers_
    
    persistence_utils.persist_experiment(kmeans_model, path, 'kmeans_model.z')
    persistence_utils.persist_experiment(centroids, path, 'centroids.z')
    persistence_utils.persist_experiment(centroid_labels, path, 'centroid_labels.z')
    persistence_utils.persist_experiment(used_centroid_labels, path, 'used_centroid_labels.z')
        
    print(f"Number of cenroids: {len(centroids)}")
    
    # concept alignment
    homogeneity = homogeneity_score(expanded_y, used_centroid_labels)
    
    # clustering efficency
    completeness = completeness_score(expanded_y, used_centroid_labels)
    
    # calculate cluster sizing
    cluster_counts = visualisation_utils.print_cluster_counts(used_centroid_labels)
    
    concept_metrics = [('homogeneity', homogeneity), ('completeness', completeness), ('cluster_count', cluster_counts)]
    persistence_utils.persist_experiment(concept_metrics, path, 'concept_metrics.z')

    print(f"Concept homogeneity score: {homogeneity}")
    print(f"Concept completeness score: {completeness}")

    # REDUCING DATA TO TRAINING SET
    test_activation = test_activation.detach().numpy()
    expanded_test_mask = data_utils.reshape_graph_to_node_data(test_mask, batch)
    test_used_centroid_labels = kmeans_model.predict(test_activation)

    print("Nodes to visualise ", expanded_test_mask.shape)
    
    # plot clustering
    visualisation_utils.plot_clustering(seed, test_activation, expanded_test_y, centroids, centroid_labels, test_used_centroid_labels, MODEL_NAME, LAYER_NUM, path)
    
    # plot samples
    edges_t = test_data.edge_index.transpose(0, 1).detach().numpy()
    sample_graphs, sample_feat = visualisation_utils.plot_samples(None, test_activation, expanded_test_y, LAYER_NUM, len(centroids), "Differential Clustering", "Raw", NUM_NODES_VIEW, edges_t, NUM_EXPANSIONS, path, concepts=centroids)
    persistence_utils.persist_experiment(sample_graphs, path, 'sample_graphs.z')
    persistence_utils.persist_experiment(sample_feat, path, 'sample_feat.z')

In [None]:
# run multiple times for confidence interval - seeds generated using Google's random number generator
random_seeds = [42, 19, 76, 58, 92]

for seed in random_seeds:
    print("\nSTART EXPERIMENT-----------------------------------------\n")
    seed_everything(seed)
    
    path = os.path.join("..", "output", "Standard_" + DATASET_NAME, f"seed_{seed}")
    data_utils.create_path(path)

    run_experiment(seed, path, load_pretrained=False)
    
    print("\nEND EXPERIMENT-------------------------------------------\n")