# Lab 1: Causal Discovery based Markov Blanket Search

This notebook demonstrates:
1. Loading datasets and visualizing the DAG
2. Introducing Causal Discovery based Markov Blanket Search methods (CD-MB)
   - Evaluating causal discovery performance (NHD, precision, recall, F1)
   - Evaluating Markov Blanket feature selection performance
3. Comparing CD-MB methods

In [None]:
%load_ext watermark
%watermark -a "Muhammed Hunaid Topiwala" -v

%load_ext autoreload
%autoreload 2

In [None]:
import logging
logging.disable(logging.WARNING)

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import pandas as pd
from tqdm.auto import tqdm
from huggingface_hub import snapshot_download
from datasets import load_dataset

from blanket.datasets import load_data
from blanket.feature_selection import (
    direct_lingam_selector,
    ges_selector,
    notears_selector,
    pc_selector,
)
from blanket.metrics import adjacency_confusion, jaccard_score, reduction_rate, shd
from blanket.plots import plot_adjmat, plot_graph

## 1. Load Dataset and Visualize DAG

In [None]:
# Load dataset
datasets = load_dataset(path="CSE472-blanket-challenge/phase1-dataset", split="train")

datasets.features

In [None]:
example = datasets[70]
X = np.asarray(example["X"])
y = np.asarray(example["y"])
adj_mat = np.asarray(example["adjacency_matrix"])
num_nodes = example["num_nodes"]
density = example["density"]
mb = np.asarray(example["feature_mask"])

In [None]:
# Visualize True DAG and Adjacency Matrix
fig, axes = plt.subplots(1, 2, figsize=(17, 8), constrained_layout=True)

plot_graph(adj_mat, figsize=(5, 5), ax=axes[0], title="True DAG")

plot_adjmat(
    adj_mat,
    title="Adjacency Matrix",
    figsize=(5, 5),
    ax=axes[1],
)

plt.show()

## 2. Causal Discovery based Markov Blanket Search (CD-MB)

The goal is to recover a target variable's Markov blanket (its parents, children, and spouses) by first learning a causal graph and then extracting the local neighborhood.

General Workflow
- Run a causal discovery algorithm on the data.
- Sanitize the learned graph (e.g., remove implausible edges, orient CPDAG edges if needed, ensure acyclicity).
- Extract the target's Markov blanket from the cleaned graph.

Pros
- Works with any causal discovery algorithm.
- One learned graph can be reused for multiple target variables.

Cons
- Algorithm assumptions matter (e.g., causal sufficiency, faithfulness, linearity, non‑Gaussianity).
- More often than not, CD method can not return a DAG but a CPDAG (equivalence class of DAGs).
- Many CD methods do not scale well on high‑dimensional data.
- Global structure learning can be wasteful when only a local (MB) neighborhood is needed.

In this tutorial we compare four families of causal discovery methods:
- PC — constraint‑based.
- GES — score‑based.
- DirectLiNGAM — functional/ICA‑based (non‑Gaussian noise).
- NOTEARS — gradient/continuous optimization‑based (linear SEM with acyclicity penalty).

