# 04_InteractionDiscovery - Physics-SR Framework v4.1

## Stage 1.4: Adaptive Interaction Discovery with TreeSHAP

**Author:** Zhengze Zhang  
**Affiliation:** Department of Statistics, Columbia University  
**Contact:** zz3239@columbia.edu  
**Date:** January 2026  
**Version:** 4.1.1 (TreeSHAP-based Adaptive Strategy)

---

### Purpose

Discover high-order feature interactions using an adaptive strategy:
- **Low-dimensional (n_features <= 10):** Direct enumeration of all pairwise interactions
- **High-dimensional (n_features > 10):** TreeSHAP interaction values with liberal threshold

### Key Innovation (v4.1.1)

**Problem with previous iRF approach:**
- Original iRF: O(trees x paths) complexity, 25+ minutes for 8 features
- EBM: Unreliable interaction ranking, cannot distinguish true vs spurious

**TreeSHAP Solution:**
- O(MTLD^2) polynomial complexity (M=features, T=trees, L=leaves, D=depth)
- 83-100% recall on moderate-to-large interactions
- Theoretically grounded (Shapley values)
- Mature implementation (`shap` package)

### Adaptive Strategy

| Features | Strategy | Rationale |
|----------|----------|----------|
| <= 10 | Direct enumeration | C(10,2)=45 pairs, no filtering needed |
| > 10 | TreeSHAP + 25th percentile | Filter dummy-related interactions |

### Reference

- Lundberg, S. M., et al. (2020). From local explanations to global understanding with explainable AI for trees. *Nature Machine Intelligence*, 2(1), 56-67.

---
## Section 1: Header and Imports

In [None]:
"""
04_InteractionDiscovery.ipynb - Adaptive Interaction Discovery
===============================================================

Three-Stage Physics-Informed Symbolic Regression Framework v4.1.1

This module provides:
- AdaptiveInteractionDiscoverer: Adaptive interaction discovery
  - Low-dim (<=10 features): Direct pairwise enumeration
  - High-dim (>10 features): TreeSHAP interaction values
- Softmax soft threshold for importance-weighted feature selection
- High-recall design: prioritizes not missing true interactions

Algorithm:
    1. Train Random Forest on data
    2. Extract Gini importance -> soft_weights via softmax
    3. If n_features <= 10: enumerate all pairwise
    4. If n_features > 10: compute TreeSHAP interactions, filter by threshold
    5. Return stable_interactions for Feature Library

Output Dictionary Keys (v4.1 compatible):
    - soft_weights: Dict of softmax-transformed importance weights
    - selected_features: List of features above selection threshold
    - stable_interactions: List of high-confidence interactions
    - interaction_stability: Dict mapping interactions to scores
    - raw_importance: Dict of raw Gini importance
    - suggested_terms: List of feature product strings for library

Author: Zhengze Zhang
Affiliation: Department of Statistics, Columbia University
Contact: zz3239@columbia.edu
"""

# Import core module
%run 00_Core.ipynb

In [None]:
# Additional imports for Interaction Discovery
from sklearn.ensemble import RandomForestRegressor
from collections import Counter
from itertools import combinations
from typing import Dict, List, Tuple, Optional, Any, Set, FrozenSet

# TreeSHAP import (optional, falls back to enumeration if unavailable)
try:
    import shap
    _SHAP_AVAILABLE = True
    print(f"04_InteractionDiscovery v4.1.1: SHAP {shap.__version__} available.")
except ImportError:
    _SHAP_AVAILABLE = False
    print("04_InteractionDiscovery v4.1.1: SHAP not available, using enumeration fallback.")

print("04_InteractionDiscovery v4.1.1: Additional imports successful.")

---
## Section 2: Class Definition

In [None]:
# ==============================================================================
# ADAPTIVE INTERACTION DISCOVERER CLASS (v4.1.1 - TreeSHAP)
# ==============================================================================

