# Short-term Crypto Price Prediction based on Order Book Dynamics

This notebook aims to implement and explore methods for short-term cryptocurrency price prediction, drawing inspiration from the research paper "Mind the Gaps: Short-term Crypto Price Prediction".

## Objective
Conduct quantitative research on order book dynamics and build predictive models to forecast future prices.

## Scope of Work
1. Download/Fetch order book data.
2. Engineer features from the order book.
3. Create ML/statistical models to predict future price changes.

## Table of Contents
1. Setup and Configuration
2. Data Acquisition
   - Method 1: Fetching from Bybit API (using `pybit` REST & WebSocket)
   - Method 2: Loading from a local dataset
3. Data Preprocessing
   - LOB Reconstruction (if necessary)
   - Resampling to Second-Level Data
   - Cleaning
4. Feature Engineering
   - Mid-Price and Spread
   - Volume-Adjusted Mid-Price (VAMP)
   - Trade Imbalance (TI)
   - Quote Imbalance (QI)
5. Target Variable Creation
6. Stationarity Checks and Transformations
7. Model Building and Evaluation
   - Data Splitting
   - Linear Regression
   - Logistic Regression (for classification)
   - Decision Tree
   - Random Forest
   - XGBoost
   - Support Vector Machine (SVM)
   - Neural Network (MLP)
8. Conclusion and Future Work

---
## 1. Setup and Configuration
Import necessary libraries and configure settings.

In [None]:
import pandas as pd
import numpy as np
from pybit.unified_trading import HTTP as UnifiedHTTP # For REST API
from pybit.unified_trading import WebSocket as UnifiedWebSocket # For WebSocket
import time
import datetime
import logging # For pybit websocket
from sklearn.model_selection import train_test_split, TimeSeriesSplit
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression, LogisticRegression
from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.svm import SVR, SVC
import xgboost as xgb
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout
from tensorflow.keras.callbacks import EarlyStopping
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns
from statsmodels.tsa.stattools import adfuller

# Display settings
pd.set_option('display.max_columns', None)
pd.set_option('display.width', 1000)
sns.set_style('whitegrid')

# Setup logging for pybit (optional, but helpful for WebSocket debugging)
logger = logging.getLogger("pybit")
logger.setLevel(logging.INFO) # You can set to DEBUG for more verbose output
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s"))
if not logger.hasHandlers(): # Avoid adding multiple handlers if re-running cell
    logger.addHandler(stream_handler)

### Configuration Parameters
Set your API keys for Bybit here if you choose Method 1.
Set the path to your local dataset if you choose Method 2.

In [None]:
# --- Configuration ---
DATA_SOURCE_METHOD = 'local_dataset' # Options: 'bybit_api', 'bybit_api_websocket_demo', 'local_dataset'

# For Bybit API (Method 1)
BYBIT_API_KEY = 'YOUR_API_KEY' # For private endpoints if needed, not required for public data
BYBIT_SECRET_KEY = 'YOUR_SECRET_KEY' # For private endpoints
BYBIT_SYMBOL = 'BTCUSDT' # Pybit uses symbols like BTCUSDT (no slash for USDT perpetual)
BYBIT_LOB_DEPTH_SNAPSHOT = 50 # For REST snapshot (max 50 for unified trading orderbook endpoint)
BYBIT_LOB_DEPTH_WEBSOCKET = 50 # For WebSocket stream (e.g., 1, 50, 200, 500)
BYBIT_TRADE_LIMIT = 50 # Pybit public_trading_records limit is max 1000, default 500 for REST
BYBIT_CHANNEL_TYPE = "linear" # For USDT perpetuals like BTCUSDT. Other options: "spot", "inverse", "option"

# For Local Dataset (Method 2)
LOCAL_ORDERBOOK_FILEPATH = 'path/to/your/orderbook_data.csv'
LOCAL_TRADES_FILEPATH = 'path/to/your/trades_data.csv'

# Feature Engineering & Modeling Parameters
VAMP_LIQUIDITY_CUTOFF = 60000 # In dollars, as per PDF findings
TRADE_IMBALANCE_WINDOW = '60S' # 60 seconds for TI calculation
QUOTE_IMBALANCE_LEVELS = 5 # Number of LOB levels for QI (ensure data has this depth)
PREDICTION_HORIZON_SECONDS = 60 # Predict price change 60 seconds ahead
TARGET_TYPE = 'regression' # 'regression' or 'classification' (for direction: up/down/neutral)

---
## 2. Data Acquisition

### Method 1: Fetching from Bybit API (using `pybit` REST client for snapshots)
This section contains functions to fetch Limit Order Book (LOB) snapshots and trade data from Bybit using `pybit`'s REST client. A separate section below discusses real-time data with WebSockets.

In [None]:
def fetch_bybit_order_book_snapshot_pybit(symbol, depth_limit):
    """Fetches the current order book snapshot from Bybit using pybit's Unified Trading REST API."""
    session = UnifiedHTTP(
        testnet=False, # Set to True for testnet
        # api_key=BYBIT_API_KEY, # Not needed for public order book
        # api_secret=BYBIT_SECRET_KEY
    )
    try:
        # For Unified Trading, category is 'linear', 'spot', 'inverse', 'option'
        response = session.get_orderbook(category=BYBIT_CHANNEL_TYPE, symbol=symbol, limit=depth_limit)
        
        if response and response.get('retCode') == 0:
            order_book_result = response['result']
            timestamp_ms = int(response['time']) # Timestamp of the response
            
            # Bids are sorted descending by price, Asks ascending by price in Bybit's response
            bids_list = [[float(p), float(q)] for p, q in order_book_result['b']]
            asks_list = [[float(p), float(q)] for p, q in order_book_result['a']]
            
            bids_df = pd.DataFrame(bids_list, columns=['price', 'qty'])
            asks_df = pd.DataFrame(asks_list, columns=['price', 'qty'])
            
            # Add cumulative volume for VAMP calculation later (if needed from snapshot)
            if not bids_df.empty:
                bids_df['cumulative_qty'] = bids_df['qty'].cumsum()
                bids_df['cumulative_value_usd'] = (bids_df['price'] * bids_df['qty']).cumsum()
            if not asks_df.empty:
                asks_df['cumulative_qty'] = asks_df['qty'].cumsum()
                asks_df['cumulative_value_usd'] = (asks_df['price'] * asks_df['qty']).cumsum()
                
            return bids_df, asks_df, timestamp_ms
        else:
            print(f"Error in Bybit API response for order book {symbol}: {response}")
            return pd.DataFrame(), pd.DataFrame(), None
    except Exception as e:
        print(f"Exception fetching Bybit order book snapshot for {symbol} with pybit: {e}")
        return pd.DataFrame(), pd.DataFrame(), None

