## Trading a Delta Hedged Straddle

1.  Buy 10 contracts of puts and calls at the strike closest to current stock price every morning
    a. Expiry that is closest to today but at least 5 trading days out.
    b. Hedge remaining delta
2.  Re-hedge if residual delta > 2 contracts
2.  Exit the trade at EOD.
5.  We assume no slippage (i.e. entry and exit at mid price) and commission of 1/2 a cent for stock trades and 1 dollar for option trades

First lets generate some option prices based on the underlying stock price.  We will generate some random volatility numbers, and use Black Scholes to generate prices.  We will add all this data to a strategy context that we can later use from our strategy and avoid excessive global variables.

In [1]:
%%checkall
import numpy as np
import pandas as pd
from dataclasses import dataclass
from typing import Sequence
import pyqstrat as pq
from types import SimpleNamespace


_logger = pq.get_child_logger(__name__)


def get_symbol(put_call: str, strike: int, expiry: np.datetime64) -> str:
    '''Create option symbol from parameters'''
    return f'{put_call}-{strike}-{expiry}'


def parse_symbol(symbol: str) -> tuple[str, int, np.datetime64]:
    '''Break down option symbol into put_call, strike, expiry'''
    split = symbol.split('-', maxsplit=2)
    return (split[0], int(split[1]), np.datetime64(split[2]))
       

@dataclass
class StraddleEntryRule:
    expiries: dict[np.datetime64, np.ndarray]
    strikes: dict[tuple[np.datetime64, str, np.datetime64], np.ndarray]
    umids: dict[np.datetime64, float]
        
    def __call__(self,
                 contract_group: pq.ContractGroup,
                 i: int,
                 timestamps: np.ndarray,
                 indicator_values: pq.SimpleNamespace,
                 signal_values: pq.SimpleNamespace,
                 account: pq.Account,
                 current_orders: Sequence[pq.Order],
                 strategy_context: pq.StrategyContextType) -> list[pq.Order]:
        
        for order in current_orders:
            if order.contract.is_basket(): 
                order.cancel()
                
        timestamp = timestamps[i]
        date = timestamp.astype('M8[D]')
        expiries = self.expiries[date]
        expiry = expiries[pq.np_find_closest(expiries, date + np.timedelta64(30, 'D'))]
        put_strikes = self.strikes[(date, 'P', expiry)]
        call_strikes = self.strikes[(date, 'C', expiry)]
        umid = self.umids[timestamp]
        cidx = pq.np_find_closest(call_strikes, umid)
        pidx = pq.np_find_closest(put_strikes, umid)
        found = False
        for j in range(10):
            put_strike = put_strikes[pidx]
            call_strike = call_strikes[cidx]
            # _logger.info(f'trying strike: {put_strike} {call_strike} index: {pidx} {cidx} i: {i} timestamp: {timestamp}')
            put_symbol = get_symbol('P', put_strike, expiry)
            call_symbol = get_symbol('C', call_strike, expiry)
            call_contract = contract_group.get_contract(call_symbol)
            put_contract = contract_group.get_contract(put_symbol)
            if call_contract is None: call_contract = pq.Contract.create(call_symbol, contract_group, expiry, 100)
            if put_contract is None: put_contract = pq.Contract.create(put_symbol, contract_group, expiry, 100)
            put_delta = context.get_delta(put_contract, timestamps, i, context)
            call_delta = context.get_delta(call_contract, timestamps, i, context)
            if np.isfinite(put_delta) and np.isfinite(call_delta): 
                # _logger.info(f'found deltas: {put_delta} {call_delta}')
                found = True
                break
            cidx += 1
            pidx -= 1
        if not found: return []
        symbol = f'{put_symbol}_{call_symbol}'
        contract = contract_group.get_contract(symbol)
        if contract is None: 
            contract = pq.Contract.create(symbol, contract_group, components=[(put_contract, 1), (call_contract, 1)])
        order = pq.MarketOrder(contract=contract, timestamp=timestamp, qty=10, reason_code='ENTER_STRADDLE')
        orders: list[pq.Order] = [order]
# hedge_contract, hedge_qty = get_hedge(put_contract, call_contract, 10, 10, timestamps, i, context)
#         hedge_qty = 0
#         if hedge_qty != 0:
#             hedge_order = pq.MarketOrder(contract=hedge_contract, 
#                                          timestamp=timestamp, 
#                                          qty=hedge_qty, 
#                                          reason_code='ENTER_STRADDLE')
#             orders.append(hedge_order)
        msg = f'ORDER: {timestamp}: \n'
        for order in orders: msg += f'    {order}\n'
        _logger.info(msg)
        # import pdb; pdb.set_trace()
        return orders
        
        
def get_hedge(put: pq.Contract, 
              call: pq.Contract, 
              put_qty: int, 
              call_qty: int, 
              timestamps: np.ndarray, 
              i: int, 
              context: pq.StrategyContextType) -> tuple[pq.Contract, int]:
    delta: float = 0
    for contract in [put, call]:
        delta += context.get_delta(contract, timestamps, i, context)
    hedge_contract = context.spx_contract
    pq.assert_(hedge_contract is not None)
    hedge_qty = round(-100 * delta)
    return hedge_contract, hedge_qty


