In [13]:
from GNNTrain import predict_from_saved_model
from CreateDatasetv2 import get_dataset_from_graph
from Paths import PATH_TO_GRAPHS, PATH_TO_RANKINGS
from GDARanking import predict_candidate_genes
from GraphSageModel import GNN7L_Sage

from dig.xgraph.method import SubgraphX

## Get rankings
Change the following parameters to get the final ranking:
- disease_id $\in$ ['C0006142','C0009402','C0023893','C0036341','C0376358']
- modality $\in$ ['binary', 'multiclass']
- exp_method $\in$ ['GNNExplainer', 'GraphSVX', 'SubgraphX']

In [14]:
disease_Id  = 'C0006142'
modality    = 'multiclass'
exp_method  = 'GraphSVX'

print('[+] Performing GDA prioritization on disease', disease_Id, 'with', modality, 'classificaiton, using', exp_method, 'as explainability method.' )

[+] Performing GDA prioritization on disease C0006142 with multiclass classificaiton, using GNNExplainer as explainability method.


In [15]:
classes     = []
model_name  = 'GraphSAGE_' + disease_Id + '_new_rankings'

if modality == 'binary':
    classes = ['P', 'U']
    # If using binary classification method change exp_method and model_name
    exp_method += '_only'
    model_name += 'binary'
elif modality == 'multiclass':
    classes     = ['P', 'LP', 'WN', 'LN', 'RN']
else:
    print('[ERR] Wrong modality!', modality, 'not in [binary, multiclass].')

# Specify wich model to use, based on training epochs and weight decay rate
model_name += '_40000_0_0005'

##################################
# Get Pytorch dataset from graph #
##################################
graph_path  = PATH_TO_GRAPHS + 'grafo_nedbit_' + disease_Id + '.gml'

print('[+] Loading dataset from graph:', graph_path)
dataset, G = get_dataset_from_graph(graph_path, disease_Id, quartile=False)

####################################
# Get predictions from saved model #
####################################
print()
print('[+] Metrics report of saved model:', model_name)
preds, probs, model = predict_from_saved_model(model_name, dataset, classes, save_to_file=False)



[+] Loading dataset from graph: Graphs/grafo_nedbit_C0006142.gml
[+] Reading graph...ok
[+] Creating dataset...ok
[i] Elapsed time: 19.551

[+] Metrics report of saved model: GraphSAGE_C0006142_new_rankings_40000_0_0005
              precision    recall  f1-score   support

           0       0.97      1.00      0.98       154
           1       0.93      0.96      0.95       739
           2       0.90      0.89      0.90       739
           3       0.93      0.84      0.88       739
           4       0.89      0.97      0.93       593

    accuracy                           0.92      2964
   macro avg       0.92      0.93      0.93      2964
weighted avg       0.92      0.92      0.92      2964



In [None]:
ranking = predict_candidate_genes(model,
                                  dataset,
                                  preds,
                                  disease_Id,
                                  exp_method,
                                  explanation_nodes_ratio=1,
                                  masks_for_seed=5,
                                  num_hops=1,
                                  G=G,
                                  num_pos='all',
                                  threshold = True)

########################
# Save ranking to file #
########################
ranking_path = PATH_TO_RANKINGS + \
               disease_Id + \
               '_all_positives_new_ranking_' + \
               exp_method.lower().replace("_only", "") + \
               '.txt'

with open(ranking_path, 'w') as f:
     for line in ranking:
        f.write(line + '\n')