def fetch_bybit_trades_pybit(symbol, limit):
    """Fetches recent trades from Bybit using pybit's Unified Trading REST API."""
    session = UnifiedHTTP(
        testnet=False, 
    )
    try:
        response = session.get_public_trade_history(category=BYBIT_CHANNEL_TYPE, symbol=symbol, limit=limit)
        if response and response.get('retCode') == 0:
            trades_data = response['result']['list']
            if not trades_data: 
                 return pd.DataFrame()
            df_trades = pd.DataFrame(trades_data)
            df_trades['timestamp'] = pd.to_datetime(df_trades['T'], unit='ms') # 'T' is trade time in ms
            df_trades.rename(columns={'s': 'symbol', 'S': 'side', 'v': 'qty', 'p': 'price'}, inplace=True)
            df_trades['side'] = df_trades['side'].apply(lambda x: 'buy' if x == 'Buy' else 'sell') # Standardize side
            df_trades['price'] = df_trades['price'].astype(float)
            df_trades['qty'] = df_trades['qty'].astype(float)
            df_trades['value_usd'] = df_trades['price'] * df_trades['qty']
            df_trades = df_trades[['timestamp', 'price', 'qty', 'side', 'value_usd']]
            return df_trades.sort_values('timestamp').reset_index(drop=True)
        else:
            print(f"Error in Bybit API response for trades {symbol}: {response}")
            return pd.DataFrame()
    except Exception as e:
        print(f"Exception fetching Bybit trades for {symbol} with pybit: {e}")
        return pd.DataFrame()

if DATA_SOURCE_METHOD == 'bybit_api':
    print("Attempting to fetch data from Bybit API using pybit (REST snapshot)...")
    bids_df_pybit, asks_df_pybit, lob_timestamp_pybit = fetch_bybit_order_book_snapshot_pybit(BYBIT_SYMBOL, BYBIT_LOB_DEPTH_SNAPSHOT)
    trades_df_pybit = fetch_bybit_trades_pybit(BYBIT_SYMBOL, BYBIT_TRADE_LIMIT)

    if not (bids_df_pybit.empty or asks_df_pybit.empty or lob_timestamp_pybit is None):
        print(f"\nOrder Book Snapshot (pybit) at {datetime.datetime.fromtimestamp(lob_timestamp_pybit/1000)}:")
        print("Top 5 Bids:\n", bids_df_pybit.head())
        print("Top 5 Asks:\n", asks_df_pybit.head())
        # Prepare df_lob_raw from this single snapshot for the rest of the notebook
        temp_lob_data = []
        entry = {'timestamp': pd.to_datetime(lob_timestamp_pybit, unit='ms')}
        entry['best_bid_price'] = bids_df_pybit['price'].iloc[0]
        entry['best_bid_qty'] = bids_df_pybit['qty'].iloc[0]
        entry['best_ask_price'] = asks_df_pybit['price'].iloc[0]
        entry['best_ask_qty'] = asks_df_pybit['qty'].iloc[0]
        for i in range(1, QUOTE_IMBALANCE_LEVELS + 1):
            if len(bids_df_pybit) >= i: 
                entry[f'bid_price_L{i}'] = bids_df_pybit['price'].iloc[i-1]
                entry[f'bid_qty_L{i}'] = bids_df_pybit['qty'].iloc[i-1]
            if len(asks_df_pybit) >= i:
                entry[f'ask_price_L{i}'] = asks_df_pybit['price'].iloc[i-1]
                entry[f'ask_qty_L{i}'] = asks_df_pybit['qty'].iloc[i-1]
        temp_lob_data.append(entry)
        df_lob_raw = pd.DataFrame(temp_lob_data).set_index('timestamp')
    else:
        print("Could not fetch order book snapshot using pybit or it was empty.")
        df_lob_raw = pd.DataFrame()
        
    if not trades_df_pybit.empty:
        print("\nRecent Trades (pybit):\n", trades_df_pybit.head())
        df_trades_raw = trades_df_pybit
    else:
        print("Could not fetch trades using pybit or no trades found.")
        df_trades_raw = pd.DataFrame()
    
    print("\nNote: Bybit API (REST) provides a single snapshot. For robust analysis, collect data over time or use WebSockets.")

### Method 1b: Real-time Order Book Data with `pybit.unified_trading.WebSocket` (Conceptual)

Fetching true real-time data requires a WebSocket connection. The following cell provides a conceptual example using `pybit.unified_trading.WebSocket` for order book data. This data would typically be streamed to a database or file for later batch processing by the rest of this notebook, or used in a dedicated real-time prediction system.

**This code is for demonstration and would run indefinitely to collect data. You would typically run it in a separate script or manage its lifecycle carefully within a notebook (e.g., run for a short period to collect sample data).**

**Important Note on LOB Reconstruction:** The WebSocket stream provides initial `snapshot` messages followed by `delta` updates. To use this data for features like VAMP or QI at specific time intervals (e.g., every second), you need to:
1.  Maintain a local, in-memory representation of the order book (e.g., sorted lists or DataFrames for bids and asks).
2.  Initialize this local LOB with the first `snapshot` message.
3.  For each incoming `delta` message, apply the changes (delete, update, insert) to your local LOB. Price levels with size "0" should be removed.
4.  At your desired sampling frequency (e.g., every 1 second), take a snapshot of your reconstructed local LOB. This snapshot would then be processed to create rows similar to what `df_lob_raw` expects (e.g., best bid/ask, L2 prices/quantities, etc.).
This reconstruction logic is non-trivial and is **not implemented** in the example below, which focuses on connection and basic message parsing.

In [None]:
collected_ws_orderbook_data = [] # Global list to store raw messages for this demo
MAX_WS_MESSAGES_TO_COLLECT = 10 # Collect a few messages then stop for demo purposes
ws_client_global = None # To allow exiting from the handler

def handle_unified_orderbook_message(message):
    """Handles incoming order book messages from Bybit Unified Trading WebSocket."""
    global collected_ws_orderbook_data, ws_client_global
    # print(f"Raw WS Message: {message}") # For debugging
    
    msg_type = message.get('type')
    topic = message.get('topic')
    timestamp_ms = message.get('ts')
    data = message.get('data')

    if not data:
        logger.warning(f"Received message with no data: {message}")
        return

    symbol = data.get('s')
    bids_raw = data.get('b', []) # List of [price_str, size_str]
    asks_raw = data.get('a', []) # List of [price_str, size_str]
    update_id = data.get('u')
    sequence_id = data.get('seq')
    # cts = message.get('cts') # Cross timestamp

    logger.info(f"WS OrderBook: Type='{msg_type}', Symbol='{symbol}', Topic='{topic}', TS={timestamp_ms}, U_ID={update_id}, Seq={sequence_id}")

    # Basic parsing of bids/asks into lists of floats
    bids_parsed = [[float(p), float(s)] for p, s in bids_raw]
    asks_parsed = [[float(p), float(s)] for p, s in asks_raw]

    collected_ws_orderbook_data.append({
        'timestamp_ms': timestamp_ms,
        'type': msg_type,
        'symbol': symbol,
        'bids': bids_parsed,
        'asks': asks_parsed,
        'update_id': update_id,
        'sequence_id': sequence_id
    })
    
    # Placeholder for LOB reconstruction logic:
    if msg_type == 'snapshot':
        # print(f"Received SNAPSHOT for {symbol}. Bids: {len(bids_parsed)}, Asks: {len(asks_parsed)}. Initialize local LOB.")
        # local_orderbook.initialize(bids_parsed, asks_parsed)
        pass
    elif msg_type == 'delta':
        # print(f"Received DELTA for {symbol}. Bids: {len(bids_parsed)}, Asks: {len(asks_parsed)}. Update local LOB.")
        # local_orderbook.update(bids_parsed, asks_parsed) # bids_parsed/asks_parsed here are deltas
        pass
        
    if len(collected_ws_orderbook_data) >= MAX_WS_MESSAGES_TO_COLLECT:
        if ws_client_global and ws_client_global.is_connected():
            logger.info(f"Collected {MAX_WS_MESSAGES_TO_COLLECT} WebSocket messages. Attempting to stop client.")
            ws_client_global.exit()

