In [None]:
!pip install node2vec



In [None]:
# Standard Libraries
import random  # Used for generating random numbers
import pickle # for optrix dataset

# Third Party Libraries for Data Handling and Analysis
import numpy as np  # Fundamental package for numerical computations
import pandas as pd  # Data manipulation and analysis library

# Third Party Libraries for Visualization
import matplotlib.pyplot as plt  # Plotting library
import seaborn as sns  # Statistical data visualization library

# Third Party Libraries for Graph Analysis
import networkx as nx  # Network analysis library
from node2vec import Node2Vec  # Node2Vec algorithm
import folium
import math

# Third Party Libraries for Machine Learning
from sklearn.decomposition import PCA  # Principal Component Analysis
from sklearn.linear_model import LogisticRegression  # Logistic Regression classifier
from sklearn.manifold import TSNE  # t-SNE for dimensionality reduction
from sklearn.metrics import (
    auc, confusion_matrix, precision_recall_curve,
    precision_score, recall_score, roc_auc_score, roc_curve, average_precision_score, f1_score, accuracy_score
) # Metrics for evaluating machine learning models
from sklearn.model_selection import train_test_split  # Splitting data into training and testing sets
from sklearn.preprocessing import LabelEncoder, MinMaxScaler  # Preprocessing tools

import os
import random
import numpy as np
import torch

DEFAULT_SEED = 1337

class Visualizer:
    """A class to visualize model performance and node embeddings."""

    def __init__(self, test_labels, test_proba, test_pred, embeddings=None, positive_samples=None, negative_samples=None):
        """
        Initializes the Visualizer with testing data and embeddings.

        Parameters:
        test_labels (array-like): True labels for the test data.
        test_proba (array-like): Predicted probabilities for the test data.
        test_pred (array-like): Predicted labels for the test data.
        embeddings (array-like, optional): Embeddings of the nodes.
        positive_samples (array-like, optional): Positive sample indices.
        negative_samples (array-like, optional): Negative sample indices.
        """
        self.test_labels = test_labels
        self.test_proba = test_proba
        self.test_pred = test_pred
        self.embeddings = embeddings
        self.positive_samples = positive_samples
        self.negative_samples = negative_samples

    @staticmethod
    def min_max_scaling(data):
        """Performs Min-Max scaling on the data."""
        min_val = np.min(data)
        max_val = np.max(data)
        scaled_data = (data - min_val) / (max_val - min_val)
        return scaled_data

    def plot_probabilities_histogram(self):
        """Plots a histogram of predicted probabilities."""
        plt.figure(figsize=(8, 6))
        plt.hist(self.test_proba, bins=20, color='skyblue', edgecolor='black')
        plt.xlabel('Predicted Probabilities')
        plt.ylabel('Frequency')
        plt.title('Histogram of Predicted Probabilities')
        plt.show()

    def plot_confusion_matrix(self):
        """Plots a confusion matrix using seaborn heatmap."""
        cm = confusion_matrix(self.test_labels, self.test_pred)

        # Define labels for the confusion matrix
        labels = ['Actual Negative', 'Actual Positive']
        categories = ['Predicted Negative', 'Predicted Positive']

        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=False,
                    xticklabels=categories, yticklabels=labels)
        plt.xlabel('Predicted')
        plt.ylabel('True')
        plt.title('Confusion Matrix')
        plt.show()

    def plot_roc_curve(self):
        """Plots a ROC curve for the model."""
        fpr, tpr, _ = roc_curve(self.test_labels, self.test_proba)
        roc_auc = auc(fpr, tpr)

        plt.figure(figsize=(8, 6))
        plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')
        plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('Receiver Operating Characteristic (ROC) Curve')
        plt.legend(loc='lower right')
        plt.show()

    def plot_precision_recall_curve(self):
        """Plots a Precision-Recall curve for the model."""
        precision, recall, _ = precision_recall_curve(self.test_labels, self.test_proba)

        plt.figure(figsize=(8, 6))
        plt.plot(recall, precision, color='darkorange', lw=2)
        plt.xlabel('Recall')
        plt.ylabel('Precision')
        plt.title('Precision-Recall Curve')
        plt.grid(True)
        plt.show()

    def plot_accuracy_vs_threshold(self):
        """Plots accuracy of the model against different thresholds."""
        thresholds = np.linspace(0.1, 0.9, 9)
        accuracies = [
            np.mean((self.test_proba >= threshold) == self.test_labels)
            for threshold in thresholds
        ]

        plt.figure(figsize=(8, 6))
        plt.plot(thresholds, accuracies, marker='o', linestyle='-')
        plt.xlabel('Threshold')
        plt.ylabel('Accuracy')
        plt.title('Accuracy vs. Threshold')
        plt.grid(True)
        plt.show()

    def visualize_TSNE_embedding(self, perplexity=30, learning_rate=200):
        """Visualizes node embeddings using t-SNE."""
        tsne = TSNE(n_components=2, perplexity=perplexity, learning_rate=learning_rate, random_state=42)
        normalized_embeddings = self.get_normalized_embeddings()
        tsne_results = tsne.fit_transform(normalized_embeddings)

        plt.figure(figsize=(8, 6))
        plt.scatter(tsne_results[:, 0], tsne_results[:, 1], alpha=0.5)
        plt.title('t-SNE Visualization of Node Embeddings')
        plt.xlabel('Dimension 1')
        plt.ylabel('Dimension 2')
        plt.show()

    def get_normalized_embeddings(self):
        """Normalizes the node embeddings using Min-Max scaling."""
        scaler = MinMaxScaler()
        normalized_embeddings = scaler.fit_transform(self.embeddings)
        return normalized_embeddings

    def visualize_PCA_embedding(self, n_components=2):
        """Visualizes node embeddings using PCA."""
        normalized_embeddings = self.get_normalized_embeddings()
        pca = PCA(n_components=n_components)  # You can change the number of components as needed
        reduced_embeddings = pca.fit_transform(normalized_embeddings)

        plt.figure(figsize=(8, 6))
        plt.scatter(reduced_embeddings[:, 0], reduced_embeddings[:, 1], alpha=0.5)
        plt.title('Visualization of Normalized Node Embeddings')
        plt.xlabel(f'Principal Component 1')
        plt.ylabel(f'Principal Component 2')
        plt.show()

# ------------------- Datatset----------------------#

