In [1]:
pip install qiskit qiskit-aer torch torch_geometric networkx matplotlib scikit-learn

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


In [2]:
import numpy as np
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.nn import GATv2Conv, TransformerConv
from torch_geometric.utils import to_undirected, negative_sampling
import networkx as nx
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import roc_auc_score, average_precision_score, precision_recall_curve
from scipy.constants import h, c
from sklearn.model_selection import KFold
import json
from datetime import datetime
import os
import torch.nn.functional as F

def bce_with_logits_loss(pos_out, neg_out):
    """Custom BCE loss for link prediction"""
    pos_loss = F.binary_cross_entropy_with_logits(
        pos_out, torch.ones_like(pos_out)
    )
    neg_loss = F.binary_cross_entropy_with_logits(
        neg_out, torch.zeros_like(neg_out)
    )
    return pos_loss + neg_loss

def to_networkx(data):
    """Convert PyG data to NetworkX graph"""
    G = nx.Graph()
    edge_index = data.edge_index.cpu().numpy()
    for i in range(edge_index.shape[1]):
        G.add_edge(edge_index[0, i], edge_index[1, i])
    return G

class AdvancedQuantumChannelSimulator:
    def __init__(self, distance, wavelength=1550e-9, fiber_loss=0.2,
                 detector_efficiency=0.1, dark_count_rate=1e-6,
                 atmospheric_visibility=None):
        self.distance = distance
        self.wavelength = wavelength
        self.fiber_loss = fiber_loss
        self.detector_efficiency = detector_efficiency
        self.dark_count_rate = dark_count_rate
        self.atmospheric_visibility = atmospheric_visibility
        self.photon_energy = h * c / wavelength

    def calculate_channel_loss(self):
        fiber_loss_db = self.fiber_loss * self.distance
        fiber_transmission = 10 ** (-fiber_loss_db/10)

        if self.atmospheric_visibility:
            beam_divergence = 1.22 * self.wavelength / 0.1
            geometric_loss = (0.1 / (beam_divergence * self.distance)) ** 2
            atmospheric_loss = np.exp(-3.91 * self.distance / self.atmospheric_visibility)
            total_transmission = fiber_transmission * geometric_loss * atmospheric_loss
        else:
            total_transmission = fiber_transmission

        return total_transmission

    def simulate_bb84_protocol(self, num_pulses=10000, mean_photon_number=0.1):
        channel_transmission = self.calculate_channel_loss()
        received_photons = np.random.poisson(
            mean_photon_number * channel_transmission * self.detector_efficiency,
            num_pulses
        )
        dark_counts = np.random.poisson(self.dark_count_rate, num_pulses)
        total_counts = received_photons + dark_counts
        basis_matches = np.random.choice([0, 1], num_pulses, p=[0.5, 0.5])
        qber = 0.5 * (1 - np.exp(-2 * self.distance / 100))
        errors = np.random.choice([0, 1], num_pulses, p=[1-qber, qber])
        matched_pulses = total_counts * basis_matches
        raw_key_rate = np.sum(matched_pulses) / num_pulses
        final_key_rate = raw_key_rate * (1 - 2 * h2(qber))

        return {
            'qber': qber,
            'raw_key_rate': raw_key_rate,
            'final_key_rate': final_key_rate,
            'channel_loss_db': -10 * np.log10(channel_transmission),
            'dark_count_probability': np.mean(dark_counts > 0)
        }

def h2(x):
    """Binary entropy function"""
    return -x * np.log2(x) - (1-x) * np.log2(1-x) if 0 < x < 1 else 0

class AdvancedQKDNetwork:
    def __init__(self, num_nodes=50):
        self.num_nodes = num_nodes
        self.positions = self._generate_realistic_topology()

    def _generate_realistic_topology(self):
        centers = np.random.multivariate_normal(
            mean=[0, 0],
            cov=[[100, 0], [0, 100]],
            size=3
        )

        positions = []
        for _ in range(self.num_nodes):
            center = centers[np.random.randint(0, 3)]
            pos = center + np.random.multivariate_normal(
                mean=[0, 0],
                cov=[[10, 0], [0, 10]]
            )
            positions.append(pos)

        return np.array(positions)

    def generate_graph_data(self):
        distances = np.zeros((self.num_nodes, self.num_nodes))
        for i in range(self.num_nodes):
            for j in range(i + 1, self.num_nodes):
                distances[i, j] = distances[j, i] = np.linalg.norm(
                    self.positions[i] - self.positions[j]
                )

        edges = []
        edge_attrs = []

        for i in range(self.num_nodes):
            for j in range(i + 1, self.num_nodes):
                if distances[i, j] < 100:
                    simulator = AdvancedQuantumChannelSimulator(
                        distance=distances[i, j],
                        atmospheric_visibility=20000 if np.random.random() < 0.2 else None
                    )
                    results = simulator.simulate_bb84_protocol()

                    if results['final_key_rate'] > 0:
                        edges.append([i, j])
                        edge_attrs.append([
                            results['final_key_rate'],
                            results['qber'],
                            distances[i, j],
                            results['channel_loss_db'],
                            results['dark_count_probability']
                        ])

        edge_index = torch.tensor(edges).t().contiguous()
        edge_attr = torch.tensor(edge_attrs, dtype=torch.float)

        G = nx.Graph()
        G.add_edges_from(edges)

        node_features = []
        for i in range(self.num_nodes):
            features = [
                self.positions[i, 0],
                self.positions[i, 1],
                G.degree(i) if i in G else 0,
                nx.betweenness_centrality(G).get(i, 0) if i in G else 0
            ]
            node_features.append(features)

        return Data(
            x=torch.tensor(node_features, dtype=torch.float),
            edge_index=to_undirected(edge_index),
            edge_attr=edge_attr,
            pos=torch.tensor(self.positions, dtype=torch.float)
        )

