In [1]:
import trading_gym as tg
from stable_baselines3 import A2C, PPO, DQN
from time import time

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def entrainement(env, algo=A2C, timesteps=50000):
    model = algo("MlpPolicy", env, verbose=1, tensorboard_log="logs/")
    model.learn(total_timesteps=timesteps, tb_log_name=str(algo))
    return model

In [3]:
amzn_data = tg.recuperation_donnee(start="01/01/2017", end="01/04/2022", interval="1d")
print("Nb donnée : ", len(amzn_data), amzn_data.head())
env_amz = tg.creation_env(df=amzn_data, frame_bound=(15, 1000))
t1 = time()
model_A2C = entrainement(env_amz, algo=A2C, timesteps=20_000)
t2 = time()
model_PPO = entrainement(env_amz, algo=PPO, timesteps=20_000)
t3 = time()
model_DQN = entrainement(env_amz, algo=DQN, timesteps=20_000)
t4 = time()

Nb donnée :  1260                   Open        High         Low       Close      Volume
Date                                                                  
2017-01-03  757.919983  758.760010  747.700012  753.669983  753.669983
2017-01-04  758.390015  759.679993  754.200012  757.179993  757.179993
2017-01-05  761.549988  782.400024  760.260010  780.450012  780.450012
2017-01-06  782.359985  799.440002  778.479980  795.989990  795.989990
2017-01-09  798.000000  801.770020  791.770020  796.919983  796.919983
Using cpu device
Logging to logs/<class 'stable_baselines3.a2c.a2c.A2C'>_1
------------------------------------
| time/                 |          |
|    fps                | 600      |
|    iterations         | 100      |
|    time_elapsed       | 0        |
|    total_timesteps    | 500      |
| train/                |          |
|    entropy_loss       | -0.693   |
|    explained_variance | -0.00741 |
|    learning_rate      | 0.0007   |
|    n_updates          | 99       |
|  

------------------------------------
| time/                 |          |
|    fps                | 500      |
|    iterations         | 1600     |
|    time_elapsed       | 15       |
|    total_timesteps    | 8000     |
| train/                |          |
|    entropy_loss       | -0.684   |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 1599     |
|    policy_loss        | -0.0759  |
|    value_loss         | 0.0153   |
------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 505      |
|    iterations         | 1700     |
|    time_elapsed       | 16       |
|    total_timesteps    | 8500     |
| train/                |          |
|    entropy_loss       | -0.674   |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 1699     |
|    policy_loss        | -3.7     |
|    value_loss         | 133      |
-

-------------------------------------
| time/                 |           |
|    fps                | 554       |
|    iterations         | 3200      |
|    time_elapsed       | 28        |
|    total_timesteps    | 16000     |
| train/                |           |
|    entropy_loss       | -0.684    |
|    explained_variance | -1.19e-07 |
|    learning_rate      | 0.0007    |
|    n_updates          | 3199      |
|    policy_loss        | 9.86      |
|    value_loss         | 291       |
-------------------------------------
------------------------------------
| time/                 |          |
|    fps                | 555      |
|    iterations         | 3300     |
|    time_elapsed       | 29       |
|    total_timesteps    | 16500    |
| train/                |          |
|    entropy_loss       | -0.682   |
|    explained_variance | 0        |
|    learning_rate      | 0.0007   |
|    n_updates          | 3299     |
|    policy_loss        | 3.71     |
|    value_loss         

-----------------------------------------
| time/                   |             |
|    fps                  | 714         |
|    iterations           | 7           |
|    time_elapsed         | 20          |
|    total_timesteps      | 14336       |
| train/                  |             |
|    approx_kl            | 0.011204361 |
|    clip_fraction        | 0.0886      |
|    clip_range           | 0.2         |
|    entropy_loss         | -0.591      |
|    explained_variance   | 5.96e-08    |
|    learning_rate        | 0.0003      |
|    loss                 | 3.24e+03    |
|    n_updates            | 60          |
|    policy_gradient_loss | -0.00445    |
|    value_loss           | 6.44e+03    |
-----------------------------------------
-----------------------------------------
| time/                   |             |
|    fps                  | 708         |
|    iterations           | 8           |
|    time_elapsed         | 23          |
|    total_timesteps      | 16384 

In [5]:
model_A2C.save("A2C")
model_PPO.save("PPO")
model_DQN.save("DQN")

In [7]:
temps_A2C = t2 - t1
temps_PP0 = t3 - t2
temps_DQN = t4 - t3
temps_A2C, temps_DQN, temps_PP0

(35.660797357559204, 2.4179999828338623, 30.659951210021973)