In [1]:
import os
import sys
sys.path.append("..")

In [2]:
import pandas as pd
import gymnasium as Env
import numpy as np
import ray

from copy import deepcopy

In [3]:
from ray.rllib.algorithms.ppo import PPOConfig
from ray.rllib.models import ModelCatalog
from ray.rllib.models.torch.torch_modelv2 import TorchModelV2
from ray.tune.registry import register_env

import torch
import torch.nn as nn

In [4]:
%load_ext autoreload
%autoreload 2

from src.environments.fx_environment import FxTradingEnv

Todo 

- Select hourly, minutely or daily data
- Truncation in days

- Write tests:

1. Unknown currency in initial portfolio
2. Missing currency in initial portfolio
3. Portfolio value in base
4. Portfolio weights == 1
5. reset works
6. etc

In [5]:
current_porfolio = {
    "usd": 100_000,
    "eur": 100_000,
    "jpy": 100_000,
    # "sgd": 100_000,
}

In [6]:
historical_data = pd.read_parquet("../data/FX_data.parquet.gzip")

In [7]:
env_config = {
    "historical_prices": historical_data[['eurjpy', 'eurusd', 'usdjpy']],
    "initial_portfolio": current_porfolio,
    "start_datetime": pd.Timestamp("2011-01-03 09:00:00")
}

In [16]:
def env_creator(env_config):
    """
    Create env
    """
    fx_env = FxTradingEnv(**env_config)
    fx_env.preprocess_data()
    return fx_env

In [23]:
class FXModel(TorchModelV2, nn.Module):
    """
    Model for action prediction
    """
    def __init__(
        self, obs_space, action_space, num_outputs, model_config, name
    ):
        
        TorchModelV2.__init__(
            self, obs_space, action_space, num_outputs, model_config, name
        )
        nn.Module.__init__(self)
        obs_dim = obs_space.shape[0]
        
        self.main_net = nn.Sequential(
            nn.Linear(obs_dim, 256),
            nn.GELU(),
            nn.Linear(256, 128),
            nn.GELU(),
            nn.Linear(128, 64),
            nn.ReLU()
        )

        self.action_net = nn.Sequential(
            nn.Linear(64, num_outputs),
            nn.Tanh()
        )

        self.value_net = nn.Linear(64, 1)
        self._value = None

    def forward(self, input_dict, state, seq_lens):
        x = self.main_net(input_dict["obs"])
        self._value = self.value_net(x)
        return self.action_net(x), state

    def value_function(self):
        return self._value.squeeze(1)

In [24]:
register_env("fx_trading_env", env_creator)
ModelCatalog.register_custom_model("fx_model", FXModel)

In [25]:
ray.init(
    ignore_reinit_error=True,
    runtime_env={
        "working_dir": os.path.dirname(os.path.abspath(".")),
        "py_modules": [os.path.abspath(".")],
        "excludes": ["*.pyc", "__pycache__", "*.parquet.gzip", "data/", "notebooks/"]
    }
)

2026-02-06 10:29:40,859	INFO worker.py:1855 -- Calling ray.init() again after it has already been called.


0,1
Python version:,3.12.10
Ray version:,2.52.1


In [26]:
suggested_workers = 1

if torch.cuda.is_available():
    num_gpus = torch.cuda.device_count()
else:
    num_gpus = 0

In [27]:
config = PPOConfig()

config = (
    config.environment(
        env="fx_trading_env",
        env_config=env_config,
    )
    .framework("torch")
    .training(
        model={"custom_model": "fx_model"},
        lr=1e-3,
        train_batch_size=2048,
    )
    # .debugging(log_level="ERROR")
    .resources(num_gpus=num_gpus)
    .env_runners(
        num_cpus_per_env_runner=1,
        num_env_runners=suggested_workers,
        rollout_fragment_length="auto",
    )
    .api_stack(
        enable_rl_module_and_learner=False,
        enable_env_runner_and_connector_v2=False
    )
)

In [28]:
trainer = config.build_algo()

2026-02-06 10:30:08,377	INFO trainable.py:161 -- Trainable.setup took 25.248 seconds. If your trainable is slow to initialize, consider setting reuse_actors=True to reduce actor creation overheads.


In [None]:

# algo = config.build()

# for iteration in range(100):
#     result = algo.train()
#     if iteration % 10 == 0:
#         print(f"Iteration {iteration}: "
#               f"Mean reward: {result['episode_reward_mean']:.3f}, "
#               f"Length: {result['episode_len_mean']:.1f}")

# checkpoint_dir = algo.save()
# algo.stop()