In [None]:
import os
os.environ["KMP_DUPLICATE_LIB_OK"]="TRUE"
from typing import List, Dict, Tuple
import io
import sys
import pickle
import itertools
import datetime
import copy
from tqdm import tqdm
import random
import csv
import json
import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import math

import torch

import torch.optim.lr_scheduler as lr_scheduler
import torch_geometric
import torch_geometric.transforms as T
import torch_geometric.transforms
import torch_geometric.datasets
import torch_geometric.nn
from torch_geometric.utils import to_networkx

from IPython.display import display, HTML
import warnings
warnings.filterwarnings('ignore', module='sklearn')

In [None]:
from env_variables import *
from gnn_models import *
from utils_helpers import *

In [None]:
results_output_dir = "./results"
if not os.path.exists(results_output_dir):
    os.makedirs(results_output_dir)

# Deep Learning Pytorch
Packages versions:
- torchmetrics==0.5.0
- pytorch-lightning==1.5.10

In [None]:
def build_param_combos_gnn(number_edge_features):
    nonlinearity_common = nn.LeakyReLU(inplace=True, negative_slope=0.2)  # nn.Tanh()
    use_batchnorm_list = [False]
    # Initial edge model
    initial_edge_model_input_dim_list = [number_edge_features]#[4]  # fixed
    edge_dim_list = [16]  # [16, 32]
    fc_dims_initial_edge_model_multipliers_list = [(1, 1)]  # (1, 1)
    nonlinearity_initial_edge_list = [nonlinearity_common]

    # Initial node model
    fc_dims_initial_node_model_multipliers_list = [(2, 4, 1)]  # (1, 2, 1), (2,4,1)
    nonlinearity_initial_node_list = [nonlinearity_common]

    # Edge model
    fc_dims_edge_model_multipliers_list = [(4, 1)]  # (6,4,1), (4,1)
    nonlinearity_edge_list = [nonlinearity_common]

    # TimeAware Node model
    fc_dims_directed_flow_model_multipliers_list = [(2, 1)]  # (4,2,1), (2,1); (4,2,2,1) for online
    nonlinearity_directed_flow_list = [nonlinearity_common]
    node_model_agg_list = ["max"]  # ["max", "attention", "attention_classifier", "attention_normalized"]
    # multiplies node_dim, last layer output is always 1 [None, (2,1)]
    fc_dims_node_attention_model_multipliers_list = [None]

    # (4,2,1), (6,4,2,1)  (2,1) online - just an extra MLP
    fc_dims_total_flow_model_multipliers_list = [(4, 2, 1)]
    nonlinearity_total_flow_list = [nonlinearity_common]

    # Edge classification model
    fc_dims_edge_classification_model_multipliers_list = [
        (4, 2, 1,)]  # (2,1) [(0.5, ), None]  # mutliplies edge_dim
    nonlinearity_edge_classification_list = [nonlinearity_common]

    mpn_steps_list = [4]#[args.mpn_steps] -> number of message passing steps default = 4
    is_recurrent_list = [True]
    node_dim_multiplier_list = [2]

    use_timeaware_list = [False]
    use_same_frame_list = [False]#[not args.no_sameframe] -> default False
    use_separate_edge_model_list = [False]
    use_initial_node_model_list = [True]
    edge_mlps_count_list = [3]#[args.edge_mlps_count] - > Number of distinct node MLPs default = 3
    node_aggr_sections_list = [3]#[args.node_aggr_sections] -> "Number of distinct sections in Node aggregation default=3

    param_combos = list(product(initial_edge_model_input_dim_list,
                            edge_dim_list, fc_dims_initial_edge_model_multipliers_list, nonlinearity_initial_edge_list,
                            fc_dims_initial_node_model_multipliers_list, nonlinearity_initial_node_list,
                            node_model_agg_list, fc_dims_node_attention_model_multipliers_list,
                            fc_dims_edge_model_multipliers_list, nonlinearity_edge_list,
                            fc_dims_directed_flow_model_multipliers_list, nonlinearity_directed_flow_list,
                            fc_dims_total_flow_model_multipliers_list, nonlinearity_total_flow_list,
                            fc_dims_edge_classification_model_multipliers_list, nonlinearity_edge_classification_list,
                            use_batchnorm_list,
                            mpn_steps_list, is_recurrent_list, node_dim_multiplier_list,
                            use_timeaware_list, use_same_frame_list, use_separate_edge_model_list, use_initial_node_model_list,
                            edge_mlps_count_list,
                            node_aggr_sections_list,
                            ))

    return param_combos

def build_params_dict_gnn(initial_edge_model_input_dim, edge_dim, fc_dims_initial_edge_model_multipliers, nonlinearity_initial_edge,
                       fc_dims_initial_node_model_multipliers, nonlinearity_initial_node, 
                       directed_flow_agg, fc_dims_directed_flow_attention_model_multipliers,
                       fc_dims_edge_model_multipliers, nonlinearity_edge,
                       fc_dims_directed_flow_model_multipliers, nonlinearity_directed_flow, 
                       fc_dims_total_flow_model_multipliers, nonlinearity_total_flow,
                       fc_dims_edge_classification_model_multipliers, nonlinearity_edge_classification,
                       use_batchnorm: bool,
                       mpn_steps: int, is_recurrent: bool, node_dim_multiplier: int,
                       use_timeaware: bool, use_same_frame: bool, use_separate_edge_model: bool, use_initial_node_model: bool,
                       edge_mlps_count: int,
                       node_aggr_sections: int,
                       **kwargs,
                       ):
    # workaround before adding sacred
    params = {
        "initial_edge_model_input_dim": initial_edge_model_input_dim,
        "edge_dim": edge_dim,

        "fc_dims_initial_edge_model_multipliers": fc_dims_initial_edge_model_multipliers,
        "nonlinearity_initial_edge": nonlinearity_initial_edge,

        "fc_dims_initial_node_model_multipliers": fc_dims_initial_node_model_multipliers,
        "nonlinearity_initial_node": nonlinearity_initial_node,
        "directed_flow_agg": directed_flow_agg,
        "fc_dims_directed_flow_attention_model_multipliers": fc_dims_directed_flow_attention_model_multipliers,

        "fc_dims_edge_model_multipliers": fc_dims_edge_model_multipliers,
        "nonlinearity_edge": nonlinearity_edge,

        "fc_dims_directed_flow_model_multipliers": fc_dims_directed_flow_model_multipliers,
        "nonlinearity_directed_flow": nonlinearity_directed_flow,

        "fc_dims_total_flow_model_multipliers": fc_dims_total_flow_model_multipliers,
        "nonlinearity_total_flow": nonlinearity_total_flow,

        "fc_dims_edge_classification_model_multipliers": fc_dims_edge_classification_model_multipliers,
        "nonlinearity_edge_classification": nonlinearity_edge_classification,

        "use_batchnorm": use_batchnorm,

        "mpn_steps": mpn_steps,
        "is_recurrent": is_recurrent,
        "node_dim_multiplier": node_dim_multiplier,

        "use_timeaware": use_timeaware,
        "use_same_frame": use_same_frame,
        "use_separate_edge_model": use_separate_edge_model,
        "use_initial_node_model": use_initial_node_model,
        "edge_mlps_count": edge_mlps_count,
        "node_aggr_sections": node_aggr_sections,
    }
    params.update(kwargs)
    return params

