# Histogram-based Causal Discovery Example\n\nThis notebook demonstrates causal network discovery using the **Histogram-based** information method with binned density estimation.\n\n## Overview\n- Generate synthetic time series with known causal structure\n- Visualize the dynamics and network structure\n- Apply causal discovery using Histogram-based conditional mutual information\n- Evaluate performance using ROC-AUC metric\n\n

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
import seaborn as sns
from sklearn.metrics import roc_auc_score, roc_curve

import warnings\nwarnings.filterwarnings('ignore')

# Import causal discovery components
from causalentropy.core.discovery import discover_network
from causalentropy.datasets.synthetic import linear_stochastic_gaussian_process

# Set plotting style
plt.style.use('seaborn-v0_8')
sns.set_palette('Set2')
print('Libraries imported successfully!')

## 1. Create Ground Truth Network\n\nWe'll create a directed graph that represents the true causal relationships.

In [None]:
# Create a ground truth network
n_nodes = 5
seed = 42
p = 0.2 

# Create a specific network structure
np.random.seed(seed)
G_true = nx.erdos_renyi_graph(n_nodes, p, seed=seed, directed=True)

print(f"Ground truth network has {G_true.number_of_nodes()} nodes and {G_true.number_of_edges()} edges")
print(f"Edges: {list(G_true.edges())}")

# Get adjacency matrix for later comparison
A_true = nx.to_numpy_array(G_true).T
print(f"\nGround truth adjacency matrix:")
print(A_true)

## 2. Generate Synthetic Histogram-based Data

Generate binned density estimation with the known causal structure.

In [None]:
# Generate synthetic Histogram-based data

T = 200  # Time series length

# Generate data using linear_stochastic_gaussian_process
data, A_genera# Generate synthetic Gaussian time series
T = 200  # Time series length
rho = 0.7  # Coupling strength

# Generate data using linear stochastic process
data, A = linear_stochastic_gaussian_process(
    rho=rho,
    n=n_nodes,
    T=T,
    p=0.2,
    seed=seed,
    G=G_true
)

print(f"Generated time series data with shape: {data.shape}")
print(f"Data statistics:")
print(f"  Mean: {np.mean(data):.3f}")
print(f"  Std:  {np.std(data):.3f}")
print(f"  Range: [{np.min(data):.3f}, {np.max(data):.3f}]")

## 3. Visualize Time Series Data\n\nPlot the binned density estimation to understand the data characteristics.

In [None]:
# Plot time series for all variables
T = 200 
fig, axes = plt.subplots(n_nodes, 1, figsize=(12, 8), sharex=True)
fig.suptitle('Gaussian Coupled Oscillators Time Series', fontsize=16, fontweight='bold')

time = np.arange(T)
colors = sns.color_palette("husl", n_nodes)

for i in range(n_nodes):
    axes[i].plot(time, data[:, i], color=colors[i], alpha=0.8, linewidth=1.5)
    axes[i].set_ylabel(f'X{i}', fontweight='bold')
    axes[i].grid(True, alpha=0.3)
    
    # Highlight the time series statistics
    mean_val = np.mean(data[:, i])
    std_val = np.std(data[:, i])
    axes[i].axhline(mean_val, color='red', linestyle='--', alpha=0.5, label=f'Mean: {mean_val:.2f}')
    axes[i].fill_between(time, mean_val-std_val, mean_val+std_val, alpha=0.1, color='gray')
    axes[i].legend(fontsize=8)

axes[-1].set_xlabel('Time', fontweight='bold')
plt.tight_layout()
plt.show()

## 4. Visualize Ground Truth Network

In [None]:
# Plot ground truth network
plt.figure(figsize=(10, 8))

# Create layout
pos = nx.spring_layout(G_true, seed=seed, k=2, iterations=50)

# Draw network
nx.draw_networkx_nodes(G_true, pos, node_color='lightblue', 
                       node_size=1500, alpha=0.8)
