In [None]:
import numpy as np
import networkx as nx
from tqdm import tqdm
from collections import defaultdict
import os
from typing import Dict, List, Tuple, Set
from multiprocessing import Pool, cpu_count
import time
import matplotlib.pyplot as plt
from itertools import combinations

# Message size constants definition
BASE_MESSAGE_SIZE = 40    # Base message header size (bytes)
ISL_DELETE_SIZE = 53     # ISL deletion operation message size
ISL_ADD_SIZE = 125      # ISL addition operation message size
ROUTE_ENTRY_SIZE = 97   # Route table entry operation message size

class NodeStats:
    def __init__(self):
        self.deleted_isls = 0
        self.added_isls = 0
        self.deleted_routes = 0
        self.added_routes = 0
        # Aggregated routing statistics
        self.deleted_aggregated_routes = 0
        self.added_aggregated_routes = 0
        
    def get_total_bytes(self):
        return (BASE_MESSAGE_SIZE +
                ISL_DELETE_SIZE * self.deleted_isls +
                ISL_ADD_SIZE * self.added_isls +
                ROUTE_ENTRY_SIZE * (self.deleted_routes + self.added_routes))
                
    def get_total_aggregated_bytes(self):
        return (BASE_MESSAGE_SIZE +
                ISL_DELETE_SIZE * self.deleted_isls +
                ISL_ADD_SIZE * self.added_isls +
                ROUTE_ENTRY_SIZE * (self.deleted_aggregated_routes + 
                                  self.added_aggregated_routes))

def build_topology_graph(inter_topology, intra_topology):
    """Build a complete graph containing both intra-domain and inter-domain connections"""
    G = nx.Graph()
    
    # Add inter-domain connections
    for (node1, node2), _ in inter_topology:
        G.add_edge(node1, node2)
    
    # Add intra-domain connections
    for grid_connections in intra_topology.values():
        for node1, node2 in grid_connections:
            G.add_edge(node1, node2)
    
    return G

def get_topology_satellites(inter_topology, intra_topology):
    """Get all satellite nodes in the topology"""
    satellites = set()
    
    # Get satellites from inter-domain topology
    for (node1, node2), _ in inter_topology:
        satellites.add(node1)
        satellites.add(node2)
    
    # Get satellites from intra-domain topology
    for grid_connections in intra_topology.values():
        for node1, node2 in grid_connections:
            satellites.add(node1)
            satellites.add(node2)
    
    return satellites

def get_first_hop(G, source, target):
    """Get the first hop on the shortest path from source to target"""
    try:
        path = nx.shortest_path(G, source, target)
        return path[1] if len(path) > 1 else None
    except nx.NetworkXNoPath:
        return None

