本文档用于记录使用决策树进行因子挖掘的具体思路和开发过程

决策树因子挖掘

整体框架：
1. 数据来源：数据存储在字典中{data_type:{target_name: dataframe}}
2. 因子计算：随机选择运算方式，根据运算方式决定的datatype和特征数量从数据字典中随机选择数据
3. 因子回测
4. 树的优化算法（遗传算法或者论文方法）

功能设计
1. datareader: 从数据库中读取数据（已完成）
2. tree_builder:
3. 数据分类
4. 树的hash化比对

In [1]:
import pandas as pd
import numpy as np
from tqdm import tqdm
import random
import time
from abc import abstractmethod

# tree ploting
from pyecharts import options as opts
from pyecharts.charts import Tree as TreePloter

In [2]:
from const import *
from classes.database_classes.data_reader import CrossPriceDataReader, CrossValuationDataReader
from classes.back_tester_classes.back_tester import SimpleBackTester, CompletedBackTester

auth success 


In [3]:
# 读入数据
price_data_reader = CrossPriceDataReader()
valuadation_data_reader = CrossValuationDataReader()

In [4]:
price_data_reader.get_all_tables_names()

['cross___price___close.csv',
 'cross___price___high.csv',
 'cross___price___low.csv',
 'cross___price___money.csv',
 'cross___price___open.csv',
 'cross___price___volume.csv']

In [5]:
valuadation_data_reader.get_all_tables_names()

['cross___valuation___capitalization.csv',
 'cross___valuation___circulating_cap.csv',
 'cross___valuation___circulating_market_cap.csv',
 'cross___valuation___code.csv',
 'cross___valuation___day.csv',
 'cross___valuation___market_cap.csv',
 'cross___valuation___pb_ratio.csv',
 'cross___valuation___pcf_ratio.csv',
 'cross___valuation___pe_ratio.csv',
 'cross___valuation___pe_ratio_lyr.csv',
 'cross___valuation___ps_ratio.csv',
 'cross___valuation___turnover_ratio.csv']

In [6]:
data_pe_ratio = valuadation_data_reader.get_one_table("turnover_ratio")

In [7]:
data_pe_ratio

Unnamed: 0_level_0,000001.XSHE,000002.XSHE,000004.XSHE,000005.XSHE,000006.XSHE,000007.XSHE,000008.XSHE,000009.XSHE,000010.XSHE,000011.XSHE,...,688787.XSHG,688788.XSHG,688789.XSHG,688793.XSHG,688798.XSHG,688799.XSHG,688800.XSHG,688819.XSHG,688981.XSHG,689009.XSHG
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2010-01-04,0.8273,1.0044,0.0000,2.4469,1.2770,1.4004,0.0000,1.7571,2.8331,2.5948,...,,,,,,,,,,
2010-01-05,1.9031,1.9145,0.0000,7.0507,2.6615,2.0441,0.0000,3.1842,2.5995,4.9131,...,,,,,,,,,,
2010-01-06,1.4095,1.4070,0.0000,5.4067,2.8047,2.3576,0.0000,2.3690,2.2352,4.3932,...,,,,,,,,,,
2010-01-07,1.2152,1.1935,0.0000,3.4410,2.0726,1.4836,0.0000,4.7855,3.2500,3.8250,...,,,,,,,,,,
2010-01-08,0.9868,1.1240,0.0000,1.8374,1.8194,0.8402,0.0000,5.6043,1.6254,2.7161,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023-07-10,0.2555,0.4127,19.9262,0.5483,0.5529,1.3400,0.8218,0.4678,0.9800,1.0308,...,4.3250,0.4546,0.5792,1.2156,1.4758,1.9588,3.4828,2.0402,0.6115,0.7045
2023-07-11,0.2941,0.6836,22.8382,0.6325,0.6318,3.1915,0.9675,0.3109,0.9267,0.4697,...,4.0306,0.5004,0.8591,0.8159,1.7401,1.3923,4.2388,1.6531,1.0049,0.8776
2023-07-12,0.3793,0.7533,19.9136,1.5197,0.6243,3.1354,0.7353,0.3926,1.3962,0.4475,...,5.4227,0.8539,1.0472,1.3565,0.7148,1.4287,2.8804,2.1860,1.9604,1.1579
2023-07-13,0.5793,0.5079,14.0237,0.5676,0.5281,1.5316,0.5336,0.3231,1.4113,0.3340,...,4.4785,0.7060,1.0270,2.4540,1.2102,1.5746,3.0505,2.7656,1.2812,0.7776