class AdvancedQKDLinkPredictor(torch.nn.Module):
    def __init__(self, in_channels, edge_attr_channels, hidden_channels=64):
        super().__init__()

        self.conv1 = TransformerConv(in_channels, hidden_channels)
        self.conv2 = GATv2Conv(hidden_channels, hidden_channels)

        self.edge_mlp = torch.nn.Sequential(
            torch.nn.Linear(edge_attr_channels, hidden_channels),
            torch.nn.LayerNorm(hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(hidden_channels, hidden_channels)
        )

        self.link_predictor = torch.nn.Sequential(
            torch.nn.Linear(3 * hidden_channels, hidden_channels),
            torch.nn.LayerNorm(hidden_channels),
            torch.nn.ReLU(),
            torch.nn.Dropout(0.2),
            torch.nn.Linear(hidden_channels, 1)
        )

    def forward(self, x, edge_index, edge_attr):
        x = self.conv1(x, edge_index)
        x = torch.relu(x)
        x = self.conv2(x, edge_index)

        # Process edge features for all edges
        edge_features = self.edge_mlp(edge_attr)

        return x, edge_features

    def decode(self, z, edge_features, edge_label_index):
        src, dst = edge_label_index

        # Handle negative sampling case
        if edge_features.size(0) != edge_label_index.size(1):
            # For negative samples, create dummy edge features
            edge_features = edge_features.mean(dim=0, keepdim=True).repeat(edge_label_index.size(1), 1)

        node_features = torch.cat([
            z[src],
            z[dst],
            edge_features
        ], dim=-1)
        return self.link_predictor(node_features).squeeze(-1)

def train_and_evaluate(model, data, num_epochs=200, k_folds=5):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)

    print(f"Using device: {device}")

    all_results = []
    kf = KFold(n_splits=k_folds, shuffle=True)

    edge_index = data.edge_index.cpu().numpy()
    edge_attr = data.edge_attr.cpu().numpy()
    unique_edges = set()
    edge_to_idx = {}

    for i in range(edge_index.shape[1]):
        edge = tuple(sorted([edge_index[0, i], edge_index[1, i]]))
        if edge not in unique_edges:
            unique_edges.add(edge)
            edge_to_idx[edge] = len(edge_to_idx)

    unique_edges = list(unique_edges)

    for fold, (train_idx, val_idx) in enumerate(kf.split(unique_edges)):
        print(f"\nFold {fold + 1}/{k_folds}")

        train_edges = [unique_edges[i] for i in train_idx]
        val_edges = [unique_edges[i] for i in val_idx]

        # Convert to numpy arrays first
        train_edge_index = np.array([[edge[0], edge[1]] for edge in train_edges]).T
        train_edge_attr = np.array([edge_attr[edge_to_idx[edge]] for edge in train_edges])

        val_edge_index = np.array([[edge[0], edge[1]] for edge in val_edges]).T
        val_edge_attr = np.array([edge_attr[edge_to_idx[edge]] for edge in val_edges])

        # Convert to tensors
        train_edge_index = torch.from_numpy(train_edge_index).to(device)
        train_edge_attr = torch.from_numpy(train_edge_attr).float().to(device)
        val_edge_index = torch.from_numpy(val_edge_index).to(device)
        val_edge_attr = torch.from_numpy(val_edge_attr).float().to(device)

        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min')

        best_val_loss = float('inf')
        early_stopping_counter = 0
        train_losses = []
        val_metrics = {'auc': [], 'ap': [], 'loss': []}

        for epoch in range(num_epochs):
            # Training
            model.train()
            optimizer.zero_grad()

            z, edge_features = model(data.x, train_edge_index, train_edge_attr)
            pos_out = model.decode(z, edge_features, train_edge_index)

            # Generate negative samples
            neg_edge_index = negative_sampling(
                train_edge_index,
                num_nodes=data.num_nodes,
                num_neg_samples=train_edge_index.size(1)
            )

            neg_out = model.decode(z, edge_features, neg_edge_index)
            loss = bce_with_logits_loss(pos_out, neg_out)

            loss.backward()
            optimizer.step()
            train_losses.append(loss.item())

            # Validation
            model.eval()
            with torch.no_grad():
                z, edge_features = model(data.x, val_edge_index, val_edge_attr)
                pos_out = model.decode(z, edge_features, val_edge_index)

                neg_edge_index = negative_sampling(
                    val_edge_index,
                    num_nodes=data.num_nodes,
                    num_neg_samples=val_edge_index.size(1)
                )

                neg_out = model.decode(z, edge_features, neg_edge_index)
                val_loss = bce_with_logits_loss(pos_out, neg_out)

                # Compute metrics
                pred = torch.cat([pos_out, neg_out]).cpu().numpy()
                true = torch.cat([
                    torch.ones(pos_out.size(0)),
                    torch.zeros(neg_out.size(0))
                ]).numpy()

                auc = roc_auc_score(true, pred)
                ap = average_precision_score(true, pred)

                val_metrics['auc'].append(auc)
                val_metrics['ap'].append(ap)
                val_metrics['loss'].append(val_loss.item())

                if (epoch + 1) % 10 == 0:
                    print(f"Epoch {epoch + 1}: Train Loss = {loss:.4f}, "
                          f"Val Loss = {val_loss:.4f}, AUC = {auc:.4f}, AP = {ap:.4f}")

                scheduler.step(val_loss)

                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    early_stopping_counter = 0
                else:
                    early_stopping_counter += 1

                if early_stopping_counter >= 20:
                    print("Early stopping triggered")
                    break

        fold_results = {
            'fold': fold + 1,
            'train_losses': train_losses,
            'val_metrics': val_metrics,
            'final_auc': auc,
            'final_ap': ap
        }
        all_results.append(fold_results)

    return all_results