We use the [`gCastle`](https://github.com/huawei-noah/trustworthyAI/blob/master/gcastle) package for the causal discovery algorithms; see its documentation for details.

In [None]:
pc_feature, pc_adjmat = pc_selector(X, y, alpha=0.05, ci_test="fisherz", variant="stable")
ges_feature, ges_adjmat = ges_selector(X, y, criterion="bic", method="scatter")
direct_lingam_feature, direct_lingam_adjmat = direct_lingam_selector(
    X, y, measure="pwling", thresh=0.3
)
notears_feature, notears_adjmat = notears_selector(X, y, lambda1=0.1, loss_type="l2")

Some CD algos returns a class of graph called CPDAG, CPDAG contains undirected edges when the direction cannot be determined from conditional independence tests alone.

For MB, we allow bidirectional edges

Some causal discovery methods return a CPDAG*, an equivalence class that includes undirected edges when orientation cannot be determined from conditional independence tests alone. For Markov‑blanket extraction, a partially oriented graph is acceptable; we permit undirected (bidirectional) edges. The minimal Markov blanket is known as the Markov boundary.

\*: Proposition 11.1, Introduction to Causal Inference

In [None]:
G = nx.from_numpy_array(pc_adjmat, create_using=nx.DiGraph)
nx.is_directed_acyclic_graph(G)

In [None]:
# Compute graph metrics (SHD, precision, recall, F1) and MB metrics for each discovered graph

cdmb_results = {
    "PC": (pc_feature, pc_adjmat),
    "GES": (ges_feature, ges_adjmat),
    "DirectLiNGAM": (direct_lingam_feature, direct_lingam_adjmat),
    "NOTEARS": (notears_feature, notears_adjmat),
}

rows = []
for name, results in cdmb_results.items():
    feature = results[0]
    adj = results[1]
    shd_val = shd(adj_mat, adj)
    precision, recall, f1 = adjacency_confusion(adj_mat, adj)

    mb_jaccard = jaccard_score(mb, feature)
    mb_size = int(np.sum(feature))
    mb_reduction = reduction_rate(feature)

    rows.append(
        {
            "Method": name,
            "SHD": int(shd_val),
            "Precision": float(precision),
            "Recall": float(recall),
            "F1 Score": float(f1),
            "MB Jaccard": float(mb_jaccard),
            "MB Size": mb_size,
            "Reduction Rate": float(mb_reduction),
        }
    )

# Create and display comparison table
comparison_df = pd.DataFrame(rows)
comparison_df.sort_values(by="F1 Score", ascending=False, inplace=False).reset_index(drop=True)


From the result above, NoTEARS performs best

In [None]:
# visualize the results

num_methods = len(cdmb_results)
fig, axes = plt.subplots(2, num_methods + 1, figsize=(num_methods * 8, 12), constrained_layout=True)

plot_graph(
    adj_mat,
    figsize=(5, 5),
    title="True DAG",
    ax=axes[0, 0],
)

plot_adjmat(
    adj_mat,
    title="True Adjacency Matrix",
    figsize=(5, 5),
    ax=axes[1, 0],
)

for i, (name, results) in enumerate(cdmb_results.items(), 1):
    adj = results[1]
    plot_graph(
        adj,
        figsize=(5, 5),
        title=f"{name}",
        ax=axes[0, i],
    )

    plot_adjmat(
        adj,
        title=f"{name}",
        figsize=(5, 5),
        ax=axes[1, i],
    )

plt.show()

## 4. Scale up

The above is just a small demo on one graph. Next, we run all CD-MB methods on 10 graphs, record time, and compare results.

In [None]:
import time

# Set random seed for reproducibility
np.random.seed(42)

# Sample 10 graphs from the dataset
n_graphs = 10
sample_indices = np.random.choice(len(datasets), size=n_graphs, replace=False)

# Initialize results storage
all_results = []

# Run all CD algos on each sampled graph
for idx, sample_idx in enumerate(tqdm(sample_indices, total=n_graphs, desc="Processing graphs", unit="graph")):

    example = datasets[sample_idx]

    # Reconstruct the DAG from adjacency matrix
    X = np.asarray(example["X"])
    y = np.asarray(example["y"])
    adj_mat = np.asarray(example["adjacency_matrix"])
    num_nodes = example["num_nodes"]
    num_edges = example["num_edges"]
    density = example["density"]
    oracle_mb = np.asarray(example["feature_mask"])

    # Run each algorithm and record time and results
    algorithms = {
        "PC": lambda: pc_selector(X, y, alpha=0.05, ci_test="fisherz", variant="stable"),
        "GES": lambda: ges_selector(X, y, criterion="bic", method="scatter"),
        "DirectLiNGAM": lambda: direct_lingam_selector(X, y, measure="pwling", thresh=0.3),
        "NOTEARS": lambda: notears_selector(X, y, lambda1=0.1, loss_type="l2"),
    }

    for algo_name, algo_func in algorithms.items():
        # Record time
        start_time = time.time()
        try:
            feature, discovered_adj = algo_func()
            end_time = time.time()
            runtime = end_time - start_time

            # Compute graph discovery metrics
            shd_val = shd(adj_mat, discovered_adj)
            precision, recall, f1 = adjacency_confusion(adj_mat, discovered_adj)

            # Compute MB metrics
            mb_jaccard = jaccard_score(oracle_mb, feature)
            mb_size = int(np.sum(feature))
            mb_reduction = reduction_rate(feature)
            oracle_mb_size = int(np.sum(oracle_mb))

            # Store results
            all_results.append({
                "Graph Index": sample_idx,
                "Graph ID": idx + 1,
                "Algorithm": algo_name,
                "Runtime (s)": runtime,
                "Num Nodes": num_nodes,
                "Num Edges": num_edges,
                "Density": density,
                "Oracle MB Size": oracle_mb_size,
                "SHD": int(shd_val),
                "Precision": precision,
                "Recall": recall,
                "F1 Score": f1,
                "MB Jaccard": mb_jaccard,
                "MB Size": mb_size,
                "Reduction Rate": mb_reduction,
            })
        except Exception as e:
            print(f"  Error in {algo_name}: {str(e)}")
            continue

# Create comparison DataFrame
results_df = pd.DataFrame(all_results)
results_df.head(10)

In [None]:
# Summary by algorithm
summary_by_algo = results_df.groupby("Algorithm").agg({
    "Runtime (s)": ["mean", "std", "min", "max"],
    "SHD": ["mean", "std"],
    "Precision": ["mean", "std"],
    "Recall": ["mean", "std"],
    "F1 Score": ["mean", "std"],
    "MB Jaccard": ["mean", "std"],
    "Reduction Rate": ["mean", "std"],
}).round(4)

summary_by_algo

## 4.1 Compare Algorithm Performance

In [None]:
# Summary Table: Runtime Performance
print("\n" + "="*100)
print("SUMMARY: RUNTIME PERFORMANCE BY ALGORITHM")
print("="*100 + "\n")

runtime_summary = results_df.groupby("Algorithm")["Runtime (s)"].agg(["count", "mean", "std", "min", "max"]).round(4)
runtime_summary.columns = ["Count", "Mean (s)", "Std (s)", "Min (s)", "Max (s)"]
runtime_summary

In [None]:
# Define color palette for algorithms
algo_colors = {
    "PC": "#1f77b4",
    "GES": "#ff7f0e",
    "DirectLiNGAM": "#2ca02c",
    "NOTEARS": "#d62728"
}

# Create figure with subplots for Runtime vs Graph Properties
fig, axes = plt.subplots(1, 2, figsize=(16, 5), constrained_layout=True)

# Plot 1: Runtime vs Num Nodes
for algo in results_df["Algorithm"].unique():
    algo_data = results_df[results_df["Algorithm"] == algo]
    axes[0].scatter(algo_data["Num Nodes"], algo_data["Runtime (s)"],
                   label=algo, s=100, alpha=0.7, color=algo_colors[algo])

axes[0].set_xlabel("Number of Nodes", fontsize=12)
axes[0].set_ylabel("Runtime (seconds)", fontsize=12)
axes[0].set_title("Runtime vs Number of Nodes", fontsize=13, fontweight="bold")
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)

