<a href="https://colab.research.google.com/github/AntoineGaton/CTU/blob/main/Decision_Trees_U5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [6]:
"""
Author: Antoine Gaton
Date: November 10, 2024
Course: Machine Learning: CS379
Description: This code implements four different splitting approaches for decision trees:
1. Reduction in Variance (for regression).
2. Information Gain (for classification).
3. Gini Impurity (for classification).
4. Chi-Square Test (for classification).

The code showcases each approach with a sample dataset and is structured for
modular use and extensibility.
"""

from sklearn.datasets import load_iris, fetch_california_housing
from sklearn.model_selection import train_test_split
import numpy as np
from math import log2
from scipy.stats import chi2_contingency
from rich import print
from rich.console import Console
from rich.panel import Panel
from rich.box import DOUBLE

console = Console()

def reduction_in_variance(y):
    """
    Reduction in variance splitting criterion.

    Parameters:
    - y: array-like, target values

    Returns:
    - variance_reduction: float, amount of variance reduction achieved
    """
    try:
        initial_variance = np.var(y)
        if len(y) <= 1:
            return 0  # Edge case for no or single data point
        return initial_variance - np.var(y)
    except Exception as e:
        console.print(f"[red]Error in reduction_in_variance: {e}[/red]")
        return 0


def information_gain(y, y_left, y_right):
    """
    Information gain splitting criterion based on entropy.

    Parameters:
    - y: array-like, target values
    - y_left: array-like, left split
    - y_right: array-like, right split

    Returns:
    - info_gain: float, amount of information gained
    """
    try:
        def entropy(vals):
            probs = np.bincount(vals) / len(vals)
            return -sum(p * log2(p) for p in probs if p > 0)

        entropy_before = entropy(y)
        entropy_after = (len(y_left) / len(y)) * entropy(y_left) + (len(y_right) / len(y)) * entropy(y_right)
        return entropy_before - entropy_after
    except Exception as e:
        console.print(f"[red]Error in information_gain: {e}[/red]")
        return 0


def gini_impurity(y):
    """
    Gini impurity splitting criterion.

    Parameters:
    - y: array-like, target values

    Returns:
    - gini: float, gini impurity value
    """
    try:
        probs = np.bincount(y) / len(y)
        return 1 - sum(p**2 for p in probs)
    except Exception as e:
        console.print(f"[red]Error in gini_impurity: {e}[/red]")
        return 0


def chi_square_test(y, y_pred):
    """
    Chi-square test splitting criterion.

    Parameters:
    - y: array-like, true target values
    - y_pred: array-like, predicted target values

    Returns:
    - chi2: float, chi-square statistic
    - p: float, p-value
    """
    try:
        contingency_table = np.zeros((len(np.unique(y)), len(np.unique(y_pred))))
        for i, label in enumerate(np.unique(y)):
            for j, pred in enumerate(np.unique(y_pred)):
                contingency_table[i, j] = np.sum((y == label) & (y_pred == pred))
        chi2, p, _, _ = chi2_contingency(contingency_table)
        return chi2, p
    except Exception as e:
        console.print(f"[red]Error in chi_square_test: {e}[/red]")
        return 0, 1


def main():
    """
    Main function to demonstrate each splitting criterion.
    """
    # Load datasets
    iris = load_iris()
    cali_housing = fetch_california_housing()

    # Sample data for classification (Iris) and regression (California Housing)
    X_class, _, y_class, _ = train_test_split(iris.data, iris.target, test_size=0.5, random_state=42)
    X_reg, _, y_reg, _ = train_test_split(cali_housing.data, cali_housing.target, test_size=0.5, random_state=42)

    # Information Gain and Explanation
    ig = information_gain(y_class, y_class[:len(y_class)//2], y_class[len(y_class)//2:])
    ig_text = f"[green]Information Gain:[/green] {ig}\n"
    if ig > 0.1:
        ig_text += "[yellow]This information gain is relatively high, indicating a significant improvement in purity.[/yellow]"
    else:
        ig_text += "[yellow]This information gain is relatively low, suggesting only a modest improvement in purity.[/yellow]"
    console.print(Panel(ig_text, title="Information Gain", box=DOUBLE))

    # Gini Impurity and Explanation
    gini = gini_impurity(y_class)
    gini_text = f"[green]Gini Impurity:[/green] {gini}\n"
    if gini < 0.3:
        gini_text += "[yellow]Low Gini impurity indicates a relatively pure node.[/yellow]"
    elif gini < 0.7:
        gini_text += "[yellow]Moderate Gini impurity suggests some mix of classes in this node.[/yellow]"
    else:
        gini_text += "[yellow]High Gini impurity indicates a mixed node with multiple classes.[/yellow]"
    console.print(Panel(gini_text, title="Gini Impurity", box=DOUBLE))

    # Chi-Square Test and Explanation
    chi2, p_val = chi_square_test(y_class, y_class)
    chi2_text = f"[green]Chi-Square Statistic:[/green] {chi2}\n[green]P-value:[/green] {p_val}\n"
    if p_val < 0.05:
        chi2_text += "[yellow]The low p-value indicates a statistically significant association between the feature and the target.[/yellow]"
    else:
        chi2_text += "[yellow]The high p-value suggests no significant association between the feature and the target.[/yellow]"
    console.print(Panel(chi2_text, title="Chi-Square Test", box=DOUBLE))

    # Reduction in Variance and Explanation
    variance_reduction = reduction_in_variance(y_reg)
    variance_text = f"[green]Reduction in Variance:[/green] {variance_reduction}\n"
    if variance_reduction > 0:
        variance_text += "[yellow]This reduction in variance suggests that the split improves the homogeneity of the target values.[/yellow]"
    else:
        variance_text += "[yellow]A reduction in variance of zero indicates no improvement in homogeneity from this split.[/yellow]"
    console.print(Panel(variance_text, title="Reduction in Variance", box=DOUBLE))

    # Summary
    summary_text = (
        "[bold magenta]Summary of Results[/bold magenta]\n"
        "The classification criteria suggest:\n"
        "- Information Gain and Gini Impurity indicate the purity levels after a split.\n"
        "- Chi-Square results show the statistical association between features and the target variable.\n"
        "The regression criterion (Reduction in Variance) assesses the effectiveness of the split in homogenizing target values in regression."
    )
    console.print(Panel(summary_text, title="Summary", box=DOUBLE))

if __name__ == "__main__":
    main()
