In [None]:
#!/usr/bin/env python3
"""
Fish Price Prediction Pipeline
==============================================

A comprehensive machine learning pipeline for predicting fish prices using multiple algorithms:
- Neural Networks (LSTM-based)
- Random Forest
- XGBoost

Author: AI Assistant
Version: 1.0.0
"""

import os
import sys
import logging
import warnings
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union, Any
import json
import pickle

# Third-party imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import xgboost as xgb

# Configure warnings and logging
warnings.filterwarnings('ignore')
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.StreamHandler(),
        logging.FileHandler('fish_price_pipeline.log')
    ]
)
logger = logging.getLogger(__name__)


class Config:
    """Configuration class for fish price prediction pipeline"""

    def __init__(self):
        # Data parameters
        self.horizon = 7  # prediction horizon in days
        self.sequence_length = 14  # lookback window
        self.test_size = 0.2
        self.val_size = 0.2
        self.random_state = 42

        # Features to normalize with StandardScaler (zero mean, unit variance)
        self.standardize_features = [
            'temperature_2m_mean (°C)', 'wind_speed_10m_max (km/h)',
            'wind_gusts_10m_max (km/h)', 'cloud_cover_mean (%)',
            'precipitation_sum (mm)', 'relative_humidity_2m_mean (%)',
            'wet_bulb_temperature_2m_mean (°C)', 'wind_speed_10m_mean (km/h)',
            'wind_gusts_10m_mean (km/h)', 'surface_pressure_mean (hPa)',
            'rain_sum (mm)', 'pressure_msl_mean (hPa)',
            'shortwave_radiation_sum (MJ/m²)', 'et0_fao_evapotranspiration (mm)',
            'wind_direction_10m_dominant (°)', 'sunshine_duration (s)',
            'wave_height_max (m)', 'wind_wave_height_max (m)',
            'swell_wave_height_max (m)', 'wave_period_max (s)',
            'wind_wave_period_max (s)', 'wave_direction_dominant (°)'
        ]

        # Features to normalize by dividing by maximum value
        self.normalize_features = [
            'dollar_rate', 'Kerosene (LK)', 'Diesel (LAD)', 'Super Diesel (LSD)'
        ]

        # Target columns
        self.target_columns = ['avg_ws_price', 'avg_rt_price']

        # Model parameters
        self.model_type = 'random_forest'  # 'neural_network', 'random_forest', or 'xgboost'

        # Neural Network parameters
        self.hidden_size = 128
        self.num_layers = 3
        self.dropout = 0.2
        self.learning_rate = 0.001
        self.batch_size = 32
        self.epochs = 100
        self.patience = 10

        # Random Forest parameters
        self.rf_n_estimators = 100
        self.rf_max_depth = 20
        self.rf_min_samples_split = 5
        self.rf_min_samples_leaf = 2

        # XGBoost parameters
        self.xgb_n_estimators = 100
        self.xgb_max_depth = 6
        self.xgb_learning_rate = 0.1
        self.xgb_subsample = 0.8
        self.xgb_colsample_bytree = 0.8
        self.xgb_reg_alpha = 0.1
        self.xgb_reg_lambda = 1.0
        self.xgb_early_stopping_rounds = 10

        # Data handling parameters
        self.min_non_zero_ratio = 0.1
        self.max_date_gap_days = 3

    def validate(self) -> bool:
        """Validate configuration parameters"""
        try:
            assert self.model_type in ['neural_network', 'random_forest', 'xgboost']
            assert 0 < self.test_size < 1
            assert 0 < self.val_size < 1
            assert self.test_size + self.val_size < 1
            assert self.horizon > 0
            assert self.sequence_length > 0
            assert len(self.target_columns) > 0
            return True
        except AssertionError as e:
            logger.error(f"Configuration validation failed: {e}")
            return False


class DataValidator:
    """Data validation utilities"""

    @staticmethod
    def validate_dataframe(df: pd.DataFrame, required_columns: List[str]) -> Tuple[bool, List[str]]:
        """Validate DataFrame structure and content"""
        errors = []

        if df is None or df.empty:
            errors.append("DataFrame is None or empty")
            return False, errors

        # Check required columns
        missing_cols = [col for col in required_columns if col not in df.columns]
        if missing_cols:
            errors.append(f"Missing required columns: {missing_cols}")

        # Check for duplicates
        if df.duplicated().any():
            errors.append(f"Found {df.duplicated().sum()} duplicate rows")

        # Check data types
        if 'Date' in df.columns:
            try:
                pd.to_datetime(df['Date'])
            except Exception:
                errors.append("Date column cannot be converted to datetime")

        return len(errors) == 0, errors


