In [1]:
import io
from typing import Dict, Type

import pandas as pd
from sb3_contrib import ARS, QRDQN, TQC, TRPO, RecurrentPPO
from stable_baselines3 import A2C, DDPG, DQN, PPO, SAC, TD3
from stable_baselines3.common.base_class import BaseAlgorithm

from app.configuration.config import TinvestSettings, PgSqlSettings
from gymnasium import Env

from app.models.data.candle import Interval
from app.models.feature_composer import FeatureComposer
from app.models.gym_env_factory import GymEnvFactory, register_single_asset_trading_env

from app.models.predictor_sb3 import PredictorSb3
from app.repositories.tinvest_candles import TinvestCandlesRepository
from app.repositories.sb3_models import Sb3ModelsRepository

from app.services.tinvest import TinvestAdapter

In [2]:
tinvest_config = TinvestSettings(_env_file='../app/.env')
tinvest = TinvestAdapter(tinvest_config)

sql_config = PgSqlSettings(_env_file='../app/.env')
candles_repo = TinvestCandlesRepository(sql_config)
models_repo = Sb3ModelsRepository(sql_config)

ALGOS: Dict[str, Type[BaseAlgorithm]] = {
    "a2c": A2C,
    "ddpg": DDPG,
    "dqn": DQN,
    "ppo": PPO,
    "sac": SAC,
    "td3": TD3,
    # SB3 Contrib,
    "ars": ARS,
    "qrdqn": QRDQN,
    "tqc": TQC,
    "trpo": TRPO,
    "ppo_lstm": RecurrentPPO,
}

register_single_asset_trading_env()

Environment registered. EnvSpec(id='SingleAssetTrading-v1', entry_point='app.models.gym_env_single_asset:SingleAssetTrading', reward_threshold=None, nondeterministic=False, max_episode_steps=None, order_enforce=True, autoreset=False, disable_env_checker=True, apply_api_compatibility=False, kwargs={}, namespace=None, name='SingleAssetTrading', version=1, additional_wrappers=(), vector_entry_point=None)


In [3]:
#df = candles_repo.get_candles("SBER", Interval.CANDLE_INTERVAL_10_MIN.value)
df = candles_repo.get_last_candles(symbol="SBER", interval=Interval.CANDLE_INTERVAL_10_MIN.value, count=20)
fk = FeatureComposer(fill_missing_values=True)
df = fk.compose(df)
#print(df.shape)
df

Unnamed: 0,open,close,feature_close_return,feature_ret_m1,feature_ret_m2,feature_ret_m3,feature_ret_m4,feature_ret_m5,feature_ret_m6,feature_ret_m7,feature_ret_m8,info_intraday_return
2025-02-06 09:20:00+00:00,287.39,287.15,-0.24,0.56,0.03,-0.44,-0.51,0.34,0.18,0.03,-0.72,-0.24
2025-02-06 09:30:00+00:00,287.15,287.34,0.19,-0.24,0.56,0.03,-0.44,-0.51,0.34,0.18,0.03,0.19
2025-02-06 09:40:00+00:00,287.33,286.8,-0.54,0.19,-0.24,0.56,0.03,-0.44,-0.51,0.34,0.18,-0.53
2025-02-06 09:50:00+00:00,286.8,286.43,-0.37,-0.54,0.19,-0.24,0.56,0.03,-0.44,-0.51,0.34,-0.37
2025-02-06 10:00:00+00:00,286.43,286.28,-0.15,-0.37,-0.54,0.19,-0.24,0.56,0.03,-0.44,-0.51,-0.15
2025-02-06 10:10:00+00:00,286.28,286.26,-0.02,-0.15,-0.37,-0.54,0.19,-0.24,0.56,0.03,-0.44,-0.02
2025-02-06 10:20:00+00:00,286.22,286.52,0.26,-0.02,-0.15,-0.37,-0.54,0.19,-0.24,0.56,0.03,0.3
2025-02-06 10:30:00+00:00,286.52,285.94,-0.58,0.26,-0.02,-0.15,-0.37,-0.54,0.19,-0.24,0.56,-0.58
2025-02-06 10:40:00+00:00,285.9,286.23,0.29,-0.58,0.26,-0.02,-0.15,-0.37,-0.54,0.19,-0.24,0.33
2025-02-06 10:50:00+00:00,286.22,286.17,-0.06,0.29,-0.58,0.26,-0.02,-0.15,-0.37,-0.54,0.19,-0.05


In [4]:
env_factory = GymEnvFactory(df)
env, total_steps = env_factory.create_env()
#observation, info = env.reset()
#observation, info

In [5]:
# load model
file_name = "dqn_model.zip"
last_model = models_repo.get_model_by_file_name(file_name)
fio = io.BytesIO(last_model.content)
model = ALGOS[last_model.algo].load(fio)


In [6]:
sb3_predictor = PredictorSb3(env, model)

In [7]:
prediction = sb3_predictor.predict_last()
print(prediction)

(Timestamp('2025-02-06 11:00:00+0000', tz='UTC'), 2)


In [8]:
prediction_df = sb3_predictor.predict_all()   
prediction_df

Unnamed: 0,action
2025-02-06 09:20:00+00:00,0
2025-02-06 09:30:00+00:00,0
2025-02-06 09:40:00+00:00,0
2025-02-06 09:50:00+00:00,1
2025-02-06 10:00:00+00:00,1
2025-02-06 10:10:00+00:00,0
2025-02-06 10:20:00+00:00,2
2025-02-06 10:30:00+00:00,2
2025-02-06 10:40:00+00:00,2
2025-02-06 10:50:00+00:00,2


In [22]:
from app.models.predictor_base import PredictorRandomWalk

rw_predictor = PredictorRandomWalk(df, 7)
rw_df = rw_predictor.predict_all()   
rw_df

Unnamed: 0,action
2025-02-06 09:20:00+00:00,1
2025-02-06 09:30:00+00:00,0
2025-02-06 09:40:00+00:00,1
2025-02-06 09:50:00+00:00,2
2025-02-06 10:00:00+00:00,0
2025-02-06 10:10:00+00:00,0
2025-02-06 10:20:00+00:00,2
2025-02-06 10:30:00+00:00,0
2025-02-06 10:40:00+00:00,1
2025-02-06 10:50:00+00:00,2


In [12]:
from app.models.predictor_base import PredictorTrendFollowing

tf_predictor = PredictorTrendFollowing(df, threshold=0.01)
tf_df = tf_predictor.predict_all()   
tf_df

Unnamed: 0,action
2025-02-06 09:20:00+00:00,1
2025-02-06 09:30:00+00:00,2
2025-02-06 09:40:00+00:00,1
2025-02-06 09:50:00+00:00,1
2025-02-06 10:00:00+00:00,1
2025-02-06 10:10:00+00:00,1
2025-02-06 10:20:00+00:00,2
2025-02-06 10:30:00+00:00,1
2025-02-06 10:40:00+00:00,2
2025-02-06 10:50:00+00:00,1
