# k-Nearest Neighbor Causal Discovery Example\n\nThis notebook demonstrates causal network discovery using the **k-Nearest Neighbor** information method with distance-based information.\n\n## Overview\n- Generate synthetic time series with known causal structure\n- Visualize the dynamics and network structure\n- Apply causal discovery using k-Nearest Neighbor conditional mutual information\n- Evaluate performance using ROC-AUC metric\n\n

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
from sklearn.metrics import roc_auc_score, roc_curve
import warnings
warnings.filterwarnings('ignore')

# Import causal discovery components
from causalentropy.core.discovery import discover_network
from causalentropy.datasets.synthetic import linear_stochastic_gaussian_process

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette('tab10')

print('Libraries imported successfully!')

## 1. Create Ground Truth Network\n\nWe'll create a directed graph that represents the true causal relationships.

In [None]:
# Create ground truth network

n_nodes = 5
seed = 42

np.random.seed(seed)

G_true = nx.DiGraph()
G_true.add_nodes_from(range(n_nodes))

# Add specific causal edges
edges = [(0, 1), (0, 2), (1, 3), (2, 4), (3, 4)]
G_true.add_edges_from(edges)

print(f'Ground truth network has {G_true.number_of_nodes()} nodes and {G_true.number_of_edges()} edges')
print(f'Edges: {list(G_true.edges())}')

A_true = nx.adjacency_matrix(G_true).toarray()
print(f'\\nGround truth adjacency matrix:')
print(A_true)

## 2. Generate Synthetic k-Nearest Neighbor Data\n\nGenerate distance-based information with the known causal structure.

In [None]:
# Generate synthetic k-Nearest Neighbor data
T = 200  # Time series length
# Generate data using linear_stochastic_gaussian_process
data, A_generated = linear_stochastic_gaussian_process(
    n=n_nodes,
    T=T,
    G=G_true,
    rho=0.7, p=0.0,
    seed=seed\n)
print(f'Generated k-Nearest Neighbor data with shape: {data.shape}')
print(f'Data statistics:')
print(f'  Mean: {np.mean(data):.3f}')
print(f'  Std:  {np.std(data):.3f}')
print(f'  Range: [{np.min(data):.3f}, {np.max(data):.3f}]')
print(f'  Data type: Any continuous (uses Gaussian data)')

## 3. Visualize Time Series Data\n\nPlot the distance-based information to understand the data characteristics.

In [None]:
# Plot time series
fig, axes = plt.subplots(n_nodes, 1, figsize=(12, 8), sharex=True)
fig.suptitle('k-Nearest Neighbor Coupled Oscillators', fontsize=16, fontweight='bold')\
time = np.arange(T)\ncolors = sns.color_palette('tab10', n_nodes)
for i in range(n_nodes):
    axes[i].plot(time, data[:, i], color=colors[i], alpha=0.8, linewidth=1.5)
    axes[i].set_ylabel(f'X{i}', fontweight='bold')
    axes[i].grid(True, alpha=0.3)
    
    mean_val = np.mean(data[:, i])
    axes[i].axhline(mean_val, color='red', linestyle='--', alpha=0.5,
                    label=f'Mean: {mean_val:.2f}')
    axes[i].legend(fontsize=8)
    axes[-1].set_xlabel('Time', fontweight='bold')
    plt.tight_layout()\nplt.show()

## 4. Visualize Ground Truth Network

In [None]:
# Plot ground truth network
plt.figure(figsize=(10, 8))
pos = nx.spring_layout(G_true, seed=seed, k=2, iterations=50)
nx.draw_networkx_nodes(G_true, pos, node_color='lightsteelblue',
                       node_size=1500, alpha=0.8)
nx.draw_networkx_edges(G_true, pos, edge_color='darkblue',
                       arrows=True, arrowsize=20, width=2, alpha=0.7)
nx.draw_networkx_labels(G_true, pos, {i: f'X{i}' for i in range(n_nodes)},
                        font_size=14, font_weight='bold')
plt.title('Ground Truth Causal Network\\n(k-Nearest Neighbor Data)',
          fontsize=16, fontweight='bold')\nplt.axis('off')
plt.tight_layout()
plt.show()

## 5. Apply Causal Discovery

Use the k-Nearest Neighbor method to discover causal relationships.

