In [1]:
# See description above for what these variables represent
big_loop_iterations = 10
explainer_runs = 10
thresholds = [30, 50]
samples = [10]

# Used to label the resulting files -
date = "quantile_alg_final"

# Set up directory for result files
import os
dir = f'./results_{date}'
if not os.path.exists(dir):
    os.mkdir(dir)

In [2]:
from GNNSubNet import GNNSubNet as gnn
import pandas as pd
import numpy as np

# # Kidney data set  ------------------------- #
loc   = "../TCGA"
ppi   = f'{loc}/KIDNEY_RANDOM_PPI.txt'
feats = [f'{loc}/KIDNEY_RANDOM_Methy_FEATURES.txt', f'{loc}/KIDNEY_RANDOM_mRNA_FEATURES.txt']
targ  = f'{loc}/KIDNEY_RANDOM_TARGET.txt'

# # Synthetic data set  ------------------------- #
# loc   = "../GNNSubNet/datasets/synthetic/"
# ppi   = f'{loc}/NETWORK_synthetic.txt'
# feats = [f'{loc}/FEATURES_synthetic.txt']
# targ  = f'{loc}/TARGET_synthetic.txt'

# Read in the synthetic data
g = gnn.GNNSubNet(loc, ppi, feats, targ, normalize=False)

# Get some general information about the data dimension
g.summary()

Graph is connected  False
Calculate subgraph ...
Number of subgraphs:  118
Size of subgraph:  2049
Graph is connected  True
##################
# DATASET LOADED #
##################

Number of nodes: 2049
Number of edges: 13588
Number of modalities: 2


In [3]:
def obtain_BAGEL_scores(loc, ppi, feats, targ, gnn_subnet, quantile_aggregation=False, quantile=0.5):
    model_info = []
    fidelity = []
    validity_plus = []
    validity_minus = []
    validity_plus_matrix = []
    validity_minus_matrix = []
    sparsity = []

    for i in range(9, big_loop_iterations):
        print(i)
        g = gnn.GNNSubNet(loc, ppi, feats, targ, normalize=False)
        g.train()

        # Check the performance of the classifier
        accuracy = g.accuracy

        # Run the explainer the desired number of times
        g.explain(explainer_runs, gnn_subnet=gnn_subnet, quantile_aggregation=quantile_aggregation, quantile=quantile)

        # Fidelity
        f = g.evaluate_RDT_fidelity_soft()
        fidelity.append([i, accuracy, np.mean(f)])
        # Save mean fidelity for each sample for further analysis
        filename = f"results_{date}/{i}_fidelities.csv"
        np.savetxt(filename, fidelity, delimiter=',', fmt ='% s')

        # Sparsity
        sparsities = g.evaluate_sparsity()
        # Save raw results in case needed for further analysis
        filename = f"results_{date}/{i}_sparsities.csv"
        np.savetxt(filename, sparsities, delimiter=',', fmt ='% s')
        # Save mean sparsity to list to create processed table
        sparsity.append([i, accuracy, np.mean(sparsities)])

        # Validity at varying thresholds
        for t in thresholds:
            v_plus, v_minus, mat_plus, mat_minus = g.evaluate_validity(threshold=t, confusion_matrix=True)
            validity_plus.append([i, accuracy, t, v_plus])
            validity_minus.append([i, accuracy, t, v_minus])
            validity_plus_matrix.append([i, accuracy, t, mat_plus[0,0], mat_plus[0,1], mat_plus[1,0], mat_plus[1,1]])
            validity_minus_matrix.append([i, accuracy, t, mat_minus[0,0], mat_minus[0,1], mat_minus[1,0], mat_minus[1,1]])

        filename = f"results_{date}/{i}_validity_plus.csv"
        np.savetxt(filename, validity_plus, delimiter=',', fmt ='% s')

        filename = f"results_{date}/{i}_validity_minus.csv"
        np.savetxt(filename, validity_minus, delimiter=',', fmt ='% s')

In [4]:
print(f"results_{date}")

results_quantile_alg_final


In [5]:
# Runs the experiment for a single explainer
# Toggle the gnn_subnet parameter for the desired explainer
obtain_BAGEL_scores(loc, ppi, feats, targ, gnn_subnet=False, quantile_aggregation=True, quantile=0.5)

9
Graph is connected  False
Calculate subgraph ...
Number of subgraphs:  118
Size of subgraph:  2049
Graph is connected  True
##################
# DATASET LOADED #
##################
graphcnn for training ...
Graphs class 0: 200, Graphs class 1: 306
Length of balanced dataset list: 400
Train graph class 0: 162, train graph class 1: 158
Validation graph class 0: 38, validation graph class 1: 42


  graph_pool = torch.sparse.FloatTensor(idx, elem, torch.Size([len(batch_graph), start_idx[-1]]))
100%|██████████| 35/35 [00:08<00:00,  4.10batch/s]


Epoch 0, loss 2190.6757
Train Acc 0.5156
Epoch 0, val_loss 24.9896
Saving best model with validation loss 24.989639282226562


100%|██████████| 35/35 [00:11<00:00,  3.10batch/s]


Epoch 1, loss 332.4395
Train Acc 0.5062
Epoch 1, val_loss 185.1322


100%|██████████| 35/35 [00:10<00:00,  3.44batch/s]


Epoch 2, loss 56.7439
Train Acc 0.4938
Epoch 2, val_loss 236.2896


100%|██████████| 35/35 [00:10<00:00,  3.33batch/s]