#### 模拟数据字典准备：
1. 仅使用价格数据中的6个特征
2. 所有数据的类别均为any

数据字典使用方法
1. 在data_type_dict中根据数据类型随机选择数据
2. 在data_dict中获取数据

In [8]:
features_list = ['close','high','low','money','open','volume']
data_type_dict = {"any": features_list}
data_dict = {}

for feature in features_list:
    data_dict[feature] = price_data_reader.get_one_table(feature=feature)

#### 计算类的定义
1. 将输入数据的标准放在构造函数中
2. 随机选择数据后将选择的数据（maybe只是数据名称）保存在类中，保持树的稳定性
3. 计算方法名称、输入数据名称和计算结果均保存在外层的树中，计算类仅用作运算和提供数据标准

In [9]:
class CalculatorTemplate(object):

    method_name = None

    data_number = 1
    input_data_type = None
    
    parameters_list = []

    def get_method_name(self):
        return self.method_name
    
    def get_data_number(self):
        return self.data_number
    
    def get_left_input_data_type(self):
        return self.input_data_type[0]
    
    def get_right_input_data_type(self):
        if len(self.input_data_type) == 2:
            return self.input_data_type[1]
        else:
            return "any"
    
    def get_output_data_type(self):
        return self.output_data_type
    
    @abstractmethod
    def calculate(self):
        pass

    pass

In [10]:
class AddCalculator(CalculatorTemplate):
    method_name = "add"
    input_data_type = ["any", "any"]
    output_data_type = ["any"]
    data_number = 2

    def calculate(self, data_1, data_2):
        return data_1 + data_2
    

class MovingaverageCalculator(CalculatorTemplate):
    method_name = "moving average"
    input_data_type = ["any"]
    ouput_data_type = ["any"]
    parameters_list = ["time_window"]

    def calculate(self, data_1: pd.DataFrame, time_window):
        return data_1.rolling(time_window).mean()

In [11]:
calculate_method_list = [AddCalculator, MovingaverageCalculator]

In [12]:
time_window_list = [1, 3, 5, 10, 15, 30, 50, 100, 200]

#### ParametersProvider 定义

* 内置变量
 1. 运算方法空间
 2. 数据空间
 3. 其他参数：时间窗口期等

* 成员函数
1. 基本函数：
   1. 选择随机运算方法
   2. 随机获取数: if left_node(right_node) is None 
   3. 获取随机参数：每种参数对应一个成员函数


In [13]:
class DataProvider(object):
    """
    used to connect the dataframe to the function "calculate"
    """
    def __init__(self, feature_name, data) -> None:
        self.name = feature_name
        self.data = data        
        pass

    def calculate(self):
        return self.data

In [14]:
class ParametersProvier(object):
    # parameter space
    calculator_methods = calculate_method_list
    
    data_type_dict = data_type_dict
    data_dict = data_dict

    time_window_list = time_window_list

    @classmethod
    def pick_calculate_method(cls):
        return random.sample(cls.calculator_methods, 1)[0]()
    
    @classmethod
    def get_data(cls, data_type: str="any"):
        data_name =  random.sample(cls.data_type_dict[data_type], 1)[0]
        return DataProvider(feature_name=data_name, data=cls.data_dict[data_name])
    
    @classmethod
    def get_time_window_parameter(cls, time_window_list: list=None):
        time_window_list = time_window_list or cls.time_window_list
        return int(random.sample(time_window_list, 1)[0])
    
    @classmethod
    def get_parameters(cls, parameter_type:str):

        if parameter_type == "time_window":
            return cls.get_time_window_parameter()
        else:
            raise KeyError("The parameter_type {} is wrong, \
                           please have a check".format(parameter_type))
    
    pass

In [15]:
ParametersProvier.pick_calculate_method()