def build_models_gnn(params: Mapping[str, Any]):
    use_batchnorm = params["use_batchnorm"]

    edge_dim = params["edge_dim"]
    node_dim_multiplier = params.get("node_dim_multiplier", 2)
    node_dim = edge_dim * node_dim_multiplier  # Have nodes hold 2x info of edges
    use_timeaware = params.get("use_timeaware", True)
    use_same_frame = params.get("use_same_frame", False)
    # separate backward/forward/sameframe MLPs or inter/intraframe or single MLP for all
    edge_mlps_count = params.get("edge_mlps_count", 3)
    assert edge_mlps_count > 0 and edge_mlps_count <= 3, f"edge_mlps_count must be 1/2/3, not {edge_mlps_count}"
    node_aggr_sections = params.get("node_aggr_sections", 3)
    assert node_aggr_sections > 0 and node_aggr_sections <= 3, f"node_aggr_sections must be 1/2/3, not {node_aggr_sections}"
    # only makes sense when using intraframe
    use_separate_edge_model = use_same_frame and params.get("use_separate_edge_model", False) 
    use_initial_node_model = params.get("use_initial_node_model", True)

    # Edge classification model
    fc_dims_edge_classification_model_multipliers = params["fc_dims_edge_classification_model_multipliers"]
    if fc_dims_edge_classification_model_multipliers is not None:
        fc_dims_edge_classification_model = dims_from_multipliers(
            edge_dim, fc_dims_edge_classification_model_multipliers) + (1,)
    else:
        fc_dims_edge_classification_model = (1,)
    edge_classifier = MLP(edge_dim, fc_dims_edge_classification_model,
                          params["nonlinearity_edge_classification"], last_output_free=True)

    # Initial edge model:
    fc_dims_initial_edge = dims_from_multipliers(
        edge_dim, params["fc_dims_initial_edge_model_multipliers"])
    initial_edge_model = MLP(params["initial_edge_model_input_dim"], fc_dims_initial_edge,
                            params["nonlinearity_initial_edge"], use_batchnorm=use_batchnorm)
    if use_separate_edge_model:
        initial_same_frame_edge_model = MLP(params["initial_edge_model_input_dim"], fc_dims_initial_edge,
                                            params["nonlinearity_initial_edge"], use_batchnorm=use_batchnorm)
    else:
        initial_same_frame_edge_model = None

    # Initial node model
    if use_initial_node_model:
        initial_node_agg_mode = params["directed_flow_agg"]
        if "attention" in initial_node_agg_mode:
            if "classifier" in initial_node_agg_mode:
                initial_node_attention_model = edge_classifier
            else:
                fc_dims_directed_flow_attention_model_multipliers = params["fc_dims_directed_flow_attention_model_multipliers"]
                if fc_dims_directed_flow_attention_model_multipliers is not None:
                    fc_dims_initial_node_attention = dims_from_multipliers(
                        edge_dim, fc_dims_directed_flow_attention_model_multipliers) + (1,)
                else:
                    fc_dims_initial_node_attention = (1,)
                initial_node_attention_model = MLP(edge_dim, fc_dims_initial_node_attention,
                                                params["nonlinearity_initial_node"], last_output_free=True)
        else:
            initial_node_attention_model = None

        fc_dims_initial_node = dims_from_multipliers(
            node_dim, params["fc_dims_initial_node_model_multipliers"])
        if use_timeaware:
            if use_same_frame:
                initial_node_model = InitialContextualNodeModel(MLP(edge_dim * 3, fc_dims_initial_node,
                                                                params["nonlinearity_initial_node"], use_batchnorm=use_batchnorm),
                                                                initial_node_agg_mode, initial_node_attention_model)
            else:
                initial_node_model = InitialTimeAwareNodeModel(MLP(edge_dim * 2, fc_dims_initial_node,  # x2 for [forward|backward] edge features
                                                                   params["nonlinearity_initial_node"], use_batchnorm=use_batchnorm),
                                                               initial_node_agg_mode)
        else:
            assert not use_same_frame
            initial_node_model = InitialUniformAggNodeModel(MLP(edge_dim, fc_dims_initial_node,
                                                                params["nonlinearity_initial_node"], use_batchnorm=use_batchnorm),
                                                            initial_node_agg_mode)
    else:  # initial nodes are zero vectors
        initial_node_model = InitialZeroNodeModel(node_dim)

    # Define models in MPN
    edge_models, node_models = [], []
    steps = params["mpn_steps"]
    assert steps > 1, "Fewer than 2 MPN steps does not make sense as in that case nodes do not get a chance to update"
    is_recurrent = params["is_recurrent"]
    for step in range(steps):
        # Edge model
        edge_model_input = node_dim * 2 + edge_dim  # edge_dim * 5
        fc_dims_edge = dims_from_multipliers(
            edge_dim, params["fc_dims_edge_model_multipliers"])
        edge_models.append(BasicEdgeModel(MLP(edge_model_input, fc_dims_edge,
                                              params["nonlinearity_edge"], use_batchnorm=use_batchnorm)))

        if step == steps - 1: # don't need a node update at the last step
            continue

        # Node model
        flow_model_input = node_dim * 2 + edge_dim  # two nodes and their edge
        fc_dims_directed_flow = dims_from_multipliers(
            node_dim, params["fc_dims_directed_flow_model_multipliers"])
        fc_dims_aggregated_flow = dims_from_multipliers(
            node_dim, params["fc_dims_total_flow_model_multipliers"])
        
        node_agg_mode = params["directed_flow_agg"]
        if "attention" in node_agg_mode:
            if "classifier" in node_agg_mode:
                attention_model = edge_classifier
            else:
                fc_dims_directed_flow_attention_model_multipliers = params["fc_dims_directed_flow_attention_model_multipliers"]
                if fc_dims_directed_flow_attention_model_multipliers is not None:
                    fc_dims_directed_flow_attention = dims_from_multipliers(
                        node_dim, fc_dims_directed_flow_attention_model_multipliers) + (1,)
                else:
                    fc_dims_directed_flow_attention = (1,)
                attention_model = MLP(node_dim, fc_dims_directed_flow_attention,
                                    params["nonlinearity_directed_flow"], last_output_free=True)
        else:
            attention_model = None

        if use_timeaware:
            forward_flow_model = MLP(flow_model_input, fc_dims_directed_flow,
                                     params["nonlinearity_directed_flow"], use_batchnorm=use_batchnorm)
            if edge_mlps_count < 3:
                backward_flow_model = forward_flow_model
            else:
                backward_flow_model = MLP(flow_model_input, fc_dims_directed_flow,
                                        params["nonlinearity_directed_flow"], use_batchnorm=use_batchnorm)
            if use_same_frame:
                if edge_mlps_count == 1:
                    frame_flow_model = forward_flow_model
                else:
                    frame_flow_model = MLP(flow_model_input, fc_dims_directed_flow,
                                        params["nonlinearity_directed_flow"], use_batchnorm=use_batchnorm)
                aggregated_flow_model = MLP(node_dim * 3, fc_dims_aggregated_flow,
                                            params["nonlinearity_total_flow"], use_batchnorm=use_batchnorm)
                node_models.append(ContextualNodeModel(
                    forward_flow_model, frame_flow_model, backward_flow_model, aggregated_flow_model, node_agg_mode, attention_model, node_aggr_sections=node_aggr_sections))

            else:
                aggregated_flow_model = MLP(node_dim * 2, fc_dims_aggregated_flow,
                                            params["nonlinearity_total_flow"], use_batchnorm=use_batchnorm)
                node_models.append(TimeAwareNodeModel(
                    forward_flow_model, backward_flow_model, aggregated_flow_model, node_agg_mode))
        else:
            individual_flow_model = MLP(flow_model_input, fc_dims_directed_flow,
                                        params["nonlinearity_directed_flow"], use_batchnorm=use_batchnorm)
            aggregated_flow_model = MLP(node_dim, fc_dims_aggregated_flow,
                                        params["nonlinearity_total_flow"], use_batchnorm=use_batchnorm)
            node_models.append(UniformAggNodeModel(individual_flow_model,
                               aggregated_flow_model, node_agg_mode))

        if is_recurrent:  # only one model to use at each step
            break

    if is_recurrent:
        assert len(edge_models) == len(node_models) == 1
        if use_separate_edge_model:
            same_frame_edge_model = BasicEdgeModel(MLP(edge_model_input, fc_dims_edge, params["nonlinearity_edge"],
                                                    use_batchnorm=use_batchnorm))
        else:
            same_frame_edge_model = None

        if use_initial_node_model:
            mpn_model = MessagePassingNetworkRecurrent(edge_models[0], node_models[0], steps,
                                                    use_same_frame, same_frame_edge_model)
        else:  # use a node-to-edge MPN
            mpn_model = MessagePassingNetworkRecurrentNodeEdge(edge_models[0], node_models[0], steps,
                                                               use_same_frame, same_frame_edge_model)
    else:
        mpn_model = MessagePassingNetworkNonRecurrent(edge_models, node_models, steps, use_same_frame)

    return initial_edge_model, initial_same_frame_edge_model, initial_node_model, mpn_model, edge_classifier