Epoch 3, loss 27.2859
Train Acc 0.5437
Epoch 3, val_loss 14.2863
Saving best model with validation loss 14.2862548828125


100%|██████████| 35/35 [00:08<00:00,  4.02batch/s]


Epoch 4, loss 24.2290
Train Acc 0.6031
Epoch 4, val_loss 8.3385
Saving best model with validation loss 8.338519096374512


100%|██████████| 35/35 [00:09<00:00,  3.73batch/s]


Epoch 5, loss 29.9741
Train Acc 0.6312
Epoch 5, val_loss 15.2460


100%|██████████| 35/35 [00:12<00:00,  2.80batch/s]


Epoch 6, loss 25.2902
Train Acc 0.6281
Epoch 6, val_loss 17.3517


100%|██████████| 35/35 [00:12<00:00,  2.87batch/s]


Epoch 7, loss 19.7972
Train Acc 0.5062
Epoch 7, val_loss 21.0744


100%|██████████| 35/35 [00:11<00:00,  3.02batch/s]


Epoch 8, loss 35.4303
Train Acc 0.6937
Epoch 8, val_loss 4.8031
Saving best model with validation loss 4.803082466125488


100%|██████████| 35/35 [00:11<00:00,  3.00batch/s]


Epoch 9, loss 25.2290
Train Acc 0.6156
Epoch 9, val_loss 9.2127


100%|██████████| 35/35 [00:11<00:00,  3.02batch/s]


Epoch 10, loss 19.4481
Train Acc 0.7781
Epoch 10, val_loss 3.3512
Saving best model with validation loss 3.351222276687622


100%|██████████| 35/35 [00:12<00:00,  2.86batch/s]


Epoch 11, loss 22.6717
Train Acc 0.7625
Epoch 11, val_loss 4.4268


100%|██████████| 35/35 [00:12<00:00,  2.91batch/s]


Epoch 12, loss 17.1809
Train Acc 0.5062
Epoch 12, val_loss 26.8993


100%|██████████| 35/35 [00:08<00:00,  4.28batch/s]


Epoch 13, loss 16.5628
Train Acc 0.5844
Epoch 13, val_loss 22.7078


100%|██████████| 35/35 [00:06<00:00,  5.54batch/s]


Epoch 14, loss 27.8319
Train Acc 0.6875
Epoch 14, val_loss 21.8803


100%|██████████| 35/35 [00:06<00:00,  5.39batch/s]


Epoch 15, loss 16.1926
Train Acc 0.7937
Epoch 15, val_loss 3.4506


100%|██████████| 35/35 [00:06<00:00,  5.68batch/s]


Epoch 16, loss 12.7976
Train Acc 0.8531
Epoch 16, val_loss 4.1890


100%|██████████| 35/35 [00:06<00:00,  5.80batch/s]


Epoch 17, loss 12.1166
Train Acc 0.5312
Epoch 17, val_loss 10.3126


100%|██████████| 35/35 [00:06<00:00,  5.59batch/s]


Epoch 18, loss 15.5479
Train Acc 0.4969
Epoch 18, val_loss 29.1389


100%|██████████| 35/35 [00:06<00:00,  5.77batch/s]


Epoch 19, loss 19.2215
Train Acc 0.7750
Epoch 19, val_loss 4.0681

Confusion matrix (Validation set):

[[25 13]
 [ 9 33]]
Validation accuracy: 72.5%
Validation loss 3.351222276687622

------- Run the Explainer -------

GNN-SubNet: False
Explainer::Iteration 1 of 10
Explainer::Iteration 2 of 10
Explainer::Iteration 3 of 10
Explainer::Iteration 4 of 10
Explainer::Iteration 5 of 10
Explainer::Iteration 6 of 10
Explainer::Iteration 7 of 10
Explainer::Iteration 8 of 10
Explainer::Iteration 9 of 10
Explainer::Iteration 10 of 10
Evaluating RDT-fidelity using 80 graphs from the test dataset
Evaluating 1 out of 80


  test_graph = Data(x=torch.tensor(perturbed_features).float().to("cpu"),


Evaluating 2 out of 80
Evaluating 3 out of 80
Evaluating 4 out of 80
Evaluating 5 out of 80
Evaluating 6 out of 80
Evaluating 7 out of 80
Evaluating 8 out of 80
Evaluating 9 out of 80
Evaluating 10 out of 80
Evaluating 11 out of 80
Evaluating 12 out of 80
Evaluating 13 out of 80
Evaluating 14 out of 80
Evaluating 15 out of 80
Evaluating 16 out of 80
Evaluating 17 out of 80
Evaluating 18 out of 80
Evaluating 19 out of 80
Evaluating 20 out of 80
Evaluating 21 out of 80
Evaluating 22 out of 80
Evaluating 23 out of 80
Evaluating 24 out of 80
Evaluating 25 out of 80
Evaluating 26 out of 80
Evaluating 27 out of 80
Evaluating 28 out of 80
Evaluating 29 out of 80
Evaluating 30 out of 80
Evaluating 31 out of 80
Evaluating 32 out of 80
Evaluating 33 out of 80
Evaluating 34 out of 80
Evaluating 35 out of 80
Evaluating 36 out of 80
Evaluating 37 out of 80
Evaluating 38 out of 80
Evaluating 39 out of 80
Evaluating 40 out of 80
Evaluating 41 out of 80
Evaluating 42 out of 80
Evaluating 43 out of 80
