In [None]:
# Setup and imports
import sys
import os

# Add project root to path
sys.path.insert(0, os.path.dirname(os.getcwd()) if 'notebooks' in os.getcwd() else os.getcwd())

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import networkx as nx
import torch

# Project imports
import config
from src.data_loader import load_gml_graph, list_available_networks, get_network_info
from src.feature_generator import generate_all_features
from src.dataset import create_pyg_data, create_edge_splits
from src.models import EdgeRiskGNN
from src.train import train_model, Trainer
from src.path_analysis import PathAnalyzer, analyze_paths
from src.visualization import (
    plot_network_risk_map,
    plot_training_history,
    plot_risk_distribution,
    plot_critical_paths,
    plot_comparison_metrics
)

# Check device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Explore Available Networks

In [None]:
# List available networks
networks = list_available_networks()
print(f"Total available networks: {len(networks)}")
print("\nSample networks:")
for net in networks[:15]:
    print(f"  - {net}")

In [None]:
# Choose a network (you can change this)
NETWORK_NAME = "Abilene"  # Try: "Geant2012", "Cogentco", "AttMpls"

# Load the network
G = load_gml_graph(NETWORK_NAME)

# Display network info
info = get_network_info(G)
print(f"\nNetwork: {NETWORK_NAME}")
print("-" * 40)
for key, value in info.items():
    print(f"{key}: {value}")

## 2. Generate Synthetic Features

In [None]:
# Generate features
node_features, edge_features = generate_all_features(G)

print("Node Features:")
display(node_features.head(10))

print("\nEdge Features:")
display(edge_features.head(10))

In [None]:
# Visualize feature distributions
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

# Node features
axes[0, 0].hist(node_features['load'], bins=15, edgecolor='black')
axes[0, 0].set_title('Node Load Distribution')
axes[0, 0].set_xlabel('Load')

axes[0, 1].hist(node_features['degree_centrality'], bins=15, edgecolor='black', color='orange')
axes[0, 1].set_title('Degree Centrality Distribution')

axes[0, 2].hist(node_features['betweenness_centrality'], bins=15, edgecolor='black', color='green')
axes[0, 2].set_title('Betweenness Centrality Distribution')

# Edge features
axes[1, 0].hist(edge_features['latency'], bins=15, edgecolor='black', color='red')
axes[1, 0].set_title('Edge Latency Distribution')

axes[1, 1].hist(edge_features['utilization'], bins=15, edgecolor='black', color='purple')
axes[1, 1].set_title('Edge Utilization Distribution')

axes[1, 2].hist(edge_features['risk_score'], bins=15, edgecolor='black', color='brown')
axes[1, 2].set_title('Risk Score Distribution')

plt.tight_layout()
plt.show()

## 3. Visualize Network Before Training

In [None]:
# Plot network with true risk scores
plot_network_risk_map(
    G, edge_features, None, node_features,
    title=f"Network Risk Map - {NETWORK_NAME} (True Risk Scores)"
)

## 4. Prepare Data for GNN

In [None]:
# Create PyTorch Geometric data
data = create_pyg_data(G, node_features, edge_features)
data = create_edge_splits(data)

print("PyTorch Geometric Data Object:")
print(f"  - Nodes: {data.num_nodes}")
print(f"  - Edges: {data.edge_index.size(1)}")
print(f"  - Node feature dim: {data.x.size(1)}")
print(f"  - Edge feature dim: {data.edge_attr.size(1)}")
print(f"  - Train edges: {data.train_mask.sum().item()}")
print(f"  - Val edges: {data.val_mask.sum().item()}")
print(f"  - Test edges: {data.test_mask.sum().item()}")

## 5. Train the GNN Model

In [None]:
# Train model
model, trainer, history = train_model(
    data,
    model_type='gnn',
    gnn_type='sage',  # Options: 'gcn', 'gat', 'sage'
)

In [None]:
# Plot training history
plot_training_history(history)

In [None]:
# Get predictions
predicted_risks = trainer.predict(data)

print(f"Predictions shape: {predicted_risks.shape}")
print(f"Predictions range: [{predicted_risks.min():.4f}, {predicted_risks.max():.4f}]")

# Compare with true labels
plot_risk_distribution(edge_features, predicted_risks)

## 6. Network Risk Map with GNN Predictions

In [None]:
# Plot network with GNN predicted risks
plot_network_risk_map(
    G, edge_features, predicted_risks, node_features,
    title=f"Network Risk Map - {NETWORK_NAME} (GNN Predictions)"
)

## 7. Critical Path Analysis

In [None]:
# Analyze paths
analysis_results = analyze_paths(G, edge_features, predicted_risks)

In [None]:
# Visualize critical paths
plot_critical_paths(
    G, analysis_results['critical_paths'], node_features
)

In [None]:
# Show critical edges
print("\nTop Critical Edges:")
display(analysis_results['critical_edges'])

## 8. Compare GNN with Static Metrics

In [None]:
# Plot comparison
plot_comparison_metrics(analysis_results['comparison'])

## 9. Interactive Path Finding

In [None]:
# Create path analyzer
analyzer = PathAnalyzer(G, edge_features, predicted_risks)

# Find path between specific nodes
# (Change these node IDs based on your network)
source_node = 0
target_node = G.number_of_nodes() - 1

print(f"Analyzing paths from node {source_node} to node {target_node}:")
print("-" * 50)

# Shortest path
shortest_path, shortest_len = analyzer.find_shortest_path(source_node, target_node)
shortest_risk = analyzer.get_path_risk(shortest_path)
print(f"\nShortest path: {' -> '.join(map(str, shortest_path))}")
print(f"  Length: {shortest_len} hops")
print(f"  Total risk: {shortest_risk['total_risk']:.4f}")

# Safest path
safest_path, safest_risk_total = analyzer.find_safest_path(source_node, target_node)
safest_risk = analyzer.get_path_risk(safest_path)
print(f"\nSafest path: {' -> '.join(map(str, safest_path))}")
print(f"  Length: {safest_risk['path_length']} hops")
print(f"  Total risk: {safest_risk['total_risk']:.4f}")

In [None]:
# Visualize the safest path on the network
plot_network_risk_map(
    G, edge_features, predicted_risks, node_features,
    title=f"Safest Path from {source_node} to {target_node}",
    highlight_path=safest_path
)

## 10. Summary Statistics

In [None]:
# Final summary
print("=" * 60)
print("SUMMARY")
print("=" * 60)

print(f"\nNetwork: {NETWORK_NAME}")
print(f"  Nodes: {G.number_of_nodes()}")
print(f"  Edges: {G.number_of_edges()}")

print(f"\nRisk Score Statistics:")
print(f"  True risk - Mean: {edge_features['risk_score'].mean():.4f}, Std: {edge_features['risk_score'].std():.4f}")
print(f"  Predicted risk - Mean: {predicted_risks[:len(edge_features)].mean():.4f}, Std: {predicted_risks[:len(edge_features)].std():.4f}")

# Correlation
corr = np.corrcoef(edge_features['risk_score'].values, predicted_risks[:len(edge_features)])[0, 1]
print(f"\nPrediction Correlation: {corr:.4f}")

print(f"\nTop 3 Critical Paths:")
for i, path in enumerate(analysis_results['critical_paths'][:3], 1):
    print(f"  {i}. Risk: {path['total_risk']:.4f} | Path: {' -> '.join(map(str, path['path']))}")

print("\n" + "=" * 60)