In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.clustering import KMeans
from pyspark.ml.feature import VectorAssembler
from pyspark.sql.functions import col
import numpy as np

from pyspark.ml.functions import vector_to_array

from d_imm.imm_model import DistributedIMM

In [2]:
spark = SparkSession.builder \
    .appName("DistributedIMM Dummy Data Test") \
    .master("local[*]") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "4g") \
    .getOrCreate()

In [3]:
# Dummy dataset (5 rows, 3 features)
dummy_data = [
    (1.0, 2.0, 3.0),
    (2.0, 3.0, 4.0),
    (3.0, 4.0, 5.0),
    (8.0, 9.0, 10.0),
    (9.0, 10.0, 11.0),
]

# Create DataFrame
columns = ["feature1", "feature2", "feature3"]
df = spark.createDataFrame(dummy_data, columns)

# Assemble features into a single vector column
assembler = VectorAssembler(inputCols=columns, outputCol="features")
feature_df = assembler.transform(df).select("features")

# Train a KMeans model (k=2)
kmeans = KMeans().setK(2).setSeed(1).setFeaturesCol("features")
kmeans_model = kmeans.fit(feature_df)

# Train the DistributedIMM model
d_imm_tree = DistributedIMM(spark, k=2, verbose=2).fit(feature_df, kmeans_model)

Running 'fit' method
Tree building completed.
Time taken to fill stats: 0 minutes and 38.37 seconds


In [4]:
# Compute Score and Surrogate Score with udf
score_value = d_imm_tree.score(feature_df)
surrogate_score_value = d_imm_tree.surrogate_score(feature_df)

# Print results
print("\n===== Distributed IMM Score Testing with Dummy Data =====")
print(f"Score (K-Means Cost): {score_value:.4f}")
print(f"Surrogate Score (K-Means Surrogate Cost): {surrogate_score_value:.4f}")

# Ensure surrogate score is greater than or equal to k-means score
assert surrogate_score_value >= score_value, "Surrogate score should be greater than or equal to normal score."

print("✅ Score and Surrogate Score tests passed successfully.")


===== Distributed IMM Score Testing with Dummy Data =====
Score (K-Means Cost): 7.5000
Surrogate Score (K-Means Surrogate Cost): 7.5000
✅ Score and Surrogate Score tests passed successfully.


In [5]:
# Compute Score and Surrogate Score with spark sql
score_value = d_imm_tree.score_sql(feature_df)
surrogate_score_value = d_imm_tree.surrogate_score_sql(feature_df)

# Print results
print("\n===== Distributed IMM Score Testing with Dummy Data =====")
print(f"Score (K-Means Cost): {score_value:.4f}")
print(f"Surrogate Score (K-Means Surrogate Cost): {surrogate_score_value:.4f}")

# Ensure surrogate score is greater than or equal to k-means score
assert surrogate_score_value >= score_value, "Surrogate score should be greater than or equal to normal score."

print("✅ Score and Surrogate Score tests passed successfully.")

Py4JJavaError: An error occurred while calling o495.collectToPython.
: org.apache.spark.SparkException: Job aborted due to stage failure: Task 2 in stage 67.0 failed 1 times, most recent failure: Lost task 2.0 in stage 67.0 (TID 256) (LAPTOP-P8U0OKKO executor driver): java.net.SocketException: An established connection was aborted by the software in your host machine
	at java.base/sun.nio.ch.NioSocketImpl.implRead(NioSocketImpl.java:330)
	at java.base/sun.nio.ch.NioSocketImpl.read(NioSocketImpl.java:355)
	at java.base/sun.nio.ch.NioSocketImpl$1.read(NioSocketImpl.java:808)
	at java.base/java.net.Socket$SocketInputStream.read(Socket.java:966)
	at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:244)
	at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:263)
	at java.base/java.io.DataInputStream.readInt(DataInputStream.java:393)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:105)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.execution.python.BatchIterator.hasNext(ArrowEvalPythonExec.scala:38)
	at org.apache.spark.sql.execution.python.BasicPythonArrowInput.writeIteratorToArrowStream(PythonArrowInput.scala:131)
	at org.apache.spark.sql.execution.python.BasicPythonArrowInput.writeIteratorToArrowStream$(PythonArrowInput.scala:124)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner.writeIteratorToArrowStream(ArrowPythonRunner.scala:30)
	at org.apache.spark.sql.execution.python.PythonArrowInput$$anon$1.$anonfun$writeIteratorToStream$1(PythonArrowInput.scala:96)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.sql.execution.python.PythonArrowInput$$anon$1.writeIteratorToStream(PythonArrowInput.scala:102)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:451)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1928)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:282)

