In [1]:
from stock_bot import SignalGeneration
from stock_bot.TradingSystem import *
from stock_bot.AlpacaDataManager import *
from stock_bot.DataFrameBacktest import *
from stock_bot.Indicators import *
from stock_bot.TA_LIB_FunctionMapping import *
from pathlib import Path

In [10]:
class Strategy: 
    def __init__(self, config: Dict[str, Any]):
        self.config = config
        self.positions = []
        self.indicators: Dict[str, Callable] = {}
        self.entry_conditions: List[Callable] = []
        self.exit_conditions: List[Callable] = []
        self.risk_filters: List[Callable] = []
        self.position_sizing: Optional[Callable] = None
        self.df = pd.DataFrame()

    def get_indicators(self):
        '''
        Parse through list of indicators in the config
        '''
        for indicator in self.config['indicators']:
            # Use the indicator + period value to instantiate pd.Series 
            # for EMA SMA calculations
            # Else:  
            if indicator['name'].lower() in ['ema', 'sma']:
                key = f"{indicator['name'].lower()}_{str(indicator['params'].get('period'))}"
                self.indicators[key] = indicator['params']
            else:
                self.indicators[indicator['name'].lower()] = indicator['params']

        return self.indicators
    
    def _check_condition(row: pd.Series, condition_config: Dict) -> bool:
        """
        Checking the entry / exit condition for a given row
        """
        
        comparison_values = ['above', 'below', 'between', 'crosses_above', 'crosses_below']
        if condition_config['comparison'] not in comparison_values:
            raise Exception(f"Comparison value {condition_config['comparison']} not a valid comparison operator" )

        indicator = condition_config['indicator']
        comparison = condition_config['comparison']
        value = condition_config['value']
        
        
        if (comparison == "above") & isinstance(value, str) :
            above = row[indicator.lower()] > row[value.lower()]
            return above
        elif (comparison == "above") & isinstance(value, (int, float)):
            above = row[indicator.lower()] > value
            return above

        if (comparison == "below") & isinstance(value, str) :
            below = row[indicator.lower()] < row[value.lower()]
            return below
        elif (comparison == "below") & isinstance(value, int):
            below = row[indicator.lower()] < value
            return below

        if comparison == "crosses_above":
            if indicator == "MACD":
                macd_cross_above = (row['macd'] > row['macd_signal']) & (row['macd_prev'] <= row['macdsignal_prev'])
                return macd_cross_above
            elif indicator == "BBANDS":
                bb_cross_above = (row['close'] > row[value.lower()]) & (row['close_prev'] <= row[value.lower()])
                return bb_cross_above
            else:
                indicator_cross_above = (row[indicator.lower()] > row[value.lower()]) & (row[f"{indicator.lower()}_prev"] <= row[f"{value.lower()}_prev"])
                return indicator_cross_above
            
        elif comparison == "crosses_below":
            if indicator == "MACD":
                macd_cross_below = (row['macd'] < row['macd_signal']) & (row['macd_prev'] >= row['macdsignal_prev'])
                return macd_cross_below

            elif indicator == "BBANDS":
                bb_cross_below = (row['close'] < row[value]) & (row['close_prev'] >= row[value])
                return bb_cross_below
            else:
                indicator_cross_below = (row[indicator.lower()] < row[value.lower()]) & (row[f"{indicator.lower()}_prev"] >= row[f"{value.lower()}_prev"])
                return indicator_cross_below
            
        elif indicator == 'between':
            if any(isinstance(x, (int, float)) for x in value):
                indicator_between = row.between(value[0], value[1], inclusive="both")
                return indicator_between
            else:
                indicator_between = row.between(row[value[0]], row[value[1]], inclusive="both")
                return indicator_between
            
    def _check_condition(row: pd.Series, condition_config: Dict) -> bool:
        """
        Checking the entry / exit condition for a given row
        """
        
        comparison_values = ['above', 'below', 'between', 'crosses_above', 'crosses_below']
        if condition_config['comparison'] not in comparison_values:
            raise Exception(f"Comparison value {condition_config['comparison']} not a valid comparison operator" )

        indicator = condition_config['indicator']
        comparison = condition_config['comparison']
        value = condition_config['value']
        
        
        if (comparison == "above") & isinstance(value, str) :
            above = row[indicator.lower()] > row[value.lower()]
            return above
        elif (comparison == "above") & isinstance(value, (int, float)):
            above = row[indicator.lower()] > value
            return above

        if (comparison == "below") & isinstance(value, str) :
            below = row[indicator.lower()] < row[value.lower()]
            return below
        elif (comparison == "below") & isinstance(value, int):
            below = row[indicator.lower()] < value
            return below

        if comparison == "crosses_above":
            if indicator == "MACD":
                macd_cross_above = (row['macd'] > row['macd_signal']) & (row['macd_prev'] <= row['macdsignal_prev'])
                return macd_cross_above
            elif indicator == "BBANDS":
                bb_cross_above = (row['close'] > row[value.lower()]) & (row['close_prev'] <= row[value.lower()])
                return bb_cross_above
            else:
                indicator_cross_above = (row[indicator.lower()] > row[value.lower()]) & (row[f"{indicator.lower()}_prev"] <= row[f"{value.lower()}_prev"])
                return indicator_cross_above
            
        elif comparison == "crosses_below":
            if indicator == "MACD":
                macd_cross_below = (row['macd'] < row['macd_signal']) & (row['macd_prev'] >= row['macdsignal_prev'])
                return macd_cross_below

            elif indicator == "BBANDS":
                bb_cross_below = (row['close'] < row[value]) & (row['close_prev'] >= row[value])
                return bb_cross_below
            else:
                indicator_cross_below = (row[indicator.lower()] < row[value.lower()]) & (row[f"{indicator.lower()}_prev"] >= row[f"{value.lower()}_prev"])
                return indicator_cross_below
            
        elif indicator == 'between':
            if any(isinstance(x, (int, float)) for x in value):
                indicator_between = row.between(value[0], value[1], inclusive="both")
                return indicator_between
            else:
                indicator_between = row.between(row[value[0]], row[value[1]], inclusive="both")
                return indicator_between


    def _check_entry_conditions(self, row: pd.Series, config: List[Dict]) -> bool:
        """Check if all entry conditions are met"""
        return all(
            self._check_condition(row, condition)
            for condition in config
        )

    def _check_exit_conditions(row: pd.Series, config: List[Dict]) -> bool:
        """Check if any exit condition is met"""
        return any(
            _check_condition(row, condition)
            for condition in config
        )

    def _calculate_position_size(row: pd.Series, account_balance: float, risk_config: Dict) -> float:
        """
        Calculate position size based on risk management rules, ensuring non-negative position sizes
        
        Args:
            row: DataFrame row containing price and indicator data
            account_balance: Current account balance
            risk_config: Dictionary containing risk management parameters
            
        Returns:
            float: Calculated position size, always >= 0
        """
        try:
            # Ensure account balance is positive
            account_balance = abs(account_balance)
            
            if risk_config['position_sizing_method'] == 'atr_based':
                # Make sure ATR exists and is not NaN
                if 'atr' not in row or pd.isna(row['atr']):
                    print(f"Warning: ATR is missing or NaN. Available columns: {row.index.tolist()}")
                    return 0.0
                    
                risk_amount = account_balance * abs(risk_config['risk_per_trade'])
                stop_distance = abs(float(row['atr'])) * abs(risk_config['atr_multiplier'])
                
                # Avoid division by zero and ensure positive stop distance
                if stop_distance <= 0:
                    print(f"Warning: Invalid stop distance calculated: {stop_distance}")
                    return 0.0
                    
                position_size = risk_amount / stop_distance
                max_position_size = abs(risk_config['max_position_size'])
                
                print(f"""
    Position Size Calculation:
    - Account Balance: {account_balance}
    - Risk Amount: {risk_amount}
    - ATR: {abs(row['atr'])}
    - Stop Distance: {stop_distance}
    - Calculated Position Size: {position_size}
    - Max Position Size: {max_position_size}
                """)
                
                return min(position_size, max_position_size)
                
            elif risk_config['position_sizing_method'] == 'fixed':
                return abs(risk_config['max_position_size'])
            
            elif risk_config['position_sizing_method'] == 'risk_based':
                risk_amount = account_balance * abs(risk_config['risk_per_trade'])
                stop_distance = abs(row['close']) * abs(risk_config['stop_loss'])
                
                if stop_distance <= 0:
                    print(f"Warning: Invalid stop distance calculated: {stop_distance}")
                    return 0.0
                    
                position_size = risk_amount / stop_distance
                max_position_size = abs(risk_config['max_position_size'])
                
                return min(position_size, max_position_size)
                
            else:
                print(f"Warning: Unknown position sizing method: {risk_config['position_sizing_method']}")
                return 0.0
                
        except Exception as e:
            print(f"Error calculating position size: {str(e)}")
            print(f"Row data: {row}")
            return 0.0

    def _calculate_pnl(entry_price: float, exit_price: float, position_size: float) -> float:
        """Calculate PnL for a trade with validation"""
        try:
            if any(pd.isna([entry_price, exit_price, position_size])):
                print(f"""
    Invalid PnL calculation values:
    - Entry Price: {entry_price}
    - Exit Price: {exit_price}
    - Position Size: {position_size}
                """)
                return 0.0
                
            pnl = (exit_price - entry_price) * position_size
            return pnl
            
        except Exception as e:
            print(f"Error calculating PnL: {str(e)}")
            return 0.0

    def _check_stop_loss(row: pd.Series, entry_price, position, risk_config) -> bool:
        """Check if stop loss is hit"""
        if not position:
            return False
        return row['close'] <= entry_price * (1 - risk_config['stop_loss'])

    def _check_take_profit(row: pd.Series, entry_price, position, risk_config) -> bool:
        """Check if take profit is hit"""
        if not position:
            return False
        return row['close'] >= entry_price * (1 + risk_config['take_profit'])

    def _calculate_metrics(self, trades: pd.DataFrame, equity: pd.Series, initial_balance: float) -> Dict[str, float]:
            """Calculate backtest performance metrics"""
            if len(trades) == 0:
                return {
                    "total_trades": 0,
                    "win_rate": 0,
                    "profit_factor": 0,
                    "total_return": 0,
                    "max_drawdown": 0,
                    "sharpe_ratio": 0
                }
            
            # Calculate returns and drawdown
            returns = equity.pct_change().dropna()
            drawdown = (equity - equity.cummax()) / equity.cummax()
            
            # Calculate trade metrics
            winning_trades = trades[trades['pnl'] > 0]
            losing_trades = trades[trades['pnl'] <= 0]
            
            metrics = {
                "total_trades": len(trades),
                "winning_trades": len(winning_trades),
                "losing_trades": len(losing_trades),
                "win_rate": round(len(winning_trades) / len(trades), 2),
                "profit_factor": round(abs(winning_trades['pnl'].sum() / losing_trades['pnl'].sum()) if len(losing_trades) > 0 else float('inf'),2),
                "total_return":round((equity.iloc[-1] - initial_balance) / initial_balance, 2),
                "max_drawdown": round(abs(drawdown.min()), 4),
                "sharpe_ratio": round(np.sqrt(252) * returns.mean() / returns.std() if len(returns) > 0 else 0, 2)
            }
            
            return metrics