if DATA_SOURCE_METHOD == 'bybit_api_websocket_demo':
    print("Starting Bybit Unified Trading WebSocket for real-time order book data (demo)...\n")
    collected_ws_orderbook_data = [] # Reset for each run
    
    ws_client_global = UnifiedWebSocket(
        testnet=False, # Set to True for Bybit testnet
        channel_type=BYBIT_CHANNEL_TYPE 
    )
    
    logger.info(f"Subscribing to orderbook.{BYBIT_LOB_DEPTH_WEBSOCKET}.{BYBIT_SYMBOL}")
    ws_client_global.orderbook_stream(
        depth=BYBIT_LOB_DEPTH_WEBSOCKET, 
        symbol=BYBIT_SYMBOL, 
        callback=handle_unified_orderbook_message
    )
    
    print(f"Collecting up to {MAX_WS_MESSAGES_TO_COLLECT} messages. This might take a few seconds depending on market activity...")
    
    start_time = time.time()
    try:
        while ws_client_global.is_connected() and len(collected_ws_orderbook_data) < MAX_WS_MESSAGES_TO_COLLECT:
            time.sleep(0.1) # Check frequently
            if time.time() - start_time > 60: # Timeout after 60 seconds if not enough messages
                logger.warning("WebSocket demo timeout after 60 seconds.")
                break
    except KeyboardInterrupt:
        logger.info("WebSocket interrupted by user.")
    finally:
        if ws_client_global and ws_client_global.is_connected():
            ws_client_global.exit()
        logger.info("WebSocket connection closed.")
        
    if collected_ws_orderbook_data:
        print(f"\n--- Collected {len(collected_ws_orderbook_data)} WebSocket Order Book Messages (Raw Sample) ---")
        for i, msg_data in enumerate(collected_ws_orderbook_data[:min(3, len(collected_ws_orderbook_data))]): 
            print(f"Message {i+1}: TS={msg_data['timestamp_ms']}, Type={msg_data['type']}, Symbol={msg_data['symbol']}")
            print(f"  Bids sample: {msg_data.get('bids')[:2]}")
            print(f"  Asks sample: {msg_data.get('asks')[:2]}")
        print("\nReminder: Full LOB reconstruction from snapshot & deltas is needed to use this data for batch analysis.")
        # Example: df_from_ws = pd.DataFrame(collected_ws_orderbook_data)
        # This df_from_ws would then need extensive processing to create time-series LOB snapshots.
elif DATA_SOURCE_METHOD == 'bybit_api':
    pass # Data already handled by REST snapshot logic


### Method 2: Loading from a Local Dataset
Load data from a CSV file. The PDF uses three months of tick-level data from Bitstamp.
You'll need to adapt the loading based on your dataset's format.

For this example, we'll assume two files:
1.  `orderbook_data.csv`: Contains time-series of LOB snapshots (e.g., best bid/ask, and deeper levels if available).
    Columns might be: `timestamp`, `best_bid_price`, `best_bid_qty`, `best_ask_price`, `best_ask_qty`, `bid_price_L2`, `bid_qty_L2`, ... `ask_price_L5`, `ask_qty_L5`.
2.  `trades_data.csv`: Contains historical trades.
    Columns might be: `timestamp`, `price`, `qty`, `side` (e.g., 'buy' or 'sell').

In [None]:
if DATA_SOURCE_METHOD == 'local_dataset':
    print(f"Loading data from local files: {LOCAL_ORDERBOOK_FILEPATH} and {LOCAL_TRADES_FILEPATH}")
    try:
        # --- Load Order Book Data ---
        try:
            df_lob_raw = pd.read_csv(LOCAL_ORDERBOOK_FILEPATH)
            df_lob_raw['timestamp'] = pd.to_datetime(df_lob_raw['timestamp'])
            df_lob_raw.set_index('timestamp', inplace=True)
        except FileNotFoundError:
            print(f"Warning: {LOCAL_ORDERBOOK_FILEPATH} not found. Generating sample LOB data.")
            sample_timestamps_lob = pd.to_datetime(pd.date_range(start='2023-01-01', periods=10000, freq='100ms').values)
            data_lob = {
                'timestamp': sample_timestamps_lob,
                'best_bid_price': np.random.uniform(20000, 20100, 10000),
                'best_bid_qty': np.random.uniform(0.1, 5, 10000),
                'best_ask_price': np.random.uniform(20101, 20200, 10000),
                'best_ask_qty': np.random.uniform(0.1, 5, 10000),
            }
            for i in range(1, QUOTE_IMBALANCE_LEVELS + 1):
                data_lob[f'bid_price_L{i}'] = data_lob['best_bid_price'] - i * np.random.uniform(0.5, 2, 10000)
                data_lob[f'bid_qty_L{i}'] = np.random.uniform(0.1, 3, 10000)
                data_lob[f'ask_price_L{i}'] = data_lob['best_ask_price'] + i * np.random.uniform(0.5, 2, 10000)
                data_lob[f'ask_qty_L{i}'] = np.random.uniform(0.1, 3, 10000)
            data_lob['best_ask_price'] = np.maximum(data_lob['best_ask_price'], data_lob['best_bid_price'] + 0.01)
            for i in range(1, QUOTE_IMBALANCE_LEVELS + 1):
                 data_lob[f'ask_price_L{i}'] = np.maximum(data_lob[f'ask_price_L{i}'], data_lob[f'bid_price_L{i}'] + (i*0.1))
            df_lob_raw = pd.DataFrame(data_lob).set_index('timestamp')

        # --- Load Trades Data ---
        try:
            df_trades_raw = pd.read_csv(LOCAL_TRADES_FILEPATH)
            df_trades_raw['timestamp'] = pd.to_datetime(df_trades_raw['timestamp'])
            if 'value_usd' not in df_trades_raw.columns and 'price' in df_trades_raw.columns and 'qty' in df_trades_raw.columns:
                 df_trades_raw['value_usd'] = df_trades_raw['price'].astype(float) * df_trades_raw['qty'].astype(float)
            df_trades_raw.set_index('timestamp', inplace=True)
        except FileNotFoundError:
            print(f"Warning: {LOCAL_TRADES_FILEPATH} not found. Generating sample trades data.")
            sample_timestamps_trades = df_lob_raw.index if not df_lob_raw.empty else pd.to_datetime(pd.date_range(start='2023-01-01', periods=10000, freq='100ms').values)
            base_price = (df_lob_raw['best_bid_price'] + df_lob_raw['best_ask_price']) / 2 if not df_lob_raw.empty and 'best_bid_price' in df_lob_raw else pd.Series(np.random.uniform(20000,20200,len(sample_timestamps_trades)), index=sample_timestamps_trades)
            data_trades = {
                'timestamp': sample_timestamps_trades,
                'price': base_price,
                'qty': np.random.uniform(0.01, 1, len(sample_timestamps_trades)),
                'side': np.random.choice(['buy', 'sell'], len(sample_timestamps_trades)),
            }
            df_trades_raw = pd.DataFrame(data_trades).set_index('timestamp')
            df_trades_raw['value_usd'] = df_trades_raw['price'] * df_trades_raw['qty']
        
        print("LOB data (raw head):\n", df_lob_raw.head())
        print("Trades data (raw head):\n", df_trades_raw.head())

    except Exception as e:
        print(f"Error loading local data: {e}. Falling back to empty DataFrames.")
        df_lob_raw = pd.DataFrame()
        df_trades_raw = pd.DataFrame()