# ----------------------COST------------------
def create_cost266_graph():
    # Create an empty graph
    G = nx.Graph()

    # Nodes
    nodes = [i for i in range(1, 38)]
    G.add_nodes_from(nodes)

    # Computed fiber lengths from haversine values
    edges = [
            (1, 8, 259.84965666159),
            (1, 14, 1067.1610071760647),
            (1, 15, 551.932850462205),
            (1, 19, 540.2969276038555),
            (2, 26, 1362.6064844523005),
            (2, 31, 793.8895623997772),
            (2, 36, 1500),
            (3, 21, 760.3967118138279),
            (3, 22, 508.18611493655453),
            (3, 30, 1244.055742188373),
            (4, 9, 474.5629290263091),
            (4, 31, 486.21646810796915),
            (4, 36, 551.0721619498521),
            (5, 10, 539.8567721623151),
            (5, 15, 379.99508008152793),
            (5, 24, 757.6386535506992),
            (5, 28, 420.89799110496745),
            (5, 35, 774.6461442087643),
            (6, 14, 609.3320278949963),
            (6, 19, 238.79773630910168),
            (7, 21, 833.7142844783104),
            (7, 22, 757.0555809713325),
            (7, 27, 747.5017116213392),
            (8, 12, 263.47457585633975),
            (8, 27, 392.4792710554075),
            (9, 17, 435.9301886415584),
            (9, 28, 667.8119982971642),
            (10, 25, 720.4908600074742),
            (10, 32, 776.2740263923121),
            (11, 14, 462.5816015725619),
            (11, 19, 689.4418556744857),
            (12, 13, 274.66261609360544),
            (13, 15, 591.9943959052813),
            (13, 24, 456.2237094139668),
            (13, 33, 271.7320304315402),
            (16, 25, 1182.486309099843),
            (16, 32, 597.7929501950997),
            (16, 35, 1370.7464471274227),
            (17, 35, 383.0241099250181),
            (18, 19, 1977.1502858443332),
            (18, 21, 750.3016254076838),
            (18, 30, 470.97351082501496),
            (19, 27, 513.4593301745787),
            (20, 22, 410.3594835737697),
            (20, 27, 595.1118221794119),
            (20, 37, 507.63930413732726),
            (22, 29, 906.6723231378437),
            (23, 24, 521.4075026201809),
            (23, 29, 720.9026371475536),
            (23, 37, 326.44767652831126),
            (24, 34, 534.012623777059),
            (26, 29, 636.4633146922802),
            (27, 33, 600.37691157053),
            (28, 34, 375.52200677111244),
            (29, 36, 782.9414880645511),
            (33, 37, 218.2728433959167),
            (34, 36, 400.6139057896578)
            ]


    # Add Edges
    # G.add_edges_from(edges)
    G.add_weighted_edges_from(edges)

    # Node attributes (latitude and longitude)
    node_attributes = {
                        1: {"lat": 52.35, "long": 4.90},
                        2: {"lat": 38.00, "long": 23.73},
                        3: {"lat": 41.37, "long": 2.18},
                        4: {"lat": 44.83, "long": 20.50},
                        5: {"lat": 52.52, "long": 13.40},
                        6: {"lat": 52.47, "long": -1.88},
                        7: {"lat": 44.85, "long": -0.57},
                        8: {"lat": 50.83, "long": 4.35},
                        9: {"lat": 47.50, "long": 19.08},
                        10: {"lat": 55.72, "long": 12.57},
                        11: {"lat": 53.33, "long": -6.25},
                        12: {"lat": 51.23, "long": 6.78},
                        13: {"lat": 50.10, "long": 8.67},
                        14: {"lat": 55.85, "long": -4.25},
                        15: {"lat": 53.55, "long": 10.02},
                        16: {"lat": 60.17, "long": 24.97},
                        17: {"lat": 50.05, "long": 19.95},
                        18: {"lat": 38.73, "long": -9.13},
                        19: {"lat": 51.50, "long": -0.17},
                        20: {"lat": 45.73, "long": 4.83},
                        21: {"lat": 40.42, "long": -3.72},
                        22: {"lat": 43.30, "long": 5.37},
                        23: {"lat": 45.47, "long": 9.17},
                        24: {"lat": 48.13, "long": 11.57},
                        25: {"lat": 59.93, "long": 10.75},
                        26: {"lat": 38.12, "long": 13.35},
                        27: {"lat": 48.87, "long": 2.33},
                        28: {"lat": 50.08, "long": 14.43},
                        29: {"lat": 41.88, "long": 12.50},
                        30: {"lat": 37.38, "long": -5.98},
                        31: {"lat": 42.75, "long": 23.33},
                        32: {"lat": 59.33, "long": 18.05},
                        33: {"lat": 48.58, "long": 7.77},
                        34: {"lat": 48.22, "long": 16.37},
                        35: {"lat": 52.25, "long": 21.00},
                        36: {"lat": 45.83, "long": 16.02},
                        37: {"lat": 47.38, "long": 8.55}
                      }

    # Setting node attributes
    nx.set_node_attributes(G, node_attributes)

    return G


def haversine_distance(lat1, lon1, lat2, lon2):
    # Convert latitude and longitude from degrees to radians
    lat1 = math.radians(lat1)
    lon1 = math.radians(lon1)
    lat2 = math.radians(lat2)
    lon2 = math.radians(lon2)

    # Haversine formula
    dlat = lat2 - lat1
    dlon = lon2 - lon1
    a = math.sin(dlat/2)**2 + math.cos(lat1) * math.cos(lat2) * math.sin(dlon/2)**2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    R = 6371  # Radius of the Earth in kilometers
    distance = R * c

    return distance

def calculate_edge_attributes(G):
    edge_betweenness = nx.edge_betweenness_centrality(G, normalized=False)
    haversine_link_lengths = []

    for edge in G.edges():
        node1, node2 = edge
        # Calculate the link length (Haversine distance between nodes)
        lat1, lon1 = G.nodes[node1]["lat"], G.nodes[node1]["long"]
        lat2, lon2 = G.nodes[node2]["lat"], G.nodes[node2]["long"]

        # Calculate link length using haversine
        haversine_link_length= haversine_distance(lat1, lon1, lat2, lon2)
        haversine_link_lengths.append( haversine_link_length)
        computed_fiber_link_length = compute_fiber_length(haversine_link_length)

        # Set edge attributes
        G.edges[edge]["edge_betweenness"] = edge_betweenness.get(edge, 0.0)
        G.edges[edge]["haversine_link_length"] = haversine_link_length
        G.edges[edge]["conputed_fiber_link_length"] =  computed_fiber_link_length

    return G, haversine_link_lengths


def visualize_cost_on_folium(G):
    # Create a base map centered around Europe
    m = folium.Map(location=[54, 15], zoom_start=4)

    # Add nodes to the map
    for node, attributes in G.nodes(data=True):
        lat = attributes['lat']
        long = attributes['long']
        folium.Marker([lat, long],
                      icon=folium.Icon(color='blue', icon='circle', prefix='fa')).add_to(m)

    # Add edges to the map
    for u, v in G.edges():
        start_lat = G.nodes[u]['lat']
        start_long = G.nodes[u]['long']
        end_lat = G.nodes[v]['lat']
        end_long = G.nodes[v]['long']
        coordinates = [[start_lat, start_long], [end_lat, end_long]]
        folium.PolyLine(coordinates, color="navy", weight=2).add_to(m)

    return m

def visualize_cost_on_folium_with_edge_length(G): # DOES NOT look pretty on the graph - don't use it
    # Create a base map centered around Europe
    m = folium.Map(location=[54, 15], zoom_start=4)

    # Add nodes to the map
    for node, attributes in G.nodes(data=True):
        lat = attributes['lat']
        long = attributes['long']
        folium.Marker([lat, long], icon=folium.Icon(color='blue', icon='circle', prefix='fa')).add_to(m)

    # Add edges to the map with edge lengths as labels
    for u, v, data in G.edges(data=True):
        start_lat = G.nodes[u]['lat']
        start_long = G.nodes[u]['long']
        end_lat = G.nodes[v]['lat']
        end_long = G.nodes[v]['long']
        coordinates = [[start_lat, start_long], [end_lat, end_long]]

        # Get the edge length from edge attributes ("haversine_link_length")
        edge_length = data.get("haversine_link_length", None)

        # Add the edge to the map with a label for edge length
        folium.PolyLine(coordinates, color="navy", weight=2).add_to(m)
        if edge_length is not None:
            label = f"Edge Length: {edge_length:.2f} km"
            folium.Marker([(start_lat + end_lat) / 2, (start_long + end_long) / 2],
                          icon=folium.DivIcon(html=f'<div>{label}</div>')).add_to(m)

    return m


