In [2]:
import csv
import math
from collections import Counter

class TreeNode:
    def __init__(self, ids=None, children=None, gini=0, depth=0):
        self.ids = ids or []
        self.gini = gini
        self.depth = depth
        self.feature_index = None
        self.threshold = None
        self.left = None
        self.right = None
        self.label = None

    def set_properties(self, feature_index, threshold):
        self.feature_index = feature_index
        self.threshold = threshold

    def set_label(self, label):
        self.label = label

def gini_impurity(y):
    counter = Counter(y)
    impurity = 1
    for count in counter.values():
        prob = count / len(y)
        impurity -= prob ** 2
    return impurity

class CART:
    def __init__(self, max_depth=10, min_samples_split=2, min_impurity_decrease=1e-7):
        self.root = None
        self.max_depth = max_depth
        self.min_samples_split = min_samples_split
        self.min_impurity_decrease = min_impurity_decrease

    def fit(self, X, y):
        self.n_features = len(X[0])
        self.n_classes = len(set(y))
        self.tree = self._grow_tree(X, y)

    def _grow_tree(self, X, y, depth=0):
        n_samples = len(y)
        node = TreeNode(ids=list(range(n_samples)), gini=gini_impurity(y), depth=depth)

        if depth >= self.max_depth or n_samples < self.min_samples_split or node.gini <= self.min_impurity_decrease:
            node.set_label(max(set(y), key=y.count))
            return node

        best_gini = float('inf')
        best_feature = None
        best_threshold = None
        for feature in range(self.n_features):
            thresholds = sorted(set(X[i][feature] for i in range(n_samples)))
            for threshold in thresholds:
                left = [i for i in range(n_samples) if X[i][feature] <= threshold]
                right = [i for i in range(n_samples) if i not in left]

                if len(left) == 0 or len(right) == 0:
                    continue

                gini_left = gini_impurity([y[i] for i in left])
                gini_right = gini_impurity([y[i] for i in right])
                gini = (len(left) * gini_left + len(right) * gini_right) / n_samples

                if gini < best_gini:
                    best_gini = gini
                    best_feature = feature
                    best_threshold = threshold

        if best_feature is None or node.gini - best_gini <= self.min_impurity_decrease:
            node.set_label(max(set(y), key=y.count))
            return node

        node.set_properties(best_feature, best_threshold)

        left_indices = [i for i in range(n_samples) if X[i][best_feature] <= best_threshold]
        right_indices = [i for i in range(n_samples) if i not in left_indices]

        node.left = self._grow_tree([X[i] for i in left_indices], [y[i] for i in left_indices], depth + 1)
        node.right = self._grow_tree([X[i] for i in right_indices], [y[i] for i in right_indices], depth + 1)

        return node

    def predict(self, X):
        return [self._predict_one(x) for x in X]

    def _predict_one(self, x):
        node = self.tree
        while node.left:
            if x[node.feature_index] <= node.threshold:
                node = node.left
            else:
                node = node.right
        return node.label

def load_csv(filename):
    with open(filename, 'r') as f:
        reader = csv.reader(f)
        return list(reader)

if __name__ == "__main__":
    # Đọc dữ liệu từ file CSV
    data = load_csv('weather.csv')
    headers = data[0]
    data = data[1:]  # Bỏ qua hàng tiêu đề

    # Tách features và target
    X = [row[:-1] for row in data]
    y = [row[-1] for row in data]

    # Chuyển đổi dữ liệu sang số nếu cần
    for i in range(len(X)):
        for j in range(len(X[i])):
            try:
                X[i][j] = float(X[i][j])
            except ValueError:
                pass  # Giữ nguyên giá trị chuỗi nếu không thể chuyển đổi

    cart = CART(max_depth=3, min_samples_split=2)
    cart.fit(X, y)
    predictions = cart.predict(X)
    print("Dự đoán:", predictions)

Dự đoán: ['no', 'no', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'yes', 'no']
