In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'

import sys
from pathlib import Path

# setup path to project root
sys.path.append(str(Path.cwd().parent))

In [2]:
from hydra import compose, core, initialize
from hydra.core.global_hydra import GlobalHydra
from omegaconf import OmegaConf
from pathlib import Path
import os

GlobalHydra.instance().clear()
initialize(config_path='../conf', version_base=None)
config = compose(config_name="a2c_trainer")



In [3]:
from hydra.utils import instantiate, to_absolute_path
from typing import Union, Optional



data_handler = instantiate(config.data)
data = data_handler.get_data()

# determine validation split either by ratio or absolute size
validation_size: Union[int, float] = config.validation.get("validation_size", 365)
if isinstance(validation_size, int):
    if validation_size < 0:
        raise ValueError("validation_size must be non-negative")
    split_idx = max(0, len(data) - validation_size)
elif isinstance(validation_size, float):
    if not 0 <= validation_size <= 1:
        raise ValueError("validation_size ratio must be between 0 and 1")
    split_idx = int(len(data) * (1 - validation_size))
else:
    raise TypeError("validation_size must be int or float")

train_df = data.iloc[:split_idx].reset_index(drop=True).copy()
valid_df = data.iloc[split_idx:].reset_index(drop=True).copy()
train_features = train_df.drop(columns=["date"], errors="ignore").shape[1]
valid_features = valid_df.drop(columns=["date"], errors="ignore").shape[1]

INFO:pipelines.rl_agent_policy.data.feature_engine:Initialized FeatureEngineeringProcessor
INFO:pipelines.rl_agent_policy.data.data_parser:Initialized DataHandler for 1 currencies
INFO:pipelines.rl_agent_policy.data.data_parser:Processing data...
INFO:pipelines.rl_agent_policy.data.data_parser:Starting data processing pipeline...
INFO:pipelines.rl_agent_policy.data.data_parser:Fetching data for currencies: ['LINK']
INFO:pipelines.rl_agent_policy.data.data_parser:Exchange: Binance, Quote: USDT, Timeframe: 1d
INFO:pipelines.rl_agent_policy.data.data_parser:Fetching LINK/USDT data...
INFO:pipelines.rl_agent_policy.data.data_parser:✓ Successfully fetched 2443 records for LINK
INFO:pipelines.rl_agent_policy.data.data_parser:Applying feature engineering...
INFO:pipelines.rl_agent_policy.data.data_parser:Processing features for LINK...
  from pkg_resources import get_distribution, DistributionNotFound
INFO:pipelines.rl_agent_policy.data.feature_engine:Generated 20 custom features for LINK_
  

In [4]:
from tensortrade.oms.instruments import Instrument, registry


input_shape = [train_features, config.env.window_size]
OmegaConf.set_struct(config.model.shared_network, False)
config.model.shared_network.input_shape = input_shape

assets = config.get("assets") or data_handler.symbols
main_currency = data_handler.main_currency

if main_currency not in registry:
    Instrument(main_currency, 2, main_currency)
base_instrument = registry[main_currency]

asset_instruments = []

assert isinstance(config.env.assets_initial, int) or len(assets) == len(config.env.assets_initial), \
    "assets and assets_initial must have the same length or be constant"

if isinstance(config.env.assets_initial, int):
    config.env.assets_initial = config.env.assets_initial * len(assets)
for sym, init_amount in zip(assets, config.env.assets_initial):
    if sym not in registry:
        registry[sym] = Instrument(sym, 8, sym)
    asset_instruments.append((registry[sym], init_amount))

In [5]:
import pandas as pd
import random
import tensortrade.env.default as default
from tensortrade.env.default import actions as action_api, rewards as reward_api
from tensortrade.env.default.renderers import construct_renderers
from tensortrade.feed.core import DataFeed, Stream, NameSpace
from tensortrade.oms.exchanges import Exchange
from tensortrade.oms.services.execution.simulated import execute_order
from tensortrade.oms.wallets import Wallet, Portfolio
from tensortrade.oms.instruments import Instrument, registry
from gymnasium.spaces import MultiDiscrete

try:
    from accelerate import Accelerator
    ACCELERATE_AVAILABLE = True
except ImportError:
    ACCELERATE_AVAILABLE = False
    Accelerator = None

def _get_process_info():
    """Get process information for distributed training."""
    if ACCELERATE_AVAILABLE:
        try:
            # Try to create accelerator to get process info
            accelerator = Accelerator()
            return {
                'process_index': accelerator.process_index,
                'num_processes': accelerator.num_processes,
                'is_main_process': accelerator.is_main_process
            }
        except:
            pass
    
    # Fallback to single process
    return {
        'process_index': 0,
        'num_processes': 1,
        'is_main_process': True
    }