# Plot 2: Runtime vs Density
for algo in results_df["Algorithm"].unique():
    algo_data = results_df[results_df["Algorithm"] == algo]
    axes[1].scatter(algo_data["Density"], algo_data["Runtime (s)"],
                   label=algo, s=100, alpha=0.7, color=algo_colors[algo])

axes[1].set_xlabel("Graph Density", fontsize=12)
axes[1].set_ylabel("Runtime (seconds)", fontsize=12)
axes[1].set_title("Runtime vs Graph Density", fontsize=13, fontweight="bold")
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)

plt.suptitle("1. Runtime Performance Analysis", fontsize=14, fontweight="bold", y=1.02)
plt.show()


In [None]:
# Summary Table: Graph Discovery Performance (F1 Score)
print("\n" + "="*100)
print("SUMMARY: GRAPH DISCOVERY PERFORMANCE (F1 SCORE) BY ALGORITHM")
print("="*100 + "\n")

f1_summary = results_df.groupby("Algorithm")[["SHD", "Precision", "Recall", "F1 Score"]].agg(["mean", "std"]).round(4)
f1_summary

Note that I did not optimize hyperparameters for each method; better performance may be possible with tuning.

F1 score std is fairly high compared to mean, indicating one hyperparameter setting may not fit all graphs well.

