Arbol de decision desde cero

In [None]:
import pandas as pd
import numpy as np
#from collections import Counter
d = pd.read_csv('https://raw.githubusercontent.com/mrBronnWow/Curso_Beginners/main/1_Dataset_titanic/train.csv')

In [None]:
class Nodo:
  def __init__(self, Y: pd.Series, X: pd.DataFrame, min_samples_split=None,
               max_depth=None, depth=None, node_type=None, rule=None ):
    self.Y = Y
    self.X = X

    self.min_samples_split = min_samples_split if min_samples_split else 20
    self.max_depth = max_depth if max_depth else 5
    self.depth = depth if depth else 0
    self.node_type = node_type if node_type else 'root'
    self.rule = rule if rule else ""

    self.features = list(self.X.columns)
    self.counts = Y.value_counts()
    self.n = len(Y)
    self.gini_impurity = self.get_GINI()

    counts_sorted = list(sorted(self.counts.items(), key=lambda item: item[1]))
    yhat = None
    if len(counts_sorted) > 0: yhat = counts_sorted[-1][0]
    self.yhat = yhat

    self.left = None
    self.right = None
    self.best_feature = None
    self.best_value = None

  @staticmethod
  def cost_gini(Y: pd.Series) -> float: return(1-np.sum((Y.value_counts()/len(Y))**2))

  @staticmethod
  def ma(x: np.array, window: int) -> np.array: return np.convolve(x, np.ones(window), 'valid') / window

  def get_GINI(self): return self.cost_gini(self.Y)

  def GI(self, y: pd.Series, mask): return self.cost_gini(y) - (sum(mask)*self.cost_gini(y[mask]) + sum(-mask)*self.cost_gini(y[-mask]))/(sum(mask)+sum(-mask))

  def best_split(self) -> tuple:
    df = self.X.copy()
    df['Y'] = self.Y
    max_gain = 0
    best_feature = None
    best_value = None

    for feature in self.features:
      Xdf = df.sort_values(feature)
      xmeans = self.ma(Xdf[feature].unique(), 2)
      for value in xmeans:
        GINIgain = self.GI(y=Xdf['Y'], mask=Xdf[feature]<value)
        if GINIgain > max_gain: best_feature, best_value, max_gain = feature, value, GINIgain
    return (best_feature, best_value)

  def grow_tree(self):
    """
    Método recursivo para crear el árbol de decisión.
    """
    df = self.X.copy()
    df['Y'] = self.Y

    # Si hay que ganar GINI, dividimos más
    if (self.depth < self.max_depth) and (self.n >= self.min_samples_split):
      # Conseguir la mejor división
      best_feature, best_value = self.best_split()
      if best_feature is not None:
        self.best_feature, self.best_value = best_feature, best_value
        left_df, right_df = df[df[best_feature]<=best_value].copy(), df[df[best_feature]>best_value].copy()
        # Creando los nodos izquierdo y derecho
        left = Node2(left_df['Y'], left_df[self.features], depth=self.depth + 1, max_depth=self.max_depth,
                    min_samples_split=self.min_samples_split, node_type='left_node', rule=f"{best_feature} <= {round(best_value, 3)}")
        right = Node2(right_df['Y'], right_df[self.features], depth=self.depth + 1, max_depth=self.max_depth,
                     min_samples_split=self.min_samples_split, node_type='right_node', rule=f"{best_feature} > {round(best_value, 3)}")
        self.left = left
        self.right = right
        self.left.grow_tree()
        self.right.grow_tree()

  def print_info(self, width=4):
    """
    Método para imprimir la información sobre el árbol.
    """
    # Definición del número de espacios
    const = int(self.depth*width**1.5)
    spaces = "-" * const
    if self.node_type == 'root': print("Root")
    else: print(f"|{spaces} Split regla: {self.rule}")
    print(f"{' ' * const}   | Impureza Gini del nodo: {round(self.gini_impurity, 2)}")
    print(f"{' ' * const}   | Distribución de clases en el nodo: {dict(self.counts)}")
    print(f"{' ' * const}   | Clase predicha: {self.yhat}")

  def print_tree(self):
    """
    Imprime todo el árbol desde el nodo actual hasta el final
    """
    self.print_info()
    if self.left is not None: self.left.print_tree()
    if self.right is not None: self.right.print_tree()


  def predict(self, X:pd.DataFrame):
    """
    Método de predicción por lotes
    """
    predictions = []
    for _, x in X.iterrows():
      values = {}
      for feature in self.features: values.update({feature: x[feature]})
      predictions.append(self.predict_obs(values))
    return predictions

  def predict_obs(self, values: dict) -> int:
    """
    Método para predecir la clase dado un conjunto de características.
    """
    cur_node = self
    while cur_node.depth < cur_node.max_depth:
      # Atravesando los nodos hasta el final
      best_feature = cur_node.best_feature
      best_value = cur_node.best_value
      if cur_node.n < cur_node.min_samples_split: break
      if (values.get(best_feature) < best_value):
        if self.left is not None: cur_node = cur_node.left
      else:
        if self.right is not None: cur_node = cur_node.right
    return cur_node.yhat

  @staticmethod
  def GI2(y, mask, func=cost_gini):
    a, b = sum(mask), sum(-mask)
    if y.dtypes != 'O': ig = y.var()-(a*y[mask].var()+b*y[-mask].var())/(a+b)
    else: ig = func(y) - (a*y[mask].var() + b*y[-mask].var())/(a+b)
    return ig

In [None]:
dtree = d[['Survived', 'Age', 'Fare']].dropna().copy()
Y = dtree['Survived']
X = dtree[['Age', 'Fare']]
hp = {'max_depth': 4, 'min_samples_split': 50}
root2 = Node(Y, X, **hp)
root2.grow_tree()
root2.print_tree()

Root
   | GINI impurity of the node: 0.48
   | Class distribution in the node: {0: 424, 1: 290}
   | Predicted class: 0
|-------- Split rule: Fare <= 52.277
           | GINI impurity of the node: 0.44
           | Class distribution in the node: {0: 389, 1: 195}
           | Predicted class: 0
|---------------- Split rule: Fare <= 10.481
                   | GINI impurity of the node: 0.32
                   | Class distribution in the node: {0: 192, 1: 47}
                   | Predicted class: 0
|------------------------ Split rule: Age <= 32.5
                           | GINI impurity of the node: 0.37
                           | Class distribution in the node: {0: 134, 1: 43}
                           | Predicted class: 0
|-------------------------------- Split rule: Age <= 16.5
                                   | GINI impurity of the node: 0.5
                                   | Class distribution in the node: {0: 7, 1: 7}
                                   | Predicted class: