# CART Decision Tree Classifier for Iris Dataset

This notebook implements the CART (Classification and Regression Trees) algorithm from scratch to classify iris flowers.

## Algorithm Overview:
- Uses **Gini Impurity** to measure node purity and select the best splits
- Handles both continuous and categorical features
- Creates binary splits at each node
- More flexible than ID3 as it can work with numerical data directly

In [None]:
# Import necessary libraries
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
import seaborn as sns
from collections import Counter

In [None]:
# Load and prepare the iris dataset
def load_iris_data():
    """Load and return iris dataset as pandas DataFrame"""
    iris = load_iris()
    
    # Create DataFrame with original continuous features
    df = pd.DataFrame(iris.data, columns=iris.feature_names)
    df['species'] = iris.target_names[iris.target]
    
    return df

# Load the data
iris_df = load_iris_data()
print("Iris Dataset Shape:", iris_df.shape)
print("\nFirst 5 rows:")
print(iris_df.head())
print("\nDataset Info:")
print(iris_df.info())
print("\nTarget distribution:")
print(iris_df['species'].value_counts())

In [None]:
# CART Decision Tree Implementation
class CARTDecisionTree:
    """Simple CART Decision Tree implementation"""
    
    def __init__(self, max_depth=10, min_samples_split=2, min_samples_leaf=1):
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_samples_leaf = min_samples_leaf
        self.tree = None
        self.feature_names = None
    
    def gini_impurity(self, labels):
        """Calculate Gini impurity of a set of labels"""
        if len(labels) == 0:
            return 0
        
        # Count occurrences of each label
        counts = Counter(labels)
        total = len(labels)
        
        # Calculate Gini impurity
        gini = 1.0
        for count in counts.values():
            p = count / total
            gini -= p ** 2
        
        return gini
    
    def find_best_split(self, data, target_col):
        """Find the best feature and threshold to split on"""
        best_gini = float('inf')
        best_feature = None
        best_threshold = None
        best_left_data = None
        best_right_data = None
        
        # Get feature columns (excluding target)
        feature_cols = [col for col in data.columns if col != target_col]
        
        # Try each feature
        for feature in feature_cols:
            # Get unique values for this feature and sort them
            unique_values = sorted(data[feature].unique())
            
            # Try each possible threshold (midpoint between consecutive unique values)
            for i in range(len(unique_values) - 1):
                threshold = (unique_values[i] + unique_values[i + 1]) / 2
                
                # Split data based on threshold
                left_data = data[data[feature] <= threshold]
                right_data = data[data[feature] > threshold]
                
                # Skip if split doesn't meet minimum sample requirements
                if len(left_data) < self.min_samples_leaf or len(right_data) < self.min_samples_leaf:
                    continue
                
                # Calculate weighted Gini impurity
                total_samples = len(data)
                left_weight = len(left_data) / total_samples
                right_weight = len(right_data) / total_samples
                
                weighted_gini = (left_weight * self.gini_impurity(left_data[target_col]) + 
                               right_weight * self.gini_impurity(right_data[target_col]))
                
                # Update best split if this is better
                if weighted_gini < best_gini:
                    best_gini = weighted_gini
                    best_feature = feature
                    best_threshold = threshold
                    best_left_data = left_data
                    best_right_data = right_data
        
        return best_feature, best_threshold, best_left_data, best_right_data, best_gini
    
    def build_tree(self, data, target_col, depth=0):
        """Recursively build the decision tree"""
        # Base cases
        target_values = data[target_col].unique()
        
        # If all samples have same class, return leaf node
        if len(target_values) == 1:
            return target_values[0]
        
        # If max depth reached or not enough samples, return most common class
        if (depth >= self.max_depth or 
            len(data) < self.min_samples_split or 
            len(data) < 2 * self.min_samples_leaf):
            return data[target_col].mode()[0]
        
        # Find best split
        best_feature, best_threshold, left_data, right_data, best_gini = self.find_best_split(data, target_col)
        
        # If no good split found, return most common class
        if best_feature is None or best_gini == float('inf'):
            return data[target_col].mode()[0]
        
        # Create tree node
        tree = {
            'feature': best_feature,
            'threshold': best_threshold,
            'left': None,
            'right': None,
            'gini': best_gini,
            'samples': len(data),
            'class_distribution': dict(data[target_col].value_counts())
        }
        
        # Recursively build left and right subtrees
        tree['left'] = self.build_tree(left_data, target_col, depth + 1)
        tree['right'] = self.build_tree(right_data, target_col, depth + 1)
        
        return tree
    
    def fit(self, X, y):
        """Train the decision tree"""
        # Combine features and target
        data = X.copy()
        data['target'] = y
        
        self.feature_names = list(X.columns)
        self.tree = self.build_tree(data, 'target')
        return self
    
    def predict_single(self, sample, tree=None):
        """Predict class for a single sample"""
        if tree is None:
            tree = self.tree
        
        # If tree is a leaf node (string), return the class
        if isinstance(tree, str):
            return tree
        
        # Navigate based on feature threshold
        feature_value = sample[tree['feature']]
        
        if feature_value <= tree['threshold']:
            return self.predict_single(sample, tree['left'])
        else:
            return self.predict_single(sample, tree['right'])
    
    def predict(self, X):
        """Predict classes for multiple samples"""
        predictions = []
        for _, sample in X.iterrows():
            pred = self.predict_single(sample)
            predictions.append(pred)
        return predictions
    
    def print_tree(self, tree=None, indent="", side=""):
        """Print the decision tree structure"""
        if tree is None:
            tree = self.tree
        
        if isinstance(tree, str):
            print(f"{indent}{side}-> Class: {tree}")
            return
        
        print(f"{indent}{side}Feature: {tree['feature']} <= {tree['threshold']:.3f}")
        print(f"{indent}  Gini: {tree['gini']:.3f}, Samples: {tree['samples']}")
        print(f"{indent}  Class distribution: {tree['class_distribution']}")
        
        if tree['left'] is not None:
            self.print_tree(tree['left'], indent + "  ", "Left ")
        if tree['right'] is not None:
            self.print_tree(tree['right'], indent + "  ", "Right ")

print("CART Decision Tree class implemented successfully!")

In [None]:
# Split data into training and testing sets
X = iris_df.drop('species', axis=1)
y = iris_df['species']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.3, random_state=42, stratify=y
)

print(f"Training set shape: {X_train.shape}")
print(f"Test set shape: {X_test.shape}")
print(f"\nFeatures: {list(X.columns)}")
print(f"\nTraining set target distribution:")
print(y_train.value_counts())
print(f"\nTest set target distribution:")
print(y_test.value_counts())

In [None]:
# Train the CART Decision Tree
print("Training CART Decision Tree...")

# Create and train the model
cart_tree = CARTDecisionTree(max_depth=5, min_samples_split=2, min_samples_leaf=1)
cart_tree.fit(X_train, y_train)

print("Training completed!")
print("\n" + "="*60)
print("CART DECISION TREE STRUCTURE:")
print("="*60)
cart_tree.print_tree()

In [None]:
# Make predictions on test set
print("Making predictions on test set...")

# Make predictions
y_pred = cart_tree.predict(X_test)

# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print(f"\nTest Accuracy: {accuracy:.4f}")

# Detailed classification report
print("\n" + "="*50)
print("CLASSIFICATION REPORT:")
print("="*50)
print(classification_report(y_test, y_pred))

# Show some example predictions
print("\n" + "="*50)
print("SAMPLE PREDICTIONS:")
print("="*50)
for i in range(min(8, len(X_test))):
    sample = X_test.iloc[i]
    actual = y_test.iloc[i]
    predicted = y_pred[i]
    
    print(f"Sample {i+1}:")
    print(f"  Sepal Length: {sample['sepal length (cm)']:.2f} cm")
    print(f"  Sepal Width:  {sample['sepal width (cm)']:.2f} cm")
    print(f"  Petal Length: {sample['petal length (cm)']:.2f} cm")
    print(f"  Petal Width:  {sample['petal width (cm)']:.2f} cm")
    print(f"  Actual: {actual}, Predicted: {predicted}")
    print(f"  {'✓ Correct' if actual == predicted else '✗ Wrong'}")
    print()

In [None]:
# Visualize results and tree analysis
fig, axes = plt.subplots(2, 2, figsize=(15, 12))

# 1. Confusion Matrix
cm = confusion_matrix(y_test, y_pred)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', 
            xticklabels=np.unique(y_test), yticklabels=np.unique(y_test),
            ax=axes[0,0])
axes[0,0].set_title('Confusion Matrix')
axes[0,0].set_xlabel('Predicted')
axes[0,0].set_ylabel('Actual')

# 2. Feature importance (based on splits in tree)
def calculate_feature_importance(tree):
    """Calculate feature importance based on Gini improvement"""
    importance = {}
    
    def traverse_tree(node, samples_weight=1.0):
        if isinstance(node, str):
            return
        
        feature = node['feature']
        if feature not in importance:
            importance[feature] = 0
        
        # Add importance based on Gini improvement weighted by samples
        gini_improvement = node['gini'] * samples_weight
        importance[feature] += gini_improvement
        
        # Recursively traverse subtrees
        if node['left'] is not None:
            left_weight = 0.5 * samples_weight  # Simplified weighting
            traverse_tree(node['left'], left_weight)
        if node['right'] is not None:
            right_weight = 0.5 * samples_weight  # Simplified weighting
            traverse_tree(node['right'], right_weight)
    
    traverse_tree(tree)
    
    # Normalize importance values
    if importance:
        total = sum(importance.values())
        for feature in importance:
            importance[feature] /= total
    
    return importance

feature_importance = calculate_feature_importance(cart_tree.tree)