In [None]:
class GraphClassifierGNN(torch.nn.Module):
    def __init__(self, params: Mapping):
        """ Top level model class holding all components necessary to perform tracking on a graph
        :param initial_same_frame_edge_model: a torch model processing initial edge attributes for same frame edges
        :param initial_node_model: a torch model processing edge attributes to get initial node features
        :param mpn_model: a message passing model
        :param edge_classifier: a final classification model operating on final edge features
        :param params: params
        """
        super().__init__()
        self.params = params
        
        #change this because we only one edge model
        (self.initial_edge_model, self.initial_same_frame_edge_model, self.initial_node_model,
             self.mpn_model, self.edge_classifier) = build_models_gnn(params)
        
        self.use_same_frame = self.params["use_same_frame"]
        self.device = torch.device('cpu')
    
    def forward(self, data):
        edge_index, edge_attr, num_nodes = data.edge_index.long(), data.edge_attr, data.num_nodes
        same_frame_edge_index = data.same_frame_edge_index.long() if self.use_same_frame else None
        same_frame_edge_attr = data.same_frame_edge_attr if self.use_same_frame else None

        # Initial Edge embeddings with Null node embeddings
        edge_attr = self.initial_edge_model(edge_attr)
        if self.use_same_frame:
            if self.initial_same_frame_edge_model is not None:
                same_frame_edge_attr = self.initial_same_frame_edge_model(same_frame_edge_attr)
            else:
                same_frame_edge_attr = self.initial_edge_model(same_frame_edge_attr)
        
        # Initial Node embeddings with Null original embeddings
        x = self.initial_node_model(edge_index, edge_attr, num_nodes,
                                    same_frame_edge_index=same_frame_edge_index, 
                                    same_frame_edge_attr=same_frame_edge_attr, 
                                    device=self.device)
        assert len(x) == num_nodes
        
        x, final_edge_embeddings = self.mpn_model(x, edge_index, edge_attr, num_nodes,
                                                 same_frame_edge_index=same_frame_edge_index,
                                                 same_frame_edge_attr=same_frame_edge_attr)
        
        return self.edge_classifier(final_edge_embeddings)
    
    def forward_graph(self, graph, criterion = None):
        out = self.forward(graph.pyg_graph).view(-1)
        loss = None
        true = graph.pyg_graph.edge_label
        if(criterion):
            loss = criterion(out, true)
        return out, loss, true

In [None]:
class GraphClassifierMLP(torch.nn.Module):
    def __init__(self, dimensions):
        super().__init__()
        layers = []
        for i in range(len(dimensions) - 1):
            layers.append(torch.nn.Linear(dimensions[i], dimensions[i+1]))
            layers.append(torch.nn.ReLU())  # You can use other activation functions here

        # Remove the last ReLU layer
        layers.pop()

        self.mlp = nn.Sequential(*layers)

    def forward(self, x):
        return self.mlp(x)
    
    def forward_graph(self, graph, criterion = None):
        out = self.forward(torch.from_numpy(graph.edge_x).to(torch.float)).view(-1)
        true = torch.from_numpy(graph.edge_y)
        loss = None
        if(criterion):
            loss = criterion(out, true)
        return out, loss, true

