In [None]:
import joblib
from sklearn.model_selection import train_test_split
from sklearn import svm
from datetime import timedelta
import datetime
import numpy as np
import pandas as pd
from date_helper import get_today_str, str2date, get_date_str, get_today
from pandas.tseries.offsets import BDay
from imblearn.over_sampling import RandomOverSampler
from sklearn.preprocessing import StandardScaler
from sklearn import tree
from sklearn.ensemble import RandomForestClassifier,ExtraTreesClassifier
from xgboost.sklearn import XGBClassifier
from imblearn.over_sampling import SMOTE
import re


class T0_Classifier:
    def __init__(self):
        self.clf= ExtraTreesClassifier()
        self.feature_list = ['Dnvaltrd','Loprc'] 
        self.log_conv_list = ['Vardrestd_3td']
        self.min_date = '2021-08-21'
        self.max_date = None
        self.train_data = None
        self.ana_quotation = None
        self.predict_result = None
        self.train_result = None
        self.test_result = None 
    def period_data(self,ori_data,ana_day):
        """
        取得一段时间内且一定维度的的数据,减少计算量
        """
        
        max_date = ana_day
        self.mindate = (str2date(ana_day)-BDay(30)).strftime("%Y-%m-%d")
        min_date = str2date(self.min_date)
        self.min_date = min_date
        today_str = ana_day
        today = datetime.date(*map(int, today_str.split('-')))
        if max_date is None:
            max_date = get_today()
        else:
            max_date = str2date(max_date)
        self.max_date = max_date
        filter_max_date = max_date
        filter_min_date = min_date - timedelta(days=3)
        quotation_data_arr = []
        keep_varlist = ['Stkcd', 'Trddt', 'Opnprc', 'Hiprc', 'Loprc', 'Clsprc', 'Dnvaltrd', 'Dsmvosd', 'Dretnd','Dsmvtll']
        q_item = ori_data
        q_item = q_item.loc[(q_item.Trddt.dt.date >= filter_min_date) & (q_item.Trddt.dt.date <= filter_max_date), keep_varlist].copy()
        quotation_data_arr.append(q_item)
        q_df = pd.concat(quotation_data_arr, ignore_index=True)
        self.ana_quotation = q_df
    def _gen_pre_close(self,df):
        """
        得到前一天的收盘价的次级函数
        """
        df['Preclsprc'] = pd.Series([np.nan] + list(df['Clsprc'].values[:-1]), index=df.index)
        return df
    def _Trovrt_3dmean(self,df):
        """
        计算三天换手率apply()调用的次级函数
        """
        df['Trovrt_3dmean'] = df['Trovrt'].rolling(3).mean()
        return df
    def _ratio_Trovrt_3dmean(self,x):
        """
        计算ratio_Trovrt_3dmean调用的次级函数
        """
        condition1 = (x['Trovrt_3dmean'] < 4)
        condition2 = (x['Trovrt_3dmean'] < 8) & (x['Trovrt_3dmean'] >= 4)
        condition3 = (x['Trovrt_3dmean'] < 10) & (x['Trovrt_3dmean'] >= 8)
        condition4 = (x['Trovrt_3dmean'] >= 10)
        x['ratio_Trovrt_3dmean'] = 0
        conditions = [condition1,condition2,condition3,condition4]
        for condition in conditions:
            x.loc[condition,'ratio_Trovrt_3dmean'] = (x[condition]['Trovrt_3dmean'] * x[condition]['Dsmvosd']/(x[condition]['Dsmvosd'].max()))
        return x
    def _ratio_Dsmvosd(self,x):
        """
        计算ration_Dsmvosd调用的次级函数
        """
        condition1 = (x['Dsmvosd'] < 10000000000)
        condition2 = (x['Dsmvosd'] < 20000000000) & (x['Dsmvosd'] >= 10000000000)
        condition3 = (x['Dsmvosd'] < 50000000000) & (x['Dsmvosd'] >= 20000000000)
        condition4 = (x['Dsmvosd'] >= 50000000000)
        x['ratio'] = 0
        conditions = [condition1,condition2,condition3,condition4]
        values = [20,10,8,4]
        for condition,value in zip(conditions,values):
            x.loc[condition,'ratio2'] = x[condition]['Trovrt_3dmean']/value
        return x
    def _Bussi(self,x):
        """
        判断是否是科创板调用的函数
        """
        if x.startswith('30'):
            return 1
        if x.startswith('688'):
            return 1
        return 0
    def _R_fel(self,df):
        """
        计算4天日平均涨幅调用的次级函数
        """
        df.loc[:,'R_fel'] = ((df['Clsprc'].shift(-1)-df['Clsprc'])/df['Clsprc']).shift().rolling(4).mean()
        return df
    def gen_vardre(self,df, var_name, start, end=None):
        """
        计算波动率调用的次级函数
        """
        min_date = str2date(start)
        if end is None:
            max_date = get_today()
        else:
            max_date = str2date(end)

        df = df[(df.Trddt.dt.date >= min_date) & (df.Trddt.dt.date <= max_date)]
        res = df.groupby('Stkcd', as_index=False).agg({'Dretnd': [(var_name, 'std')]})
        res.columns = [_[1]  if _[1] else _[0] for _ in res.columns.ravel()]
        return res
    def get_date(self,today = None):
        today = self.max_date
        y = [today - BDay(3)]
        res = [x.strftime("%Y-%m-%d") for x in y if 1]
        names = ['Vardrestd_3td']
        return res,names
    def trovert(self,feature = False):
        """
        计算每日还手率
        """
        self.ana_quotation['Trovrt'] = self.ana_quotation['Dnvaltrd'] / self.ana_quotation['Dsmvosd'] * 100
        if feature==True :
            self.feature_list.append('Trovrt')
    def trovrt_3dmean(self,feature = False):
        """
        计算前三天平均换手率
        """
        self.ana_quotation = self.ana_quotation.sort_values(['Stkcd', 'Trddt']).groupby('Stkcd').apply(self._Trovrt_3dmean)
        self.ana_quotation.index.names = ['key1']
        if feature==True :
            self.feature_list.append('Trovrt_3dmean')
    def gen_pre_close(self,feature = False):
        """
        计算前一天的收盘价
        """
        self.ana_quotation = self.ana_quotation.sort_values(by=['Stkcd', 'Trddt']).groupby('Stkcd').apply(self._gen_pre_close)
        if feature==True :
            self.feature_list.append('Preclsprc')
    def drange(self,feature=False):
        """
        计算振幅
        """
        self.ana_quotation['Drange'] = (self.ana_quotation.Hiprc - self.ana_quotation.Loprc) / self.ana_quotation.Preclsprc * 100
        if feature==True :
            self.feature_list.append('Drange')
    def ratio_Trovrt_3dmean(self,feature=False):
        """
        计算换手率区间内成交额的相对比率
        """
        self.ana_quotation = self._ratio_Trovrt_3dmean(self.ana_quotation)
        self.ana_quotation = self.ana_quotation[~self.ana_quotation['Stkcd'].isna()]
        if feature == True:
            self.feature_list.append('ratio_Trovrt_3dmean')
    def ratio_Dsmvosd(self,feature = False):
        """
        计算成交额区间内换手率的相对比率
        """
        self.ana_quotation = self._ratio_Dsmvosd(self.ana_quotation)
        self.ana_quotation = self.ana_quotation[~self.ana_quotation['Stkcd'].isna()]
        if feature == True:
            self.feature_list.append('ratio_Dsmvosd')
    def r_fel(self,feature=False):
        """
        计算四天平均涨伏
        """
        self.ana_quotation.Trddt = pd.to_datetime(self.ana_quotation.Trddt)
        self.ana_quotation  = self.ana_quotation.groupby('Stkcd').apply(self._R_fel)
        if feature == True:
            self.feature_list.append('R_fel')
    def bussi(self,feature=False):
        """
        判断是否为科创板
        """
        self.ana_quotation.loc[:,'Bussi'] = self.ana_quotation['Stkcd'].apply(self._Bussi)
        if feature == True:
            self.feature_list.append('Bussi')
    def vardrestd3(self,today,feature=False):
        """
        计算三天的波动率
        """
        min_date = self.min_date
        max_date = self.max_date
        ana_day = target_date
        quotation_data = self.ana_quotation[(self.ana_quotation.Trddt.dt.date >= min_date) & (self.ana_quotation.Trddt.dt.date <= max_date)].copy()
        ana_quotation = quotation_data[quotation_data.Trddt.dt.date.eq(str2date(ana_day))]
        dates,names = self.get_date(today)
        for date,name in zip(dates,names):
            vardrestd_df = self.gen_vardre(quotation_data, name , date,end = ana_day)
            ana_quotation = ana_quotation.merge(right=vardrestd_df, on='Stkcd', how='left', validate='one_to_one')
        self.ana_quotation = ana_quotation
        if feature == True:
            self.feature_list.append('Vardrestd_3td')
    def varify_ratio_DT(self,feature=False):
        """
        判断是否符合交易员的标准
        """
        condition1 = (self.ana_quotation['Dsmvosd'] > 80000000000) & (self.ana_quotation['Trovrt_3dmean'] > 4)
        condition2 = (self.ana_quotation['Dsmvosd'] > 40000000000) & (self.ana_quotation['Trovrt_3dmean'] > 8) & (self.ana_quotation['Dsmvosd'] < 80000000000)
        condition3 = (self.ana_quotation['Dsmvosd'] > 10000000000) & (self.ana_quotation['Trovrt_3dmean'] > 10) & (self.ana_quotation['Dsmvosd']< 40000000000)
        condition4 = (self.ana_quotation['Dsmvosd']<10000000000) & (self.ana_quotation['Trovrt_3dmean'] > 15)
        condition = condition1 | condition2 | condition3 | condition4
        self.ana_quotation['varify_ratio_DT'] = 0
        self.ana_quotation.loc[condition,'varify_ratio_DT'] = 1
        if feature == True:
            self.feature_list.append('varify_ratio_DT')
    def log_convert(self):
        """
        对相应列进行对数变换
        """
        self.ana_quotation = self.ana_quotation.fillna(method = 'ffill')
        self.ana_quotation[self.log_conv_list] = self.ana_quotation[self.log_conv_list].replace(0,0.0001)
        self.ana_quotation[self.log_conv_list] =  self.ana_quotation[self.log_conv_list].apply(np.log,axis = 1)
        self.ana_quotation = self.ana_quotation.dropna()
    def train_data_merge(self,sample):
        """
        整合训练数据
        """
        self.train_data = pd.merge(sample , self.ana_quotation,how='left',left_on = 'Stkcd',right_on = 'Stkcd')       
    def data_preprocess(self,cal_data,ana_day,sample=None):
        """
        数据处理，特征计算
        """
        self.period_data(cal_data,ana_day)

        self.trovert()

        # 计算3天平均换手率
        self.trovrt_3dmean(True)

        # 计算昨收
        self.gen_pre_close()

        #   计算振幅