nx.draw_networkx_edges(G_true, pos, edge_color='darkblue', 
                       arrows=True, arrowsize=20, width=2, alpha=0.7)
nx.draw_networkx_labels(G_true, pos, {i: f'X{i}' for i in range(n_nodes)},
                        font_size=14, font_weight='bold')

plt.title('Ground Truth Causal Network\n(Gaussian Data)', 
          fontsize=16, fontweight='bold')
plt.axis('off')
plt.tight_layout()
plt.show()

# Print network statistics
print("Ground Truth Network Statistics:")
print(f"  Nodes: {G_true.number_of_nodes()}")
print(f"  Edges: {G_true.number_of_edges()}")
print(f"  Edge density: {nx.density(G_true):.3f}")
print(f"  Is DAG: {nx.is_directed_acyclic_graph(G_true)}")

## 5. Apply Causal Discovery\n\nUse the Histogram-based method to discover causal relationships.

In [None]:
# Apply causal discovery with Histogram-based method\nprint('Applying causal discovery with Histogram-based information method...')\nprint('This may take a few moments...\\n')\n\nmethods_to_test = ['standard', 'alternative']\ndiscovered_networks = {}\n\nfor method in methods_to_test:\n    print(f'Running {method} method...')\n    \n    G_discovered = discover_network(\n        data=data,\n        method=method,\n        information='histogram',\n        max_lag=2,\n        alpha_forward=0.05,\n        alpha_backward=0.05,\n        n_shuffles=100\n    )\n    \n    discovered_networks[method] = G_discovered\n    print(f'  Discovered {G_discovered.number_of_edges()} edges')\n    print(f'  Edges: {list(G_discovered.edges())}\\n')

## 6. Calculate ROC-AUC Performance

In [None]:
def calculate_roc_auc(true_adj, discovered_graph):\n    \"\"\"Calculate ROC-AUC for network discovery performance.\"\"\"\n    n = true_adj.shape[0]\n    \n    G_int = nx.DiGraph()\n    G_int.add_nodes_from(range(n))\n    for edge in discovered_graph.edges():\n        src = int(edge[0].replace('X', '')) if 'X' in str(edge[0]) else int(edge[0])\n        dst = int(edge[1].replace('X', '')) if 'X' in str(edge[1]) else int(edge[1])\n        G_int.add_edge(src, dst)\n    \n    discovered_adj = nx.adjacency_matrix(G_int, nodelist=range(n)).toarray()\n    \n    mask = ~np.eye(n, dtype=bool).flatten()\n    y_true = true_adj.flatten()[mask]\n    y_scores = discovered_adj.flatten()[mask]\n    \n    if len(np.unique(y_true)) > 1:\n        auc_score = roc_auc_score(y_true, y_scores)\n        fpr, tpr, _ = roc_curve(y_true, y_scores)\n        return auc_score, fpr, tpr\n    else:\n        return None, None, None\n\n# Calculate ROC-AUC\nresults = {}\nplt.figure(figsize=(10, 6))\n\nfor method, G_disc in discovered_networks.items():\n    auc_score, fpr, tpr = calculate_roc_auc(A_true, G_disc)\n    \n    if auc_score is not None:\n        results[method] = {'auc': auc_score, 'fpr': fpr, 'tpr': tpr}\n        plt.plot(fpr, tpr, linewidth=2, label=f'{method} (AUC = {auc_score:.3f})')\n        print(f'{method} method: ROC-AUC = {auc_score:.3f}')\n\nplt.plot([0, 1], [0, 1], 'k--', alpha=0.5, label='Random (AUC = 0.500)')\nplt.xlabel('False Positive Rate', fontweight='bold')\nplt.ylabel('True Positive Rate', fontweight='bold')\nplt.title('ROC Curves for Histogram-based Causal Discovery', fontweight='bold')\nplt.legend()\nplt.grid(True, alpha=0.3)\nplt.tight_layout()\nplt.show()