if feature_importance:
    features = list(feature_importance.keys())
    importance_values = list(feature_importance.values())
    
    axes[0,1].bar(features, importance_values, color='skyblue')
    axes[0,1].set_title('Feature Importance (CART)')
    axes[0,1].set_xlabel('Features')
    axes[0,1].set_ylabel('Importance')
    axes[0,1].tick_params(axis='x', rotation=45)
else:
    axes[0,1].text(0.5, 0.5, 'No feature importance\ncalculated', 
                   ha='center', va='center', transform=axes[0,1].transAxes)
    axes[0,1].set_title('Feature Importance')

# 3. Decision boundary visualization (for 2 most important features)
if len(feature_importance) >= 2:
    # Get two most important features
    sorted_features = sorted(feature_importance.items(), key=lambda x: x[1], reverse=True)
    feature1, feature2 = sorted_features[0][0], sorted_features[1][0]
    
    # Create a subset with just these two features for visualization
    X_subset = iris_df[[feature1, feature2]]
    y_subset = iris_df['species']
    
    # Create a mesh grid
    h = 0.02
    x_min, x_max = X_subset.iloc[:, 0].min() - 1, X_subset.iloc[:, 0].max() + 1
    y_min, y_max = X_subset.iloc[:, 1].min() - 1, X_subset.iloc[:, 1].max() + 1
    xx, yy = np.meshgrid(np.arange(x_min, x_max, h),
                         np.arange(y_min, y_max, h))
    
    # Train a simple CART tree on just these two features
    simple_cart = CARTDecisionTree(max_depth=3)
    simple_cart.fit(X_subset, y_subset)
    
    # Make predictions on the mesh grid
    mesh_points = pd.DataFrame(np.c_[xx.ravel(), yy.ravel()], columns=[feature1, feature2])
    Z = simple_cart.predict(mesh_points)
    
    # Convert string labels to numeric for plotting
    label_to_num = {'setosa': 0, 'versicolor': 1, 'virginica': 2}
    Z_numeric = [label_to_num[label] for label in Z]
    Z_numeric = np.array(Z_numeric).reshape(xx.shape)
    
    axes[1,0].contourf(xx, yy, Z_numeric, alpha=0.3, cmap=plt.cm.RdYlBu)
    
    # Plot the data points
    colors = ['red', 'green', 'blue']
    for i, species in enumerate(['setosa', 'versicolor', 'virginica']):
        mask = y_subset == species
        axes[1,0].scatter(X_subset[mask].iloc[:, 0], X_subset[mask].iloc[:, 1], 
                         c=colors[i], label=species, alpha=0.7)
    
    axes[1,0].set_xlabel(feature1)
    axes[1,0].set_ylabel(feature2)
    axes[1,0].set_title('Decision Boundary (Top 2 Features)')
    axes[1,0].legend()
else:
    axes[1,0].text(0.5, 0.5, 'Not enough features\nfor boundary plot', 
                   ha='center', va='center', transform=axes[1,0].transAxes)

# 4. Tree depth analysis
def get_tree_depth(tree):
    """Calculate the depth of the tree"""
    if isinstance(tree, str):
        return 0
    
    left_depth = get_tree_depth(tree['left']) if tree['left'] else 0
    right_depth = get_tree_depth(tree['right']) if tree['right'] else 0
    
    return 1 + max(left_depth, right_depth)

def count_nodes(tree):
    """Count total nodes in the tree"""
    if isinstance(tree, str):
        return 1
    
    left_count = count_nodes(tree['left']) if tree['left'] else 0
    right_count = count_nodes(tree['right']) if tree['right'] else 0
    
    return 1 + left_count + right_count

tree_depth = get_tree_depth(cart_tree.tree)
total_nodes = count_nodes(cart_tree.tree)

# Tree statistics
stats_text = f"""Tree Statistics:
• Max Depth: {tree_depth}
• Total Nodes: {total_nodes}
• Leaf Nodes: {total_nodes - (total_nodes // 2)}
• Test Accuracy: {accuracy:.4f}

CART vs ID3:
• Uses Gini Impurity
• Handles continuous features
• Creates binary splits
• More flexible than ID3"""

axes[1,1].text(0.1, 0.9, stats_text, transform=axes[1,1].transAxes, 
               verticalalignment='top', fontsize=11, fontfamily='monospace',
               bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
axes[1,1].set_title('Tree Analysis')
axes[1,1].axis('off')

plt.tight_layout()
plt.show()

print(f"\nFinal Results Summary:")
print(f"{'='*40}")
print(f"Algorithm: CART Decision Tree")
print(f"Dataset: Iris (continuous features)")
print(f"Test Accuracy: {accuracy:.4f}")
print(f"Tree Depth: {tree_depth}")
print(f"Total Nodes: {total_nodes}")
print(f"Total samples: {len(iris_df)}")
print(f"Training samples: {len(X_train)}")
print(f"Test samples: {len(X_test)}")
print(f"Features used: {len(X.columns)}")
print(f"Classes: {len(np.unique(y))}")
if feature_importance:
    print(f"Most important feature: {max(feature_importance, key=feature_importance.get)}")
print("✓ CART implementation completed successfully!")