#         self.drange()
        # 计算ratio_Trovrt_3dmean
        #  self.ratio_Trovrt_3dmean()
        # 计算ratio_Dsmvosd
        #self.ratio_Dsmvosd()
        self.r_fel(True)

        #计算波动率
        self.vardrestd3(ana_day,feature = True)

        #varify_ratio_DT()
        self.bussi(True)

        #对数处理
#         self.log_convert()

        #训练数据整合
        if sample is not None:
            self.train_data_merge(sample)

    def train(self):
        
        """
        训练模型
        """
        train_data = self.train_data.dropna().copy()
        X = train_data[self.feature_list]
        y = train_data['score']
        ros = RandomOverSampler(random_state=0)
        X, y = ros.fit_resample(X, y)
        std = StandardScaler()
        
        train_X,test_X, train_y, test_y = train_test_split(X,
                                                   y,
                                                   test_size = 0.2)
#         train_X.loc[:,self.feature_list[:-1]] = std.fit_transform(train_X[self.feature_list[:-1]])
        self.clf.fit(train_X,train_y)
        Tr_score = self.clf.score(train_X,train_y)
        self.train_result = Tr_score
#         test_X.loc[:,self.feature_list[:-1]] = std.fit_transform(test_X[self.feature_list[:-1]])
        Te_score = self.clf.score(test_X,test_y)
        self.test_result = Te_score
        data_split = {'train_X':train_X,'test_X':test_X, 'train_y':train_y, 'test_y':test_y}
        return data_split
    def evaluate(self):
        """
        评价模型
        """
        print("train_score:",self.train_result)
        print("test_score:",self.test_result)
    def predict(self):
        """
        预测结果
        """
        temp = self.ana_quotation.copy().dropna()