def process_single_node(args):
    """Process overhead calculation for a single node (for parallel processing)"""
    node, current_topology, previous_topology, current_graph, previous_graph, \
    previous_first_hops, previous_aggregated_routes, current_satellites, previous_satellites, num_satellites = args
    
    stats = NodeStats()
    
    # Get node status
    in_current = node in current_satellites
    in_previous = node in previous_satellites
    
    if not in_current and not in_previous:
        return node, stats, {}, {}
        
    # Calculate ISL changes
    current_links = set((n1, n2) for (n1, n2), _ in current_topology 
                      if n1 == node or n2 == node)
    previous_links = set((n1, n2) for (n1, n2), _ in previous_topology 
                       if n1 == node or n2 == node)
    
    deleted_links = previous_links - current_links
    added_links = current_links - previous_links
    
    stats.deleted_isls = len(deleted_links)
    stats.added_isls = len(added_links)
    
    # Handle route changes
    current_first_hops = {}
    current_aggregated_routes = defaultdict(set)  # next_hop -> {destinations}
    
    # If node changes from existing to non-existing
    if in_previous and not in_current:
        stats.deleted_routes = sum(1 for pair, hop in previous_first_hops.items() 
                                 if pair[0] == node)
        stats.deleted_aggregated_routes = len(previous_aggregated_routes.get(node, {}))
        return node, stats, {}, {}
        
    # If node changes from non-existing to existing or continues to exist
    # Calculate routing information for current timestamp
    for dst in current_satellites:
        if dst != node:
            current_hop = get_first_hop(current_graph, node, dst)
            if current_hop is not None:
                current_first_hops[(node, dst)] = current_hop
                current_aggregated_routes[current_hop].add(dst)
    
    # If node is newly appeared
    if not in_previous and in_current:
        stats.added_routes = len(current_first_hops)
        stats.added_aggregated_routes = len(current_aggregated_routes)
        return node, stats, current_first_hops, current_aggregated_routes
    
    # Calculate regular route changes
    for dst in current_satellites:
        if dst == node:
            continue
            
        pair = (node, dst)
        prev_hop = previous_first_hops.get(pair)
        curr_hop = current_first_hops.get(pair)
        
        if curr_hop and not prev_hop:  # New route
            stats.added_routes += 1
        elif not curr_hop and prev_hop:  # Deleted route
            stats.deleted_routes += 1
        elif curr_hop and prev_hop and curr_hop != prev_hop:  # Route change
            stats.deleted_routes += 1  # Delete old route
            stats.added_routes += 1    # Add new route
            
    # Calculate aggregated route changes
    prev_agg_routes = previous_aggregated_routes.get(node, {})
    
    # Process all route entries
    for prev_next_hop, prev_dsts in prev_agg_routes.items():
        curr_dsts = current_aggregated_routes.get(prev_next_hop, set())
        if not curr_dsts:
            # This next hop is no longer used
            stats.deleted_aggregated_routes += 1
        else:
            # Check if destination set needs updating
            if prev_dsts != curr_dsts:
                # Destination set has changed, needs updating
                stats.deleted_aggregated_routes += 1
                stats.added_aggregated_routes += 1
    
    # Process new next hops
    for curr_next_hop, curr_dsts in current_aggregated_routes.items():
        if curr_next_hop not in prev_agg_routes:
            # Completely new next hop
            stats.added_aggregated_routes += 1
            
    return node, stats, current_first_hops, current_aggregated_routes

def process_single_node_tiny_leo(args):
    """Process TinyLEO overhead calculation for a single node (for parallel processing)"""
    node, current_topology, previous_topology = args
    
    stats = NodeStats()
    
    # Get node status in current and previous timestamp
    current_nodes = set()
    for (n1, n2), _ in current_topology:
        current_nodes.add(n1)
        current_nodes.add(n2)
        
    previous_nodes = set()
    for (n1, n2), _ in previous_topology:
        previous_nodes.add(n1)
        previous_nodes.add(n2)
    
    in_current = node in current_nodes
    in_previous = node in previous_nodes
    
    # If node doesn't exist in either timestamp, return immediately
    if not in_current and not in_previous:
        return node, stats
    
    current_links = set((n1, n2) for (n1, n2), _ in current_topology 
                      if n1 == node or n2 == node)
    previous_links = set((n1, n2) for (n1, n2), _ in previous_topology 
                       if n1 == node or n2 == node)
    
    # If node changes from existing to non-existing
    if in_previous and not in_current:
        stats.deleted_isls = len(previous_links)
        return node, stats
    
    # If node changes from non-existing to existing
    if not in_previous and in_current:
        stats.added_isls = len(current_links)
        return node, stats
    
    # Normal case: node exists in both timestamps
    deleted_links = previous_links - current_links
    added_links = current_links - previous_links
    
    stats.deleted_isls = len(deleted_links)
    stats.added_isls = len(added_links)
    
    return node, stats