In [None]:
def apply_constraints_func(probabilities, data, threshold=0.5, debug = False):
    # Apply constraints to the predicted probabilities
    edge_types = data.edge_types.detach().cpu().numpy()
    source_nodes = data.edge_index[0].detach().cpu().numpy()
    target_nodes = data.edge_index[1].detach().cpu().numpy()
    
    # Sort edge_list_info and probabilities in descending order of probabilities
    sorted_indices = sorted(range(len(probabilities)), key=lambda k: probabilities[k], reverse=True)
    
    edge_types_sorted = [edge_types[i] for i in sorted_indices]
    source_nodes_sorted = [source_nodes[i] for i in sorted_indices]
    target_nodes_sorted = [target_nodes[i] for i in sorted_indices]
    
    probabilities_sorted = [probabilities[i] for i in sorted_indices]

    # Create a set to keep track of assigned node ids
    assigned_nodes = set()

    # Create a new list to store the predicted labels
    pred_labels = [0] * len(probabilities)
    
    allowed_edge_types = set([edges_type_int_encodings['nuclei-golgi'], edges_type_int_encodings['golgi-nuclei']])
    
    # Assign 1 to the links with the highest probabilities for each nuclei and golgi
    if(debug):
        print("\n\n\nConstraints ", threshold, "\n\n")
    
    for i in range(len(sorted_indices)):
        src = source_nodes_sorted[i]
        tgt = target_nodes_sorted[i]
        edge_type = edge_types_sorted[i]
        prob = probabilities_sorted[i]
        if(debug):
            print("src:",src, "tgt:",tgt, "prob:",prob, "assigned: ", end = "")
        
        # If the edge is nuclei-golgi or golgi-nuclei and the nodes are not already assigned and probability > threshold
        if (edge_type in allowed_edge_types) and (src not in assigned_nodes) and (tgt not in assigned_nodes):

            if prob > threshold:
                pred_labels[sorted_indices[i]] = 1
                assigned_nodes.add(src)
                assigned_nodes.add(tgt)
            if(debug):
                print("True")
        else:
            if(debug):
                print("False")

    return pred_labels

In [None]:
##################################################################################################
## Functions to  train and evaluate neural network 
#################################################################################################
import sklearn.metrics
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import negative_sampling

def pyg_train_link_predictor(
    model, train_data, val_data, optimizer, criterion, n_epochs=100, debug = False,
    early_stopper = None, scheduler = None, apply_constraints = True
):
    early_stopper = early_stopper
    for epoch in range(1, n_epochs + 1):
        model.train()
        random.shuffle(train_data)
        for graph in train_data:
            optimizer.zero_grad()
            out, loss, true = model.forward_graph(graph, criterion = criterion)
            loss.backward()
            optimizer.step()
        
        if(debug):
            if epoch % 10 == 0:
                # Eval the model at the end of each Epoch
                metrics = pyg_eval_link_predictor(model, val_data, criterion = criterion, apply_constraints = apply_constraints)
                print(f"Epoch: {epoch:03d}, Train Loss: {loss:.3f}, Metrics:",metrics)
                

        if early_stopper:
            if early_stopper.early_stop(metrics["loss"]):             
                break

        if scheduler:
            scheduler.step()

    return model

@torch.no_grad()
def pyg_aggregate_metrics_all(metrics_list, loss_criterion=None):
    aggregated_metrics = {}
    
    aggregated_metrics["rouc_auc_score"] = statistics.mean([metric["rouc_auc_score"] for metric in metrics_list])
    
    # Aggregate loss if provided
    if loss_criterion:
        aggregated_metrics["loss"] = torch.mean(torch.stack([metric_["loss"] for metric_ in metrics_list]), dim=0)

    
    # Aggregate other metrics
    metric_keys = ["@best", "@0.5"]

    for key in metric_keys:
        aggregated_metrics[key] = {}    
        aggregated_metrics[key]["metrics"] = aggregate_metrics([metric[key]["metrics"] for metric in metrics_list])
        
        sample_metric = metrics_list[0][key]
        if("@constraints" in sample_metric):
            aggregated_metrics[key]["@constraints"] = {}
            aggregated_metrics[key]["@constraints"]["metrics"] = aggregate_metrics([metric[key]["@constraints"]["metrics"] for metric in metrics_list])

    return aggregated_metrics

@torch.no_grad()
def pyg_eval_link_predictor(model, data, criterion = None, plot_roc_curve=False, debug = False, 
                                 apply_constraints=True):
    model.eval()
    
    #computed metrics -> "acc", "precision", "recall", "tp", fp", "tn", "fn"
    metrics_dict = {}
    metrics_dict["individual_metrics"] = {}#the metrics for each graph, key=graph_id->value=graph_metrics
    metrics_dict["aggregated_metrics"] = {}
    
    for i in range(len(data)):
        graph = data[i]
        graph_id = data[i].graph_id
        tp_total_count = len(data[i].edge_list)
        
        metrics = {}
        out, loss, true = model.forward_graph(graph, criterion = criterion)
        if(criterion!=None):
            metrics["loss"] = loss
        
        out = out.sigmoid()
        pred = out.cpu().numpy()
        
        if len(np.unique(pred))==1 or len(np.unique(true)) == 1:
            rouc_auc_score = 0
        else:
            rouc_auc_score = round(roc_auc_score(true, pred), 3)
        
        metrics["rouc_auc_score"] = rouc_auc_score
        fpr, tpr, thresholds = sklearn.metrics.roc_curve(true, pred)
        
        # 0.5 threshold
        pred_labels_05 = (pred > 0.5).astype(int)
        metrics["@0.5"] = {}
        metrics["@0.5"]["pred_labels"] = pred_labels_05#save pred labels to make plot of predicted graph
        metrics["@0.5"]["metrics"] = eval_metrics(true, pred_labels_05, tp_total_count)

        sensitivity = tpr
        specificity = 1 - fpr
        optimal_idx = np.argmax(sensitivity + specificity - 1)
        optimal_threshold = thresholds[optimal_idx]

        # Calculate pred_labels_best with constraints if required
        pred_labels_best = (pred > optimal_threshold).astype(int)
        metrics["@best"] = {}
        metrics["@best"]["metrics"] = eval_metrics(true, pred_labels_best, tp_total_count)
        metrics["@best"]["pred_labels"] = pred_labels_best
        metrics["@best"]["optimal_threshold"] = optimal_threshold
        metrics["figures"] = {}

        if(apply_constraints):
            pred_labels_constraints_05 = apply_constraints_func(pred, graph.pyg_graph, threshold=0.5)  # Apply constraints
            metrics["@0.5"]["@constraints"] = {}
            metrics["@0.5"]["@constraints"]["pred_labels"] = pred_labels_constraints_05
            metrics["@0.5"]["@constraints"]["metrics"] = eval_metrics(true, pred_labels_constraints_05, tp_total_count)

            pred_labels_constraints_best = apply_constraints_func(pred, graph.pyg_graph, threshold=optimal_threshold)  # Apply constraints
            metrics["@best"]["@constraints"] = {}
            metrics["@best"]["@constraints"]["metrics"] = eval_metrics(true, pred_labels_constraints_best, tp_total_count)
            metrics["@best"]["@constraints"]["pred_labels"] = pred_labels_constraints_best
            
            pred_labels_constraints = apply_constraints_func(pred, graph.pyg_graph, threshold=0)  # Apply constraints
            metrics["@constraints"] = {}
            metrics["@constraints"]["metrics"] = eval_metrics(true, pred_labels_constraints, tp_total_count)
            metrics["@constraints"]["pred_labels"]  = pred_labels_constraints
        
        metrics["pred_edge_probabilities"] = pred
        metrics_dict["individual_metrics"][graph_id]= metrics
        data[i].metrics = metrics

    metrics_dict["aggregated_metrics"] = pyg_aggregate_metrics_all(list(metrics_dict["individual_metrics"].values()), 
                                                               loss_criterion = criterion)
    
    return metrics_dict