#         std = StandardScaler()
        temp[['prob_0','prob_1','prob_2','prob_3']] = self.clf.predict_proba((temp[self.feature_list]))
        temp['pre'] = self.clf.predict(temp[self.feature_list])
        if self.train_data is None:
            self.predict_result = temp[['Stkcd','prob_0','prob_1','prob_2','prob_3','pre']]
        else:
            temp = temp.merge(right = self.train_data,on='Stkcd',how = 'left')
            self.predict_result = temp[['Stkcd','prob_0','prob_1','prob_2','prob_3','pre','score']] 
        return  self.predict_result
    def clf_save(self,clf_filename):
        joblib.dump(self.clf,clf_filename)
    def clf_load(self,clf_filename):
        self.clf = joblib.load(clf_filename)

In [None]:
def fun_zfill(x):
    return x.zfill(6)


def get_sample(mark_filenames,db_filename,target_date):
    mark = [pd.read_excel(i,converters = {'Stkcd':str})[['Stkcd','评分']] for i in mark_filenames]
    mark = pd.concat(mark).drop_duplicates(subset = 'Stkcd',keep = 'last')
    mark = mark.dropna()
    mark = mark.rename(columns = {"评分":'score'})
    db_data = pd.read_csv(db_filename,
                      parse_dates=['createTime','endTime']
                     ,converters = {'assetCode':str},usecols = [2,3,5,7],encoding = 'ANSI')

    db_data['assetCode'] = db_data['assetCode'].apply(fun_zfill)
    condition = (db_data['endTime'].dt.date>=str2date(target_date)) & (db_data['createTime'].dt.date<=str2date(target_date))

    db_data =pd.DataFrame(db_data.loc[condition,'assetCode'])
    db_data['score']=3
    db_data.columns = ['Stkcd','score']
    db_data = db_data.drop_duplicates()
    sample = pd.concat([mark,db_data])
    
    sample = sample.dropna()
    sample.loc[sample['Stkcd'].isin(db_data['Stkcd'].values),'score'] = 3
    sample = sample.drop_duplicates()
    return sample