class AdaptiveInteractionDiscoverer:
    """
    Adaptive Interaction Discovery with TreeSHAP.
    
    This discoverer uses an adaptive strategy based on feature count:
    - Low-dimensional (<=10 features): Direct pairwise enumeration
    - High-dimensional (>10 features): TreeSHAP interaction values
    
    The design prioritizes HIGH RECALL (not missing true interactions)
    over precision, since downstream E-WSINDy will filter false positives.
    
    Attributes
    ----------
    temperature : float
        Softmax temperature for feature importance (default: 0.5)
    selection_threshold : float
        Minimum softmax weight for feature selection (default: 0.1)
    stability_threshold : float
        For TreeSHAP: percentile threshold (default: 0.5 = top 50%)
    max_interaction_order : int
        Maximum order of interactions (default: 2 for pairwise)
    n_estimators : int
        Number of trees in Random Forest (default: 200)
    low_dim_threshold : int
        Features below this use direct enumeration (default: 10)
    
    Methods
    -------
    discover(X, y, feature_names) -> Dict
        Discover feature interactions
    get_stable_interactions() -> List[Tuple[str, ...]]
        Get list of stable interactions as tuples
    get_interaction_matrix(X) -> Tuple[np.ndarray, List[str]]
        Compute interaction features from stable interactions
    print_interaction_report() -> None
        Print detailed discovery report
    
    Reference
    ---------
    Lundberg et al. (2020). Nature Machine Intelligence, 2(1), 56-67.
    
    Examples
    --------
    >>> discoverer = AdaptiveInteractionDiscoverer()
    >>> result = discoverer.discover(X, y, feature_names)
    >>> print(result['stable_interactions'])
    [frozenset({'x0', 'x1'}), frozenset({'x1', 'x2'})]  # Discovered interactions
    """
    
    def __init__(
        self,
        temperature: float = DEFAULT_SOFTMAX_TEMPERATURE,
        selection_threshold: float = DEFAULT_IMPORTANCE_THRESHOLD,
        stability_threshold: float = DEFAULT_STABILITY_THRESHOLD,
        max_interaction_order: int = 2,
        n_estimators: int = 200,
        n_bootstrap: int = 50,  # Kept for API compatibility, not used in v4.1.1
        low_dim_threshold: int = 10,
        shap_percentile: float = 25.0,  # Liberal threshold for high recall
        random_state: int = RANDOM_SEED
    ):
        """
        Initialize AdaptiveInteractionDiscoverer.
        
        Parameters
        ----------
        temperature : float
            Softmax temperature for feature importance.
            Default: 0.5 (recommended balanced value)
        selection_threshold : float
            Minimum softmax weight for a feature to be considered.
            Default: 0.1
        stability_threshold : float
            Minimum score for stable interaction (used differently in v4.1.1).
            Default: 0.5
        max_interaction_order : int
            Maximum number of features in an interaction.
            Default: 2 (pairwise interactions)
        n_estimators : int
            Number of trees in Random Forest.
            Default: 200
        n_bootstrap : int
            Kept for API compatibility. Not used in v4.1.1.
            Default: 50
        low_dim_threshold : int
            Feature count below which direct enumeration is used.
            Default: 10
        shap_percentile : float
            Percentile threshold for SHAP filtering (lower = more recall).
            Default: 25.0 (keep top 75% of interactions)
        random_state : int
            Random seed for reproducibility.
            Default: 42
        """
        self.temperature = temperature
        self.selection_threshold = selection_threshold
        self.stability_threshold = stability_threshold
        self.max_interaction_order = max_interaction_order
        self.n_estimators = n_estimators
        self.n_bootstrap = n_bootstrap  # Kept for API compatibility
        self.low_dim_threshold = low_dim_threshold
        self.shap_percentile = shap_percentile
        self.random_state = random_state
        
        # Internal state
        self._feature_names = None
        self._rf_model = None
        self._raw_importance = None
        self._soft_weights = None
        self._selected_features = None
        self._all_interactions = None
        self._interaction_stability = None
        self._stable_interactions = None
        self._discovery_complete = False
        self._method_used = None  # 'enumeration' or 'treeshap'
    
    def discover(
        self,
        X: np.ndarray,
        y: np.ndarray,
        feature_names: List[str]
    ) -> Dict[str, Any]:
        """
        Discover feature interactions using adaptive strategy.
        
        Parameters
        ----------
        X : np.ndarray
            Feature matrix of shape (n_samples, n_features)
        y : np.ndarray
            Target vector of shape (n_samples,)
        feature_names : List[str]
            Names of features corresponding to columns of X
        
        Returns
        -------
        Dict[str, Any]
            Dictionary containing (v4.1 compatible keys):
            - soft_weights: Dict of softmax-transformed importance weights
            - selected_features: List of features above selection threshold
            - stable_interactions: List of high-confidence interactions
            - interaction_stability: Dict mapping interactions to scores
            - raw_importance: Dict of raw Gini importance
            - all_interactions: List of all candidate interactions
            - suggested_terms: List of feature product strings for library
            - n_stable_interactions: Number of stable interactions
            - method_used: 'enumeration' or 'treeshap'
        """
        self._feature_names = list(feature_names)
        n_features = X.shape[1]
        
        # Step 1: Train Random Forest for importance
        self._rf_model = self._fit_random_forest(X, y)
        
        # Step 2: Extract and transform importance
        self._raw_importance = self._rf_model.feature_importances_
        self._soft_weights = self._softmax_transform(self._raw_importance)
        
        # Step 3: Select features above threshold
        self._selected_features = [
            self._feature_names[i] for i in range(n_features)
            if self._soft_weights[i] > self.selection_threshold
        ]
        
        # Step 4: Adaptive interaction discovery
        if n_features <= self.low_dim_threshold:
            # Low-dimensional: direct enumeration
            self._method_used = 'enumeration'
            self._all_interactions, self._interaction_stability = \
                self._enumerate_all_interactions(n_features)
        else:
            # High-dimensional: TreeSHAP
            if _SHAP_AVAILABLE:
                self._method_used = 'treeshap'
                self._all_interactions, self._interaction_stability = \
                    self._treeshap_interactions(X)
            else:
                # Fallback to enumeration if SHAP unavailable
                self._method_used = 'enumeration_fallback'
                self._all_interactions, self._interaction_stability = \
                    self._enumerate_all_interactions(n_features)
        
        # Step 5: Filter stable interactions based on method
        if self._method_used == 'enumeration' or self._method_used == 'enumeration_fallback':
            # For enumeration: all interactions are "stable" (score = 1.0)
            self._stable_interactions = list(self._all_interactions)
        else:
            # For TreeSHAP: filter by percentile threshold
            if self._interaction_stability:
                scores = list(self._interaction_stability.values())
                threshold = np.percentile(scores, self.shap_percentile)
                self._stable_interactions = [
                    interaction for interaction, score
                    in self._interaction_stability.items()
                    if score >= threshold
                ]
            else:
                self._stable_interactions = []
        
        self._discovery_complete = True
        
        # Build result dictionary (v4.1 compatible)
        raw_importance_dict = {
            name: float(self._raw_importance[i])
            for i, name in enumerate(self._feature_names)
        }
        
        soft_weights_dict = {
            name: float(self._soft_weights[i])
            for i, name in enumerate(self._feature_names)
        }
        
        # Convert frozenset keys to tuple for JSON compatibility in some contexts
        interaction_stability_dict = {
            interaction: float(score)
            for interaction, score in self._interaction_stability.items()
        }
        
        # Build suggested terms for feature library
        suggested_terms = self._build_suggested_terms()
        
        return {
            # v4.1 primary keys
            'soft_weights': soft_weights_dict,
            'selected_features': self._selected_features,
            'stable_interactions': self._stable_interactions,
            'interaction_stability': interaction_stability_dict,
            # Additional useful keys
            'raw_importance': raw_importance_dict,
            'all_interactions': list(self._all_interactions),
            'suggested_terms': suggested_terms,
            'n_stable_interactions': len(self._stable_interactions),
            'temperature': self.temperature,
            'stability_threshold': self.stability_threshold,
            # v4.1.1 additions
            'method_used': self._method_used,
            'n_features': n_features,
            'low_dim_threshold': self.low_dim_threshold,
            # Backward compatibility alias
            'softmax_weights': soft_weights_dict
        }
    
    def _fit_random_forest(
        self,
        X: np.ndarray,
        y: np.ndarray
    ) -> RandomForestRegressor:
        """
        Fit Random Forest regressor.
        
        Parameters
        ----------
        X : np.ndarray
            Feature matrix
        y : np.ndarray
            Target vector
        
        Returns
        -------
        RandomForestRegressor
            Fitted Random Forest model
        """
        rf = RandomForestRegressor(
            n_estimators=self.n_estimators,
            max_features='sqrt',
            max_depth=8,  # Shallower for faster SHAP computation
            min_samples_leaf=5,
            n_jobs=-1,
            random_state=self.random_state
        )
        rf.fit(X, y)
        return rf
    
    def _softmax_transform(
        self,
        importance: np.ndarray
    ) -> np.ndarray:
        """
        Apply softmax transformation to importance scores.
        
        Parameters
        ----------
        importance : np.ndarray
            Raw Gini importance scores
        
        Returns
        -------
        np.ndarray
            Softmax-transformed weights (sum to 1)
        """
        # Scale by temperature
        scaled = importance / self.temperature
        
        # Numerical stability: subtract max before exp
        scaled = scaled - np.max(scaled)
        
        # Compute softmax
        exp_scaled = np.exp(scaled)
        weights = exp_scaled / np.sum(exp_scaled)
        
        return weights
    
    def _enumerate_all_interactions(
        self,
        n_features: int
    ) -> Tuple[Set[FrozenSet[str]], Dict[FrozenSet[str], float]]:
        """
        Enumerate all pairwise interactions (for low-dimensional data).
        
        Parameters
        ----------
        n_features : int
            Number of features
        
        Returns
        -------
        Tuple[Set[FrozenSet[str]], Dict[FrozenSet[str], float]]
            - Set of all pairwise interactions
            - Dictionary mapping each interaction to score 1.0
        """
        all_interactions = set()
        interaction_scores = {}
        
        # Generate all pairwise combinations
        for i, j in combinations(range(n_features), 2):
            interaction = frozenset([self._feature_names[i], self._feature_names[j]])
            all_interactions.add(interaction)
            interaction_scores[interaction] = 1.0  # All equally valid
        
        # Optionally add 3-way interactions if max_order allows
        if self.max_interaction_order >= 3 and n_features <= 6:
            for combo in combinations(range(n_features), 3):
                interaction = frozenset(self._feature_names[i] for i in combo)
                all_interactions.add(interaction)
                interaction_scores[interaction] = 1.0
        
        return all_interactions, interaction_scores
    
    def _treeshap_interactions(
        self,
        X: np.ndarray
    ) -> Tuple[Set[FrozenSet[str]], Dict[FrozenSet[str], float]]:
        """
        Compute TreeSHAP interaction values.
        
        Parameters
        ----------
        X : np.ndarray
            Feature matrix
        
        Returns
        -------
        Tuple[Set[FrozenSet[str]], Dict[FrozenSet[str], float]]
            - Set of all detected interactions
            - Dictionary mapping interactions to SHAP interaction scores
        """
        # Use subset of samples for speed if dataset is large
        n_samples = X.shape[0]
        if n_samples > 500:
            np.random.seed(self.random_state)
            sample_idx = np.random.choice(n_samples, 500, replace=False)
            X_sample = X[sample_idx]
        else:
            X_sample = X
        
        # Compute TreeSHAP interaction values
        explainer = shap.TreeExplainer(self._rf_model)
        shap_interactions = explainer.shap_interaction_values(X_sample)
        
        # shap_interactions shape: (n_samples, n_features, n_features)
        # Average absolute interaction values across samples
        mean_interactions = np.abs(shap_interactions).mean(axis=0)
        
        # Zero out diagonal (self-interactions)
        np.fill_diagonal(mean_interactions, 0)
        
        # Build interaction dictionary
        all_interactions = set()
        interaction_scores = {}
        n_features = X.shape[1]
        
        for i in range(n_features):
            for j in range(i + 1, n_features):
                interaction = frozenset([self._feature_names[i], self._feature_names[j]])
                # Use symmetric average
                score = (mean_interactions[i, j] + mean_interactions[j, i]) / 2
                
                if score > 0:  # Only include non-zero interactions
                    all_interactions.add(interaction)
                    interaction_scores[interaction] = float(score)
        
        return all_interactions, interaction_scores
    
    def _build_suggested_terms(
        self
    ) -> List[str]:
        """
        Build suggested interaction terms for feature library.
        
        Returns
        -------
        List[str]
            List of feature product strings (e.g., "x0 * x1")
        """
        terms = []
        for interaction in self._stable_interactions:
            # Sort for consistent ordering
            sorted_features = sorted(interaction)
            term = " * ".join(sorted_features)
            terms.append(term)
        return terms
    
    def get_stable_interactions(
        self
    ) -> List[Tuple[str, ...]]:
        """
        Get list of stable interactions as tuples.
        
        Returns
        -------
        List[Tuple[str, ...]]
            List of stable interactions
        
        Raises
        ------
        RuntimeError
            If discovery has not been performed
        """
        if not self._discovery_complete:
            raise RuntimeError("Must run discover() before getting interactions")
        
        return [tuple(sorted(interaction))
                for interaction in self._stable_interactions]
    
    def get_interaction_matrix(
        self,
        X: np.ndarray
    ) -> Tuple[np.ndarray, List[str]]:
        """
        Compute interaction features from stable interactions.
        
        Parameters
        ----------
        X : np.ndarray
            Original feature matrix
        
        Returns
        -------
        Tuple[np.ndarray, List[str]]
            - Interaction feature matrix
            - Names of interaction features
        
        Raises
        ------
        RuntimeError
            If discovery has not been performed
        """
        if not self._discovery_complete:
            raise RuntimeError("Must run discover() before computing interaction matrix")
        
        if len(self._stable_interactions) == 0:
            return np.empty((X.shape[0], 0)), []
        
        n_samples = X.shape[0]
        interaction_features = []
        interaction_names = []
        
        # Create mapping from feature name to column index
        name_to_idx = {name: i for i, name in enumerate(self._feature_names)}
        
        for interaction in self._stable_interactions:
            # Compute product of features in interaction
            product = np.ones(n_samples)
            sorted_features = sorted(interaction)
            
            for feat_name in sorted_features:
                idx = name_to_idx[feat_name]
                product *= X[:, idx]
            
            interaction_features.append(product)
            interaction_names.append("*".join(sorted_features))
        
        return np.column_stack(interaction_features), interaction_names
    
    def print_interaction_report(self) -> None:
        """
        Print a detailed interaction discovery report.
        """
        if not self._discovery_complete:
            print("Discovery not yet performed. Run discover() first.")
            return
        
        print("=" * 70)
        print("=== Interaction Discovery Results (v4.1.1 Adaptive) ===")
        print("=" * 70)
        print()
        print("Configuration:")
        print(f"  Method used: {self._method_used}")
        print(f"  Temperature: {self.temperature}")
        print(f"  Selection threshold: {self.selection_threshold}")
        print(f"  Low-dim threshold: {self.low_dim_threshold}")
        if self._method_used == 'treeshap':
            print(f"  SHAP percentile: {self.shap_percentile}")
        print()
        print("-" * 70)
        print(" Feature Importance (Softmax Weights):")
        print("-" * 70)
        print(f"{'Feature':<20} {'Raw Importance':<15} {'Soft Weight':<15} {'Selected'}")
        print("-" * 70)
        
        # Sort by softmax weight
        sorted_indices = np.argsort(self._soft_weights)[::-1]
        for idx in sorted_indices:
            name = self._feature_names[idx]
            raw = self._raw_importance[idx]
            soft = self._soft_weights[idx]
            selected = "YES" if name in self._selected_features else "no"
            print(f"{name:<20} {raw:<15.4f} {soft:<15.4f} {selected}")
        
        print()
        print("-" * 70)
        print(" Stable Interactions:")
        print("-" * 70)
        
        if len(self._stable_interactions) == 0:
            print("  No stable interactions found.")
        else:
            # Sort by score
            sorted_interactions = sorted(
                self._stable_interactions,
                key=lambda x: self._interaction_stability.get(x, 0),
                reverse=True
            )
            
            print(f"{'Interaction':<30} {'Score':<15}")
            print("-" * 50)
            for interaction in sorted_interactions[:20]:  # Show top 20
                name = " * ".join(sorted(interaction))
                score = self._interaction_stability.get(interaction, 0)
                print(f"{name:<30} {score:<15.4f}")
            
            if len(sorted_interactions) > 20:
                print(f"  ... and {len(sorted_interactions) - 20} more interactions")
        
        print()
        print(f"Total stable interactions: {len(self._stable_interactions)}")
        print(f"Total candidate interactions: {len(self._all_interactions)}")
        print()
        print("=" * 70)