def print_edge_attributes(G):
    for edge in G.edges():
      node1, node2 = edge
      edge_betweenness = G.edges[edge]["edge_betweenness"]
      haversine_link_length = G.edges[edge]["haversine_link_length"]
      computed_fiber_link_length = G.edges[edge]["conputed_fiber_link_length"]
      print(f"Edge ({node1}, {node2}): Edge Betweenness = {edge_betweenness}, Haversine Link Length = {haversine_link_length}, Computed Fiber Link Length = {computed_fiber_link_length}")



def compute_fiber_length(haversine_distance):
    """
    Compute the length of fiber based on haversine distance.

    Parameters:
    - haversine_distance (float): Haversine distance in kilometers.

    Returns:
    - float: Length of fiber in kilometers.
    """

    if haversine_distance < 1000:
        return 1.5 * haversine_distance
    elif 1000 <= haversine_distance <= 1200:
        return 1500
    else:
        return 1.25 * haversine_distance

# ------------------- CORONET ----------------------#
def create_coronet_conus_60_graph():
  # Create an empty graph
    G = nx.Graph()

    # Nodes
    nodes = [i for i in range(1, 60)]
    G.add_nodes_from(nodes)

    # Edges
    # Add edges and weights
    edges = [
              (1, 8, 277.1),
              (1, 53, 234.2),
              (2, 16, 647.7),
              (2, 18, 436.9),
              (3, 7, 266.2),
              (3, 10, 439.2),
              (3, 22, 554.1),
              (4, 37, 179.2),
              (4, 59, 67.2),
              (5, 21, 500.2),
              (5, 32, 147.4),
              (6, 51, 1293.1),
              (6, 30, 1468),
              (7, 31, 352.4),
              (7, 32, 582.5),
              (8, 41, 79.9),
              (9, 53, 272),
              (9, 13, 336.4),
              (10, 19, 159.9),
              (11, 52, 501.6),
              (11, 17, 459.1),
              (11, 29, 165.3),
              (12, 14, 193.2),
              (12, 27, 177.5),
              (12, 59, 777.1),
              (13, 14, 239),
              (13, 56, 191.2),
              (14, 39, 294.7),
              (15, 58, 560),
              (15, 18, 1189.9),
              (15, 21, 432.7),
              (15, 25, 554),
              (16, 36, 920.3),
              (16, 44, 731.3),
              (17, 56, 107.2),
              (18, 45, 964.5),
              (18, 57, 505.7),
              (19, 59, 508),
              (19, 27, 690.4),
              (19, 42, 131.1),
              (20, 33, 215.6),
              (20, 41, 125.6),
              (21, 45, 425),
              (22, 42, 809),
              (22, 60, 522),
              (23, 36, 314),
              (23, 52, 470.9),
              (23, 58, 418.4),
              (24, 35, 786),
              (24, 26, 484.5),
              (24, 38, 495.9),
              (24, 44, 700),
              (25, 31, 639.2),
              (26, 46, 223.8),
              (26, 49, 150.7),
              (27, 31, 295.1),
              (27, 52, 473.8),
              (28, 55, 397.1),
              (28, 60, 129.8),
              (29, 30, 568.3),
              (30, 36, 561.3),
              (32, 54, 653.4),
              (33, 34, 24.2),
              (33, 50, 199.6),
              (34, 37, 136.1),
              (35, 43, 132.6),
              (35, 44, 1135.7),
              (35, 47, 25.7),
              (37, 50, 193.4),
              (38, 46, 574.7),
              (38, 57, 222.5),
              (39, 50, 473.6),
              (40, 43, 937.7),
              (40, 51, 279.1),
              (44, 51, 1463.4),
              (47, 48, 77.2),
              (48, 49, 447),
              (50, 53, 223.8),
              (54, 55, 394.1) ]

    # Add weighted edges to the graph
    G.add_weighted_edges_from(edges)


    # Node attributes (latitude and longitude)
    node_attributes = {
                        1: {"lat": 42.67, "long": -73.8},
                        2: {"lat": 35.12, "long": -106.62},
                        3: {"lat": 33.76, "long": -84.42},
                        4: {"lat": 39.3, "long": -76.61},
                        5: {"lat": 30.45, "long": -91.13},
                        6: {"lat": 45.79, "long": -108.54},
                        7: {"lat": 33.53, "long": -86.8},
                        8: {"lat": 42.34, "long": -71.02},
                        9: {"lat": 42.89, "long": -78.86},
                        10: {"lat": 35.2, "long": -80.83},
                       11: {"lat": 41.84, "long": -87.68},
                        12: {"lat": 39.14, "long": -84.51},
                        13: {"lat": 41.48, "long": -81.68},
                        14: {"lat": 39.99, "long": -82.99},
                        15: {"lat": 32.79, "long": -96.77},
                        16: {"lat": 39.77, "long": -104.87},
                        17: {"lat": 42.38, "long": -83.1},
                        18: {"lat": 31.85, "long": -106.44},
                        19: {"lat": 36.08, "long": -79.83},
                        20: {"lat": 41.77, "long": -72.68},
                        21: {"lat": 29.77, "long": -95.39},
                        22: {"lat": 30.33, "long": -81.66},
                        23: {"lat": 39.12, "long": -94.73},
                        24: {"lat": 36.21, "long": -115.22},
                        25: {"lat": 34.72, "long": -92.35},
                        26: {"lat": 34.11, "long": -118.41},
                        27: {"lat": 38.22, "long": -85.74},
                        28: {"lat": 25.78, "long": -80.21},
                        29: {"lat": 43.06, "long": -87.97},
                        30: {"lat": 44.96, "long": -93.27},
                        31: {"lat": 36.17, "long": -86.78},
                        32: {"lat": 30.07, "long": -89.93},
                        33: {"lat": 40.67, "long": -73.94},
                        34: {"lat": 40.72, "long": -74.17},
                        35: {"lat": 37.77, "long": -122.22},
                        36: {"lat": 41.26, "long": -96.01},
                        37: {"lat": 40.01, "long": -75.13},
                        38: {"lat": 33.54, "long": -112.07},
                        39: {"lat": 40.3, "long": -80.13},
                        40: {"lat": 45.54, "long": -122.66},
                        41: {"lat": 41.82, "long": -71.42},
                        42: {"lat": 35.82, "long": -78.66},
                        43: {"lat": 38.57, "long": -121.47},
                        44: {"lat": 40.78, "long": -111.93},
                        45: {"lat": 29.46, "long": -98.51},
                        46: {"lat": 32.81, "long": -117.14},
                        47: {"lat": 37.66, "long": -122.42},
                        48: {"lat": 37.3, "long": -121.85},
                        49: {"lat": 34.43, "long": -119.72},
                        50: {"lat": 41.4, "long": -75.67},
                        51: {"lat": 47.62, "long": -122.35},
                        52: {"lat": 38.64, "long": -90.24},
                        53: {"lat": 43.04, "long": -76.14},
                        54: {"lat": 30.46, "long": -84.28},
                        55: {"lat": 27.96, "long": -82.48},
                        56: {"lat": 41.66, "long": -83.58},
                        57: {"lat": 32.2, "long": -110.89},
                        58: {"lat": 36.13, "long": -95.92},
                        59: {"lat": 38.91, "long": -77.02},
                        60: {"lat": 26.75, "long": -80.13}
                    }

    # Setting node attributes
    nx.set_node_attributes(G, node_attributes)

    return G