In [11]:
data_folder = Path('./config')
yaml_file = 'api_keys.yaml'
strategy  = 'ta-lib_example.yaml'
yaml_path = data_folder / yaml_file
strat_file = data_folder / strategy

with open(yaml_path, 'r') as file:
    yaml_config = yaml.safe_load(file)

with open("./config/stage-based-trading-system.yaml", 'r') as file:
    yaml_trade_config = yaml.safe_load(file)

api_key=yaml_config['api_key_paper']
api_secret=yaml_config['api_secret_paper']
data_fetcher = AlpacaDataFetcher(api_key, api_secret)


# Fetch historical data
historical_data = data_fetcher.get_historical_data(
    symbol=yaml_trade_config['symbol'],
    timeframe='1m',
    start_date=datetime(2024, 6, 1, tzinfo=pytz.UTC),
    end_date=datetime(2024, 12, 29, tzinfo=pytz.UTC)
)

strategy = Strategy(yaml_trade_config)


        

In [12]:
strategy.config

{'symbol': 'PLTR',
 'stage1_conditions': [{'indicator': 'close',
   'comparison': 'between',
   'value': ['lowerband', 'upperband']},
  {'indicator': 'close',
   'comparison': 'between',
   'value': ['sma_20', 'sma_50']},
  {'indicator': 'rsi', 'comparison': 'between', 'value': [40, 60]},
  {'indicator': 'atr', 'comparison': 'below', 'value': 1000}],
 'entry_conditions': [{'indicator': 'close',
   'comparison': 'above',
   'value': 'sma_20'},
  {'indicator': 'ema_5', 'comparison': 'crosses_above', 'value': 'sma_20'}],
 'exit_conditions': [{'indicator': 'close',
   'comparison': 'below',
   'value': 'ema_20'},
  {'indicator': 'ema_5', 'comparison': 'crosses_below', 'value': 'ema_20'},
  {'indicator': 'close', 'comparison': 'below', 'value': 'lowerband'}],
 'risk_management': {'position_sizing_method': 'risk_based',
  'risk_per_trade': 0.7,
  'stop_loss': 0.02,
  'take_profit': 0.06,
  'max_position_size': 1000.0,
  'atr_multiplier': 2.0},
 'indicators': [{'name': 'EMA', 'params': {'peri

In [13]:
strategy.get_indicators()

{'ema_5': {'period': 5},
 'sma_20': {'period': 20},
 'sma_50': {'period': 50},
 'rsi': {'period': 14},
 'bbands': {'period': 20, 'std_dev': 2},
 'atr': {'period': 14}}

In [14]:

backtest_df = TechnicalIndicators(historical_data, strategy.get_indicators())
backtest_df.calculate_indicators()

calculate previous values:  


Unnamed: 0_level_0,Unnamed: 1_level_0,open,high,low,close,volume,trade_count,vwap,ema_5,sma_20,sma_50,...,atr,close_prev,ema_5_prev,sma_20_prev,sma_50_prev,rsi_prev,upperband_prev,middleband_prev,lowerband_prev,atr_prev
symbol,timestamp,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1,Unnamed: 22_level_1
PLTR,2024-06-03 07:57:00+00:00,21.8500,21.9700,21.81,21.85,10146.0,200.0,21.853400,,,,...,,,,,,,,,,
PLTR,2024-06-03 08:08:00+00:00,21.8500,21.8800,21.83,21.87,39262.0,166.0,21.853660,,,,...,,21.85,,,,,,,,
PLTR,2024-06-03 08:19:00+00:00,21.8600,21.9000,21.86,21.90,20650.0,149.0,21.866508,,,,...,,21.87,,,,,,,,
PLTR,2024-06-03 08:30:00+00:00,21.8700,21.9100,21.82,21.90,17067.0,140.0,21.888111,,,,...,,21.90,,,,,,,,
PLTR,2024-06-03 08:41:00+00:00,21.8900,21.8900,21.87,21.89,1482.0,20.0,21.877176,21.882000,,,...,,21.90,,,,,,,,
PLTR,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
PLTR,2024-12-28 00:14:00+00:00,78.8534,78.8999,78.84,78.88,7795.0,122.0,78.870855,78.890805,78.988060,78.898206,...,0.304401,78.90,78.896207,78.992060,78.892606,43.086451,79.451709,78.892606,78.333503,0.309389
PLTR,2024-12-28 00:25:00+00:00,78.8800,78.8800,78.84,78.85,8859.0,141.0,78.845552,78.877203,78.977060,78.910006,...,0.299113,78.88,78.890805,78.988060,78.898206,42.975586,79.451048,78.898206,78.345364,0.304401
PLTR,2024-12-28 00:36:00+00:00,78.8500,78.8900,78.81,78.88,17863.0,225.0,78.843722,78.878135,78.965485,78.921806,...,0.294731,78.85,78.877203,78.977060,78.910006,42.806996,79.432192,78.910006,78.387820,0.299113
PLTR,2024-12-28 00:47:00+00:00,78.8700,78.9000,78.81,78.81,16875.0,305.0,78.830291,78.855424,78.951985,78.926206,...,0.290636,78.88,78.878135,78.965485,78.921806,43.035026,79.413173,78.921806,78.430439,0.294731


In [None]:
yaml_trade_config.get('entry_conditions')

In [None]:

risk_config = yaml_trade_config['risk_management']

# Initialize state
position = False
entry_price = 0
position_size = 0

trades_list = []
initial_balance = 100000
balance = initial_balance
# Initialize equity as a numpy array with float dtype
equity = np.full(len(df), initial_balance, dtype=np.float64)
positions = pd.Series(0, index=df.index)
signals = pd.Series(None, index=df.index)


In [None]:
result = nihhanNV
result.tail()

In [None]:


for i, (idx, row) in enumerate(result.iterrows()):

    current_equity = balance
    
    # Check for exits if in position
    if position:
        # Calculate unrealized P&L
        unrealized_pnl = (row['close'] - entry_price) * position_size
        current_equity = balance + unrealized_pnl
        positions[idx] = position_size
        
        # Check stop loss
        if _check_stop_loss(row,entry_price, position, risk_config):
            trade_pnl = _calculate_pnl(entry_price, row['close'], position_size)
            balance += trade_pnl
            trades_list.append({
                'entry_time': entry_time,
                'exit_time': idx,
                'entry_price': entry_price,
                'exit_price': row['close'],
                'quantity': position_size,
                'pnl': trade_pnl,
                'exit_type': 'stop_loss'
            })
            signals[idx] = 'stop_loss_exit'
            position = False
            equity[i] = float(balance)
            continue
        
        
        if _check_take_profit(row, entry_price, position, risk_config):
            trade_pnl = _calculate_pnl(entry_price, row['close'], position_size)
            balance += trade_pnl
            trades_list.append({
                'entry_time': entry_time,
                'exit_time': idx,
                'entry_price': entry_price,
                'exit_price': row['close'],
                'quantity': position_size,
                'pnl': trade_pnl,
                'exit_type': 'take_profit'
            })
            
            print(f"\nTrade recorded:")
            print(f"Entry price: {entry_price}")
            print(f"Exit price: {row['close']}")
            print(f"Position size: {position_size}")
            print(f"PnL: {trade_pnl}")

            signals[idx] = 'take_profit_exit'
            position = False
            equity[i] = float(balance)
            continue

        if _check_exit_conditions(row, yaml_trade_config['exit_conditions']):
            trade_pnl = (row['close'] - entry_price) * position_size
            balance += trade_pnl
            trades_list.append({
                'entry_time': entry_time,
                'exit_time': idx,
                'entry_price': entry_price,
                'exit_price': row['close'],
                'quantity': position_size,
                'pnl': trade_pnl,
                'exit_type': 'signal'
            })

            print(f"\nTrade recorded:")
            print(f"Entry price: {entry_price}")
            print(f"Exit price: {row['close']}")
            print(f"Position size: {position_size}")
            print(f"PnL: {trade_pnl}")

            signals[idx] = 'signal_exit'
            position = False
            equity[i] = float(balance)
            continue

    
    '''print("1. row_close:", row['close'], " above row_sma_20:", row['sma_20'])
    print("2. ema_5 cross above sma_20", row['ema_5'], row['sma_20'], "prev_ema_5 / sma_20: ", row['ema_5_prev'], row['sma_20_prev'])
    print("3. close above upperband", row['close'], 'upperband: ', row['upperband'])
    print("\n")'''
    if _check_entry_conditions(row, yaml_trade_config['entry_conditions']):

        position_size = _calculate_position_size(row, balance, risk_config)
        print("position_size:", position_size, idx)
        position = True
        entry_price = row['close']
        entry_time = idx
        positions[idx] = position_size
        signals[idx] = 'entry'
        #print(entry)

    # Update equity for current bar
    equity[i] = float(current_equity)
                

        

In [None]:
equity

In [None]:
len(trades_list)

In [60]:
result['equity'] = equity

In [None]:
result.loc['2024-06-03': '2024-09-03'][['ema_5', 'sma_20',]].plot()

In [None]:
result['equity'].plot()

In [None]:
yaml_trade_config

Incorporating Strategy.py into the workflow


In [None]:
strategy = Strategy(yaml_trade_config)
strategy.config

In [None]:
technical_indicators = TechnicalIndicators(historical_data, strategy.get_indicators())


In [None]:
technical_indicators.calculate_indicators()