class FishPriceDataProcessor:
    """Data preprocessing and feature engineering for fish price prediction"""

    def __init__(self, config: Config):
        self.config = config
        self.standard_scaler = StandardScaler()
        self.normalizers = {}
        self.fish_encoder = LabelEncoder()
        self.feature_columns = []
        self.is_fitted = False

    def validate_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Validate input data and ensure required columns exist"""
        required_cols = ['Date', 'Fish Type'] + self.config.target_columns
        is_valid, errors = DataValidator.validate_dataframe(df, required_cols)

        if not is_valid:
            raise ValueError(f"Data validation failed: {'; '.join(errors)}")

        df = df.copy()

        try:
            df['Date'] = pd.to_datetime(df['Date'])
        except Exception as e:
            raise ValueError(f"Error converting Date column: {e}")

        if len(df) < self.config.sequence_length + self.config.horizon:
            raise ValueError(
                f"Insufficient data: need at least {self.config.sequence_length + self.config.horizon} rows, got {len(df)}"
            )

        return df

    def create_seasonal_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create seasonal and temporal features from date"""
        df = df.copy()

        try:
            # Extract temporal components
            df['day_of_year'] = df['Date'].dt.dayofyear
            df['week_of_year'] = df['Date'].dt.isocalendar().week.astype(int)
            df['month'] = df['Date'].dt.month
            df['quarter'] = df['Date'].dt.quarter
            df['day_of_week'] = df['Date'].dt.dayofweek

            # Create cyclical features to capture seasonal patterns
            df['day_of_year_sin'] = np.sin(2 * np.pi * df['day_of_year'] / 365.25)
            df['day_of_year_cos'] = np.cos(2 * np.pi * df['day_of_year'] / 365.25)
            df['week_of_year_sin'] = np.sin(2 * np.pi * df['week_of_year'] / 52)
            df['week_of_year_cos'] = np.cos(2 * np.pi * df['week_of_year'] / 52)
            df['month_sin'] = np.sin(2 * np.pi * df['month'] / 12)
            df['month_cos'] = np.cos(2 * np.pi * df['month'] / 12)
            df['day_of_week_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
            df['day_of_week_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)

            return df
        except Exception as e:
            logger.error(f"Error creating seasonal features: {e}")
            raise

    def create_historical_features(self, df: pd.DataFrame) -> pd.DataFrame:
        """Create historical price features for each fish type"""
        df = df.copy()
        df = df.sort_values(['Fish Type', 'Date'])

        try:
            for fish_type in df['Fish Type'].unique():
                fish_mask = df['Fish Type'] == fish_type
                fish_data = df[fish_mask].copy()

                # Create rolling statistics
                for window in [7, 14, 30]:
                    for target in self.config.target_columns:
                        if target not in fish_data.columns:
                            continue

                        valid_prices = fish_data[target].replace(0, np.nan)

                        # Rolling mean
                        rolling_mean = valid_prices.rolling(
                            window=window, min_periods=window
                        ).mean()
                        df.loc[fish_mask, f'{target}_rolling_{window}d'] = rolling_mean

                        # Rolling standard deviation
                        rolling_std = valid_prices.rolling(
                            window=window, min_periods=window
                        ).std()
                        df.loc[fish_mask, f'{target}_rolling_std_{window}d'] = rolling_std

                # Create lag features
                for lag in [1, 3, 7]:
                    for target in self.config.target_columns:
                        if target in fish_data.columns:
                            df.loc[fish_mask, f'{target}_lag_{lag}d'] = fish_data[target].shift(lag)

            return df
        except Exception as e:
            logger.error(f"Error creating historical features: {e}")
            raise

    def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and prepare data"""
        df = df.copy()

        try:
            # Handle target columns
            for col in self.config.target_columns:
                if col in df.columns:
                    df[col] = pd.to_numeric(df[col], errors='coerce')
                    df[col] = df[col].replace(0, np.nan)

            # Drop rows where all target columns are NaN
            initial_length = len(df)
            df = df.dropna(subset=self.config.target_columns, how='all')
            logger.info(f"Removed {initial_length - len(df)} rows with all missing targets")

            # Fill remaining NaN targets with 0
            df[self.config.target_columns] = df[self.config.target_columns].fillna(0)

            return df
        except Exception as e:
            logger.error(f"Error cleaning data: {e}")
            raise

    def fit_transform(self, df: pd.DataFrame) -> pd.DataFrame:
        """Fit transformers and transform the data"""
        try:
            # Validate and clean data
            df = self.validate_data(df)
            df = self.clean_data(df)

            # Create features
            df = self.create_seasonal_features(df)
            df = self.create_historical_features(df)

            # Handle NaN values
            feature_cols = [col for col in df.columns if any(x in col for x in ['lag', 'rolling'])]
            if feature_cols:
                for col in feature_cols:
                    df[col] = df[col].fillna(method='ffill').fillna(method='bfill').fillna(0)

            df = df.fillna(0)

            # Encode categorical variables
            if 'Fish Type' in df.columns:
                df['Fish Type_encoded'] = self.fish_encoder.fit_transform(df['Fish Type'])

            # Apply scaling
            available_std_features = [col for col in self.config.standardize_features if col in df.columns]
            if available_std_features:
                df[available_std_features] = self.standard_scaler.fit_transform(df[available_std_features])

            # Apply normalization
            for feature in self.config.normalize_features:
                if feature in df.columns:
                    max_val = df[feature].max()
                    if max_val == 0 or np.isnan(max_val):
                        max_val = 1.0
                        logger.warning(f"Max value for {feature} is 0 or NaN, using 1.0 for normalization")
                    self.normalizers[feature] = max_val
                    df[feature] = df[feature] / max_val

            # Define feature columns
            exclude_cols = [
                'Date', 'Fish Type', 'index', 'day_of_year', 'week_of_year',
                'month', 'quarter', 'day_of_week'
            ] + self.config.target_columns
            exclude_cols.extend([col for col in df.columns if col.startswith('Unnamed')])

            self.feature_columns = [col for col in df.columns if col not in exclude_cols]
            self.is_fitted = True

            logger.info(f"Feature engineering completed. Total features: {len(self.feature_columns)}")
            return df

        except Exception as e:
            logger.error(f"Error in fit_transform: {e}")
            raise

    def transform(self, df: pd.DataFrame) -> pd.DataFrame:
        """Transform new data using fitted transformers"""
        if not self.is_fitted:
            raise ValueError("Transformers not fitted. Call fit_transform first.")

        try:
            df = self.validate_data(df)
            df = self.clean_data(df)
            df = self.create_seasonal_features(df)
            df = self.create_historical_features(df)

            # Handle NaN values
            feature_cols = [col for col in df.columns if any(x in col for x in ['lag', 'rolling'])]
            if feature_cols:
                for col in feature_cols:
                    df[col] = df[col].fillna(method='ffill').fillna(method='bfill').fillna(0)

            df = df.fillna(0)

            # Apply transformations
            if 'Fish Type' in df.columns:
                # Handle unknown fish types
                known_types = self.fish_encoder.classes_
                df.loc[~df['Fish Type'].isin(known_types), 'Fish Type'] = known_types[0]
                df['Fish Type_encoded'] = self.fish_encoder.transform(df['Fish Type'])

            # Apply scaling
            available_std_features = [col for col in self.config.standardize_features if col in df.columns]
            if available_std_features:
                df[available_std_features] = self.standard_scaler.transform(df[available_std_features])

            # Apply normalization
            for feature in self.config.normalize_features:
                if feature in df.columns and feature in self.normalizers:
                    df[feature] = df[feature] / self.normalizers[feature]

            return df

        except Exception as e:
            logger.error(f"Error in transform: {e}")
            raise


class FishPriceDataset(Dataset):
    """PyTorch Dataset for fish price prediction"""

    def __init__(self, data: pd.DataFrame, processor: FishPriceDataProcessor, config: Config, mode: str = 'train'):
        self.data = data.copy().sort_values(['Fish Type', 'Date'])
        self.processor = processor
        self.config = config
        self.mode = mode
        self.sequences = []

        try:
            self.sequences = self._create_sequences()
        except Exception as e:
            logger.error(f"Error creating sequences: {e}")
            raise

    def _create_sequences(self) -> List[Dict[str, np.ndarray]]:
        """Create sequences for time series prediction"""
        sequences = []

        feature_columns = self.processor.feature_columns
        if not feature_columns:
            raise ValueError("No feature columns found")

        missing_features = [col for col in feature_columns if col not in self.data.columns]
        if missing_features:
            raise ValueError(f"Missing feature columns: {missing_features}")

        data_reset = self.data.reset_index(drop=True)
        feature_data = data_reset[feature_columns].values
        target_data = data_reset[self.config.target_columns].values

        for fish_type_encoded in data_reset['Fish Type_encoded'].unique():
            fish_positions = data_reset.index[data_reset['Fish Type_encoded'] == fish_type_encoded].tolist()

            if len(fish_positions) < self.config.sequence_length + self.config.horizon:
                logger.warning(f"Insufficient data for fish type {fish_type_encoded} ({len(fish_positions)} rows)")
                continue

            for i in range(len(fish_positions) - self.config.sequence_length - self.config.horizon + 1):
                seq_start_pos = i
                seq_end_pos = seq_start_pos + self.config.sequence_length
                target_pos = seq_end_pos + self.config.horizon - 1

                seq_positions = fish_positions[seq_start_pos:seq_end_pos]
                target_position = fish_positions[target_pos]

                # Check date continuity
                seq_dates = data_reset.loc[seq_positions, 'Date']
                max_gap = timedelta(days=self.config.max_date_gap_days)
                if len(seq_dates) > 1 and (seq_dates.diff().dropna() > max_gap).any():
                    continue

                try:
                    X = feature_data[seq_positions]
                    y = target_data[target_position]

                    if np.any(y != 0) and not np.any(np.isnan(X)) and not np.any(np.isnan(y)):
                        sequences.append({'features': X.astype(np.float32), 'targets': y.astype(np.float32)})

                except (IndexError, ValueError) as e:
                    logger.debug(f"Skipping sequence due to error: {e}")
                    continue

        logger.info(f"Created {len(sequences)} sequences for {self.mode} dataset")
        return sequences

    def __len__(self) -> int:
        return len(self.sequences)

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        if idx >= len(self.sequences):
            raise IndexError(f"Index {idx} out of range for dataset of size {len(self.sequences)}")

        sequence = self.sequences[idx]
        features = torch.from_numpy(sequence['features'])
        targets = torch.from_numpy(sequence['targets'])

        return features, targets


class FishPriceNN(nn.Module):
    """Neural Network for fish price prediction"""

    def __init__(self, input_size: int, config: Config):
        super(FishPriceNN, self).__init__()
        self.config = config

        # LSTM layers
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=config.hidden_size,
            num_layers=config.num_layers,
            dropout=config.dropout if config.num_layers > 1 else 0,
            batch_first=True
        )

        # Fully connected layers
        self.dropout = nn.Dropout(config.dropout)
        self.fc1 = nn.Linear(config.hidden_size, config.hidden_size // 2)
        self.fc2 = nn.Linear(config.hidden_size // 2, len(config.target_columns))
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        lstm_out, _ = self.lstm(x)
        last_output = lstm_out[:, -1, :]
        output = self.dropout(last_output)
        output = self.relu(self.fc1(output))
        output = self.dropout(output)
        output = self.fc2(output)
        return output


class FishPricePipeline:
    """Complete pipeline for fish price prediction"""

    def __init__(self, config: Config):
        if not config.validate():
            raise ValueError("Invalid configuration")

        self.config = config
        self.processor = FishPriceDataProcessor(config)
        self.model = None
        self.is_trained = False

    def load_data(self, csv_path: str) -> pd.DataFrame:
        """Load data from CSV file"""
        if not os.path.exists(csv_path):
            raise FileNotFoundError(f"CSV file not found: {csv_path}")

        try:
            df = pd.read_csv(csv_path)
            logger.info(f"Loaded data with shape: {df.shape}")

            # Basic validation
            if df.empty:
                raise ValueError("Loaded DataFrame is empty")

            return df
        except Exception as e:
            logger.error(f"Error loading CSV file: {e}")
            raise

    def prepare_data(self, df: pd.DataFrame) -> Tuple[FishPriceDataset, FishPriceDataset, FishPriceDataset]:
        """Prepare train, validation, and test datasets"""
        try:
            logger.info("Processing data...")
            processed_df = self.processor.fit_transform(df)
            processed_df = processed_df.sort_values('Date').reset_index(drop=True)

            # Time-based split
            unique_dates = sorted(processed_df['Date'].unique())
            n_dates = len(unique_dates)

            train_end_idx = int(n_dates * (1 - self.config.test_size - self.config.val_size))
            val_end_idx = int(n_dates * (1 - self.config.test_size))

            train_end_date = unique_dates[train_end_idx - 1]
            val_end_date = unique_dates[val_end_idx - 1]

            # Split datasets
            train_df = processed_df[processed_df['Date'] <= train_end_date].copy()
            val_df = processed_df[
                (processed_df['Date'] > train_end_date) &
                (processed_df['Date'] <= val_end_date)
            ].copy()
            test_df = processed_df[processed_df['Date'] > val_end_date].copy()

            logger.info(f"Data split - Train: {len(train_df)}, Val: {len(val_df)}, Test: {len(test_df)}")

            # Create datasets
            train_dataset = FishPriceDataset(train_df, self.processor, self.config, 'train')
            val_dataset = FishPriceDataset(val_df, self.processor, self.config, 'val')
            test_dataset = FishPriceDataset(test_df, self.processor, self.config, 'test')

            if len(train_dataset) == 0:
                raise ValueError("No training sequences created")
            if len(val_dataset) == 0:
                raise ValueError("No validation sequences created")

            return train_dataset, val_dataset, test_dataset

        except Exception as e:
            logger.error(f"Error preparing data: {e}")
            raise

    def train_neural_network(self, train_dataset: FishPriceDataset, val_dataset: FishPriceDataset):
        """Train neural network model"""
        try:
            train_loader = DataLoader(train_dataset, batch_size=self.config.batch_size, shuffle=True)
            val_loader = DataLoader(val_dataset, batch_size=self.config.batch_size, shuffle=False)

            input_size = len(self.processor.feature_columns)
            self.model = FishPriceNN(input_size, self.config)

            criterion = nn.MSELoss(reduction='none')
            optimizer = optim.Adam(self.model.parameters(), lr=self.config.learning_rate)
            scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5)

            best_val_loss = float('inf')
            patience_counter = 0

            logger.info("Starting neural network training...")

            for epoch in range(self.config.epochs):
                # Training phase
                self.model.train()
                total_train_loss = 0
                train_samples = 0

                for batch_features, batch_targets in train_loader:
                    optimizer.zero_grad()
                    outputs = self.model(batch_features)

                    loss = criterion(outputs, batch_targets)
                    weights = (batch_targets != 0).float() + 0.1
                    weighted_loss = (loss * weights).mean()

                    weighted_loss.backward()
                    optimizer.step()

                    total_train_loss += weighted_loss.item() * batch_features.size(0)
                    train_samples += batch_features.size(0)

                # Validation phase
                self.model.eval()
                total_val_loss = 0
                val_samples = 0

                with torch.no_grad():
                    for batch_features, batch_targets in val_loader:
                        outputs = self.model(batch_features)
                        loss = criterion(outputs, batch_targets)
                        weights = (batch_targets != 0).float() + 0.1
                        weighted_loss = (loss * weights).mean()

                        total_val_loss += weighted_loss.item() * batch_features.size(0)
                        val_samples += batch_features.size(0)

                avg_train_loss = total_train_loss / train_samples if train_samples > 0 else 0
                avg_val_loss = total_val_loss / val_samples if val_samples > 0 else 0

                scheduler.step(avg_val_loss)

                if epoch % 10 == 0:
                    logger.info(f"Epoch {epoch}/{self.config.epochs} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f}")

                # Early stopping
                if avg_val_loss < best_val_loss:
                    best_val_loss = avg_val_loss
                    patience_counter = 0
                    torch.save(self.model.state_dict(), 'best_fish_price_model.pth')
                else:
                    patience_counter += 1
                    if patience_counter >= self.config.patience:
                        logger.info(f"Early stopping at epoch {epoch}. Best Val Loss: {best_val_loss:.4f}")
                        break

            # Load best model
            if os.path.exists('best_fish_price_model.pth'):
                self.model.load_state_dict(torch.load('best_fish_price_model.pth'))

            self.is_trained = True
            logger.info("Neural network training completed!")

        except Exception as e:
            logger.error(f"Error training neural network: {e}")
            raise

    def train_random_forest(self, train_dataset: FishPriceDataset, val_dataset: FishPriceDataset):
        """Train Random Forest model"""
        try:
            logger.info("Preparing Random Forest features...")

            train_features, train_targets = self._extract_rf_features(train_dataset)

            logger.info(f"Random Forest feature shape: {train_features.shape}")

            self.model = {}
            for i, target_name in enumerate(self.config.target_columns):
                mask = train_targets[:, i] != 0

                if np.sum(mask) < 10:
                    logger.warning(f"Insufficient non-zero samples for {target_name}")
                    continue

                logger.info(f"Training Random Forest for {target_name} with {np.sum(mask)} samples")

                rf = RandomForestRegressor(
                    n_estimators=self.config.rf_n_estimators,
                    max_depth=self.config.rf_max_depth,
                    min_samples_split=self.config.rf_min_samples_split,
                    min_samples_leaf=self.config.rf_min_samples_leaf,
                    random_state=self.config.random_state,
                    n_jobs=-1
                )

                rf.fit(train_features[mask], train_targets[mask, i])
                self.model[target_name] = rf

            self.is_trained = True
            logger.info("Random Forest training completed!")

        except Exception as e:
            logger.error(f"Error training Random Forest: {e}")
            raise

    def train_xgboost(self, train_dataset: FishPriceDataset, val_dataset: FishPriceDataset):
        """Train XGBoost model"""
        try:
            logger.info("Preparing XGBoost features...")

            train_features, train_targets = self._extract_xgb_features(train_dataset)
            val_features, val_targets = self._extract_xgb_features(val_dataset)

            logger.info(f"XGBoost feature shape: {train_features.shape}")

            self.model = {}
            for i, target_name in enumerate(self.config.target_columns):
                train_mask = train_targets[:, i] != 0
                val_mask = val_targets[:, i] != 0

                if np.sum(train_mask) < 10:
                    logger.warning(f"Insufficient non-zero samples for {target_name}")
                    continue

                logger.info(f"Training XGBoost for {target_name} with {np.sum(train_mask)} training samples")

                dtrain = xgb.DMatrix(train_features[train_mask], label=train_targets[train_mask, i])

                if np.sum(val_mask) > 0:
                    dval = xgb.DMatrix(val_features[val_mask], label=val_targets[val_mask, i])
                    evallist = [(dtrain, 'train'), (dval, 'val')]
                else:
                    evallist = [(dtrain, 'train')]

                params = {
                    'objective': 'reg:squarederror',
                    'max_depth': self.config.xgb_max_depth,
                    'learning_rate': self.config.xgb_learning_rate,
                    'subsample': self.config.xgb_subsample,
                    'colsample_bytree': self.config.xgb_colsample_bytree,
                    'reg_alpha': self.config.xgb_reg_alpha,
                    'reg_lambda': self.config.xgb_reg_lambda,
                    'random_state': self.config.random_state,
                    'verbosity': 0,
                    'eval_metric': 'rmse'
                }

                xgb_model = xgb.train(
                    params,
                    dtrain,
                    num_boost_round=self.config.xgb_n_estimators,
                    evals=evallist,
                    early_stopping_rounds=self.config.xgb_early_stopping_rounds,
                    verbose_eval=False
                )

                self.model[target_name] = xgb_model

            self.is_trained = True
            logger.info("XGBoost training completed!")

        except Exception as e:
            logger.error(f"Error training XGBoost: {e}")
            raise

    def _extract_rf_features(self, dataset: FishPriceDataset) -> Tuple[np.ndarray, np.ndarray]:
        """Extract features for Random Forest"""
        features, targets = [], []

        for seq_features, seq_targets in dataset:
            seq_data = seq_features.numpy()
            seq_stats = np.concatenate([
                seq_data.mean(axis=0),
                seq_data.std(axis=0),
                seq_data.max(axis=0),
                seq_data.min(axis=0),
                seq_data[-1, :],
                seq_data[0, :],
            ])

            features.append(seq_stats)
            targets.append(seq_targets.numpy())

        return np.array(features), np.array(targets)

    def _extract_xgb_features(self, dataset: FishPriceDataset) -> Tuple[np.ndarray, np.ndarray]:
        """Extract features for XGBoost"""
        features, targets = [], []

        for seq_features, seq_targets in dataset:
            seq_data = seq_features.numpy()
            seq_stats = np.concatenate([
                seq_data.mean(axis=0),
                seq_data.std(axis=0),
                seq_data.max(axis=0),
                seq_data.min(axis=0),
                seq_data[-1, :],
                seq_data[0, :],
                np.median(seq_data, axis=0),
                np.percentile(seq_data, 25, axis=0),
                np.percentile(seq_data, 75, axis=0),
            ])

            features.append(seq_stats)
            targets.append(seq_targets.numpy())

        return np.array(features), np.array(targets)

    def predict(self, test_dataset: FishPriceDataset) -> np.ndarray:
        """Make predictions on test dataset"""
        if not self.is_trained:
            raise ValueError("Model not trained. Call train method first.")

        try:
            if self.config.model_type == 'neural_network':
                return self._predict_neural_network(test_dataset)
            elif self.config.model_type == 'xgboost':
                return self._predict_xgboost(test_dataset)
            else:
                return self._predict_random_forest(test_dataset)
        except Exception as e:
            logger.error(f"Error making predictions: {e}")
            raise

    def _predict_neural_network(self, test_dataset: FishPriceDataset) -> np.ndarray:
        """Predict using neural network"""
        self.model.eval()
        predictions = []
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False)

        with torch.no_grad():
            for features, _ in test_loader:
                outputs = self.model(features)
                predictions.append(outputs.cpu().numpy())

        return np.vstack(predictions) if predictions else np.array([])

    def _predict_random_forest(self, test_dataset: FishPriceDataset) -> np.ndarray:
        """Predict using random forest"""
        test_features, _ = self._extract_rf_features(test_dataset)

        if len(test_features) == 0:
            return np.array([])

        predictions = np.zeros((len(test_features), len(self.config.target_columns)))

        for i, target_name in enumerate(self.config.target_columns):
            if target_name in self.model:
                predictions[:, i] = self.model[target_name].predict(test_features)
            else:
                logger.warning(f"No model found for {target_name}")

        return predictions

    def _predict_xgboost(self, test_dataset: FishPriceDataset) -> np.ndarray:
        """Predict using XGBoost"""
        test_features, _ = self._extract_xgb_features(test_dataset)

        if len(test_features) == 0:
            return np.array([])

        predictions = np.zeros((len(test_features), len(self.config.target_columns)))
        dtest = xgb.DMatrix(test_features)

        for i, target_name in enumerate(self.config.target_columns):
            if target_name in self.model:
                predictions[:, i] = self.model[target_name].predict(dtest)
            else:
                logger.warning(f"No model found for {target_name}")

        return predictions

    def evaluate(self, test_dataset: FishPriceDataset) -> Dict[str, float]:
        """Evaluate model performance"""
        if len(test_dataset) == 0:
            return {}

        try:
            predictions = self.predict(test_dataset)
            if len(predictions) == 0:
                return {}

            actual_targets = np.array([t.numpy() for _, t in test_dataset])

            metrics = {}
            for i, target_name in enumerate(self.config.target_columns):
                mask = actual_targets[:, i] != 0

                if np.sum(mask) == 0:
                    logger.warning(f"No non-zero targets for {target_name}")
                    continue

                y_true = actual_targets[mask, i]
                y_pred = predictions[mask, i]

                metrics[f'{target_name}_mse'] = mean_squared_error(y_true, y_pred)
                metrics[f'{target_name}_rmse'] = np.sqrt(metrics[f'{target_name}_mse'])
                metrics[f'{target_name}_mae'] = mean_absolute_error(y_true, y_pred)

                try:
                    metrics[f'{target_name}_r2'] = r2_score(y_true, y_pred)
                except Exception:
                    metrics[f'{target_name}_r2'] = float('nan')

            return metrics

        except Exception as e:
            logger.error(f"Error evaluating model: {e}")
            raise

    def save_pipeline(self, path: str):
        """Save the complete pipeline"""
        try:
            pipeline_data = {
                'config': self.config,
                'processor': self.processor
            }

            with open(f'{path}_pipeline.pkl', 'wb') as f:
                pickle.dump(pipeline_data, f)

            if self.config.model_type == 'neural_network' and self.model is not None:
                torch.save(self.model.state_dict(), f'{path}_model.pth')
            elif self.config.model_type in ['random_forest', 'xgboost'] and self.model is not None:
                with open(f'{path}_model.pkl', 'wb') as f:
                    pickle.dump(self.model, f)

            logger.info(f"Pipeline and model saved to {path}_...")

        except Exception as e:
            logger.error(f"Error saving pipeline: {e}")
            raise

    def load_pipeline(self, path: str):
        """Load a saved pipeline"""
        try:
            with open(f'{path}_pipeline.pkl', 'rb') as f:
                pipeline_data = pickle.load(f)

            self.config = pipeline_data['config']
            self.processor = pipeline_data['processor']

            if self.config.model_type == 'neural_network':
                input_size = len(self.processor.feature_columns)
                self.model = FishPriceNN(input_size, self.config)
                self.model.load_state_dict(torch.load(f'{path}_model.pth'))
                self.model.eval()
            else:
                with open(f'{path}_model.pkl', 'rb') as f:
                    self.model = pickle.load(f)

            self.is_trained = True
            logger.info(f"Pipeline loaded from {path}_...")

        except Exception as e:
            logger.error(f"Error loading pipeline: {e}")
            raise

    def predict_future(self, df: pd.DataFrame, fish_type: str, n_days: int = 1) -> Dict[str, List[float]]:
        """Predict future prices for a specific fish type"""
        if not self.is_trained:
            raise ValueError("Model not trained. Call train method first.")

        try:
            processed_df = self.processor.transform(df)
            fish_data = processed_df[processed_df['Fish Type'] == fish_type].copy()
            fish_data = fish_data.sort_values('Date').tail(self.config.sequence_length)

            if len(fish_data) < self.config.sequence_length:
                raise ValueError(f"Insufficient data for {fish_type}. Need at least {self.config.sequence_length} days.")

            predictions = {target: [] for target in self.config.target_columns}
            current_sequence = fish_data[self.processor.feature_columns].values

            for day in range(n_days):
                if self.config.model_type == 'neural_network':
                    seq_tensor = torch.from_numpy(current_sequence.astype(np.float32)).unsqueeze(0)
                    self.model.eval()
                    with torch.no_grad():
                        pred = self.model(seq_tensor).cpu().numpy()[0]
                else:
                    if self.config.model_type == 'xgboost':
                        seq_stats = np.concatenate([
                            current_sequence.mean(axis=0),
                            current_sequence.std(axis=0),
                            current_sequence.max(axis=0),
                            current_sequence.min(axis=0),
                            current_sequence[-1, :],
                            current_sequence[0, :],
                            np.median(current_sequence, axis=0),
                            np.percentile(current_sequence, 25, axis=0),
                            np.percentile(current_sequence, 75, axis=0),
                        ]).reshape(1, -1)
                    else:
                        seq_stats = np.concatenate([
                            current_sequence.mean(axis=0),
                            current_sequence.std(axis=0),
                            current_sequence.max(axis=0),
                            current_sequence.min(axis=0),
                            current_sequence[-1, :],
                            current_sequence[0, :],
                        ]).reshape(1, -1)

                    pred = np.zeros(len(self.config.target_columns))
                    for i, target_name in enumerate(self.config.target_columns):
                        if target_name in self.model:
                            if self.config.model_type == 'xgboost':
                                dtest = xgb.DMatrix(seq_stats)
                                pred[i] = self.model[target_name].predict(dtest)[0]
                            else:
                                pred[i] = self.model[target_name].predict(seq_stats)[0]

                for i, target in enumerate(self.config.target_columns):
                    predictions[target].append(max(0, pred[i]))

                current_sequence = np.roll(current_sequence, -1, axis=0)
                current_sequence[-1] = current_sequence[-2]

            return predictions

        except Exception as e:
            logger.error(f"Error predicting future prices: {e}")
            raise


def print_model_specific_usage(model_type: str):
    """Print model-specific usage instructions"""
    print("\n" + "="*60)
    print(f"FISH PRICE PREDICTION PIPELINE - {model_type.upper().replace('_', ' ')} MODEL")
    print("="*60)
    print()

    if model_type == 'neural_network':
        print("NEURAL NETWORK MODEL FEATURES:")
        print("- LSTM-based deep learning architecture")
        print("- Handles sequential patterns in time series data")
        print("- Automatic feature learning from raw sequences")
        print("- Early stopping and learning rate scheduling")
        print("- Best for: Complex temporal patterns, large datasets")
        print()
        print("KEY PARAMETERS (configurable in Config class):")
        print("- hidden_size: LSTM hidden units (default: 128)")
        print("- num_layers: Number of LSTM layers (default: 3)")
        print("- dropout: Dropout rate for regularization (default: 0.2)")
        print("- learning_rate: Adam optimizer learning rate (default: 0.001)")
        print("- batch_size: Training batch size (default: 32)")
        print("- epochs: Maximum training epochs (default: 100)")
        print("- patience: Early stopping patience (default: 10)")

    elif model_type == 'random_forest':
        print("RANDOM FOREST MODEL FEATURES:")
        print("- Ensemble of decision trees")
        print("- Robust to overfitting")
        print("- Feature importance analysis")
        print("- No hyperparameter tuning required")
        print("- Best for: Interpretable results, medium datasets")
        print()
        print("KEY PARAMETERS (configurable in Config class):")
        print("- rf_n_estimators: Number of trees (default: 100)")
        print("- rf_max_depth: Maximum tree depth (default: 20)")
        print("- rf_min_samples_split: Min samples to split (default: 5)")
        print("- rf_min_samples_leaf: Min samples per leaf (default: 2)")

    elif model_type == 'xgboost':
        print("XGBOOST MODEL FEATURES:")
        print("- Gradient boosting with advanced optimization")
        print("- Built-in regularization (L1/L2)")
        print("- Early stopping with validation monitoring")
        print("- Memory efficient and fast training")
        print("- Best for: High performance, structured data")
        print()
        print("KEY PARAMETERS (configurable in Config class):")
        print("- xgb_n_estimators: Boosting rounds (default: 100)")
        print("- xgb_max_depth: Maximum tree depth (default: 6)")
        print("- xgb_learning_rate: Learning rate (default: 0.1)")
        print("- xgb_subsample: Row subsampling (default: 0.8)")
        print("- xgb_colsample_bytree: Column subsampling (default: 0.8)")
        print("- xgb_reg_alpha: L1 regularization (default: 0.1)")
        print("- xgb_reg_lambda: L2 regularization (default: 1.0)")
        print("- xgb_early_stopping_rounds: Early stopping patience (default: 10)")

    print()
    print("USAGE EXAMPLES:")
    print("-" * 40)
    print("1. Train and save model:")
    print("   python fish_price_pipeline.py")
    print()
    print("2. Load and make predictions:")
    print("   predictions = load_and_predict(")
    print("       pipeline_path='fish_price_pipeline',")
    print("       data_path='your_data.csv',")
    print("       fish_type='Skipjack',")
    print("       n_days=7")
    print("   )")
    print()
    print("3. Change model configuration:")
    print("   config = Config()")
    print(f"   config.model_type = '{model_type}'")
    if model_type == 'neural_network':
        print("   config.hidden_size = 256  # Increase model capacity")
        print("   config.learning_rate = 0.0005  # Fine-tune learning")
    elif model_type == 'random_forest':
        print("   config.rf_n_estimators = 200  # More trees")
        print("   config.rf_max_depth = 30  # Deeper trees")
    elif model_type == 'xgboost':
        print("   config.xgb_learning_rate = 0.05  # Slower learning")
        print("   config.xgb_n_estimators = 200  # More boosting rounds")

    print()
    print("FILES CREATED:")
    print("- fish_price_pipeline_pipeline.pkl (preprocessor + config)")
    if model_type == 'neural_network':
        print("- fish_price_pipeline_model.pth (trained neural network)")
        print("- best_fish_price_model.pth (best model checkpoint)")
    else:
        print("- fish_price_pipeline_model.pkl (trained model)")

    print()
    print("LOG FILES:")
    print("- fish_price_pipeline.log (detailed execution log)")


def main():
    """Main function to run the fish price prediction pipeline"""
    try:
        config = Config()

        # Validate configuration
        if not config.validate():
            logger.error("Invalid configuration. Please check your settings.")
            return False

        logger.info("Fish Price Prediction Pipeline Starting...")
        logger.info(f"Model Type: {config.model_type}")
        logger.info(f"Sequence Length: {config.sequence_length} days")
        logger.info(f"Prediction Horizon: {config.horizon} day(s)")

        pipeline = FishPricePipeline(config)

        # Load data
        csv_path = 'Final data set 2025 08 10.csv'  # Update this path

        try:
            df = pipeline.load_data(csv_path)
        except FileNotFoundError:
            logger.error(f"CSV file not found: {csv_path}")
            print(f"Error: '{csv_path}' not found.")
            print("Please update the csv_path variable with the correct file path.")
            return False

        # Prepare datasets
        logger.info("Preparing datasets...")
        train_dataset, val_dataset, test_dataset = pipeline.prepare_data(df)
        logger.info(f"Dataset sizes: Train={len(train_dataset)}, Val={len(val_dataset)}, Test={len(test_dataset)}")

        if len(train_dataset) == 0:
            logger.error("No training data available")
            return False

        # Train model
        logger.info(f"Training {config.model_type.replace('_', ' ').title()} model...")

        if config.model_type == 'neural_network':
            pipeline.train_neural_network(train_dataset, val_dataset)
        elif config.model_type == 'xgboost':
            pipeline.train_xgboost(train_dataset, val_dataset)
        else:
            pipeline.train_random_forest(train_dataset, val_dataset)

        # Evaluate model
        if len(test_dataset) > 0:
            logger.info("Evaluating model performance...")
            metrics = pipeline.evaluate(test_dataset)

            if metrics:
                logger.info("Model Performance Metrics:")
                for metric, value in metrics.items():
                    if not np.isnan(value):
                        logger.info(f"  {metric}: {value:.4f}")
                    else:
                        logger.info(f"  {metric}: N/A")

                print("\nMODEL PERFORMANCE METRICS:")
                print("-" * 40)
                for metric, value in metrics.items():
                    if not np.isnan(value):
                        print(f"  {metric}: {value:.4f}")
                    else:
                        print(f"  {metric}: N/A")
            else:
                logger.warning("No evaluation metrics available")
        else:
            logger.warning("No test data available for evaluation")

        # Save pipeline
        logger.info("Saving pipeline...")
        pipeline.save_pipeline('fish_price_pipeline')

        # Example prediction
        try:
            fish_types = df['Fish Type'].unique()
            if len(fish_types) > 0:
                example_fish = fish_types[0]
                logger.info(f"Example prediction for {example_fish}:")
                future_predictions = pipeline.predict_future(df, example_fish, n_days=7)

                print(f"\nEXAMPLE PREDICTION FOR {example_fish.upper()}:")
                print("-" * 50)
                for target, preds in future_predictions.items():
                    print(f"  {target}: {[f'{p:.2f}' for p in preds]}")

                for target, preds in future_predictions.items():
                    logger.info(f"  {target}: {preds}")
        except Exception as e:
            logger.warning(f"Example prediction failed: {e}")

        logger.info("Fish Price Prediction Pipeline completed successfully!")
        return True

    except Exception as e:
        logger.error(f"Pipeline failed with error: {e}", exc_info=True)
        return False


def load_and_predict(pipeline_path: str, data_path: str, fish_type: str, n_days: int = 7) -> Optional[Dict[str, List[float]]]:
    """Utility function to load a saved pipeline and make predictions"""
    try:
        logger.info(f"Loading pipeline from {pipeline_path}")

        config = Config()
        pipeline = FishPricePipeline(config)
        pipeline.load_pipeline(pipeline_path)

        logger.info(f"Loading data from {data_path}")
        df = pipeline.load_data(data_path)

        logger.info(f"Making predictions for {fish_type} over {n_days} days")
        predictions = pipeline.predict_future(df, fish_type, n_days)

        print(f"\nPREDICTIONS FOR {fish_type.upper()} OVER NEXT {n_days} DAYS:")
        print("-" * 60)
        for target, preds in predictions.items():
            print(f"  {target}: {[f'{p:.2f}' for p in preds]}")

        return predictions

    except Exception as e:
        logger.error(f"Prediction failed: {e}", exc_info=True)
        return None


if __name__ == "__main__":
    success = main()

    if success:
        # Get the model type from config to show appropriate usage
        config = Config()  # This will have the default model_type
        print_model_specific_usage(config.model_type)
    else:
        print("\n" + "="*50)
        print("PIPELINE EXECUTION FAILED")
        print("="*50)
        print("Please check the log file 'fish_price_pipeline.log' for detailed error information.")
        print()
        print("Common issues and solutions:")
        print("1. Data file not found - Update the csv_path variable")
        print("2. Missing dependencies - Install required packages:")
        print("   pip install torch scikit-learn xgboost pandas numpy")
        print("3. Data format issues - Ensure CSV has required columns:")
        print("   ['Date', 'Fish Type', 'avg_ws_price', 'avg_rt_price']")
        print("4. Insufficient data - Ensure dataset has enough historical data")