# ------------------------------------------------------------------
# Environment building
# ------------------------------------------------------------------
def build_env(df: pd.DataFrame, env_rng: Optional[random.Random] = None):
    
    # we do trade on close prices for previous day
    price_streams = [
        Stream.source(list(df[f"{sym}_close"]), dtype="float").rename(
            f"{main_currency}-{sym}"
        )
        for sym in assets
    ]
    exchange = Exchange(config.env.exchange, service=execute_order)(*price_streams)

    cash = Wallet(exchange, config.env.initial_cash * base_instrument)
    asset_wallets = [Wallet(exchange, init_amount * inst) for inst, init_amount in asset_instruments]
    portfolio = Portfolio(base_instrument, [cash, *asset_wallets])

    with NameSpace(config.env.exchange):
        feature_streams = [
            Stream.source(list(df[c]), dtype="float").rename(c)
            for c in df.columns
            if c != "date"
        ]
    
    feed = DataFeed(feature_streams)
    feed.compile()

    # renderer feed for plotting or further analysis
    renderer_streams = []
    if "date" in df.columns:
        renderer_streams.append(Stream.source(list(df["date"])).rename("date"))
    for sym in assets:
        for field in ["open", "high", "low", "close", "volume"]:
            column = f"{sym}_{field}"
            if column in df.columns:
                renderer_streams.append(
                    Stream.source(list(df[column]), dtype="float").rename(column)
                )
    renderer_feed = DataFeed(renderer_streams)
    renderer_feed.compile()

    action_scheme = action_api.get(
        'simple',
        portfolio=portfolio,
        criteria=[None],
        trade_sizes=[0.005, 0.01, 0.05, 0.1, 0.2, 0.3, 0.4],
        min_order_abs=0,
        min_order_pct=0,
    )
    
    reward_scheme = reward_api.get('risk-adjusted')

    reward_cfg = config.get("reward_scheme")
    reward_scheme = reward_api.create(reward_cfg)
    
    # create renderers
    renderer_list = config.validation.get('renderers', 'all')
    renderer_formats = config.validation.get('renderer_formats', ["png", "html"])
    renderers = construct_renderers(renderer_list, display=True, save_formats=renderer_formats)

    env = default.create(
        portfolio=portfolio,
        action_scheme=action_scheme,
        reward_scheme=reward_scheme,
        feed=feed,
        renderer_feed=renderer_feed,
        window_size=config.env.window_size,
        max_episode_length=config.env.get('max_episode_length', None),
        enable_logger=False,
        renderer=renderers,
        rng=env_rng
    )
    
    return env

# Create process-specific environment seeds and RNG instances
process_info = _get_process_info()
process_index = process_info['process_index']
base_seed = config.get('seed', 42)
train_env_seed = (base_seed + process_index * 2003) % (2**32)  # Different prime for env seeding
valid_env_seed = (base_seed + process_index * 2003 + 1009) % (2**32)  # Offset for validation

# Create separate RNG instances for each environment
train_rng = random.Random(train_env_seed)
valid_rng = random.Random(valid_env_seed)

train_env = build_env(train_df, train_rng)
valid_env = build_env(valid_df, valid_rng)

Constructing renderers: type(identifier)=<class 'omegaconf.listconfig.ListConfig'>
Constructing renderers: type(identifier)=<class 'omegaconf.listconfig.ListConfig'>


In [7]:
from pipelines.rl_agent_policy.train.a2c import A2CTrainer
import tempfile
import logging

PATH_TO_CHECKPOINT = '/mnt/virtual_ai0001071-01239_SR006-nfs1/afedorov/mft_prj/outputs/2025-09-21/a2c_train/checkpoints/6460086'

agent = instantiate(config.model, env=train_env)
train_config = instantiate(config.train.approach)
if hasattr(config, 'seed') and config.seed is not None:
    train_config.seed = config.seed

with tempfile.TemporaryDirectory() as temp_dir:
    pt = Path(temp_dir)
    
    trainer = A2CTrainer(
        agent=agent, 
        train_env=train_env, 
        valid_env=valid_env, 
        output_dir=pt, 
        config=train_config,
        max_episode_length=config.env.get('max_episode_length', None),
        use_accelerate=False
    )
    
    trainer._load_checkpoint(PATH_TO_CHECKPOINT)
    trainer._validate(save_validation_output=False)