In [1]:
!curl ipinfo.io

{
  "ip": "202.189.108.190",
  "city": "Hong Kong",
  "region": "Hong Kong",
  "country": "HK",
  "loc": "22.2783,114.1747",
  "org": "AS4528 The University of Hong Kong",
  "postal": "999077",
  "timezone": "Asia/Hong_Kong",
  "readme": "https://ipinfo.io/missingauth"
}

In [2]:
import json
import time
import pandas as pd
import websocket
import threading
from datetime import datetime as dt, timedelta
import os
import traceback
import numpy as np
import pandas as pd

In [3]:

class GrossmanMillerModel:
    def __init__(self, num_mm, gamma, sigma_squared):
        """
        Initializes the Grossman-Miller model.

        Args:
            num_mm (int): Number of market makers.
            gamma (float): Risk aversion parameter.
            sigma_squared (float): Variance of the asset price shock.
        """
        self.num_mm = num_mm
        self.gamma = gamma
        self.sigma_squared = sigma_squared
        self.mu = 0  # Expected asset value (can be dynamic in a more complex model)

    def calculate_price_t1(self, i):
        """
        Calculates the equilibrium price at t=1.

        Args:
            i (float): Liquidity trader's desired trade (positive for sell, negative for buy).

        Returns:
            float: Equilibrium price at t=1.
        """
        return self.mu - self.gamma * self.sigma_squared * (i / (self.num_mm + 1))

    def calculate_quantity_t1(self, i):
        """
        Calculates the quantity of asset held by each MM and LT1 at t=1.

        Args:
            i (float): Liquidity trader's desired trade.

        Returns:
            float: Quantity held by each agent at t=1.
        """
        return i / (self.num_mm + 1)

    def calculate_price_impact(self, i):
        """
        Calculates the price impact (lambda) and the actual trade quantity of LT1.

        Args:
            i (float): Liquidity trader's desired trade.

        Returns:
            tuple: A tuple containing lambda and the trade quantity of LT1.
        """

        trade_quantity_lt1 = i * self.num_mm / (self.num_mm + 1)
        price_impact = -(1 / self.num_mm) * self.gamma * self.sigma_squared
        return price_impact, trade_quantity_lt1


    def run_simulation(self, trades):
        """
        Runs a simulation of the Grossman-Miller model over a series of trades.

        Args:
            trades (list): A list of liquidity trader trades (i values).

        Returns:
            pd.DataFrame: A DataFrame containing the simulation results.
        """
        prices_t1 = []
        quantities_t1 = []
        price_impacts = []
        trade_quantities = []

        for i in trades:
            price_t1 = self.calculate_price_t1(i)
            quantity_t1 = self.calculate_quantity_t1(i)
            price_impact, trade_quantity_lt1 = self.calculate_price_impact(i)

            prices_t1.append(price_t1)
            quantities_t1.append(quantity_t1)
            price_impacts.append(price_impact)
            trade_quantities.append(trade_quantity_lt1)

        results = pd.DataFrame({
            "trade": trades,
            "price_t1": prices_t1,
            "quantity_t1": quantities_t1,
            "price_impact": price_impacts,
            "trade_quantity_lt1": trade_quantities
        })
        return results

