In [1]:
from hydra import compose, core, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
from pathlib import Path
import os

def compose_config_from_path(config_path, config_name="config"):
    """
    Compose a Hydra-compatible configuration from the specified path.
    
    Args:
        config_path (str): Path to the directory containing config files
        config_name (str, optional): Name of the main config file (without .yaml extension). 
                                   Defaults to "config".
    
    Returns:
        OmegaConf: Composed configuration object
    
    Raises:
        FileNotFoundError: If the config directory or config file doesn't exist
        hydra.errors.HydraException: If there are issues with Hydra configuration
    """
    # Reset Hydra to avoid conflicts
    GlobalHydra.instance().clear()
    
    # Convert to absolute path and validate
    config_path = Path(config_path)
    if not config_path.exists():
        raise FileNotFoundError(f"Config directory not found: {config_path}")
    
    if not config_path.is_dir():
        raise ValueError(f"Config path must be a directory: {config_path}")
    
    # Get absolute path and parent directory
    abs_config_path = config_path.resolve()
    parent_dir = abs_config_path.parent
    relative_config_path = abs_config_path.name
    
    # Save current working directory
    original_cwd = os.getcwd()
    
    try:
        # Change to parent directory and use relative path for Hydra
        os.chdir(parent_dir)
        
        # Initialize Hydra with the relative config path
        initialize(config_path=relative_config_path, version_base=None)
        
        # Compose the configuration
        cfg = compose(config_name=config_name)
        return cfg
        
    except Exception as e:
        # Clean up Hydra instance on error
        GlobalHydra.instance().clear()
        raise e
    finally:
        # Always restore original working directory
        os.chdir(original_cwd)

cfg = compose_config_from_path('./conf', 'a2c_trainer.yaml')



In [2]:
from hydra.utils import instantiate, to_absolute_path


data_handler = instantiate(cfg.data)
data = data_handler.get_data()

INFO:pipelines.rl_agent_policy.data.feature_engine:Initialized FeatureEngineeringProcessor
INFO:pipelines.rl_agent_policy.data.data_parser:Initialized DataHandler for 2 currencies
INFO:pipelines.rl_agent_policy.data.data_parser:Processing data...
INFO:pipelines.rl_agent_policy.data.data_parser:Starting data processing pipeline...
INFO:pipelines.rl_agent_policy.data.data_parser:Fetching data for currencies: ['LINK', 'SOL']
INFO:pipelines.rl_agent_policy.data.data_parser:Exchange: Binance, Quote: USDT, Timeframe: 1h
INFO:pipelines.rl_agent_policy.data.data_parser:Fetching LINK/USDT data...
INFO:pipelines.rl_agent_policy.data.data_parser:✓ Successfully fetched 58236 records for LINK
INFO:pipelines.rl_agent_policy.data.data_parser:Fetching SOL/USDT data...
INFO:pipelines.rl_agent_policy.data.data_parser:✓ Successfully fetched 44554 records for SOL
INFO:pipelines.rl_agent_policy.data.data_parser:Applying feature engineering...
INFO:pipelines.rl_agent_policy.data.data_parser:Processing featu

In [3]:
data

Unnamed: 0,date,LINK_open,LINK_high,LINK_low,LINK_close,LINK_volume,LINK_ma_5,LINK_ma_10,LINK_ma_20,LINK_ma_50,...,SOL_lr_high,SOL_lr_low,SOL_lr_close,SOL_prev_open,SOL_prev_high,SOL_prev_low,SOL_prev_close,SOL_prev_volume,SOL_rsi_14,SOL_macd
0,2020-08-11 06:00:00,13.3574,13.3656,13.0527,13.1781,5.202574e+06,13.21860,13.23967,13.335675,13.405414,...,,,,,,,,,,0.000000
1,2020-08-11 07:00:00,13.1795,13.2094,12.8793,13.0764,5.297624e+06,13.18764,13.22432,13.304485,13.412596,...,-0.101366,0.010471,-0.009908,2.8500,3.4700,2.8500,2.9515,6.140623e+04,,-0.000290
2,2020-08-11 08:00:00,13.0704,13.1879,12.7500,12.8069,6.447591e+06,13.11788,13.17062,13.265065,13.420534,...,-0.044176,0.011874,0.012784,2.9515,3.1355,2.8800,2.9224,1.251929e+05,,0.000411
3,2020-08-11 09:00:00,12.7965,13.0563,12.6189,12.7202,9.377370e+06,13.02772,13.09542,13.222635,13.421954,...,-0.008839,-0.022345,-0.036363,2.9626,3.0000,2.9144,2.9600,7.147019e+04,,-0.001960
4,2020-08-11 10:00:00,12.7202,12.8332,12.5555,12.7001,9.912966e+06,12.89634,13.05209,13.184200,13.422188,...,-0.013782,-0.002354,0.015056,2.9600,2.9736,2.8500,2.8543,7.726005e+04,,-0.001410
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
45533,2025-09-16 19:00:00,23.5000,23.5900,23.4500,23.5800,1.863454e+06,23.50800,23.51200,23.505000,23.665800,...,-0.003016,0.001181,0.000757,238.1500,239.0800,236.8500,237.6400,2.469391e+07,56.159069,0.513392
45534,2025-09-16 20:00:00,23.5800,23.5900,23.4800,23.5800,2.342528e+06,23.53400,23.51300,23.506000,23.655400,...,0.000168,0.000801,0.001974,237.6300,238.3600,237.1300,237.8200,1.869103e+07,61.980198,0.503658
45535,2025-09-16 21:00:00,23.5800,23.7000,23.5200,23.6900,2.155911e+06,23.58000,23.52000,23.519500,23.644800,...,0.001593,-0.000801,0.001049,237.8200,238.4000,237.3200,238.2900,1.147533e+07,63.512195,0.488020
45536,2025-09-16 22:00:00,23.7000,23.7500,23.5800,23.6100,3.045227e+06,23.59000,23.52600,23.527500,23.634800,...,0.001046,0.000295,-0.002603,238.3000,238.7800,237.1300,238.5400,1.554500e+07,58.633776,0.412866


