In [1]:
import pandas as pd
import numpy as np
from pathlib import Path
from statsmodels.tsa.statespace.sarimax import SARIMAX
from sklearn.metrics import mean_absolute_error, mean_absolute_percentage_error
import warnings
warnings.filterwarnings('ignore')

pd.set_option('display.max_columns', None)

## Load data
Expected columns: `date`, `fruit`, `demand`, `price`. Adjust `DATA_PATH` to your dataset.

In [3]:
DATA_PATH = Path('C:/Users/ashan/Downloads/SriLanka_Fruit_Market_Data_2022_2024.csv')  # update if needed
assert DATA_PATH.exists(), f'Missing data file at {DATA_PATH}'
df = pd.read_csv(DATA_PATH, parse_dates=['date'])
df = df.sort_values(['fruit', 'date']).reset_index(drop=True)
fruits = sorted(df['fruit'].unique())
print(f'Loaded {len(df):,} rows for fruits: {fruits}')

AssertionError: Missing data file at C:\Users\ashan\Downloads\SriLanka_Fruit_Market_Data_2022_2024.csv

## Helpers: train/test split and SARIMA trainer

In [None]:
def train_test_split_series(series: pd.Series, test_size: int = 30):
    series = series.dropna()
    train = series.iloc[:-test_size] if len(series) > test_size else series
    test = series.iloc[-test_size:] if len(series) > test_size else pd.Series(dtype=float)
    return train, test

def fit_sarima(train: pd.Series, order=(1,1,1), seasonal_order=(1,1,1,7)):
    model = SARIMAX(train, order=order, seasonal_order=seasonal_order, enforce_stationarity=False, enforce_invertibility=False)
    res = model.fit(disp=False)
    return res

def evaluate_model(res, test: pd.Series):
    if test.empty:
        return {'mae': np.nan, 'mape': np.nan}
    preds = res.forecast(steps=len(test))
    mae = mean_absolute_error(test, preds)
    mape = mean_absolute_percentage_error(test, preds)
    return {'mae': mae, 'mape': mape}

## Train SARIMA per fruit for demand and price

In [None]:
results = []
models = {}
forecast_horizon = 7

for fruit in fruits:
    subset = df[df['fruit'] == fruit].set_index('date')
    demand_series = subset['demand']
    price_series = subset['price']

    demand_train, demand_test = train_test_split_series(demand_series)
    price_train, price_test = train_test_split_series(price_series)

    demand_res = fit_sarima(demand_train)
    price_res = fit_sarima(price_train)

    demand_metrics = evaluate_model(demand_res, demand_test)
    price_metrics = evaluate_model(price_res, price_test)

    models[fruit] = {'demand': demand_res, 'price': price_res}
    results.append({
        'fruit': fruit,
        'demand_mae': demand_metrics['mae'],
        'demand_mape': demand_metrics['mape'],
        'price_mae': price_metrics['mae'],
        'price_mape': price_metrics['mape']
    })

pd.DataFrame(results)

## 7-day forecasts

In [None]:
future_dates = pd.date_range(df['date'].max() + pd.Timedelta(days=1), periods=forecast_horizon)
forecast_rows = []

for fruit in fruits:
    demand_forecast = models[fruit]['demand'].forecast(steps=forecast_horizon)
    price_forecast = models[fruit]['price'].forecast(steps=forecast_horizon)
    for dt, d_pred, p_pred in zip(future_dates, demand_forecast, price_forecast):
        forecast_rows.append({
            'date': dt.date(),
            'fruit': fruit,
            'demand_pred': max(0, d_pred),
            'price_pred': max(0, p_pred)
        })

forecast_df = pd.DataFrame(forecast_rows)
forecast_df

## Quick visualization

In [None]:
import matplotlib.pyplot as plt

fig, axes = plt.subplots(len(fruits), 2, figsize=(12, 4 * len(fruits)), sharex=False)
if len(fruits) == 1:
    axes = np.array([axes])

for idx, fruit in enumerate(fruits):
    subset = df[df['fruit'] == fruit]
    # Demand
    axes[idx, 0].plot(subset['date'], subset['demand'], label='history')
    axes[idx, 0].plot(future_dates, forecast_df[forecast_df['fruit'] == fruit]['demand_pred'], label='forecast')
    axes[idx, 0].set_title(f'{fruit} demand')
    axes[idx, 0].legend()
    axes[idx, 0].grid(True, alpha=0.3)
    # Price
    axes[idx, 1].plot(subset['date'], subset['price'], label='history', color='tab:orange')
    axes[idx, 1].plot(future_dates, forecast_df[forecast_df['fruit'] == fruit]['price_pred'], label='forecast', color='tab:red')
    axes[idx, 1].set_title(f'{fruit} price')
    axes[idx, 1].legend()
    axes[idx, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()