# Stock Price Prediction System

This notebook implements an advanced stock price prediction system using machine learning models including XGBoost ensemble and LSTM neural networks.

## Import Libraries

In [1]:
!pip install yfinance xgboost tensorflow scikit-learn pandas numpy matplotlib plotly seaborn --quiet



In [2]:
import yfinance as yf
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import warnings
warnings.filterwarnings('ignore')

In [3]:
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
from sklearn.model_selection import train_test_split, TimeSeriesSplit
import xgboost as xgb
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor
from sklearn.linear_model import Ridge, ElasticNet

In [4]:
try:
    import tensorflow as tf
    from tensorflow.keras.models import Sequential
    from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization, GRU
    from tensorflow.keras.optimizers import Adam
    from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
    LSTM_AVAILABLE = True
    print("TensorFlow available for LSTM models")
except ImportError:
    print("TensorFlow not available. LSTM model will be skipped.")
    LSTM_AVAILABLE = False

TensorFlow available for LSTM models


In [5]:
from datetime import datetime, timedelta

## Configuration

In [6]:
# Configuration parameters
STOCK_TICKER = "AAPL"  # Default stock ticker
MODEL_CHOICE = 1  # 1: XGBoost, 2: LSTM, 3: Both
FORECAST_HORIZON = 1  # Days to predict ahead

## Data Fetching Functions

In [13]:
## Data Fetching Functions

def fetch_stock_data(ticker_symbol, period="max"):
    """Fetch stock data from Yahoo Finance"""
    ticker = yf.Ticker(ticker_symbol)
    data = ticker.history(period=period)
    
    if data.empty:
        print(f"No data found for {ticker_symbol}")
        return None
    
    info = ticker.info
    company_name = info.get('longName', ticker_symbol)
    
    print(f"\nFetched data for {company_name} ({ticker_symbol})")
    print(f"Data range: {data.index[0].date()} to {data.index[-1].date()}")
    print(f"Total trading days: {len(data)}")
    
    return data, company_name

def get_valid_ticker():
    """Get valid ticker input from user"""
    while True:
        stock = input("Enter a valid stock ticker (e.g., AAPL, TSLA, MSFT): ").upper()
        try:
            test = yf.Ticker(stock)
            if test.history(period="1d").empty:
                print("Invalid ticker. Please try again.")
            else:
                return stock
        except Exception as e:
            print(f"Error validating ticker: {e}")
            print("Invalid input. Please try again.")

def get_model_choice():
    """Get model choice from user"""
    print("\nChoose prediction model:")
    print("1. XGBoost (Less time, High accuracy)")
    print("2. LSTM (More time, Most accuracy)")  
    print("3. Both (Conclusion with both models)")
    
    while True:
        choice = input("Enter your choice (1/2/3): ")
        if choice in ['1', '2', '3']:
            return int(choice)
        print("Invalid choice. Please enter 1, 2, or 3.")

## Data Exploration

# Test data fetching
print("Testing data fetching...")
sample_data = fetch_stock_data("AAPL", "1y")
if sample_data is not None:
    data, company_name = sample_data
    print(f"\nSample data shape: {data.shape}")
    print(f"Columns: {list(data.columns)}")
    print(f"\nFirst few rows:")
    print(data.head())
    print(f"\nLast few rows:")
    print(data.tail())
    print(f"\nBasic statistics:")
    print(data.describe())

## Basic Visualization

plt.figure(figsize=(12, 6))
if sample_data is not None:
    data, _ = sample_data
    plt.plot(data.index, data['Close'], label='Close Price', linewidth=2)
    plt.title(f'{company_name} Stock Price - Last Year')
    plt.xlabel('Date')
    plt.ylabel('Price ($)')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()
else:
    print("No data available for visualization")

Testing data fetching...


$AAPL: possibly delisted; no price data found  (period=1y)


No data found for AAPL
No data available for visualization


<Figure size 1200x600 with 0 Axes>

## Technical Indicators

In [8]:
def calculate_rsi(prices, window=14):
    """Calculate Relative Strength Index"""
    # TODO: Implement RSI calculation
    pass

## Model Training Functions

In [9]:
def train_xgboost_model(data, forecast_horizon=1):
    """Train XGBoost ensemble model"""
    # TODO: Implement XGBoost training
    pass

In [10]:
def train_lstm_model(data, forecast_horizon=1):
    """Train LSTM model"""
    # TODO: Implement LSTM training
    pass

## Visualization Functions

In [11]:
def create_interactive_plot(data_with_signals, company_name, ticker_symbol):
    """Create interactive Plotly visualization"""
    # TODO: Implement interactive plotting
    pass

## Main Execution

In [12]:
# Main execution will be implemented in later commits
print("Stock Prediction System initialized")

Stock Prediction System initialized
