Реалізувати алгоритм побудови дерева рішень за допомогою Байєсівського класифікатору. На вхід додатку передається CSV файл з навчальною вибіркою. Побудоване дерево виводиться на екран.

Порівняти результати з результатами побудови за допомогою алгоритму ID3

In [47]:
import pandas as pd
import math

df = pd.read_csv('./data/salaries_by_college_major.csv')
df = df[['Starting Median Salary', 'Mid-Career Median Salary', 'Group']].dropna()
target = 'Group'

In [48]:
target_attribute = 'Group'
features = ['Starting Median Salary', 'Mid-Career Median Salary']
overall_majority = df[target_attribute].mode()[0]

In [None]:
def calculate_entropy(series):
    counts = series.value_counts()
    probabilities = counts / len(series)
    entropy = 0
    for p in probabilities:
        if p > 0:
            entropy -= p * math.log2(p)
    return entropy

def calculate_bayes_error(series):
    if len(series) == 0:
        return 0
    counts = series.value_counts()
    max_prob = counts.max() / len(series)
    return 1 - max_prob


def calculate_split(data, features, target_attribute, criterion='id3'):
    best_gain = -1
    
    if criterion == 'id3':
        parent_metric = calculate_entropy(data[target_attribute])
    else:
        parent_metric = calculate_bayes_error(data[target_attribute])
    
    for feature in features:
        unique_values = sorted(data[feature].unique())
        
        if len(unique_values) < 2:
            continue

        split_points = []
        for i in range(len(unique_values) - 1):
            split_points.append((unique_values[i] + unique_values[i+1]) / 2)
        
        for split_value in split_points:
            left_subset = data[data[feature] <= split_value]
            right_subset = data[data[feature] > split_value]
            
            if len(left_subset) == 0 or len(right_subset) == 0:
                continue

            prop_left = len(left_subset) / len(data)
            prop_right = len(right_subset) / len(data)
            
            if criterion == 'id3':
                metric_left = calculate_entropy(left_subset[target_attribute])
                metric_right = calculate_entropy(right_subset[target_attribute])
            else:
                metric_left = calculate_bayes_error(left_subset[target_attribute])
                metric_right = calculate_bayes_error(right_subset[target_attribute])
            
            weighted_metric = (prop_left * metric_left) + (prop_right * metric_right)
            
            gain = parent_metric - weighted_metric
            
            if gain > best_gain:
                best_gain = gain
                best_split = (feature, split_value)
                        
    return best_split, best_gain

def build_tree(data, features, target_attribute, criterion='id3', parent_majority_class=None):
    if data.empty:
        return parent_majority_class  

    try:
        majority_class = data[target_attribute].mode()[0]
    except (KeyError, IndexError):
        return parent_majority_class

    if data[target_attribute].nunique() == 1:
        return data[target_attribute].iloc[0]
        
    best_split, best_gain = calculate_split(data, features, target_attribute, criterion)
    
    if best_split is None or best_gain <= 0:
        return majority_class

    best_feature, best_split_value = best_split
    
    node_name = f"{best_feature} <= {best_split_value}"
    tree = {node_name: {}}
    
    left_subset = data[data[best_feature] <= best_split_value]
    right_subset = data[data[best_feature] > best_split_value]
    
    tree[node_name][True] = build_tree(left_subset, features, target_attribute, criterion, majority_class)
    tree[node_name][False] = build_tree(right_subset, features, target_attribute, criterion, majority_class)
    
    return tree

def display_tree(tree, indent=''):
    if not isinstance(tree, dict):
        print(f"{indent}Group: {tree}")
        return
    
    node, branches = next(iter(tree.items()))
    print(f"{indent}{node}")
    
    print(f"{indent} | -- True (<=):")
    display_tree(branches[True], indent + ' |    ')
    
    print(f"{indent} | -- False (>):")
    display_tree(branches[False], indent + ' |    ')

In [50]:
print("ID3")
tree_id3 = build_tree(df, features, target_attribute, criterion='id3', parent_majority_class=overall_majority)
display_tree(tree_id3)

ID3
Starting Median Salary <= 41400.0
 | -- True (<=):
 |    Starting Median Salary <= 37300.0
 |     | -- True (<=):
 |     |    Group: HASS
 |     | -- False (>):
 |     |    Mid-Career Median Salary <= 63650.0
 |     |     | -- True (<=):
 |     |     |    Starting Median Salary <= 39500.0
 |     |     |     | -- True (<=):
 |     |     |     |    Group: Business
 |     |     |     | -- False (>):
 |     |     |     |    Group: HASS
 |     |     | -- False (>):
 |     |     |    Mid-Career Median Salary <= 65150.0
 |     |     |     | -- True (<=):
 |     |     |     |    Starting Median Salary <= 38400.0
 |     |     |     |     | -- True (<=):
 |     |     |     |     |    Group: HASS
 |     |     |     |     | -- False (>):
 |     |     |     |     |    Group: STEM
 |     |     |     | -- False (>):
 |     |     |     |    Mid-Career Median Salary <= 78900.0
 |     |     |     |     | -- True (<=):
 |     |     |     |     |    Group: HASS
 |     |     |     |     | -- False (>):

In [51]:
print("Bayes")
tree_bayes = build_tree(df, features, target_attribute, criterion='bayes', parent_majority_class=overall_majority)
display_tree(tree_bayes)

Bayes
Starting Median Salary <= 41400.0
 | -- True (<=):
 |    Starting Median Salary <= 34050.0
 |     | -- True (<=):
 |     |    Group: HASS
 |     | -- False (>):
 |     |    Mid-Career Median Salary <= 64750.0
 |     |     | -- True (<=):
 |     |     |    Starting Median Salary <= 37300.0
 |     |     |     | -- True (<=):
 |     |     |     |    Group: HASS
 |     |     |     | -- False (>):
 |     |     |     |    Starting Median Salary <= 39500.0
 |     |     |     |     | -- True (<=):
 |     |     |     |     |    Mid-Career Median Salary <= 63650.0
 |     |     |     |     |     | -- True (<=):
 |     |     |     |     |     |    Group: Business
 |     |     |     |     |     | -- False (>):
 |     |     |     |     |     |    Group: HASS
 |     |     |     |     | -- False (>):
 |     |     |     |     |    Group: HASS
 |     |     | -- False (>):
 |     |     |    Mid-Career Median Salary <= 64850.0
 |     |     |     | -- True (<=):
 |     |     |     |    Group: STEM
 |