<a target="_blank" href="https://colab.research.google.com/github/AI4Finance-Foundation/FinRL-Tutorials/blob/master/1-Introduction/China_A_share_market_tushare.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

## Quantitative trading in China A stock market with FinRL

Install FinRL

Install other libraries

In [1]:
!pip install stockstats
!pip install tushare


Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
Looking in indexes: https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple


## Import Modules

In [2]:
import warnings

warnings.filterwarnings("ignore")

import pandas as pd 
from IPython import display



from meta import config 
from meta.data_processor import DataProcessor 
from main import check_and_make_directories 
from meta.data_processors.tushare import Tushare, ReturnPlotter 
from meta.env_stock_trading.env_stocktrading_China_A_shares import StockTradingEnv 
from agents.stablebaselines3_models import DRLAgent 
import os 
from typing import List 
from argparse import ArgumentParser 
from meta import config 
from meta.config_tickers import DOW_30_TICKER 
from meta.config import ( DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR, INDICATORS, TRAIN_START_DATE, TRAIN_END_DATE, TEST_START_DATE, TEST_END_DATE, TRADE_START_DATE, TRADE_END_DATE, ERL_PARAMS, RLlib_PARAMS, SAC_PARAMS, ALPACA_API_KEY, ALPACA_API_SECRET, ALPACA_API_BASE_URL, )

import pyfolio
from pyfolio import timeseries

pd.options.display.max_columns = None

print("ALL Modules have been imported!")

ALL Modules have been imported!


## Create Folders

In [3]:
import os

''' 
use check_and_make_directories() to replace the following

if not os.path.exists("./datasets"): 
  os.makedirs("./datasets") 
if not os.path.exists("./trained_models"): 
  os.makedirs("./trained_models") 
if not os.path.exists("./tensorboard_log"): 
  os.makedirs("./tensorboard_log") 
if not os.path.exists("./results"): 
  os.makedirs("./results") 
'''

check_and_make_directories([DATA_SAVE_DIR, TRAINED_MODEL_DIR, TENSORBOARD_LOG_DIR, RESULTS_DIR])

## Download data, cleaning and feature engineering

In [4]:
ticker_list = ['600000.SH', '600009.SH', '600016.SH', '600028.SH', '600030.SH', '600031.SH', '600036.SH', '600050.SH', '600104.SH', '600196.SH', '600276.SH', '600309.SH', '600519.SH', '600547.SH', '600570.SH']

TRAIN_START_DATE = '2015-01-01' 
TRAIN_END_DATE= '2019-08-01' 
TRADE_START_DATE = '2019-08-01' 
TRADE_END_DATE = '2020-01-03'

TIME_INTERVAL = "1d" 
kwargs = {} 
kwargs['token'] = '27080ec403c0218f96f388bca1b1d85329d563c91a43672239619ef5' 
p = DataProcessor(data_source='tushare', start_date=TRAIN_START_DATE, end_date=TRADE_END_DATE, time_interval=TIME_INTERVAL, **kwargs)

tushare successfully connected


### Download and Clean

In [5]:
p.download_data(ticker_list=ticker_list)
p.clean_data()
p.fillna()

100%|██████████| 15/15 [00:06<00:00,  2.15it/s]

Download complete! Dataset saved to ./data/dataset.csv. 
Shape of DataFrame: (17960, 8)
Shape of DataFrame:  (18315, 8)





### Add technical indicator

In [6]:
p.add_technical_indicator(config.INDICATORS) 
p.fillna()

#print(f"p.dataframe: {p.dataframe}")

tech_indicator_list:  ['macd', 'boll_ub', 'boll_lb', 'rsi_30', 'cci_30', 'dx_30', 'close_30_sma', 'close_60_sma']
indicator:  macd
indicator:  boll_ub
indicator:  boll_lb
indicator:  rsi_30
indicator:  cci_30
indicator:  dx_30
indicator:  close_30_sma
indicator:  close_60_sma
Succesfully add technical indicators
Shape of DataFrame:  (18270, 17)


## Split training dataset

In [7]:
train = p.data_split(p.dataframe, TRAIN_START_DATE, TRAIN_END_DATE) 

print(f"len(train.tic.unique()): {len(train.tic.unique())}")

len(train.tic.unique()): 15


In [8]:
print(f"train.tic.unique(): {train.tic.unique()}")

train.tic.unique(): ['600000.SH' '600009.SH' '600016.SH' '600028.SH' '600030.SH' '600031.SH'
 '600036.SH' '600050.SH' '600104.SH' '600196.SH' '600276.SH' '600309.SH'
 '600519.SH' '600547.SH' '600570.SH']


In [9]:
print(f"train.head(): {train.head()}")

