Реалізувати алгоритм побудови дерева рішень за допомогою алгоритму ID3. На вхід додатку передається CSV файл з навчальною вибіркою. Побудоване дерево виводиться на екран.

In [None]:
import pandas as pd
import math

df = pd.read_csv('./data/salaries_by_college_major.csv')
df = df[['Starting Median Salary', 'Mid-Career Median Salary', 'Group']].dropna()

In [44]:
target_attribute = 'Group'
features = ['Starting Median Salary', 'Mid-Career Median Salary']
overall_majority = df[target_attribute].mode()[0]

In [45]:
def calculate_entropy(target_col):
    counts = target_col.value_counts()
    probabilities = counts / len(target_col)
    entropy = 0
    for probability in probabilities:
        if probability > 0:
            entropy -= probability * math.log2(probability)
    return entropy

def calculate_split(data, features, target_attribute):
    best_information_gain = -1  
    best_split = None
    
    current_entropy = calculate_entropy(data[target_attribute])
    
    for feature in features:
        unique_values = sorted(data[feature].unique())
        
        if len(unique_values) < 2:
            continue

        split_points = []
        for i in range(len(unique_values) - 1):
            split_points.append((unique_values[i] + unique_values[i+1]) / 2)
        
        if not split_points:
            continue

        for split_value in split_points:
            left_subset = data[data[feature] <= split_value]
            right_subset = data[data[feature] > split_value]
            
            if len(left_subset) == 0 or len(right_subset) == 0:
                continue

            proportion_left = len(left_subset) / len(data)
            proportion_right = len(right_subset) / len(data)
            
            entropy_left = calculate_entropy(left_subset[target_attribute])
            entropy_right = calculate_entropy(right_subset[target_attribute])
            
            weighted_entropy = (proportion_left * entropy_left) + (proportion_right * entropy_right)
            
            information_gain = current_entropy - weighted_entropy
            
            if information_gain > best_information_gain:
                best_information_gain = information_gain
                best_split = (feature, split_value)
                        
    return best_split, best_information_gain

def build_tree(data, features, target_attribute, parent_majority_class=None):
    if data.empty:
        return parent_majority_class  

    try:
        majority_class = data[target_attribute].mode()[0]
    except (KeyError, IndexError):
        return parent_majority_class

    if data[target_attribute].nunique() == 1:
        return data[target_attribute].iloc[0]
        
    best_split, best_information_gain = calculate_split(data, features, target_attribute)
    
    if best_split is None or best_information_gain <= 0:
        return majority_class

    best_feature, best_split_value = best_split
    
    node_name = f"{best_feature} <= {best_split_value}"
    tree = {node_name: {}}
    
    left_subset = data[data[best_feature] <= best_split_value]
    right_subset = data[data[best_feature] > best_split_value]
    
    tree[node_name][True] = build_tree(left_subset, features, target_attribute, majority_class)
    tree[node_name][False] = build_tree(right_subset, features, target_attribute, majority_class)
    
    return tree

def display_tree(tree, indent=''):
    if not isinstance(tree, dict):
        print(f"{indent}Group: {tree}")
        return
    
    node, branches = next(iter(tree.items()))
    print(f"{indent}{node}")
    
    print(f"{indent}  |-- True (<=):")
    display_tree(branches[True], indent + '  |    ')
    
    print(f"{indent}  |-- False (>):")
    display_tree(branches[False], indent + '  |    ')

In [46]:
decision_tree = build_tree(df, features, target_attribute, overall_majority)

print("ID3")
display_tree(decision_tree)

ID3
Starting Median Salary <= 41400.0
  |-- True (<=):
  |    Starting Median Salary <= 37300.0
  |      |-- True (<=):
  |      |    Group: HASS
  |      |-- False (>):
  |      |    Mid-Career Median Salary <= 63650.0
  |      |      |-- True (<=):
  |      |      |    Starting Median Salary <= 39500.0
  |      |      |      |-- True (<=):
  |      |      |      |    Group: Business
  |      |      |      |-- False (>):
  |      |      |      |    Group: HASS
  |      |      |-- False (>):
  |      |      |    Mid-Career Median Salary <= 65150.0
  |      |      |      |-- True (<=):
  |      |      |      |    Starting Median Salary <= 38400.0
  |      |      |      |      |-- True (<=):
  |      |      |      |      |    Group: HASS
  |      |      |      |      |-- False (>):
  |      |      |      |      |    Group: STEM
  |      |      |      |-- False (>):
  |      |      |      |    Mid-Career Median Salary <= 78900.0
  |      |      |      |      |-- True (<=):
  |      |     