# Imports

In [19]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [22]:
import json
from lightning import Trainer
from torchmetrics.functional import accuracy, specificity, auroc, recall, precision, f1_score, matthews_corrcoef
from pathlib import Path
from torch_geometric.utils.convert import to_networkx
import seaborn as sns
import time
import os
import sys
from tqdm import tqdm
import ray
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

sys.path.insert(0, os.path.abspath('..'))
from src.model_development.kfold import *
from src.model_development.datamodule import *
from src.model_development.dgn.dgn import DGN
from src.model_development.dgn.deepsets import DeepSetsModule

# Data loading

In [23]:
config = json.load(open('../config.json', 'r'))

In [24]:
base_path = Path(config["ray_results_path"])
trial_names = pickle.load(open("gcn_best_trial_names.pkl","rb"))

In [25]:
# project = "Master-Thesis"
project = "peppina-final"
# project = "gnn-ppi-sens"

def find_trial_ckpts(name, project=project):
    
    # try:
    dir = [f for f in base_path.rglob(f"*{name}*/")][0]
    # except:
    #     return None
    
    config = pickle.load((dir/"params.pkl").open("rb"))
    paths = [f for f in (dir/project).rglob(f"*last.ckpt")]
    paths = paths + [f for f in (dir/project).rglob(f"*.ckpt")]
    
    return {
        'config': config,
        'ckpts': paths
    }

In [26]:
ckpts={}
for k, name in tqdm(trial_names.items()):
    ckpts[k] = find_trial_ckpts(name)

  0%|          | 0/36 [00:00<?, ?it/s]

100%|██████████| 36/36 [02:59<00:00,  5.00s/it]


In [27]:
for f in range(4):
    print(ckpts[('UC1','128',f)]['ckpts'][1])

/data/ray_results/train_gcn_2024-12-05_10-34-29/train_gcn_81de7_00000_0_activation=sigmoid,aggr=add,batch_size=4096,bias=True,bias_scale=0.1000,bn=False,conv=DirGNNConv,dataloade_2024-12-05_10-34-29/peppina-final/1l2ce4ql/checkpoints/last.ckpt
/data/ray_results/train_gcn_2024-12-05_10-34-29/train_gcn_81de7_00001_1_activation=sigmoid,aggr=add,batch_size=4096,bias=True,bias_scale=0.1000,bn=False,conv=DirGNNConv,dataloade_2024-12-05_10-34-29/peppina-final/gn6zhovx/checkpoints/last.ckpt
/data/ray_results/train_gcn_2024-12-05_10-34-29/train_gcn_81de7_00002_2_activation=sigmoid,aggr=add,batch_size=4096,bias=True,bias_scale=0.1000,bn=False,conv=DirGNNConv,dataloade_2024-12-05_11-01-47/peppina-final/bs3862b7/checkpoints/last.ckpt
/data/ray_results/train_gcn_2024-12-02_15-05-07/train_gcn_d1689_00000_0_activation=sigmoid,aggr=add,batch_size=4096,bias=True,bias_scale=0.1000,bn=False,conv=DirGNNConv,dataloade_2024-12-02_15-05-07/peppina-final/s2zd54wp/checkpoints/last.ckpt


In [29]:
folds={}

ray.shutdown()
ray.init(ignore_reinit_error=True, _temp_dir="/data/adipalma/tmp")

for hold_out_by in ['UC1','UC2','UC3']:
    for embeddings_len in ['0','128']:#,'onehot']:
        config=ckpts[(hold_out_by, embeddings_len, 0)]['config']
        config['biogrid_ver'] = '2024-10'
        folds[(hold_out_by,embeddings_len)] = create_fold_dict(config=config)

2025-01-31 18:57:25,690	INFO worker.py:1841 -- Started a local Ray instance.


KeyError: 'use_case'

In [9]:
metrics = {
    'accuracy': accuracy,
    'specificity': specificity,
    'auroc': auroc,
    'recall': recall,
    'precision': precision,
    'f1': f1_score,
    'mcc': matthews_corrcoef
}