In [12]:
class MarketDataCollector:
    def __init__(self, data_path='crypto_data'):
        self.combined_data = []
        self.data_path = data_path
        self.running = True

        if not os.path.exists(data_path):
            os.makedirs(data_path)
            print(f"Created directory: {data_path}")
        else:
            print(f"Using existing directory: {data_path}")

    def save_data(self, df, filename):
        try:
            full_path = os.path.join(self.data_path, filename)
            print(f"Saving data to {full_path}...")
            df.to_csv(full_path, index=False)  #
            print(f"Successfully saved {len(df)} records")
        except Exception as e:
            print(f"Save failed: {str(e)[:200]}")


    def _ws_handler(self, message, data_type):
        try:
            data = json.loads(message)
            timestamp = dt.fromtimestamp(data['E']/1000).strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]

            # Add safety checks for order book data
            bid_price = None
            ask_price = None
            if data_type == 'depth':
                if len(data.get('b', [])) > 0:
                    bid_price = float(data['b'][0][0])
                if len(data.get('a', [])) > 0:
                    ask_price = float(data['a'][0][0])

            record = {
                'timestamp': timestamp,
                'bid_price': bid_price,
                'ask_price': ask_price,
                'trade_price': float(data['p']) if data_type == 'trade' else None,
                'volume': float(data['q'])*float(data['p']) if data_type == 'trade' else None
            }
            self.combined_data.append(record)
        except Exception as e:
            print(f"Processing error: {str(e)[:200]}")


    def on_order_book_message(self, ws, message):
        self._ws_handler(message, 'depth')

    def on_trade_message(self, ws, message):
        self._ws_handler(message, 'trade')


    def _run_websocket(self, url, handler, stream_type):
        end_time = time.time() + (3600 * 24)  # Fail-safe timeout
        while time.time() < end_time and self.running:
            try:
                ws = websocket.WebSocketApp(
                    url,
                    on_message=handler,
                    on_error=lambda ws, e: print(f"{stream_type} error: {str(e)[:200]}"),
                    on_close=lambda ws: print(f"{stream_type} closed"),
                    on_open=lambda ws: print(f"{stream_type} connected")
                )
                ws.run_forever(ping_interval=30, ping_timeout=10)
            except Exception as e:
                print(f"{stream_type} failure: {str(e)[:200]}")
            time.sleep(5)

    def collect_long_duration_data(self, symbol, duration_hours=0.1, checkpoint_minutes=1):

        total_seconds = duration_hours * 3600
        checkpoint_seconds = checkpoint_minutes * 60
        start_time = time.time()
        end_time = start_time + total_seconds
        last_checkpoint = start_time
        checkpoint_count = 0

        print(f"\n{'='*40}\nStarting {duration_hours}-hour collection for {symbol}")
        print(f"Checkpoints every {checkpoint_minutes} mins | Target end: {dt.fromtimestamp(end_time)}\n{'='*40}")


        threads = [
            threading.Thread(target=self._run_websocket,
                args=(f"wss://stream.binance.com:9443/ws/{symbol}@depth@100ms",
                    self.on_order_book_message, "OrderBook")),
            threading.Thread(target=self._run_websocket,
                args=(f"wss://stream.binance.com:9443/ws/{symbol}@trade",
                    self.on_trade_message, "Trades"))
        ]

        for t in threads:
            t.daemon = True
            t.start()

        alive_threads = sum(1 for t in threads if t.is_alive())
        print(f"Active connections: {alive_threads}/2 | Buffer size: {len(self.combined_data)}")

        try:
            while time.time() < end_time and self.running:
                # !IMPROVEMENT: Adaptive sleep management
                remaining = end_time - time.time()
                sleep_time = max(0, min(1, remaining))
                time.sleep(sleep_time)

                # Checkpoint handling
                if time.time() - last_checkpoint >= checkpoint_seconds:
                    checkpoint_count += 1
                    last_checkpoint = time.time()
                    self._process_checkpoint(symbol, checkpoint_count)

                if int(time.time() - start_time) % 10 == 0:
                    elapsed = time.time() - start_time
                    progress = min(100, (elapsed / total_seconds) * 100)
                    print(f"Progress: {progress:.1f}% | Records: {len(self.combined_data)}")

        except KeyboardInterrupt:
            print("\nUser requested shutdown...")
        finally:
            self.running = False
            self._final_save(symbol, duration_hours)
            print("\nCollection completed" if time.time() >= end_time else "\nCollection aborted")

    def _process_checkpoint(self, symbol, count):
        try:
            print(f"\n{'='*20} Checkpoint {count} {'='*20}")
            original_count = len(self.combined_data)

            # Process COPY of data
            temp_data = self.combined_data.copy()
            df = self.process_data()

            if not df.empty:
                # Only clear original data AFTER successful processing
                self.combined_data = self.combined_data[original_count:]  # Keep unprocessed data
                timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
                filename = f"{symbol}_checkpoint_{count}_{timestamp}.csv"
                self.save_data(df, filename)

            # Memory management (preserve last 10k)
            if len(self.combined_data) > 10000:
                self.combined_data = self.combined_data[-10000:]

        except Exception as e:
            print(f"Checkpoint failed: {str(e)[:200]}")
            traceback.print_exc()


    def process_data(self):
        try:
            if not self.combined_data:
                return pd.DataFrame()

            batch = pd.DataFrame(self.combined_data.copy())
            batch['timestamp'] = pd.to_datetime(batch['timestamp'])
            batch = batch.drop_duplicates(subset=['timestamp'], keep='last')

            # ==== New Validation Checks ====
            required_columns = {'bid_price', 'ask_price', 'trade_price', 'volume'}
            if not required_columns.issubset(batch.columns):
                missing = required_columns - set(batch.columns)
                print(f"Missing columns: {missing}")
                return pd.DataFrame()

            if not batch.empty:
                time_span = batch['timestamp'].max() - batch['timestamp'].min()
                if time_span < pd.Timedelta('1s'):
                    print(f"Critical time range error: {time_span}")
                    return pd.DataFrame()
            # ==============================

            if not batch.empty:
                resampled = (
                    batch.set_index('timestamp')
                    .resample('1000ms', origin='start')
                    .agg({
                        'bid_price': 'ffill',
                        'ask_price': 'ffill',
                        'trade_price': 'bfill',
                        'volume': 'sum'
                    })
                    .reset_index()
                )
                resampled['mid_price'] = (resampled['bid_price'] + resampled['ask_price']) / 2
                resampled = resampled.ffill().dropna(subset=['bid_price', 'ask_price'], how='all')

                print(f"Processed {len(batch)} records -> {len(resampled)} data points")
                return resampled
            return pd.DataFrame()
        except Exception as e:
            print(f"Processing error: {str(e)[:200]}")
            traceback.print_exc()
            return pd.DataFrame()



    def _final_save(self, symbol, duration):

        try:
            print("\nFinalizing collection...")
            df = self.process_data()
            if not df.empty:
                timestamp = dt.now().strftime("%Y%m%d_%H%M%S")
                filename = f"{symbol}_FINAL_{duration}h_{timestamp}.csv"
                self.save_data(df, filename)
        except Exception as e:
            print(f"Final save failed: {str(e)[:200]}")


