In [None]:

def prepare_stock_data(data, target_column='Close'):
    """
    Prepare stock data for time series forecasting
    
    Args:
        data: pandas DataFrame with stock data
        target_column: Column to forecast
    
    Returns:
        numpy array of target values
    """
    # Extract the target column
    prices = data[target_column].values
    
    # Remove any NaN values
    prices = prices[~np.isnan(prices)]
    
    print(f"Prepared {len(prices)} data points for forecasting")
    return prices

def create_nbeats_config():
    """
    Create N-BEATS model configuration
    Different stack configurations for different patterns
    """
    stack_configs = [
        {
            'n_blocks': 3,
            'basis_type': 'trend',  # For trend patterns
            'n_layers_per_block': 4,
            'hidden_size': 512,
            'degree': 3,
            'share_weights': True
        },
        {
            'n_blocks': 3,
            'basis_type': 'seasonality',  # For seasonal patterns
            'n_layers_per_block': 4,
            'hidden_size': 512,
            'share_weights': True
        },
        {
            'n_blocks': 3,
            'basis_type': 'generic',  # For generic patterns
            'n_layers_per_block': 4,
            'hidden_size': 256,
            'share_weights': False
        }
    ]
    return stack_configs

def plot_data_overview(data, symbol):
    """
    Plot an overview of the stock data
    """
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Price chart
    axes[0, 0].plot(data.index, data['Close'], color='blue', alpha=0.7)
    axes[0, 0].set_title(f'{symbol} - Closing Price')
    axes[0, 0].set_ylabel('Price ($)')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Volume chart
    axes[0, 1].plot(data.index, data['Volume'], color='green', alpha=0.7)
    axes[0, 1].set_title(f'{symbol} - Trading Volume')
    axes[0, 1].set_ylabel('Volume')
    axes[0, 1].grid(True, alpha=0.3)
    
    # Daily returns
    daily_returns = data['Close'].pct_change().dropna()
    axes[1, 0].plot(data.index[1:], daily_returns, color='red', alpha=0.7)
    axes[1, 0].set_title(f'{symbol} - Daily Returns')
    axes[1, 0].set_ylabel('Return (%)')
    axes[1, 0].grid(True, alpha=0.3)
    
    # Return distribution
    axes[1, 1].hist(daily_returns, bins=50, color='purple', alpha=0.7)
    axes[1, 1].set_title(f'{symbol} - Return Distribution')
    axes[1, 1].set_xlabel('Daily Return')
    axes[1, 1].set_ylabel('Frequency')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()

