# DPGExplainer Saga Benchmarks — Class Bounds Class-specific Hexbin

We build PCA 2D Class-specific Hexbin clouds using **Class Bounds** extracted from DPG (RandomForest → DPG → Class Bounds).
Each sample is weighted by how many class-bound predicates it satisfies for its true class.

## 1. Setup

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import re
from sklearn.datasets import load_iris, load_wine, load_breast_cancer, load_digits
from sklearn.ensemble import RandomForestClassifier
from sklearn.decomposition import PCA
from dpg import DPGExplainer


## 2. Helper functions

In [None]:
def parse_bound(bound_str):
    bound_str = bound_str.strip()
    m_range = re.match(r'([0-9.]+)\s*<\s*(.+?)\s*<=\s*([0-9.]+)', bound_str)
    if m_range:
        lo, feature, hi = float(m_range.group(1)), m_range.group(2).strip(), float(m_range.group(3))
        return feature, 'range', (lo, hi)
    m_le = re.match(r'(.+?)\s*<=\s*([0-9.]+)', bound_str)
    if m_le:
        feature, hi = m_le.group(1).strip(), float(m_le.group(2))
        return feature, 'le', hi
    m_gt = re.match(r'(.+?)\s*>\s*([0-9.]+)', bound_str)
    if m_gt:
        feature, lo = m_gt.group(1).strip(), float(m_gt.group(2))
        return feature, 'gt', lo
    return None

def class_bounds_map(explanation):
    bounds = explanation.class_boundaries.get('Class Bounds', {})
    mapping = {}
    for key, bounds_list in bounds.items():
        m = re.search(r'(\d+)', str(key))
        if not m:
            continue
        class_idx = int(m.group(1))
        mapping[class_idx] = bounds_list
    return mapping

def bounds_weights_from_explanation(explanation, X_df, y):
    bounds_by_class = class_bounds_map(explanation)
    weights = np.zeros(len(X_df), dtype=float)
    for i, cls in enumerate(y):
        bounds_list = bounds_by_class.get(int(cls), [])
        if not bounds_list:
            continue
        satisfied = 0
        total = 0
        for b in bounds_list:
            parsed = parse_bound(b)
            if not parsed:
                continue
            feature, kind, val = parsed
            if feature not in X_df.columns:
                continue
            x = X_df.iloc[i][feature]
            if not np.isfinite(x):
                continue
            total += 1
            if kind == 'range':
                lo, hi = val
                if lo < x <= hi:
                    satisfied += 1
            elif kind == 'le':
                if x <= val:
                    satisfied += 1
            elif kind == 'gt':
                if x > val:
                    satisfied += 1
        if total > 0:
            weights[i] = satisfied / total
    return weights


In [None]:
def pca_class_hexbin_plot(X_df, y, weights, title):
    X_clean = X_df.replace([np.inf, -np.inf], np.nan)
    valid_mask = ~X_clean.isna().any(axis=1)
    X_valid = X_clean[valid_mask]
    y_valid = y[valid_mask]
    w_valid = weights[valid_mask]

    pca = PCA(n_components=2, random_state=27)
    X_pca = pca.fit_transform(X_valid)

    fig, ax = plt.subplots(1, 1, figsize=(7, 5), facecolor='white')
    ax.set_facecolor('white')

    classes = sorted(set(y_valid))
    palette = sns.color_palette('tab10', n_colors=len(classes))

    for idx, cls in enumerate(classes):
        mask = (y_valid == cls)
        if mask.sum() < 5:
            continue
        ax.hexbin(
            X_pca[mask, 0],
            X_pca[mask, 1],
            C=w_valid[mask],
            reduce_C_function=np.mean,
            gridsize=35,
            cmap=sns.light_palette(palette[idx], as_cmap=True),
            mincnt=1,
            alpha=0.6,
        )

    ax.scatter(
        X_pca[:, 0],
        X_pca[:, 1],
        c=y_valid,
        cmap='viridis',
        s=16,
        alpha=0.25,
        edgecolor='k',
        linewidth=0.2,
    )

    # Class legend
    class_handles = [
        plt.Line2D([0], [0], marker='s', color='w', label=str(cls),
                   markerfacecolor=palette[i], markersize=8,
                   markeredgecolor='k', markeredgewidth=0.4)
        for i, cls in enumerate(classes)
    ]
    ax.legend(handles=class_handles, title='Classes', loc='upper right', frameon=True)

    ax.set_title(title)
    ax.set_xlabel('PCA 1')
    ax.set_ylabel('PCA 2')
    plt.tight_layout()
    plt.show()


## 4. Pairwise Class-Bounds Regions (All Feature Pairs)
This plots all feature pairs with class-bounds weights as a background hexbin map.