In [5]:
def backtest_grossman_miller(symbol, duration_hours=0.1, gamma=1, sigma_squared=0.01, 
                            num_mm=10, csv_path=None, resample_freq=None):
    """
    Fetches market data (or loads from CSV), simulates the Grossman-Miller model, and analyzes the results.
    
    Args:
        symbol (str): The trading symbol (e.g., "btcusdt").
        duration_hours (float): Duration for data collection (ignored if csv_path is provided).
        gamma (float): Risk aversion parameter for the model.
        sigma_squared (float): Variance of price shock for the model.
        num_mm (int): Number of market makers for the model.
        csv_path (str, optional): Path to a CSV file with market data.
        resample_freq (str, optional): Frequency to resample data to (e.g., "1s", "100ms")
    """
    
    if csv_path and os.path.exists(csv_path):
        print(f"Loading market data from {csv_path}")
        try:
            # Load data from CSV file
            processed_data = pd.read_csv(csv_path)
            
            # Convert timestamp to datetime if it's not already
            if 'timestamp' in processed_data.columns and not pd.api.types.is_datetime64_any_dtype(processed_data['timestamp']):
                processed_data['timestamp'] = pd.to_datetime(processed_data['timestamp'])
            
            # Required columns check
            required_columns = ['bid_price', 'ask_price', 'trade_price', 'volume']
            missing_columns = [col for col in required_columns if col not in processed_data.columns]
            
            if missing_columns:
                print(f"Warning: Missing required columns in CSV: {missing_columns}")
                
                # Calculate mid_price if needed
                if 'mid_price' not in processed_data.columns and 'bid_price' in processed_data.columns and 'ask_price' in processed_data.columns:
                    processed_data['mid_price'] = (processed_data['bid_price'] + processed_data['ask_price']) / 2
                    print("Calculated mid_price from bid_price and ask_price")
                    
            print(f"Loaded {len(processed_data)} records from CSV")
            
            # Resample the data if requested
            if resample_freq and 'timestamp' in processed_data.columns:
                print(f"Resampling data to {resample_freq} frequency")
                processed_data = processed_data.set_index('timestamp')
                
                # Calculate average time between records for diagnostics
                avg_time_delta = (processed_data.index.max() - processed_data.index.min()) / len(processed_data)
                print(f"Average time between records: {avg_time_delta}")
                
                # Resample with appropriate aggregation methods
                processed_data = processed_data.resample(resample_freq).agg({
                    'bid_price': 'last',
                    'ask_price': 'last',
                    'trade_price': 'last',
                    'volume': 'sum',
                    'mid_price': 'last' if 'mid_price' in processed_data.columns else None
                }).dropna()
                
                # Recalculate mid_price if needed
                if 'mid_price' not in processed_data.columns:
                    processed_data['mid_price'] = (processed_data['bid_price'] + processed_data['ask_price']) / 2
                    
                processed_data = processed_data.reset_index()
                print(f"After resampling: {len(processed_data)} records")
            
        except Exception as e:
            print(f"Error loading data from CSV: {str(e)}")
            print("Falling back to real-time data collection...")
            csv_path = None
    
    # Real-time data collection fallback
    if not csv_path or not os.path.exists(csv_path):
        print(f"Collecting real-time market data for {symbol}...")
        collector = MarketDataCollector()
        collector.collect_long_duration_data(symbol, duration_hours=duration_hours, checkpoint_minutes=1)
        processed_data = collector.process_data()
    
    if processed_data.empty:
        print("No data to backtest.")
        return
    
    # Determine data frequency for parameter scaling
    if 'timestamp' in processed_data.columns:
        processed_data['time_diff'] = processed_data['timestamp'].diff()
        median_time_diff = processed_data['time_diff'].median()
        print(f"Median time between records: {median_time_diff}")
        
        # Scale sigma_squared based on data frequency
        if median_time_diff < pd.Timedelta('500ms'):
            # High-frequency data needs reduced sigma_squared
            original_sigma = sigma_squared
            sigma_squared = sigma_squared * 0.1  # Scale down for high frequency
            print(f"High-frequency data detected. Scaling sigma_squared from {original_sigma} to {sigma_squared}")
    
    # Adaptive threshold calculation based on data characteristics
    volume_std = processed_data['volume'].std()
    volume_mean = processed_data['volume'].mean()
    
    # Use higher threshold for noisier data
    if volume_std > volume_mean * 10:
        threshold = processed_data['volume'].quantile(0.75)  # 75th percentile for high variation
        print(f"High volume variation detected. Using 75th percentile threshold: {threshold}")
    else:
        threshold = volume_mean
        print(f"Using mean volume threshold: {threshold}")
    
    # Generate trade sizes with frequency-appropriate scaling
    trades = []
    for index, row in processed_data.iterrows():
        trade_size = 0
        if pd.notna(row['volume']) and row['volume'] > threshold:
            # Adjust divisor based on price level
            price_level = row['mid_price'] if pd.notna(row['mid_price']) else row['trade_price']
            divisor = 10 if price_level < 100 else 100
            
            # Scale for high-frequency data
            if 'time_diff' in processed_data.columns and median_time_diff < pd.Timedelta('500ms'):
                divisor *= 10  # Further reduce trade sizes for millisecond data
                
            trade_size = (row['volume'] - threshold) / divisor
        trades.append(trade_size)
    
    # Create model and run simulation with adjusted parameters
    model = GrossmanMillerModel(num_mm, gamma, sigma_squared)
    results = model.run_simulation(trades)
    
    # Ensure dataframes align
    if 'timestamp' in processed_data.columns:
        processed_data = processed_data.reset_index(drop=True)
    
    if len(results) != len(processed_data):
        min_len = min(len(results), len(processed_data))
        results = results.iloc[:min_len]
        processed_data = processed_data.iloc[:min_len]
    
    # Combine results and calculate PnL
    combined_results = pd.concat([processed_data, results], axis=1)
    combined_results['price_change'] = combined_results['mid_price'].diff()
    combined_results['mm_pnl'] = -combined_results['price_impact'] * combined_results['trade_quantity_lt1'] * combined_results['price_change']
    
    # Diagnostic statistics
    print("\nMarket Making Performance Metrics:")
    print(f"Total trades: {len(combined_results[combined_results['trade'] > 0])}")
    print(f"Average trade size: {combined_results['trade'].mean():.6f}")
    print(f"Average price change: {combined_results['price_change'].mean():.6f}")
    print(f"Price change volatility: {combined_results['price_change'].std():.6f}")
    print(f"Total market maker PnL: {combined_results['mm_pnl'].sum():.2f}")
    
    # Return full results for further analysis
    return combined_results