In [7]:
# hide cuda devices
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

In [None]:
def load_and_predict(config, ckpt, hold_out_by, features):

    fold = ray.get(folds[(hold_out_by, features)][0]
                   [('sensitivity',config['hold_out_by']['grid_search'][0] if type(config['hold_out_by'])==dict else config['hold_out_by'])]
                   [config['test_fold']])
    
    if 'aggr' not in config.keys():
        config['aggr'] = config['SAGE_aggr']
    if 'uniform_bound' not in config.keys():
        config['uniform_bound'] = None
    if 'weight_initializer' not in config.keys():
        config['weight_initializer'] = 'kaiming_uniform'


    test = fold['test']
    outer_train = fold['train']
    train = outer_train[config["val_fold"]]['train']
    val = outer_train[config["val_fold"]]['val']
    config['batch_size'] = 10000
    
    data = GraphDataModule(train, val, test, config)
    data.setup()

    input_dim = train[0].x.shape[1]
    output_dim = 1
    start = time.time()

    if config['model'] == 'gcn':
        model = DGN.load_from_checkpoint(ckpt, input_dim=input_dim, output_dim=output_dim, config=config,  map_location=torch.device('cpu'))
    elif config['model'] == 'deepsets':
        model = DeepSetsModule.load_from_checkpoint(ckpt, input_dim=input_dim, output_dim=output_dim, config=config,  map_location=torch.device('cpu'))
    
    cuda_args = {"accelerator": "cpu"}

    trainer = Trainer(enable_progress_bar = False, **cuda_args)
  
    data = GraphDataModule(train, val, test, config)
    data.setup()
    end = time.time()
    print("Time to load model: ", end-start)
    
    start = time.time()
    predictions = trainer.predict(model, data.test_dataloader())

    end = time.time()
    print("Time to predict: ", end-start)
    # concatenate all predictions in a single tensor
    predictions = torch.cat(predictions)
    
    return predictions, data.test_data

In [None]:
predictions_dict = {}
for k,v in ckpts.items():

    predictions_dict[k] = {}
    if len(v['ckpts']) == 0:
        print(f"no checkpoint for {k} fold {k[2]}")
    try:
        predictions = load_and_predict(v['config'], v['ckpts'][-1], k[0], k[1])
        predictions_dict[k] = predictions

    except Exception as e:
        print(f"error for {k} fold {k[2]}")
        print(e)
        continue

In [None]:
pickle.dump(predictions_dict, open("predictions_dict.pkl", "wb"))

In [10]:
predictions_dict=pickle.load(open("predictions_dict.pkl", "rb"))

# Metrics

In [26]:
def compute_stats(predictions, test, confidence_threshold=None):
    stats = []
    # print(len(predictions), len(test))
    # print(type(test))
    if confidence_threshold is not None:
        probs = torch.sigmoid(predictions)
        index = [i for i, p in enumerate(probs) if p < confidence_threshold or p > (1-confidence_threshold)]
        test = [test[i] for i in index]   
        predictions = predictions[index]     
        if len(index) == 0:
            print("no predictions")
            return None
    
    for i,d in enumerate(test):
        if i >= len(predictions):
            break
        # convert the graph to networkx
        g = to_networkx(d, remove_self_loops=True)
        # print(g.edges())
        # print(d['x'])
        input_node = np.where(d['x'][:,0]==1)[0][0]
        output_node = np.where(d['x'][:,1]==1)[0][0]
        # print(input_node, output_node)
        
        def distance(g, input_node, output_node):
            try: 
                return nx.shortest_path_length(g, source=input_node, target=output_node)
            except: 
                return 0

        stats.append({
            'nodes': d['x'].shape[0],
            'edges': d['edge_index'].shape[1] if d['edge_index'].shape[0] == 2 else 0,
            'cc': nx.average_clustering(g),
            'distance_io': distance(g, input_node, output_node),
            'distance_oi': distance(g, output_node, input_node),
            'input_centrality': nx.degree_centrality(g)[input_node],
            'output_centrality': nx.degree_centrality(g)[output_node],
        })

    assert len(stats)==len(predictions)
    df = pd.DataFrame(stats)
    return df


