In [0]:
# Convert Spark table to Pandas
df = spark.table("Gold_Sales_By_City_Month").toPandas()

# Total rows
total_rows = df.shape[0]
print(f"Total rows: {total_rows}")

# Number of unique cities
num_cities = df['City'].nunique()
print(f"Number of cities: {num_cities}")

# Rows per city
rows_per_city = df.groupby("City").size()
print("\nRows per city:")
print(rows_per_city)

# Rows per feature ratio (rough estimate)
feature_cols = ["Month_Num", "lag_1", "lag_2", "lag_3", "lag_6",
                "rolling_avg_3", "rolling_std_3", "month_sin", "month_cos"]
rows_per_feature = total_rows / len(feature_cols)
print(f"\nTotal rows per feature (approx): {rows_per_feature:.2f}")


In [0]:
# ======================================
# Databricks Free Edition - Regression + SARIMA + MLflow
# ======================================

# --- MLflow setup with Databricks workaround ---
import mlflow
import databricks.connect as db_connect
import mlflow.tracking._model_registry.utils

# Workaround to set the registry URI manually
mlflow.tracking._model_registry.utils._get_registry_uri_from_spark_session = lambda: "databricks-uc"
mlflow.login()  # prints: INFO-log: Login successful!

# --- Imports ---
import pandas as pd
import numpy as np
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
import joblib
import plotly.graph_objects as go
from statsmodels.tsa.statespace.sarimax import SARIMAX
from statsmodels.tools.sm_exceptions import ConvergenceWarning
import warnings
warnings.simplefilter('ignore', ConvergenceWarning)

# ======================================
# Load data
# ======================================
df_spark = spark.table("Gold_Sales_By_City_Month")
df = df_spark.toPandas()
df["Month_Num"] = df["Year"] * 12 + df["Month"]
df = df.sort_values(["City", "Month_Num"])
cities = df["City"].unique()

# ======================================
# Feature engineering for Linear Regression
# ======================================
for lag in [1,2,3,6]:
    df[f"lag_{lag}"] = df.groupby("City")["Total_Sales"].shift(lag).fillna(0)
df["rolling_avg_3"] = df.groupby("City")["Total_Sales"].rolling(3, min_periods=1).mean().reset_index(0, drop=True)
df["rolling_std_3"] = df.groupby("City")["Total_Sales"].rolling(3, min_periods=1).std().reset_index(0, drop=True).fillna(0)
df["month_sin"] = np.sin(2 * np.pi * df["Month"] / 12)
df["month_cos"] = np.cos(2 * np.pi * df["Month"] / 12)

df_lr = pd.get_dummies(df, columns=["City"], drop_first=True)
feature_cols = [c for c in df_lr.columns if c not in ["Total_Sales","Year","Month"]]

cutoff = int(0.8 * len(df_lr))
train_df = df_lr.iloc[:cutoff]
test_df = df_lr.iloc[cutoff:]

X_train = train_df[feature_cols].values
y_train = train_df["Total_Sales"].values
X_test = test_df[feature_cols].values
y_test = test_df["Total_Sales"].values

# ======================================
# MLflow experiment setup
# ======================================
mlflow.set_experiment("/Users/amirrezakha@yahoo.com/Retail_ML_Experiments")

# ======================================
# Linear Regression model
# ======================================
with mlflow.start_run(run_name="LinearRegression") as run:
    lr_model = LinearRegression()
    lr_model.fit(X_train, y_train)
    preds_lr = lr_model.predict(X_test)
    rmse_lr = mean_squared_error(y_test, preds_lr, squared=False)
    
    mlflow.log_param("model_type", "LinearRegression")
    mlflow.log_param("features", feature_cols)
    mlflow.log_metric("rmse", rmse_lr)
    
    local_model_path = "/tmp/lr_model.pkl"
    joblib.dump(lr_model, local_model_path)
    mlflow.log_artifact(local_model_path, artifact_path="model")
    
    print(f"LinearRegression RMSE: {rmse_lr}")

# ======================================
# SARIMA per city
# ======================================
arima_results = {}
for city in cities:
    city_df = df[df["City"]==city].sort_values("Month_Num")
    ts_train = city_df["Total_Sales"].iloc[:int(0.8*len(city_df))]
    ts_test = city_df["Total_Sales"].iloc[int(0.8*len(city_df)):]
    
    try:
        sarima_model = SARIMAX(ts_train, order=(1,1,1), seasonal_order=(1,1,1,12),
                               enforce_stationarity=False, enforce_invertibility=False).fit(disp=False)
        preds_sarima = sarima_model.forecast(len(ts_test))
        rmse_sarima = mean_squared_error(ts_test, preds_sarima, squared=False)
        arima_results[city] = {"model": sarima_model, "rmse": rmse_sarima,
                               "preds": preds_sarima, "test": ts_test,
                               "month_num": city_df["Month_Num"].iloc[int(0.8*len(city_df)):] }
        
        # Log each SARIMA city run to MLflow
        with mlflow.start_run(run_name=f"SARIMA_{city}", nested=True):
            mlflow.log_param("model_type", "SARIMA")
            mlflow.log_param("city", city)
            mlflow.log_param("order", (1,1,1))
            mlflow.log_param("seasonal_order", (1,1,1,12))
            mlflow.log_metric("rmse", rmse_sarima)
        
        print(f"SARIMA {city} RMSE: {rmse_sarima}")
    except Exception as e:
        print(f"SARIMA failed for {city}: {e}")

# ======================================
# Plot actual vs predictions
# ======================================
fig = go.Figure()

# Linear Regression
fig.add_trace(go.Scatter(
    x=test_df["Month_Num"], y=preds_lr,
    mode='lines+markers', name="LinearRegression",
    line=dict(color='blue', width=3, dash='dash')
))

# SARIMA per city
colors = ['red','green','orange','purple','brown']
for i, (city, res) in enumerate(arima_results.items()):
    fig.add_trace(go.Scatter(
        x=res["month_num"], y=res["preds"],
        mode='lines+markers', name=f"SARIMA_{city}",
        line=dict(color=colors[i%len(colors)], width=2, dash='dot')
    ))

# Actual
fig.add_trace(go.Scatter(
    x=test_df["Month_Num"], y=test_df["Total_Sales"],
    mode='lines+markers', name="Actual",
    line=dict(color='black', width=4)
))

fig.update_layout(
    title="Actual vs Predicted Total Sales",
    xaxis_title="Month_Num",
    yaxis_title="Total_Sales",
    template="plotly_white",
    legend=dict(x=0.02, y=0.98)
)
fig.show()

# ======================================
# Summary RMSE
# ======================================
print("\nLinearRegression RMSE:", rmse_lr)
for city, res in arima_results.items():
    print(f"SARIMA {city} RMSE: {res['rmse']}")