elif DATA_SOURCE_METHOD == 'bybit_api':
    if 'df_lob_raw' not in locals(): df_lob_raw = pd.DataFrame()
    if 'df_trades_raw' not in locals(): df_trades_raw = pd.DataFrame()
    print("Using data fetched from Bybit API (REST snapshot).")
    print("LOB data (raw head from API snapshot):\n", df_lob_raw.head())
    print("Trades data (raw head from API):\n", df_trades_raw.head())
elif DATA_SOURCE_METHOD == 'bybit_api_websocket_demo':
    print("WebSocket demo ran. `collected_ws_orderbook_data` contains raw messages.")
    print("Further processing is needed to use this data for batch analysis.")
    # For the rest of the notebook to run with some data, we might generate sample data here
    # or the user should switch DATA_SOURCE_METHOD to 'local_dataset' or 'bybit_api' (REST)
    if 'df_lob_raw' not in locals() or df_lob_raw.empty:
        print("Generating sample LOB data as WebSocket demo does not populate df_lob_raw directly.")
        sample_timestamps_lob = pd.to_datetime(pd.date_range(start='2023-01-01', periods=1000, freq='1s').values)
        data_lob = {'timestamp': sample_timestamps_lob,'best_bid_price': np.random.uniform(20000, 20100, 1000),'best_bid_qty': np.random.uniform(0.1, 5, 1000),'best_ask_price': np.random.uniform(20101, 20200, 1000),'best_ask_qty': np.random.uniform(0.1, 5, 1000)}
        for i in range(1, QUOTE_IMBALANCE_LEVELS + 1):
            data_lob[f'bid_price_L{i}'] = data_lob['best_bid_price'] - i * np.random.uniform(0.5,2,1000); data_lob[f'bid_qty_L{i}'] = np.random.uniform(0.1,3,1000)
            data_lob[f'ask_price_L{i}'] = data_lob['best_ask_price'] + i * np.random.uniform(0.5,2,1000); data_lob[f'ask_qty_L{i}'] = np.random.uniform(0.1,3,1000)
        df_lob_raw = pd.DataFrame(data_lob).set_index('timestamp')
    if 'df_trades_raw' not in locals() or df_trades_raw.empty:
        print("Generating sample trades data as WebSocket demo does not populate df_trades_raw directly.")
        df_trades_raw = pd.DataFrame({'timestamp': df_lob_raw.index, 'price': (df_lob_raw.best_bid_price+df_lob_raw.best_ask_price)/2, 'qty': np.random.uniform(0.01,1,len(df_lob_raw)), 'side':np.random.choice(['buy','sell'],len(df_lob_raw))}).set_index('timestamp')
        df_trades_raw['value_usd'] = df_trades_raw.price * df_trades_raw.qty
else:
    print("Invalid DATA_SOURCE_METHOD selected or API data fetch failed.")
    df_lob_raw = pd.DataFrame()
    df_trades_raw = pd.DataFrame()

---
## 3. Data Preprocessing
The PDF mentions condensing the full dataset to consider full seconds rather than every tick update to reduce data size.
We will resample the data to 1-second intervals.

**Note on LOB Reconstruction for VAMP/QI from Tick Data:**
If `df_lob_raw` is from tick-by-tick updates (not snapshots), you'd need a more sophisticated LOB reconstruction process before this step. You would maintain the state of the order book at each tick and then sample it every second. The PDF mentions: "Starting with the end of day snapshot, we then used each quote update to first update the order book, and then calculated and stored the important features".
For simplicity, if using `local_dataset`, we assume `df_lob_raw` provides snapshots or already has necessary levels available at each timestamp.

In [None]:
if not df_lob_raw.empty:
    # Resample LOB data to 1-second frequency. Use last observation in interval.
    agg_dict_lob = {}
    if 'best_bid_price' in df_lob_raw.columns: agg_dict_lob['best_bid_price'] = 'last'
    if 'best_bid_qty' in df_lob_raw.columns: agg_dict_lob['best_bid_qty'] = 'last'
    if 'best_ask_price' in df_lob_raw.columns: agg_dict_lob['best_ask_price'] = 'last'
    if 'best_ask_qty' in df_lob_raw.columns: agg_dict_lob['best_ask_qty'] = 'last'
    for i in range(1, QUOTE_IMBALANCE_LEVELS + 1):
        if f'bid_price_L{i}' in df_lob_raw.columns: agg_dict_lob[f'bid_price_L{i}'] = 'last'
        if f'bid_qty_L{i}' in df_lob_raw.columns: agg_dict_lob[f'bid_qty_L{i}'] = 'last'
        if f'ask_price_L{i}' in df_lob_raw.columns: agg_dict_lob[f'ask_price_L{i}'] = 'last'
        if f'ask_qty_L{i}' in df_lob_raw.columns: agg_dict_lob[f'ask_qty_L{i}'] = 'last'

    if df_lob_raw.index.empty or not isinstance(df_lob_raw.index, pd.DatetimeIndex):
        print("Warning: df_lob_raw has no DatetimeIndex. Cannot resample. Ensure 'timestamp' is index and datetime.")
        df_lob_sec = df_lob_raw.copy() # Or handle error appropriately
    elif not agg_dict_lob:
        print("Warning: No LOB columns found for aggregation in df_lob_raw.")
        df_lob_sec = pd.DataFrame(index=df_lob_raw.index.unique()) # Empty df with original index
    else:
        df_lob_sec = df_lob_raw.resample('1S').agg(agg_dict_lob)
    
    if 'best_bid_price' in df_lob_sec.columns and 'best_ask_price' in df_lob_sec.columns:
        df_lob_sec.dropna(subset=['best_bid_price', 'best_ask_price'], inplace=True)
    print("Resampled LOB data (1-second frequency, head):\n", df_lob_sec.head())
else:
    print("df_lob_raw is empty. Skipping LOB resampling.")
    df_lob_sec = pd.DataFrame()

