# Imports

In [1]:
import json
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pickle
from lightning import Trainer
from torchmetrics.functional import accuracy, specificity, auroc, recall, precision, f1_score, matthews_corrcoef
from pathlib import Path
from torch_geometric.utils.convert import to_networkx
import seaborn as sns
import time
import os
import sys
from tqdm import tqdm
import ray

sys.path.insert(0, os.path.abspath('../src'))
from model_development.kfold import *
from model_development.datamodule import *

  from .autonotebook import tqdm as notebook_tqdm


FileNotFoundError: [Errno 2] No such file or directory: 'config.json'

# Data loading

In [8]:
config = json.load(open('../config.json', 'r'))

In [6]:
base_path = config.ray_results_path
trial_names = pickle.load(open("../model_selection/gcn_2024-4_best_trial_names.pkl","rb"))

AttributeError: 'dict' object has no attribute 'ray_results_path'

In [None]:
ckpts={}

project = "peppina-final"
for k, name in tqdm(trial_names.items()):
    ckpts[k] = {}
    count=0
    try: 
        dir = [f for f in base_path.rglob(f"*{name}*/")][0]
    except:
        continue
    config = pickle.load((dir/"params.pkl").open("rb"))
    paths = [f for f in (dir/project).rglob(f"*last.ckpt")]
    if len(paths)==0:
        paths=[f for f in (dir/project).rglob(f"*.ckpt")]
    ckpts[k] = {
        'config': config,
        'ckpts': paths
    }
    count+=len(ckpts[k]["ckpts"])
    print(k, count)

  4%|▍         | 1/24 [00:04<01:51,  4.86s/it]

('UC1', '0', 0) 1


  8%|▊         | 2/24 [00:08<01:31,  4.16s/it]

('UC1', '0', 1) 1


 12%|█▎        | 3/24 [00:12<01:21,  3.89s/it]

('UC1', '0', 2) 1


 17%|█▋        | 4/24 [00:15<01:16,  3.83s/it]

('UC1', '0', 3) 1


 21%|██        | 5/24 [00:19<01:11,  3.75s/it]

('UC1', '128', 0) 1


 25%|██▌       | 6/24 [00:23<01:07,  3.76s/it]

('UC1', '128', 1) 1


 29%|██▉       | 7/24 [00:26<01:03,  3.71s/it]

('UC1', '128', 2) 1


 33%|███▎      | 8/24 [00:30<00:59,  3.70s/it]

('UC1', '128', 3) 1


 38%|███▊      | 9/24 [00:35<00:59,  3.98s/it]

('UC2', '0', 0) 1


 42%|████▏     | 10/24 [00:38<00:54,  3.88s/it]

('UC2', '0', 1) 1


 46%|████▌     | 11/24 [00:43<00:53,  4.12s/it]

('UC2', '0', 2) 1


 50%|█████     | 12/24 [00:47<00:47,  3.96s/it]

('UC2', '0', 3) 1


 54%|█████▍    | 13/24 [00:50<00:42,  3.87s/it]

('UC2', '128', 0) 1


 58%|█████▊    | 14/24 [00:54<00:37,  3.79s/it]

('UC2', '128', 1) 1


 62%|██████▎   | 15/24 [00:57<00:33,  3.74s/it]

('UC2', '128', 2) 1


 67%|██████▋   | 16/24 [01:01<00:29,  3.69s/it]

('UC2', '128', 3) 1


 71%|███████   | 17/24 [01:05<00:25,  3.69s/it]

('UC3', '0', 0) 1


 75%|███████▌  | 18/24 [01:08<00:22,  3.68s/it]

('UC3', '0', 1) 1


 79%|███████▉  | 19/24 [01:12<00:18,  3.67s/it]

('UC3', '0', 2) 1


 83%|████████▎ | 20/24 [01:16<00:14,  3.69s/it]

('UC3', '0', 3) 1


 88%|████████▊ | 21/24 [01:20<00:11,  3.96s/it]

