<a href="https://colab.research.google.com/github/AmayaDes/Neuro-Symbolic_PCOS-Detection-FYP/blob/main/notebooks/04_Rule_Extraction.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
"""
PCOS Rule Extraction using Decision Trees

Uses top 6 features based on Cohen's d effect sizes
- follicle_count (d=2.900)
- follicle_density (d=2.900)
- peripheral_distribution (d=1.635)
- avg_follicle_size (d=1.397)
- stromal_echogenicity (d=1.267)
- ovarian_circularity (d=1.087)

Classes:
- RuleExtractor: Trains decision tree and extracts symbolic rules
- RuleApplier: Applies extracted rules to generate rule scores
- RuleVisualizer: Visualizes decision tree and rule importance
"""

import os
import pandas as pd
import numpy as np
from sklearn.tree import DecisionTreeClassifier, export_text, plot_tree
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc
import matplotlib.pyplot as plt
import seaborn as sns
import json
import joblib

# ========================
# CONFIGURATION
# ========================

ORGANIZED_DIR = "/content/"

# GitHub base URL for feature CSVs
GITHUB_BASE_URL = "https://raw.githubusercontent.com/AmayaDes/Neuro-Symbolic_PCOS-Detection-FYP/main/results/"

# ========================
# LOAD FEATURE CSVs FROM GITHUB
# ========================

print("\n Downloading feature CSVs from GitHub...")

train_features_df = pd.read_csv(GITHUB_BASE_URL + "features_train.csv")
val_features_df   = pd.read_csv(GITHUB_BASE_URL + "features_val.csv")
test_features_df  = pd.read_csv(GITHUB_BASE_URL + "features_test.csv")


print(f" Loaded feature CSVs from GitHub:")
print(f" Train: {len(train_features_df)} samples")
print(f" Val:   {len(val_features_df)} samples")
print(f" Test:  {len(test_features_df)} samples")

# Verify required features exist
required_features = [
    'follicle_count',
    'follicle_density',
    'avg_follicle_size',
    'peripheral_distribution',
    'stromal_echogenicity',
    'ovarian_circularity'
]

print(f"\n Verifying required features...")
missing_features = [f for f in required_features if f not in train_features_df.columns]

if missing_features:
    raise ValueError(f" Missing required features: {missing_features}")
else:
    print(f" All {len(required_features)} required features present")


# ========================
# CLASS 1: RuleExtractor
# ========================