<__main__.AddCalculator at 0x205541a3730>

#### 树的定义
1. 需要决定计算方式（计算类）和输入的数据
2. 需要解决的问题：随机抽取的数据数量（可能从下一棵树上直接获取）
3. 具有迭代功能的calculate函数

算法实现顺序：
1. 先完成树生长
2. 完成每一部分的算法选择
3. 根据每一部分的算法选择和树结构，补全剩余参数

In [22]:
class Tree(object):

    def __init__(self, tree_depth: int=None, max_depth: int=7) -> None:
        self.tree_depth = tree_depth or 1
        self.max_depth = max_depth

        self.calculate_method = None 
        self.left_node = None
        self.right_node = None
        self.parameter_list = []
        pass

    def _build_tree(self):
        """
        used to build a new tree in the left or right node
        the depth for the new tree would add 1, 
        while the max_depth wpuld remain the same
        """
        return Tree(tree_depth=self.tree_depth+1, max_depth=self.max_depth)
    
    def tree_growing(self):
        """
        used to grow the structure for the whole tree
        """
        if self.tree_depth < (self.max_depth):
            self.left_node = self._build_tree()
        
        if self.tree_depth < (self.max_depth-1):
            self.left_node.tree_growing()

        pass
    
    def _node_initialize(self):
        """
        used to decide the content for one node:
        1. calculate method
        2. choose the data for both nodes
        3. update the parameters
        """
        # calculate method
        self.calculate_method = ParametersProvier.pick_calculate_method()

        # data for both nodes
        if self.left_node is None:
            self.left_node = ParametersProvier.get_data(data_type=self.calculate_method.get_left_input_data_type())
        if self.calculate_method.get_data_number() == 2:
            if self.right_node is None:
                self.right_node = ParametersProvier.get_data(data_type=self.calculate_method.get_right_input_data_type())
        
        # update the parameter_list
        parameter_list = []
        if len(self.calculate_method.parameters_list) != 0:
            for parameter in self.calculate_method.parameters_list:
                parameter_list.append(ParametersProvier.get_parameters(parameter_type=parameter))
            self.parameter_list = tuple(parameter_list)
            
        pass


    def tree_initialize(self):
        """
        used to decide the content for all nodes
        """
        self._node_initialize()

        if type(self.left_node) is Tree:
            self.left_node.tree_initialize()

        if type(self.right_node) is Tree:
            self.right_node.tree_initialize()
        
        pass

    
    def calculate(self):
        """
        used to calculate the result for the whole tree
        """
        if self.calculate_method.get_data_number() == 1:
            return self.calculate_method.calculate(self.left_node.calculate(),
                                                   *self.parameter_list)
        else:
            return self.calculate_method.calculate(self.left_node.calculate(),
                                                   self.right_node.calculate(),
                                                    *self.parameter_list)
    
    # description of the structure of the tree in dict

    def _node_plot(self):
        result_dict = {}
        children = []
        result_dict["name"] = self.calculate_method.method_name

        for node in [self.left_node, self.right_node]:
            if node is not None:
                if type(node) is DataProvider:
                    children.append({"name": node.name})
                else:
                    children.append(node._node_plot())
            
        result_dict["children"] = children
        return [result_dict]

    def get_tree_name(self):
        tree_structure_dict = self._node_plot()
        return str(tree_structure_dict)
    
    def get_tree_hash_name(self):
        tree_name = self.get_tree_name()
        return hash(tree_name)
    
    def tree_plot(self, plot_save:bool=False, plot_save_path: str=TREE_PLOT_SAVE_PATH):
        tree = TreePloter()

        plot_data = self._node_plot()
        tree.set_global_opts(title_opts=opts.TitleOpts(title="Tree结构图"),legend_opts=opts.LegendOpts(is_show=False))
        tree.add(series_name="abc", data=[plot_data], label_opts=opts.LabelOpts(color="red", font_size=17))
        tree.set_colors("white")
        tree.render_notebook()

        if plot_save is True:
            tree_name = self.get_tree_hash_name()
            tree.render("{}/{}.html".format(plot_save_path, tree_name))
        pass

### tree test