if not df_trades_raw.empty:
    agg_dict_trades = {}
    if 'price' in df_trades_raw.columns: agg_dict_trades['price'] = 'last' 
    if 'qty' in df_trades_raw.columns: agg_dict_trades['qty'] = 'sum'
    if 'value_usd' in df_trades_raw.columns: agg_dict_trades['value_usd'] = 'sum'
    
    if df_trades_raw.index.empty or not isinstance(df_trades_raw.index, pd.DatetimeIndex):
        print("Warning: df_trades_raw has no DatetimeIndex. Cannot resample.")
        df_trades_sec = df_trades_raw.copy()
        buy_volume_sec = pd.Series(name='buy_qty_sum', dtype='float64')
        sell_volume_sec = pd.Series(name='sell_qty_sum', dtype='float64')
        buy_value_sec = pd.Series(name='buy_value_sum', dtype='float64')
        sell_value_sec = pd.Series(name='sell_value_sum', dtype='float64')
    elif not agg_dict_trades:
        print("Warning: No trade columns for aggregation in df_trades_raw.")
        df_trades_sec = pd.DataFrame(index=df_trades_raw.index.unique() if isinstance(df_trades_raw.index, pd.DatetimeIndex) else None)
        buy_volume_sec = pd.Series(name='buy_qty_sum', dtype='float64')
        sell_volume_sec = pd.Series(name='sell_qty_sum', dtype='float64')
        buy_value_sec = pd.Series(name='buy_value_sum', dtype='float64')
        sell_value_sec = pd.Series(name='sell_value_sum', dtype='float64')
    else:
        df_trades_sec = df_trades_raw.resample('1S').agg(agg_dict_trades)
        # Ensure 'side' column exists before trying to filter by it
        if 'side' in df_trades_raw.columns:
            buy_volume_sec = df_trades_raw[df_trades_raw['side'].str.lower() == 'buy']['qty'].resample('1S').sum().rename('buy_qty_sum')
            sell_volume_sec = df_trades_raw[df_trades_raw['side'].str.lower() == 'sell']['qty'].resample('1S').sum().rename('sell_qty_sum')
            buy_value_sec = df_trades_raw[df_trades_raw['side'].str.lower() == 'buy']['value_usd'].resample('1S').sum().rename('buy_value_sum')
            sell_value_sec = df_trades_raw[df_trades_raw['side'].str.lower() == 'sell']['value_usd'].resample('1S').sum().rename('sell_value_sum')
        else:
            print("Warning: 'side' column missing in df_trades_raw. Cannot calculate buy/sell specific sums.")
            buy_volume_sec = pd.Series(name='buy_qty_sum', index=df_trades_sec.index, dtype='float64').fillna(0)
            sell_volume_sec = pd.Series(name='sell_qty_sum', index=df_trades_sec.index, dtype='float64').fillna(0)
            buy_value_sec = pd.Series(name='buy_value_sum', index=df_trades_sec.index, dtype='float64').fillna(0)
            sell_value_sec = pd.Series(name='sell_value_sum', index=df_trades_sec.index, dtype='float64').fillna(0)

    df_trades_info_sec = pd.concat([df_trades_sec, buy_volume_sec, sell_volume_sec, buy_value_sec, sell_value_sec], axis=1)
    df_trades_info_sec.fillna({'buy_qty_sum': 0, 'sell_qty_sum': 0, 'buy_value_sum':0, 'sell_value_sum':0}, inplace=True)
    print("Resampled Trades data (1-second frequency, head):\n", df_trades_info_sec.head())
else:
    print("df_trades_raw is empty. Skipping trades resampling.")
    df_trades_info_sec = pd.DataFrame()

df_combined = pd.DataFrame()
if not df_lob_sec.empty and not df_trades_info_sec.empty:
    df_combined = pd.merge(df_lob_sec, df_trades_info_sec, left_index=True, right_index=True, how='outer')
elif not df_lob_sec.empty:
    df_combined = df_lob_sec
elif not df_trades_info_sec.empty:
    df_combined = df_trades_info_sec

if not df_combined.empty:
    if 'best_bid_price' in df_combined.columns: 
        lob_cols = df_lob_sec.columns if not df_lob_sec.empty else []
        lob_cols_to_ffill = [col for col in lob_cols if col in df_combined.columns]
        if lob_cols_to_ffill:
             df_combined[lob_cols_to_ffill] = df_combined[lob_cols_to_ffill].ffill().bfill()

    trade_qty_val_cols = ['qty', 'value_usd', 'buy_qty_sum', 'sell_qty_sum', 'buy_value_sum', 'sell_value_sum']
    for col in trade_qty_val_cols:
        if col in df_combined.columns:
            df_combined[col].fillna(0, inplace=True)
    
    if 'best_bid_price' in df_combined.columns and 'best_ask_price' in df_combined.columns:
        df_combined.dropna(subset=['best_bid_price', 'best_ask_price'], inplace=True)
    else:
        # If essential columns are missing after merge, df_combined might not be usable
        print("Warning: Essential LOB columns (best_bid_price, best_ask_price) are missing in df_combined. Further steps might fail.")
        # df_combined = pd.DataFrame() # Or decide to stop / handle differently
    print("Combined and preprocessed data (head):\n", df_combined.head())
else:
    print("df_combined is empty after merging/resampling.")

---
## 4. Feature Engineering
Based on the paper "Mind the Gaps".

In [None]:
# Ensure we have the necessary base columns
if df_combined.empty or not all(col in df_combined.columns for col in ['best_bid_price', 'best_ask_price']):
    print("Skipping feature engineering due to missing base LOB data (best_bid_price, best_ask_price) in df_combined.")
else:
    # --- Mid-Price and Spread ---
    df_combined['mid_price'] = (df_combined['best_bid_price'] + df_combined['best_ask_price']) / 2
    df_combined['spread'] = df_combined['best_ask_price'] - df_combined['best_bid_price']

    # --- Volume-Adjusted Mid-Price (VAMP) ---
    def calculate_weighted_price_for_vamp(levels_prices, levels_qty, target_value_usd):
        """ Helper for VAMP. Calculates weighted price for one side. """
        cumulative_value = 0
        weighted_price_sum = 0
        total_qty_for_value = 0
        for price, qty in zip(levels_prices, levels_qty):
            if pd.isna(price) or pd.isna(qty) or qty == 0 or price == 0: # Added price == 0 check
                continue
            value_at_level = price * qty
            if cumulative_value + value_at_level >= target_value_usd:
                remaining_value_needed = target_value_usd - cumulative_value
                if price == 0: continue # Avoid division by zero if price is somehow zero
                qty_to_take = remaining_value_needed / price
                weighted_price_sum += price * qty_to_take
                total_qty_for_value += qty_to_take
                cumulative_value += remaining_value_needed
                break
            else:
                weighted_price_sum += price * qty 
                total_qty_for_value += qty
                cumulative_value += value_at_level
        return (weighted_price_sum / total_qty_for_value) if total_qty_for_value > 0 else np.nan

    vamp_bids_p_cols = [f'bid_price_L{i}' for i in range(1, QUOTE_IMBALANCE_LEVELS + 1) if f'bid_price_L{i}' in df_combined.columns]
    vamp_bids_q_cols = [f'bid_qty_L{i}' for i in range(1, QUOTE_IMBALANCE_LEVELS + 1) if f'bid_qty_L{i}' in df_combined.columns]
    vamp_asks_p_cols = [f'ask_price_L{i}' for i in range(1, QUOTE_IMBALANCE_LEVELS + 1) if f'ask_price_L{i}' in df_combined.columns]
    vamp_asks_q_cols = [f'ask_qty_L{i}' for i in range(1, QUOTE_IMBALANCE_LEVELS + 1) if f'ask_qty_L{i}' in df_combined.columns]

    if vamp_bids_p_cols and vamp_bids_q_cols and vamp_asks_p_cols and vamp_asks_q_cols and \
       len(vamp_bids_p_cols) == len(vamp_bids_q_cols) and len(vamp_asks_p_cols) == len(vamp_asks_q_cols):
        
        Pb_vamp_values = []
        Pa_vamp_values = []
        for index, row in df_combined.iterrows():
            bid_prices = [row[col] for col in vamp_bids_p_cols]
            bid_qtys = [row[col] for col in vamp_bids_q_cols]
            ask_prices = [row[col] for col in vamp_asks_p_cols]
            ask_qtys = [row[col] for col in vamp_asks_q_cols]
            
            Pb_vamp_values.append(calculate_weighted_price_for_vamp(bid_prices, bid_qtys, VAMP_LIQUIDITY_CUTOFF))
            Pa_vamp_values.append(calculate_weighted_price_for_vamp(ask_prices, ask_qtys, VAMP_LIQUIDITY_CUTOFF))
            
        df_combined['Pb_vamp'] = Pb_vamp_values
        df_combined['Pa_vamp'] = Pa_vamp_values
        df_combined['vamp'] = (df_combined['Pb_vamp'] + df_combined['Pa_vamp']) / 2
        df_combined['vamp_mid_diff'] = df_combined['mid_price'] - df_combined['vamp']
    else:
        df_combined['vamp'] = df_combined['mid_price']
        df_combined['vamp_mid_diff'] = 0
        print("Warning: VAMP calculation could not be performed due to missing/mismatched deep LOB data. Using mid_price as fallback.")

    # --- Trade Imbalance (TI) ---
    if 'buy_value_sum' in df_combined.columns and 'sell_value_sum' in df_combined.columns:
        rolling_window_size_ti = int(pd.Timedelta(TRADE_IMBALANCE_WINDOW).total_seconds())
        diff_value = df_combined['buy_value_sum'] - df_combined['sell_value_sum']
        total_value = df_combined['buy_value_sum'] + df_combined['sell_value_sum']
        numerator_ti = diff_value.rolling(window=rolling_window_size_ti, min_periods=max(1, int(rolling_window_size_ti*0.1))).sum() # Ensure min_periods is at least 1
        denominator_ti = total_value.rolling(window=rolling_window_size_ti, min_periods=max(1, int(rolling_window_size_ti*0.1))).sum()
        df_combined['trade_imbalance'] = (numerator_ti / denominator_ti.replace(0, np.nan)).fillna(0) 
        df_combined['trade_imbalance'] = np.clip(df_combined['trade_imbalance'], -1, 1)
    else:
        df_combined['trade_imbalance'] = 0
        print("Warning: Trade Imbalance calculation skipped due to missing trade value columns.")

    # --- Quote Imbalance (QI) ---
    sum_bid_qty_L = pd.Series(0.0, index=df_combined.index)
    sum_ask_qty_L = pd.Series(0.0, index=df_combined.index)
    qi_cols_found = False
    for i in range(1, QUOTE_IMBALANCE_LEVELS + 1):
        bid_col_name = f'bid_qty_L{i}'
        ask_col_name = f'ask_qty_L{i}'
        if bid_col_name in df_combined.columns and ask_col_name in df_combined.columns:
            sum_bid_qty_L += df_combined[bid_col_name].fillna(0)
            sum_ask_qty_L += df_combined[ask_col_name].fillna(0)
            qi_cols_found = True
            
    if qi_cols_found:
        numerator_qi = sum_bid_qty_L - sum_ask_qty_L
        denominator_qi = sum_bid_qty_L + sum_ask_qty_L
        df_combined['quote_imbalance'] = (numerator_qi / denominator_qi.replace(0, np.nan)).fillna(0)
        df_combined['quote_imbalance'] = np.clip(df_combined['quote_imbalance'], -1, 1)
    else:
        df_combined['quote_imbalance'] = 0
        print("Warning: Quote Imbalance calculation skipped due to missing LOB quantity columns.")

    print("Data with engineered features (head):\n", df_combined[['mid_price', 'spread', 'vamp', 'vamp_mid_diff', 'trade_imbalance', 'quote_imbalance']].head())


