In [12]:
from flask import Flask, request, jsonify
from datetime import datetime
import pandas as pd
import joblib
import numpy as np
import os

app = Flask(__name__)

# --- 1. MODEL LOADING ---
# Ensure your pkl file is in the same directory as this script
MODEL_FILE = "sarimax_model.pkl"

if os.path.exists(MODEL_FILE):
    model = joblib.load(MODEL_FILE)
    print("Model loaded successfully.")
else:
    print(f"ERROR: {MODEL_FILE} not found in {os.getcwd()}")

# --- 2. VALIDATION FUNCTIONS ---

def validate_input(df):
    """Checks if Discount_Flag and Holiday are present and are 0 or 1."""
    required_cols = ['Discount_Flag', 'Holiday']
    for col in required_cols:
        if col not in df.columns:
            return False, f"Missing column: {col}"
    
    # Check for nulls and ensure only 0 or 1
    if df[required_cols].isnull().values.any():
        return False, "Input contains null values."
    
    if not df[required_cols].isin([0, 1]).all().all():
        return False, "Discount_Flag and Holiday must be binary (0 or 1)."
        
    return True, "Success"

def validate_dates(date_list):
    """Ensures YYYY-MM-DD format, chronological order, and post-2018."""
    try:
        parsed_dates = [datetime.strptime(d, '%Y-%m-%d') for d in date_list]
        
        for i in range(len(parsed_dates) - 1):
            if parsed_dates[i] >= parsed_dates[i+1]:
                return False, "Dates must be in chronological order."
        
        # Training end date check
        training_end_date = datetime(2018, 12, 31)
        if parsed_dates[0] <= training_end_date:
            return False, "Predictions must start after 2018-12-31."

        return True, "Success"
    except ValueError:
        return False, "Invalid date format. Please use YYYY-MM-DD."
    except Exception as e:
        return False, f"Date validation error: {str(e)}"

# --- 3. ROUTES ---

@app.route('/', methods=['GET'])
def home():
    return "<h1>Sales Forecasting API</h1><p>Status: Online</p>"

@app.route('/predict', methods=['POST'])
def predict():
    try:
        data = request.get_json()
        if not data:
            return jsonify({"status": "error", "message": "No JSON data received"}), 400
            
        input_df = pd.DataFrame(data)

        # Step A: Data Validation
        is_valid_data, data_msg = validate_input(input_df)
        if not is_valid_data:
            return jsonify({"status": "invalid_data", "message": data_msg}), 400

        # Step B: Date Validation
        if 'date' in input_df.columns:
            is_valid_date, date_msg = validate_dates(input_df['date'].tolist())
            if not is_valid_date:
                return jsonify({"status": "invalid_date", "message": date_msg}), 400
        else:
            return jsonify({"status": "invalid_date", "message": "Date column required"}), 400

        # Step C: Forecast Generation
        # Enforce column order and float type for statsmodels
        exog_vars = input_df[['Discount_Flag', 'Holiday']].astype(float)
        forecast = model.get_forecast(steps=len(input_df), exog=exog_vars)
        
        # Get bounds and means
        conf_int = forecast.conf_int()
        predictions_log = forecast.predicted_mean
        
        # Step D: Back-Transform (np.expm1 converts log back to real dollars)
        predictions_real = np.expm1(predictions_log)
        lower_bound = np.expm1(conf_int.iloc[:, 0])
        upper_bound = np.expm1(conf_int.iloc[:, 1])

        # Step E: Build Response
        results = []
        for i in range(len(predictions_real)):
            results.append({
                "date": input_df['date'].iloc[i],
                "prediction": round(float(predictions_real.iloc[i]), 2),
                "lower_95_ci": round(float(lower_bound.iloc[i]), 2),
                "upper_95_ci": round(float(upper_bound.iloc[i]), 2)
            })

        return jsonify({"status": "success", "results": results}), 200

    except Exception as e:
        return jsonify({"status": "error", "message": str(e)}), 500

# --- 4. EXECUTION ---
if __name__ == '__main__':
    # use_reloader=False is critical to prevent crashes in Jupyter/Spyder
    app.run(debug=True, use_reloader=False, port=5000)

Model loaded successfully.
 * Serving Flask app '__main__'
 * Debug mode: on


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [27/Dec/2025 01:06:35] "POST /predict HTTP/1.1" 400 -
127.0.0.1 - - [27/Dec/2025 01:06:35] "POST /predict HTTP/1.1" 200 -
