In [11]:
!pip install prophet

Defaulting to user installation because normal site-packages is not writeable


In [12]:
import os
import pandas as pd
from prophet import Prophet
import matplotlib.pyplot as plt
from datetime import date

if not os.path.exists('forecast_plots'):
    os.makedirs('forecast_plots')

In [13]:
# --- 1. Load the sales data ---
sales_df = pd.read_csv('sales_data.csv')

# --- 2. Convert 'Date' to datetime objects and aggregate sales by date and SKU ---
sales_df['Date'] = pd.to_datetime(sales_df['Date'])

# Create a list of our unique SKUs to loop through
skus_to_forecast = sales_df['SKU'].unique()
print(f"SKUs to forecast: {skus_to_forecast}")

# --- 3. Prepare an empty dataframe to store all our forecasts ---
all_forecasts = pd.DataFrame()

SKUs to forecast: ['SKU-001' 'SKU-002' 'SKU-003' 'SKU-004']


In [14]:
# Loop through each SKU
for sku in skus_to_forecast:
    print(f"\nForecasting for SKU: {sku}")
    
    sku_sales_df = sales_df[sales_df['SKU'] == sku].copy()
    sku_sales_daily = sku_sales_df.groupby('Date')['Quantity'].sum().reset_index()
    
    if len(sku_sales_daily) < 10:
        print(f"Skipping {sku} due to insufficient data.")
        continue

    # Prepare data for Prophet
    prophet_df = sku_sales_daily.rename(columns={'Date': 'ds', 'Quantity': 'y'})
    
    # Instantiate and fit the model
    m = Prophet(yearly_seasonality=True, daily_seasonality=True)
    m.fit(prophet_df)
    
    # Create future dataframe and generate forecast
    future = m.make_future_dataframe(periods=180, freq='D')
    forecast = m.predict(future)

    # --- ENHANCED VISUALIZATION ---
    # Create the main forecast plot with more detail
    fig, ax = plt.subplots(figsize=(10, 6))
    m.plot(forecast, ax=ax)
    ax.set_title(f"Sales Forecast for {sku}", fontsize=16)
    ax.set_xlabel("Date")
    ax.set_ylabel("Sales Quantity")
    ax.grid(True)
    
    # Save the main plot
    plot_filename = f"forecast_plots/sales_forecast_{sku}.png"
    fig.savefig(plot_filename)
    plt.close(fig)
    print(f"Main forecast plot saved for {sku}.")

    # Plot and save the forecast components (trend, seasonality)
    fig_components = m.plot_components(forecast)
    components_filename = f"forecast_plots/components_forecast_{sku}.png"
    fig_components.savefig(components_filename)
    plt.close(fig_components)
    print(f"Components plot saved for {sku}.")

    # --- STORE THE ENHANCED FORECAST DATA ---
    # Combine historical sales with the forecast for a comprehensive dataset
    forecast_df = forecast[['ds', 'yhat', 'yhat_lower', 'yhat_upper']].copy()
    forecast_df['SKU'] = sku
    
    # Merge with historical data to create a single table
    prophet_df['y'] = prophet_df['y'].astype(float)
    
    full_data = pd.merge(forecast_df, prophet_df, on='ds', how='left')
    all_forecasts = pd.concat([all_forecasts, full_data])

# --- FINAL STEP: EXPORT THE COMPREHENSIVE FORECAST DATA ---
all_forecasts['ds'] = all_forecasts['ds'].dt.date # Convert datetime to date for clean CSV
all_forecasts.to_csv('all_forecasts.csv', index=False)
print("\nComprehensive forecast data exported to 'all_forecasts.csv'")


Forecasting for SKU: SKU-001


18:29:06 - cmdstanpy - INFO - Chain [1] start processing
18:29:07 - cmdstanpy - INFO - Chain [1] done processing


Main forecast plot saved for SKU-001.
Components plot saved for SKU-001.

Forecasting for SKU: SKU-002


18:29:09 - cmdstanpy - INFO - Chain [1] start processing
18:29:09 - cmdstanpy - INFO - Chain [1] done processing


Main forecast plot saved for SKU-002.
Components plot saved for SKU-002.

Forecasting for SKU: SKU-003


18:29:10 - cmdstanpy - INFO - Chain [1] start processing
18:29:11 - cmdstanpy - INFO - Chain [1] done processing


Main forecast plot saved for SKU-003.
Components plot saved for SKU-003.

Forecasting for SKU: SKU-004
Main forecast plot saved for SKU-004.
Components plot saved for SKU-004.

Comprehensive forecast data exported to 'all_forecasts.csv'