def metrics_from_predictions(predictions, test, confidence_threshold=None):
    
    df = compute_stats(predictions, test, confidence_threshold)
    
    metrics_dict = {}
    print(len(predictions), len(test))
    for metric_name, fun in {'f1': f1_score, 'mcc': matthews_corrcoef}.items():
        metrics_dict[metric_name] = {}
        
        all_labels = torch.IntTensor([d['y'] for d in test])
        print(all_labels.shape)
        def apply_metric(x):
            preds = predictions[x.index]
            labels = all_labels[x.index]
            # labels = d['y']
            
            
            return fun(preds, torch.IntTensor(labels), task='binary')
        
        metrics_dict[metric_name]['mean'] = apply_metric(df)
        print(metrics_dict[metric_name]['mean'])
        
        for grouping, bins in {
            'nodes': range(0, df.nodes.max(), 5),
            'edges': range(0, df.edges.max(), 20),
            'cc': [0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,0.9,1],
            'distance_io': range(8),
            'distance_oi': range(8),
        }.items():
            if grouping not in ['distance_io', 'distance_oi']:
                df[f"{grouping}_bin"] = pd.cut(df[grouping], bins=bins)
                # set the bin id as the center of the bin
                df["bin_id"] = df[f"{grouping}_bin"].apply(lambda x: x.mid).astype(float)
                metrics_dict[metric_name][f'by_{grouping}'] = df.groupby("bin_id").apply(apply_metric)
            else:
                metrics_dict[metric_name][f'by_{grouping}'] = df.groupby(grouping).apply(apply_metric)
            

    # print(by_nodes)

    return metrics_dict, df

In [27]:
metrics_dicts = {}
for k, v in predictions_dict.items():
    if k[1] =='onehot':
        continue
    metrics_dicts[k]= {}
    print(k)
    try:
        # print(predictions_dict[k])
        metrics_dicts[k], df = metrics_from_predictions(*predictions_dict[k])    
    except:
        continue

In [28]:
pickle.dump(metrics_dicts, open("metrics_dicts.pkl", "wb"))

In [29]:
metrics_dicts = pickle.load(open("metrics_dicts.pkl", "rb"))

# Most significant plots (export for article)

## Performance w.r.t. CC, nodes, edges

