In [1]:
import pandas as pd


from utils.load_data import load_data
from utils.data_splitter import DataSplitter

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)

df = load_data('data/commodity_prices.csv')
df_grouped = pd.read_csv('data/important_feature_groups.csv')

In [2]:
date_dict = {
    'train_start': "2023-06-01", 'train_end': "2025-06-30",
    'valid_start': "2025-07-01", 'valid_end': "2025-07-31",
    'test_start': "2025-08-01", 'test_end': "2025-08-18"
}

thresholds = {
    'train': 100,
    'valid': 10,
    'test': 5
}

splitter = DataSplitter(df, date_dict, thresholds)
train_df, valid_df, test_df = splitter.run()

In [3]:
train_df.shape, valid_df.shape, test_df.shape, 

((140521, 9), (6639, 9), (3677, 9))

In [23]:
df.head()

Unnamed: 0,Product_Type,Commodity,Variety_Type,Arrival_Date,Market,Is_VFPCK,Season,Year,Modal_Price,log_Modal_Price,Max_Price,Min_Price
0,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-13,North Paravur,False,Winter,2023,5200.0,8.556606,6000.0,5000.0
1,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-14,North Paravur,False,Winter,2023,6200.0,8.732466,6500.0,6000.0
2,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-16,North Paravur,False,Winter,2023,4800.0,8.47658,5600.0,4600.0
3,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-18,North Paravur,False,Winter,2023,3500.0,8.160804,4500.0,3000.0
4,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-19,North Paravur,False,Winter,2023,5500.0,8.612685,6000.0,5500.0


In [24]:
df.groupby(['Product_Type', 'Market']).size().reset_index().shape

(402, 3)

## Filtering Sparse Groups

We analyzed the number of observations per Commodity–Market group across the train, validation, and test periods.  

- Groups with **insufficient data** (train ≤ 100, validation ≤ 10, test ≤ 5) were removed.  
- This ensures that classical time series models (SARIMA/Prophet) and ML models have enough historical points to learn patterns and be evaluated reliably.  

**Result:** 109 sparse groups were excluded. Remaining groups have sufficient data for modeling.

In [25]:
train_start, train_end = "2023-06-01", "2025-06-30"
valid_start, valid_end = "2025-07-01", "2025-07-31"
test_start, test_end = "2025-08-01", "2025-08-18"

def assign_split(date):
    if train_start <= date.strftime("%Y-%m-%d") <= train_end:
        return "train"
    elif valid_start <= date.strftime("%Y-%m-%d") <= valid_end:
        return "valid"
    else:
        return "test"

df['Split'] = df['Arrival_Date'].apply(assign_split)

group_counts = (
    df.groupby(["Product_Type", "Market", "Split"])
      .size()
      .unstack(fill_value=0)
      .reset_index()
)

group_counts = group_counts[
    ~((group_counts['train'] <= 100) |
      (group_counts['valid'] <= 10) |
      (group_counts['test'] <= 5))
]
group_counts.shape

df = df.merge(group_counts, on=['Product_Type', 'Market'], how='right')

In [21]:
group_counts.head()

Split,Product_Type,Market,test,train,valid
0,Alsandikai|Alsandikai|FAQ,North Paravur,13,435,23
1,Amaranthus|Amaranthus|FAQ,Aluva,11,551,25
2,Amaranthus|Amaranthus|FAQ,Angamaly,15,533,22
3,Amaranthus|Amaranthus|FAQ,Broadway market,13,363,18
4,Amaranthus|Amaranthus|FAQ,Ernakulam,10,485,23


In [26]:
df.head()

Unnamed: 0,Product_Type,Commodity,Variety_Type,Arrival_Date,Market,Is_VFPCK,Season,Year,Modal_Price,log_Modal_Price,Max_Price,Min_Price,Split,test,train,valid
0,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-13,North Paravur,False,Winter,2023,5200.0,8.556606,6000.0,5000.0,train,13,435,23
1,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-14,North Paravur,False,Winter,2023,6200.0,8.732466,6500.0,6000.0,train,13,435,23
2,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-16,North Paravur,False,Winter,2023,4800.0,8.47658,5600.0,4600.0,train,13,435,23
3,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-18,North Paravur,False,Winter,2023,3500.0,8.160804,4500.0,3000.0,train,13,435,23
4,Alsandikai|Alsandikai|FAQ,Alsandikai,Alsandikai|Alsandikai,2023-12-19,North Paravur,False,Winter,2023,5500.0,8.612685,6000.0,5500.0,train,13,435,23


In [None]:
train_start, train_end = "2023-06-01", "2025-06-30"
valid_start, valid_end = "2025-07-01", "2025-07-31"
test_start, test_end = "2025-08-01", "2025-08-18"

def assign_split(date):
    if train_start <= date.strftime("%Y-%m-%d") <= train_end:
        return "train"
    elif valid_start <= date.strftime("%Y-%m-%d") <= valid_end:
        return "valid"
    else:
        return "test"

df['Split'] = df['Arrival_Date'].apply(assign_split)

group_counts = (
    df.groupby(["Product_Type", "Market", "Split"])
      .size()
      .unstack(fill_value=0)
      .reset_index()
)

