In [0]:
%pip install mlflow prophet scikit-learn pandas
dbutils.library.restartPython()

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
from pyspark.sql.functions import pandas_udf, PandasUDFType, col, lit, current_date, expr
from pyspark.sql.types import *
import pandas as pd
from prophet import Prophet
import logging

# 1. Prepare Data: Aggregate to Monthly Level to reduce noise
# This makes training faster and more accurate for long-term trends
df_market = spark.table("agriculture.silver.market_prices") \
    .groupBy("crop_name", "market_date") \
    .agg({"modal_price_rs_quintal": "avg"}) \
    .withColumnRenamed("avg(modal_price_rs_quintal)", "y") \
    .withColumnRenamed("market_date", "ds") \
    .orderBy("ds")

# 2. Define Output Schema for the Forecast
result_schema = StructType([
    StructField("crop_name", StringType(), True),
    StructField("ds", DateType(), True),
    StructField("yhat", DoubleType(), True), # Predicted Price
    StructField("yhat_lower", DoubleType(), True), # Confidence Interval Lower
    StructField("yhat_upper", DoubleType(), True), # Confidence Interval Upper
    StructField("price_stability_score", DoubleType(), True) # Risk Metric
])

# 3. The "Grouped Map UDF" (The Distributed Training Logic)
def forecast_crop_prices(history_pd):
    # a. Setup Group
    crop = history_pd['crop_name'].iloc[0]
    
    # b. Train Prophet Model
    # We disable daily seasonality as we are using monthly data
    m = Prophet(daily_seasonality=False, weekly_seasonality=False, yearly_seasonality=True)
    m.fit(history_pd[['ds', 'y']])
    
    # c. Forecast Future (12 Months)
    future = m.make_future_dataframe(periods=12, freq='M')
    forecast = m.predict(future)
    
    # d. Calculate Stability Score (Risk)
    # Metric: Coefficient of Variation (StdDev / Mean) of the future predictions
    future_preds = forecast.tail(12)
    volatility = future_preds['yhat'].std()
    avg_price = future_preds['yhat'].mean()
    stability_score = volatility / avg_price if avg_price != 0 else 0
    
    # e. Format Output
    results = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].copy()
    results['crop_name'] = crop
    results['price_stability_score'] = stability_score
    
    # Only return future dates for the Gold Table
    last_historical_date = pd.to_datetime(history_pd['ds'].max())
    return results[results['ds'] > last_historical_date]

# 4. Execute Distributed Training
# This runs the function above for every crop in parallel
print("Launching distributed training for all crops...")

forecast_df = df_market.groupBy("crop_name").applyInPandas(
    forecast_crop_prices, 
    schema=result_schema
)

# 5. Write to Gold Layer (Materialize results)
# We use 'overwrite' so every run updates the forecast
forecast_df.write.mode("overwrite").saveAsTable("agriculture.gold.market_forecasts")

print("Forecasts generated and saved to 'agriculture.gold.market_forecasts'")
display(spark.table("agriculture.gold.market_forecasts"))

Launching distributed training for all crops...
Forecasts generated and saved to 'agriculture.gold.market_forecasts'


crop_name,ds,yhat,yhat_lower,yhat_upper,price_stability_score
apple,2024-02-29,7959.643052550945,6999.317975573409,8859.799798789201,0.1384172116962412
apple,2024-03-31,8495.612680763214,7481.737018963636,9445.611424327068,0.1384172116962412
apple,2024-04-30,9243.547334807516,8340.358909769382,10234.228711047224,0.1384172116962412
apple,2024-05-31,10097.45400664644,9148.211188559417,11009.907862986922,0.1384172116962412
apple,2024-06-30,10284.547979021598,9352.496245302773,11188.013853243026,0.1384172116962412
apple,2024-07-31,8571.052041434019,7537.661355039704,9489.785240670017,0.1384172116962412
apple,2024-08-31,7215.448585548114,6304.932477106327,8223.539359426533,0.1384172116962412
apple,2024-09-30,6915.630106845369,6090.467464781479,7943.953425201076,0.1384172116962412
apple,2024-10-31,7061.663972775278,6068.63656430213,8046.874065534371,0.1384172116962412
apple,2024-11-30,7353.209719092647,6392.691094425875,8278.942859915285,0.1384172116962412


In [0]:
from pyspark.sql.functions import abs, col, mean

print("=== FORECAST PERFORMANCE METRICS ===")

# Simple calculation of volatility/stability

# 1. Calculate Price Stability Score per Crop
# Lower score = More stable prices (Less Risk)
risk_profile = forecast_df.groupBy("crop_name").agg(
    (mean(col("yhat_upper") - col("yhat_lower")) / mean("yhat")).alias("volatility_index")
).orderBy("volatility_index")

display(risk_profile)

print("""
INTERPRETATION:
- Low Volatility (< 0.15): Safe bets (Prices are stable)
- High Volatility (> 0.25): Risky crops (Prices fluctuate wildly)
""")

=== FORECAST PERFORMANCE METRICS ===


crop_name,volatility_index
lentil,0.1488070309988307
maize,0.1506774806314728
cotton,0.1963055061786033
jute,0.2025812253391609
papaya,0.208798102355015
pomegranate,0.2122506896964157
apple,0.2316733752325645
chickpea,0.2351771009629659
banana,0.3350480575237155
grapes,0.4349366797337638



INTERPRETATION:
- Low Volatility (< 0.15): Safe bets (Prices are stable)
- High Volatility (> 0.25): Risky crops (Prices fluctuate wildly)