def calculate_all_timestamps_overhead(all_inter_topology, all_intra_topology, num_satellites):
    """Calculate signaling overhead for all timestamps"""
    ts_sdn_overhead = {}
    tiny_leo_overhead = {}
    
    timestamps = sorted(all_inter_topology.keys())
    num_processes = max(1, cpu_count() - 1)  # Reserve one core for the system
    
    # Process first timestamp
    current_graph = build_topology_graph(all_inter_topology[timestamps[0]], 
                                       all_intra_topology[timestamps[0]])
    current_satellites = get_topology_satellites(all_inter_topology[timestamps[0]], 
                                               all_intra_topology[timestamps[0]])
    prev_first_hops = {}
    prev_aggregated_routes = {}
    
    # Initialize first hop information
    for src in current_satellites:
        curr_aggregated_routes = defaultdict(set)
        for dst in current_satellites:
            if src != dst:
                current_hop = get_first_hop(current_graph, src, dst)
                if current_hop is not None:
                    prev_first_hops[(src, dst)] = current_hop
                    curr_aggregated_routes[current_hop].add(dst)
        if curr_aggregated_routes:
            prev_aggregated_routes[src] = curr_aggregated_routes
    
    print(f"Processing {len(timestamps)} timestamps using {num_processes} processes...")
    
    for t in tqdm(range(1, len(timestamps))):
        current_topology = all_inter_topology[timestamps[t]]
        previous_topology = all_inter_topology[timestamps[t-1]]
        
        # Build graphs
        current_graph = build_topology_graph(current_topology, 
                                           all_intra_topology[timestamps[t]])
        previous_graph = build_topology_graph(previous_topology, 
                                            all_intra_topology[timestamps[t-1]])
        
        current_satellites = get_topology_satellites(current_topology, 
                                                   all_intra_topology[timestamps[t]])
        previous_satellites = get_topology_satellites(previous_topology, 
                                                    all_intra_topology[timestamps[t-1]])
        
        # Prepare parameters for parallel processing
        ts_sdn_args = [(node, current_topology, previous_topology, current_graph, 
                       previous_graph, prev_first_hops, prev_aggregated_routes, 
                       current_satellites, previous_satellites, num_satellites) 
                      for node in range(num_satellites)]
        
        tiny_leo_args = [(node, current_topology, previous_topology) 
                        for node in range(num_satellites)]
        
        # Process all nodes in parallel
        with Pool(processes=num_processes) as pool:
            # Process TS-SDN
            ts_results = pool.map(process_single_node, ts_sdn_args)
            
            # Process TinyLEO
            tiny_results = pool.map(process_single_node_tiny_leo, tiny_leo_args)
        
        # Organize TS-SDN results
        ts_messages = {}
        ts_bytes = {}
        ts_aggregated_bytes = {}
        new_first_hops = {}
        new_aggregated_routes = {}
        
        for node, stats, first_hops, aggregated_routes in ts_results:
            if (stats.deleted_isls > 0 or stats.added_isls > 0 or 
                stats.deleted_routes > 0 or stats.added_routes > 0 or
                stats.deleted_aggregated_routes > 0 or stats.added_aggregated_routes > 0):
                ts_messages[node] = 1
                ts_bytes[node] = stats.get_total_bytes()
                ts_aggregated_bytes[node] = stats.get_total_aggregated_bytes()
            if first_hops:
                new_first_hops.update(first_hops)
            if aggregated_routes:
                new_aggregated_routes[node] = aggregated_routes
        
        # Organize TinyLEO results
        tiny_messages = {}
        tiny_bytes = {}
        for node, stats in tiny_results:
            if stats.deleted_isls > 0 or stats.added_isls > 0:
                tiny_messages[node] = 1
                tiny_bytes[node] = stats.get_total_bytes()
        
        # Store results
        ts_sdn_overhead[timestamps[t]] = {
            'messages': ts_messages,
            'bytes': ts_bytes,
            'aggregated_bytes': ts_aggregated_bytes,
            'detailed_stats': {
                node: {
                    'deleted_isls': stats.deleted_isls,
                    'added_isls': stats.added_isls,
                    'deleted_routes': stats.deleted_routes,
                    'added_routes': stats.added_routes,
                    'deleted_aggregated_routes': stats.deleted_aggregated_routes,
                    'added_aggregated_routes': stats.added_aggregated_routes
                }
                for node, stats, _, _ in ts_results if (
                    stats.deleted_isls > 0 or stats.added_isls > 0 or
                    stats.deleted_routes > 0 or stats.added_routes > 0 or
                    stats.deleted_aggregated_routes > 0 or stats.added_aggregated_routes > 0
                )
            }
        }
        
        tiny_leo_overhead[timestamps[t]] = {
            'messages': tiny_messages,
            'bytes': tiny_bytes
        }
        
        # Update previous information
        prev_first_hops = new_first_hops
        prev_aggregated_routes = new_aggregated_routes
    
    return ts_sdn_overhead, tiny_leo_overhead