def visualize_results(results, network_data, save_path='qkd_results'):
    """Create comprehensive visualizations and analysis"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    save_path = f"{save_path}_{timestamp}"
    os.makedirs(save_path, exist_ok=True)

    # Find the minimum length across all result arrays
    min_epochs = min(len(result['train_losses']) for result in results)

    # Truncate all arrays to the minimum length
    truncated_results = []
    for result in results:
        truncated_result = {
            'train_losses': result['train_losses'][:min_epochs],
            'val_metrics': {
                'auc': result['val_metrics']['auc'][:min_epochs],
                'ap': result['val_metrics']['ap'][:min_epochs],
                'loss': result['val_metrics']['loss'][:min_epochs]
            },
            'final_auc': result['final_auc'],
            'final_ap': result['final_ap']
        }
        truncated_results.append(truncated_result)

    # Training Metrics
    plt.figure(figsize=(15, 10))

    # Plot training losses
    plt.subplot(2, 2, 1)
    for result in truncated_results:
        plt.plot(result['train_losses'], alpha=0.3)
    mean_train_loss = np.mean([r['train_losses'] for r in truncated_results], axis=0)
    plt.plot(mean_train_loss, 'r-', label='Mean')
    plt.title('Training Loss Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()

    # Plot validation AUC
    plt.subplot(2, 2, 2)
    for result in truncated_results:
        plt.plot(result['val_metrics']['auc'], alpha=0.3)
    mean_val_auc = np.mean([r['val_metrics']['auc'] for r in truncated_results], axis=0)
    plt.plot(mean_val_auc, 'r-', label='Mean')
    plt.title('Validation AUC Evolution')
    plt.xlabel('Epoch')
    plt.ylabel('AUC')
    plt.legend()

    # Key Rate vs Distance Analysis
    distances = network_data.edge_attr[:, 2].cpu().numpy()
    key_rates = network_data.edge_attr[:, 0].cpu().numpy()

    plt.subplot(2, 2, 3)
    plt.scatter(distances, key_rates, alpha=0.5)
    plt.xlabel('Distance (km)')
    plt.ylabel('Key Rate (bits/s)')
    plt.yscale('log')
    plt.title('Key Rate vs Distance')

    x_fit = np.linspace(min(distances), max(distances), 100)
    y_fit = np.exp(-0.2 * x_fit)
    plt.plot(x_fit, y_fit * max(key_rates), 'r--', label='Theoretical')
    plt.legend()

    # QBER Distribution
    plt.subplot(2, 2, 4)
    qber_values = network_data.edge_attr[:, 1].cpu().numpy()
    sns.histplot(qber_values, bins=20)
    plt.xlabel('QBER')
    plt.ylabel('Count')
    plt.title('QBER Distribution')

    plt.tight_layout()
    plt.savefig(f"{save_path}/training_metrics.png")
    plt.close()

    # Performance Report
    report = {
        'network_stats': {
            'num_nodes': int(network_data.num_nodes),
            'num_edges': int(len(key_rates)),
            'avg_degree': float(2 * len(key_rates) / network_data.num_nodes),
            'avg_key_rate': float(np.mean(key_rates)),
            'avg_qber': float(np.mean(qber_values)),
            'max_distance': float(np.max(distances))
        },
        'model_performance': {
            'final_metrics': {
                'auc_mean': float(np.mean([r['final_auc'] for r in results])),
                'auc_std': float(np.std([r['final_auc'] for r in results])),
                'ap_mean': float(np.mean([r['final_ap'] for r in results])),
                'ap_std': float(np.std([r['final_ap'] for r in results]))
            },
            'convergence': {
                'final_train_loss_mean': float(np.mean([r['train_losses'][-1] for r in truncated_results])),
                'best_epoch_mean': float(np.mean([np.argmin(r['val_metrics']['loss']) for r in truncated_results]))
            }
        }
    }

    with open(f"{save_path}/performance_report.json", 'w') as f:
        json.dump(report, f, indent=4)

    return report

# Main execution
if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)

    # Generate network
    print("Generating QKD network...")
    network = AdvancedQKDNetwork(num_nodes=50)
    data = network.generate_graph_data()

    # Create and train model
    print("Creating and training model...")
    model = AdvancedQKDLinkPredictor(
        in_channels=data.x.size(1),
        edge_attr_channels=data.edge_attr.size(1)
    )

    # Train and evaluate
    results = train_and_evaluate(model, data)

    # Generate visualizations and analysis
    print("Generating visualizations and analysis...")
    save_path = 'qkd_results'
    performance_report = visualize_results(results, data, save_path)

    print(f"\nAnalysis complete. Results saved in: {save_path}")

    # Print summary statistics
    print("\nSummary Statistics:")
    print(f"Number of nodes: {data.num_nodes}")
    print(f"Number of edges: {data.edge_index.size(1) // 2}")
    print(f"Average degree: {data.edge_index.size(1) / data.num_nodes:.2f}")
    print(f"Average key rate: {performance_report['network_stats']['avg_key_rate']:.2e} bits/s")
    print(f"Average QBER: {performance_report['network_stats']['avg_qber']:.3f}")
    print(f"Model AUC: {performance_report['model_performance']['final_metrics']['auc_mean']:.3f} ± "
          f"{performance_report['model_performance']['final_metrics']['auc_std']:.3f}")

Generating QKD network...
Creating and training model...
Using device: cpu

Fold 1/5
Epoch 10: Train Loss = 1.2917, Val Loss = 1.3077, AUC = 0.6956, AP = 0.6516
Epoch 20: Train Loss = 1.1165, Val Loss = 1.2284, AUC = 0.7280, AP = 0.6972
Epoch 30: Train Loss = 0.9381, Val Loss = 1.1720, AUC = 0.7566, AP = 0.7112
Epoch 40: Train Loss = 0.7759, Val Loss = 1.2591, AUC = 0.7346, AP = 0.7111
Early stopping triggered

Fold 2/5
Epoch 10: Train Loss = 0.8287, Val Loss = 1.0229, AUC = 0.8384, AP = 0.8269
Epoch 20: Train Loss = 0.7362, Val Loss = 1.2334, AUC = 0.7516, AP = 0.7067
Early stopping triggered

Fold 3/5
Epoch 10: Train Loss = 0.7097, Val Loss = 1.0497, AUC = 0.8249, AP = 0.8077
Epoch 20: Train Loss = 0.6588, Val Loss = 1.1829, AUC = 0.7889, AP = 0.7409
Early stopping triggered

Fold 4/5
Epoch 10: Train Loss = 0.6401, Val Loss = 0.9622, AUC = 0.8453, AP = 0.7903
Epoch 20: Train Loss = 0.6096, Val Loss = 1.0067, AUC = 0.8319, AP = 0.7866
Epoch 30: Train Loss = 0.6646, Val Loss = 1.0684, 

In [3]:
# Add these imports to the existing imports cell
import heapq
from scipy.optimize import differential_evolution
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import LinearRegression
import time
import warnings
warnings.filterwarnings('ignore')

class ClassicalRoutingMethods:
    """Classical routing and optimization methods for QKD networks"""
    
    def __init__(self, network_data):
        self.data = network_data
        self.G = self._build_networkx_graph()
        
    def _build_networkx_graph(self):
        """Build NetworkX graph from PyTorch Geometric data"""
        G = nx.Graph()
        edge_index = self.data.edge_index.cpu().numpy()
        edge_attr = self.data.edge_attr.cpu().numpy()
        
        # Handle the case where edges are duplicated in edge_index but not in edge_attr
        # Create a mapping from edge to attributes
        edge_to_attr = {}
        
        # Since we use to_undirected, edges are duplicated in edge_index
        # But edge_attr contains only unique edges
        num_unique_edges = edge_attr.shape[0]
        
        for i in range(num_unique_edges):
            # For each unique edge, we need to find its position in the original edge_index
            # Since to_undirected creates [src,dst] and [dst,src], we handle both
            if i < edge_index.shape[1] // 2:
                src, dst = edge_index[0, i], edge_index[1, i]
            else:
                # For the second half, edges are reversed
                original_idx = i - edge_index.shape[1] // 2
                dst, src = edge_index[0, original_idx], edge_index[1, original_idx]
            
            edge_key = tuple(sorted([src, dst]))
            if edge_key not in edge_to_attr:
                key_rate = edge_attr[i % num_unique_edges, 0]
                qber = edge_attr[i % num_unique_edges, 1]
                distance = edge_attr[i % num_unique_edges, 2]
                edge_to_attr[edge_key] = (key_rate, qber, distance)
        
        # Now build the graph with unique edges
        for edge_key, (key_rate, qber, distance) in edge_to_attr.items():
            src, dst = edge_key
            
            # Use negative key rate as weight for shortest path algorithms
            weight = 1.0 / (key_rate + 1e-10)  # Avoid division by zero
            
            G.add_edge(src, dst, 
                      weight=weight,
                      key_rate=key_rate,
                      qber=qber,
                      distance=distance)
        
        return G
    
    def dijkstra_routing(self, source_target_pairs):
        """Dijkstra's shortest path algorithm"""
        predictions = []
        
        for src, dst in source_target_pairs:
            try:
                # Find shortest path based on inverse key rate
                path = nx.shortest_path(self.G, src, dst, weight='weight')
                path_exists = len(path) > 1
                
                if path_exists:
                    # Calculate path quality metrics
                    total_key_rate = 1.0
                    max_qber = 0.0
                    total_distance = 0.0
                    
                    for i in range(len(path) - 1):
                        edge_data = self.G[path[i]][path[i+1]]
                        total_key_rate *= edge_data['key_rate']
                        max_qber = max(max_qber, edge_data['qber'])
                        total_distance += edge_data['distance']
                    
                    # Score based on end-to-end key rate
                    score = total_key_rate if max_qber < 0.11 else 0.0
                else:
                    score = 0.0
                    
            except nx.NetworkXNoPath:
                score = 0.0
            except KeyError:
                score = 0.0
                
            predictions.append(score)
            
        return np.array(predictions)
    
    def minimum_spanning_tree(self, source_target_pairs):
        """MST-based routing"""
        try:
            mst = nx.minimum_spanning_tree(self.G, weight='weight')
        except:
            # If MST fails, use original graph
            mst = self.G
        
        predictions = []
        for src, dst in source_target_pairs:
            try:
                path = nx.shortest_path(mst, src, dst)
                if len(path) > 1:
                    # Calculate path quality
                    total_key_rate = 1.0
                    for i in range(len(path) - 1):
                        if mst.has_edge(path[i], path[i+1]):
                            edge_data = mst[path[i]][path[i+1]]
                            total_key_rate *= edge_data.get('key_rate', 0.001)
                    score = total_key_rate
                else:
                    score = 0.0
            except (nx.NetworkXNoPath, KeyError):
                score = 0.0
            predictions.append(score)
            
        return np.array(predictions)
    
    def greedy_best_first(self, source_target_pairs):
        """Greedy best-first search based on key rate"""
        predictions = []
        
        for src, dst in source_target_pairs:
            try:
                visited = set()
                current = src
                path_key_rate = 1.0
                found_path = False
                max_hops = 10  # Prevent infinite loops
                
                while current != dst and current not in visited and len(visited) < max_hops:
                    visited.add(current)
                    
                    # Find best next hop based on key rate
                    best_neighbor = None
                    best_key_rate = 0.0
                    
                    if current in self.G:
                        for neighbor in self.G.neighbors(current):
                            if neighbor not in visited:
                                edge_data = self.G[current][neighbor]
                                if edge_data['key_rate'] > best_key_rate:
                                    best_key_rate = edge_data['key_rate']
                                    best_neighbor = neighbor
                    
                    if best_neighbor is None:
                        break
                        
                    path_key_rate *= best_key_rate
                    current = best_neighbor
                    
                    if current == dst:
                        found_path = True
                        break
                
                score = path_key_rate if found_path else 0.0
            except:
                score = 0.0
                
            predictions.append(score)
            
        return np.array(predictions)