def visualize_coronet_conus_60_on_folium(G):
    # Create a base map centered around the USA
    m = folium.Map(location=[39.8283, -98.5795], zoom_start=4)

    # Add nodes to the map
    for node, attributes in G.nodes(data=True):
        lat = attributes['lat']
        long = attributes['long']
        folium.Marker([lat, long],
                      icon=folium.Icon(color='blue', icon='circle', prefix='fa')).add_to(m)

    # Add edges to the map
    for u, v in G.edges():
        start_lat = G.nodes[u]['lat']
        start_long = G.nodes[u]['long']
        end_lat = G.nodes[v]['lat']
        end_long = G.nodes[v]['long']
        coordinates = [[start_lat, start_long], [end_lat, end_long]]
        folium.PolyLine(coordinates, color="navy", weight=2).add_to(m)

    return m


# ------------------- Datatset----------------------#


def generate_node2vec_embeddings(G, dimensions=64, walk_length=30, num_walks=200, workers=4, random_seed=42):
    """
    Generates Node2Vec embeddings for the nodes in graph G.

    Parameters:
    - G (networkx graph): The input graph.
    - dimensions (int): Dimensionality of the node embeddings.
    - walk_length (int): Length of walk per source.
    - num_walks (int): Number of walks per source.
    - workers (int): Number of workers for parallel computation.
    - random_seed (int): Seed for the random number generator.

    Returns:
    - embeddings (numpy array): The generated node embeddings.
    """
    node2vec = Node2Vec(G, dimensions=dimensions, walk_length=walk_length, num_walks=num_walks, workers=workers, seed=random_seed)
    model = node2vec.fit(window=10, min_count=1, batch_words=4)
    embeddings = np.array([model.wv[str(node)] for node in G.nodes])
    return embeddings


def get_edge_and_labels(G):
    """
    Extracts edges and labels from the graph G.

    Parameters:
    - G (networkx graph): The input graph.

    Returns:
    - edges (list): A list of tuples representing edges.
    - non_edges (list): A list of tuples representing non-edges.
    - labels (list): A list of labels (1 for edges, 0 for non-edges).
    """
    edges = [(edge[0], edge[1]) for edge in G.edges]
    labels = [1] * len(edges)
    non_edges = list(nx.non_edges(G))
    selected_non_edges = np.random.choice(len(non_edges), len(edges), replace=False)
    for idx in selected_non_edges:
        edges.append(non_edges[idx])
        labels.append(0)
    return edges, non_edges, labels


def split_data(edges, labels, test_size=0.2, random_state=42):
    """
    Splits the data into training and testing sets.

    Parameters:
    - edges (list): A list of tuples representing edges.
    - labels (list): A list of labels.
    - test_size (float): The proportion of the data to include in the test split.
    - random_state (int): Seed for the random number generator.

    Returns:
    - Train and test split data.
    """
    return train_test_split(edges, labels, test_size=test_size, random_state=random_state)

def add_distance_feature(train_edges, test_edges, G):

    # Compute distances
    train_distances = []
    for edge in train_edges:
        node1, node2 = edge
        lat1, lon1 = G.nodes[node1]["lat"], G.nodes[node1]["long"]
        lat2, lon2 = G.nodes[node2]["lat"], G.nodes[node2]["long"]
        distance = haversine_distance(lat1, lon1, lat2, lon2)
        train_distances.append(distance)

    test_distances = []
    for edge in test_edges:
        node1, node2 = edge
        lat1, lon1 = G.nodes[node1]["lat"], G.nodes[node1]["long"]
        lat2, lon2 = G.nodes[node2]["lat"], G.nodes[node2]["long"]
        distance = haversine_distance(lat1, lon1, lat2, lon2)
        test_distances.append(distance)

    return train_distances, test_distances


def train_classifier(train_distances, train_edges, train_labels, embeddings, C=1.0, penalty='l2', solver='lbfgs'):
    """
    Trains a logistic regression classifier.

    Parameters:
    - train_edges (list): A list of tuples representing training edges.
    - train_labels (list): A list of labels for the training edges.
    - embeddings (numpy array): The node embeddings.

    Returns:
    - clf (LogisticRegression object): The trained classifier.
    """
    # Create feature arrays
    train_features = np.hstack([np.array(train_distances).reshape(-1,1),
                              np.array([np.dot(embeddings[i], embeddings[j]) for i, j in train_edges]).reshape(-1,1)])
    clf = LogisticRegression(C=C, penalty=penalty, solver=solver).fit(train_features, train_labels)
    return clf


def evaluate_classifier(clf, test_distances, test_edges, test_labels, embeddings, threshold=0.5):
    """
    Evaluates the logistic regression classifier.

    Parameters:
    - clf (LogisticRegression object): The trained classifier.
    - test_edges (list): A list of tuples representing testing edges.
    - test_labels (list): A list of labels for the testing edges.
    - embeddings (numpy array): The node embeddings.

    Returns:
    - precision (float): Precision score.
    - recall (float): Recall score.
    - auc_roc (float): ROC-AUC score.
    - test_score (float): Accuracy score.
    - test_pred (array): Predicted labels for the test set.
    - test_proba (array): Predicted probabilities for the test set.

    """
    test_features = np.hstack([np.array(test_distances).reshape(-1,1),
                              np.array([np.dot(embeddings[i], embeddings[j]) for i, j in test_edges]).reshape(-1,1)])
    test_proba = clf.predict_proba(test_features)[:, 1]
    # Apply threshold to get predictions
    test_pred = (test_proba >= threshold).astype(int)
    test_score = clf.score(test_features, test_labels)
    precision = precision_score(test_labels, test_pred)
    recall = recall_score(test_labels, test_pred)
    auc_roc = roc_auc_score(test_labels, test_proba)
    avg_precision = average_precision_score(test_labels, test_proba)
    f1 = f1_score(test_labels, test_pred)
    ap_k = []
    for k in [5, 10, 20]:
      top_k_preds = test_pred[-k:]
      ap_k.append(accuracy_score(test_labels[-k:], top_k_preds))

    return  ap_k, f1, avg_precision, precision, recall, auc_roc, test_score, test_pred, test_proba  # return test_pred and test_proba

# Add this method after the evaluate_classifier method
def print_prediction_analysis(test_edges, test_labels, test_pred, G):
    # Get haversine distances
    distances = []
    for edge in test_edges:
        node1, node2 = edge
        lat1, lon1 = G.nodes[node1]["lat"], G.nodes[node1]["long"]
        lat2, lon2 = G.nodes[node2]["lat"], G.nodes[node2]["long"]
        distance = haversine_distance(lat1, lon1, lat2, lon2)
        distances.append(distance)

    # Print analysis
    print("Correctly predicted connected edges:")
    print ("test_pred[i] == 1 and test_labels[i] == 1")
    for i, edge in enumerate(test_edges):
        if test_pred[i] == 1 and test_labels[i] == 1:
            print(edge, distances[i])

    print("\nIncorrectly predicted connected edges:")
    print("test_pred[i] == 1 and test_labels[i] == 0")
    for i, edge in enumerate(test_edges):
        if test_pred[i] == 1 and test_labels[i] == 0:
            print(edge, distances[i])

    print("\nCorrectly predicted unconnected edges:")
    print("test_pred[i] == 0 and test_labels[i] == 0")
    for i, edge in enumerate(test_edges):
        if test_pred[i] == 0 and test_labels[i] == 0:
            print(edge, distances[i])

    print("\nIncorrectly predicted unconnected edges:")
    print("test_pred[i] == 0 and test_labels[i] == 1")
    for i, edge in enumerate(test_edges):
        if test_pred[i] == 0 and test_labels[i] == 1:
            print(edge, distances[i])

