In [None]:
import os
from pyspark.sql import SparkSession
from pyspark.ml import PipelineModel
from pyspark.sql.types import StructType, StructField, StringType, FloatType, IntegerType
from pyspark.sql.functions import col, length, when, log1p, expm1, lower, lit

spark = SparkSession.builder \
    .appName("Inference DT") \
    .config("spark.driver.memory", "4g") \
    .master("local[*]") \
    .getOrCreate()

spark.sparkContext.setLogLevel("ERROR")

MODEL_PATH = "../../models/regression/dt_price_log_v1"

print(f"Loading Decision Tree model from {MODEL_PATH}...")
try:
    model = PipelineModel.load(MODEL_PATH)
    print("Decision Tree loaded successfully!")
except Exception as e:
    print(f"Error loading model: {e}")
    model = None

In [None]:
import math 

def predict_price_dt(title, category, rating, rating_count, store_freq=100):
    if model is None:
        return "Model not loaded"

    schema = StructType([
        StructField("title", StringType(), True),
        StructField("main_category", StringType(), True),
        StructField("average_rating", FloatType(), True),
        StructField("rating_number", IntegerType(), True),
        StructField("store_freq", IntegerType(), True),
        StructField("features_count", IntegerType(), True)
    ])
    
    data = [(title, category, float(rating), int(rating_count), int(store_freq), 0)]
    input_df = spark.createDataFrame(data, schema=schema)
    
    input_df = input_df.withColumn("title_len", length(col("title")))
    input_df = input_df.withColumn("is_premium", when(lower(col("title")).rlike("premium|pro|deluxe"), 1).otherwise(0))
    input_df = input_df.withColumn("is_bundle", when(lower(col("title")).rlike("bundle|set|pack"), 1).otherwise(0))
    input_df = input_df.withColumn("log_rating_number", log1p(col("rating_number")))
    input_df = input_df.withColumn("log_store_freq", log1p(col("store_freq")))
    
    prediction_df = model.transform(input_df)
    
    predicted_log_price = prediction_df.select("prediction").collect()[0][0]
    
    real_price = math.expm1(predicted_log_price)
    
    return max(0.0, real_price)

In [None]:
items = [
    {
        "title": "Cheap Plastic Pen",
        "category": "Office Products",
        "rating": 3.5,
        "rating_count": 10,
        "store_freq": 5
    },
    {
        "title": "Professional Gaming Laptop 16GB RAM Bundle",
        "category": "Computers", # Переконайся, що ця категорія була в тренуванні
        "rating": 4.9,
        "rating_count": 5000,
        "store_freq": 200
    }
]

print("--- Decision Tree Predictions ---")
for item in items:
    price = predict_price_dt(**item)
    print(f"\nProduct: {item['title']}")
    print(f"Predicted Price: ${price:.2f}")