class RuleExtractor:
    """
    Extracts symbolic rules from PCOS features using Decision Tree

    Attributes:
        organized_dir (str): Path to organized data
        train_df (DataFrame): Training features
        pcos_features (list): List of PCOS feature names
        decision_tree (DecisionTreeClassifier): Trained decision tree
        rules_text (str): Extracted rules in text format
        max_depth (int): Maximum depth of decision tree
    """

    def __init__(self, train_features_df, organized_dir, max_depth=4):
        """
        Initialize RuleExtractor

        Args:
            train_features_df (DataFrame): Training features (loaded from GitHub)
            organized_dir (str): Path for saving outputs
            max_depth (int): Maximum depth for decision tree (for interpretability)
        """
        self.organized_dir = organized_dir
        self.max_depth = max_depth
        self.train_df = train_features_df
        self.decision_tree = None
        self.rules_text = None

        # from feature extraction results
        self.pcos_features = [
            'follicle_count',
            'follicle_density',
            'peripheral_distribution',
            'avg_follicle_size',
            'stromal_echogenicity',
            'ovarian_circularity',
        ]

    def load_training_features(self):
        """Verify training features are loaded"""
        print("="*70)
        print("VERIFYING TRAINING FEATURES FOR RULE EXTRACTION")
        print("="*70)

        print(f"\n Using {len(self.train_df)} training samples")
        print(f"   PCOS:   {(self.train_df['label']==1).sum()}")
        print(f"   Normal: {(self.train_df['label']==0).sum()}")

        print(f"\n Using {len(self.pcos_features)} ")
        feature_cohens_d = {
            'follicle_count': 2.900,
            'follicle_density': 2.900,
            'peripheral_distribution': 1.635,
            'avg_follicle_size': 1.397,
            'stromal_echogenicity': 1.267,
            'ovarian_circularity': 1.087
        }

        for feat in self.pcos_features:
            if feat not in self.train_df.columns:
                raise ValueError(f" Required feature not found: {feat}")
            print(f"   - {feat:<25} (d={feature_cohens_d[feat]:.3f})")

        return self

    def train_decision_tree(self):
        """Train decision tree on PCOS features"""
        print("\n" + "="*70)
        print("TRAINING DECISION TREE")
        print("="*70)

        # Prepare features and labels
        X_train = self.train_df[self.pcos_features]
        y_train = self.train_df['label']

        # Train decision tree
        self.decision_tree = DecisionTreeClassifier(
            max_depth=self.max_depth,
            min_samples_split=20,
            min_samples_leaf=10,
            random_state=42,
            class_weight='balanced'     # Handle class imbalance
        )

        self.decision_tree.fit(X_train, y_train)

        # Evaluate on training set
        train_pred = self.decision_tree.predict(X_train)
        train_accuracy = (train_pred == y_train).mean()

        # Get probabilities for AUC
        train_proba = self.decision_tree.predict_proba(X_train)[:, 1]
        fpr, tpr, _ = roc_curve(y_train, train_proba)
        train_auc = auc(fpr, tpr)

        print(f"\n Decision Tree trained successfully")
        print(f"   Max depth:         {self.max_depth}")
        print(f"   Training accuracy: {train_accuracy:.3f}")
        print(f"   Training AUC:      {train_auc:.3f}")
        print(f"   Number of leaves:  {self.decision_tree.get_n_leaves()}")
        print(f"   Tree depth:        {self.decision_tree.get_depth()}")

        return self

    def extract_rules(self):
        """Extract symbolic rules from trained decision tree"""

        print("\n" + "="*70)
        print("EXTRACTING SYMBOLIC RULES")
        print("="*70)

        # Extract rules as text
        self.rules_text = export_text(
            self.decision_tree,
            feature_names=self.pcos_features
        )

        print("\n Extracted Rules:")
        print(self.rules_text)

        # Save rules to file
        rules_path = os.path.join(self.organized_dir, 'pcos_decision_tree_rules.txt')
        with open(rules_path, 'w') as f:
            f.write("PCOS DIAGNOSTIC RULES\n")
            f.write("="*70 + "\n\n")
            f.write("Rules extracted from Decision Tree trained on optimized features\n")
            f.write("(Blob detection threshold: 0.130, Cohen's d = 2.691)\n\n")
            f.write("="*70 + "\n\n")
            f.write("FEATURE EFFECT SIZES (Cohen's d):\n")
            f.write("- follicle_count:           2.900 (STRONGEST)\n")
            f.write("- follicle_density:         2.900 (STRONGEST)\n")
            f.write("- peripheral_distribution:  1.635 (Large)\n")
            f.write("- avg_follicle_size:        1.397 (Large)\n")
            f.write("- stromal_echogenicity:     1.267 (Large)\n")
            f.write("- ovarian_circularity:      1.087 (Large)\n\n")
            f.write("="*70 + "\n\n")
            f.write("DECISION TREE RULES:\n\n")
            f.write(self.rules_text)

        print(f"\n Rules saved to: pcos_decision_tree_rules.txt")

        return self

    def get_feature_importance(self):
        """Get feature importance from decision tree"""
        importance_df = pd.DataFrame({
            'feature': self.pcos_features,
            'importance': self.decision_tree.feature_importances_,
            'cohens_d': [2.900, 2.900, 1.635, 1.397, 1.267, 1.087]
        }).sort_values('importance', ascending=False)

        return importance_df

    def save_model(self):
        """Save trained decision tree model"""
        model_path = os.path.join(self.organized_dir, 'decision_tree_model.pkl')
        joblib.dump(self.decision_tree, model_path)
        print(f"\n Model saved to: decision_tree_model.pkl")
        return self


# ========================
# CLASS 2: RuleApplier
# ========================