def set_up_lr_hyper_params():
    # Hyperparameter values to evaluate
    Cs = [0.001, 0.01, 0.1, 1, 10, 100]
    penalties = ['l1', 'l2']
    solvers = ['liblinear', 'saga', 'lbfgs']
    return  Cs, penalties, solvers

def setup_node2vec_hyper_params():
    embeddings_dim = [32, 64, 128, 256, 512]
    walk_length = [10, 20, 30, 40, 50, 60]
    num_walks = [20, 40, 60, 80, 100, 120, 140, 160, 180, 200, 220, 240, 260, 280, 300]
    return embeddings_dim, walk_length, num_walks

def train(G, random_seed, test_size, threshold):
    """
    Training function to generate embeddings and evaluate the classifier.

    # Parameters:
    # - config (dict): Configuration dictionary containing hyperparameters for embedding generation.

    No return value.
    """

    # Logistic Regression hyperparams
    Cs, penalties, solvers =  set_up_lr_hyper_params()
    # Default configuration penalty: 'l2', C: 1.0, solver: 'lbfgs'


    # Node2vec hyperparams
    embeddings_dim, walk_length, num_walks = setup_node2vec_hyper_params()


    # Training code
    embeddings = generate_node2vec_embeddings(G,
                                             dimensions=embeddings_dim[1], # 64
                                             walk_length=walk_length[2], # 30
                                             num_walks=num_walks[9] #200
                                            )

    # Extract edges, non-edges, and labels from the graph
    edges, non_edges, labels = get_edge_and_labels(G)

    # Split the data into training and testing sets
    train_edges, test_edges, train_labels, test_labels = split_data(edges, labels, test_size, random_seed)

    # Call the method
    train_distances, test_distances = add_distance_feature(train_edges, test_edges, G)

    # Separate positive and negative samples for later visualization
    positive_samples = [(i, j) for i, j, label in zip(train_edges, test_edges, train_labels + test_labels) if label == 1]
    negative_samples = [(i, j) for i, j, label in zip(train_edges, test_edges, train_labels + test_labels) if label == 0]

    # Train classifier for 2 features
    clf = train_classifier(train_distances, train_edges, train_labels, embeddings, Cs[3],  # C: 1.0,
                                                                  penalties[1],  # penalty: 'l2'
                                                                  solvers[2]  # solver: 'lbfgs'
                          ) # pass on the Cs, penalties, solvers


    # Evaluation for 2 features
    ap_k, f1, avg_precision, precision, recall, auc_roc, test_score, test_pred, test_proba = evaluate_classifier(clf, test_distances, test_edges, test_labels, embeddings, threshold)

    print("Accuracy@K: ", ap_k)
    # Accuracy@K evaluates if true positives appear in the top K predictions

    print("F1 Score: ", f1)
    # F1 combines precision and recall as a harmonic mean

    print("Average Precision: ", avg_precision)
    # Avg precision summarizes precision across all recall levels

    print("Precision: ", precision)
    # Precision measures correctness of positive predictions

    print("Recall: ", recall)
    # Recall evaluates ability to find all true positive links

    print("AUC-ROC: ", auc_roc)
    # AUC-ROC measures ability to distinguish true/false links

    print("Accuracy: ", test_score)
    # Accuracy measures overall correctness of predictions

    # Call the method after evaluating classifier
    # print_prediction_analysis(test_edges, test_labels, test_pred, G)

    return embeddings, positive_samples, negative_samples, ap_k, f1, avg_precision, precision, recall, auc_roc, test_score, test_pred, test_proba, train_edges, test_edges, train_labels, test_labels

def visualise_metrics(test_labels, test_proba, test_pred, embeddings, positive_samples, negative_samples, perplexity):
    # Create a Visualizer instance for plotting and visualization
    visualizer = Visualizer(test_labels, test_proba, test_pred, embeddings=embeddings, positive_samples=positive_samples, negative_samples=negative_samples)
    # Various plots and visualizations
    visualizer.plot_roc_curve()
    visualizer.plot_precision_recall_curve()
    visualizer.plot_confusion_matrix()
    visualizer.plot_probabilities_histogram()
    visualizer.plot_accuracy_vs_threshold()
    visualizer.visualize_TSNE_embedding(perplexity)
    visualizer.visualize_PCA_embedding()


def setup_cost266_graph():
    # Source SNDLIb
    # Create and preprocess the cost graph
    G = create_cost266_graph()
    # Set edge attributes
    G, haversine_link_lengths = calculate_edge_attributes(G)
    # print
    # print_edge_attributes(G)
    # Convert the node labels to integers starting from 0
    G = nx.convert_node_labels_to_integers(G, first_label=0)
    return G

def run_cost266_graph(seed, test_size, threshold):
    # Cost266 from SNDLib
    G = setup_cost266_graph()
    # Run the code
    embeddings, positive_samples, negative_samples, ap_k, f1, avg_precision, precision, recall, auc_roc, test_score, test_pred, test_proba, train_edges, test_edges, train_labels, test_labels = train(G, seed, test_size, threshold)
    # Visualise
    # visualise_metrics(test_labels, test_proba, test_pred, embeddings, positive_samples, negative_samples, perplexity=G.number_of_nodes()-1)
    return ap_k, f1, avg_precision, precision, recall, auc_roc, test_score

def setup_coronet_conus60_graph():
    # Create and preprocess the coronet graph
    G = create_coronet_conus_60_graph()
    # Set edge attributes
    G, haversine_link_lengths = calculate_edge_attributes(G)
    # print
    # print_edge_attributes(G)
    # Convert the node labels to integers starting from 0
    G = nx.convert_node_labels_to_integers(G, first_label=0)
    return G

def run_coronet_conus60_graph(seed, test_size, threshold):
    # BT-106
    G = setup_coronet_conus60_graph()
    # Run the code
    embeddings, positive_samples, negative_samples, ap_k, f1, avg_precision, precision, recall, auc_roc, test_score, test_pred, test_proba, train_edges, test_edges, train_labels, test_labels = train(G, seed, test_size, threshold)
    # Visualise
    # visualise_metrics(test_labels, test_proba, test_pred, embeddings, positive_samples, negative_samples, perplexity=G.number_of_nodes()-1)
    return ap_k, f1, avg_precision, precision, recall, auc_roc, test_score

