In [15]:
import pandas as pd
import numpy as np
from statsmodels.tsa.statespace.sarimax import SARIMAX

# Load the CSV data
train_file_path = r'C:\Users\Giridhar\Downloads\forecasting-unit-sales-vit-task-2\train.csv'
test_file_path = r'C:\Users\Giridhar\Downloads\forecasting-unit-sales-vit-task-2\test.csv'

train_data = pd.read_csv(train_file_path)
test_data = pd.read_csv(test_file_path)

# Handle missing values
def handle_missing_values(data, is_train=True):
    data['Item Id'].fillna('Unknown', inplace=True)
    data['Item Name'].fillna('Unknown Item', inplace=True)
    data['ad_spend'].replace([np.inf, -np.inf], np.nan, inplace=True)
    data['ad_spend'].fillna(0, inplace=True)
    if is_train and 'units' in data.columns:
        data['units'].fillna(0, inplace=True)
    return data

train_data = handle_missing_values(train_data, is_train=True)
test_data = handle_missing_values(test_data, is_train=False)

train_data['units'] = train_data['units'].apply(lambda x: max(x, 0))
train_data['unit_price'] = train_data['unit_price'].apply(lambda x: max(x, 0))

# Create time-based features
def create_time_features(data):
    data['date'] = pd.to_datetime(data['date'])
    data['day_of_week'] = data['date'].dt.dayofweek
    data['month'] = data['date'].dt.month
    data['quarter'] = data['date'].dt.quarter
    return data

train_data = create_time_features(train_data)
test_data = create_time_features(test_data)

# Aggregate data by Item Id and date
def aggregate_data(data, is_train=True):
    agg_dict = {
        'ad_spend': 'sum',
        'unit_price': 'mean'
    }
    if is_train:
        agg_dict['units'] = 'sum'
    return data.groupby(['Item Id', 'date']).agg(agg_dict).reset_index()

train_aggregated = aggregate_data(train_data, is_train=True)
test_aggregated = aggregate_data(test_data, is_train=False)

# Create a dictionary to hold DataFrames for each item
predictions = {}

# List of unique item IDs
item_ids = train_aggregated['Item Id'].unique()

for item_id in item_ids:
    train_item = train_aggregated[train_aggregated['Item Id'] == item_id].copy()
    test_item = test_aggregated[test_aggregated['Item Id'] == item_id].copy()
    
    if not train_item.empty and not test_item.empty:
        # Prepare data for SARIMA
        train_item_sarima = train_item[['date', 'units']].set_index('date')
        train_item_sarima = train_item_sarima.asfreq('D')
        
        # Check if there are sufficient data points for model fitting
        if len(train_item_sarima) > 10:  # Adjust threshold as necessary
            try:
                # Initialize and fit the SARIMA model with simpler parameters
                model_sarima = SARIMAX(train_item_sarima,
                                       order=(1, 1, 1),
                                       seasonal_order=(1, 1, 0, 7),
                                       enforce_stationarity=False,
                                       enforce_invertibility=False)
                
                model_sarima_fit = model_sarima.fit(disp=False, maxiter=2000)
                print(f"Model fitted successfully for item {item_id}")
                
                # Create future dataframe and predict
                future_dates = pd.date_range(start=train_item_sarima.index[-1] + pd.Timedelta(days=1), periods=len(test_item), freq='D')
                future = pd.DataFrame(index=future_dates)
                future = future.asfreq('D')
                
                forecast_sarima = model_sarima_fit.get_forecast(steps=len(future_dates))
                forecast_sarima_mean = forecast_sarima.predicted_mean
                
                # Debugging print
                print(f"Forecast for item {item_id}: {forecast_sarima_mean.head()}")
                
                # Add predictions to test_item DataFrame
                test_item['predicted_units_sarima'] = np.ceil(forecast_sarima_mean).astype(int)
                
                # Remove negative values (if any) by setting them to zero
                test_item['predicted_units_sarima'] = test_item['predicted_units_sarima'].clip(lower=0)
                
                # Store the DataFrame in the dictionary
                predictions[item_id] = test_item[['date', 'Item Id', 'predicted_units_sarima']]
            
            except Exception as e:
                print(f"Failed to fit SARIMA model for item {item_id}: {e}")
                test_item['predicted_units_sarima'] = np.nan
        else:
            print(f"Not enough data to fit model for item {item_id}")

# Combine all predictions into a single DataFrame
if predictions:
    combined_predictions = pd.concat(predictions.values(), ignore_index=True)
    combined_predictions.to_csv('task1_predictions_sarima.csv', index=False)
    print('task1_predictions_sarima.csv')
else:
    print('No predictions were generated.')


Model fitted successfully for item B09KDLQ2GW
Forecast for item B09KDLQ2GW: 2024-06-01    1.441972
2024-06-02    0.994169
2024-06-03    0.012641
2024-06-04    5.998181
2024-06-05    1.012848
Freq: D, dtype: float64
Model fitted successfully for item B09KDN7PYR
Forecast for item B09KDN7PYR: 2024-06-01    0.679592
2024-06-02    0.322797
2024-06-03    0.441242
2024-06-04    1.005589
2024-06-05    0.502772
Freq: D, dtype: float64
Model fitted successfully for item B09KDNYCYR
Forecast for item B09KDNYCYR: 2024-06-01   -0.521910
2024-06-02   -0.982155
2024-06-03   -0.267285
2024-06-04   -0.103103
2024-06-05   -0.036389
Freq: D, dtype: float64
Model fitted successfully for item B09KDPXYG3
Forecast for item B09KDPXYG3: 2024-05-31   -0.292699
2024-06-01   -0.161846
2024-06-02    0.379893
2024-06-03   -0.274312
2024-06-04   -0.028256
Freq: D, dtype: float64
Model fitted successfully for item B09KDQ2BWY
Forecast for item B09KDQ2BWY: 2024-05-31   -0.454569
2024-06-01   -0.360520
2024-06-02   -0.13

In [17]:
combined_predictions

Unnamed: 0,date,Item Id,predicted_units_sarima
0,2024-07-01,B09KDLQ2GW,
1,2024-07-02,B09KDLQ2GW,
2,2024-07-03,B09KDLQ2GW,
3,2024-07-04,B09KDLQ2GW,
4,2024-07-05,B09KDLQ2GW,
...,...,...,...
2828,2024-07-24,B0CY5QQ49F,
2829,2024-07-25,B0CY5QQ49F,
2830,2024-07-26,B0CY5QQ49F,
2831,2024-07-27,B0CY5QQ49F,