train.head():          tic        time  index   open   high    low  close  adjusted_close  \
0  600000.SH  2015-01-08     45  15.87  15.88  15.20  15.25           15.25   
0  600009.SH  2015-01-08     46  20.18  20.18  19.73  20.00           20.00   
0  600016.SH  2015-01-08     47  10.61  10.66  10.09  10.20           10.20   
0  600028.SH  2015-01-08     48   7.09   7.41   6.83   6.85            6.85   
0  600030.SH  2015-01-08     49  36.40  36.70  34.68  35.25           35.25   

       volume      macd    boll_ub    boll_lb     rsi_30      cci_30  \
0  3306271.72 -0.032571  16.617911  15.012089   6.058641 -125.593009   
0   198117.45 -0.016008  20.663897  19.736103  12.828915  -90.842491   
0  4851684.17 -0.018247  10.957604   9.997396  11.862558  -99.887006   
0  8190902.35 -0.008227   7.342000   6.743000  27.409248   36.578171   
0  6376268.69  0.032910  36.576444  33.808556  61.517448   47.947020   

        dx_30  close_30_sma  close_60_sma  
0   23.014040       15.8150       

In [10]:
print(f"train.shape: {train.shape}")

train.shape: (16695, 17)


In [11]:
stock_dimension = len(train.tic.unique()) 
state_space = stock_dimension * (len(config.INDICATORS) + 2) + 1 

print(f"Stock Dimension: {stock_dimension}, State Space: {state_space}")

Stock Dimension: 15, State Space: 151


## Train

In [12]:
env_kwargs = { "stock_dim": stock_dimension, "hmax": 1000, "initial_amount": 1000000, "buy_cost_pct": 6.87e-5, "sell_cost_pct": 1.0687e-3, "reward_scaling": 1e-4, "state_space": state_space, "action_space": stock_dimension, "tech_indicator_list": config.INDICATORS, "print_verbosity": 1, "initial_buy": True, "hundred_each_trade": True }

e_train_gym = StockTradingEnv(df=train, **env_kwargs)

In [13]:
env_train, _ = e_train_gym.get_sb_env() 

print(f"print(type(env_train)): {print(type(env_train))}")

<class 'stable_baselines3.common.vec_env.dummy_vec_env.DummyVecEnv'>
print(type(env_train)): None


### DDPG

In [14]:
import os
from stable_baselines3 import DDPG # Import DDPG class for loading
# ... existing code ...

agent = DRLAgent(env=e_train_gym) # Use e_train_gym directly as it's the unwrapped env

# Define model save path
model_save_path = os.path.join(config.TRAINED_MODEL_DIR, 'ddpg_model.zip')

# Check if model exists, if so, load it
if os.path.exists(model_save_path):
    print(f"Loading existing model from {model_save_path}")
    # Pass env to the loaded model. Note: stable_baselines3 load method handles wrapping if needed
    trained_ddpg = DDPG.load(model_save_path, env=e_train_gym)
else:
    print(f"No existing model found at {model_save_path}. Training new model...")
    DDPG_PARAMS = { "batch_size": 256, "buffer_size": 50000, "learning_rate": 0.0005, "action_noise": "normal", }
    POLICY_KWARGS = dict(net_arch=dict(pi=[64, 64], qf=[400, 300]))
    # Get the model instance (unwrapped env is fine here)
    model_ddpg = agent.get_model("ddpg", model_kwargs=DDPG_PARAMS, policy_kwargs=POLICY_KWARGS)

    # Train the model using the wrapped environment from get_sb_env()
    env_train_wrapped, _ = e_train_gym.get_sb_env()
    trained_ddpg = agent.train_model(model=model_ddpg, tb_log_name='ddpg', total_timesteps=10000)
    # Save the trained model
    trained_ddpg.save(model_save_path)
    print(f"Trained model saved to {model_save_path}")

# print(f"print(type(env_train)): {print(type(env_train))}")

Loading existing model from trained_models\ddpg_model.zip
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


### A2C

In [15]:
import os
from stable_baselines3 import A2C # Import A2C class for loading
# ... existing code ...

agent = DRLAgent(env=e_train_gym) # Use e_train_gym directly as it's the unwrapped env

# Define model save path
model_save_path = os.path.join(config.TRAINED_MODEL_DIR, 'a2c_model.zip')

# Check if model exists, if so, load it
if os.path.exists(model_save_path):
    print(f"Loading existing model from {model_save_path}")
    trained_a2c = A2C.load(model_save_path, env=e_train_gym) # Pass env for the loaded model
else:
    print(f"No existing model found at {model_save_path}. Training new model...")
    # Get the model instance (unwrapped env is fine here)
    model_a2c = agent.get_model("a2c")

    # Train the model using the wrapped environment from get_sb_env()
    env_train_wrapped, _ = e_train_gym.get_sb_env()
    trained_a2c = agent.train_model(model=model_a2c, tb_log_name='a2c', total_timesteps=50000)
    # Save the trained model
    trained_a2c.save(model_save_path)
    print(f"Trained model saved to {model_save_path}")

# print(f"print(type(env_train)): {print(type(env_train))}")