def get_expiries(prices: pd.DataFrame) -> dict[np.datetime64, np.ndarray]:
    _expiries = prices.groupby(['date']).expiry.unique()
    expiries = {_expiries.index.values[i].astype('M8[D]'): 
                np.sort(_expiries.values[i].astype('M8[D]')) for i in range(len(_expiries))}
    return expiries


def get_strikes(prices: pd.DataFrame) -> dict[tuple[np.datetime64, str, np.datetime64], np.ndarray]:
    _strikes = prices.groupby(['date', 'put_call', 'expiry']).strike.unique()
    for i in range(len(_strikes)):
        _strikes[i] = _strikes[i][_strikes[i] % 100 == 0]
    strikes = {(_strikes.index.get_level_values(0).values[i].astype('M8[D]'), 
                _strikes.index.get_level_values(1)[i], 
                _strikes.index.get_level_values(2).values[i].astype('M8[D]')): np.sort(_strikes.values[i]) 
               for i in range(len(_strikes))}
    return strikes


def get_price_function(prices: pd.DataFrame, field_name: str) -> pq.PriceFunctionType:
    price_dict: dict[str, tuple[np.ndarray, np.ndarray]] = {}
    for symbol in np.unique(prices.symbol.values):
        sym_prc = prices[['timestamp', field_name]][prices.symbol == symbol].sort_values(by='timestamp')
        _timestamps = sym_prc.timestamp.values.astype('M8[m]')
        _prices = sym_prc[field_name].values
        price_dict[symbol] = (_timestamps, _prices)
    spx_prices = prices[['timestamp', 'umid']].sort_values(by=['timestamp']).drop_duplicates(subset=['timestamp'])
    price_dict['SPX'] = (spx_prices.timestamp.values.astype('M8[m]'), spx_prices.umid.values)
    return pq.PriceFuncArrayDict(price_dict=price_dict)


@dataclass
class BasketOrderMarketSimulator:
    '''
    A function object with a signature of MarketSimulatorType.
    It can take into account slippage and commission
    >>> pq.ContractGroup.clear()
    >>> pq.Contract.clear()
    >>> cg = pq.ContractGroup.create('test_cg')
    >>> put_symbol, call_symbol = 'SPX-P-3500-2023-01-19', 'SPX-C-4000-2023-01-19'
    >>> put_contract = pq.Contract.create(put_symbol, cg)
    >>> call_contract = pq.Contract.create(call_symbol, cg)
    >>> basket = pq.Contract.create('test_contract', cg)
    >>> basket.components = [(put_contract, -1), (call_contract, 1)]
    >>> timestamp = np.datetime64('2023-01-03 14:35')
    >>> price_func = pq.PriceFuncDict({put_symbol: {timestamp: 4.8}, call_symbol: {timestamp: 3.5}})
    >>> order = pq.MarketOrder(contract=basket, timestamp=timestamp, qty=10, reason_code='TEST')
    >>> sim = BasketOrderMarketSimulator(price_func=price_func, slippage_per_trade=0)
    >>> out = sim([order], 0, np.array([timestamp]), {}, {}, SimpleNamespace())
    >>> assert(len(out) == 1)
    >>> assert(math.isclose(out[0].price, -1.3))
    >>> assert(out[0].qty == 10)
    '''
    slippage: float
    price_func: pq.PriceFunctionType
        
    def __init__(self,
                 price_func: pq.PriceFunctionType,
                 slippage_per_trade: float = 0.) -> None:
        '''
        Args:
            price_func: A function that we use to get the price to execute at
            slippage_per_trade: Slippage in local currency. Meant to simulate the difference
            between bid/ask mid and execution price 
        '''
        self.price_func = price_func
        self.slippage = slippage_per_trade
    
    def __call__(self,
                 orders: Sequence[pq.Order],
                 i: int, 
                 timestamps: np.ndarray, 
                 indicators: dict[pq.ContractGroup, SimpleNamespace],
                 signals: dict[pq.ContractGroup, SimpleNamespace],
                 strategy_context: SimpleNamespace) -> list[pq.Trade]:
        trades = []
        timestamp = timestamps[i]
        # _logger.info(f'got: {orders}')
        for order in orders:
            contract = order.contract
            if not isinstance(order, pq.MarketOrder) and not isinstance(order, pq.LimitOrder): continue
            prices_found = True
            raw_price = 0.
            for (_contract, ratio) in contract.components:
                raw_price += self.price_func(_contract, timestamps, i, strategy_context) * ratio
                if np.isnan(raw_price):
                    # _logger.info(f'could not find price for: {_contract}')
                    prices_found = False
                    break
            if not prices_found: continue
            slippage = self.slippage
            if order.qty < 0: slippage = -slippage
            price = raw_price + slippage
            if isinstance(order, pq.LimitOrder):
                if np.isfinite(order.limit_price):
                    if ((abs(order.qty > 0) and order.limit_price > price) 
                            or (abs(order.qty < 0) and order.limit_price < price)):
                        _logger.debug(f'limit_price: {order.limit_price} not met price: {price}')
                        continue
            # market order
            trade = pq.Trade(order.contract, order, timestamp, order.qty, price)
            _logger.info(f'Trade: {timestamp.astype("M8[m]")} {trade} {i}')
            trades.append(trade)
            order.fill()
        return trades