def run_pipeline(seed, test_size):
    # Run the code for COST
    print("-------------------COST TOPOLOGY-------------------")
    cost_ap_k, cost_f1, cost_avg_precision, cost_precision, cost_recall, cost_auc_roc, cost_test_score  = run_cost266_graph(seed, test_size, threshold=0.5) # Best performing threshold from 0.1 test size plotted predicted probs
    # Run the code for CORONET
    print("-------------------CORONET TOPOLOGY-------------------")
    coronet_ap_k, coronet_f1, coronet_avg_precision, coronet_precision, coronet_recall, coronet_auc_roc, coronet_test_score = run_coronet_conus60_graph(seed, test_size, threshold=0.2) # Best performing threshold from 0.1 test size plotted predicted probs
    return cost_ap_k, cost_f1, cost_avg_precision, cost_precision, cost_recall, cost_auc_roc, cost_test_score, coronet_ap_k, coronet_f1, coronet_avg_precision, coronet_precision, coronet_recall, coronet_auc_roc, coronet_test_score


def set_python(seed=DEFAULT_SEED):
    os.environ['PYTHONHASHSEED'] = str(seed)
    random.seed(seed)

def set_numpy(seed=DEFAULT_SEED):
    np.random.seed(seed)

def set_torch(seed=DEFAULT_SEED, deterministic=False):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    if deterministic:
        torch.backends.cudnn.benchmark = False
        torch.backends.cudnn.deterministic = True

def set_all_seeds(seed=DEFAULT_SEED, deterministic=False, test_size=0.2):
    set_python(seed)
    set_numpy(seed)
    set_torch(seed, deterministic)
    return seed, test_size

def main(seed):
    test_size = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9]
    results = pd.DataFrame()

    for size in test_size:
      print("--------------------------------------------------")
      print("Test Size = " + str(size))
      seed, test_size  = set_all_seeds(seed = seed, test_size = size)
      cost_ap_k, cost_f1, cost_avg_precision, cost_precision, cost_recall, cost_auc_roc, cost_test_score, coronet_ap_k, coronet_f1, coronet_avg_precision, coronet_precision, coronet_recall, coronet_auc_roc, coronet_test_score = run_pipeline(seed, test_size)

      print("--------------------------------------------------")

    return


# Entry point of the program
if __name__ == "__main__":
    main(seed=42)




--------------------------------------------------
Test Size = 0.1
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  1.0
Average Precision:  0.9999999999999998
Precision:  1.0
Recall:  1.0
AUC-ROC:  1.0
Accuracy:  1.0
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.85]
F1 Score:  0.9142857142857143
Average Precision:  0.9829475649909087
Precision:  0.9411764705882353
Recall:  0.8888888888888888
AUC-ROC:  0.9814814814814815
Accuracy:  0.9166666666666666
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.875]
F1 Score:  0.888888888888889
Average Precision:  0.959375
Precision:  0.8
Recall:  1.0
AUC-ROC:  0.953125
Accuracy:  0.8125
--------------------------------------------------
--------------------------------------------------
Test Size = 0.2
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  1.0
Average Precision:  1.0
Precision:  1.0
Recall:  1.0
AUC-ROC:  1.0
Accuracy:  1.0
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.9]
F1 Score:  0.9090909090909091
Average Precision:  0.9733600891327865
Precision:  0.875
Recall:  0.9459459459459459
AUC-ROC:  0.9745173745173745
Accuracy:  0.9027777777777778
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9090909090909091
Average Precision:  0.9448242607066136
Precision:  0.8333333333333334
Recall:  1.0
AUC-ROC:  0.9529411764705882
Accuracy:  0.875
--------------------------------------------------
--------------------------------------------------
Test Size = 0.3
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.972972972972973
Average Precision:  0.9907216297825174
Precision:  0.9473684210526315
Recall:  1.0
AUC-ROC:  0.9901960784313726
Accuracy:  0.9714285714285714
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.9217391304347826
Average Precision:  0.9632517159695355
Precision:  0.8833333333333333
Recall:  0.9636363636363636
AUC-ROC:  0.9663807890222985
Accuracy:  0.9166666666666666
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.92
Average Precision:  0.9737303958565524
Precision:  0.8518518518518519
Recall:  1.0
AUC-ROC:  0.9756521739130435
Accuracy:  0.9166666666666666
--------------------------------------------------
--------------------------------------------------
Test Size = 0.4
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.7, 0.85]
F1 Score:  0.9259259259259259
Average Precision:  0.9700279948440482
Precision:  0.8928571428571429
Recall:  0.9615384615384616
AUC-ROC:  0.9557692307692307
Accuracy:  0.9130434782608695
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.9]
F1 Score:  0.9210526315789475
Average Precision:  0.9680427795563491
Precision:  0.8974358974358975
Recall:  0.9459459459459459
AUC-ROC:  0.9704633204633205
Accuracy:  0.9166666666666666
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.911764705882353
Average Precision:  0.9779379809481578
Precision:  0.8611111111111112
Recall:  0.96875
AUC-ROC:  0.978515625
Accuracy:  0.90625
--------------------------------------------------
--------------------------------------------------
Test Size = 0.5
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.9393939393939394
Average Precision:  0.9802181661465414
Precision:  0.9117647058823529
Recall:  0.96875
AUC-ROC:  0.97125
Accuracy:  0.9298245614035088
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9312169312169313
Average Precision:  0.9744772147061844
Precision:  0.9263157894736842
Recall:  0.9361702127659575
AUC-ROC:  0.9772389905987136
Accuracy:  0.9277777777777778
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.9302325581395349
Average Precision:  0.9873202610475491
Precision:  0.8888888888888888
Recall:  0.975609756097561
AUC-ROC:  0.9865211810012837
Accuracy:  0.9113924050632911
--------------------------------------------------
--------------------------------------------------
Test Size = 0.6
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9382716049382716
Average Precision:  0.9784982621928437
Precision:  0.95
Recall:  0.926829268292683
AUC-ROC:  0.9668989547038327
Accuracy:  0.927536231884058
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.9]
F1 Score:  0.9417040358744394
Average Precision:  0.9733774867261066
Precision:  0.9130434782608695
Recall:  0.9722222222222222
AUC-ROC:  0.9775377229080933
Accuracy:  0.9398148148148148
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.85]
F1 Score:  0.9108910891089108
Average Precision:  0.9539935369799929
Precision:  0.8518518518518519
Recall:  0.9787234042553191
AUC-ROC:  0.9645390070921986
Accuracy:  0.9157894736842105
--------------------------------------------------
--------------------------------------------------
Test Size = 0.7
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.942528735632184
Average Precision:  0.9773614039324844
Precision:  0.9534883720930233
Recall:  0.9318181818181818
AUC-ROC:  0.9703282828282829
Accuracy:  0.9375
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.9]
F1 Score:  0.929889298892989
Average Precision:  0.9692656469183681
Precision:  0.8936170212765957
Recall:  0.9692307692307692
AUC-ROC:  0.9725094577553594
Accuracy:  0.9246031746031746
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.9322033898305085
Average Precision:  0.9622238662023751
Precision:  0.8870967741935484
Recall:  0.9821428571428571
AUC-ROC:  0.9698051948051949
Accuracy:  0.9099099099099099
--------------------------------------------------
--------------------------------------------------
Test Size = 0.8
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.9]
F1 Score:  0.9130434782608695
Average Precision:  0.9706970925672808
Precision:  0.9545454545454546
Recall:  0.875
AUC-ROC:  0.9673295454545454
Accuracy:  0.9130434782608695
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.8, 0.85]
F1 Score:  0.9041095890410958
Average Precision:  0.9729194187609629
Precision:  0.9166666666666666
Recall:  0.8918918918918919
AUC-ROC:  0.9752413127413128
Accuracy:  0.9027777777777778
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.85]
F1 Score:  0.920863309352518
Average Precision:  0.9515728822430225
Precision:  0.8648648648648649
Recall:  0.9846153846153847
AUC-ROC:  0.960545905707196
Accuracy:  0.905511811023622
--------------------------------------------------
--------------------------------------------------
Test Size = 0.9
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.6, 0.8, 0.85]
F1 Score:  0.8807339449541284
Average Precision:  0.9361678082565664
Precision:  0.8727272727272727
Recall:  0.8888888888888888
AUC-ROC:  0.9357520786092214
Accuracy:  0.8737864077669902
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9442815249266863
Average Precision:  0.9724530731551878
Precision:  0.8994413407821229
Recall:  0.9938271604938271
AUC-ROC:  0.9762231367169638
Accuracy:  0.941358024691358
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.8]
F1 Score:  0.896969696969697
Average Precision:  0.9590835793044377
Precision:  0.8131868131868132
Recall:  1.0
AUC-ROC:  0.9643556600078339
Accuracy:  0.8951048951048951
--------------------------------------------------