def get_q_item(q_item_filename,name_filename):
    q_item_ = pd.read_csv(q_item_filename,
                                skiprows=lambda x: x in [1,2],
                                parse_dates=['Trddt'],
                                header=0,
                                converters={
                                    'Stkcd': str,
                                    'Dsmvosd': lambda x: np.float64(x) * 1000,
                                    'Dsmvtll': lambda x: np.float64(x) * 1000,
                                }
                            )
    code_name = pd.read_csv(name_filename,
                                header=0,
                                converters={
                                    'Stkcd': str,
                                }
                            ,usecols = [0,1]
                            )
    code_name['Stkcd'] = code_name['Stkcd'].apply(fun_zfill)
    return q_item_,code_name


def model_train(q_item_filename,map_filename,target_date,mark_filenames,db_filename,model_path=None,save_path =None):
    sample = get_sample(mark_filenames,db_filename,target_date)
    q_item,code_name = get_q_item(q_item_filename,map_filename)
    a = T0_Classifier()
    a.data_preprocess(q_item,target_date,sample)
    a.clf = ExtraTreesClassifier()
    a.train()
    a.evaluate()
    predict_result = a.predict()
    predict_result = predict_result.merge(code_name,on='Stkcd',how='left') 
    if save_path is not None:
        predict_result.to_csv(save_path+'/result.csv')
    if model_path is not None:
        a.clf_save(model_path)
        print('Model Saved Sucessfully')
    return predict_result


def model_predict(q_item_filename,map_filename,target_date,load_model_path,save_path =None):
    q_item,code_name = get_q_item(q_item_filename,map_filename)
    a = T0_Classifier()
    a.data_preprocess(q_item,target_date)
    a.clf_load(load_model_path)
    predict_result = a.predict()
    predict_result = predict_result.merge(code_name,on='Stkcd',how='left') 
    if save_path is not None:
        predict_result.to_csv(save_path+'/result.csv')
    return predict_result