Driver stacktrace:
	at org.apache.spark.scheduler.DAGScheduler.failJobAndIndependentStages(DAGScheduler.scala:2856)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2(DAGScheduler.scala:2792)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$abortStage$2$adapted(DAGScheduler.scala:2791)
	at scala.collection.mutable.ResizableArray.foreach(ResizableArray.scala:62)
	at scala.collection.mutable.ResizableArray.foreach$(ResizableArray.scala:55)
	at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:49)
	at org.apache.spark.scheduler.DAGScheduler.abortStage(DAGScheduler.scala:2791)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGScheduler.$anonfun$handleTaskSetFailed$1$adapted(DAGScheduler.scala:1247)
	at scala.Option.foreach(Option.scala:407)
	at org.apache.spark.scheduler.DAGScheduler.handleTaskSetFailed(DAGScheduler.scala:1247)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.doOnReceive(DAGScheduler.scala:3060)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2994)
	at org.apache.spark.scheduler.DAGSchedulerEventProcessLoop.onReceive(DAGScheduler.scala:2983)
	at org.apache.spark.util.EventLoop$$anon$1.run(EventLoop.scala:49)
Caused by: java.net.SocketException: An established connection was aborted by the software in your host machine
	at java.base/sun.nio.ch.NioSocketImpl.implRead(NioSocketImpl.java:330)
	at java.base/sun.nio.ch.NioSocketImpl.read(NioSocketImpl.java:355)
	at java.base/sun.nio.ch.NioSocketImpl$1.read(NioSocketImpl.java:808)
	at java.base/java.net.Socket$SocketInputStream.read(Socket.java:966)
	at java.base/java.io.BufferedInputStream.fill(BufferedInputStream.java:244)
	at java.base/java.io.BufferedInputStream.read(BufferedInputStream.java:263)
	at java.base/java.io.DataInputStream.readInt(DataInputStream.java:393)
	at org.apache.spark.sql.execution.python.PythonArrowOutput$$anon$1.read(PythonArrowOutput.scala:105)
	at org.apache.spark.api.python.BasePythonRunner$ReaderIterator.hasNext(PythonRunner.scala:525)
	at org.apache.spark.InterruptibleIterator.hasNext(InterruptibleIterator.scala:37)
	at scala.collection.Iterator$$anon$11.hasNext(Iterator.scala:491)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.catalyst.expressions.GeneratedClass$GeneratedIteratorForCodegenStage2.processNext(Unknown Source)
	at org.apache.spark.sql.execution.BufferedRowIterator.hasNext(BufferedRowIterator.java:43)
	at org.apache.spark.sql.execution.WholeStageCodegenEvaluatorFactory$WholeStageCodegenPartitionEvaluator$$anon$1.hasNext(WholeStageCodegenEvaluatorFactory.scala:43)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.ContextAwareIterator.hasNext(ContextAwareIterator.scala:39)
	at scala.collection.Iterator$$anon$10.hasNext(Iterator.scala:460)
	at org.apache.spark.sql.execution.python.BatchIterator.hasNext(ArrowEvalPythonExec.scala:38)
	at org.apache.spark.sql.execution.python.BasicPythonArrowInput.writeIteratorToArrowStream(PythonArrowInput.scala:131)
	at org.apache.spark.sql.execution.python.BasicPythonArrowInput.writeIteratorToArrowStream$(PythonArrowInput.scala:124)
	at org.apache.spark.sql.execution.python.ArrowPythonRunner.writeIteratorToArrowStream(ArrowPythonRunner.scala:30)
	at org.apache.spark.sql.execution.python.PythonArrowInput$$anon$1.$anonfun$writeIteratorToStream$1(PythonArrowInput.scala:96)
	at scala.runtime.java8.JFunction0$mcV$sp.apply(JFunction0$mcV$sp.java:23)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally(SparkErrorUtils.scala:64)
	at org.apache.spark.util.SparkErrorUtils.tryWithSafeFinally$(SparkErrorUtils.scala:61)
	at org.apache.spark.util.Utils$.tryWithSafeFinally(Utils.scala:94)
	at org.apache.spark.sql.execution.python.PythonArrowInput$$anon$1.writeIteratorToStream(PythonArrowInput.scala:102)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.$anonfun$run$1(PythonRunner.scala:451)
	at org.apache.spark.util.Utils$.logUncaughtExceptions(Utils.scala:1928)
	at org.apache.spark.api.python.BasePythonRunner$WriterThread.run(PythonRunner.scala:282)