In [None]:
def build_model(model_type, dataset_num_node_features, dataset_num_edge_features, dataset_num_total_features,
               dataset_num_classes):
    model = None
    
    if(model_type=="GNN_Classifier"):
        param_combos = build_param_combos_gnn(dataset_num_edge_features)
        params_gnn = build_params_dict_gnn(*param_combos[0])

        model = GraphClassifierGNN(params_gnn)
    elif(model_type =="MLP"):
        model_dims = (dataset_num_total_features, 100, 100, dataset_num_classes)
        model = GraphClassifierMLP(model_dims)
    else:
        try:
            model = GraphClassifierPyg(model_type, 
                    dataset_num_node_features, 100, 100, dataset_num_classes, 
                    dataset_num_edge_features,
                    decode_type = "lin")
        except:
            raise ValueError("Wrong model type!")
        
    return model
    
def schedule_training_GNN(job_parameters, model, graph_list_train, graph_list_val , debug = False):   
    
    dataset_num_classes = job_parameters["num_classes"]
    dataset_num_node_features = job_parameters["num_node_features"]
    dataset_num_edge_features = job_parameters["num_edge_features"]
    k_inter = job_parameters["knn_inter_nodes"]
    k_intra = job_parameters["knn_intra_nodes"]
    
    lr = job_parameters["lr"]
    n_epochs = job_parameters["n_epochs"]
    device = job_parameters["device"]
    
    criterion_switch = {"BCEWithLogitsLoss":torch.nn.BCEWithLogitsLoss}
    criterion_function = criterion_switch[job_parameters["criterion"]]
    
    if job_parameters["pos_weight"]:
        k_inter_mean = statistics.mean([g.k_inter for g in graph_list_train])
        pos_weight = torch.tensor([max((k_inter_mean+k_intra*2-1)/1,1)])
        criterion = criterion_function(pos_weight = pos_weight)
    else:
        criterion = criterion_function()
    
    optimizer = torch.optim.Adam(params=model.parameters(), lr=lr)
    early_stopper = job_parameters["early_stopper"]
    scheduler = job_parameters["scheduler"]
    
    model = pyg_train_link_predictor(model, graph_list_train, graph_list_val, optimizer, criterion,
                                          n_epochs=n_epochs, debug = debug, early_stopper = early_stopper, scheduler = scheduler)
            
    return model         

In [None]:
def get_cv_groups(data_type):
    if data_type in ["Real", "Real_automatic"]:
        cross_validation_groups = [
            ["Crop1.csv", "Crop2.csv", "Crop3.csv", "Crop4.csv"],
            ["Crop5_BC.csv", "Crop6_BC.csv"],
            ["Crop7_BC.csv", "Crop8_BC.csv"]
        ]
    else:
        cross_validation_groups = "even"
    return cross_validation_groups

def get_job_params_dl(debug = False):
    combinations = {
         "data_type_train":[
                    "Real"
        ],
        "data_type_test":[
            "Real",
             #"Real_automatic",
        ],
        "model_type":[
                    "GNN_Classifier", 
                    #"MLP",
                    ],
        "knn_inter_nodes":[
                            #7,
                            10
                            #"min"
                        ],
        "knn_inter_nodes_max": [7],
       "knn_intra_nodes":[0],
        "normalize":[True],#False],
        "node_feats":[
        [
            #'Y', 
            #'X', 
            #'Z', 
            'node_type', 
            #'ID'
        ],
        #[
        #    'Y', 
        #    'X', 
        #    'Z', 
        #    'node_type',
        #]
        ],

        "edge_feats":[[
     'delta_x',
     'delta_y',
     'delta_z',
     'weight',
     'angle_orientation_theta',
     'angle_orientation_phi'],
            [
     'delta_x',
     'delta_y',
     'delta_z',
     'weight',
     ],
      #[
     #'angle_orientation_theta',
     #'angle_orientation_phi']
     ],

        "to_undirected":[False],
       "lr":[1e-3],
       "n_epochs":[100],
        "early_stopper": [None],
        "scheduler" : [None],
        "pos_weight" : [True],
        "criterion" : ["BCEWithLogitsLoss"],
        "device" : ["cpu"]
    }
    
    jobs = []
    
    # Generate all possible combinations of the dictionary values
    for values in itertools.product(*combinations.values()):
        # Generate a dictionary for the combination of values
        job_dict = dict(zip(combinations.keys(), values))
        job_dict["scale_features"] = True if "Real" in job_dict["data_type_train"] else False

        index_train = "all"
        index_test = "all"
        cross_validation_groups_train = get_cv_groups(job_dict["data_type_train"])
        cross_validation_groups_test = get_cv_groups(job_dict["data_type_test"])

        job_dict["index_train"] = index_train
        job_dict["index_test"] = index_test
        job_dict["cross_validation_groups_train"] = cross_validation_groups_train
        job_dict["cross_validation_groups_test"] = cross_validation_groups_test

        jobs.append(job_dict)
    
    if(debug):
        print("Total Number of jobs is:",len(jobs))
        print(json.dumps(jobs))
    return jobs