# ==============================================================================
# BACKWARD COMPATIBILITY ALIAS
# ==============================================================================
# Keep IRFInteractionDiscoverer as an alias for backward compatibility
# This allows old code that uses IRFInteractionDiscoverer to still work
IRFInteractionDiscoverer = AdaptiveInteractionDiscoverer

print("AdaptiveInteractionDiscoverer class defined (v4.1.1 - TreeSHAP Adaptive).")
print("IRFInteractionDiscoverer is now an alias for AdaptiveInteractionDiscoverer.")

---
## Section 3: Internal Tests

In [None]:
# ==============================================================================
# TEST CONTROL FLAG
# ==============================================================================

_RUN_TESTS = False  # Set to True to run internal tests

if _RUN_TESTS:
    print("=" * 70)
    print(" RUNNING INTERNAL TESTS FOR 04_InteractionDiscovery v4.1.1")
    print("=" * 70)

In [None]:
# ==============================================================================
# TEST 1: Low-Dimensional Enumeration (n_features <= 10)
# ==============================================================================

if _RUN_TESTS:
    print()
    print_section_header("Test 1: Low-Dimensional Enumeration")
    
    # Generate data with 5 features (should use enumeration)
    np.random.seed(42)
    n_samples = 300
    
    x0 = np.random.uniform(0.1, 1, n_samples)
    x1 = np.random.uniform(0.1, 1, n_samples)
    x2 = np.random.uniform(0.1, 1, n_samples)
    x3 = np.random.randn(n_samples)  # Noise
    x4 = np.random.randn(n_samples)  # Noise
    
    # True equation: y = 3*x0*x1 + x2
    y = 3 * x0 * x1 + x2 + 0.1 * np.random.randn(n_samples)
    
    X = np.column_stack([x0, x1, x2, x3, x4])
    feature_names = ['x0', 'x1', 'x2', 'x3', 'x4']
    
    print(f"True equation: y = 3*x0*x1 + x2")
    print(f"Number of features: {X.shape[1]} (should use enumeration)")
    print()
    
    import time
    start = time.time()
    
    discoverer = AdaptiveInteractionDiscoverer(low_dim_threshold=10)
    result = discoverer.discover(X, y, feature_names)
    
    elapsed = time.time() - start
    
    print(f"Method used: {result['method_used']}")
    print(f"Time: {elapsed:.2f}s")
    print(f"Total interactions: {len(result['stable_interactions'])}")
    print(f"Expected: C(5,2) = 10 pairwise interactions")
    print()
    
    # Check x0*x1 is included
    x0_x1_found = any(
        frozenset(['x0', 'x1']) == interaction
        for interaction in result['stable_interactions']
    )
    
    if x0_x1_found:
        print("[PASS] True interaction (x0*x1) included in results")
    else:
        print("[FAIL] True interaction (x0*x1) NOT found")