def save_and_analyze_results(ts_sdn_overhead, tiny_leo_overhead, output_dir):
    """Save and analyze results"""
    # Save raw data
    np.save(os.path.join(output_dir, "ts_sdn_overhead_573_11_11_distancedt_global_max_new.npy"), 
            ts_sdn_overhead)
    np.save(os.path.join(output_dir, "tiny_leo_overhead_573_11_11_distancedt_global_max_new.npy"), 
            tiny_leo_overhead)
    
    # Calculate statistics
    ts_total_messages = sum(sum(t['messages'].values()) 
                          for t in ts_sdn_overhead.values())
    ts_total_bytes = sum(sum(t['bytes'].values()) 
                        for t in ts_sdn_overhead.values())
    ts_total_aggregated_bytes = sum(sum(t['aggregated_bytes'].values()) 
                                  for t in ts_sdn_overhead.values())
    
    tiny_total_messages = sum(sum(t['messages'].values()) 
                            for t in tiny_leo_overhead.values())
    tiny_total_bytes = sum(sum(t['bytes'].values()) 
                          for t in tiny_leo_overhead.values())
    
    # Count detailed operation quantities
    total_stats = defaultdict(int)
    for timestamp_data in ts_sdn_overhead.values():
        for node_stats in timestamp_data['detailed_stats'].values():
            for key, value in node_stats.items():
                total_stats[key] += value
    
    # Save statistics
    with open(os.path.join(output_dir, "overhead_statistics_573_11_11_distancedt_global_max.txt"), 'w') as f:
        f.write("Signaling Overhead Statistics\n")
        f.write("===========================\n\n")
        
        f.write("TS-SDN:\n")
        f.write(f"Total messages: {ts_total_messages}\n")
        f.write(f"Total bytes (non-aggregated): {ts_total_bytes}\n")
        f.write(f"Total bytes (aggregated): {ts_total_aggregated_bytes}\n")
        f.write(f"Average messages per timestamp: {ts_total_messages/len(ts_sdn_overhead):.2f}\n")
        f.write(f"Average bytes per timestamp (non-aggregated): {ts_total_bytes/len(ts_sdn_overhead):.2f}\n")
        f.write(f"Average bytes per timestamp (aggregated): {ts_total_aggregated_bytes/len(ts_sdn_overhead):.2f}\n\n")
        
        f.write("Detailed Operations:\n")
        for key, value in total_stats.items():
            f.write(f"{key}: {value}\n")
        f.write("\n")
        
        f.write("TinyLEO:\n")
        f.write(f"Total messages: {tiny_total_messages}\n")
        f.write(f"Total bytes: {tiny_total_bytes}\n")
        f.write(f"Average messages per timestamp: {tiny_total_messages/len(tiny_leo_overhead):.2f}\n")
        f.write(f"Average bytes per timestamp: {tiny_total_bytes/len(tiny_leo_overhead):.2f}\n")

if __name__ == "__main__":
    # Configuration parameters
    input_dir = "data"
    output_dir = "data"
    os.makedirs(output_dir, exist_ok=True)
    
    # Load topology data
    print("Loading topology data...")
    all_inter_topology = np.load(os.path.join(input_dir, 
        "all_inter_topology_573_11_11_distancedt_global_max.npy"), 
        allow_pickle=True).item()
    all_intra_topology = np.load(os.path.join(input_dir, 
        "all_intra_topology_573_11_11_distancedt_global_max.npy"), 
        allow_pickle=True).item()
    
    num_satellites = 1762
    
    # Calculate signaling overhead
    start_time = time.time()
    ts_sdn_overhead, tiny_leo_overhead = calculate_all_timestamps_overhead(
        all_inter_topology, all_intra_topology, num_satellites
    )
    end_time = time.time()
    
    print(f"Processing time: {end_time - start_time:.2f} seconds")
    
    # Save and analyze results
    save_and_analyze_results(ts_sdn_overhead, tiny_leo_overhead, output_dir)