In [None]:
from itertools import combinations

def pairwise_class_bounds_hexbin(X_df, y, explanation, title):
    features = list(X_df.columns)
    pairs = list(combinations(features, 2))
    cols = 3
    rows = int(np.ceil(len(pairs) / cols))
    fig, axes = plt.subplots(rows, cols, figsize=(5 * cols, 3.8 * rows), facecolor='white')
    axes = np.atleast_1d(axes).ravel()

    bounds = explanation.class_boundaries.get('Class Bounds', {})
    parsed_bounds = {}
    for key, bounds_list in bounds.items():
        m = re.search(r'(\d+)', str(key))
        if not m:
            continue
        class_idx = int(m.group(1))
        parsed = [p for p in (parse_bound(b) for b in bounds_list) if p is not None]
        parsed_bounds[class_idx] = parsed

    classes = sorted(parsed_bounds.keys())
    palette = sns.color_palette('tab10', n_colors=max(1, len(classes)))

    for ax, (fx, fy) in zip(axes, pairs):
        x_vals = X_df[fx].replace([np.inf, -np.inf], np.nan)
        y_vals = X_df[fy].replace([np.inf, -np.inf], np.nan)
        mask = (~x_vals.isna()) & (~y_vals.isna())

        ax.scatter(
            x_vals[mask],
            y_vals[mask],
            c=y[mask],
            cmap='viridis',
            s=10,
            alpha=0.25,
            edgecolor='k',
            linewidth=0.2,
        )

        for idx, cls in enumerate(classes):
            preds = parsed_bounds.get(cls, [])
            color = palette[idx]
            for feature, kind, val in preds:
                if feature == fx:
                    if kind == 'range':
                        lo, hi = val
                        ax.axvspan(lo, hi, color=color, alpha=0.08)
                        ax.axvline(lo, color=color, linestyle='--', linewidth=0.8)
                        ax.axvline(hi, color=color, linestyle='--', linewidth=0.8)
                    elif kind == 'le':
                        ax.axvspan(x_vals.min(), val, color=color, alpha=0.08)
                        ax.axvline(val, color=color, linestyle='--', linewidth=0.8)
                    elif kind == 'gt':
                        ax.axvspan(val, x_vals.max(), color=color, alpha=0.08)
                        ax.axvline(val, color=color, linestyle='--', linewidth=0.8)
                elif feature == fy:
                    if kind == 'range':
                        lo, hi = val
                        ax.axhspan(lo, hi, color=color, alpha=0.08)
                        ax.axhline(lo, color=color, linestyle='--', linewidth=0.8)
                        ax.axhline(hi, color=color, linestyle='--', linewidth=0.8)
                    elif kind == 'le':
                        ax.axhspan(y_vals.min(), val, color=color, alpha=0.08)
                        ax.axhline(val, color=color, linestyle='--', linewidth=0.8)
                    elif kind == 'gt':
                        ax.axhspan(val, y_vals.max(), color=color, alpha=0.08)
                        ax.axhline(val, color=color, linestyle='--', linewidth=0.8)

        ax.set_xlabel(fx)
        ax.set_ylabel(fy)
        ax.set_title(f'{fx} vs {fy}')

    for j in range(len(pairs), len(axes)):
        axes[j].axis('off')

    class_handles = [
        plt.Line2D([0], [0], color=palette[i], lw=2, label=f'Class {cls}')
        for i, cls in enumerate(classes)
    ]
    fig.legend(handles=class_handles, loc='upper center', ncol=min(4, len(classes)), frameon=True)

    fig.suptitle(title, y=1.02)
    plt.tight_layout()
    plt.show()


## 3. Datasets and Class-Bounds Class-specific Hexbin clouds

In [None]:
datasets = [
    ('Iris', load_iris(as_frame=True)),
    ('Wine', load_wine(as_frame=True)),
    ('Breast Cancer', load_breast_cancer(as_frame=True)),
    ('Digits', load_digits(as_frame=True)),
]

for name, ds in datasets:
    X = ds.data
    y = ds.target

    model = RandomForestClassifier(n_estimators=10, random_state=27)
    model.fit(X, y)

    explainer = DPGExplainer(
        model=model,
        feature_names=X.columns,
        target_names=[str(t) for t in sorted(set(y))],
        config_file='config.yaml',
    )
    explanation = explainer.explain_global(X.values, communities=False)

    weights = bounds_weights_from_explanation(explanation, X, y)
    title = f'{name}: Class-Bounds Map (RF → DPG → Class Bounds)'
    pca_class_hexbin_plot(X, y, weights, title)

    pairwise_class_bounds_hexbin(X, y, explanation, f'{name}: Pairwise Class-Bounds Regions')