In [None]:
# Entry point of the program
if __name__ == "__main__":
    main(seed=777)


--------------------------------------------------
Test Size = 0.1
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  1.0
Average Precision:  1.0
Precision:  1.0
Recall:  1.0
AUC-ROC:  1.0
Accuracy:  1.0
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.9473684210526316
Average Precision:  0.9649750055967807
Precision:  0.9
Recall:  1.0
AUC-ROC:  0.9660493827160493
Accuracy:  0.9444444444444444
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.875]
F1 Score:  0.888888888888889
Average Precision:  0.959375
Precision:  0.8
Recall:  1.0
AUC-ROC:  0.953125
Accuracy:  0.875
--------------------------------------------------
--------------------------------------------------
Test Size = 0.2
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9473684210526316
Average Precision:  0.9909090909090909
Precision:  1.0
Recall:  0.9
AUC-ROC:  0.9923076923076923
Accuracy:  0.9565217391304348
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.9350649350649352
Average Precision:  0.9792701035342462
Precision:  0.8780487804878049
Recall:  1.0
AUC-ROC:  0.9783950617283951
Accuracy:  0.9305555555555556
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.9032258064516129
Average Precision:  0.9822929171668668
Precision:  0.8235294117647058
Recall:  1.0
AUC-ROC:  0.9841269841269842
Accuracy:  0.9375
--------------------------------------------------
--------------------------------------------------
Test Size = 0.3
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.9375
Average Precision:  0.9930555555555556
Precision:  0.9375
Recall:  0.9375
AUC-ROC:  0.993421052631579
Accuracy:  0.9428571428571428
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.9391304347826087
Average Precision:  0.9845382978535362
Precision:  0.9
Recall:  0.9818181818181818
AUC-ROC:  0.9838765008576329
Accuracy:  0.9351851851851852
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9019607843137256
Average Precision:  0.9756818527364504
Precision:  0.8518518518518519
Recall:  0.9583333333333334
AUC-ROC:  0.9739583333333333
Accuracy:  0.875
--------------------------------------------------
--------------------------------------------------
Test Size = 0.4
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9387755102040817
Average Precision:  0.9820710606665983
Precision:  0.92
Recall:  0.9583333333333334
AUC-ROC:  0.9810606060606061
Accuracy:  0.9347826086956522
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.9240506329113923
Average Precision:  0.9712454532245645
Precision:  0.8902439024390244
Recall:  0.9605263157894737
AUC-ROC:  0.9725232198142415
Accuracy:  0.9166666666666666
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.8, 0.9]
F1 Score:  0.8750000000000001
Average Precision:  0.95210637642702
Precision:  0.8
Recall:  0.9655172413793104
AUC-ROC:  0.9625615763546799
Accuracy:  0.875
--------------------------------------------------
--------------------------------------------------
Test Size = 0.5
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.9]
F1 Score:  0.9354838709677419
Average Precision:  0.9870662244198884
Precision:  0.9354838709677419
Recall:  0.9354838709677419
AUC-ROC:  0.9851116625310173
Accuracy:  0.9298245614035088
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9285714285714286
Average Precision:  0.9682365147458974
Precision:  0.8921568627450981
Recall:  0.9680851063829787
AUC-ROC:  0.9705591291439881
Accuracy:  0.9222222222222223
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.891566265060241
Average Precision:  0.9663200335324095
Precision:  0.8222222222222222
Recall:  0.9736842105263158
AUC-ROC:  0.9717586649550707
Accuracy:  0.8987341772151899
--------------------------------------------------
--------------------------------------------------
Test Size = 0.6
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.9444444444444444
Average Precision:  0.9838622567915566
Precision:  0.9444444444444444
Recall:  0.9444444444444444
AUC-ROC:  0.9831649831649831
Accuracy:  0.9420289855072463
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.85]
F1 Score:  0.927659574468085
Average Precision:  0.9717234344577617
Precision:  0.9008264462809917
Recall:  0.956140350877193
AUC-ROC:  0.9720502235982111
Accuracy:  0.9212962962962963
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.8]
F1 Score:  0.8712871287128713
Average Precision:  0.9697026300359465
Precision:  0.7719298245614035
Recall:  1.0
AUC-ROC:  0.9754901960784313
Accuracy:  0.8947368421052632
--------------------------------------------------
--------------------------------------------------
Test Size = 0.7
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.9523809523809523
Average Precision:  0.9868923387228131
Precision:  0.9523809523809523
Recall:  0.9523809523809523
AUC-ROC:  0.9862155388471179
Accuracy:  0.95
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.9]
F1 Score:  0.9259259259259259
Average Precision:  0.968189737958748
Precision:  0.9057971014492754
Recall:  0.946969696969697
AUC-ROC:  0.9706439393939394
Accuracy:  0.9206349206349206
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.9]
F1 Score:  0.8793103448275861
Average Precision:  0.9522354060806091
Precision:  0.796875
Recall:  0.9807692307692307
AUC-ROC:  0.9661016949152542
Accuracy:  0.9009009009009009
--------------------------------------------------
--------------------------------------------------
Test Size = 0.8
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.6, 0.75]
F1 Score:  0.9
Average Precision:  0.9785678383631022
Precision:  0.8333333333333334
Recall:  0.9782608695652174
AUC-ROC:  0.9792060491493384
Accuracy:  0.8913043478260869
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9333333333333333
Average Precision:  0.9624740670803416
Precision:  0.9090909090909091
Recall:  0.958904109589041
AUC-ROC:  0.9693710206444145
Accuracy:  0.9305555555555556
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.9]
F1 Score:  0.8970588235294118
Average Precision:  0.9463977959440586
Precision:  0.8243243243243243
Recall:  0.9838709677419355
AUC-ROC:  0.9593052109181142
Accuracy:  0.8976377952755905
--------------------------------------------------
--------------------------------------------------
Test Size = 0.9
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.8, 0.65]
F1 Score:  0.864406779661017
Average Precision:  0.9837468438773314
Precision:  0.7611940298507462
Recall:  1.0
AUC-ROC:  0.9849170437405731
Accuracy:  0.8446601941747572
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.9340974212034383
Average Precision:  0.965568426978709
Precision:  0.8810810810810811
Recall:  0.9939024390243902
AUC-ROC:  0.971532012195122
Accuracy:  0.9290123456790124
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.85]
F1 Score:  0.9041095890410958
Average Precision:  0.9258122862513063
Precision:  0.8918918918918919
Recall:  0.9166666666666666
AUC-ROC:  0.9491392801251956
Accuracy:  0.8951048951048951
--------------------------------------------------