In [32]:
def set_fonts(ax, skip_artists=2):
    ax.legend(['I/O', 'I/O+emb', '',''])
    artists = ax.legend_.legendHandles
    # remove the error bars from the legend
    artists = [artist for i, artist in enumerate(artists) if i%skip_artists==0]
    ax.legend(artists, ['I/O', 'I/O+emb'], loc='upper right', prop={'size': 20}, )
    # set the legend background alpha
    ax.legend_.get_frame().set_alpha(0.4)

    # increase the font of axis labels and legend
    for item in (ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontsize(16)
    for item in [ax.title, ax.xaxis.label, ax.yaxis.label]:
        item.set_fontsize(24)
    # avoid cutting exis title

In [34]:
# plot the mcc w.r.t. the number of nodes, edges and cc with a regplot
# for by in ['by_nodes', 'by_edges', 'by_cc']:
for by in ['by_cc','by_nodes','by_edges']:
    for metric in ['mcc']:
        for hold_out_by in ['UC1','UC2','UC3']:
            fig, ax = plt.subplots(figsize=(8,6))

            for embeddings_len in ['0', '128']:
                data=[]
                for fold in range(4):
                    data.append(metrics_dicts[(hold_out_by,embeddings_len,fold)][metric][by].sort_index())
                data = pd.concat(data, axis=1)
                data = data.mean(axis=1)

                sns.regplot(x=data.index, y=np.array(data, dtype=float), ax=ax, order=2)
                
                ax.set_ylim([0,1])
                ax.set_ylabel(metric.upper())
                ax.set_xlabel(by.replace('by_',''))
                
                if by == 'by_cc':
                    ax.set_xlim([0,1])
                elif by == 'by_nodes':
                    ax.set_xlim([2,None])
                elif by == 'by_edges':
                    ax.set_xlim([1,None])

                fig.tight_layout(pad=3.0)
                
            set_fonts(ax, skip_artists=3)
            plt.savefig(f'figures/{metric}_{by}_{hold_out_by}.pdf', format='pdf')
            plt.close()

In [143]:
# plot the mcc w.r.t. the number of nodes, edges and cc with a regplot
for by in ['by_nodes', 'by_edges']:
    for metric in ['f1']:
        for hold_out_by in ['UC1','UC2','UC3']:
            fig, ax = plt.subplots(figsize=(8,6))

            for embeddings_len in ['0', '128']:
                data=[]
                for fold in (range(4) if not (hold_out_by=='UC1' and embeddings_len=='0') else [0,1,2]):
                    data.append(metrics_dicts[(hold_out_by,embeddings_len,fold)][metric][by].sort_index())
                data = pd.concat(data, axis=1)
                data = data.mean(axis=1)
                # group samples in bins
                data = data.groupby(pd.cut(data.index, bins=20)).mean()
                # convert the intevals in the middle of the interval
                data.index = [(i.left+i.right)/2 for i in data.index]
                # sns.regplot(x=data.index, y=np.array(data, dtype=float), ax=ax, order=2)
                # plot the moving average
                avg_data = data.rolling(window=3)
                # smooth the data
                ax.plot(avg_data.mean())
                # plot also the standard deviation of the moving average
                ax.fill_between(data.index, avg_data.mean() - avg_data.std(), avg_data.mean() + avg_data.std(), alpha=0.2)
                ax.set_ylabel(metric.upper())
                ax.set_xlabel(by.replace('by_',''))
                # ax.set_title(f'{metric} w.r.t. {by}')

                fig.tight_layout(pad=3.0)
                
            set_fonts(ax)
            plt.savefig(f'figures_2024-10/f1_{by}_{hold_out_by}.pdf', format='pdf')
            plt.close()

## Performance w.r.t. I->O and O->I distance

In [144]:
# plot the f1 score w.r.t. I->O distance one a line plot, averaging over the folds and showing the standard deviation
for direction in ['io', 'oi']:
    for hold_out_by in ['UC1','UC2','UC3']:
        for metric in ['f1','mcc']:
            fig, ax = plt.subplots(1, 1, figsize=(8,6))
            for embeddings_len in ['0', '128']:
                data = []
                for fold in range(4):
                    data.append(metrics_dicts[(hold_out_by,embeddings_len,fold)]['f1'][f'by_distance_{direction}'].sort_index())
                data = pd.concat(data, axis=1)
                means = data.mean(axis=1).sort_index()
                stds =  data.std(axis=1).sort_index()
                # ax.plot(means, label=k)
                # replace inf and nans with 0
                index = means.index
                stds = np.array(stds, dtype=float)
                means = np.array(means, dtype=float)
                # ax.fill_between(index, means - stds, means + stds, alpha=0.2)
                ax.set_ylim([0,1.1])
                # ax.set_title(f'f1 w.r.t. I->O distance')
                ax.set_ylabel('F1')
                fig.tight_layout(pad=3.0)
                # increase font size
                ax.set_xlim([1,7])
                
                sns.regplot(x=index, y=means, ax=ax, scatter=True, order=2)
                ax.set_xlabel('Path length from $u_{in}$ to $u_{out}$')
            
            set_fonts(ax, skip_artists=3)
            
            plt.savefig(f'figures_2024-10/{metric}_by_distance_{direction}_{hold_out_by}.pdf', format='pdf')
            plt.close()