---
## 5. Target Variable Creation
We want to predict future price changes. The PDF uses look-ahead windows from 1s to 60s.
Let's use `PREDICTION_HORIZON_SECONDS`.
Target: $MidPrice_{t+\Delta t} - MidPrice_t$ (absolute change) or % change.
Or, for classification: sign of change.

In [None]:
if 'mid_price' not in df_combined.columns or df_combined.empty:
    print("Skipping target variable creation as mid_price is missing or df_combined is empty.")
    df_final = pd.DataFrame() # Ensure df_final exists
else:
    df_combined['future_mid_price'] = df_combined['mid_price'].shift(-PREDICTION_HORIZON_SECONDS)
    
    if TARGET_TYPE == 'regression':
        df_combined['target_price_change'] = df_combined['future_mid_price'] - df_combined['mid_price']
        TARGET_COLUMN = 'target_price_change'
    elif TARGET_TYPE == 'classification':
        price_diff = df_combined['future_mid_price'] - df_combined['mid_price']
        # Ensure spread is available and not all NaNs before calculating mean
        if 'spread' in df_combined.columns and not df_combined['spread'].isnull().all():
            neutral_threshold = df_combined['spread'].mean() * 0.1 
        else:
            neutral_threshold = 0.001 # Fallback if spread is not available
            print(f"Warning: 'spread' column missing or all NaN. Using fixed neutral_threshold: {neutral_threshold}")
        df_combined['target_direction'] = 1 # Neutral
        df_combined.loc[price_diff > neutral_threshold, 'target_direction'] = 2 # Up
        df_combined.loc[price_diff < -neutral_threshold, 'target_direction'] = 0 # Down
        TARGET_COLUMN = 'target_direction'

    df_final = df_combined.dropna(subset=[TARGET_COLUMN])
    
    if not df_final.empty:
        print(f"Final data with target variable '{TARGET_COLUMN}' (head):\n", df_final[['mid_price', 'future_mid_price', TARGET_COLUMN] + [col for col in ['vamp_mid_diff', 'trade_imbalance', 'quote_imbalance'] if col in df_final.columns]].head())
        print(f"\nTarget variable ({TARGET_COLUMN}) distribution:")
        if TARGET_TYPE == 'regression':
            df_final[TARGET_COLUMN].plot(kind='hist', bins=50, title='Target Price Change Distribution')
            plt.show()
            print(df_final[TARGET_COLUMN].describe())
        elif TARGET_TYPE == 'classification':
            print(df_final[TARGET_COLUMN].value_counts(normalize=True))
    else:
        print("DataFrame is empty after target creation and NaN removal. Check data or prediction horizon.")

---
## 6. Stationarity Checks and Transformations
Time series data for financial modeling often requires features to be stationary.
The target variable (price change or returns) is usually stationary.
We should check engineered features like `vamp_mid_diff`, `trade_imbalance`, `quote_imbalance`.

In [None]:
def check_stationarity(series, series_name=''):
    """Performs ADF test and prints results."""
    if not isinstance(series, pd.Series) or series.empty or series.isnull().all():
        print(f"Series {series_name} is not a valid Series, is empty, or all NaN. Skipping stationarity check.")
        return False 
    print(f'\nStationarity Test for {series_name}:')
    try:
        # Ensure data is float for adfuller, and dropna
        series_cleaned = series.dropna().astype(float)
        if series_cleaned.empty:
            print(f"Series {series_name} is empty after dropna. Skipping stationarity check.")
            return False
        result = adfuller(series_cleaned) 
        print('ADF Statistic: %f' % result[0])
        print('p-value: %f' % result[1])
        if result[1] <= 0.05:
            print(f"Result: Likely Stationary (p-value <= 0.05) for {series_name}")
            return True
        else:
            print(f"Result: Likely Non-Stationary (p-value > 0.05) for {series_name}")
            return False
    except Exception as e:
        print(f"Error during stationarity test for {series_name}: {e}")
        return False

feature_columns_for_model = []
df_model_ready = pd.DataFrame() 

