In [1]:
import os
os.chdir('/workspaces/b2b-wf-experiments/stacks/short_term_forecast')

In [2]:
from dataclasses import dataclass
from pathlib import Path
from typing import List

@dataclass(frozen=True)
class DatasetStatisticsConfig:
    root_dir: Path
    train_dataset: Path
    test_dataset: Path
    time_column: str
    target_column: str
    attribute_columns: List[str]
    output_statistics: Path

In [3]:
from src.ShortTermForecast.constants import *
from src.ShortTermForecast.utils.common import read_yaml, create_directories

In [4]:
class ConfigurationManager:
    def __init__(
        self,
        config_filepath = CONFIG_FILE_PATH,
        params_filepath = PARAMS_FILE_PATH
    ):
        if config_filepath is not None:
            self.config = read_yaml(config_filepath)
            create_directories([self.config.artifacts_root])
        
        if params_filepath is not None:
            self.params = read_yaml(params_filepath)

    def get_dataset_statistics_dvc_config(self) -> DatasetStatisticsConfig:
        general_configs = self.config.general_setup
        cross_validation_config = self.config.cross_validation
        dataset_statistics_config = self.config.dataset_statistics

        create_directories([dataset_statistics_config.root_dir])

        return DatasetStatisticsConfig(
            root_dir=Path(dataset_statistics_config.root_dir),
            train_dataset=Path(cross_validation_config.root_dir, cross_validation_config.train_file_name),
            test_dataset=Path(cross_validation_config.root_dir, cross_validation_config.test_file_name),
            time_column=general_configs.time_column,
            target_column=general_configs.target_column,
            attribute_columns=general_configs.attribute_columns,
            output_statistics=Path(dataset_statistics_config.root_dir, dataset_statistics_config.output_file_name)
        )

    def get_dataset_statistics_kfp_config(
        self,
        train_dataset: str,
        test_dataset: str,
        time_column: str,
        target_column: str,
        attribute_columns: List[str],
        output_statistics: str
    ) -> DatasetStatisticsConfig:
        return DatasetStatisticsConfig(
            root_dir=None,
            train_dataset=Path(train_dataset),
            test_dataset=Path(test_dataset),
            time_column=time_column,
            target_column=target_column,
            attribute_columns=attribute_columns,
            output_statistics=Path(output_statistics)
        )

In [7]:
from typing import Dict
import pandas as pd
import json
from src.ShortTermForecast import logger