('UC3', '128', 0) 1


 92%|█████████▏| 22/24 [01:24<00:07,  3.93s/it]

('UC3', '128', 1) 1


 96%|█████████▌| 23/24 [01:28<00:03,  3.87s/it]

('UC3', '128', 2) 1


100%|██████████| 24/24 [01:32<00:00,  3.84s/it]

('UC3', '128', 3) 1





In [5]:
folds={}

ray.shutdown()
ray.init(ignore_reinit_error=True, _temp_dir="/data/adipalma/tmp")

for hold_out_by in ['UC1','UC2','UC3']:
    for embeddings_len in ['0','128']:#,'onehot']:
        config=ckpts[(hold_out_by, embeddings_len, 0)]['config']
        config['biogrid_ver'] = '2024-10'
        folds[(hold_out_by,embeddings_len)] = create_fold_dict(config=config)

2024-10-30 06:56:35,354	INFO worker.py:1636 -- Started a local Ray instance.


17169
loading  ppi_sensitivity/biogrid/2024-10/folds/random_sensitivity.pkl
17169
loading  ppi_sensitivity/biogrid/2024-10/folds/random_sensitivity.pkl
17169
loading  ppi_sensitivity/biogrid/2024-10/folds/protein_sensitivity.pkl
17169
loading  ppi_sensitivity/biogrid/2024-10/folds/protein_sensitivity.pkl
17169
loading  ppi_sensitivity/biogrid/2024-10/folds/model_sensitivity.pkl
17169
loading  ppi_sensitivity/biogrid/2024-10/folds/model_sensitivity.pkl


In [6]:
metrics = {
    'accuracy': accuracy,
    'specificity': specificity,
    'auroc': auroc,
    'recall': recall,
    'precision': precision,
    'f1': f1_score,
    'mcc': matthews_corrcoef
}

In [7]:
# hide cuda devices
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
def load_and_predict(config, ckpt, hold_out_by, features):
    fold = ray.get(folds[(hold_out_by, features)][0]
                   [('sensitivity',config['hold_out_by']['grid_search'][0] if type(config['hold_out_by'])==dict else config['hold_out_by'])]
                   [config['test_fold']])
    
    if 'aggr' not in config.keys():
        config['aggr'] = config['SAGE_aggr']
    if 'uniform_bound' not in config.keys():
        config['uniform_bound'] = None
    if 'weight_initializer' not in config.keys():
        config['weight_initializer'] = 'kaiming_uniform'


    test = fold['test']
    outer_train = fold['train']
    train = outer_train[config["val_fold"]]['train']
    val = outer_train[config["val_fold"]]['val']
    config['batch_size'] = 5000
    
    data = GraphDataModule(train, val, test, config)
    data.setup()

    input_dim = train[0].x.shape[1]
    output_dim = 1
    start = time.time()
    model = GCN.load_from_checkpoint(ckpt, input_dim=input_dim, output_dim=output_dim, config=config,  map_location=torch.device('cpu'))
    
    cuda_args = {"accelerator": "cpu"}

    trainer = Trainer(enable_progress_bar = False, **cuda_args)
  
    data = GraphDataModule(train, val, test, config)
    data.setup()
    end = time.time()
    print("Time to load model: ", end-start)
    start = time.time()
    predictions = trainer.predict(model, data.test_dataloader())
    end = time.time()
    print("Time to predict: ", end-start)
    # concatenate all predictions in a single tensor
    predictions = torch.cat(predictions)
    print(len(predictions), len(test))
    
    return predictions, test

In [9]:
predictions_dict = {}
for k,v in ckpts.items():
    predictions_dict[k] = {}

    if len(v['ckpts']) == 0:
        print(f"no checkpoint for {k} fold {k[2]}")
        continue
    # print(v['config'], v['ckpts'][0], k[0], k[1])
    try:
 
        predictions = load_and_predict(v['config'], v['ckpts'][0], k[0], k[1])
        predictions_dict[k] = predictions
    except Exception as e:
        print(f"error for {k} fold {k[2]}")
        print(e)
        continue