In [None]:
# ==============================================================================
# TEST 2: Backward Compatibility (IRFInteractionDiscoverer alias)
# ==============================================================================

if _RUN_TESTS:
    print()
    print_section_header("Test 2: Backward Compatibility")
    
    # Test that IRFInteractionDiscoverer works as an alias
    discoverer_old_name = IRFInteractionDiscoverer()
    
    # Check it has the new attributes
    if hasattr(discoverer_old_name, 'low_dim_threshold'):
        print(f"[PASS] IRFInteractionDiscoverer has low_dim_threshold: {discoverer_old_name.low_dim_threshold}")
    else:
        print("[FAIL] IRFInteractionDiscoverer missing low_dim_threshold")
    
    if hasattr(discoverer_old_name, 'shap_percentile'):
        print(f"[PASS] IRFInteractionDiscoverer has shap_percentile: {discoverer_old_name.shap_percentile}")
    else:
        print("[FAIL] IRFInteractionDiscoverer missing shap_percentile")
    
    # Check class identity
    if IRFInteractionDiscoverer is AdaptiveInteractionDiscoverer:
        print("[PASS] IRFInteractionDiscoverer is AdaptiveInteractionDiscoverer")
    else:
        print("[FAIL] Classes are not the same")

---
## Section 4: Module Summary

