# DPGExplainer Saga Benchmarks — Episode 1: Iris

A practitioner-friendly walkthrough of Decision Predicate Graphs (DPG) using the classic Iris dataset. We train a small Random Forest (RF), build a DPG to map the model’s global behavior using Explainable AI (XAI), and interpret three key properties to explain the model: Local Reaching Centrality (LRC), Betweenness Centrality (BC), and node communities.


## 1. What is Explainable AI (XAI)
Explainable AI (XAI) focuses on making model behavior understandable to people. It helps answer questions like why a prediction was made, what features mattered most, and whether the model behaves as intended.

Common motivations for XAI include:
- Explain to justify: Provide evidence for decisions in high-stakes contexts.
- Explain to discover: Surface patterns, biases, or unexpected signals in the data.
- Explain to improve: Debug models, features, and data issues.
- Explain to control: Support monitoring, governance, and compliance.

XAI methods are often grouped into:
- Global explanations: Summarize how the model behaves overall.
- Local explanations: Explain a single prediction or a small region of the feature space.

SHAP is a popular local method, while DPG provides a global view by turning an ensemble into a predicate graph and analyzing its structure.


## 2. Why DPG (in one minute)
Tree ensembles, such as RF, can be accurate but hard to interpret globally. DPG converts the ensemble into a graph where:
- Nodes are predicates like `petal length <= 2.45`, in the iris case.
- Edges capture how often training samples traverse those predicates
- Metrics quantify how predicates structure the model’s global reasoning

This gives a global map of decision logic and allows the use of graph metrics to capture the model’s rationale.

In the next steps, we create a Random Forest model of the Iris dataset and explain it with DPG.


## 3. Setup (Iris + Random Forest + DPG)

We first train a baseline Random Forest, then inspect pairwise feature/class structure with a pair plot.


In [None]:
%pip install --force-reinstall --no-deps git+https://github.com/Meta-Group/DPG.git

In [None]:
import numpy as np
import pandas as pd
from sklearn.datasets import load_iris
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import ConfusionMatrixDisplay, classification_report, confusion_matrix
from sklearn.model_selection import train_test_split
from dpg import DPGExplainer

iris = load_iris(as_frame=True)
X = iris.data
y = iris.target

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=27, stratify=y
)

model = RandomForestClassifier(n_estimators=10, random_state=27)
model.fit(X_train, y_train)

y_pred = model.predict(X_test)
cm = confusion_matrix(y_test, y_pred)
ConfusionMatrixDisplay(cm, display_labels=iris.target_names).plot()
print(classification_report(y_test, y_pred, target_names=iris.target_names))


### 3.1 Pair Plot (All Features by Class)


In [None]:
import seaborn as sns

pair_df = X.copy()
pair_df["class_name"] = y.map(lambda i: iris.target_names[i])

sns.pairplot(
    pair_df,
    hue="class_name",
    diag_kind="kde",
    corner=True,
    plot_kws={"alpha": 0.6, "s": 28, "edgecolor": "none"},
)
plt.suptitle("Iris pair plot by class", y=1.02)
plt.show()



## 4. Extracting DPG from RF

Next, we extract the DPG from our RF model. The parameters `feature_names` and `target_names` provide readable output for the mapped scenarios.


In [None]:
explainer = DPGExplainer(
    model=model,
    feature_names=X.columns,
    target_names=iris.target_names.tolist(),
    config_file="config.yaml",  # optional if present
)

explanation = explainer.explain_global(
    X.values,
    communities=True,
    community_threshold=0.2,
)


## 5. Read the DPG Metrics


In [None]:
explanation.node_metrics.head()


**Local Reaching Centrality (LRC)**
- High LRC nodes can reach many other nodes downstream.
- These predicates often act early, framing large portions of the model’s logic.

**Betweenness Centrality (BC)**
- High BC nodes lie on many shortest paths between other nodes.
- These predicates are “bottlenecks” that connect major decision flows.


## 6. Compare Top LRC Predicates vs Random Forest Importance


In [None]:
import matplotlib.pyplot as plt
import re

def parse_predicate_parts(label: str):
    m = re.search(r"(.+?)\s*(<=|>)\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", str(label))
    if not m:
        return None
    feature = m.group(1).strip()
    op = m.group(2)
    threshold = float(m.group(3))
    return feature, op, threshold


def parse_feature_from_predicate(label: str) -> str:
    parts = parse_predicate_parts(label)
    return parts[0] if parts else str(label)


def lrc_predicate_scores(explanation, top_k=10):
    nm = explanation.node_metrics.copy()
    nm = nm[nm["Label"].str.contains(r"(<=|>)", regex=True, na=False)].copy()
    nm = nm.sort_values("Local reaching centrality", ascending=False).head(top_k)

    rows = []
    for _, r in nm.iterrows():
        parsed = parse_predicate_parts(r["Label"])
        if not parsed:
            continue
        feature, op, threshold = parsed
        rows.append({
            "predicate": str(r["Label"]),
            "feature": feature,
            "op": op,
            "threshold": threshold,
            "lrc": float(r["Local reaching centrality"]),
        })

    return pd.DataFrame(rows)


