In [None]:
# Exercise 4: Plot OC1 Tree by Feature ID (Corrected)

import math
import numpy as np
import pydot
from random import randint, random
import copy

class BinaryLeaf:
    def __init__(self, elements, labels, ids):
        self.L = None
        self.R = None
        self.p = None
        self.elements = elements
        self.labels = labels
        self.completed = False
        self.ids = ids
        self.split_feature_id = None
        self.split_value = None

    def set_R(self, Rleaf): self.R = Rleaf
    def set_L(self, Lleaf): self.L = Lleaf
    def get_L(self): return self.L
    def get_R(self): return self.R
    def set_completed(self): self.completed = True
    def is_completed(self): return self.completed
    def get_elements(self): return self.elements
    def get_labels(self): return self.labels
    def get_ids(self): return self.ids
    def set_ids(self, ids): self.ids = ids
    def set_split_feature(self, feature_id): self.split_feature_id = feature_id
    def get_split_feature(self): return self.split_feature_id
    def set_split_value(self, value): self.split_value = value
    def get_split_value(self): return self.split_value

def compute_v(element, scv):
    element = np.array(element)
    scv = np.array(scv)
    return np.dot(element, scv[:-1]) + scv[-1]

def compare_two_leafs(leaf1, leaf2):
    return leaf1.labels == leaf2.labels

def is_leaf_completed(node):
    if node.is_completed():
        if node.get_L() and not node.get_L().is_completed(): return node.get_L()
        if node.get_R() and not node.get_R().is_completed(): return node.get_R()
        if node.get_L() is None and node.get_R() is None: return None
        new_node = is_leaf_completed(node.get_L())
        return new_node if new_node else is_leaf_completed(node.get_R())
    return node

def calculate_gini(labels):
    if len(labels) == 0: return 0
    labels = np.array(labels)
    _, counts = np.unique(labels, return_counts=True)
    probs = counts / counts.sum()
    return np.sum(probs ** 2)

def get_all_possible_splits_by_gini(leaf):
    ginis = []
    for i in range(len(leaf.elements[0])):
        feature_column = np.array(leaf.elements)[:, i]
        best_gini = float('inf')
        best_val = None
        for feature in feature_column:
            distinguish = feature_column <= feature
            left_labels  = np.array(leaf.labels)[distinguish]
            right_labels = np.array(leaf.labels)[~distinguish]
            gini = 1 - calculate_gini(left_labels) - calculate_gini(right_labels)
            if gini < best_gini:
                best_val = feature
                best_gini = gini
        ginis.append([i, best_val, best_gini])
    return ginis

def divide_data_hiperplane(leaf, scv):
    below, above, below_labels, above_labels, below_ids, above_ids = [], [], [], [], [], []
    for i, el in enumerate(leaf.elements):
        if compute_v(el, scv) > 0:
            above.append(el); above_labels.append(leaf.labels[i]); above_ids.append(leaf.ids[i])
        else:
            below.append(el); below_labels.append(leaf.labels[i]); below_ids.append(leaf.ids[i])
    return below, above, below_labels, above_labels, below_ids, above_ids

def get_coefficiency(splits):
    scv = np.zeros(splits[-1][0] + 1 + 1)  # ensure length equals feature count + bias
    min_idx = np.argmin([x[2] for x in splits])
    scv[splits[min_idx][0]] = 1
    scv[-1] = -splits[min_idx][1]  # bias
    return scv, splits[min_idx][0]

def compute_u(element, scv, feature):
    denom = element[feature] if element[feature] != 0 else 1e-9
    return (scv[feature] * element[feature] - compute_v(element, scv)) / denom

def sort_u(u): return np.sort(u)

def perturb(leaf, scv, feature, old_gini):
    u = [compute_u(el, scv, feature) for el in leaf.elements]
    splits = sort_u(np.array(u))
    am = []
    for split in splits:
        new_scv = np.copy(scv)
        new_scv[feature] = split
        below, above, below_label, above_label, _, _ = divide_data_hiperplane(leaf, new_scv)
        gini = 1 - calculate_gini(below_label) - calculate_gini(above_label)
        am.append([new_scv, gini])
    am = np.array(am, dtype=object)
    best_idx = np.argmin([x[1] for x in am])
    best_gini = am[best_idx][1]
    if best_gini < old_gini or (best_gini == old_gini and random() < 0.3):
        return best_gini, am[best_idx][0]
    return old_gini, scv

def build_level(root, split_history):
    leaf = is_leaf_completed(root)
    if leaf is None: return root, split_history

    splits = get_all_possible_splits_by_gini(leaf)
    scv, feature_id = get_coefficiency(splits)
    leaf.set_split_feature(feature_id)

    below, above, below_label, above_label, below_ids, above_ids = divide_data_hiperplane(leaf, scv)
    gini = 1 - calculate_gini(below_label) - calculate_gini(above_label)

    for _ in range(10):
        feature = randint(0, len(leaf.elements[0]) - 1)
        gini, scv = perturb(leaf, scv, feature, gini)
        below, above, below_label, above_label, below_ids, above_ids = divide_data_hiperplane(leaf, scv)

    left_leaf = BinaryLeaf(below, below_label, below_ids)
    right_leaf = BinaryLeaf(above, above_label, above_ids)

    left_leaf.set_split_feature(feature_id)
    right_leaf.set_split_feature(feature_id)

    split_history.append([str(leaf.get_split_feature()), str(left_leaf.get_split_feature())])
    split_history.append([str(leaf.get_split_feature()), str(right_leaf.get_split_feature())])

    leaf.set_completed()
    if len(np.unique(below_label)) == 1:
        left_leaf.set_completed()
    if len(np.unique(above_label)) == 1:
        right_leaf.set_completed()
    if compare_two_leafs(leaf, left_leaf) or compare_two_leafs(leaf, right_leaf):
        leaf.set_completed()
    else:
        leaf.set_L(left_leaf)
        leaf.set_R(right_leaf)

    return build_level(root, split_history)

def build(root):
    split_history = []
    return build_level(root, split_history)

def plot_tree(split_history):
    tree = pydot.Dot(graph_type='graph')
    for split in split_history:
        parent, child = split
        edge = pydot.Edge(f"Feature {parent}", f"Feature {child}", fillcolor='red')
        tree.add_edge(edge)
    tree.write('oc1_tree.png', format='png')

# Test setup
labels = [1, 1, -1, 1, -1, -1, 1, 1, 1, 1, 1, 1, -1, -1, -1]
data_set = [[1, 1, 2, 2], [2, 1, 2, 2], [1, 1, 1, 2], [1, 2, 1, 2], [2, 3, 2, 2],
            [2, 2, 1, 2], [3, 2, 2, 1], [1, 3, 2, 2], [3, 3, 2, 1], [2, 3, 1, 2],
            [3, 1, 1, 1], [1, 2, 1, 1], [2, 3, 1, 1], [2, 1, 1, 2], [2, 2, 1, 1]]
ids = list(range(len(data_set)))
root = BinaryLeaf(data_set, labels, ids)
oc1_tree, split_history = build(root)
plot_tree(split_history)