pickle.dump(predictions_dict, open("predictions_dict.pkl", "wb"))

GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


Time to load model:  4.85598611831665
Time to predict:  3.9597575664520264
4293 4293
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  5.884603023529053
Time to predict:  10.930286407470703
4292 4292


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  4.660712242126465
Time to predict:  4.136009216308594
4292 4292
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  5.0101094245910645
Time to predict:  10.069069147109985
4291 4292


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  4.867904186248779
Time to predict:  6.18917179107666
4293 4293


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  4.769256114959717
Time to predict:  4.200435638427734
4292 4292


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  4.7747087478637695
Time to predict:  4.204930543899536
4292 4292


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  4.9301979541778564
Time to predict:  4.2957470417022705
4291 4292


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  5.047852993011475
Time to predict:  4.843435287475586
2895 2896


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  5.023789882659912
Time to predict:  4.608443975448608
4433 4433


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
------------------------------------- class
Time to load model:  5.106456995010376
Time to predict:  3.8804540634155273
3865 3865


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  5.162142992019653
Time to predict:  4.113345623016357
3398 3398


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
------------------------------------- class
Time to load model:  5.143088340759277
Time to predict:  1.9516239166259766
2895 2896


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  5.244577169418335
Time to predict:  4.775022745132446
4433 4433


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  5.255730152130127
Time to predict:  3.1301584243774414
3865 3865


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


Time to load model:  5.340029954910278
Time to predict:  2.153625726699829
3398 3398
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  5.290875434875488
Time to predict:  4.313160419464111
3294 3294


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
------------------------------------- class
Time to load model:  4.1762096881866455
Time to predict:  3.7138452529907227
3806 3806
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  4.461603879928589
Time to predict:  9.414254188537598
5590 5590
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  5.295984268188477
Time to predict:  5.827144384384155
4478 4479


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
------------------------------------- class
Time to load model:  5.224966764450073
Time to predict:  3.2450878620147705
3294 3294
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  5.360688924789429
Time to predict:  3.9404585361480713
3806 3806
------------------------------------- class


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
Time to load model:  4.347567558288574
Time to predict:  10.358577728271484
5590 5590


GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


------------------------------------- class
------------------------------------- class
Time to load model:  4.133492469787598
Time to predict:  2.315964698791504
4478 4479


In [10]:
predictions_dict=pickle.load(open("predictions_dict.pkl", "rb"))

# Metrics

In [26]:
def compute_stats(predictions, test, confidence_threshold=None):
    stats = []
    # print(len(predictions), len(test))
    # print(type(test))
    if confidence_threshold is not None:
        probs = torch.sigmoid(predictions)
        index = [i for i, p in enumerate(probs) if p < confidence_threshold or p > (1-confidence_threshold)]
        test = [test[i] for i in index]   
        predictions = predictions[index]     
        if len(index) == 0:
            print("no predictions")
            return None
    
    for i,d in enumerate(test):
        if i >= len(predictions):
            break
        # convert the graph to networkx
        g = to_networkx(d, remove_self_loops=True)
        # print(g.edges())
        # print(d['x'])
        input_node = np.where(d['x'][:,0]==1)[0][0]
        output_node = np.where(d['x'][:,1]==1)[0][0]
        # print(input_node, output_node)
        
        def distance(g, input_node, output_node):
            try: 
                return nx.shortest_path_length(g, source=input_node, target=output_node)
            except: 
                return 0

        stats.append({
            'nodes': d['x'].shape[0],
            'edges': d['edge_index'].shape[1] if d['edge_index'].shape[0] == 2 else 0,
            'cc': nx.average_clustering(g),
            'distance_io': distance(g, input_node, output_node),
            'distance_oi': distance(g, output_node, input_node),
            'input_centrality': nx.degree_centrality(g)[input_node],
            'output_centrality': nx.degree_centrality(g)[output_node],
        })

    assert len(stats)==len(predictions)
    df = pd.DataFrame(stats)
    return df