class RuleApplier:
    """
    Applies extracted decision tree rules to generate rule scores

    Attributes:
        decision_tree (DecisionTreeClassifier): Trained decision tree
        organized_dir (str): Path for saving outputs
        pcos_features (list): List of PCOS feature names
        feature_dfs (dict): Dictionary of pre-loaded feature DataFrames
    """

    def __init__(self, decision_tree, organized_dir, pcos_features, feature_dfs):
        """
        Initialize RuleApplier

        Args:
            decision_tree: Trained DecisionTreeClassifier
            organized_dir: Path for saving outputs
            pcos_features: List of feature names
            feature_dfs: Dict with 'train', 'val', 'test' DataFrames
        """
        self.decision_tree = decision_tree
        self.organized_dir = organized_dir
        self.pcos_features = pcos_features
        self.feature_dfs = feature_dfs

    def apply_rules_to_split(self, split_name):
        """
        Apply decision tree rules to a data split
        Args: split_name (str): 'train', 'val', or 'test'
        Returns:DataFrame with added rule scores
        """

        print(f"\n Applying rules to {split_name} set...")

        # Use pre-loaded DataFrame
        df = self.feature_dfs[split_name].copy()

        # Extract features
        X = df[self.pcos_features]

        # Get rule predictions (0 or 1)
        df['dt_rule_prediction'] = self.decision_tree.predict(X)

        # Get rule confidence (probability of PCOS class)
        dt_proba = self.decision_tree.predict_proba(X)
        df['dt_rule_score'] = dt_proba[:, 1]

        # Save with rule scores
        output_path = os.path.join(self.organized_dir, f'features_{split_name}_with_rules.csv')
        df.to_csv(output_path, index=False)

        print(f" Saved: features_{split_name}_with_rules.csv")

        # Evaluate rule performance
        if 'label' in df.columns:
            accuracy = (df['dt_rule_prediction'] == df['label']).mean()

            # Calculate AUC
            fpr, tpr, _ = roc_curve(df['label'], df['dt_rule_score'])
            rule_auc = auc(fpr, tpr)

            print(f"  Accuracy: {accuracy:.3f}")
            print(f"  AUC:      {rule_auc:.3f}")

        return df

    def apply_rules_to_all_splits(self):
        """Apply rules to train, val, and test splits"""

        print("\n" + "="*70)
        print("APPLYING RULES TO ALL SPLITS")
        print("="*70)

        results = {}
        for split in ['train', 'val', 'test']:
            try:
                results[split] = self.apply_rules_to_split(split)
            except Exception as e:
                print(f"\n Error processing {split}: {str(e)}")
                import traceback
                traceback.print_exc()
                continue

        return results


# ========================
# CLASS 3: RuleVisualizer
# ========================