group_counts = group_counts[
    ~((group_counts['train'] <= 100) |
      (group_counts['valid'] <= 10) |
      (group_counts['test'] <= 5))
]
group_counts.shape

df = df.merge(group_counts, on=['Product_Type', 'Market'], how='right')
# .drop(columns=['train', 'valid', 'test'])

train_df = df[df['Split'] == 'train']
valid_df = df[df['Split'] == 'valid']
test_df = df[df['Split'] == 'test']



def drop_columns(df):
    columns_to_drop = ['Split', 'train', 'valid', 'test', 'Modal_Price', 'Max_Price', 'Min_Price']
    df = df.drop(columns=columns_to_drop)
    return df


In [28]:
class DataSplitter:

    def __init__(self, df, date_dict: dict, thresholds: dict):
        self.df = df
        self.date_dict = date_dict
        self.thresholds = thresholds

    def assign_split(self, date):
        """Assign a split label (train/valid/test) based on date ranges."""
        if self.date_dict['train_start'] <= date.strftime("%Y-%m-%d") <= self.date_dict['train_end']:
            return "train"
        elif self.date_dict['valid_start'] <= date.strftime("%Y-%m-%d") <= self.date_dict['valid_end']:
            return "valid"
        else:
            return "test"
        
    def assign_splits(self):
        """Add 'Split' column to df."""
        self.df['Split'] = self.df['Arrival_Date'].apply(lambda d: self.assign_split(d))
        
    def filter_groups(self):
        """Filter out groups with too few samples in any split."""
        group_counts = (
            self.df.groupby(["Product_Type", "Market", "Split"])
            .size()
            .unstack(fill_value=0)
            .reset_index()
            )
        group_counts = group_counts[
            ~((group_counts['train'] <= self.thresholds['train']) |
              (group_counts['valid'] <= self.thresholds['valid']) |
              (group_counts['test'] <= self.thresholds['test']))
              ]
        self.df = self.df.merge(group_counts, on=['Product_Type', 'Market'], how='right')

    def split_datasets(self):
        """Return train/valid/test DataFrames."""
        train_df = self.df[self.df['Split'] == 'train']
        valid_df = self.df[self.df['Split'] == 'valid']
        test_df = self.df[self.df['Split'] == 'test']
        return train_df, valid_df, test_df
        
    
    def drop_columns(self, dataframe):
        """Drop leakage or helper columns safely."""
        columns_to_drop = ['Split', 'train', 'valid', 'test', 'Modal_Price', 'Max_Price', 'Min_Price']
        dataframe = dataframe.drop(columns=columns_to_drop)
        return dataframe

    def run(self):
        """Full pipeline: assign splits → filter groups → return cleaned splits."""
        self.assign_splits()
        self.filter_groups()
        train_df, valid_df, test_df = self.split_datasets()
        return (
            self.drop_columns(train_df),
            self.drop_columns(valid_df),
            self.drop_columns(test_df),
        )

In [None]:
train_df.shape, valid_df.shape, test_df.shape, 

((140521, 9), (6639, 9), (3677, 9))

In [6]:
train_df.shape, valid_df.shape, test_df.shape, 

((140521, 16), (6639, 16), (3677, 16))

In [8]:
train_df.groupby(['Product_Type', 'Market']).size().reset_index().shape, valid_df.groupby(['Product_Type', 'Market']).size().reset_index().shape, test_df.groupby(['Product_Type', 'Market']).size().reset_index().shape,



((293, 3), (293, 3), (293, 3))

In [9]:
train_df.columns

Index(['Product_Type', 'Commodity', 'Variety_Type', 'Arrival_Date', 'Market',
       'Is_VFPCK', 'Season', 'Year', 'Modal_Price', 'log_Modal_Price',
       'Max_Price', 'Min_Price', 'Split', 'test', 'train', 'valid'],
      dtype='object')

In [42]:
group_counts = group_counts[
    ~((group_counts['train'] <= 100) |
      (group_counts['valid'] <= 10) |
      (group_counts['test'] <= 5))
]

In [43]:
group_counts.shape

(293, 5)

In [38]:
group_counts[group_counts['train'] <= 100].shape

(20, 5)

In [39]:
group_counts[group_counts['valid'] <= 10].shape

(46, 5)

In [41]:
group_counts[group_counts['test'] <= 5].shape

(43, 5)

In [5]:
df_grouped.head()

Unnamed: 0,important_features,Product_Type,Market,Mean_Commodity_Effect_Size,Mean_Variety_Type_Effect_Size,Mean_Season_Effect_Size,Mean_Market_Effect_Size,Mean_Year_Effect_Size,Total_Records
0,Commodity,Banana|Nendra Bale|Medium,THURAVOOR VFPCK,0.650715,0.00014,0.000149,0.0,0.0,65
1,Commodity,Capsicum|Other|FAQ,Piravam,0.057064,0.0,0.03146,0.0,0.008407,502
2,Commodity|Market,Amaranthus|Amaranthus|FAQ,Aluva,0.152055,0.0,0.012549,0.745668,0.014132,587
3,Commodity|Market,Amaranthus|Amaranthus|FAQ,Angamaly,0.152055,0.0,0.012549,0.745668,0.014132,570
4,Commodity|Market,Amaranthus|Amaranthus|FAQ,Broadway market,0.152055,0.0,0.012549,0.745668,0.014132,394