if __name__ == '__main__':
    pq.set_defaults()
    pq.Contract.clear()
    pq.ContractGroup.clear()

    prices = pd.read_csv('./support/spx_options.csv.gz', parse_dates=['timestamp'])
    prices = prices[['timestamp', 'symbol', 'umid', 'c', 'delta']]
    
    prices['date'] = prices.timestamp.values.astype('M8[D]')
    prices = prices[prices.date == "2023-01-03"]

    # remove prices outside regular trading hours (9:30 am and 4 pm)
    minute = (prices.timestamp - prices.date) / np.timedelta64(1, 'm')
    prices = prices[(minute > 9 * 60 + 30) & (minute < 16 * 60)]
    # prices.groupby('date').strike.unique()

    splits = prices.symbol.str.split('-', n=2, expand=True)
    prices['put_call'] = splits[0]
    prices['strike'] = splits[1].astype(int)
    prices['expiry'] = splits[2].astype(np.datetime64)

    data = prices[['timestamp', 'umid']].sort_values(by=['timestamp']).drop_duplicates(subset=['timestamp'])
    data['date'] = data.timestamp.values.astype('M8[D]')
    data['hour'] = data.timestamp.dt.hour
    # beginning and end of day and rehedge signals
    
    bod = pd.Series(np.where(data.date != data.date.shift(1), 1, np.nan))
    bod = bod.fillna(method='ffill', limit=60)
    data['bod'] = np.where(bod == 1, True, False)
    eod = pd.Series(np.where(data.date != data.date.shift(-1), 1, np.nan))
    eod = eod.fillna(method='bfill', limit=30)
    data['eod'] = np.where(eod == 1, True, False)

    data['rehedge_sig'] = (data.hour != data.hour.shift(-1))
    
    strat_builder = pq.StrategyBuilder(data)
    context = pq.StrategyContextType()
    price_func = get_price_function(prices, 'c')
    delta_func = get_price_function(prices, 'delta')
    context.get_price = price_func
    context.get_delta = delta_func
    strat_builder.set_strategy_context(context)
    strat_builder.set_price_function(price_func)
    strat_builder.add_market_sim(BasketOrderMarketSimulator(price_func, 0.))
    
    expiries = get_expiries(prices)
    strikes = get_strikes(prices)
    umids = {data.timestamp.values[i].astype('M8[m]'): data.umid.values[i] for i in range(len(data))}
    straddle_entry_rule = StraddleEntryRule(expiries, strikes, umids)

    opt_cg = pq.ContractGroup.create('OPTIONS')
    # opt_cg.add_contract(pq.Contract.create('DUMMY', contract_group=opt_cg))
    # hedge_cg = pq.ContractGroup.create('HEDGES')
    hedge_cg = opt_cg
    strat_builder.add_contract_group(opt_cg)
    strat_builder.add_contract_group(hedge_cg)

    #spx = pq.Contract.create('SPX', contract_group=hedge_cg)
    #hedge_cg.add_contract(spx)
    #context.spx_contract = spx

    strat_builder.add_series_rule('bod', straddle_entry_rule, position_filter='zero', contract_groups=[opt_cg])
    # close_rule = pq.ClosePositionExitRule('EOD', context.get_price)
    # strat_builder.add_series_rule('eod', close_rule, position_filter='nonzero')

    strategy = strat_builder()
    
    strategy.run()

running typecheck
[1m[32mSuccess: no issues found in 1 source file[m

running flake8
stdin:273:5: E265 block comment should start with '# '
stdin:274:5: E265 block comment should start with '# '
stdin:275:5: E265 block comment should start with '# '

[2023-10-21 17:01:54.992 __call__] ORDER: 2023-01-03T09:40: 
    P-3900-2023-01-20_C-3900-2023-01-20 2023-01-03 09:40:00 qty: 10 ENTER_STRADDLE OrderStatus.OPEN

[2023-10-21 17:01:54.993 __call__] ORDER: 2023-01-03T09:50: 
    P-3700-2023-01-20_C-3900-2023-01-20 2023-01-03 09:50:00 qty: 10 ENTER_STRADDLE OrderStatus.OPEN

[2023-10-21 17:01:54.993 __call__] Trade: 2023-01-03T10:00 P-3700-2023-01-20_C-3900-2023-01-20 2023-01-03 10:00:00 qty: 10 prc: 72.18 order: P-3700-2023-01-20_C-3900-2023-01-20 2023-01-03 09:50:00 qty: 10 ENTER_STRADDLE OrderStatus.OPEN 2


In [2]:
strategy.contract_groups

[OPTIONS]