In [None]:
# Create figure with subplots for F1 Score vs Graph Properties
fig, axes = plt.subplots(1, 2, figsize=(16, 5), constrained_layout=True)

# Plot 1: F1 Score vs Num Nodes
for algo in results_df["Algorithm"].unique():
    algo_data = results_df[results_df["Algorithm"] == algo]
    axes[0].scatter(algo_data["Num Nodes"], algo_data["F1 Score"],
                   label=algo, s=100, alpha=0.7, color=algo_colors[algo])

axes[0].set_xlabel("Number of Nodes", fontsize=12)
axes[0].set_ylabel("F1 Score", fontsize=12)
axes[0].set_title("F1 Score vs Number of Nodes", fontsize=13, fontweight="bold")
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim([0, 1.05])

# Plot 2: F1 Score vs Density
for algo in results_df["Algorithm"].unique():
    algo_data = results_df[results_df["Algorithm"] == algo]
    axes[1].scatter(algo_data["Density"], algo_data["F1 Score"],
                   label=algo, s=100, alpha=0.7, color=algo_colors[algo])

axes[1].set_xlabel("Graph Density", fontsize=12)
axes[1].set_ylabel("F1 Score", fontsize=12)
axes[1].set_title("F1 Score vs Graph Density", fontsize=13, fontweight="bold")
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, 1.05])

plt.show()


In [None]:
# Summary Table: Markov Blanket Feature Selection Performance
print("\n" + "="*100)
print("SUMMARY: MARKOV BLANKET FEATURE SELECTION PERFORMANCE BY ALGORITHM")
print("="*100 + "\n")

mb_summary = results_df.groupby("Algorithm")[["MB Jaccard", "MB Size", "Reduction Rate"]].agg(["mean", "std"]).round(4)
print(mb_summary)

MB performance is aligned with graph discovery performance

In [None]:
# Create figure with subplots for Jaccard Similarity vs Graph Properties
fig, axes = plt.subplots(1, 2, figsize=(16, 5), constrained_layout=True)

# Plot 1: MB Jaccard vs Num Nodes
for algo in results_df["Algorithm"].unique():
    algo_data = results_df[results_df["Algorithm"] == algo]
    axes[0].scatter(algo_data["Num Nodes"], algo_data["MB Jaccard"],
                   label=algo, s=100, alpha=0.7, color=algo_colors[algo])

axes[0].set_xlabel("Number of Nodes", fontsize=12)
axes[0].set_ylabel("Markov Blanket Jaccard Similarity", fontsize=12)
axes[0].set_title("MB Jaccard Similarity vs Number of Nodes", fontsize=13, fontweight="bold")
axes[0].legend(fontsize=10)
axes[0].grid(True, alpha=0.3)
axes[0].set_ylim([0, 1.05])

# Plot 2: MB Jaccard vs Density
for algo in results_df["Algorithm"].unique():
    algo_data = results_df[results_df["Algorithm"] == algo]
    axes[1].scatter(algo_data["Density"], algo_data["MB Jaccard"],
                   label=algo, s=100, alpha=0.7, color=algo_colors[algo])

axes[1].set_xlabel("Graph Density", fontsize=12)
axes[1].set_ylabel("Markov Blanket Jaccard Similarity", fontsize=12)
axes[1].set_title("MB Jaccard Similarity vs Graph Density", fontsize=13, fontweight="bold")
axes[1].legend(fontsize=10)
axes[1].grid(True, alpha=0.3)
axes[1].set_ylim([0, 1.05])

plt.show()


## 4.2 Markov Blanket Analysis: Relationship between MB Size, Density, and Graph Structure

By definition, a MB of $Y$ shield $Y$ from the rest of the features:

$$P(Y | MB(Y), Z) = P(Y | MB(Y))$$

,where $Z \in X\setminus \{Y, MB(Y)\}$.

From this definition, we can infer that the all features, i.e $X \setminus \{Y\}$, is a trivial MB of $Y$.

Therefore, if a MB is close to the size of all features, it is not very useful to do MB search.
In the following, for each graph, we iterate through all nodes and calculate their MBs to understand how MB size relates to graph density and structure.

In [None]:
from blanket.graph import get_markov_blanket