class DatasetStatistics:
    def __init__(self, config: DatasetStatisticsConfig):
        self.config = config
        self.stats = None

    def format_date(self, date):
        """Safely format a date, handling NaT values"""
        return date.strftime("%Y-%m-%d") if pd.notna(date) else None

    def calculate_statistics(self, df: pd.DataFrame, dataset_type: str, split_number: int) -> Dict[str, any]:
        df[self.config.time_column] = pd.to_datetime(df[self.config.time_column], errors='coerce')
            
        date_range_start = df[self.config.time_column].min()
        date_range_end = df[self.config.time_column].max()
        
        stats = {
            "dataset_type": dataset_type,
            "split_number": split_number,
            "total_rows": len(df),
            "date_range": {
                "start": self.format_date(date_range_start),
                "end": self.format_date(date_range_end)
            },
            "target_column": {
                "mean": float(df[self.config.target_column].mean()),
                "median": float(df[self.config.target_column].median()),
                "min": float(df[self.config.target_column].min()),
                "max": float(df[self.config.target_column].max()),
                "std": float(df[self.config.target_column].std()),
                "total": float(df[self.config.target_column].sum())
            },
            "null_counts": df.isnull().sum().to_dict(),
            "categorical_columns": {}
        }

        for col in self.config.attribute_columns:
            value_counts = df[col].value_counts()
            stats["categorical_columns"][col] = {
                "unique_values": int(df[col].nunique()),
                "top_5_values": value_counts.nlargest(5).to_dict(),
                "null_count": int(df[col].isnull().sum()),
                "total_count": int(len(df)),
                "distribution_percentage": value_counts.nlargest(5).apply(lambda x: float(x/len(df) * 100)).to_dict()
            }

        return stats

    def generate_statistics(self) -> Dict[str, Dict[str, Dict[str, any]]]:
        logger.info("Reading input datasets")
        try:
            train_df = pd.read_csv(self.config.train_dataset)
            test_df = pd.read_csv(self.config.test_dataset)
            
            train_df[self.config.time_column] = pd.to_datetime(train_df[self.config.time_column], errors='coerce')
            test_df[self.config.time_column] = pd.to_datetime(test_df[self.config.time_column], errors='coerce')

            all_statistics = {}

            logger.info("Processing training splits")
            for split_index in range(1, 5):
                split_train = train_df[train_df['split_index'] == split_index]
                logger.info(f"Processing training split {split_index} (shape: {split_train.shape})")
                all_statistics[f"train_split_{split_index}"] = self.calculate_statistics(split_train, "train", split_index)

            logger.info("Processing test splits")
            for split_index in range(1, 5):
                split_test = test_df[test_df['split_index'] == split_index]
                logger.info(f"Processing test split {split_index} (shape: {split_test.shape})")
                all_statistics[f"test_split_{split_index}"] = self.calculate_statistics(split_test, "test", split_index)
                
            logger.info("Calculating overall statistics")
            all_data = pd.concat([train_df, test_df])
            all_statistics["overall"] = self.calculate_statistics(all_data, "overall", 0)

            earliest_date = all_data[self.config.time_column].min()
            latest_date = all_data[self.config.time_column].max()
            
            all_statistics["summary"] = {
                "total_rows": {
                    "train": len(train_df),
                    "test": len(test_df),
                    "total": len(all_data)
                },
                "date_range": {
                    "earliest": self.format_date(earliest_date),
                    "latest": self.format_date(latest_date)
                },
                "splits_info": {
                    f"split_{i}": {
                        "train_rows": len(train_df[train_df['split_index'] == i]),
                        "test_rows": len(test_df[test_df['split_index'] == i])
                    } for i in range(1, 5)
                }
            }

            self.stats = all_statistics
        except Exception as e:
            logger.error(f"Error generating statistics: {str(e)}")
            raise

    def save_statistics(self):
        if self.stats is None:
            logger.error("No statistics to save. Run generate_statistics() first.")
            return

        logger.info(f"Saving statistics to {self.config.output_statistics}")
        try:
            with open(self.config.output_statistics, "w") as f:
                json.dump(self.stats, f, indent=2)
            logger.info("Statistics saved successfully")
        except Exception as e:
            logger.error(f"Error saving statistics: {str(e)}")
            raise

In [8]:
from src.ShortTermForecast import logger

STAGE_NAME = "Dataset Statistics Generation"

class DatasetStatisticsPipeline:
    def __init__(self):
        pass

    def main(self):
        config = ConfigurationManager()
        dataset_statistics = DatasetStatistics(config.get_dataset_statistics_dvc_config())
        dataset_statistics.generate_statistics()
        dataset_statistics.save_statistics()

if __name__ == '__main__':
    try:
        logger.info(f">>>>>> stage {STAGE_NAME} started <<<<<<")
        obj = DatasetStatisticsPipeline()
        obj.main()
        logger.info(f">>>>>> stage {STAGE_NAME} completed <<<<<<")
        logger.info("\nx" + "=" * 50 + "x")
    except Exception as e:
        logger.exception(e)
        raise e

[2025-04-15 20:28:31,551: INFO: 3603232476] >>>>>> stage Dataset Statistics Generation started <<<<<<
[2025-04-15 20:28:31,557: INFO: common] yaml file: config/config.yaml loaded successfully
[2025-04-15 20:28:31,560: INFO: common] Creating directory: artifacts
[2025-04-15 20:28:31,562: INFO: common] yaml file: params.yaml loaded successfully
[2025-04-15 20:28:31,564: INFO: common] Creating directory: artifacts/dataset_statistics
[2025-04-15 20:28:31,565: INFO: 3781864347] Reading input datasets
[2025-04-15 20:28:31,574: INFO: 3781864347] Processing training splits
[2025-04-15 20:28:31,576: INFO: 3781864347] Processing training split 1 (shape: (0, 12))
[2025-04-15 20:28:31,598: INFO: 3781864347] Processing training split 2 (shape: (0, 12))
[2025-04-15 20:28:31,607: INFO: 3781864347] Processing training split 3 (shape: (0, 12))
[2025-04-15 20:28:31,615: INFO: 3781864347] Processing training split 4 (shape: (0, 12))
[2025-04-15 20:28:31,624: INFO: 3781864347] Processing test splits
[2025