def main_example():
    """
    Main example function demonstrating N-BEATS on stock data
    """
    print("=" * 60)
    print("N-BEATS Stock Price Forecasting Example")
    print("=" * 60)
    
    # 1. Fetch Stock Data
    symbol = "AAPL"  # Apple Inc.
    print(f"\n1. Fetching data for {symbol}...")
    
    stock_data = fetch_stock_data(symbol, period="2y")
    if stock_data is None:
        print("Failed to fetch data. Exiting...")
        return
    
    # Display basic info
    print(f"Data range: {stock_data.index[0].date()} to {stock_data.index[-1].date()}")
    print(f"Latest closing price: ${stock_data['Close'].iloc[-1]:.2f}")
    
    # 2. Plot data overview
    print("\n2. Plotting data overview...")
    plot_data_overview(stock_data, symbol)
    
    # 3. Prepare data for forecasting
    print("\n3. Preparing data for forecasting...")
    prices = prepare_stock_data(stock_data, target_column='Close')
    
    # 4. Create N-BEATS model
    print("\n4. Creating N-BEATS model...")
    
    # Model parameters
    backcast_length = 30    # Use 30 days of history
    forecast_length = 5     # Predict next 5 days
    
    # Create stack configurations
    stack_configs = [
        {
            'n_blocks': 2,
            'basis_type': 'polynomial',
            'n_layers_per_block': 4,
            'hidden_size': 256,
            'degree': 3,
            'share_weights': True
        },
        {
            'n_blocks': 2,
            'basis_type': 'fourier',
            'n_layers_per_block': 4,
            'hidden_size': 256,
            'share_weights': True
        },
        {
            'n_blocks': 2,
            'basis_type': 'generic',
            'n_layers_per_block': 3,
            'hidden_size': 128,
            'share_weights': False
        }
    ]
    
    # Initialize model
    model = NeuralForecast(
        stack_configs=stack_configs,
        backcast_length=backcast_length,
        forecast_length=forecast_length
    )
    
    print(f"Model created with {sum(p.numel() for p in model.model.parameters()):,} parameters")
    
    # 5. Process and split data
    print("\n5. Processing and splitting data...")
    
    train_data, val_data, test_data = model.process_data(
        data=prices,
        train_ratio=0.7,
        val_ratio=0.15,
        normalize=False
    )
    
    # 6. Train the model
    print("\n6. Training the model...")
    
    history = model.fit(
        train_data=train_data,
        val_data=val_data,
        epochs=100,
        batch_size=32,
        learning_rate=1e-3,
        optimizer='adam',
        loss_function='mae',
        early_stopping=True,
        patience=15,
        scheduler='plateau',
        gradient_clip=1.0,
        verbose=True
    )
    
    # 7. Plot training history
    print("\n7. Plotting training history...")
    model.plot_training_history()
    
    # 8. Evaluate on test data
    print("\n8. Evaluating on test data...")
    
    test_metrics = model.evaluate(
        test_data=test_data,
        metrics=['mae', 'mse', 'rmse', 'mape']
    )
    
    # 9. Generate forecasts
    print("\n9. Generating forecasts...")
    
    # Use the last part of the data for forecasting
    recent_data = prices[-100:]  # Last 100 days
    input_sequence = recent_data[-backcast_length:]  # Last 30 days as input
    
    # Generate forecast with components
    forecast, components = model.forecast(
        input_sequence=input_sequence,
        return_components=True
    )
    
    print(f"Forecast for next {forecast_length} days:")
    for i, pred in enumerate(forecast.flatten()):
        print(f"Day {i+1}: ${pred:.2f}")
    
    # 10. Plot comprehensive forecast
    print("\n10. Plotting forecast results...")
    
    # Create a more detailed forecast plot
    plt.figure(figsize=(15, 10))
    
    # Historical data (last 60 days)
    hist_data = prices[-60:]
    hist_time = np.arange(len(hist_data))
    
    # Forecast data
    forecast_flat = forecast.flatten()
    forecast_time = np.arange(len(hist_data), len(hist_data) + len(forecast_flat))
    
    # Main plot
    plt.subplot(2, 1, 1)
    plt.plot(hist_time, hist_data, label='Historical Prices', color='blue', linewidth=2)
    plt.plot(forecast_time, forecast_flat, label='N-BEATS Forecast', 
             color='red', linewidth=2, marker='o')
    
    # Add input sequence highlight
    input_start = len(hist_data) - backcast_length
    plt.axvspan(input_start, len(hist_data)-1, alpha=0.2, color='green', 
                label='Input Sequence')
    plt.axvline(x=len(hist_data)-1, color='black', linestyle='--', alpha=0.7)
    
    plt.title(f'{symbol} Stock Price Forecast - N-BEATS Model')
    plt.xlabel('Days')
    plt.ylabel('Price ($)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    # Component breakdown
    plt.subplot(2, 1, 2)
    colors = ['purple', 'orange', 'brown', 'pink', 'gray']
    for i, (name, component) in enumerate(components.items()):
        plt.plot(forecast_time, component.flatten(), 
                label=f'{name}', color=colors[i % len(colors)], linewidth=1.5)
    
    plt.title('Forecast Components Breakdown')
    plt.xlabel('Days')
    plt.ylabel('Price Contribution ($)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.show()
    
    # 11. Model summary
    print("\n11. Model Summary:")
    print("=" * 40)
    summary = model.get_model_summary()
    for key, value in summary.items():
        if key != 'model_info':
            print(f"{key}: {value}")
    
    # 12. Save the model (optional)
    model_path = f"nbeats_{symbol.lower()}_model.pth"
    model.save_model(model_path)
    print(f"\nModel saved to: {model_path}")
    
    return model, test_metrics, forecast

def hyperparameter_example():
    """
    Example of hyperparameter tuning for stock data
    """
    print("\n" + "="*60)
    print("HYPERPARAMETER TUNING EXAMPLE")
    print("="*60)
    
    # Fetch data
    symbol = "GOOGL"
    stock_data = fetch_stock_data(symbol, period="1y")
    prices = prepare_stock_data(stock_data)
    
    # Simple model configuration for quick tuning
    stack_configs = [
        {
            'n_blocks': 2,
            'basis_type': 'generic',
            'n_layers_per_block': 3,
            'hidden_size': 128,
            'share_weights': True
        }
    ]
    
    model = NeuralForecast(
        stack_configs=stack_configs,
        backcast_length=20,
        forecast_length=3
    )
    
    # Process data
    train_data, val_data, test_data = model.process_data(
        data=prices,
        train_ratio=0.8,
        val_ratio=0.15,
        normalize=False
    )
    
    # Define parameter grid
    param_grid = {
        'learning_rate': [1e-4, 1e-3, 1e-2],
        'batch_size': [16, 32],
        'optimizer': ['adam', 'adamw']
    }
    
    # Run hyperparameter search
    results = model.hyperparameter_finder(
        train_data=train_data,
        val_data=val_data,
        param_grid=param_grid,
        max_trials=6,
        epochs=30
    )
    
    print(f"Best parameters found: {results['best_params']}")
    print(f"Best validation score: {results['best_score']:.6f}")

if __name__ == "__main__":
    # Run main example
    try:
        model, metrics, forecast = main_example()
        print("\n" + "="*60)
        print("MAIN EXAMPLE COMPLETED SUCCESSFULLY!")
        print("="*60)
        
        # Optionally run hyperparameter example
        run_hp_example = input("\nRun hyperparameter tuning example? (y/n): ")
        if run_hp_example.lower() == 'y':
            hyperparameter_example()
            
    except Exception as e:
        print(f"Error in main example: {e}")
        print("Make sure you have the required dependencies:")
        print("pip install yfinance torch numpy matplotlib pandas")

# Additional utility functions for more advanced usage

def multi_stock_comparison(symbols, period="1y"):
    """
    Compare N-BEATS performance across multiple stocks
    """
    results = {}
    
    for symbol in symbols:
        print(f"\nProcessing {symbol}...")
        
        try:
            # Fetch and prepare data
            stock_data = fetch_stock_data(symbol, period)
            prices = prepare_stock_data(stock_data)
            
            # Create model
            stack_configs = [
                {'n_blocks': 2, 'basis_type': 'generic', 'n_layers_per_block': 3, 
                 'hidden_size': 128, 'share_weights': True}
            ]
            
            model = NeuralForecast(
                stack_configs=stack_configs,
                backcast_length=20,
                forecast_length=5
            )
            
            # Process and train
            train_data, val_data, test_data = model.process_data(
                data=prices, train_ratio=0.7, val_ratio=0.2, normalize=False
            )
            
            model.fit(train_data, val_data, epochs=50, verbose=False)
            
            # Evaluate
            metrics = model.evaluate(test_data, metrics=['mae', 'mse', 'rmse'])
            results[symbol] = metrics
            
        except Exception as e:
            print(f"Error processing {symbol}: {e}")
            results[symbol] = None
    
    return results

def create_trading_strategy(model, prices, lookback=30, threshold=0.02):
    """
    Simple trading strategy based on N-BEATS predictions
    """
    positions = []
    signals = []
    
    for i in range(lookback, len(prices) - 5):
        # Get input sequence
        input_seq = prices[i-lookback:i]
        
        # Generate forecast
        forecast = model.forecast(input_seq)
        
        # Calculate expected return
        current_price = prices[i]
        future_price = forecast[0, 0]  # First day forecast
        expected_return = (future_price - current_price) / current_price
        
        # Generate signal
        if expected_return > threshold:
            signal = 1  # Buy
        elif expected_return < -threshold:
            signal = -1  # Sell
        else:
            signal = 0  # Hold
        
        signals.append(signal)
        positions.append(expected_return)
    
    return signals, positions