if 'df_final' in locals() and not df_final.empty:
    potential_features = ['spread', 'vamp_mid_diff', 'trade_imbalance', 'quote_imbalance']
    df_transformed = df_final.copy()

    for col in potential_features:
        if col in df_transformed.columns:
            if not check_stationarity(df_transformed[col], col):
                print(f"Feature {col} is non-stationary. Applying differencing.")
                df_transformed[f'{col}_diff'] = df_transformed[col].diff()
                # Check stationarity of the differenced series
                if check_stationarity(df_transformed[f'{col}_diff'].dropna(), f'{col}_diff'):
                    feature_columns_for_model.append(f'{col}_diff')
                else:
                    print(f"Differenced feature {col}_diff is still non-stationary. Consider further transformation or excluding.")
                    # Optionally, still add the differenced feature if you want to proceed with caution
                    # feature_columns_for_model.append(f'{col}_diff') 
            else:
                feature_columns_for_model.append(col)
    
    if TARGET_TYPE == 'regression' and TARGET_COLUMN in df_transformed.columns:
        check_stationarity(df_transformed[TARGET_COLUMN], TARGET_COLUMN)
    
    if feature_columns_for_model and TARGET_COLUMN in df_transformed.columns:
        # Ensure all selected feature columns and the target column exist before creating df_model_ready
        final_cols_to_select = [f for f in feature_columns_for_model if f in df_transformed.columns] + [TARGET_COLUMN]
        df_model_ready = df_transformed[final_cols_to_select].copy()
        df_model_ready.dropna(inplace=True)
    else:
        print("Warning: No features selected or target column missing after stationarity check.")

    if df_model_ready.empty:
        print("DataFrame for modeling (df_model_ready) is empty after dropping NaNs. Check data, feature engineering, and stationarity steps.")
    else:
        print(f"\nFeatures selected for modeling: {feature_columns_for_model}")
        print(f"Shape of df_model_ready: {df_model_ready.shape}")
else:
    print("Skipping stationarity checks as df_final is not available or empty.")

---
## 7. Model Building and Evaluation

In [None]:
if 'df_model_ready' not in locals() or df_model_ready.empty or not feature_columns_for_model:
    print("Skipping model building as data is not ready or no features are selected.")
else:
    X = df_model_ready[feature_columns_for_model]
    y = df_model_ready[TARGET_COLUMN]

    if X.empty or y.empty:
        print("X or y is empty. Cannot proceed with model training.")
    else:
        train_size_pct = 0.8
        split_idx = int(len(X) * train_size_pct)
        
        X_train, X_test = X.iloc[:split_idx], X.iloc[split_idx:]
        y_train, y_test = y.iloc[:split_idx], y.iloc[split_idx:]

        if X_train.empty or X_test.empty:
            print("Training or testing set is empty after split. Check data size and split point.")
        else:
            print(f"Training set size: {X_train.shape[0]}, Test set size: {X_test.shape[0]}")
            scaler = StandardScaler()
            X_train_scaled = scaler.fit_transform(X_train)
            X_test_scaled = scaler.transform(X_test)
            
            model_results = {}

            def evaluate_model(name, model, X_test_data, y_true, y_pred):
                if TARGET_TYPE == 'regression':
                    mse = mean_squared_error(y_true, y_pred)
                    r2 = r2_score(y_true, y_pred)
                    print(f"{name} - MSE: {mse:.4f}, R2: {r2:.4f}")
                    model_results[name] = {'MSE': mse, 'R2': r2}
                    plt.figure(figsize=(10, 6))
                    plt.scatter(y_true, y_pred, alpha=0.5, label='Predicted vs Actual')
                    min_val = min(y_true.min(), y_pred.min()) if not y_true.empty and not pd.Series(y_pred).empty else 0
                    max_val = max(y_true.max(), y_pred.max()) if not y_true.empty and not pd.Series(y_pred).empty else 1
                    plt.plot([min_val, max_val], [min_val, max_val], 'k--', lw=2, label='Perfect Prediction')
                    plt.xlabel("Actual Values")
                    plt.ylabel("Predicted Values")
                    plt.title(f"{name} - Predictions vs Actuals")
                    plt.legend()
                    plt.show()
                elif TARGET_TYPE == 'classification':
                    accuracy = accuracy_score(y_true, y_pred)
                    precision = precision_score(y_true, y_pred, average='weighted', zero_division=0)
                    recall = recall_score(y_true, y_pred, average='weighted', zero_division=0)
                    f1 = f1_score(y_true, y_pred, average='weighted', zero_division=0)
                    print(f"{name} - Accuracy: {accuracy:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}")
                    model_results[name] = {'Accuracy': accuracy, 'Precision': precision, 'Recall': recall, 'F1': f1}

            # --- Model Implementations ---
            if TARGET_TYPE == 'regression':
                print("\n--- Linear Regression ---")
                model_lr = LinearRegression()
                model_lr.fit(X_train_scaled, y_train)
                y_pred_lr = model_lr.predict(X_test_scaled)
                evaluate_model("Linear Regression", model_lr, X_test_scaled, y_test, y_pred_lr)
            elif TARGET_TYPE == 'classification':
                print("\n--- Logistic Regression ---")
                model_logr = LogisticRegression(solver='liblinear', multi_class='auto', random_state=42, max_iter=1000)
                model_logr.fit(X_train_scaled, y_train)
                y_pred_logr = model_logr.predict(X_test_scaled)
                evaluate_model("Logistic Regression", model_logr, X_test_scaled, y_test, y_pred_logr)

            print("\n--- Decision Tree ---")
            if TARGET_TYPE == 'regression':
                model_dt = DecisionTreeRegressor(random_state=42, max_depth=10, min_samples_split=10)
                model_dt.fit(X_train_scaled, y_train)
                y_pred_dt = model_dt.predict(X_test_scaled)
                evaluate_model("Decision Tree Regressor", model_dt, X_test_scaled, y_test, y_pred_dt)
            elif TARGET_TYPE == 'classification':
                model_dtc = DecisionTreeClassifier(random_state=42, max_depth=10, min_samples_split=10)
                model_dtc.fit(X_train_scaled, y_train)
                y_pred_dtc = model_dtc.predict(X_test_scaled)
                evaluate_model("Decision Tree Classifier", model_dtc, X_test_scaled, y_test, y_pred_dtc)

            print("\n--- Random Forest ---")
            if TARGET_TYPE == 'regression':
                model_rf = RandomForestRegressor(n_estimators=100, random_state=42, max_depth=10, min_samples_split=10, n_jobs=-1)
                model_rf.fit(X_train_scaled, y_train)
                y_pred_rf = model_rf.predict(X_test_scaled)
                evaluate_model("Random Forest Regressor", model_rf, X_test_scaled, y_test, y_pred_rf)
            elif TARGET_TYPE == 'classification':
                model_rfc = RandomForestClassifier(n_estimators=100, random_state=42, max_depth=10, min_samples_split=10, n_jobs=-1)
                model_rfc.fit(X_train_scaled, y_train)
                y_pred_rfc = model_rfc.predict(X_test_scaled)
                evaluate_model("Random Forest Classifier", model_rfc, X_test_scaled, y_test, y_pred_rfc)

            print("\n--- XGBoost ---")
            if TARGET_TYPE == 'regression':
                model_xgb_reg = xgb.XGBRegressor(objective='reg:squarederror', n_estimators=100, random_state=42, max_depth=7, learning_rate=0.1, n_jobs=-1)
                model_xgb_reg.fit(X_train_scaled, y_train)
                y_pred_xgb_reg = model_xgb_reg.predict(X_test_scaled)
                evaluate_model("XGBoost Regressor", model_xgb_reg, X_test_scaled, y_test, y_pred_xgb_reg)
            elif TARGET_TYPE == 'classification':
                num_class_xgb = len(np.unique(y_train)) 
                objective_xgb_clf = 'multi:softmax' if num_class_xgb > 2 else 'binary:logistic'
                model_xgbc_params = {'n_estimators': 100, 'random_state': 42, 'max_depth': 7, 'learning_rate': 0.1, 'n_jobs': -1, 'objective': objective_xgb_clf}
                if objective_xgb_clf == 'multi:softmax': model_xgbc_params['num_class'] = num_class_xgb
                model_xgbc = xgb.XGBClassifier(**model_xgbc_params)
                model_xgbc.fit(X_train_scaled, y_train)
                y_pred_xgbc = model_xgbc.predict(X_test_scaled)
                evaluate_model("XGBoost Classifier", model_xgbc, X_test_scaled, y_test, y_pred_xgbc)
                
            print("\n--- Support Vector Machine (SVM) ---")
            if TARGET_TYPE == 'regression':
                model_svr = SVR(kernel='rbf', C=1.0, epsilon=0.1)
                model_svr.fit(X_train_scaled, y_train)
                y_pred_svr = model_svr.predict(X_test_scaled)
                evaluate_model("SVR", model_svr, X_test_scaled, y_test, y_pred_svr)
            elif TARGET_TYPE == 'classification':
                model_svc = SVC(kernel='rbf', C=1.0, random_state=42, probability=True)
                model_svc.fit(X_train_scaled, y_train)
                y_pred_svc = model_svc.predict(X_test_scaled)
                evaluate_model("SVC", model_svc, X_test_scaled, y_test, y_pred_svc)

            print("\n--- Neural Network (MLP) ---")
            def create_mlp(input_dim, num_unique_targets=1, classification=False):
                model = Sequential()
                model.add(Dense(64, input_dim=input_dim, activation='relu'))
                model.add(Dropout(0.2))
                model.add(Dense(32, activation='relu'))
                model.add(Dropout(0.2))
                if classification:
                    if num_unique_targets <= 2: 
                         model.add(Dense(1, activation='sigmoid'))
                         model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
                    else: 
                         model.add(Dense(num_unique_targets, activation='softmax'))
                         model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
                else: 
                    model.add(Dense(1, activation='linear'))
                    model.compile(optimizer='adam', loss='mean_squared_error', metrics=['mae'])
                return model

            if TARGET_TYPE == 'regression':
                model_mlp_reg = create_mlp(X_train_scaled.shape[1], classification=False)
                early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
                model_mlp_reg.fit(X_train_scaled, y_train, epochs=50, batch_size=32, validation_split=0.1, callbacks=[early_stop], verbose=0)
                y_pred_mlp_reg = model_mlp_reg.predict(X_test_scaled).flatten()
                evaluate_model("MLP Regressor", model_mlp_reg, X_test_scaled, y_test, y_pred_mlp_reg)
            elif TARGET_TYPE == 'classification':
                num_unique_targets_nn = len(np.unique(y_train))
                model_mlpc = create_mlp(X_train_scaled.shape[1], num_unique_targets=num_unique_targets_nn, classification=True)
                early_stop = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
                model_mlpc.fit(X_train_scaled, y_train, epochs=50, batch_size=32, validation_split=0.1, callbacks=[early_stop], verbose=0)
                y_pred_mlpc_proba = model_mlpc.predict(X_test_scaled)
                if num_unique_targets_nn <= 2: 
                    y_pred_mlpc = (y_pred_mlpc_proba > 0.5).astype(int).flatten()
                else: 
                    y_pred_mlpc = np.argmax(y_pred_mlpc_proba, axis=1)
                evaluate_model("MLP Classifier", model_mlpc, X_test_scaled, y_test, y_pred_mlpc)

            print("\n--- Model Performance Summary ---")
            results_df = pd.DataFrame(model_results).T
            print(results_df)