Loading existing model from trained_models\a2c_model.zip
Wrapping the env with a `Monitor` wrapper
Wrapping the env in a DummyVecEnv.


## Trade

In [16]:
trade = p.data_split(p.dataframe, TRADE_START_DATE, TRADE_END_DATE) 
env_kwargs = { "stock_dim": stock_dimension, "hmax": 1000, "initial_amount": 1000000, "buy_cost_pct": 6.87e-5, "sell_cost_pct": 1.0687e-3, "reward_scaling": 1e-4, "state_space": state_space, "action_space": stock_dimension, "tech_indicator_list": config.INDICATORS, "print_verbosity": 1, "initial_buy": False, "hundred_each_trade": True } 
e_trade_gym = StockTradingEnv(df=trade, **env_kwargs)

In [17]:
df_account_value, df_actions = DRLAgent.DRL_prediction(model=trained_ddpg, environment=e_trade_gym)

Episode: 2
day: 103, episode: 2
begin_total_asset: 1000000.00
end_total_asset: 1226184.32
total_reward: 226184.32
total_cost: 68.68
total_trades: 4
Sharpe: 2.167
hit end!


In [18]:
df_actions.to_csv("action.csv", index=False) 
print(f"df_actions: {df_actions}")

df_actions:             600000.SH  600009.SH  600016.SH  600028.SH  600030.SH  600031.SH  \
date                                                                           
2019-08-01          0          0          0          0          0          0   
2019-08-01          0          0          0          0          0          0   
2019-08-02          0          0          0          0          0          0   
2019-08-05          0          0          0          0          0          0   
2019-08-06          0          0          0          0          0          0   
...               ...        ...        ...        ...        ...        ...   
2019-12-24          0          0          0          0          0          0   
2019-12-25          0          0          0          0          0          0   
2019-12-26          0          0          0          0          0          0   
2019-12-27          0          0          0          0          0          0   
2019-12-30          0       

In [19]:
print("df_account_value info:")
print(df_account_value.info())
print("\ndf_account_value head:")
print(df_account_value.head())
print("\ndf_account_value tail:")
print(df_account_value.tail())


print("\ntrade info:")
print(trade.info())
print("\ntrade head:")
print(trade.head())
print("\ntrade tail:")
print(trade.tail())


df_account_value info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 103 entries, 0 to 102
Data columns (total 2 columns):
 #   Column         Non-Null Count  Dtype         
---  ------         --------------  -----         
 0   date           103 non-null    datetime64[ns]
 1   account_value  103 non-null    float64       
dtypes: datetime64[ns](1), float64(1)
memory usage: 1.7 KB
None

df_account_value head:
        date  account_value
0 2019-08-01   999931.31958
1 2019-08-02   992453.31958
2 2019-08-05   977735.31958
3 2019-08-06   979976.31958
4 2019-08-07   978770.31958

df_account_value tail:
          date  account_value
98  2019-12-25   1.176832e+06
99  2019-12-26   1.180060e+06
100 2019-12-27   1.203224e+06
101 2019-12-30   1.226566e+06
102 2019-12-31   1.226184e+06

trade info:
<class 'pandas.core.frame.DataFrame'>
Index: 1560 entries, 0 to 103
Data columns (total 17 columns):
 #   Column          Non-Null Count  Dtype  
---  ------          --------------  -----  
 0   

## Backtest

### matplotlib inline

In [24]:
plotter = ReturnPlotter(df_account_value, trade, TRADE_START_DATE, TRADE_END_DATE, token='27080ec403c0218f96f388bca1b1d85329d563c91a43672239619ef5')
plotter.plot()

In [25]:
# ticket: SSE 50：000016
plotter.plot("000016")

Error fetching baseline data for 000016 (000016.SH, asset=I) from Tushare Pro: 请指定正确的接口名
Error: Could not fetch baseline data for ticket 000016.


### CSI 300

In [26]:
baseline_df = plotter.get_baseline("399300")

Error fetching baseline data for 399300 (399300, asset=E) from Tushare Pro: 请指定正确的接口名


In [23]:
daily_return = plotter.get_return(df_account_value)
daily_return_base = plotter.get_return(baseline_df, value_col_name="close")

perf_func = timeseries.perf_stats 
perf_stats_all = perf_func(returns=daily_return, factor_returns=daily_return_base, positions=None, transactions=None, turnover_denom="AGB")
print("==============DRL Strategy Stats===========")
print(f"perf_stats_all: {perf_stats_all}")

KeyError: 'time'

In [24]:
daily_return = plotter.get_return(df_account_value)
daily_return_base = plotter.get_return(baseline_df, value_col_name="close")

perf_func = timeseries.perf_stats
perf_stats_all = perf_func(returns=daily_return_base, factor_returns=daily_return_base, positions=None, transactions=None, turnover_denom="AGB")

print("==============Baseline Strategy Stats===========")

print(f"perf_stats_all: {perf_stats_all}")

KeyError: 'time'