In [None]:
import numpy as np
from typing import List, Dict, Tuple
from collections import defaultdict

class ImpurityCalculator:
    
    @staticmethod
    def gini(probs: List[float]) -> float:
        return 1 - sum(p**2 for p in probs)
    @staticmethod
    def entropy(probs: List[float], base: float = 2) -> float:
        return -sum(p * np.log(p) / np.log(base) if p > 0 else 0 for p in probs)
    @staticmethod
    def misclassification(probs: List[float]) -> float:
        return 1 - max(probs) if probs else 0

class TreeNode:
    
    def __init__(self, class_counts: Dict[int, int], depth: int = 0, 
                 name: str = "root", parent=None):
        self.class_counts = class_counts
        self.depth = depth
        self.name = name
        self.parent = parent
        self.children = []
        self.total_samples = sum(class_counts.values())
        self.probs = [count / self.total_samples for count in class_counts.values()]
    
    def add_child(self, child_node):
        self.children.append(child_node)
        child_node.parent = self
    
    def get_impurity(self, measure: str = 'gini', base: float = 2) -> float:
        calc = ImpurityCalculator()
        if measure == 'gini':
            return calc.gini(self.probs)
        elif measure == 'entropy':
            return calc.entropy(self.probs, base)
        elif measure == 'misclassification':
            return calc.misclassification(self.probs)
        else:
            raise ValueError(f"Unknown measure: {measure}")

class DecisionTreeAnalyzer:

    def __init__(self, root: TreeNode, measure: str = 'gini', base: float = 2):
        self.root = root
        self.measure = measure
        self.base = base
        self.nodes_by_depth = self._collect_nodes_by_depth()
    
    def _collect_nodes_by_depth(self) -> Dict[int, List[TreeNode]]:
        nodes = defaultdict(list)
        queue = [self.root]
        
        while queue:
            node = queue.pop(0)
            nodes[node.depth].append(node)
            queue.extend(node.children)
        
        return dict(nodes)
    
    def calculate_node_impurity(self, node: TreeNode) -> float:
        return node.get_impurity(self.measure, self.base)
    
    def calculate_depth_quality(self, depth: int) -> Tuple[float, List[Dict]]:
        if depth not in self.nodes_by_depth:
            return None, []
        nodes = self.nodes_by_depth[depth]
        total_samples = sum(node.total_samples for node in nodes)
        node_details = []
        weighted_impurity = 0
        for node in nodes:
            impurity = self.calculate_node_impurity(node)
            weight = node.total_samples / total_samples
            weighted_impurity += weight * impurity 
            node_details.append({
                'name': node.name,
                'samples': node.total_samples,
                'class_counts': node.class_counts,
                'probabilities': node.probs,
                'impurity': impurity,
                'weight': weight,
                'weighted_impurity': weight * impurity
            }) 
        return weighted_impurity, node_details
    
    def analyze_all_depths(self) -> Dict[int, Dict]:
        results = {}
        max_depth = max(self.nodes_by_depth.keys())
        for depth in range(max_depth + 1):
            Q, details = self.calculate_depth_quality(depth)
            results[depth] = {
                'Q': Q,
                'nodes': details
            }
        return results
    
    def print_analysis(self):
        results = self.analyze_all_depths()  
        print(f"{'='*70}")
        print(f"Decision Tree Impurity Analysis")
        print(f"Measure: {self.measure.upper()}")
        if self.measure == 'entropy':
            print(f"Base: {self.base}")
        print(f"{'='*70}\n")
        for depth in sorted(results.keys()):
            data = results[depth]
            print(f"DEPTH {depth}:")
            print(f"Overall Quality Q({depth}) = {data['Q']:.6f}\n") 
            for node in data['nodes']:
                print(f"  Node: {node['name']}")
                print(f"    Samples: {node['samples']}")
                print(f"    Class counts: {node['class_counts']}")
                print(f"    Probabilities: {[f'{p:.4f}' for p in node['probabilities']]}")
                print(f"    Impurity: {node['impurity']:.6f}")
                print(f"    Weight: {node['weight']:.6f}")
                print(f"    Weighted impurity: {node['weighted_impurity']:.6f}")
                print()
            print(f"{'-'*70}\n")


if __name__ == "__main__":

    # ADDING NODES
    root = TreeNode(class_counts={0: 8, 1: 5}, depth=0, name="root")
    
    left_child = TreeNode(class_counts={0: 5, 1: 2}, depth=1, name="left")
    right_child = TreeNode(class_counts={0: 0, 1: 6}, depth=1, name="right")
    
    root.add_child(left_child)
    root.add_child(right_child)
    
    # left_left = TreeNode(class_counts={0: 5, 1: 0, 2: 0}, depth=2, name="left-left")
    # left_right = TreeNode(class_counts={0: 1, 1: 2, 2: 1}, depth=2, name="left-right")
    
    # left_child.add_child(left_left)
    # left_child.add_child(left_right)
    
    # IMPURITY MEASURES
    print("\n" + "="*70)
    print("GINI IMPURITY")
    print("="*70)
    analyzer_gini = DecisionTreeAnalyzer(root, measure='gini')
    analyzer_gini.print_analysis()
    
    print("\n" + "="*70)
    print("ENTROPY (BASE 2)")
    print("="*70)
    analyzer_entropy = DecisionTreeAnalyzer(root, measure='entropy', base=2)
    analyzer_entropy.print_analysis()
    
    print("\n" + "="*70)
    print("MISCLASSIFICATION ERROR")
    print("="*70)
    analyzer_misclass = DecisionTreeAnalyzer(root, measure='misclassification')
    analyzer_misclass.print_analysis()
    
    # OVERALL DEPTH IMPURITY
    print("\n" + "="*70)
    print("Quick queries:")
    print("="*70)
    Q1, _ = analyzer_gini.calculate_depth_quality(1)
    print(f"Gini Q at depth 1: {Q1:.6f}")


GINI IMPURITY
Decision Tree Impurity Analysis
Measure: GINI

DEPTH 0:
Overall Quality Q(0) = 0.473373

  Node: root
    Samples: 13
    Class counts: {0: 8, 1: 5}
    Probabilities: ['0.6154', '0.3846']
    Impurity: 0.473373
    Weight: 1.000000
    Weighted impurity: 0.473373

----------------------------------------------------------------------

DEPTH 1:
Overall Quality Q(1) = 0.219780

  Node: left
    Samples: 7
    Class counts: {0: 5, 1: 2}
    Probabilities: ['0.7143', '0.2857']
    Impurity: 0.408163
    Weight: 0.538462
    Weighted impurity: 0.219780

  Node: right
    Samples: 6
    Class counts: {0: 0, 1: 6}
    Probabilities: ['0.0000', '1.0000']
    Impurity: 0.000000
    Weight: 0.461538
    Weighted impurity: 0.000000

----------------------------------------------------------------------


ENTROPY (BASE 2)
Decision Tree Impurity Analysis
Measure: ENTROPY
Base: 2

DEPTH 0:
Overall Quality Q(0) = 0.961237

  Node: root
    Samples: 13
    Class counts: {0: 8, 1: 5}
  