def _feature_color_map(features):
    unique = list(dict.fromkeys(features))
    cmap = plt.cm.tab20
    if len(unique) <= 1:
        return {unique[0]: cmap(0)} if unique else {}
    return {f: cmap(i / (len(unique) - 1)) for i, f in enumerate(unique)}


def plot_lrc_vs_rf_importance(explanation, model, X_df, top_k=10, dataset_name='Iris'):
    top_lrc = lrc_predicate_scores(explanation, top_k=top_k).copy()

    top_rf = (
        pd.DataFrame({
            "feature": list(getattr(model, "feature_names_in_", X_df.columns)),
            "rf_importance": model.feature_importances_.astype(float),
        })
        .sort_values("rf_importance", ascending=False)
        .head(top_k)
    )

    # Sort ascending for readable horizontal bars
    top_lrc_plot = top_lrc.sort_values("lrc", ascending=True)
    top_rf_plot = top_rf.sort_values("rf_importance", ascending=True)

    # Keep feature colors consistent across both plots
    all_features = top_lrc_plot["feature"].tolist() + top_rf_plot["feature"].tolist()
    feature_to_color = _feature_color_map(all_features)

    fig, axes = plt.subplots(1, 2, figsize=(16, max(5, top_k * 0.45)))

    axes[0].barh(
        top_lrc_plot["predicate"],
        top_lrc_plot["lrc"],
        color=[feature_to_color[f] for f in top_lrc_plot["feature"]],
        edgecolor="black",
        linewidth=0.4,
    )
    axes[0].set_title(f"{dataset_name}: Top {top_k} LRC predicates")
    axes[0].set_xlabel("Local Reaching Centrality")
    axes[0].set_ylabel("Predicate")

    axes[1].barh(
        top_rf_plot["feature"],
        top_rf_plot["rf_importance"],
        color=[feature_to_color[f] for f in top_rf_plot["feature"]],
        edgecolor="black",
        linewidth=0.4,
    )
    axes[1].set_title(f"{dataset_name}: Top {top_k} RF feature importances")
    axes[1].set_xlabel("Random Forest feature importance")
    axes[1].set_ylabel("Feature")

    legend_features = list(dict.fromkeys(all_features))
    legend_handles = [
        plt.Line2D([0], [0], marker='s', color='w', label=f,
                   markerfacecolor=feature_to_color[f], markeredgecolor='black', markersize=8)
        for f in legend_features
    ]
    fig.legend(handles=legend_handles, title="Feature colors",
               loc="lower center", ncol=min(4, max(1, len(legend_handles))), frameon=True)

    plt.tight_layout(rect=(0, 0.08, 1, 1))
    plt.show()


