In [1]:
from trading_envs.alpaca.obs_class import AlpacaObservationClass
from alpaca.data.timeframe import TimeFrame, TimeFrameUnit


## Single Time Frame Observation Example

In [2]:
window_size = 10
observer = AlpacaObservationClass(
    symbol="BTC/USD",
    timeframes=TimeFrame(15, TimeFrameUnit.Minute),
    window_sizes=window_size,
)

In [3]:
observer.get_keys()

['15Minute_10']

In [4]:
observer.get_features()

{'observation_features': ['feature_close',
  'feature_open',
  'feature_high',
  'feature_low'],
 'original_features': ['index', 'open', 'high', 'low', 'close']}

In [5]:
obs = observer.get_observations()
print(obs)
print(obs['15Minute_10'].shape)

{'15Minute_10': array([[ 0.0000000e+00,  1.0010585e+00,  1.0021054e+00,  9.9970555e-01],
       [-1.4106651e-03,  1.0022540e+00,  1.0022540e+00,  9.9931169e-01],
       [ 1.6442838e-03,  9.9896801e-01,  1.0000000e+00,  9.9834281e-01],
       [ 1.4563642e-03,  9.9790984e-01,  1.0005499e+00,  9.9790984e-01],
       [-9.9146413e-04,  1.0003086e+00,  1.0010542e+00,  9.9993533e-01],
       [ 7.3765457e-04,  9.9880600e-01,  1.0000000e+00,  9.9880600e-01],
       [ 1.4381361e-04,  1.0002320e+00,  1.0023915e+00,  9.9956459e-01],
       [ 1.8598975e-03,  9.9846059e-01,  1.0004395e+00,  9.9833316e-01],
       [ 1.1529045e-04,  1.0002431e+00,  1.0002431e+00,  1.0000000e+00],
       [ 7.1045180e-04,  9.9971163e-01,  1.0000000e+00,  9.9874866e-01]],
      dtype=float32)}
(10, 4)


## Multi Timeframe Observations Example

In [6]:
window_sizes = [10, 20]
observer = AlpacaObservationClass(
    symbol="BTC/USD",
    timeframes=[
        TimeFrame(15, TimeFrameUnit.Minute),
        TimeFrame(1, TimeFrameUnit.Hour)
    ],
    window_sizes=window_sizes
)

In [7]:
observer.get_features()

{'observation_features': ['feature_close',
  'feature_open',
  'feature_high',
  'feature_low'],
 'original_features': ['index', 'open', 'high', 'low', 'close']}

In [9]:
obs = observer.get_observations()
print(obs)

{'15Minute_10': array([[ 0.0000000e+00,  1.0010585e+00,  1.0021054e+00,  9.9970555e-01],
       [-1.4106651e-03,  1.0022540e+00,  1.0022540e+00,  9.9931169e-01],
       [ 1.6442838e-03,  9.9896801e-01,  1.0000000e+00,  9.9834281e-01],
       [ 1.4563642e-03,  9.9790984e-01,  1.0005499e+00,  9.9790984e-01],
       [-9.9146413e-04,  1.0003086e+00,  1.0010542e+00,  9.9993533e-01],
       [ 7.3765457e-04,  9.9880600e-01,  1.0000000e+00,  9.9880600e-01],
       [ 1.4381361e-04,  1.0002320e+00,  1.0023915e+00,  9.9956459e-01],
       [ 1.8598975e-03,  9.9846059e-01,  1.0004395e+00,  9.9833316e-01],
       [ 1.1529045e-04,  1.0002431e+00,  1.0002431e+00,  1.0000000e+00],
       [ 7.1045180e-04,  9.9971163e-01,  1.0000000e+00,  9.9874866e-01]],
      dtype=float32), '1Hour_20': array([[ 0.0000000e+00,  9.9802035e-01,  1.0023915e+00,  9.9802035e-01],
       [ 2.1547519e-03,  9.9816686e-01,  1.0010757e+00,  9.9803942e-01],
       [ 1.4968647e-04,  1.0005095e+00,  1.0013787e+00,  9.9943256e-01],


## Add Custom Preprocessing Class 
With a custom preprocessing class, we can add any feature we want to the observations.


In [2]:
def custom_preprocessing(df):
    df = df.reset_index()
    df.dropna(inplace=True)
    df["feature_volatility"] = df["high"] - df["low"]
    df["feature_volume_ma"] = df["volume"].rolling(window=3).mean()
    df.dropna(inplace=True)  # Drop NaN values from rolling window
    return df

In [3]:
window_size = 10
observer = AlpacaObservationClass(
    symbol="BTC/USD",
    timeframes=TimeFrame(15, TimeFrameUnit.Minute),
    window_sizes=window_size,
    feature_preprocessing_fn=custom_preprocessing,
)

In [4]:
observer.get_features()

{'observation_features': ['feature_volatility', 'feature_volume_ma'],
 'original_features': ['index',
  'symbol',
  'open',
  'high',
  'low',
  'close',
  'volume']}

In [6]:
observer.get_keys(  )

['15Minute_10']

In [5]:
observer.get_observations()

{'15Minute_10': array([[1.7691850e+02, 7.5762826e-03],
        [2.8225400e+02, 3.0486675e-02],
        [1.1949700e+02, 2.9882619e-02],
        [1.2762250e+02, 2.3063026e-02],
        [3.0220001e+02, 2.6655692e-01],
        [2.2558501e+02, 2.6658788e-01],
        [2.6032499e+01, 2.6658788e-01],
        [1.3413000e+02, 1.8360966e-04]], dtype=float32)}