In [1]:
"""
Simple test for formula loading and evaluation.
"""
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join("/home/ubuntu/projects/hindsight/examples/data/ast", '../../..')))

import matplotlib.pyplot as plt
import jax
import pandas as pd
import pandas_ta as ta
import jax.numpy as jnp

from src import DataManager
from src.data.ast import parse_formula
from src.data.ast.manager import FormulaManager
from src.data.ast.functions import register_built_in_functions, get_function_context
from src.data.core import prepare_for_jit

  @xr.register_dataarray_accessor('dt')


In [2]:
# Register built-in functions
register_built_in_functions()

In [3]:
# Initialize and load formulas
manager = FormulaManager()

In [4]:
# Load CRSP data

dm = DataManager()
configs = dm.get_builtin_configs()

In [5]:
configs

['equity_standard']

In [6]:
ds = dm.load_builtin("equity_standard")

wrds/equity/crsp: Attemping to load found cache(/home/suchismit/data/cache/wrds/equity/crsp/4f5bc4c841f1d1d0506b704fdad82df5_2000-01-01_2023-12-01.nc).
wrds/equity/crsp: Successfully loaded from /home/suchismit/data/cache/wrds/equity/crsp/4f5bc4c841f1d1d0506b704fdad82df5_2000-01-01_2023-12-01.nc


In [7]:
ds = ds['equity_prices']

In [8]:
list(ds.data_vars.keys())

['cusip',
 'issuno',
 'hexcd',
 'hsiccd',
 'bidlo',
 'askhi',
 'prc',
 'vol',
 'ret',
 'bid',
 'ask',
 'shrout',
 'cfacpr',
 'cfacshr',
 'altprc',
 'spread',
 'altprcdt',
 'retx',
 'comnam',
 'exchcd',
 'me']

In [9]:
ds_jit, _ = prepare_for_jit(ds)

In [10]:
function_context = get_function_context() 

In [11]:
manager = FormulaManager()

In [12]:
manager.list_formulas()

['alma',
 'chcsho_12m',
 'chcsho_1m',
 'chcsho_3m',
 'chcsho_6m',
 'dema',
 'div12m_me',
 'div1m_me',
 'div3m_me',
 'div6m_me',
 'dividend',
 'dividend_times_shares',
 'fwma',
 'hma',
 'hwma',
 'kama',
 'market_equity',
 'prc_adj',
 'price_ret',
 'ret_12m',
 'ret_1m',
 'ret_3m',
 'ret_6m',
 'rsi',
 'shares_adj',
 'simple_rsi_momentum',
 'triple_exponential_smoothing',
 'ts_combined',
 'ts_momentum',
 'ts_signal',
 'wma']

In [13]:
ds_jit

In [14]:
@jax.jit
def _eval(): 
    context = {
        '_dataset': ds, # captured by closure
        'me': 'me',
        'ret': 'ret',
        'cfacpr': 'cfacpr',
        'prc': 'prc',
        'retx': 'retx',
        'shrout': 'shrout',
        'cfacshr': 'cfacshr'   
    }
    
    formulas_to_eval = [
        'market_equity',
        'chcsho_12m',
        'chcsho_1m',
        'chcsho_3m',
        'chcsho_6m',
        'div12m_me',
        'div1m_me',
        'div3m_me',
        'div6m_me',
        'dividend',
        'dividend_times_shares',
        'prc_adj',
        'price_ret',
        'ret_12m',
        'ret_1m',
        'ret_3m',
        'ret_6m',
        'shares_adj',]

    result = manager.evaluate_bulk(formulas_to_eval, context)
    return result

In [15]:
res = _eval()

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


In [16]:
res

In [17]:
res['dividend'].sel(asset=14593).data

Array([[[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ]],

       [[0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
        [0.        ],
    