In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
import pandas as pd
from pyspark.rdd import RDD
from pyspark.sql.functions import col
from d_imm.histogram.histogram import DecisionTreeSplitFinder, Instance
from pyspark.ml.clustering import KMeans
import pandas as pd
from d_imm.imm_model import DistributedIMM

In [2]:
import os

# Set Java environment variable if needed
os.environ["JAVA_HOME"] = "C:\\Program Files\\Java\\jdk1.8.0_261"
os.environ["PYSPARK_PYTHON"] = "C:\\Users\\saadha\\Desktop\\FYP-code\\GITHUB\\distributed-imm\\d-imm-python\\version-1\\venv\\Scripts\\python.exe"
os.environ["PYSPARK_DRIVER_PYTHON"] = "C:\\Users\\saadha\\Desktop\\FYP-code\\GITHUB\\distributed-imm\\d-imm-python\\version-1\\venv\\Scripts\\python.exe"

In [3]:
# Set up Spark session
spark = SparkSession.builder \
        .appName("DecisionTreeSplitFinderExample") \
        .master("local[*]") \
        .getOrCreate()
sc = spark.sparkContext

In [4]:
# Load the Iris dataset from the UCI Machine Learning Repository
url = "https://archive.ics.uci.edu/ml/machine-learning-databases/iris/iris.data"
column_names = ["sepal_length", "sepal_width", "petal_length", "petal_width", "species"]
iris_df = pd.read_csv(url, header=None, names=column_names)

# Convert the pandas DataFrame to a Spark DataFrame
df_1 = spark.createDataFrame(iris_df)

# Stack the dataset 5 times row-wise
df = df_1
for _ in range(5):
    df = df.union(df_1)

In [5]:
# Assemble features into a single vector column
assembler = VectorAssembler(
    inputCols=["sepal_length", "sepal_width", "petal_length", "petal_width"],
    outputCol="features"
)
feature_df = assembler.transform(df)

# Convert to RDD of Instances
def to_instance(row):
    features = [row.sepal_length, row.sepal_width, row.petal_length, row.petal_width]
    label = 1.0 if row.species == "Iris-setosa" else 0.0  # Simplified binary label
    weight = 1.0
    return Instance(features=features, label=label, weight=weight)

iris_rdd: RDD[Instance] = feature_df.select("sepal_length", "sepal_width", "petal_length", "petal_width", "species").rdd.map(to_instance)

In [6]:
iris_rdd.take(5)

[Instance(features=[5.1, 3.5, 1.4, 0.2], label=1.0, weight=1.0),
 Instance(features=[4.9, 3.0, 1.4, 0.2], label=1.0, weight=1.0),
 Instance(features=[4.7, 3.2, 1.3, 0.2], label=1.0, weight=1.0),
 Instance(features=[4.6, 3.1, 1.5, 0.2], label=1.0, weight=1.0),
 Instance(features=[5.0, 3.6, 1.4, 0.2], label=1.0, weight=1.0)]

In [7]:
# Initialize DecisionTreeSplitFinder
num_features = 4
is_continuous = [True, True, True, True]  # All features are continuous
is_unordered = [False, False, False, False]  # No categorical unordered features
max_splits_per_feature = [10, 10, 10, 10]  # Max splits per feature
max_bins = 32
total_weighted_examples = float(750)
seed = 42

split_finder = DecisionTreeSplitFinder(
    num_features=num_features,
    is_continuous=is_continuous,
    is_unordered=is_unordered,
    max_splits_per_feature=max_splits_per_feature,
    max_bins=max_bins,
    total_weighted_examples=total_weighted_examples,
    seed=seed
)

# Find splits
splits = split_finder.find_splits(input_rdd=iris_rdd)

# Print the splits
for fidx, feature_splits in enumerate(splits):
    if is_continuous[fidx]:
        print(f"Feature {fidx} (Continuous) splits:")
        for s in feature_splits:
            print(f"  Threshold = {s.threshold}")
    else:
        print(f"Feature {fidx} (Categorical) splits:")
        for s in feature_splits:
            print(f"  Categories = {s.categories}")
        if not feature_splits:
            print("  No splits found.")

# Stop Spark
spark.stop()

Feature 0 (Continuous) splits:
  Threshold = 4.85
  Threshold = 5.05
  Threshold = 5.15
  Threshold = 5.45
  Threshold = 5.65
  Threshold = 5.95
  Threshold = 6.15
  Threshold = 6.35
  Threshold = 6.65
  Threshold = 6.95
Feature 1 (Continuous) splits:
  Threshold = 2.45
  Threshold = 2.6500000000000004
  Threshold = 2.8499999999999996
  Threshold = 2.95
  Threshold = 3.05
  Threshold = 3.1500000000000004
  Threshold = 3.25
  Threshold = 3.3499999999999996
  Threshold = 3.45
  Threshold = 3.6500000000000004
Feature 2 (Continuous) splits:
  Threshold = 1.35
  Threshold = 1.45
  Threshold = 1.65
  Threshold = 3.55
  Threshold = 4.15
  Threshold = 4.45
  Threshold = 4.75
  Threshold = 5.05
  Threshold = 5.45
  Threshold = 5.85
Feature 3 (Continuous) splits:
  Threshold = 0.15000000000000002
  Threshold = 0.25
  Threshold = 0.35
  Threshold = 1.05
  Threshold = 1.25
  Threshold = 1.35
  Threshold = 1.55
  Threshold = 1.75
  Threshold = 1.95
  Threshold = 2.25