In [None]:
def get_graph_list_dl(jobs, debug = False):
    #build dataframes
    graph_list_dict_deep_learning = {}

    for params in tqdm(jobs):

        params_list_train = [params["data_type_train"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                        params["knn_inter_nodes_max"], params["normalize"],
                                        params["scale_features"], str(params["node_feats"]), str(params["edge_feats"])]
        params_list_train = [str(param_) for param_ in params_list_train]
        graph_key = "_".join(params_list_train)

        if graph_key not in graph_list_dict_deep_learning:
            graph_list = get_graph_list(params["data_type_train"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                            params["knn_inter_nodes_max"],  normalize = params["normalize"],
                                            scale_feats = params["scale_features"],
                                            node_feats = params["node_feats"], edge_feats = params["edge_feats"],
                                            shuffle = False)
            graph_list_dict_deep_learning[graph_key] = graph_list

        params_list_test = [params["data_type_test"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                        params["knn_inter_nodes_max"], params["normalize"],
                                        params["scale_features"], str(params["node_feats"]), str(params["edge_feats"])]
        params_list_test = [str(param_) for param_ in params_list_test]
        graph_key = "_".join(params_list_test)

        if graph_key not in graph_list_dict_deep_learning:
            graph_list = get_graph_list(params["data_type_test"], params["knn_inter_nodes"], params["knn_intra_nodes"], 
                                            params["knn_inter_nodes_max"], normalize = params["normalize"],
                                            scale_feats = params["scale_features"],
                                            node_feats = params["node_feats"], edge_feats = params["edge_feats"],
                                            shuffle = False)
            graph_list_dict_deep_learning[graph_key] = graph_list
    return graph_list_dict_deep_learning

In [None]:
def train_dl(graph_list_dict_deep_learning, jobs, debug = False):
    results_list_pytorch = []

    for i, job_parameters in tqdm(enumerate(jobs), total=len(jobs)):

        k_inter = job_parameters["knn_inter_nodes"]
        k_inter_max = job_parameters["knn_inter_nodes_max"]
        k_intra = job_parameters["knn_intra_nodes"]

        scale_feats = job_parameters["scale_features"]

        #get data
        params_list_train = [job_parameters["data_type_train"], job_parameters["knn_inter_nodes"], job_parameters["knn_intra_nodes"], 
                                        job_parameters["knn_inter_nodes_max"], job_parameters["normalize"],
                                        job_parameters["scale_features"],
                                         str(job_parameters["node_feats"]), str(job_parameters["edge_feats"])]
        params_list_train = [str(param_) for param_ in params_list_train]
        graph_key = "_".join(params_list_train)
        graph_list_train = graph_list_dict_deep_learning[graph_key]

        params_list_test = [job_parameters["data_type_test"], job_parameters["knn_inter_nodes"], job_parameters["knn_intra_nodes"], 
                                        job_parameters["knn_inter_nodes_max"], job_parameters["normalize"],
                                        job_parameters["scale_features"],
                                         str(job_parameters["node_feats"]), str(job_parameters["edge_feats"])]
        params_list_test = [str(param_) for param_ in params_list_test]
        graph_key = "_".join(params_list_test)
        graph_list_test = graph_list_dict_deep_learning[graph_key]

        job_parameters["num_classes"] = 1
        job_parameters["num_node_features"] = graph_list_train[0].pyg_graph.x.shape[1] if graph_list_train[0].pyg_graph.x.shape[0] >0 else 0
        job_parameters["num_edge_features"] = graph_list_train[0].pyg_graph.edge_attr.shape[1]
        job_parameters["num_total_features"] = graph_list_train[0].edge_x.shape[1]

        indexes_train = job_parameters["index_train"]
        indexes_test = job_parameters["index_test"]
        cross_validation_groups_train = job_parameters.get("cross_validation_groups_train",[])
        cross_validation_groups_test = job_parameters.get("cross_validation_groups_test",[])

        if(indexes_train=="all"):
            indexes_train = [g.graph_id for g in graph_list_train]
        if(indexes_test=="all"):
            indexes_test = [g.graph_id for g in graph_list_test]
        if(cross_validation_groups_train=="even"):
            number_cross_validation_groups = 3
            cross_validation_groups_train = distribute_elements_to_lists(indexes_train, number_cross_validation_groups)
        if(cross_validation_groups_test=="even"):
            number_cross_validation_groups = 3
            cross_validation_groups_test = distribute_elements_to_lists(indexes_test, number_cross_validation_groups)

        indexes_train = set(indexes_train)
        indexes_test = set(indexes_test)

        cv_dataset_list = []
        if(not cross_validation_groups_train):#without cross-validation
            graph_list_train = [el for el in graph_list if el.graph_id in indexes_train]
            graph_list_test = [el for el in graph_list if el.graph_id in indexes_test]
            cv_dataset_list.append({"train":graph_list_train, "test":graph_list_test})

        else:#with cross-valudation
            for i in range(len(cross_validation_groups_test)):
                graph_list_test_cv = [el for el in graph_list_test if el.graph_id in set(cross_validation_groups_test[i])]
                graph_list_train_cv = []
                for j in range(len(cross_validation_groups_train)):
                    if(j!=i):
                        graph_list_train_cv.extend([el for el in graph_list_train if el.graph_id in set(cross_validation_groups_train[j])])

                cv_dataset_list.append({"train": graph_list_train_cv, "test": graph_list_test_cv})    

        #Train Model
        results = {"cv_results":[], "job_parameters":job_parameters, "aggregated_metrics" : None}

        for dataset in cv_dataset_list:

            result = {}
            graph_list_train, graph_list_test = dataset["train"], dataset["test"]

            model_type = job_parameters["model_type"]
            model = build_model(model_type, job_parameters["num_node_features"], job_parameters["num_edge_features"], 
                                job_parameters["num_total_features"], 1)

            model = schedule_training_GNN(job_parameters, model, 
                                               graph_list_train, graph_list_test, debug = debug)
            
            result["graphs"] = {}
            result["graphs"]["train"] = graph_list_train
            result["graphs"]["test"] = graph_list_test

            #Eval Model
            result["eval"] = pyg_eval_link_predictor(model, graph_list_test, 
                                                     criterion = None,  apply_constraints = True,
                                                    plot_roc_curve = False, debug = False)
            results["cv_results"].append(result)


        #aggregate all metrics

        all_metrics = {}
        for item in results["cv_results"]:
            individual_metrics = item["eval"]["individual_metrics"]
            for individual_graph in individual_metrics:
                all_metrics[individual_graph] = individual_metrics[individual_graph]

        results["aggregated_metrics"]  = pyg_aggregate_metrics_all(list(all_metrics.values()))

        results_list_pytorch.append(results)
    return results_list_pytorch

In [None]:
def plot_results_dl(results_list_pytorch):
    plot_df_pytorch = plot_table(results_list_pytorch, metrics_dict_entries = [["@best","metrics"],["@best","@constraints","metrics"]])
    plot_df_pytorch = plot_df_pytorch.sort_values(by=["Algorithm", "Normalize", "K Inter", 'Data Train', 'Data Test','Constraints'])
    display(plot_df_pytorch)
    plot_df_pytorch = plot_df_pytorch.drop(["Data Train", "Data Test"], axis=1)
    display(plot_df_pytorch)
    print(plot_df_to_latex(plot_df_pytorch))

In [None]:
jobs_dl = get_job_params_dl()
len(jobs_dl)

In [None]:
graph_list_dl = get_graph_list_dl(jobs_dl, debug = False)

In [None]:
results_list_dl = train_dl(graph_list_dl, jobs_dl, debug = False)

In [None]:
pd.set_option('display.max_rows', 500)
plot_results_dl(results_list_dl)

In [None]:
def annote_figure(figax, annotations, dims, fontsize = 7):
    
    start_point_y = 0.9
    
    if(dims ==2):
        start_point_x = 1.05
    else:
        start_point_x = 1.2
        
    for i, annotation in enumerate(annotations):
        figax.annotate(annotation, xy=(start_point_x, start_point_y - i * 0.1), 
                       xycoords="axes fraction", fontsize=fontsize, color="black")
    return

def annote_knn_metrics(figax, knn_intra, knn_inter, dims):
    knn_annotations = [
        r'$K_{{INTRA\_CLASS}}: {}$'.format(knn_intra),
        r'$K_{{INTER\_CLASS}}: {}$'.format(knn_inter),
    ]
    
    annote_figure(figax, knn_annotations, dims)
    return 

def annote_figure_results_metrics(figax, metrics, threshold, dims):
    
    metrics = metrics["metrics"]

    count_annotations = [
        f"Threshold: {threshold:.3f}",
        f"Acc.: {metrics['acc']}",
        f"Prec.: {metrics['precision']}",
        f"Recall: {metrics['recall']}",
        f"TP: {metrics['tp']}",
        f"FP: {metrics['fp']}",
        f"TN: {metrics['tn']}",
        f"FN: {metrics['fn']}"
    ]
    
    annote_figure(figax, count_annotations, dims)
    return

def plot_graph_results(graph, graph_metrics, k_intra, k_inter):
    plot_styles = {
                "nuclei":{"marker":"o","color":"red", "alpha":0.3},
                "golgi":{"marker":"o","color":"green", "alpha":0.3},
                "tp": {"color": "black",  "dashed": False, "alpha":1},
                "fp": {"color": "yellow", "dashed": False, "alpha":1},
                "tn": None,
                "fn": {"color": "red",  "dashed": False, "alpha":1}
    }
    dims = 3
    subplot_params = {"projection":"3d"} if dims==3 else {}

    node_list_true = graph.node_list
    edge_list_true = graph.edge_list
    edge_list_knn = graph.pyg_graph_edge_list
    true_labels = graph.pyg_graph_true_labels

    metrics_best = graph_metrics["@best"]
    optimal_threshold = metrics_best["optimal_threshold"]
    metrics_constraints = graph_metrics["@constraints"]
    metrics_constraints_threshold = graph_metrics["@best"]["@constraints"]
    pred_edge_probabilities = graph_metrics["pred_edge_probabilities"]

    fig_all = plt.figure(figsize=(18, 9), dpi= 300)#12,9
    plt.subplots_adjust(wspace=0.4, hspace=0.15)#wspace-> horizontal hspace-> vertical

    #True Graph
    figax = plt.subplot(2, 3, 1, **subplot_params)
    plot_edge_labels = ["tp"]*len(edge_list_true)
    true_graph_fig = GraphInfo.make_graph_plot(node_list_true, edge_list_true, 
                                            plot_edge_labels, plot_styles, dims = dims,
                                     title = "(a) True", figax = figax)

    #KNN Graph
    figax = plt.subplot(2, 3 , 2 , **subplot_params)

    plot_edge_labels = ["tp"]*len(edge_list_knn)
    knn_graph_fig = GraphInfo.make_graph_plot(node_list_true, edge_list_knn, 
                                              plot_edge_labels, plot_styles, dims = dims, 
                                    title = "(b) KNN", figax = figax)
    annote_knn_metrics(knn_graph_fig, k_intra, k_inter, dims)

    #Pred Best Threshold
    figax = plt.subplot(2, 3, 3, **subplot_params)
    metrics = metrics_best
    edge_list_pred_labels_best = metrics["pred_labels"]
    edge_list_pred_best = GraphInfo.edge_index_to_edge_list(graph.pyg_graph.edge_index)
    plot_edge_labels_pred_best = GraphInfo.convert_edge_preds_to_labels(true_labels, edge_list_pred_labels_best)
    
    predicted_graph_fig = GraphInfo.make_graph_plot(node_list_true,  edge_list_pred_best, 
                                            plot_edge_labels_pred_best, plot_styles, dims = dims, 
                                          title = "(c) Pred. W/ Best Threshold", figax = figax)

    annote_figure_results_metrics(figax, metrics, optimal_threshold, dims)

    #Pred Constraints With Threshold
    figax = plt.subplot(2, 3, 4, **subplot_params)
    metrics = metrics_constraints_threshold
    edge_list_pred_labels_constraints_threshold = metrics["pred_labels"]
    edge_list_pred_constraints_threshold = GraphInfo.edge_index_to_edge_list(graph.pyg_graph.edge_index)
    plot_edge_labels_pred_constraints_threshold = GraphInfo.convert_edge_preds_to_labels(true_labels, edge_list_pred_labels_constraints_threshold)
   
    predicted_graph_fig = GraphInfo.make_graph_plot(node_list_true, edge_list_pred_constraints_threshold, 
                                        plot_edge_labels_pred_constraints_threshold,  plot_styles, dims = dims, 
                                          title = "(d) Pred. W/ Best Threshold W/ Constraints", figax = figax)

    annote_figure_results_metrics(figax, metrics, optimal_threshold , dims)

    #Pred Constraints Without Threshold
    figax = plt.subplot(2, 3, 5, **subplot_params)
    metrics = metrics_constraints
    edge_list_pred_labels_constraints = metrics["pred_labels"]
    edge_list_pred_constraints = GraphInfo.edge_index_to_edge_list(graph.pyg_graph.edge_index)
    plot_edge_labels_pred_constraints = GraphInfo.convert_edge_preds_to_labels(true_labels, edge_list_pred_labels_constraints)
 
    predicted_graph_fig = GraphInfo.make_graph_plot(node_list_true, edge_list_pred_constraints, 
                                        plot_edge_labels_pred_constraints,  plot_styles, dims = dims, 
                                          title = "(e) Pred. W/ Constraints W/o Threshold", figax = figax)

    annote_figure_results_metrics(figax, metrics, 0 , dims)

    ########################
    # Fig All legends ######
    ########################

    # Create custom legend handles and labels based on plot_styles
    legend_handles = []
    legend_labels = []

    #Add Edge Legends
    for label, style in plot_styles.items():
        if isinstance(style, dict) and "dashed" in style:
            color = style["color"]
            dashed = style.get("dashed", False)
            alpha = style["alpha"]
            linestyle = "--" if dashed else "-"
            legend_handles.append(matplotlib.lines.Line2D([0], [0], color=color, linewidth=2, linestyle=linestyle, label=label, alpha = alpha))
            legend_labels.append(label.upper() + " Edge")

    # Add node legends
    legend_handles.append(matplotlib.lines.Line2D([0], [0], marker=plot_styles["nuclei"]["marker"], color="w", alpha = plot_styles["nuclei"]["alpha"],
                                    label="Nuclei", markerfacecolor=plot_styles["nuclei"]["color"], markersize=10))
    legend_labels.append("Nuclei")
    legend_handles.append(matplotlib.lines.Line2D([0], [0], marker=plot_styles["golgi"]["marker"], color="w", alpha = plot_styles["golgi"]["alpha"],
                                    label="Golgi", markerfacecolor=plot_styles["golgi"]["color"], markersize=10))
    legend_labels.append("Golgi")

    legend = fig_all.legend(legend_handles, legend_labels, loc="upper right", fontsize="small", ncol=2)

    legend.set_bbox_to_anchor((0.75, 0.15))
    plt.show()
    plt.clf()

    return

def describe_results(results_list_pytorch, make_plot = False):
    results_entry = results_list_pytorch[0]

    k_intra = results_entry["job_parameters"]["knn_intra_nodes"]
    k_inter = results_entry["job_parameters"]["knn_inter_nodes"]

    print("K_Intra", k_intra, "K_Inter", k_inter)

    print("Aggregated Metrics", json.dumps(results_entry["aggregated_metrics"], indent = 1, cls = CustomEncoder))

    for cv_crop_index, cv_crop in enumerate(results_entry["cv_results"]):
        graphs_list_test = cv_crop["graphs"]["test"]
        graphs_list_train = cv_crop["graphs"]["train"]
        graph_individual_metrics = cv_crop["eval"]["individual_metrics"]
        graphs_indexes = {"train":[g.graph_id for g in graphs_list_train], "test":[g.graph_id for g in graphs_list_test]}

        print("##############################################################\n")
        print("CV Crop Index",cv_crop_index, "", graphs_indexes)
        #print(json.dumps(cv_crop["eval"]["aggregated_metrics"], indent = 1, cls = CustomEncoder))

        for graph_id in graphs_indexes["test"]:
            graph_matches = [g for g in graphs_list_test if g.graph_id ==graph_id]
            if len(graph_matches)!=1:
                raise ValueError("Multiple graphs with same ID!")
            graph_test = graph_matches[0]
            graph_metrics = graph_individual_metrics[graph_id]

            if make_plot:
                print(graph_id)

                #print(json.dumps(graph_metrics, cls = CustomEncoder))
                plot_graph_results(graph_test, graph_metrics, k_intra, k_inter)
        print("##############################################################\n")

describe_results(results_list_dl, make_plot = True)

In [None]:
def save_results(output_folder, results_list_pytorch):
    if not os.path.exists(output_folder):
        os.makedirs(output_folder)
    
    results_count = 0
    #Save results to file
    for results_entry in results_list_pytorch:
        results_entry_output_folder = os.path.join(output_folder, "Results_"+str(results_count))
        if not os.path.exists(results_entry_output_folder):
            os.makedirs(results_entry_output_folder)
            
        desc_file_path = os.path.join(results_entry_output_folder, "params.json")
        with open(desc_file_path, 'w') as f:
            json.dump(results_entry["job_parameters"], f, indent = 2)
        
        data_type_train = results_entry["job_parameters"]["data_type_train"]
        data_type_test = results_entry["job_parameters"]["data_type_test"]
        model_type = results_entry["job_parameters"]["model_type"]
        k_intra = results_entry["job_parameters"]["knn_intra_nodes"]
        k_inter = results_entry["job_parameters"]["knn_inter_nodes"]
        node_feats_str = str(results_entry["job_parameters"]["node_feats"]).replace("[","").replace(",","_").replace("]","").replace("\'","")
        edge_feats_str = str(results_entry["job_parameters"]["edge_feats"]).replace("[","").replace(",","_").replace("]","").replace("\'","")

        print("K_Intra", k_intra, "K_Inter", k_inter)

        print("Aggregated Metrics", json.dumps(results_entry["aggregated_metrics"], indent = 1, cls = CustomEncoder))

        for cv_crop_index, cv_crop in enumerate(results_entry["cv_results"]):
            graphs_list_test = cv_crop["graphs"]["test"]
            graphs_list_train = cv_crop["graphs"]["train"]
            graph_individual_metrics = cv_crop["eval"]["individual_metrics"]
            graphs_indexes = {"train":[g.graph_id for g in graphs_list_train], "test":[g.graph_id for g in graphs_list_test]}

            print("##############################################################\n")
            print("CV Crop Index",cv_crop_index, "", graphs_indexes)

            for graph_id in graphs_indexes["test"]:
                graph_matches = [g for g in graphs_list_test if g.graph_id ==graph_id]
                if len(graph_matches)!=1:
                    raise ValueError("Multiple graphs with same ID!")
                graph_test = graph_matches[0]
                graph_metrics = graph_individual_metrics[graph_id]

                edge_list = GraphInfo.edge_index_to_edge_list(graph_test.pyg_graph.edge_index)
                edge_df = GraphInfo.edge_list_to_edge_df(edge_list)
                #nodes_df = graph_test.nodes_df

                edge_df_constraints = edge_df.copy()
                edge_df_constraints["edge_label"] = graph_metrics["@constraints"]["pred_labels"]            
                constraints_array = pred_df_to_csv(edge_df_constraints, graph_test.nodes_df_original)

                output_file_path = os.path.join(results_entry_output_folder, graph_test.graph_id)
                array_to_csv(constraints_array, output_file_path)
                
        results_count+=1

In [None]:
save_results(os.path.join("./results","results_automatic_test"), results_list_dl)