In [None]:
# Entry point of the program
if __name__ == "__main__":
    main(seed=108)


--------------------------------------------------
Test Size = 0.1
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  1.0
Average Precision:  1.0
Precision:  1.0
Recall:  1.0
AUC-ROC:  1.0
Accuracy:  1.0
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  1.0
Average Precision:  1.0
Precision:  1.0
Recall:  1.0
AUC-ROC:  1.0
Accuracy:  1.0
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.75]
F1 Score:  0.8333333333333333
Average Precision:  0.82985236985237
Precision:  0.7142857142857143
Recall:  1.0
AUC-ROC:  0.7666666666666666
Accuracy:  0.8125
--------------------------------------------------
--------------------------------------------------
Test Size = 0.2
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  1.0
Average Precision:  0.9999999999999998
Precision:  1.0
Recall:  1.0
AUC-ROC:  1.0
Accuracy:  1.0
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.9538461538461539
Average Precision:  0.9926891471524397
Precision:  0.9117647058823529
Recall:  1.0
AUC-ROC:  0.9944925255704169
Accuracy:  0.9583333333333334
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [0.6, 0.8, 0.85]
F1 Score:  0.8205128205128205
Average Precision:  0.8383752710852028
Precision:  0.7619047619047619
Recall:  0.8888888888888888
AUC-ROC:  0.8333333333333333
Accuracy:  0.78125
--------------------------------------------------
--------------------------------------------------
Test Size = 0.3
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.9]
F1 Score:  0.9500000000000001
Average Precision:  0.9973684210526315
Precision:  0.9047619047619048
Recall:  1.0
AUC-ROC:  0.9967105263157895
Accuracy:  0.9428571428571428
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.9444444444444444
Average Precision:  0.9628092920474262
Precision:  0.8947368421052632
Recall:  1.0
AUC-ROC:  0.9735122119023047
Accuracy:  0.9444444444444444
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.8, 0.75]
F1 Score:  0.8214285714285715
Average Precision:  0.8560014447598745
Precision:  0.7931034482758621
Recall:  0.8518518518518519
AUC-ROC:  0.8589065255731922
Accuracy:  0.8125
--------------------------------------------------
--------------------------------------------------
Test Size = 0.4
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.9]
F1 Score:  0.9454545454545454
Average Precision:  0.9960739794711695
Precision:  0.9285714285714286
Recall:  0.9629629629629629
AUC-ROC:  0.9941520467836258
Accuracy:  0.9347826086956522
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.9559748427672956
Average Precision:  0.9788784776713971
Precision:  0.926829268292683
Recall:  0.987012987012987
AUC-ROC:  0.9800348904826517
Accuracy:  0.9513888888888888
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.8641975308641975
Average Precision:  0.922366004131572
Precision:  0.8333333333333334
Recall:  0.8974358974358975
AUC-ROC:  0.8984615384615384
Accuracy:  0.828125
--------------------------------------------------
--------------------------------------------------
Test Size = 0.5
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9538461538461539
Average Precision:  0.9951375711574952
Precision:  0.9117647058823529
Recall:  1.0
AUC-ROC:  0.9937965260545906
Accuracy:  0.9473684210526315
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9538461538461539
Average Precision:  0.9685261401663114
Precision:  0.9117647058823529
Recall:  1.0
AUC-ROC:  0.9724385119268323
Accuracy:  0.95
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.7, 0.85]
F1 Score:  0.8409090909090908
Average Precision:  0.8741540768310448
Precision:  0.8604651162790697
Recall:  0.8222222222222222
AUC-ROC:  0.8908496732026143
Accuracy:  0.8227848101265823
--------------------------------------------------
--------------------------------------------------
Test Size = 0.6
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.9]
F1 Score:  0.9473684210526316
Average Precision:  0.9954933003178618
Precision:  0.9
Recall:  1.0
AUC-ROC:  0.994949494949495
Accuracy:  0.9420289855072463
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.95]
F1 Score:  0.9464285714285715
Average Precision:  0.9747228787460889
Precision:  0.905982905982906
Recall:  0.9906542056074766
AUC-ROC:  0.9789076566921031
Accuracy:  0.9444444444444444
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.86
Average Precision:  0.8898850340645515
Precision:  0.8775510204081632
Recall:  0.8431372549019608
AUC-ROC:  0.9180035650623886
Accuracy:  0.8526315789473684
--------------------------------------------------
--------------------------------------------------
Test Size = 0.7
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.9545454545454545
Average Precision:  0.9933401380320308
Precision:  0.9130434782608695
Recall:  1.0
AUC-ROC:  0.9924812030075187
Accuracy:  0.95
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.85]
F1 Score:  0.9433962264150945
Average Precision:  0.9746648138302109
Precision:  0.8928571428571429
Recall:  1.0
AUC-ROC:  0.9789606299212599
Accuracy:  0.9404761904761905
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 1.0]
F1 Score:  0.8852459016393444
Average Precision:  0.9125178825624259
Precision:  0.9
Recall:  0.8709677419354839
AUC-ROC:  0.9302172481895985
Accuracy:  0.8738738738738738
--------------------------------------------------
--------------------------------------------------
Test Size = 0.8
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9574468085106385
Average Precision:  0.9899492046121621
Precision:  0.9574468085106383
Recall:  0.9574468085106383
AUC-ROC:  0.988179669030733
Accuracy:  0.9565217391304348
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.943894389438944
Average Precision:  0.971899196251093
Precision:  0.9050632911392406
Recall:  0.9862068965517241
AUC-ROC:  0.9757897275138654
Accuracy:  0.9409722222222222
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 1.0, 0.95]
F1 Score:  0.8920863309352517
Average Precision:  0.9284009285924004
Precision:  0.9117647058823529
Recall:  0.8732394366197183
AUC-ROC:  0.9391348088531187
Accuracy:  0.8740157480314961
--------------------------------------------------
--------------------------------------------------
Test Size = 0.9
-------------------COST TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/37 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.9, 0.8]
F1 Score:  0.9272727272727272
Average Precision:  0.9824609616273197
Precision:  0.8793103448275862
Recall:  0.9807692307692307
AUC-ROC:  0.9811463046757165
Accuracy:  0.9223300970873787
-------------------BT TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/106 [00:00<?, ?it/s]

Accuracy@K:  [0.8, 0.8, 0.9]
F1 Score:  0.935672514619883
Average Precision:  0.9695672989497389
Precision:  0.8791208791208791
Recall:  1.0
AUC-ROC:  0.9746189024390244
Accuracy:  0.9320987654320988
-------------------CORONET TOPOLOGY-------------------


Computing transition probabilities:   0%|          | 0/60 [00:00<?, ?it/s]

Accuracy@K:  [1.0, 0.9, 0.95]
F1 Score:  0.9090909090909091
Average Precision:  0.9385042115514015
Precision:  0.8860759493670886
Recall:  0.9333333333333333
AUC-ROC:  0.9488235294117647
Accuracy:  0.8881118881118881
--------------------------------------------------