---
## 8. Conclusion and Future Work

This notebook provided a framework for acquiring, preprocessing, and modeling cryptocurrency order book data for short-term price prediction. Key features inspired by "Mind the Gaps" such as Mid-Price, Spread, VAMP, Trade Imbalance, and Quote Imbalance were implemented.

**Observations from this run (based on sample data/placeholder logic):**
* (Actual observations will depend on the real data and model performance)
* The VAMP feature was noted in the paper as a strong predictor. Its effectiveness would depend on the quality and depth of LOB data used.
* Trade Imbalance and Quote Imbalance aim to capture market pressure.

**Future Work:**
* **Robust Data Pipeline**: Implement a more robust data acquisition pipeline, especially for live API data (e.g., using WebSockets for continuous LOB updates and reconstruction).
* **Advanced Feature Engineering**:
    * Explore more sophisticated weighting for Trade Imbalance.
    * Test different VAMP liquidity cutoffs and QI levels systematically.
    * Incorporate features like realized volatility, order flow toxicity, or market impact models.
* **Hyperparameter Tuning**: Systematically tune hyperparameters for each ML model (e.g., using GridSearchCV or RandomizedSearchCV with TimeSeriesSplit).
* **Stationarity**: Rigorously ensure all features used in models are stationary. Apply transformations like differencing if needed and re-evaluate.
* **Model Ensembling/Stacking**: Combine predictions from multiple models to potentially improve performance.
* **Deeper Neural Networks**: Explore more complex architectures like LSTMs or GRUs, which are well-suited for time series data.
* **Alternative Prediction Targets**: Expand on the binary and multiclass classification approaches from the paper, especially predicting one-standard-deviation price movements.
* **Backtesting Framework**: Develop a rigorous backtesting framework that accounts for transaction costs, slippage, and realistic trading conditions. The P&L metric in the paper is a good starting point.
* **Expand Dataset**: Analyze data across different crypto assets and exchanges, and longer time periods, including diverse market conditions (e.g., high volatility periods).

---
### References from "Mind the Gaps" used in this notebook:
- [1] Martin, P., Line Jr., W., Feng, Y., Yang, Y., Zheng, S., Qi, S., & Zhu, B. (2022). *Mind the Gaps: Short-term Crypto Price Prediction*. Cornell University. Available at SSRN: https://ssrn.com/abstract=4351947
- [16] Prediction at time scales from one second to 60 seconds.
- [18] Volume-Adjusted Mid-Price as the ultimate short-term predictor.
- [19, 20] Data sourcing: Bitstamp, three full months, tick level.
- [21] Initial feature calculation: spread, mid-price, best bid/ask, volume-adjusted versions.
- [25, 26] Condensing dataset to full seconds.
- [48, 49] Volume-Adjusted Mid-Price (VAMP) definition and formula.
- [52] Plotting (mid-price - VAMP) against returns.
- [54, 55, 159] VAMP volume cutoffs, settling on $50k-$60k range, specifically $60k.
- [56, 57, 58] Trade Imbalance (TI) definition, formula with linear weight, range -1 to 1.
- [72] Using 1-minute window for Trade Imbalance.
- [88, 89, 94] Quote Imbalance (QI) definition, formula, range -1 to 1, using up to level 5.
- [93] QI relationship becoming more linear with deeper levels.
- [128] Trading P&L metric introduction.
- [152] Binary classification setup: strict inequalities for price change prediction.
- [171] Multiclass classification setup: one standard deviation thresholds.
- [195] Expanding data to include diverse BTC data and volatile conditions.

Note: Citation numbers in the markdown cells (e.g., `[cite: X]`) refer to page numbers or specific findings in the provided PDF "Mind-the-Gaps-Short-term-Crypto-Price-Prediction-2022.pdf".