def plot_top_lrc_predicate_splits(explanation, X_df, y, top_predicates=5, top_features=2, dataset_name='Iris'):
    top_lrc = lrc_predicate_scores(explanation, top_k=max(top_predicates, 10)).copy()
    top5 = top_lrc.sort_values("lrc", ascending=False).head(top_predicates).copy()

    # Select top-2 LRC features using cumulative LRC contribution
    feature_rank = (
        top_lrc.groupby("feature", as_index=False)["lrc"].sum()
        .sort_values("lrc", ascending=False)
        .head(top_features)
    )
    selected_features = feature_rank["feature"].tolist()
    if len(selected_features) < 2:
        print(f"{dataset_name}: not enough LRC features to build a 2D split plot.")
        return

    fx, fy = selected_features[0], selected_features[1]
    if fx not in X_df.columns or fy not in X_df.columns:
        print(f"{dataset_name}: selected LRC features not present in input dataframe columns.")
        return

    # Lines from top-5 predicates only, restricted to top-2 selected features
    split_rows = top5[top5["feature"].isin([fx, fy])].copy()

    fig, ax = plt.subplots(figsize=(8, 6))

    sc = ax.scatter(
        X_df[fx],
        X_df[fy],
        c=y,
        cmap='viridis',
        s=36,
        alpha=0.75,
        edgecolor='white',
        linewidth=0.5,
    )

    # Color by feature to keep consistency with section colors
    feature_to_color = _feature_color_map([fx, fy])

    line_labels_seen = set()
    for _, r in split_rows.iterrows():
        f, op, thr, score = r["feature"], r["op"], r["threshold"], r["lrc"]
        if f == fx:
            ls = '--' if op == '<=' else '-'
            label = f"{f} {op} {thr:.2f} (LRC={score:.3f})"
            ax.axvline(
                thr,
                color=feature_to_color[f],
                linestyle=ls,
                linewidth=2,
                alpha=0.9,
                label=label if label not in line_labels_seen else None,
            )
            line_labels_seen.add(label)
        elif f == fy:
            ls = '--' if op == '<=' else '-'
            label = f"{f} {op} {thr:.2f} (LRC={score:.3f})"
            ax.axhline(
                thr,
                color=feature_to_color[f],
                linestyle=ls,
                linewidth=2,
                alpha=0.9,
                label=label if label not in line_labels_seen else None,
            )
            line_labels_seen.add(label)

    ax.set_title(
        f"{dataset_name}: Top-{top_predicates} LRC predicate splits on top-2 LRC features"
    )
    ax.set_xlabel(fx)
    ax.set_ylabel(fy)

    # class legend + predicate legend
    cbar = fig.colorbar(sc, ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Class id')

    handles, labels = ax.get_legend_handles_labels()
    if handles:
        ax.legend(handles, labels, title='Top LRC predicate lines', loc='best', fontsize=8)

    plt.tight_layout()
    plt.show()


plot_lrc_vs_rf_importance(explanation, model, X, top_k=10, dataset_name='Iris')
plot_top_lrc_predicate_splits(explanation, X, y, top_predicates=5, top_features=2, dataset_name='Iris')



### Optional: inspect top-10 LRC and RF tables


In [None]:
top_lrc = lrc_predicate_scores(explanation, top_k=10)
top_rf = (
    pd.DataFrame({
        "feature": list(getattr(model, "feature_names_in_", X.columns)),
        "rf_importance": model.feature_importances_.astype(float),
    })
    .sort_values("rf_importance", ascending=False)
    .head(10)
)

top_lrc
top_rf



Interpretation guide:
- If a predicate has **high LRC**, it likely sets an early rule that shapes many later decisions.
- If a feature has **high RF importance**, it contributes strongly to split quality across the forest.
- Compare overlap: when high-LRC predicates and high-RF features agree, the global graph and model-level importance tell a consistent story.

### BC analysis
This notebook includes a BC bottleneck cloud in PCA space (Section 7).
A full BC ranking-focused analysis will be covered in a separate session/notebook.


## 7. Show BC Bottleneck Cloud in PCA Space


In [None]:
import seaborn as sns
from sklearn.decomposition import PCA

def bc_weights_from_explanation(explanation, X_df, top_k=10):
    nm = explanation.node_metrics.copy()
    nm = nm[nm["Label"].str.contains(r"(<=|>)", regex=True, na=False)].copy()
    top_bc = nm.sort_values("Betweenness centrality", ascending=False).head(top_k)

    weights = np.zeros(len(X_df), dtype=float)

    for _, row in top_bc.iterrows():
        parsed = re.search(r"(.+?)\s*(<=|>)\s*([-+]?\d*\.?\d+(?:[eE][-+]?\d+)?)", str(row["Label"]))
        if not parsed:
            continue
        feature = parsed.group(1).strip()
        op = parsed.group(2)
        threshold = float(parsed.group(3))
        if feature not in X_df.columns:
            continue

        vals = X_df[feature].values
        vals = np.where(np.isfinite(vals), vals, np.nan)
        if op == '<=':
            weights += (vals <= threshold)
        else:
            weights += (vals > threshold)

    if weights.max() > 0:
        weights = weights / weights.max()
    return weights


def pca_kde_plot(X_df, y, weights, title):
    X_clean = X_df.replace([np.inf, -np.inf], np.nan)
    valid_mask = ~X_clean.isna().any(axis=1)
    X_valid = X_clean[valid_mask]
    y_valid = y[valid_mask]
    w_valid = weights[valid_mask]

    pca = PCA(n_components=2, random_state=27)
    X_pca = pca.fit_transform(X_valid)

    fig, ax = plt.subplots(1, 1, figsize=(7, 5), facecolor='white')
    ax.set_facecolor('#f6d6d6')

    kde = sns.kdeplot(
        x=X_pca[:, 0],
        y=X_pca[:, 1],
        weights=w_valid,
        fill=True,
        levels=25,
        cmap='turbo_r',
        alpha=0.9,
        thresh=0.0,
        bw_adjust=1.15,
        ax=ax,
    )

    ax.scatter(
        X_pca[:, 0],
        X_pca[:, 1],
        c=y_valid,
        cmap='viridis',
        s=22,
        alpha=0.5,
        edgecolor='k',
        linewidth=0.4,
    )

    cbar = fig.colorbar(kde.collections[0], ax=ax, fraction=0.046, pad=0.04)
    cbar.set_label('Prediction confidence (red = higher, blue = lower)')

    ax.set_title(title)
    ax.set_xlabel('PCA 1')
    ax.set_ylabel('PCA 2')
    plt.tight_layout()
    plt.show()


weights = bc_weights_from_explanation(explanation, X, top_k=10)
pca_kde_plot(X, y, weights, 'Iris: BC Bottleneck Cloud in PCA Space')



## 8. Communities (Decision Themes)


In [None]:
explanation.communities.keys()
explanation.communities.get("Communities", [])[:3]

run_name = "iris_dpg"
explainer.plot(run_name, explanation, save_dir="results", class_flag=True, export_pdf=True)
explainer.plot_communities(run_name, explanation, save_dir="results", class_flag=True, export_pdf=True)



## 9. What to Say in the Story
Use these three points for a quick practitioner summary:
- **LRC:** Which predicate most strongly frames the model’s logic?
- **BC:** Which predicate acts as a bottleneck between key decision paths?
- **Communities:** Which predicate groups define the “themes” of each class?


## Next Episode
We will move to another scikit-learn benchmark dataset.