def metrics_from_predictions(predictions, test, confidence_threshold=None):
    
    df = compute_stats(predictions, test, confidence_threshold)
    
    metrics_dict = {}
    for metric_name, fun in {'f1': f1_score}.items():
        metrics_dict[metric_name] = {}
        
        def apply_metric(x):
            preds = predictions[x.index]
            labels = [d['y'] for d in [test[i] for i in x.index]]
            return fun(preds, torch.IntTensor(labels), task='binary')
        
        for grouping, bins in {
            'nodes': range(0, df.nodes.max(), 5),
            'edges': range(0, df.edges.max(), 20),
            'cc': [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],
            'distance_io': range(8),
            'distance_oi': range(8),
        }.items():
            if grouping not in ['distance_io', 'distance_oi']:
                df[f"{grouping}_bin"] = pd.cut(df[grouping], bins=bins)
                # set the bin id as the center of the bin
                df["bin_id"] = df[f"{grouping}_bin"].apply(lambda x: x.mid).astype(float)
                metrics_dict[metric_name][f'by_{grouping}'] = df.groupby("bin_id").apply(apply_metric)
            else:
                metrics_dict[metric_name][f'by_{grouping}'] = df.groupby(grouping).apply(apply_metric)
            
        metrics_dict[metric_name]['mean'] = apply_metric(df)

    # print(by_nodes)

    return metrics_dict, df

In [27]:
metrics_dicts = {}
for k, v in predictions_dict.items():
    metrics_dicts[k]= {}
    metrics_dicts[k], df = metrics_from_predictions(*predictions_dict[k])    

In [28]:
pickle.dump(metrics_dicts, open("metrics_dicts-2024-10.pkl", "wb"))

In [29]:
metrics_dicts = pickle.load(open("metrics_dicts-2024-10.pkl", "rb"))

In [15]:
metrics_dicts_aio = {}

# concatenate the results from the different folds in one list
for hold_out_by in ['UC1','UC2','UC3']:
    for embeddings_len in ['0','128']:
        metrics_dicts_aio[(hold_out_by,embeddings_len)] = {}
        for metric in ["f1"]:
            metrics_dicts_aio[(hold_out_by,embeddings_len)][metric] = {}
            for grouping in ["by_nodes", "by_edges",'by_cc', 'by_distance_io', 'by_distance_oi']: #'by_input_centrality', 'by_output_centrality']:
                if hold_out_by=='UC1' and embeddings_len=='0':
                    metrics_dicts_aio[(hold_out_by,embeddings_len)][metric][grouping] = pd.concat([metrics_dicts[(hold_out_by, embeddings_len, fold)][metric][grouping] for fold in [0,1,2]]).groupby(level=0).mean()
                else:
                    metrics_dicts_aio[(hold_out_by,embeddings_len)][metric][grouping] = pd.concat([metrics_dicts[(hold_out_by, embeddings_len, fold)][metric][grouping] for fold in range(4)]).groupby(level=0).mean()



In [30]:
titles = ['number of nodes', 'number of edges', 'clustering coefficient', 'I/O distance', 'O->I distance', 'input node centrality', 'ouptut node centrality']
bys = ['by_nodes', 'by_edges', 'by_cc', 'by_distance_io', 'by_distance_oi', 'input_centrality', 'output_centrality']

# Performance tables for high-confidence predictions

In [31]:
hc_metrics_dicts = {}
for k,v in predictions_dict.items():
    hc_metrics_dicts[k]= {}
    for fold in v.keys():
        metrics_dict=metrics_from_predictions(*predictions_dict[k][fold], confidence_threshold=1e-2)
        if metrics_dict is not None:
            hc_metrics_dicts[k][fold]=metrics_dict

AttributeError: 'tuple' object has no attribute 'keys'

