In [None]:
import numpy as np
from scipy.stats import linregress
from datetime import datetime, timedelta
import time
import logging

from ib_insync import IB, Contract, util, MarketOrder, OrderState, Stock, Bar

# Configure logging to console
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

class SimpleDynamicMomentumAlgorithmIBKR:
    def __init__(self, host='127.0.0.1', port=7497, client_id=1):
        self.ib = IB()
        self.host = host
        self.port = port
        self.client_id = client_id

        # List of trading symbols
        self.symbols_str = ["AAPL", "MSFT", "GOOGL", "AMZN", "META", "NVDA", "TSLA"]
        self.contracts = [Stock(s, 'SMART', 'USD') for s in self.symbols_str]
        self.market_contract = Stock('SPY', 'SMART', 'USD')

        self.lookback = 90  # Lookback period for momentum calculation (days)
        self.rebalance_period_days = 30  # Rebalance period (days)
        self.next_rebalance_time = None

        self.atr_period = 14  # ATR calculation period
        self.atr_multiplier = 2.5  # ATR stop-loss multiplier

        # Price anchor for trailing stop-loss
        # For long positions: highest price. For short positions: lowest price.
        self.trailing_price_anchor = {}

        # Historical OHLC data
        self.history_data = {c: [] for c in self.contracts + [self.market_contract]}
        self.current_prices = {}  # Latest real-time prices

        # SMA periods for market regime
        self.short_sma_period = 50
        self.long_sma_period = 200
        self.spy_short_sma = 0
        self.spy_long_sma = 0

        self.is_connected = False

    def connect(self):
        # Connect to IBKR TWS
        try:
            self.ib.connect(self.host, self.port, self.client_id)
            self.is_connected = True
            logging.info(f"Connected to IBKR at {self.host}:{self.port} with Client ID {self.client_id}")

            # Register callback functions
            self.ib.pendingTickersEvent += self.on_tick_update
            self.ib.orderStatusEvent += self.on_order_status
            self.ib.execDetailsEvent += self.on_exec_details

            # Set first rebalance time
            now = datetime.now()
            next_rebalance = now.replace(hour=10, minute=0, second=0, microsecond=0)
            if now > next_rebalance:
                self.next_rebalance_time = next_rebalance + timedelta(days=1)
            else:
                self.next_rebalance_time = next_rebalance

            logging.info(f"First rebalance scheduled for: {self.next_rebalance_time}")

        except Exception as e:
            logging.error(f"Failed to connect to IBKR: {e}")
            self.is_connected = False

    def disconnect(self):
        # Disconnect from IBKR
        if self.is_connected:
            self.ib.disconnect()
            self.is_connected = False
            logging.info("Disconnected from IBKR")

    async def initialize_historical_data(self):
        # Asynchronously request historical OHLC data for all contracts
        end_datetime = ''
        duration_str = f'{max(self.lookback, self.long_sma_period, self.atr_period) + 20} D'
        bar_size = '1 day'
        what_to_show = 'TRADES'
        use_rth = True

        for contract in self.contracts + [self.market_contract]:
            logging.info(f"Requesting historical data for {contract.symbol}...")
            bars = await self.ib.reqHistoricalDataAsync(
                contract,
                end_datetime=end_datetime,
                durationStr=duration_str,
                barSizeSetting=bar_size,
                whatToShow=what_to_show,
                useRTH=use_rth
            )
            if bars:
                self.history_data[contract] = [
                    {'date': b.date, 'open': b.open, 'high': b.high, 'low': b.low, 'close': b.close}
                    for b in bars
                ]
                logging.info(f"Received {len(bars)} historical bars for {contract.symbol}.")
            else:
                logging.warning(f"No historical data received for {contract.symbol}.")

        # Subscribe to real-time data
        for contract in self.contracts + [self.market_contract]:
            self.ib.reqMktData(contract, '', False, False)
            logging.info(f"Subscribed to market data for {contract.symbol}")

    def on_tick_update(self, tickers):
        # Real-time market data update callback
        for ticker in tickers:
            contract = ticker.contract
            current_price = ticker.last if ticker.last else ticker.close

            if current_price and current_price > 0:
                self.current_prices[contract] = current_price
                self.UpdateTrailingStopLoss(contract, current_price)
            else:
                logging.debug(f"Received invalid price for {contract.symbol}: {current_price}")

    def on_order_status(self, trade):
        # Order status update callback
        logging.info(f"Order status update for Order ID {trade.order.orderId} ({trade.contract.symbol}): {trade.orderStatus.status}")

    def on_exec_details(self, trade, fill):
        # Order execution details callback
        logging.info(f"Order {trade.order.orderId} filled for {fill.contract.symbol} at {fill.avgPrice} for {fill.shares} shares.")

        is_opening_long = (fill.execution.side == 'BOT') and (fill.contract not in self.ib.positions() or self.ib.positions()[fill.contract].position == 0)
        is_opening_short = (fill.execution.side == 'SLD') and (fill.contract not in self.ib.positions() or self.ib.positions()[fill.contract].position == 0)

        # Update price anchor based on opening trade direction
        if is_opening_long:
            self.trailing_price_anchor[fill.contract] = fill.avgPrice
            logging.info(f"Recorded initial anchor for {fill.contract.symbol} (long): {self.trailing_price_anchor[fill.contract]:.2f}")
        elif is_opening_short:
            self.trailing_price_anchor[fill.contract] = fill.avgPrice
            logging.info(f"Recorded initial anchor for {fill.contract.symbol} (short): {self.trailing_price_anchor[fill.contract]:.2f}")
        else: # Remove anchor on closing or position adjustment
            if fill.contract in self.trailing_price_anchor:
                del self.trailing_price_anchor[fill.contract]
            logging.info(f"Removed {fill.contract.symbol} from trailing stop tracking after fill (assumed closing trade).")


    def UpdateTrailingStopLoss(self, contract, current_price):
        # Update ATR-based trailing stop-loss and check for trigger
        current_position_obj = self.ib.positions().get(contract)

        if not current_position_obj or current_position_obj.position == 0:
            if contract in self.trailing_price_anchor:
                del self.trailing_price_anchor[contract]
            return

        current_shares = current_position_obj.position

        if contract not in self.history_data or len(self.history_data[contract]) < self.atr_period + 1:
            logging.debug(f"Not enough historical data for ATR for {contract.symbol}. Skipping trailing stop update.")
            return

        # Calculate current ATR
        temp_history_for_atr = list(self.history_data[contract])
        temp_history_for_atr.append({'date': datetime.now(), 'open': current_price, 'high': current_price, 'low': current_price, 'close': current_price})

        current_atr = self.calculate_atr(temp_history_for_atr)

        if current_atr == 0:
            logging.warning(f"Calculated ATR is 0 for {contract.symbol}. Skipping trailing stop update.")
            return

        stop_distance = current_atr * self.atr_multiplier

        if current_shares > 0:  # Long position
            if contract not in self.trailing_price_anchor:
                self.trailing_price_anchor[contract] = current_price
            else:
                self.trailing_price_anchor[contract] = max(self.trailing_price_anchor[contract], current_price)

            trailing_stop_price = self.trailing_price_anchor[contract] - stop_distance

            if current_price < trailing_stop_price:
                self.liquidate_position(contract)
                logging.info(f"Trailing stop-loss triggered for LONG {contract.symbol} at {current_price:.2f}. Stop price: {trailing_stop_price:.2f}. Liquidating.")
                if contract in self.trailing_price_anchor:
                    del self.trailing_price_anchor[contract]

        elif current_shares < 0:  # Short position
            if contract not in self.trailing_price_anchor:
                self.trailing_price_anchor[contract] = current_price
            else:
                self.trailing_price_anchor[contract] = min(self.trailing_price_anchor[contract], current_price)

            trailing_stop_price = self.trailing_price_anchor[contract] + stop_distance

            if current_price > trailing_stop_price:
                self.liquidate_position(contract)
                logging.info(f"Trailing stop-loss triggered for SHORT {contract.symbol} at {current_price:.2f}. Stop price: {trailing_stop_price:.2f}. Liquidating.")
                if contract in self.trailing_price_anchor:
                    del self.trailing_price_anchor[contract]

    def liquidate_position(self, contract):
        # Liquidate all positions for the specified contract
        try:
            current_positions_map = {p.contract: p for p in self.ib.positions()}
            pos = current_positions_map.get(contract)

            if pos and pos.position != 0:
                action = 'SELL' if pos.position > 0 else 'BUY'
                order_qty = abs(pos.position)
                if order_qty > 0:
                    order = MarketOrder(action, order_qty)
                    order.orderId = self.ib.client.getReqId()
                    self.ib.placeOrder(contract, order)
                    logging.info(f"Placed liquidation order for {action} {order_qty} shares of {contract.symbol}. Order ID: {order.orderId}")
        except Exception as e:
            logging.error(f"Error during liquidation of {contract.symbol}: {e}")


    def calculate_momentum(self, history_data_list):
        # Calculate annualized volatility-adjusted momentum slope
        if len(history_data_list) < self.lookback:
            return 0

        closes = np.array([d['close'] for d in history_data_list[-self.lookback:]])
        log_prices = np.log(closes)

        days = np.arange(len(log_prices))
        slope, _, _, _, _ = linregress(days, log_prices)
        annualized_slope = slope * 252

        if len(log_prices) < 2:
            return 0

        log_returns = np.diff(log_prices)
        historical_volatility = np.std(log_returns) * np.sqrt(252)

        if historical_volatility > 0:
            return annualized_slope / historical_volatility
        return 0


    def calculate_atr(self, history_data_list):
        # Calculate Average True Range (ATR)
        if len(history_data_list) < self.atr_period + 1:
            return 0

        true_ranges = []
        for i in range(1, len(history_data_list)):
            high = history_data_list[i]['high']
            low = history_data_list[i]['low']
            prev_close = history_data_list[i-1]['close']

            tr = max(high - low, abs(high - prev_close), abs(low - prev_close))
            true_ranges.append(tr)

        if len(true_ranges) < self.atr_period:
            return 0

        initial_atr = np.mean(true_ranges[:self.atr_period])

        atr_values = [initial_atr]
        for i in range(self.atr_period, len(true_ranges)):
            current_atr = (atr_values[-1] * (self.atr_period - 1) + true_ranges[i]) / self.atr_period
            atr_values.append(current_atr)

        return atr_values[-1]


    def calculate_sma(self, history_prices, period):
        # Calculate Simple Moving Average (SMA)
        if len(history_prices) < period:
            return 0
        return np.mean(history_prices[-period:])

    async def Rebalance(self):
        # Execute the strategy's rebalancing logic
        now = datetime.now()
        if now < self.next_rebalance_time or \
           now.hour != self.next_rebalance_time.hour or \
           now.minute != self.next_rebalance_time.minute:
             return

        logging.info(f"Starting Rebalance at {now}")

        spy_closes = [d['close'] for d in self.history_data[self.market_contract]]
        if len(spy_closes) >= self.long_sma_period:
            self.spy_short_sma = self.calculate_sma(spy_closes, self.short_sma_period)
            self.spy_long_sma = self.calculate_sma(spy_closes, self.long_sma_period)
        else:
            logging.warning("Not enough SPY history data to calculate SMAs. Skipping rebalance.")
            self.next_rebalance_time = (now + timedelta(days=self.rebalance_period_days)).replace(hour=10, minute=0, second=0, microsecond=0)
            return

        # Determine market regime (Bull/Bear)
        if self.spy_short_sma > self.spy_long_sma:
            long_weight_factor = 0.99
            short_weight_factor = 0.01
            market_condition = "Bull Market"
        else:
            long_weight_factor = 0.01
            short_weight_factor = 0.99
            market_condition = "Bear Market"
        logging.info(f"Market condition: {market_condition} (SPY SMA50: {self.spy_short_sma:.2f}, SMA200: {self.spy_long_sma:.2f})")

        momentum = {}
        for contract in self.contracts:
            if contract in self.history_data and len(self.history_data[contract]) >= self.lookback + 1:
                momentum[contract] = self.calculate_momentum(self.history_data[contract])
            else:
                logging.warning(f"Not enough historical data for volatility-adjusted momentum calculation for {contract.symbol}. Skipping.")

        # Sort symbols by momentum value in descending order
        sorted_symbols = sorted(momentum.items(), key=lambda x: x[1], reverse=True)

        num_long_positions = int(len(sorted_symbols) * long_weight_factor)
        num_short_positions = int(len(sorted_symbols) * short_weight_factor)

        long_contracts = [contract for contract, mom in sorted_symbols[:num_long_positions]]
        short_contracts = [contract for contract, mom in sorted_symbols[-num_short_positions:]]

        # Allocate weight for each long/short position
        long_weight_per_position = long_weight_factor / num_long_positions if num_long_positions > 0 else 0
        short_weight_per_position = short_weight_factor / num_short_positions if num_short_positions > 0 else 0

        # Get account Net Liquidation Value
        account_value = 0
        try:
            account_values = await self.ib.reqAccountValuesAsync()
            for val in account_values:
                if val.tag == 'NetLiquidation' and val.currency == 'USD':
                    account_value = float(val.value)
                    break
        except Exception as e:
            logging.error(f"Error retrieving account value: {e}")

        if account_value == 0:
            logging.error("Could not retrieve account NetLiquidation value. Aborting rebalance.")
            self.next_rebalance_time = (now + timedelta(days=self.rebalance_period_days)).replace(hour=10, minute=0, second=0, microsecond=0)
            return

        logging.info(f"Current Net Liquidation Value: {account_value:.2f}")

        # Get current positions
        current_positions = {p.contract: p for p in await self.ib.reqPositionsAsync()}

        # Iterate through all monitored contracts and adjust positions
        for contract in self.contracts:
            target_shares = 0
            current_shares = current_positions.get(contract, util.Position(contract, 0, 0, 0)).position

            # Determine target position based on strategy
            if contract in long_contracts:
                if contract in self.current_prices and self.current_prices[contract] > 0:
                    price = self.current_prices[contract]
                    target_value = account_value * long_weight_per_position
                    target_shares = round(target_value / price)
                    logging.info(f"Long candidate: {contract.symbol}, Target shares: {target_shares}, Current shares: {current_shares}")
                else:
                    logging.warning(f"No valid current price for {contract.symbol}. Cannot place long order.")
                    continue
            elif contract in short_contracts:
                if contract in self.current_prices and self.current_prices[contract] > 0:
                    price = self.current_prices[contract]
                    target_value = account_value * short_weight_per_position
                    target_shares = -round(target_value / price)
                    logging.info(f"Short candidate: {contract.symbol}, Target shares: {target_shares}, Current shares: {current_shares}")
                else:
                    logging.warning(f"No valid current price for {contract.symbol}. Cannot place short order.")
                    continue
            else: # Not in long or short lists, liquidate
                target_shares = 0
                logging.info(f"Liquidating {contract.symbol}. Current shares: {current_shares}")

            # Calculate shares to adjust
            diff_shares = target_shares - current_shares

            if diff_shares != 0:
                action = 'BUY' if diff_shares > 0 else 'SELL'
                order_qty = abs(diff_shares)
                order = MarketOrder(action, order_qty)
                order.orderId = self.ib.client.getReqId()

                try:
                    trade = self.ib.placeOrder(contract, order)
                    logging.info(f"Placed order for {action} {order_qty} shares of {contract.symbol}. Order ID: {order.orderId}")
                except Exception as e:
                    logging.error(f"Error placing order for {contract.symbol}: {e}")
            else:
                logging.info(f"No change in position for {contract.symbol}.")

        self.next_rebalance_time = (now + timedelta(days=self.rebalance_period_days)).replace(hour=10, minute=0, second=0, microsecond=0)
        logging.info(f"Rebalance completed. Next rebalance scheduled for: {self.next_rebalance_time}")

    async def run(self):
        # Main strategy execution loop
        if not self.is_connected:
            self.connect()
            if not self.is_connected:
                logging.error("Failed to establish IBKR connection. Exiting.")
                return

        await self.initialize_historical_data()
        logging.info("Historical data initialized and market data subscribed.")

        while self.is_connected:
            self.ib.sleep(1) # Sleep for 1 second to allow ib_insync to process events

            now = datetime.now()
            # Check and execute rebalance daily at 10:00 AM
            if now.hour == self.next_rebalance_time.hour and \
               now.minute == self.next_rebalance_time.minute and \
               now >= self.next_rebalance_time:
                await self.Rebalance()

        self.disconnect()
        logging.info("Algorithm finished running.")


async def main():
    # Main function to create and run the strategy
    algo = SimpleDynamicMomentumAlgorithmIBKR(host='127.0.0.1', port=7497, client_id=1)
    await algo.run()

if __name__ == "__main__":
    from ib_insync import util
    util.patchAsyncio() # Patch asyncio for compatibility in certain environments

    try:
        import asyncio
        import nest_asyncio
        nest_asyncio.apply() # Allows running new event loops in an already running one

        asyncio.run(main())
    except KeyboardInterrupt:
        logging.info("Algorithm stopped by user (KeyboardInterrupt).")
    except Exception as e:
        logging.error(f"An unexpected error occurred during execution: {e}")