In [None]:
# plot.py

"""
Visualization script for Bitcoin price prediction models.

This module contains functions to:
- Load prediction model results
- Generate comparison charts
- Display BTC price forecasts

#Refactored version with DRY principles and separation of concerns.
#"""
import os
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.dates as mdates
import logging
from typing import Optional, Dict, List, Any
from sklearn.linear_model import Ridge
from datetime import timedelta
from src.api.coingecko_client import CoinGeckoClient
from src.ml.btc_predictor import BTCPredictor
from src.pipeline.btc_pipeline import BTCDataPipeline


logger = logging.getLogger(__name__)


class BTCPlotter:
    """
    Bitcoin price plotter with ML model predictions.
    
    Generates individual plots for Linear and Ridge regression models,
    plus a comparison plot showing both models together.
    """

    def __init__(self, df: pd.DataFrame, output_dir: str = "plots"):
        """
        Initialize the plotter.
        
        Args:
            df: DataFrame with historical BTC data
            output_dir: Directory to save plot images
        """
        self.df = df
        self.btc_predictor = BTCPredictor
        self.pipeline = BTCDataPipeline
        self.output_dir = output_dir

        # Create output folder
        os.makedirs(self.output_dir, exist_ok=True)

        # Color scheme
        self.colors = {
            'real': '#2E86AB',      # Blue
            'linear': '#A23B72',    # Magenta
            'ridge': '#F18F01'      # Orange
        }

    # ============================================================
    # PRIVATE HELPER METHODS - REUSABLE COMPONENTS
    # ============================================================

    def _train_and_predict(
        self, 
        df: pd.DataFrame, 
        model_type: str,
        n_days_future: int,
        **model_kwargs
    ) -> Dict:
        """
        Train a model and generate predictions.
        
        Args:
            df: Historical data DataFrame
            model_type: 'linear' or 'ridge'
            n_days_future: Number of days to predict
            **model_kwargs: Additional model parameters (e.g., alpha for Ridge)
        
        Returns:
            Dictionary containing:
                - X: Training features
                - y: Training targets
                - predictions: Future predictions
                - r2_score: Model R¬≤ score
                - future_dates: DatetimeIndex for predictions
                - model_name: Human-readable model name
                - last_date: Last date in historical data
                - predictor: Trained predictor instance
        """
        # Select and configure model
        if model_type == 'linear':
            predictor = BTCPredictor()
            model_name = "Linear Regression"
        elif model_type == 'ridge':
            alpha = model_kwargs.get('alpha', 1.0)
            predictor = BTCPredictor(model=Ridge(alpha=alpha))
            model_name = f"Ridge Regression (Œ±={alpha})"
        else:
            raise ValueError(f"Unknown model type: {model_type}")
        

        df = df.sort_values('date').reset_index(drop=True)

        # Preparing daa
        X, y = predictor.prepare_training_data(df)

        # Train
        predictor.train(X, y)

        # 
        y_pred_train = predictor.model.predict(X)

        # Metrics
        r2 = predictor.model.score(X, y)
        mae = np.mean(np.abs(y - y_pred_train))
        rmse = np.sqrt(np.mean(y - y_pred_train))


        predictions = predictor.predict_future(n_days_future)


        # Generate future dates
        last_date = df['date'].iloc[-1]
        future_dates = pd.date_range(
            start=last_date + timedelta(days=1),
            periods=n_days_future,
            freq='D'
        )
        
        last_price = df['price_usd'].iloc[-1]

        return {
            'X': X,
            'y': y,
            'y_pred_train': y_pred_train,
            'predictions': predictions,
            'r2_score': r2,
            'mae': mae ,
            'rmse': rmse,
            'future_dates': future_dates,
            'model_name': model_name,
            'last_date': last_date,
            'predictor': predictor
        }

    def _plot_prediction_with_metrics(
        self,
        df: pd.DataFrame,
        model_data: Dict,
        color: str,
        filename: str,
        title_suffix: str = ""
    ) -> str:
        """
        Generate standardized plot: historical data + predictions + metrics.
        
        Args:
            df: DataFrame with historical data
            model_data: Dictionary returned by _train_and_predict()
            color: Color for prediction line
            filename: Output PNG filename
            title_suffix: Additional text for title
        
        Returns:
            Path to saved plot file
        """
        # Unpack model data
        y = model_data['y']
        y_pred_train = model_data['y_pred_train']
        predictions = model_data['predictions']
        r2 = model_data['r2_score']
        mae = model_data['mae']
        rmse = model_data['rmse']
        future_dates = model_data['future_dates']
        last_date = model_data['last_date']
        last_price = model_data['last_price']
        model_name = model_data['model_name']

        std_error = np.std(y - y_pred_train)
        
        # Create figure with two subplots
        fig, (ax_main, ax_metrics) = plt.subplots(
            2, 1,
            figsize=(14, 10),
            height_ratios=[3, 1],
            gridspec_kw={'hspace': 0.3}
        )
        
        # ---- MAIN PLOT: Historical + Predictions ----
        # Historical data
        ax_main.plot(
            df['date'], df['price_usd'],
            color=color, linewidth=1.5, linestyle='-', alpha=0.7,
            label=f'{model_name} Training Fit (R¬≤={r2:.3f})', zorder=4
        )
        
        # Future predictions
        ax_main.plot(
            future_dates, predictions,
            color=color, linestyle='--', linewidth=2.5, marker='D', markersize=6,
            markerfacecolor='white',markeredgewidth=2, markeredgecolor=color,
                 label=f'{model_name} Future Prediction', zorder=6)
        
        
      # Intervalo de confianza (sombreado)
        ax_main.fill_between(future_dates,
                             predictions - 1.96 * std_error,
                             predictions + 1.96 * std_error,
                             color=color, alpha=0.15, label='95% Confidence Interval', zorder=2)    


        # Transition point
        ax_main.scatter(
            [last_date], [last_price],
            color='red', s=150, zorder=10,
            edgecolor='white', linewidth=2,
            label='Prediction Start', marker='*'
        )

        ax_main.axvline(x=last_date, color='gray', linestyle=':', linewidth=1.5, alpha=0.7, zorder=4)

        
        # Shaded prediction area
        ax_main.axvspan(
            last_date, future_dates[-1],
            alpha=0.08, color='yellow', zorder=1,
            label='Prediction Period'
        )
        
        # Main plot configuration
        ax_main.set_title(f'{model_name} vs Real BTC Price\n{len(predictions)}-Day Prediction{title_suffix}',
                      fontsize=14, fontweight='bold')
        ax_main.set_xlabel('Date')
        ax_main.set_ylabel('Price (USD)')
        ax_main.legend(loc='upper left', fontsize=9)
        ax_main.grid(True, alpha=0.3)
        ax_main.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
        ax_main.tick_params(axis='x', rotation=90)

        
        # ---- METRICS PLOT: Error Distribution ----
        # Calculate training errors
        errors = y - y_pred_train
        ax_metrics.hist(errors, bins=30, alpha=0.7, color=color, edgecolor='black')
        ax_metrics.axvline(x=0, color='black', linestyle='--', linewidth=1.5)
        ax_metrics.set_title(f'Error Distribution (Training) - {model_name}', fontsize=12, fontweight='bold')
        ax_metrics.set_xlabel('Error (USD)')
        ax_metrics.set_ylabel('Frequency')
        ax_metrics.grid(True, alpha=0.3)

        # Stats box
        stats_text = (f'MAE: ${mae:,.2f}\n'
                  f'RMSE: ${rmse:,.2f}\n'
                  f'R¬≤: {r2:.4f}\n'
                  f'Std Error: ${std_error:,.2f}')
        ax_metrics.text(0.95, 0.95, stats_text,
                    transform=ax_metrics.transAxes,
                    verticalalignment='top', horizontalalignment='right',
                    bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
                    fontsize=9)
        
        # Save figure
        plt.tight_layout()
        filepath = os.path.join(self.output_dir, filename)
        plt.savefig(filepath, dpi=150, bbox_inches='tight')
        plt.close(fig)
        
        print(f"‚úÖ Plot saved: {filepath}")
        return filepath


    def plot_model_lr(
        self, 
        df: pd.DataFrame, 
        n_days_future: int
    ) -> str:
        """
        Generate plot for Linear Regression model.
        
        Creates individual plot: "Real Price vs Linear Regression Prediction"
        
        Args:
            df: Historical data DataFrame
            n_days_future: Number of days to predict
        
        Returns:
            Path to saved PNG file
        """
        print(f"\nüìä Training Linear Regression model...")
        model_data = self._train_and_predict(df, 'linear', n_days_future)
        
        return self._plot_prediction_with_metrics(
            df=df,
            model_data=model_data,
            color=self.colors['linear'],
            filename="btc_linear_comparison.png"
        )

    def plot_model_ridge(
        self, 
        df: pd.DataFrame, 
        n_days_future: int,
        alpha: float = 1.0
    ) -> str:
        """
        Generate plot for Ridge Regression model.
        
        Creates individual plot: "Real Price vs Ridge Regression Prediction"
        
        Args:
            df: Historical data DataFrame
            n_days_future: Number of days to predict
            alpha: Regularization parameter (L2 penalty)
        
        Returns:
            Path to saved PNG file
        """

        model_data = self._train_and_predict(
            df, 'ridge', n_days_future, alpha=alpha
        )
        suffix= f" | Œ±={alpha} "
        
        return self._plot_prediction_with_metrics(
            df=df,
            model_data=model_data,
            color=self.colors['ridge'],
            filename="btc_ridge_comparison.png",
            title_suffix=suffix
        )


    def plot_all(
        self,
        df_real: pd.DataFrame,
        n_days_future: int = 10,
        alpha: float = 1.0
    ) -> Dict[str, str]:
        """
        Generate all plots: linear model, ridge model, and comparison.
        
        Args:
            df_real: Historical data DataFrame
            n_days_future: Number of days to predict
            alpha: Ridge regularization parameter
        
        Returns:
            Dictionary mapping plot types to file paths:
                - 'linear': Linear regression plot
                - 'ridge': Ridge regression plot
                - 'comparison': Comparison plot
        """
        
        plot_paths = {}
        
        # Plot 1: Linear Regression
        try:
            print("\n[2/4] Plotting Linear Regression model...")
            plot_paths['linear'] = self.plot_model_lr(df_real, n_days_future)
        except Exception as e:
            print(f"‚ùå Error in plot_model_lr: {e}")
            import traceback
            traceback.print_exc()
        
        # Plot 2: Ridge model
        try:
            print("\n[3/4] Plotting Ridge Regression model...")
            plot_paths['ridge'] = self.plot_model_ridge(df_real, n_days_future, alpha)
        except Exception as e:
            print(f"‚ùå Error in plot_model_ridge: {e}")
            import traceback
            traceback.print_exc()
        
        # Plot 4: Comparison 
        #try:
        #    print("\n[4/4] Plotting model comparison...")
        #    plot_paths['comparison'] = self.plot_comparison(df_real, n_days_future, alpha)
        #except Exception as e:
        #    print(f"‚ùå Error in plot_comparison: {e}")
        #    import traceback
        #    traceback.print_exc()
        
        print("\n" + "="*60)
        print(f"‚úÖ COMPLETED: {len(plot_paths)}/2 plots generated successfully")
        print("="*60)
        
        return plot_paths