In [None]:
# ==============================================================================
# MODULE SUMMARY
# ==============================================================================

print("=" * 70)
print(" 04_InteractionDiscovery.ipynb v4.1.1 - Module Summary")
print("=" * 70)
print()
print("CLASS: AdaptiveInteractionDiscoverer (v4.1.1 - TreeSHAP Adaptive)")
print("-" * 70)
print()
print("Purpose:")
print("  Discover feature interactions using adaptive strategy:")
print("  - Low-dim (<=10 features): Direct pairwise enumeration")
print("  - High-dim (>10 features): TreeSHAP interaction values")
print()
print("Key Innovation (v4.1.1):")
print("  - Replaces slow iRF with fast adaptive strategy")
print("  - Prioritizes HIGH RECALL (not missing true interactions)")
print("  - Fast: <1s for low-dim, 1-5min for high-dim")
print("  - False positives handled by downstream E-WSINDy")
print()
print("Backward Compatibility:")
print("  IRFInteractionDiscoverer = AdaptiveInteractionDiscoverer (alias)")
print()
print("Main Methods:")
print("  discover(X, y, feature_names)")
print("      Discover feature interactions")
print("      Returns: dict with soft_weights, stable_interactions, etc.")
print()
print("  get_stable_interactions()")
print("      Get list of stable interactions as tuples")
print()
print("  get_interaction_matrix(X)")
print("      Compute interaction features from stable interactions")
print()
print("  print_interaction_report()")
print("      Print detailed discovery report")
print()
print("Output Dictionary Keys (v4.1 compatible):")
print("  - soft_weights          : Dict of softmax-transformed weights")
print("  - selected_features     : List of features above threshold")
print("  - stable_interactions   : List of high-confidence interactions")
print("  - interaction_stability : Dict mapping interactions to scores")
print("  - raw_importance        : Dict of raw Gini importance")
print("  - suggested_terms       : List of feature product strings")
print("  - method_used           : 'enumeration' or 'treeshap'")
print()
print("Key Parameters:")
print("  low_dim_threshold: Features below this use enumeration (default: 10)")
print("  shap_percentile: For TreeSHAP, percentile threshold (default: 25)")
print("  temperature: Softmax temperature for importance (default: 0.5)")
print()
print("Usage Example:")
print("-" * 70)
print("""
# Create discoverer (new name)
discoverer = AdaptiveInteractionDiscoverer(
    low_dim_threshold=10,
    shap_percentile=25
)

# Or use old name (backward compatible)
discoverer = IRFInteractionDiscoverer()  # Same as above

# Run discovery
result = discoverer.discover(X, y, feature_names)

# Get stable interactions (v4.1 key)
interactions = result['stable_interactions']
print(f"Found {len(interactions)} interactions using {result['method_used']}")
""")
print()
print("=" * 70)
print("Module loaded successfully. Import via: %run 04_InteractionDiscovery.ipynb")
print("=" * 70)