In [32]:
hc_metrics_dicts = {}
threshold = 0.15
for k,v in predictions_dict.items():
    hc_metrics_dicts[k]= {}
    for fold in v.keys():
        metrics_dict=metrics_from_predictions(*predictions_dict[k][fold], confidence_threshold=threshold)
        if metrics_dict is not None:
            hc_metrics_dicts[k][fold]=metrics_dict

for k,v in hc_metrics_dicts.items():
    if len(v) == 0:
        continue
    try:
        scatter_metrics(k,v, prefix=f"hc{threshold}_metrics_wrt_graph_properties")
        line_plot_metrics(k,v, prefix=f"hc{threshold}_metrics_wrt_graph_properties")
    except:
        pass

4291 4292
<class 'list'>
4292 4292
<class 'list'>
4292 4292
<class 'list'>
4293 4293
<class 'list'>
4292 4292
<class 'list'>
4291 4292
<class 'list'>
4293 4293
<class 'list'>
4292 4292
<class 'list'>
4292 4292
<class 'list'>
4293 4293
<class 'list'>
4292 4292
<class 'list'>
4291 4292
<class 'list'>
3398 3398
<class 'list'>
3865 3865
<class 'list'>
4433 4433
<class 'list'>
2895 2896
<class 'list'>
4433 4433
<class 'list'>
2895 2896
<class 'list'>
3398 3398
<class 'list'>
3865 3865
<class 'list'>
3398 3398
<class 'list'>
3865 3865
<class 'list'>
4433 4433
<class 'list'>
2895 2896
<class 'list'>
4478 4479
<class 'list'>
5590 5590
<class 'list'>
no predictions
3806 3806
<class 'list'>
3294 3294
<class 'list'>
4478 4479
<class 'list'>
5590 5590
<class 'list'>
3806 3806
<class 'list'>
3294 3294
<class 'list'>
4478 4479
<class 'list'>
5590 5590
<class 'list'>
3806 3806
<class 'list'>
3294 3294
<class 'list'>


In [33]:
pickle.dump(hc_metrics_dicts, open(f"hc{threshold}_metrics_dicts.pkl", "wb"))

In [12]:
def table_from_metrics_dicts(metrics_dicts):
    table=[]
    for k in metrics_dicts:
        table.append({
            'hold_out_by': k[0],
            'features': k[1]
        })
        # print(len(metrics_dicts[k].keys()))
        for metric in metrics:
            # mean over the different folds
            sum = 0
            for fold in metrics_dicts[k].keys():
                sum += metrics_dicts[k][fold][metric]['mean']
            if len(metrics_dicts[k].keys()) > 0:
                table[-1][metric] = sum/len(metrics_dicts[k].keys())
            else:
                table[-1][metric] = np.nan
    return pd.DataFrame(table)

In [13]:
table_from_metrics_dicts(metrics_dicts)

Unnamed: 0,hold_out_by,features,accuracy,specificity,auroc,recall,precision,f1,mcc
0,random,0,tensor(0.8187),tensor(0.8841),tensor(0.8832),tensor(0.6821),tensor(0.7403),tensor(0.7092),tensor(0.5795)
1,random,128,tensor(0.8533),tensor(0.8840),tensor(0.9125),tensor(0.7892),tensor(0.7666),tensor(0.7776),tensor(0.6684)
2,random,onehot,tensor(0.8536),tensor(0.9082),tensor(0.9152),tensor(0.7396),tensor(0.8028),tensor(0.7683),tensor(0.6642)
3,protein,0,tensor(0.6212),tensor(0.7933),tensor(0.5872),tensor(0.2649),tensor(0.3736),tensor(0.3022),tensor(0.0636)
4,protein,128,tensor(0.7794),tensor(0.8623),tensor(0.8231),tensor(0.5965),tensor(0.6710),tensor(0.6298),tensor(0.4754)
5,protein,onehot,tensor(0.7665),tensor(0.8078),tensor(0.8238),tensor(0.6756),tensor(0.6223),tensor(0.6469),tensor(0.4742)
6,model,0,tensor(0.5903),tensor(0.7576),tensor(0.5596),tensor(0.2436),tensor(0.3365),tensor(0.2513),tensor(0.0058)
7,model,128,tensor(0.6719),tensor(0.7859),tensor(0.6386),tensor(0.4323),tensor(0.4928),tensor(0.4594),tensor(0.2272)
8,model,onehot,tensor(0.6921),tensor(0.7919),tensor(0.6509),tensor(0.4841),tensor(0.5271),tensor(0.5033),tensor(0.2829)