class GeneticAlgorithmOptimizer:
    """Genetic Algorithm for QKD network optimization"""
    
    def __init__(self, network_data, population_size=50, generations=100):
        self.data = network_data
        self.population_size = population_size
        self.generations = generations
        self.num_nodes = network_data.num_nodes
        
        # Build adjacency matrix for faster access
        self.adj_matrix = self._build_adjacency_matrix()
        
    def _build_adjacency_matrix(self):
        """Build adjacency matrix with key rates"""
        adj = np.zeros((self.num_nodes, self.num_nodes))
        edge_index = self.data.edge_index.cpu().numpy()
        edge_attr = self.data.edge_attr.cpu().numpy()
        
        # Handle undirected edges correctly
        num_unique_edges = edge_attr.shape[0]
        
        for i in range(min(edge_index.shape[1], num_unique_edges)):
            src, dst = edge_index[0, i], edge_index[1, i]
            key_rate = edge_attr[i, 0]
            adj[src, dst] = key_rate
            adj[dst, src] = key_rate
            
        return adj
    
    def _fitness_function(self, individual, source_target_pairs):
        """Evaluate fitness of an individual (routing solution)"""
        total_score = 0.0
        
        for i, (src, dst) in enumerate(source_target_pairs):
            # Individual encodes path selection probabilities
            path_score = self._evaluate_path(src, dst, individual)
            total_score += path_score
            
        return total_score / max(len(source_target_pairs), 1)
    
    def _evaluate_path(self, src, dst, individual):
        """Evaluate a specific path using individual's strategy"""
        current = src
        visited = set()
        path_key_rate = 1.0
        
        while current != dst and current not in visited and len(visited) < 10:
            visited.add(current)
            
            # Use individual's weights to select next hop
            neighbors = np.where(self.adj_matrix[current] > 0)[0]
            neighbors = [n for n in neighbors if n not in visited]
            
            if not neighbors:
                return 0.0
                
            # Weight neighbors based on individual's preferences
            weights = []
            for neighbor in neighbors:
                key_rate = self.adj_matrix[current, neighbor]
                # Individual influences neighbor selection
                weight = key_rate * (1 + individual[neighbor % len(individual)])
                weights.append(weight)
            
            if not weights:
                return 0.0
                
            # Select best neighbor
            best_idx = np.argmax(weights)
            next_node = neighbors[best_idx]
            path_key_rate *= self.adj_matrix[current, next_node]
            current = next_node
            
        return path_key_rate if current == dst else 0.0
    
    def optimize(self, source_target_pairs):
        """Run genetic algorithm optimization"""
        if not source_target_pairs:
            return np.array([])
            
        # Initialize population
        population = []
        for _ in range(self.population_size):
            individual = np.random.uniform(-1, 1, self.num_nodes)
            population.append(individual)
        
        best_fitness = -float('inf')
        best_individual = None
        
        for generation in range(self.generations):
            # Evaluate fitness
            fitness_scores = []
            for individual in population:
                fitness = self._fitness_function(individual, source_target_pairs)
                fitness_scores.append(fitness)
                
                if fitness > best_fitness:
                    best_fitness = fitness
                    best_individual = individual.copy()
            
            # Selection and reproduction
            new_population = []
            
            # Elitism: keep best individuals
            elite_size = max(1, self.population_size // 10)
            elite_indices = np.argsort(fitness_scores)[-elite_size:]
            for idx in elite_indices:
                new_population.append(population[idx].copy())
            
            # Crossover and mutation
            while len(new_population) < self.population_size:
                # Tournament selection
                parent1 = self._tournament_selection(population, fitness_scores)
                parent2 = self._tournament_selection(population, fitness_scores)
                
                # Crossover
                child = self._crossover(parent1, parent2)
                
                # Mutation
                child = self._mutate(child)
                
                new_population.append(child)
            
            population = new_population
        
        # Generate predictions using best individual
        if best_individual is None:
            return np.zeros(len(source_target_pairs))
            
        predictions = []
        for src, dst in source_target_pairs:
            score = self._evaluate_path(src, dst, best_individual)
            predictions.append(score)
            
        return np.array(predictions)
    
    def _tournament_selection(self, population, fitness_scores, tournament_size=3):
        """Tournament selection for parent selection"""
        tournament_size = min(tournament_size, len(population))
        tournament_indices = np.random.choice(len(population), tournament_size, replace=False)
        tournament_fitness = [fitness_scores[i] for i in tournament_indices]
        winner_idx = tournament_indices[np.argmax(tournament_fitness)]
        return population[winner_idx].copy()
    
    def _crossover(self, parent1, parent2):
        """Single-point crossover"""
        crossover_point = np.random.randint(1, len(parent1))
        child = np.concatenate([parent1[:crossover_point], parent2[crossover_point:]])
        return child
    
    def _mutate(self, individual, mutation_rate=0.1):
        """Gaussian mutation"""
        mask = np.random.random(len(individual)) < mutation_rate
        individual[mask] += np.random.normal(0, 0.1, np.sum(mask))
        return np.clip(individual, -1, 1)

class MachineLearningBaselines:
    """Classical ML approaches for link prediction"""
    
    def __init__(self, network_data):
        self.data = network_data
        
    def random_forest_predictor(self, train_pairs, test_pairs, train_labels, test_labels):
        """Random Forest baseline"""
        # Extract features for training
        train_features = self._extract_features(train_pairs)
        test_features = self._extract_features(test_pairs)
        
        if train_features.shape[0] == 0 or test_features.shape[0] == 0:
            return np.zeros(len(test_pairs))
        
        # Train Random Forest
        rf = RandomForestRegressor(n_estimators=100, random_state=42)
        rf.fit(train_features, train_labels)
        
        # Predict
        predictions = rf.predict(test_features)
        return predictions
    
    def linear_regression_predictor(self, train_pairs, test_pairs, train_labels, test_labels):
        """Linear Regression baseline"""
        train_features = self._extract_features(train_pairs)
        test_features = self._extract_features(test_pairs)
        
        if train_features.shape[0] == 0 or test_features.shape[0] == 0:
            return np.zeros(len(test_pairs))
        
        lr = LinearRegression()
        lr.fit(train_features, train_labels)
        
        predictions = lr.predict(test_features)
        return predictions
    
    def _extract_features(self, node_pairs):
        """Extract features for node pairs"""
        if not node_pairs:
            return np.array([]).reshape(0, 10)  # Return empty array with correct shape
            
        features = []
        node_features = self.data.x.cpu().numpy()
        edge_index = self.data.edge_index.cpu().numpy()
        
        # Build adjacency list for faster lookup
        adj_list = {}
        num_unique_edges = min(edge_index.shape[1], self.data.edge_attr.shape[0])
        
        for i in range(num_unique_edges):
            src, dst = edge_index[0, i], edge_index[1, i]
            if src not in adj_list:
                adj_list[src] = []
            if dst not in adj_list:
                adj_list[dst] = []
            adj_list[src].append(dst)
            adj_list[dst].append(src)
        
        for src, dst in node_pairs:
            # Node features
            src_features = node_features[src] if src < len(node_features) else np.zeros(4)
            dst_features = node_features[dst] if dst < len(node_features) else np.zeros(4)
            
            # Distance between nodes
            distance = np.linalg.norm(src_features[:2] - dst_features[:2])
            
            # Common neighbors
            src_neighbors = set(adj_list.get(src, []))
            dst_neighbors = set(adj_list.get(dst, []))
            common_neighbors = len(src_neighbors.intersection(dst_neighbors))
            
            # Combine features
            pair_features = np.concatenate([
                src_features,
                dst_features,
                [distance, common_neighbors]
            ])
            
            features.append(pair_features)
            
        return np.array(features)

def comprehensive_evaluation(model, data, num_epochs=200, k_folds=5):
    """Comprehensive evaluation comparing GNN with classical methods"""
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    data = data.to(device)
    
    print(f"Using device: {device}")
    print("Running comprehensive evaluation with classical baselines...")
    
    # Initialize classical methods
    classical_routing = ClassicalRoutingMethods(data)
    ml_baselines = MachineLearningBaselines(data)
    
    all_results = []
    kf = KFold(n_splits=k_folds, shuffle=True)
    
    # Prepare edge data - handle the undirected edge issue
    edge_index = data.edge_index.cpu().numpy()
    edge_attr = data.edge_attr.cpu().numpy()
    
    # Create unique edges from the edge_index
    unique_edges = set()
    edge_to_idx = {}
    
    # Since edges are undirected, we only take the first half
    num_unique_edges = edge_attr.shape[0]
    
    for i in range(num_unique_edges):
        if i < edge_index.shape[1]:
            src, dst = edge_index[0, i], edge_index[1, i]
            edge = tuple(sorted([src, dst]))
            if edge not in unique_edges:
                unique_edges.add(edge)
                edge_to_idx[edge] = i
    
    unique_edges = list(unique_edges)
    print(f"Found {len(unique_edges)} unique edges")
    
    for fold, (train_idx, val_idx) in enumerate(kf.split(unique_edges)):
        print(f"\nFold {fold + 1}/{k_folds}")
        
        train_edges = [unique_edges[i] for i in train_idx]
        val_edges = [unique_edges[i] for i in val_idx]
        
        # Prepare training data
        train_edge_list = []
        train_attr_list = []
        
        for edge in train_edges:
            train_edge_list.append([edge[0], edge[1]])
            attr_idx = edge_to_idx[edge]
            train_attr_list.append(edge_attr[attr_idx])
        
        val_edge_list = []
        val_attr_list = []
        
        for edge in val_edges:
            val_edge_list.append([edge[0], edge[1]])
            attr_idx = edge_to_idx[edge]
            val_attr_list.append(edge_attr[attr_idx])
        
        if not train_edge_list or not val_edge_list:
            print(f"  Skipping fold {fold + 1} due to empty train or validation set")
            continue
        
        train_edge_index = torch.tensor(train_edge_list, dtype=torch.long).t().to(device)
        train_edge_attr = torch.tensor(train_attr_list, dtype=torch.float).to(device)
        val_edge_index = torch.tensor(val_edge_list, dtype=torch.long).t().to(device)
        val_edge_attr = torch.tensor(val_attr_list, dtype=torch.float).to(device)
        
        # Train GNN model
        print("Training GNN model...")
        optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
        
        # Simplified training loop for comparison
        for epoch in range(min(num_epochs, 50)):  # Reduced for faster comparison
            model.train()
            optimizer.zero_grad()
            
            z, edge_features = model(data.x, train_edge_index, train_edge_attr)
            pos_out = model.decode(z, edge_features, train_edge_index)
            
            neg_edge_index = negative_sampling(
                train_edge_index,
                num_nodes=data.num_nodes,
                num_neg_samples=train_edge_index.size(1)
            )
            
            neg_out = model.decode(z, edge_features, neg_edge_index)
            loss = bce_with_logits_loss(pos_out, neg_out)
            
            loss.backward()
            optimizer.step()
            
            if epoch % 10 == 0:
                print(f"  GNN Epoch {epoch}: Loss = {loss:.4f}")
        
        # Evaluate all methods
        print("Evaluating all methods...")
        
        # Prepare test data
        val_pairs = [(val_edge_index[0, i].item(), val_edge_index[1, i].item()) 
                     for i in range(val_edge_index.size(1))]
        train_pairs = [(train_edge_index[0, i].item(), train_edge_index[1, i].item()) 
                       for i in range(train_edge_index.size(1))]
        
        # Generate negative samples for fair comparison
        val_neg_edge_index = negative_sampling(
            val_edge_index,
            num_nodes=data.num_nodes,
            num_neg_samples=val_edge_index.size(1)
        )
        val_neg_pairs = [(val_neg_edge_index[0, i].item(), val_neg_edge_index[1, i].item()) 
                         for i in range(val_neg_edge_index.size(1))]
        
        all_val_pairs = val_pairs + val_neg_pairs
        true_labels = [1] * len(val_pairs) + [0] * len(val_neg_pairs)
        
        train_labels = [1] * len(train_pairs)
        
        results = {}
        
        # 1. GNN Evaluation
        print("  Evaluating GNN...")
        model.eval()
        with torch.no_grad():
            z, edge_features = model(data.x, val_edge_index, val_edge_attr)
            pos_out = model.decode(z, edge_features, val_edge_index)
            neg_out = model.decode(z, edge_features, val_neg_edge_index)
            
            gnn_pred = torch.cat([pos_out, neg_out]).cpu().numpy()
            gnn_auc = roc_auc_score(true_labels, gnn_pred)
            gnn_ap = average_precision_score(true_labels, gnn_pred)
            
        results['GNN'] = {
            'auc': gnn_auc,
            'ap': gnn_ap,
            'predictions': gnn_pred
        }
        
        # 2. Classical Routing Methods
        print("  Evaluating classical routing methods...")
        
        # Dijkstra
        start_time = time.time()
        dijkstra_pred = classical_routing.dijkstra_routing(all_val_pairs)
        dijkstra_time = time.time() - start_time
        
        # Normalize predictions for fair comparison
        if np.max(dijkstra_pred) > 0:
            dijkstra_pred = dijkstra_pred / np.max(dijkstra_pred)
        
        dijkstra_auc = roc_auc_score(true_labels, dijkstra_pred) if len(np.unique(dijkstra_pred)) > 1 else 0.5
        dijkstra_ap = average_precision_score(true_labels, dijkstra_pred)
        
        results['Dijkstra'] = {
            'auc': dijkstra_auc,
            'ap': dijkstra_ap,
            'time': dijkstra_time,
            'predictions': dijkstra_pred
        }
        
        # MST
        start_time = time.time()
        mst_pred = classical_routing.minimum_spanning_tree(all_val_pairs)
        mst_time = time.time() - start_time
        
        if np.max(mst_pred) > 0:
            mst_pred = mst_pred / np.max(mst_pred)
        
        mst_auc = roc_auc_score(true_labels, mst_pred) if len(np.unique(mst_pred)) > 1 else 0.5
        mst_ap = average_precision_score(true_labels, mst_pred)
        
        results['MST'] = {
            'auc': mst_auc,
            'ap': mst_ap,
            'time': mst_time,
            'predictions': mst_pred
        }
        
        # Greedy Best-First
        start_time = time.time()
        greedy_pred = classical_routing.greedy_best_first(all_val_pairs)
        greedy_time = time.time() - start_time
        
        if np.max(greedy_pred) > 0:
            greedy_pred = greedy_pred / np.max(greedy_pred)
        
        greedy_auc = roc_auc_score(true_labels, greedy_pred) if len(np.unique(greedy_pred)) > 1 else 0.5
        greedy_ap = average_precision_score(true_labels, greedy_pred)
        
        results['Greedy'] = {
            'auc': greedy_auc,
            'ap': greedy_ap,
            'time': greedy_time,
            'predictions': greedy_pred
        }
        
        # 3. Machine Learning Baselines
        print("  Evaluating ML baselines...")
        
        # Random Forest
        start_time = time.time()
        try:
            rf_pred = ml_baselines.random_forest_predictor(
                train_pairs, all_val_pairs, train_labels, true_labels
            )
            rf_time = time.time() - start_time
            
            rf_auc = roc_auc_score(true_labels, rf_pred)
            rf_ap = average_precision_score(true_labels, rf_pred)
            
            results['Random Forest'] = {
                'auc': rf_auc,
                'ap': rf_ap,
                'time': rf_time,
                'predictions': rf_pred
            }
        except Exception as e:
            print(f"    Random Forest failed: {e}")
            results['Random Forest'] = {'auc': 0.5, 'ap': 0.0, 'time': 0.0}
        
        # Linear Regression
        start_time = time.time()
        try:
            lr_pred = ml_baselines.linear_regression_predictor(
                train_pairs, all_val_pairs, train_labels, true_labels
            )
            lr_time = time.time() - start_time
            
            lr_auc = roc_auc_score(true_labels, lr_pred)
            lr_ap = average_precision_score(true_labels, lr_pred)
            
            results['Linear Regression'] = {
                'auc': lr_auc,
                'ap': lr_ap,
                'time': lr_time,
                'predictions': lr_pred
            }
        except Exception as e:
            print(f"    Linear Regression failed: {e}")
            results['Linear Regression'] = {'auc': 0.5, 'ap': 0.0, 'time': 0.0}
        
        # 4. Genetic Algorithm (reduced generations for speed)
        print("  Evaluating Genetic Algorithm...")
        start_time = time.time()
        try:
            ga_optimizer = GeneticAlgorithmOptimizer(
                data, population_size=20, generations=30
            )
            ga_pred = ga_optimizer.optimize(all_val_pairs)
            ga_time = time.time() - start_time
            
            if len(ga_pred) > 0 and np.max(ga_pred) > 0:
                ga_pred = ga_pred / np.max(ga_pred)
            
            ga_auc = roc_auc_score(true_labels, ga_pred) if len(np.unique(ga_pred)) > 1 else 0.5
            ga_ap = average_precision_score(true_labels, ga_pred)
            
            results['Genetic Algorithm'] = {
                'auc': ga_auc,
                'ap': ga_ap,
                'time': ga_time,
                'predictions': ga_pred
            }
        except Exception as e:
            print(f"    Genetic Algorithm failed: {e}")
            results['Genetic Algorithm'] = {'auc': 0.5, 'ap': 0.0, 'time': 0.0}
        
        # Print fold results
        print(f"\n  Fold {fold + 1} Results:")
        for method, metrics in results.items():
            print(f"    {method}: AUC = {metrics['auc']:.4f}, AP = {metrics['ap']:.4f}")
        
        all_results.append(results)
    
    return all_results

def visualize_comparative_results(results, network_data, save_path='qkd_comparative_results'):
    """Create comprehensive comparative visualizations"""
    timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
    save_path = f"{save_path}_{timestamp}"
    os.makedirs(save_path, exist_ok=True)
    
    # Extract method names
    methods = list(results[0].keys())
    
    # Aggregate results across folds
    aggregated_results = {}
    for method in methods:
        auc_scores = [fold[method]['auc'] for fold in results if method in fold]
        ap_scores = [fold[method]['ap'] for fold in results if method in fold]
        
        aggregated_results[method] = {
            'auc_mean': np.mean(auc_scores),
            'auc_std': np.std(auc_scores),
            'ap_mean': np.mean(ap_scores),
            'ap_std': np.std(ap_scores),
            'auc_scores': auc_scores,
            'ap_scores': ap_scores
        }
    
    # Create comparative plots
    fig, axes = plt.subplots(2, 2, figsize=(15, 12))
    
    # AUC Comparison
    ax1 = axes[0, 0]
    method_names = list(aggregated_results.keys())
    auc_means = [aggregated_results[m]['auc_mean'] for m in method_names]
    auc_stds = [aggregated_results[m]['auc_std'] for m in method_names]
    
    bars1 = ax1.bar(method_names, auc_means, yerr=auc_stds, capsize=5, alpha=0.7)
    ax1.set_title('AUC Score Comparison')
    ax1.set_ylabel('AUC Score')
    ax1.set_ylim(0, 1)
    ax1.tick_params(axis='x', rotation=45)
    
    # Highlight best method
    best_auc_idx = np.argmax(auc_means)
    bars1[best_auc_idx].set_color('red')
    bars1[best_auc_idx].set_alpha(1.0)
    
    # AP Comparison
    ax2 = axes[0, 1]
    ap_means = [aggregated_results[m]['ap_mean'] for m in method_names]
    ap_stds = [aggregated_results[m]['ap_std'] for m in method_names]
    
    bars2 = ax2.bar(method_names, ap_means, yerr=ap_stds, capsize=5, alpha=0.7)
    ax2.set_title('Average Precision Comparison')
    ax2.set_ylabel('Average Precision')
    ax2.set_ylim(0, 1)
    ax2.tick_params(axis='x', rotation=45)
    
    # Highlight best method
    best_ap_idx = np.argmax(ap_means)
    bars2[best_ap_idx].set_color('red')
    bars2[best_ap_idx].set_alpha(1.0)
    
    # Box plots for score distributions
    ax3 = axes[1, 0]
    auc_data = [aggregated_results[m]['auc_scores'] for m in method_names]
    box1 = ax3.boxplot(auc_data, labels=method_names, patch_artist=True)
    ax3.set_title('AUC Score Distribution')
    ax3.set_ylabel('AUC Score')
    ax3.tick_params(axis='x', rotation=45)
    
    ax4 = axes[1, 1]
    ap_data = [aggregated_results[m]['ap_scores'] for m in method_names]
    box2 = ax4.boxplot(ap_data, labels=method_names, patch_artist=True)
    ax4.set_title('Average Precision Distribution')
    ax4.set_ylabel('Average Precision')
    ax4.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    plt.savefig(f"{save_path}/method_comparison.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # Performance improvement analysis
    gnn_auc = aggregated_results['GNN']['auc_mean']
    gnn_ap = aggregated_results['GNN']['ap_mean']
    
    improvements = {}
    for method in method_names:
        if method != 'GNN':
            auc_improvement = ((gnn_auc - aggregated_results[method]['auc_mean']) / 
                             aggregated_results[method]['auc_mean']) * 100
            ap_improvement = ((gnn_ap - aggregated_results[method]['ap_mean']) / 
                            aggregated_results[method]['ap_mean']) * 100
            
            improvements[method] = {
                'auc_improvement': auc_improvement,
                'ap_improvement': ap_improvement
            }
    
    # Statistical significance testing
    from scipy import stats
    
    statistical_tests = {}
    gnn_auc_scores = aggregated_results['GNN']['auc_scores']
    gnn_ap_scores = aggregated_results['GNN']['ap_scores']
    
    for method in method_names:
        if method != 'GNN':
            method_auc_scores = aggregated_results[method]['auc_scores']
            method_ap_scores = aggregated_results[method]['ap_scores']
            
            # Perform t-tests
            auc_tstat, auc_pvalue = stats.ttest_rel(gnn_auc_scores, method_auc_scores)
            ap_tstat, ap_pvalue = stats.ttest_rel(gnn_ap_scores, method_ap_scores)
            
            statistical_tests[method] = {
                'auc_tstat': float(auc_tstat),
                'auc_pvalue': float(auc_pvalue),
                'ap_tstat': float(ap_tstat),
                'ap_pvalue': float(ap_pvalue),
                'auc_significant': bool(auc_pvalue < 0.05),  # Convert to native bool
                'ap_significant': bool(ap_pvalue < 0.05)     # Convert to native bool
            }
    
    # Generate comprehensive report
    report = {
        'summary': {
            'best_method_auc': method_names[best_auc_idx],
            'best_auc_score': float(auc_means[best_auc_idx]),
            'best_method_ap': method_names[best_ap_idx],
            'best_ap_score': float(ap_means[best_ap_idx]),
            'gnn_rank_auc': int(sorted(auc_means, reverse=True).index(gnn_auc) + 1),
            'gnn_rank_ap': int(sorted(ap_means, reverse=True).index(gnn_ap) + 1)
        },
        'detailed_results': {},
        'improvements_over_classical': improvements,
        'statistical_significance': statistical_tests,
        'network_characteristics': {
            'num_nodes': int(network_data.num_nodes),
            'num_edges': int(network_data.edge_index.size(1) // 2),
            'avg_degree': float(network_data.edge_index.size(1) / network_data.num_nodes),
            'avg_key_rate': float(network_data.edge_attr[:, 0].mean()),
            'avg_qber': float(network_data.edge_attr[:, 1].mean())
        }
    }
    
    # Add detailed results for each method
    for method in method_names:
        report['detailed_results'][method] = {
            'auc_mean': float(aggregated_results[method]['auc_mean']),
            'auc_std': float(aggregated_results[method]['auc_std']),
            'ap_mean': float(aggregated_results[method]['ap_mean']),
            'ap_std': float(aggregated_results[method]['ap_std'])
        }
    
    # Save report
    with open(f"{save_path}/comparative_analysis_report.json", 'w') as f:
        json.dump(report, f, indent=4)
    
    # Print summary
    print(f"\n{'='*60}")
    print("COMPARATIVE ANALYSIS SUMMARY")
    print(f"{'='*60}")
    print(f"Best AUC: {report['summary']['best_method_auc']} ({report['summary']['best_auc_score']:.4f})")
    print(f"Best AP:  {report['summary']['best_method_ap']} ({report['summary']['best_ap_score']:.4f})")
    print(f"GNN Rank (AUC): {report['summary']['gnn_rank_auc']}/{len(method_names)}")
    print(f"GNN Rank (AP):  {report['summary']['gnn_rank_ap']}/{len(method_names)}")
    
    print(f"\nGNN Improvements over Classical Methods:")
    for method, improvement in improvements.items():
        print(f"  vs {method}:")
        print(f"    AUC: {improvement['auc_improvement']:+.2f}%")
        print(f"    AP:  {improvement['ap_improvement']:+.2f}%")
    
    print(f"\nStatistical Significance (p < 0.05):")
    for method, test in statistical_tests.items():
        print(f"  vs {method}:")
        print(f"    AUC: {'Significant' if test['auc_significant'] else 'Not significant'} (p={test['auc_pvalue']:.4f})")
        print(f"    AP:  {'Significant' if test['ap_significant'] else 'Not significant'} (p={test['ap_pvalue']:.4f})")
    
    return report

# Update the main execution to use comprehensive evaluation
if __name__ == "__main__":
    # Set random seeds for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Generate network
    print("Generating QKD network...")
    network = AdvancedQKDNetwork(num_nodes=30)  # Reduced for faster comparison
    data = network.generate_graph_data()
    
    # Create model
    print("Creating GNN model...")
    model = AdvancedQKDLinkPredictor(
        in_channels=data.x.size(1),
        edge_attr_channels=data.edge_attr.size(1)
    )
    
    # Run comprehensive evaluation
    print("Running comprehensive evaluation...")
    comparative_results = comprehensive_evaluation(model, data, num_epochs=50, k_folds=3)
    
    # Generate comparative analysis
    print("Generating comparative analysis...")
    analysis_report = visualize_comparative_results(comparative_results, data)
    
    print(f"\nComparative analysis complete!")
    print(f"Results demonstrate the relative performance of GNN vs classical methods.")

Generating QKD network...
Creating GNN model...
Running comprehensive evaluation...
Using device: cpu
Running comprehensive evaluation with classical baselines...
Found 199 unique edges

Fold 1/3
Training GNN model...
  GNN Epoch 0: Loss = 1.5601
  GNN Epoch 10: Loss = 1.1354
  GNN Epoch 20: Loss = 0.8901
  GNN Epoch 30: Loss = 0.7825
  GNN Epoch 40: Loss = 0.6684
Evaluating all methods...
  Evaluating GNN...
  Evaluating classical routing methods...
  Evaluating ML baselines...
  Evaluating Genetic Algorithm...

  Fold 1 Results:
    GNN: AUC = 0.8209, AP = 0.7885
    Dijkstra: AUC = 0.3904, AP = 0.4237
    MST: AUC = 0.4053, AP = 0.4337
    Greedy: AUC = 0.5376, AP = 0.5426
    Random Forest: AUC = 0.5000, AP = 0.5000
    Linear Regression: AUC = 0.5000, AP = 0.5000
    Genetic Algorithm: AUC = 0.4908, AP = 0.5310

Fold 2/3
Training GNN model...
  GNN Epoch 0: Loss = 0.9506
  GNN Epoch 10: Loss = 0.8179
  GNN Epoch 20: Loss = 0.6839
  GNN Epoch 30: Loss = 0.7092
  GNN Epoch 40: Loss 