In [4]:
import pandas as pd
import hydra
import numpy as np
from omegaconf import DictConfig, OmegaConf
from hydra.utils import instantiate, to_absolute_path
from typing import Union

import tensortrade.env.default as default
from tensortrade.env.default import actions as action_api, rewards as reward_api
from tensortrade.feed.core import DataFeed, Stream, NameSpace
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.wallets import Wallet, Portfolio
from tensortrade.oms.instruments import Instrument, registry
from gymnasium.spaces import MultiDiscrete

from pipelines.rl_agent_policy.train.a2c import A2CTrainer

In [5]:
# determine validation split either by ratio or absolute size
validation_size: Union[int, float] = cfg.get("validation_size", 0.2)
if isinstance(validation_size, int):
    if validation_size < 0:
        raise ValueError("validation_size must be non-negative")
    split_idx = max(0, len(data) - validation_size)
elif isinstance(validation_size, float):
    if not 0 <= validation_size <= 1:
        raise ValueError("validation_size ratio must be between 0 and 1")
    split_idx = int(len(data) * (1 - validation_size))
else:
    raise TypeError("validation_size must be int or float")

train_df = data.iloc[:split_idx].reset_index(drop=True)
valid_df = data.iloc[split_idx:].reset_index(drop=True)

# Ensure train and validation data have identical feature dimensions
train_features = train_df.drop(columns=["date"], errors="ignore").shape[1]
valid_features = valid_df.drop(columns=["date"], errors="ignore").shape[1]
if train_features != valid_features:
    raise ValueError(
        f"Train and validation feature counts differ: {train_features} != {valid_features}"
    )

# Propagate inferred input shape to the model configuration
input_shape = [train_features, cfg.env.window_size]
OmegaConf.set_struct(cfg.model.shared_network, False)
cfg.model.shared_network.input_shape = input_shape

assets = cfg.get("assets") or data_handler.symbols
main_currency = data_handler.main_currency

if main_currency not in registry:
    Instrument(main_currency, 2, main_currency)
base_instrument = registry[main_currency]

asset_instruments = []
for sym in assets:
    if sym not in registry:
        Instrument(sym, 8, sym)
    asset_instruments.append(registry[sym])

In [8]:
def build_env(df: pd.DataFrame):
        price_streams = [
            Stream.source(list(df[f"{sym}_close"]), dtype="float").rename(
                f"{main_currency}-{sym}"
            )
            for sym in assets
        ]
        exchange = Exchange(cfg.env.exchange, service=execute_order)(*price_streams)

        cash = Wallet(exchange, cfg.env.initial_cash * base_instrument)
        asset_wallets = [Wallet(exchange, 0 * inst) for inst in asset_instruments]
        portfolio = Portfolio(base_instrument, [cash, *asset_wallets])

        with NameSpace(cfg.env.exchange):
            feature_streams = [
                Stream.source(list(df[c]), dtype="float").rename(c)
                for c in df.columns
                if c != "date"
            ]
        feed = DataFeed(feature_streams)
        feed.compile()

        # renderer feed for plotting or further analysis
        renderer_streams = []
        if "date" in df.columns:
            renderer_streams.append(Stream.source(list(df["date"])).rename("date"))
        for sym in assets:
            for field in ["open", "high", "low", "close", "volume"]:
                column = f"{sym}_{field}"
                if column in df.columns:
                    renderer_streams.append(
                        Stream.source(list(df[column]), dtype="float").rename(column)
                    )
        renderer_feed = DataFeed(renderer_streams)
        renderer_feed.compile()

        action_cfg = cfg.get("action_scheme")
        try:
            action_scheme = action_api.create(action_cfg)
        except TypeError:
            params = {"cash": cash}
            if len(asset_wallets) == 1:
                params["asset"] = asset_wallets[0]
            else:
                params["assets"] = asset_wallets
            action_scheme = action_api.create(action_cfg, **params)

        reward_cfg = cfg.get("reward_scheme")
        reward_scheme = reward_api.create(reward_cfg)

        env = default.create(
            portfolio=portfolio,
            action_scheme=action_scheme,
            reward_scheme=reward_scheme,
            feed=feed,
            renderer_feed=renderer_feed,
            window_size=cfg.env.window_size,
            enable_logger=False,
        )
        return env

train_env = build_env(train_df)
valid_env = build_env(valid_df)

In [9]:
train_env

<tensortrade.env.generic.environment.TradingEnv at 0x7f01cbbb3650>