In [None]:
# Apply causal discovery with k-Nearest Neighbor method
print('Applying causal discovery with k-Nearest Neighbor information method...')
print('This may take a few moments...\\n')
methods_to_test = ['standard', 'alternative']
discovered_networks = {}

for method in methods_to_test:
    print(f'Running {method} method...')
    G_discovered = discover_network(
        data=data,
        method=method,
        information='knn',
        max_lag=2,
        alpha_forward=0.05,
        alpha_backward=0.05,
        n_shuffles=100\n    )
    discovered_networks[method] = G_discovered
    print(f'  Discovered {G_discovered.number_of_edges()} edges')
    print(f'  Edges: {list(G_discovered.edges())}\\n')

## 6. Calculate ROC-AUC Performance

In [None]:
def calculate_roc_auc(true_adj, discovered_graph):
    "Calculate ROC-AUC for network discovery performance.\"\"\"
    n = true_adj.shape[0]\n    \n    G_int = nx.DiGraph()\n    G_int.add_nodes_from(range(n))\n    for edge in discovered_graph.edges():\n        src = int(edge[0].replace('X', '')) if 'X' in str(edge[0]) else int(edge[0])\n        dst = int(edge[1].replace('X', '')) if 'X' in str(edge[1]) else int(edge[1])\n        G_int.add_edge(src, dst)\n    \n    discovered_adj = nx.adjacency_matrix(G_int, nodelist=range(n)).toarray()\n    \n    mask = ~np.eye(n, dtype=bool).flatten()\n    y_true = true_adj.flatten()[mask]\n    y_scores = discovered_adj.flatten()[mask]\n    \n    if len(np.unique(y_true)) > 1:\n        auc_score = roc_auc_score(y_true, y_scores)\n        fpr, tpr, _ = roc_curve(y_true, y_scores)\n        return auc_score, fpr, tpr\n    else:\n        return None, None, None\n\n# Calculate ROC-AUC\nresults = {}\nplt.figure(figsize=(10, 6))\n\nfor method, G_disc in discovered_networks.items():\n    auc_score, fpr, tpr = calculate_roc_auc(A_true, G_disc)\n    \n    if auc_score is not None:\n        results[method] = {'auc': auc_score, 'fpr': fpr, 'tpr': tpr}\n        plt.plot(fpr, tpr, linewidth=2, label=f'{method} (AUC = {auc_score:.3f})')\n        print(f'{method} method: ROC-AUC = {auc_score:.3f}')\n\nplt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Random (AUC = 0.500)')\nplt.xlabel('False Positive Rate', fontweight='bold')\nplt.ylabel('True Positive Rate', fontweight='bold')\nplt.title('ROC Curves for k-Nearest Neighbor Causal Discovery', fontweight='bold')\nplt.legend()\nplt.grid(True, alpha=0.3)\nplt.tight_layout()\nplt.show()

## 7. Conclusions

In [None]:
print('\\n' + '='*60)\nprint('EXPERIMENT CONCLUSIONS - K-NEAREST NEIGHBOR CAUSAL DISCOVERY')\nprint('='*60)\n\nprint(f'📊 DATA CHARACTERISTICS:')\nprint(f'  • Data type: Any continuous (uses Gaussian data)')\nprint(f'  • Time series length: {T}')\nprint(f'  • Number of variables: {n_nodes}')\nprint(f'  • Ground truth edges: {G_true.number_of_edges()}')\n\nprint(f'🔍 DISCOVERY RESULTS:')\nbest_auc = 0\nbest_method = None\nfor method, G_disc in discovered_networks.items():\n    auc_val = results.get(method, {}).get('auc', 0)\n    print(f'  • {method.capitalize()}: {G_disc.number_of_edges()} edges, AUC = {auc_val:.3f}')\n    if auc_val > best_auc:\n        best_auc = auc_val\n        best_method = method\n\nif best_method:\n    print(f'🏆 BEST METHOD: {best_method.upper()}')\n    print(f'  • ROC-AUC: {best_auc:.3f}')\n\nprint(f'💡 K-NEAREST NEIGHBOR METHOD INSIGHTS:')\nprint(f'  • Designed for distance-based information')\nprint(f'  • Data type: Any continuous (uses Gaussian data)')\nprint(f'  • Performance depends on data characteristics and coupling strength')\n\nprint('Experiment completed! 🎉')