In [None]:
table_from_metrics_dicts(hc_metrics_dicts)

Unnamed: 0,hold_out_by,features,accuracy,specificity,auroc,recall,precision,f1,mcc
0,random,0,tensor(0.9110),tensor(0.9510),tensor(0.9239),tensor(0.8012),tensor(0.8618),tensor(0.8299),tensor(0.7709)
1,random,128,tensor(0.9185),tensor(0.9454),tensor(0.9388),tensor(0.8552),tensor(0.8692),tensor(0.8621),tensor(0.8044)
2,random,onehot,tensor(0.9223),tensor(0.9575),tensor(0.9388),tensor(0.8299),tensor(0.8912),tensor(0.8591),tensor(0.8065)
3,protein,0,tensor(0.9155),tensor(0.9958),tensor(0.2265),tensor(0.0458),tensor(0.1774),tensor(0.0728),tensor(0.6237)
4,protein,128,tensor(0.8676),tensor(0.9396),tensor(0.8537),tensor(0.6384),tensor(0.7779),tensor(0.6941),tensor(0.6200)
5,protein,onehot,tensor(0.8150),tensor(0.8526),tensor(0.8536),tensor(0.7269),tensor(0.6895),tensor(0.7074),tensor(0.5723)
6,model,0,tensor(0.5166),tensor(0.6588),tensor(0.5509),tensor(0.3419),tensor(0.1014),tensor(0.0854),tensor(0.0095)
7,model,128,tensor(0.7538),tensor(0.9163),tensor(0.5650),tensor(0.2813),tensor(0.5777),tensor(0.3598),tensor(0.2594)
8,model,onehot,tensor(0.7260),tensor(0.8420),tensor(0.6421),tensor(0.4659),tensor(0.5707),tensor(0.5116),tensor(0.3277)


# Most significant plots (export for article)

## Performance w.r.t. CC, nodes, edges

In [32]:
def set_fonts(ax, skip_artists=2):
    ax.legend(['I/O', 'I/O+emb', '',''])
    artists = ax.legend_.legendHandles
    # remove the error bars from the legend
    artists = [artist for i, artist in enumerate(artists) if i%skip_artists==0]
    ax.legend(artists, ['I/O', 'I/O+emb'], loc='upper right', prop={'size': 20}, )
    # set the legend background alpha
    ax.legend_.get_frame().set_alpha(0.4)

    # increase the font of axis labels and legend
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(16)
    for item in [ax.title, ax.xaxis.label, ax.yaxis.label]:
        item.set_fontsize(24)
    # avoid cutting exis title

In [34]:
# plot the mcc w.r.t. the number of nodes, edges and cc with a regplot
# for by in ['by_nodes', 'by_edges', 'by_cc']:
for by in ['by_cc','by_nodes','by_edges']:
    for metric in ['f1']:
        for hold_out_by in ['UC1','UC2','UC3']:
            fig, ax = plt.subplots(figsize=(8,6))

            for embeddings_len in ['0', '128']:
                data=[]
                for fold in range(4):
                    data.append(metrics_dicts[(hold_out_by,embeddings_len,fold)][metric][by].sort_index())
                data = pd.concat(data, axis=1)
                data = data.mean(axis=1)

                sns.regplot(x=data.index, y=np.array(data, dtype=float), ax=ax, order=2)
                
                ax.set_ylim([0,1])
                ax.set_ylabel(metric.upper())
                ax.set_xlabel(by.replace('by_',''))
                
                if by == 'by_cc':
                    ax.set_xlim([0,1])
                elif by == 'by_nodes':
                    ax.set_xlim([2,None])
                elif by == 'by_edges':
                    ax.set_xlim([1,None])

                fig.tight_layout(pad=3.0)
                
            set_fonts(ax, skip_artists=3)
            plt.savefig(f'figures_2024-10/f1_{by}_{hold_out_by}.pdf', format='pdf')
            plt.close()

  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  artists = ax.legend_.legendHandles
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  artists = ax.legend_.legendHandles
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles


