In [1]:
# Exercise 3: Minimum Number of Objects (MNO) Pruning

import math
import numpy as np
import copy

N = 2

labels = [1,1,-1,1,-1,-1,1,1,1,1,1,1,-1,-1,-1]
data_set = [[1,1,2,2],[2,1,2,2],[1,1,1,2],[1,2,1,2],[2,3,2,2],
            [2,2,1,2],[3,2,2,1],[1,3,2,2],[3,3,2,1],[2,3,1,2],
            [3,1,1,1],[1,2,1,1],[2,3,1,1],[2,1,1,2],[2,2,1,1]]

class BinaryLeaf:
    def __init__(self, elements, labels, ids):
        self.L = None
        self.R = None
        self.elements = elements
        self.split_feature = None
        self.split_value = None
        self.labels = labels
        self.completed = False
        self.ids = ids
        self.validated = False

    def set_R(self, Rleaf): self.R = Rleaf
    def set_L(self, Lleaf): self.L = Lleaf
    def get_L(self): return self.L
    def get_R(self): return self.R
    def set_completed(self): self.completed = True
    def is_completed(self): return self.completed
    def get_elements(self): return self.elements
    def get_labels(self): return self.labels
    def get_ids(self): return self.ids
    def set_ids(self, ids): self.ids = ids
    def set_split(self, feature): self.split_feature = feature
    def get_split(self): return self.split_feature
    def set_split_value(self, value): self.split_value = value
    def get_split_value(self): return self.split_value

labels_count = len(np.unique(labels))
ids = list(range(len(data_set)))
root = BinaryLeaf(data_set, labels, ids)
current_node = root

def get_unique_labels(labels):
    return np.unique(np.array(labels)).tolist()

def get_unique_values(elements):
    features_number = len(elements[0])
    unique = []
    for i in range(features_number):
        features_list = [elements[j][i] for j in range(len(elements))]
        unique.append(np.unique(np.array(features_list)))
    return unique

def is_leaf_completed(node):
    if node.is_completed():
        if node.get_L() and not node.get_L().is_completed(): return node.get_L()
        if node.get_R() and not node.get_R().is_completed(): return node.get_R()
        if node.get_L() is None and node.get_R() is None: return None
        if node.get_L().is_completed() or node.get_R().is_completed():
            new_node = is_leaf_completed(node.get_L())
            return new_node if new_node else is_leaf_completed(node.get_R())
        return None
    return node

def find_leaf_not_completed(root):
    return is_leaf_completed(root)

def get_split_candidates(unique_values):
    split_list = []
    for i in range(len(unique_values)):
        current_list = [copy.deepcopy(unique_values[i])]
        temp_list = copy.deepcopy(unique_values)
        del temp_list[i]
        current_list.append(temp_list)
        split_list.append(current_list)
    return split_list

def get_node_elements_column(column_id):
    return np.array(current_node.elements)[..., column_id].tolist()

def count_number_of_elements(elements, column_id):
    column_elements = get_node_elements_column(column_id)
    if isinstance(elements, list):
        return sum(column_elements.count(e) for e in elements)
    return column_elements.count(elements)

def get_number_of_labels_for_value(elements, column_id, label):
    column_elements = get_node_elements_column(column_id)
    if not isinstance(elements, list): elements = [elements]
    return sum(1 for e in elements for j, v in enumerate(column_elements) if v == e and current_node.labels[j] == label)

def calculate_omega(elements, column_id):
    t_l = count_number_of_elements(elements[0], column_id)
    t_r = count_number_of_elements(elements[1], column_id)
    p_l = t_l / len(current_node.elements)
    p_r = t_r / len(current_node.elements)
    sum_p = 0
    labels = get_unique_labels(current_node.labels)
    for lbl in labels:
        p_class_l = get_number_of_labels_for_value(elements[0], column_id, lbl) / (t_l or 1)
        p_class_r = get_number_of_labels_for_value(elements[1], column_id, lbl) / (t_r or 1)
        sum_p += abs(p_class_l - p_class_r)
    return 2.0 * p_l * p_r * sum_p

def check_completed(labels, elements):
    if len(get_unique_labels(labels)) == 1: return True
    elements = sorted(elements)
    return len([elements[i] for i in range(len(elements)) if i == 0 or elements[i] != elements[i - 1]]) == 1

def split_node(current_node, value, split_id, split_history):
    left_leaf, left_labels, left_ids = [], [], []
    right_leaf, right_labels, right_ids = [], [], []

    for i in range(len(current_node.elements)):
        if current_node.elements[i][split_id] == value:
            left_leaf.append(current_node.elements[i])
            left_labels.append(current_node.labels[i])
            left_ids.append(current_node.ids[i])
        else:
            right_leaf.append(current_node.elements[i])
            right_labels.append(current_node.labels[i])
            right_ids.append(current_node.ids[i])

    # --- MNO pruning condition ---
    if len(left_leaf) < N or len(right_leaf) < N:
        current_node.set_completed()
        return current_node, split_history

    split_history.append([str(current_node.ids), str(left_ids)])
    split_history.append([str(current_node.ids), str(right_ids)])
    current_node.set_L(BinaryLeaf(left_leaf, left_labels, left_ids))
    current_node.set_R(BinaryLeaf(right_leaf, right_labels, right_ids))
    current_node.set_split(split_id)
    current_node.set_split_value(value)
    current_node.set_completed()

    if check_completed(left_labels, left_leaf):
        current_node.get_L().set_completed()
    if check_completed(right_labels, right_leaf):
        current_node.get_R().set_completed()

    return current_node, split_history

def build(root_node):
    current_node = root_node
    stop_criterion = False
    split_history = []
    while not stop_criterion:
        unique_values = get_unique_values(current_node.get_elements())
        max_unique_id = 0
        max_split_id = 0
        max_value = 0
        for i in range(len(unique_values)):
            if len(unique_values[i]) == 1: continue
            split_candidates = get_split_candidates(unique_values[i].tolist())
            for j in range(len(split_candidates)):
                current_value = calculate_omega(split_candidates[j], i)
                if current_value > max_value:
                    max_value = current_value
                    max_unique_id = i
                    max_split_id = j
        current_node, split_history = split_node(current_node, unique_values[max_unique_id][max_split_id], max_unique_id, split_history)
        new_node = find_leaf_not_completed(root_node)
        current_node = new_node if new_node else current_node
        stop_criterion = new_node is None
    return root_node, split_history

# Test MNO pruning
dt_tree, split_history_cart = build(current_node)


In [2]:
dt_tree

<__main__.BinaryLeaf at 0x26860870830>

In [3]:
split_history_cart

[['[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]',
  '[1, 4, 5, 9, 12, 13, 14]'],
 ['[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14]',
  '[0, 2, 3, 6, 7, 8, 10, 11]'],
 ['[1, 4, 5, 9, 12, 13, 14]', '[5, 9, 12, 13, 14]'],
 ['[1, 4, 5, 9, 12, 13, 14]', '[1, 4]'],
 ['[0, 2, 3, 6, 7, 8, 10, 11]', '[2, 3, 10, 11]'],
 ['[0, 2, 3, 6, 7, 8, 10, 11]', '[0, 6, 7, 8]'],
 ['[5, 9, 12, 13, 14]', '[12, 14]'],
 ['[5, 9, 12, 13, 14]', '[5, 9, 13]'],
 ['[2, 3, 10, 11]', '[10, 11]'],
 ['[2, 3, 10, 11]', '[2, 3]']]