In [8]:
results_100ms = backtest_grossman_miller(
    symbol="bswusdt",
    gamma=1,
    sigma_squared=0.001,
    num_mm=10,
    csv_path="/home/misango/code/Algorithmic_Trading_and_HFT_Research/Market_Making/Avellaneda-Stoikov/Data_Folder_Test/HFT_1_hr_combined_crypto_data.csv",
    resample_freq=None  # Resample to 100ms intervals
)
'''
results_1s = backtest_grossman_miller(
    symbol="btcusdt",
    gamma=1,
    sigma_squared=0.01,  # Higher for second-level data
    num_mm=10,
    csv_path="/home/misango/code/Algorithmic_Trading_and_HFT_Research/Market_Making/Avellaneda-Stoikov/Data_Folder_Test/HFT_2_hr_combined_crypto_data.csv",
    resample_freq="1000ms"  # Resample to 1-second intervals
)
'''
#print(f"Grossman Miller MM simulation on 100ms BTC/USDT data PnL: {results_100ms['mm_pnl'].sum():.2f}")
#print(f"1s data PnL: {results_1s['mm_pnl'].sum():.2f}")

Loading market data from /home/misango/code/Algorithmic_Trading_and_HFT_Research/Market_Making/Avellaneda-Stoikov/Data_Folder_Test/HFT_1_hr_combined_crypto_data.csv
Loaded 50223 records from CSV
Using mean volume threshold: 83.23278549110965

Market Making Performance Metrics:
Total trades: 8475
Average trade size: 6.487977
Average price change: -0.000000
Price change volatility: 0.000904
Total market maker PnL: 0.00


'\nresults_1s = backtest_grossman_miller(\n    symbol="btcusdt",\n    gamma=1,\n    sigma_squared=0.01,  # Higher for second-level data\n    num_mm=10,\n    csv_path="/home/misango/code/Algorithmic_Trading_and_HFT_Research/Market_Making/Avellaneda-Stoikov/Data_Folder_Test/HFT_2_hr_combined_crypto_data.csv",\n    resample_freq="1000ms"  # Resample to 1-second intervals\n)\n'