class RuleVisualizer:
    """
    Visualizes decision tree rules and performance

    Attributes:
        decision_tree: Trained decision tree
        organized_dir: Output directory
        pcos_features: Feature names
    """

    def __init__(self, decision_tree, organized_dir, pcos_features):
        """Initialize RuleVisualizer"""
        self.decision_tree = decision_tree
        self.organized_dir = organized_dir
        self.pcos_features = pcos_features

    def plot_decision_tree(self):
        """Visualize the decision tree structure"""

        print("\n Creating decision tree visualization...")

        plt.figure(figsize=(25, 12))
        plot_tree(
            self.decision_tree,
            feature_names=self.pcos_features,
            class_names=['Normal', 'PCOS'],
            filled=True,
            rounded=True,
            fontsize=10,
            proportion=True
        )
        plt.title('PCOS Decision Tree Rules\n(Optimized Features: follicle_count d=2.900, threshold=0.130)',
                 fontsize=18, fontweight='bold', pad=20)
        plt.tight_layout()

        output_path = os.path.join(self.organized_dir, 'decision_tree_visualization.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"    Saved: decision_tree_visualization.png")
        plt.close()

    def plot_feature_importance(self, importance_df):
        """Plot feature importance from decision tree"""
        print(" Creating feature importance plot...")

        fig, ax = plt.subplots(figsize=(12, 7))

        # Create bars
        colors = plt.cm.RdYlGn(importance_df['importance'] / importance_df['importance'].max())
        bars = ax.barh(importance_df['feature'], importance_df['importance'],
                       color=colors, edgecolor='black', linewidth=1.5)

        ax.set_xlabel('Importance', fontsize=12, fontweight='bold')
        ax.set_title('Feature Importance in Decision Tree Rules\n(Top 6 Features by Cohen\'s d)',
                 fontsize=14, fontweight='bold')
        ax.grid(axis='x', alpha=0.3)

        # Add value labels with Cohen's d
        for idx, (bar, row) in enumerate(zip(bars, importance_df.iterrows())):
            width = bar.get_width()
            ax.text(width + 0.01, bar.get_y() + bar.get_height()/2,
                    f'{width:.3f} (d={row[1]["cohens_d"]:.2f})',
                    ha='left', va='center', fontsize=10, fontweight='bold')

        plt.tight_layout()

        output_path = os.path.join(self.organized_dir, 'rule_feature_importance.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f" Saved: rule_feature_importance.png")
        plt.close()

    def plot_rule_performance(self, train_df, val_df, test_df):
        """Plot rule performance across splits"""

        print("Creating rule performance comparison...")

        splits_data = [
            ('Train', train_df),
            ('Val', val_df),
            ('Test', test_df)
        ]

        fig, axes = plt.subplots(2, 3, figsize=(18, 10))

        for idx, (split_name, df) in enumerate(splits_data):
            # Confusion matrix
            cm = confusion_matrix(df['label'], df['dt_rule_prediction'])

            sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', ax=axes[0, idx],
                       xticklabels=['Normal', 'PCOS'],
                       yticklabels=['Normal', 'PCOS'],
                       cbar=False, annot_kws={'fontsize': 14, 'fontweight': 'bold'})

            accuracy = (df['dt_rule_prediction'] == df['label']).mean()
            fpr, tpr, _ = roc_curve(df['label'], df['dt_rule_score'])
            rule_auc = auc(fpr, tpr)

            axes[0, idx].set_title(f'{split_name} Set\nAcc: {accuracy:.3f}, AUC: {rule_auc:.3f}',
                                  fontweight='bold', fontsize=12)
            axes[0, idx].set_ylabel('True Label', fontsize=11)
            axes[0, idx].set_xlabel('Rule Prediction', fontsize=11)

            # ROC curve
            axes[1, idx].plot(fpr, tpr, linewidth=2.5, label=f'AUC = {rule_auc:.3f}', color='#2196F3')
            axes[1, idx].plot([0, 1], [0, 1], 'k--', linewidth=1.5, label='Random', alpha=0.5)
            axes[1, idx].set_xlabel('False Positive Rate', fontsize=11)
            axes[1, idx].set_ylabel('True Positive Rate', fontsize=11)
            axes[1, idx].set_title(f'{split_name} ROC Curve', fontweight='bold', fontsize=12)
            axes[1, idx].legend(fontsize=10)
            axes[1, idx].grid(alpha=0.3)
            axes[1, idx].set_xlim([-0.05, 1.05])
            axes[1, idx].set_ylim([-0.05, 1.05])

        plt.suptitle('Decision Tree Rule Performance\n(Optimized Features: follicle_count d=2.900, threshold=0.130)',
                     fontsize=16, fontweight='bold')
        plt.tight_layout()

        output_path = os.path.join(self.organized_dir, 'rule_performance.png')
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f" Saved: rule_performance.png")
        plt.close()

    def save_rule_statistics(self, train_df, val_df, test_df):
        """Save rule performance statistics"""
        print(" Saving rule statistics...")

        stats = {}

        for split_name, df in [('train', train_df), ('val', val_df), ('test', test_df)]:
            # Classification report
            report = classification_report(
                df['label'],
                df['dt_rule_prediction'],
                target_names=['Normal', 'PCOS'],
                output_dict=True
            )

            # Add AUC
            fpr, tpr, _ = roc_curve(df['label'], df['dt_rule_score'])
            report['auc'] = auc(fpr, tpr)

            stats[split_name] = report

        output_path = os.path.join(self.organized_dir, 'rule_performance_stats.json')
        with open(output_path, 'w') as f:
            json.dump(stats, f, indent=4)

        print(f" Saved: rule_performance_stats.json")

        # Print summary

        print("\n" + "="*70)
        print("RULE PERFORMANCE SUMMARY")
        print("="*70)

        for split_name in ['train', 'val', 'test']:
            acc = stats[split_name]['accuracy']
            pcos_f1 = stats[split_name]['PCOS']['f1-score']
            auc_score = stats[split_name]['auc']

            print(f"\n{split_name.upper()}:")
            print(f"  Accuracy:     {acc:.3f}")
            print(f"  PCOS F1:      {pcos_f1:.3f}")
            print(f"  AUC:          {auc_score:.3f}")