In [23]:
tree_test = Tree()
tree_test.tree_growing()
tree_test.tree_initialize()

In [18]:
tree_test.calculate()

Unnamed: 0_level_0,000001.XSHE,000002.XSHE,000004.XSHE,000005.XSHE,000006.XSHE,000007.XSHE,000008.XSHE,000009.XSHE,000010.XSHE,000011.XSHE,...,688787.XSHG,688788.XSHG,688789.XSHG,688793.XSHG,688798.XSHG,688799.XSHG,688800.XSHG,688819.XSHG,688981.XSHG,689009.XSHG
date,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2010-01-04,,,,,,,,,,,...,,,,,,,,,,
2010-01-05,,,,,,,,,,,...,,,,,,,,,,
2010-01-06,,,,,,,,,,,...,,,,,,,,,,
2010-01-07,,,,,,,,,,,...,,,,,,,,,,
2010-01-08,,,,,,,,,,,...,,,,,,,,,,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2023-07-10,1.766449e+09,1.689903e+09,4.389325e+08,2.262215e+07,1.615444e+08,4.432328e+07,1.139342e+08,3.961703e+08,4.053260e+07,9.369802e+07,...,8.672441e+08,5.411343e+07,1.030420e+08,5.338224e+07,2.556057e+08,8.327882e+07,3.797690e+08,1.723230e+08,3.553161e+09,3.781279e+08
2023-07-11,1.831657e+09,2.052511e+09,4.840521e+08,2.369735e+07,1.627543e+08,7.465590e+07,1.220093e+08,3.417674e+08,3.943182e+07,6.584670e+07,...,8.478643e+08,5.520368e+07,1.174679e+08,4.750783e+07,2.866694e+08,7.003880e+07,4.266085e+08,1.538493e+08,3.945143e+09,4.085781e+08
2023-07-12,2.017418e+09,2.132755e+09,4.506135e+08,3.465738e+07,1.596210e+08,7.595046e+07,1.065076e+08,3.632766e+08,4.482466e+07,6.413948e+07,...,8.893714e+08,6.451819e+07,1.279719e+08,5.365553e+07,1.878611e+08,7.036268e+07,3.445206e+08,1.829788e+08,4.857235e+09,4.597732e+08
2023-07-13,2.464717e+09,1.803564e+09,3.537143e+08,2.242839e+07,1.516515e+08,5.007329e+07,9.380131e+07,3.394997e+08,4.509886e+07,5.817657e+07,...,8.442699e+08,6.031279e+07,1.274506e+08,6.634939e+07,2.354636e+08,7.371147e+07,3.533241e+08,2.152029e+08,4.212498e+09,3.893275e+08


In [24]:
tree_test.tree_plot(plot_save=True)

{'name': 'add',
 'children': [{'name': 'add',
   'children': [{'name': 'moving average',
     'children': [{'name': 'moving average',
       'children': [{'name': 'add',
         'children': [{'name': 'moving average',
           'children': [{'name': 'add',
             'children': [{'name': 'open'}, {'name': 'high'}]}]},
          {'name': 'low'}]}]}]},
    {'name': 'volume'}]},
  {'name': 'low'}]}

#### 树状图绘图

In [29]:
plot_data = tree_test._node_plot()

In [30]:
tree = TreePloter()

In [31]:
tree.set_global_opts(
    title_opts=opts.TitleOpts(title="Tree结构图"),
    legend_opts=opts.LegendOpts(is_show=False))


<pyecharts.charts.basic_charts.tree.Tree at 0x205555f9360>

In [32]:
tree.add(series_name="abc", data=[plot_data], label_opts=opts.LabelOpts(color="red", font_size=17))

<pyecharts.charts.basic_charts.tree.Tree at 0x205555f9360>

In [33]:
tree.set_colors("white")

<pyecharts.charts.basic_charts.tree.Tree at 0x205555f9360>

In [162]:
tree.render(path="tree_plot/tree.html")

'd:\\PythonProjects\\quant_platform\\quant_platform\\tree_plot\\tree.html'

In [158]:
tree.render_notebook()

In [34]:
tree.

AttributeError: 'Tree' object has no attribute 'data'