In [5]:
from prettytable import PrettyTable
import networkx as nx
from tabulate import tabulate
import matplotlib.pyplot as plt
import random

In [8]:
class GraphTools:
    def __init__(self,graphConstruction):
        object.__init__(self)
        self.gc = graphConstruction
        self.count_components = {}


    def describe_graph(self):
        self.count_components = {}
        for node in self.gc.G.nodes():
            node_type = self.gc.G.nodes[node]['col']
            if node_type in self.count_components:
                self.count_components[node_type] += 1
            else:
                self.count_components[node_type] = 1

        print("The total number of edges is :", self.gc.G.number_of_edges())
        print("The total number of nodes is :", self.gc.G.number_of_nodes())
        table = PrettyTable()
        table.field_names = ["Node Type", "Count"]

        for node_type, count in self.count_components.items():
            table.add_row([self.gc.mapping_color_components[node_type], count])

        print(table)
        

    def get_sub_graphs(self):
        self.componentsGraphs = nx.connected_components(self.gc.G)

        subgraphs_info = []

        for component in self.componentsGraphs:
            subgraph = self.gc.G.subgraph(component)
            num_nodes = len(subgraph)
            num_edges = subgraph.size()
            subgraphs_info.append((num_nodes, num_edges))

        # Sort subgraphs_info based on the number of nodes in descending order
        #subgraphs_info.sort(reverse=True)

        # Prepare the table headers
        table_headers = ["Component Number", "Number of Nodes", "Number of Edges"]

        # Prepare the table rows
        table_rows = [(index + 1, num_nodes, num_edges) for index, (num_nodes, num_edges) in enumerate(subgraphs_info)]

        # Print the table using tabulate
        print(tabulate(table_rows, headers=table_headers, tablefmt="grid"))


    def plot_sub_graph_by_id(self, component_number, labeled=False):
        self.componentsGraphs = nx.connected_components(self.gc.G)
        for index, component in enumerate(self.componentsGraphs, start=1):
            if index == component_number:
                subgraph = self.gc.G.subgraph(component)
                colorNodes = nx.get_node_attributes(subgraph, 'col')
                colorList = [colorNodes[node] for node in colorNodes]

                plt.figure(3, figsize=(30, 30)) 
                pos = nx.spring_layout(subgraph)

                # Common code for both cases
                if labeled:
                    nx.draw(
                        subgraph,
                        pos=pos,
                        node_size=100,
                        with_labels=True,
                        node_color=colorList,
                        labels={node: node for node in subgraph.nodes()},
                    )
                else:
                    nx.draw(
                        subgraph,
                        pos=pos,
                        node_size=100,
                        with_labels=False,
                        node_color=colorList,
                    )

                # Add edge labels
                edge_labels = {(edge[0], edge[1]): str(subgraph.edges[edge]['idEdge']) for edge in subgraph.edges}
                nx.draw_networkx_edge_labels(subgraph, pos=pos, edge_labels=edge_labels) if labeled else None

                # Add legend
                mapping_color_components = {
                    "red": "Manhole", "springgreen": "Structure", "yellow": "Pump",
                    "cyan": "Fitting", "black": "Treatment Plant", "orange": "Accessory",
                    "violet": "Dummy", "blue": "Device", "bisque": "Spillway"
                }
                handles = [plt.Line2D([0], [0], marker='o', color=color, label=label) for color, label in mapping_color_components.items()]
                legend = plt.legend(handles=handles, title="Components", fontsize=40)

                legend.get_title().set_fontsize('xx-large')

                plt.show()
                break





    

    def get_repeated_edges_by_id(self):
        id_counts = {}

        # Iterate through edges and count IDs
        for u, v, data in self.gc.G.edges(data=True):
            id_edge = data.get('idEdge')
            if id_edge is not None:
                if id_edge in id_counts:
                    id_counts[id_edge] += 1
                else:
                    id_counts[id_edge] = 1

        # Find IDs that are repeated
        repeated_ids = [id_edge for id_edge, count in id_counts.items() if count > 1]

        print("Repeated IDs:", repeated_ids)