# ============================================================== Original ‚Üì ========================================================================================
# import os
#import pandas as pd
#import numpy as np
#import matplotlib.pyplot as plt

##from typing import Optional
##from sklearn.linear_model import Ridge
#from datetime import timedelta
#from src.api.coingecko_client import CoinGeckoClient
#from src.ml.btc_predictor import BTCPredictor
#from src.pipeline.btc_pipeline import BTCDataPipeline
#
#class BTCPlotter:
#
#    def __init__(self, df: pd.DataFrame, output_dir: str = "plots"):
#        self.df = df
#        self.btc_predictor = BTCPredictor
#        self.pipeline = BTCDataPipeline
#        self.output_dir = output_dir
#
#        #create folder
#        os.makedirs(self.output_dir, exist_ok=True)
#
#        self.colors = {
#            'real': '#2E86AB',      # Azul
#            'linear': '#A23B72',    # Magenta
#            'ridge': '#F18F01' 
#        }
#
#        
#
#    def prepare_training_data(self):
#        return self.btc_predictor.prepare_training_data()
#    
#    def train(self):
#        return self.btc_predictor.train()
#    
#    
#    def predict_future(self):
#        return self.btc_predictor.predict_future()
#    
#    
#    def plot_real_prices(self, df: pd.DataFrame, title: str = "Precios Reales BTC") -> str:
#        
#        fig, ax = plt.subplots(figsize=(12, 6))
#        
#        # Asegurar que date sea datetime
#        dates = pd.to_datetime(df['date'])
#        
#        # Plot simple de l√≠nea
#        ax.plot(dates, df['price_usd'], 
#                color=self.colors['real'], 
#                linewidth=2,
#                label='Precio Real BTC')
#        
#        # Puntos en inicio y fin
#        ax.plot(dates.iloc[0], df['price_usd'].iloc[0], 'ko', markersize=8,
#                label=f'Inicio: ${df["price_usd"].iloc[0]:.2f}')
#        ax.plot(dates.iloc[-1], df['price_usd'].iloc[-1], 'ro', markersize=8,
#                label=f'Fin: ${df["price_usd"].iloc[-1]:.2f}')
#        
#        # Configuraci√≥n
#        ax.set_title(title, fontsize=14, fontweight='bold')
#        ax.set_xlabel('Fecha')
#        ax.set_ylabel('Precio (USD)')
#        ax.legend()
#        ax.grid(True, alpha=0.3)
#        
#        # Formato de fechas
#        fig.autofmt_xdate()
#        
#        # Guardar
#        filename = os.path.join(self.output_dir, "btc_real_prices.png")
#        plt.tight_layout()
#        plt.savefig(filename, dpi=150)
#        plt.close(fig)
#        
#        return filename
#    
#
#    def plot_model_lr(self, df: pd.DataFrame, n_days_future: int):
#
#        # Train model
#        linear_predictor = BTCPredictor()
#        X,y = linear_predictor.prepare_training_data(df)
#        linear_predictor.train(X, y)
#        linear_pred = linear_predictor.predict_future(n_days_future)
#
#        #Calculate R¬≤
#        r2_score = linear_predictor.model.score(X, y)
#
#        # Combine historical + prections
#        all_predictions = np.concatenate([y, linear_pred])
#
#        # Generate future dates
#        last_date = df['date'].iloc[-1]
#        future_dates = pd.date_range(
#            start=last_date + timedelta(days=1),
#            periods=n_days_future,
#            freq='D'
#        )
#
#        # Create visualization
#        fig,(ax_main, ax_metrics) = plt.subplots(
#            2, 1,
#            figsize=(14, 10),
#            height_ratios=[3, 1],
#            gridspec_kw={'hspace': 0.3}
#        )
#
#        # Main plot
#        ax_main.plot(df['date'], df['price_usd'], 
#                'bo-', linewidth=2, markersize=4, alpha=0.7,
#                label='Precio Real BTC')
#        
#        # Predicciones futuras
#        ax_main.plot(future_dates, linear_pred, 
#                    'r--', linewidth=2.5, markersize=5,
#                    label=f'Predicci√≥n Linear (R¬≤={r2_score:.3f})')
#        
#        # Punto de transici√≥n
#        ax_main.scatter([last_date], [df['price_usd'].iloc[-1]], 
#                       color='red', s=100, zorder=5,
#                       edgecolor='black', linewidth=1.5,
#                       label='Inicio Predicci√≥n')
#        
#        # Configuraci√≥n del gr√°fico principal
#        ax_main.set_title(
#            f'Predicci√≥n de Precio BTC - {n_days_future} d√≠as futuros\n'
#            f'Regresi√≥n Lineal Simple (R¬≤ = {r2_score:.4f})',
#            fontsize=14, fontweight='bold', pad=15
#        )
#        ax_main.set_xlabel('Fecha')
#        ax_main.set_ylabel('Precio (USD)', fontsize=11)
#        ax_main.legend(loc='upper left', fontsize=10)
#        ax_main.grid(True, alpha=0.3)
#        
#        # √Årea sombreada para predicci√≥n
#        ax_main.axvspan(last_date, future_dates[-1], 
#                       alpha=0.1, color='gray',
#                       label='Per√≠odo de Predicci√≥n')
#        
#        # Formato de fechas
#        ax_main.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
#        ax_main.tick_params(axis='x', rotation=45)
#        
#        # ---- GR√ÅFICO DE M√âTRICAS ----
#        # Calcular errores
#        errors = y - all_predictions[:len(y)]
#        
#        # Histograma de errores
#        ax_metrics.hist(errors, bins=30, alpha=0.7, color='blue', edgecolor='black')
#        ax_metrics.axvline(x=0, color='red', linestyle='--', linewidth=1.5, alpha=0.7)
#        
#        ax_metrics.set_title('Distribuci√≥n de Errores (Per√≠odo de Entrenamiento)', 
#                            fontsize=12, fontweight='bold')
#        ax_metrics.set_xlabel('Error (USD)')
#        ax_metrics.set_ylabel('Frecuencia')
#        ax_metrics.grid(True, alpha=0.3)
#        
#        # A√±adir estad√≠sticas
#        error_stats = (f'MAE: ${np.mean(np.abs(errors)):.2f}\n'
#                      f'RMSE: ${np.sqrt(np.mean(errors**2)):.2f}\n'
#                      f'R¬≤: {r2_score:.4f}')
#        
#        ax_metrics.text(0.95, 0.95, error_stats,
#                       transform=ax_metrics.transAxes,
#                       verticalalignment='top',
#                       horizontalalignment='right',
#                       bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
#                       fontsize=9)
#
#
#        plt.tight_layout()
#
#        # Save the file
#        filename = os.path.join(self.output_dir, "btc_linear_prediction.png")
#        plt.savefig(filename, dpi=150, bbox_inches='tight')
#        plt.close(fig)
#
#        print(f"‚úÖ Linear plot guardado: {filename}")  # Debug
#
#
#        return filename
#
#
#
#    def plot_model_ridge(self, df: pd.DataFrame, n_days_future: int, alpha: float = 1.0):
#    
#        # 2. Entrenar modelo Ridge
#    
#        ridge_predictor = BTCPredictor(model=Ridge(alpha=alpha))
#        X, y = ridge_predictor.prepare_training_data(df)
#        ridge_predictor.train(X, y)
#        ridge_pred = ridge_predictor.predict_future(n_days_future)
#    
#        # 3. Calcular m√©tricas
#        ridge_score = ridge_predictor.model.score(X, y)
#    
#        # 4. Combinar hist√≥rico + predicciones
#        all_predictions = np.concatenate([y, ridge_pred])
#    
#        # 5. Generar fechas futuras
#        last_date = df['date'].iloc[-1]
#        future_dates = pd.date_range(
#            start=last_date + timedelta(days=1),
#            periods=n_days_future,
#            freq='D'
#        )
#    
#        # 6. Crear visualizaci√≥n (MISMA ESTRUCTURA QUE plot_model_lr)
#        fig, (ax_main, ax_metrics) = plt.subplots(
#            2, 1,
#            figsize=(14, 10),
#            height_ratios=[3, 1],
#            gridspec_kw={'hspace': 0.3}
#        )
#    
#        # ---- GR√ÅFICO PRINCIPAL ----
#        # Datos hist√≥ricos
#        ax_main.plot(df['date'], df['price_usd'], 
#                    'ko-', linewidth=2, markersize=4, alpha=0.7,
#                    label='Precio Real BTC')
#    
#         # Predicciones futuras RIDGE
#        ax_main.plot(future_dates, ridge_pred, 
#                    'r--', linewidth=2.5, markersize=5,
#                    label=f'Ridge Regression (Œ±={alpha}, R¬≤={ridge_score:.3f})')
#    
#        # Punto de transici√≥n
#        ax_main.scatter([last_date], [df['price_usd'].iloc[-1]], 
#                    color='red', s=100, zorder=5,
#                    edgecolor='black', linewidth=1.5,
#                     label='Inicio Predicci√≥n')
#    
#        # Configuraci√≥n del gr√°fico principal
#        ax_main.set_title(
#            f'Predicci√≥n de Precio BTC - {n_days_future} d√≠as futuros\n'
#            f'Ridge Regression con Regularizaci√≥n L2 (Œ±={alpha})',
#            fontsize=14, fontweight='bold', pad=15
#        )
#        ax_main.set_xlabel('Fecha')
#        ax_main.set_ylabel('Precio (USD)', fontsize=11)
#        ax_main.legend(loc='upper left', fontsize=10)
#        ax_main.grid(True, alpha=0.3)
#    
#        # √Årea sombreada para predicci√≥n
#        ax_main.axvspan(last_date, future_dates[-1], 
#                    alpha=0.1, color='gray',
#                    label='Per√≠odo de Predicci√≥n')
#    
#        # Formato de fechas
#        ax_main.xaxis.set_major_formatter(plt.matplotlib.dates.DateFormatter('%Y-%m-%d'))
#        ax_main.tick_params(axis='x', rotation=45)
#    
#        # ---- GR√ÅFICO DE M√âTRICAS ----
#        # Calcular errores (usando predicciones del per√≠odo de entrenamiento)
#        errors = y - all_predictions[:len(y)]
#    
#        # Histograma de errores
#        ax_metrics.hist(errors, bins=30, alpha=0.7, color='red', edgecolor='black')
#        ax_metrics.axvline(x=0, color='blue', linestyle='--', linewidth=1.5, alpha=0.7)
#    
#        ax_metrics.set_title('Distribuci√≥n de Errores - Ridge Regression', 
#                            fontsize=12, fontweight='bold')
#        ax_metrics.set_xlabel('Error (USD)')
#        ax_metrics.set_ylabel('Frecuencia')
#        ax_metrics.grid(True, alpha=0.3)
#    
#        # A√±adir estad√≠sticas
#        mae = np.mean(np.abs(errors))
#        rmse = np.sqrt(np.mean(errors**2))
#    
#        error_stats = (f'MAE: ${mae:.2f}\n'
#                    f'RMSE: ${rmse:.2f}\n'
#                    f'R¬≤: {ridge_score:.4f}\n'
#                    f'Œ± (alpha): {alpha}')
#    
#        ax_metrics.text(0.95, 0.95, error_stats,
#                    transform=ax_metrics.transAxes,
#                   verticalalignment='top',
#                   horizontalalignment='right',
#                   bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.8),
#                   fontsize=9)
#    
#        # 7. Ajustar layout
#        plt.tight_layout()
#    
#        print(f"‚úÖ Modelo Ridge entrenado: R¬≤ = {ridge_score:.4f}, Œ± = {alpha}")
#        print(f"üìä Predicciones: {n_days_future} d√≠as desde {future_dates[0].date()}")
#    
#        filename = os.path.join(self.output_dir, "btc_ridge_prediction.png")
#        plt.savefig(filename, dpi= 150, bbox_inches='tight')
#        plt.close(fig)
#
#        print(f"‚úÖ Ridge plot guardado: {filename}")  # Debug
#
#
#        return filename
#
#    def plot_all(self,
#                df_real: pd.DataFrame,
#                n_days_future: int = 10) -> dict:
#
#        plot_paths = {}
#
#        try:
#            plot_paths['real'] = self.plot_real_prices(df_real)
#        except Exception as e:
#            print(f"‚ùå Error en plot_real_prices: {e}")
#        
#        try:
#            plot_paths['linear'] = self.plot_model_lr(df_real, n_days_future)
#        except Exception as e:
#            print(f"‚ùå Error en plot_model_lr: {e}")
#            import traceback
#            traceback.print_exc()
#        
#        try:
#            plot_paths['ridge'] = self.plot_model_ridge(df_real, n_days_future)
#        except Exception as e:
#            print(f"‚ùå Error en plot_model_ridge: {e}")
#            import traceback
#            traceback.print_exc()
#        
#        return plot_paths
#