In [None]:
from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel
from pyspark.sql.types import StructType, StructField, FloatType, StringType


In [None]:
# Get the SparkSession
spark = SparkSession.builder.appName("IrisInference").getOrCreate()

# Define the path to the saved model
model_path = "/mnt/my_delta_lake/iris_model" # Ensure this matches the training path

# Load the trained model
loaded_model = PipelineModel.load(model_path)

# Define the schema for new data
new_data_schema = StructType([
    StructField("sepal_length", FloatType(), True),
    StructField("sepal_width", FloatType(), True),
    StructField("petal_length", FloatType(), True),
    StructField("petal_width", FloatType(), True)
])

# Create a sample DataFrame for new data
new_data = spark.createDataFrame(
    [
        (5.8, 2.7, 5.1, 1.9),
        (6.3, 3.3, 6.0, 2.5),
        (5.0, 3.4, 1.6, 0.4)
    ],
    new_data_schema
)

# Make predictions
predictions = loaded_model.transform(new_data)

# Show the predictions
predictions.select("sepal_length", "sepal_width", "petal_length", "petal_width", "prediction").show()

# You might want to save the predictions to a Delta Lake table or another sink
predictions_path = "/mnt/my_delta_lake/iris_predictions" # Replace with your desired path
predictions.write.format("delta").mode("overwrite").save(predictions_path)

print(f"Predictions saved to: {predictions_path}")

spark.stop()