In [143]:
# plot the mcc w.r.t. the number of nodes, edges and cc with a regplot
for by in ['by_nodes', 'by_edges']:
    for metric in ['f1']:
        for hold_out_by in ['UC1','UC2','UC3']:
            fig, ax = plt.subplots(figsize=(8,6))

            for embeddings_len in ['0', '128']:
                data=[]
                for fold in (range(4) if not (hold_out_by=='UC1' and embeddings_len=='0') else [0,1,2]):
                    data.append(metrics_dicts[(hold_out_by,embeddings_len,fold)][metric][by].sort_index())
                data = pd.concat(data, axis=1)
                data = data.mean(axis=1)
                # group samples in bins
                data = data.groupby(pd.cut(data.index, bins=20)).mean()
                # convert the intevals in the middle of the interval
                data.index = [(i.left+i.right)/2 for i in data.index]
                # sns.regplot(x=data.index, y=np.array(data, dtype=float), ax=ax, order=2)
                # plot the moving average
                avg_data = data.rolling(window=3)
                # smooth the data
                ax.plot(avg_data.mean())
                # plot also the standard deviation of the moving average
                ax.fill_between(data.index, avg_data.mean() - avg_data.std(), avg_data.mean() + avg_data.std(), alpha=0.2)
                ax.set_ylabel(metric.upper())
                ax.set_xlabel(by.replace('by_',''))
                # ax.set_title(f'{metric} w.r.t. {by}')

                fig.tight_layout(pad=3.0)
                
            set_fonts(ax)
            plt.savefig(f'figures_2024-10/f1_{by}_{hold_out_by}.pdf', format='pdf')
            plt.close()

  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles


## Performance w.r.t. I->O and O->I distance

In [144]:
# plot the f1 score w.r.t. I->O distance one a line plot, averaging over the folds and showing the standard deviation
for direction in ['io', 'oi']:
    for hold_out_by in ['UC1','UC2','UC3']:
        fig, ax = plt.subplots(1, 1, figsize=(8,6))
        for embeddings_len in ['0', '128']:
            data = []
            for fold in range(4):
                data.append(metrics_dicts[(hold_out_by,embeddings_len,fold)]['f1'][f'by_distance_{direction}'].sort_index())
            data = pd.concat(data, axis=1)
            means = data.mean(axis=1).sort_index()
            stds =  data.std(axis=1).sort_index()
            # ax.plot(means, label=k)
            # replace inf and nans with 0
            index = means.index
            stds = np.array(stds, dtype=float)
            means = np.array(means, dtype=float)
            # ax.fill_between(index, means - stds, means + stds, alpha=0.2)
            ax.set_ylim([0,1.1])
            # ax.set_title(f'f1 w.r.t. I->O distance')
            ax.set_ylabel('F1')
            fig.tight_layout(pad=3.0)
            # increase font size
            ax.set_xlim([1,7])
            
            sns.regplot(x=index, y=means, ax=ax, scatter=True, order=2)
            ax.set_xlabel('Path length from $u_{in}$ to $u_{out}$')
        
        set_fonts(ax, skip_artists=3)
        
        plt.savefig(f'figures_2024-10/f1_by_distance_{direction}_{hold_out_by}.pdf', format='pdf')
        plt.close()


  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
  boot_dist.append(f(*sample, **func_kwargs))
  boot_dist.append(f(*sample, **func_kwargs))
  artists = ax.legend_.legendHandles
  artists = ax.legend_.legendHandles