# ========================
# MAIN EXECUTION
# ========================

if __name__ == "__main__":

    print("="*70)
    print("PCOS RULE EXTRACTION - DECISION TREE")
    print("="*70)

    # Step 1: Extract rules using GitHub-loaded training features
    extractor = RuleExtractor(train_features_df, ORGANIZED_DIR, max_depth=4)
    extractor.load_training_features()
    extractor.train_decision_tree()
    extractor.extract_rules()
    importance_df = extractor.get_feature_importance()
    extractor.save_model()

    print("\n" + "="*70)
    print("FEATURE IMPORTANCE IN RULES")
    print("="*70)
    print(importance_df.to_string(index=False))

    # Step 2: Apply rules to all splits using GitHub-loaded DataFrames
    feature_dfs = {
        'train': train_features_df,
        'val': val_features_df,
        'test': test_features_df
    }

    applier = RuleApplier(
        extractor.decision_tree,
        ORGANIZED_DIR,
        extractor.pcos_features,
        feature_dfs
    )
    results = applier.apply_rules_to_all_splits()

    # Step 3: Visualize rules
    if len(results) > 0:
        visualizer = RuleVisualizer(
            extractor.decision_tree,
            ORGANIZED_DIR,
            extractor.pcos_features
        )
        visualizer.plot_decision_tree()
        visualizer.plot_feature_importance(importance_df)

        if all(k in results for k in ['train', 'val', 'test']):
            visualizer.plot_rule_performance(
                results['train'],
                results['val'],
                results['test']
            )
            visualizer.save_rule_statistics(
                results['train'],
                results['val'],
                results['test']
            )

    # Summary
    print("\n" + "="*70)
    print(" RULE EXTRACTION COMPLETE!")
    print("="*70)

    print("\n Created files:")
    print("   features_train_with_rules.csv")
    print("   features_val_with_rules.csv")
    print("   features_test_with_rules.csv")
    print("   decision_tree_model.pkl")
    print("   pcos_decision_tree_rules.txt")
    print("   decision_tree_visualization.png")
    print("   rule_feature_importance.png")
    print("   rule_performance.png")
    print("   rule_performance_stats.json")


 Downloading feature CSVs from GitHub...
 Loaded feature CSVs from GitHub:
 Train: 2782 samples
 Val:   596 samples
 Test:  597 samples

 Verifying required features...
 All 6 required features present
PCOS RULE EXTRACTION - DECISION TREE
VERIFYING TRAINING FEATURES FOR RULE EXTRACTION

 Using 2782 training samples
   PCOS:   1162
   Normal: 1620

 Using 6 
   - follicle_count            (d=2.900)
   - follicle_density          (d=2.900)
   - peripheral_distribution   (d=1.635)
   - avg_follicle_size         (d=1.397)
   - stromal_echogenicity      (d=1.267)
   - ovarian_circularity       (d=1.087)

TRAINING DECISION TREE

 Decision Tree trained successfully
   Max depth:         4
   Training accuracy: 0.982
   Training AUC:      0.992
   Number of leaves:  10
   Tree depth:        4

EXTRACTING SYMBOLIC RULES

 Extracted Rules:
|--- follicle_count <= 35.50
|   |--- peripheral_distribution <= 0.96
|   |   |--- ovarian_circularity <= 0.13
|   |   |   |--- stromal_echogenicity <= 69.39