# Compute MB for all nodes in each graph
mb_analysis_data = []

for graph_idx, sample_idx in enumerate(sample_indices):
    example = datasets[sample_idx]
    adj_mat = np.asarray(example["adjacency_matrix"])
    num_nodes = example["num_nodes"]
    num_edges = example["num_edges"]
    density = example["density"]

    # Compute MB for each node
    for node_id in range(num_nodes):
        mb = get_markov_blanket(adj_mat, node_id)
        mb_size = sum(mb)
        mb_ratio = mb_size / (num_nodes - 1)

        mb_analysis_data.append({
            "Graph Index": sample_idx,
            "Graph ID": graph_idx + 1,
            "Node ID": node_id,
            "Num Nodes": num_nodes,
            "Num Edges": num_edges,
            "Density": density,
            "MB Size": mb_size,
            "MB Ratio": mb_ratio,
        })

mb_analysis_df = pd.DataFrame(mb_analysis_data)


In [None]:
mb_analysis_df.groupby("Graph ID")[["MB Ratio", "Density", "Num Nodes"]].agg(["mean", "std"]).round(4)

1. When density exceeds 0.2, the average MB ratio is over 0.7. Noted that it already includes trivial nodes such as the root node with a limited MB. In practice, target variables are typically meaningful and thus not trivial. Hence, the average MB ratio could even be larger.
2. Consequently, in many cases, identifying the MB is not particularly informative.
3. It may seem counterintuitive, but densities of 0.1–0.2 are not sparse. Truly sparse graphs usually have densities below 0.001, which generally requires a large number of nodes (>1000). We leave this for future work.

In [None]:
# Aggregate MB Ratio by Graph ID
aggregated_mb_df = mb_analysis_df.groupby("Graph ID").agg({
    "MB Ratio": "mean",
    "Density": "mean",
    "Num Nodes": "mean"
}).reset_index()

# Visualization: Aggregated MB Ratio vs Graph Properties (colored by Graph ID)
fig, axes = plt.subplots(1, 2, figsize=(16, 5), constrained_layout=True)

# Plot 1: Aggregated MB Ratio vs Density with Graph ID coloring
scatter1 = axes[0].scatter(aggregated_mb_df["Density"], aggregated_mb_df["MB Ratio"],
                           c=aggregated_mb_df["Graph ID"], cmap="tab10",
                           s=100, alpha=0.8, edgecolors="black", linewidth=0.5)
axes[0].set_xlabel("Graph Density", fontsize=12, fontweight="bold")
axes[0].set_ylabel("MB Ratio", fontsize=12, fontweight="bold")
axes[0].set_title("Aggregated MB Ratio vs Graph Density", fontsize=13, fontweight="bold")
axes[0].grid(True, alpha=0.3)
cbar1 = plt.colorbar(scatter1, ax=axes[0])
cbar1.set_label("Graph ID", fontsize=11)

# Plot 2: Aggregated MB Ratio vs Num Nodes with Graph ID coloring
scatter2 = axes[1].scatter(aggregated_mb_df["Num Nodes"], aggregated_mb_df["MB Ratio"],
                           c=aggregated_mb_df["Graph ID"], cmap="tab10",
                           s=100, alpha=0.8, edgecolors="black", linewidth=0.5)
axes[1].set_xlabel("Number of Nodes", fontsize=12, fontweight="bold")
axes[1].set_ylabel("MB Ratio", fontsize=12, fontweight="bold")
axes[1].set_title("Aggregated MB Ratio vs Number of Nodes", fontsize=13, fontweight="bold")
axes[1].grid(True, alpha=0.3)
cbar2 = plt.colorbar(scatter2, ax=axes[1])
cbar2.set_label("Graph ID", fontsize=11)

plt.show()

## Bonus

1. [5 point] Implement a specialized causal feature selection method (IAMB)
- [bnlearn](https://github.com/cran/bnlearn) implements IAMB, fastIAMB. Using [rpy2](https://rpy2.github.io/) or write a R script to run it (if you're familiar with R).
- [py-tetrad](https://github.com/cmu-phil/py-tetrad) also provides IAMB implementations.
2. Is dedicated causal feature selection method (IAMB) working better than causal discovery methods?