In [None]:
import unittest
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, lit
from decision_tree import DecisionTree, TreeNode
from random_forest import RandomForest
from feature_preprocessing import preprocess_features, equal_frequency_binning
from split_management import apply_split
from impurity_calculators import ImpurityCalculator


In [None]:
class DecisionTreeTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.master("local[2]").appName("Testing").getOrCreate()

    @classmethod
    def tearDownClass(cls):
        cls.spark.stop()

    def test_tree_initialization(self):
        dt = DecisionTree()
        self.assertIsNone(dt.root)
        self.assertEqual(dt.max_depth, 5)

    def test_predict(self):
        # Assuming a simple tree for testing prediction
        root = TreeNode(0, 10, TreeNode(prediction=0), TreeNode(prediction=1))
        self.assertEqual(root.predict([5]), 0)
        self.assertEqual(root.predict([15]), 1)

    def test_training(self):
        data = self.spark.createDataFrame([
            (1, 'A', 0),
            (2, 'B', 0),
            (3, 'C', 1),
            (4, 'D', 1)
        ], ["feature", "category", "label"])
        dt = DecisionTree(max_depth=3, min_instances_per_node=1)
        dt.train(data, ["feature", "category"], "label")
        self.assertIsNotNone(dt.root)
        self.assertTrue(dt.root.is_leaf or dt.root.left is not None)

class RandomForestTests(unittest.TestCase):
    @classmethod
    def setUpClass(cls):
        cls.spark = SparkSession.builder.master("local[2]").appName("TestingRF").getOrCreate()

    def test_rf_initialization(self):
        rf = RandomForest(num_trees=3, max_depth=3)
        self.assertEqual(len(rf.trees), 3)

    def test_rf_training(self):
        data = self.spark.createDataFrame([
            (1, 'A', 0),
            (2, 'B', 0),
            (3, 'C', 1),
            (4, 'D', 1)
        ], ["feature", "category", "label"])
        rf = RandomForest(num_trees=2, max_depth=2)
        rf.train(data, ["feature", "category"], "label")
        self.assertEqual(len(rf.trees), 2)
        for tree in rf.trees:
            self.assertIsNotNone(tree.root)

if __name__ == '__main__':
    unittest.main()
