In [1]:
import pickle
from pathlib import Path
import sys

import pandas as pd
from sklearn.svm import SVC
from xgboost import XGBClassifier

from module.mymodule import grid_search_cv
from features import pipe_1, pipe_2, pipe_3, pipe_4, pipe_5, pipe_6, pipe_7,\
                     pipe_8, pipe_9, pipe_10, pipe_11, pipe_12, pipe_13, pipe_14

# モデルのチューニングと訓練を行う

### 特徴量候補を用意してデータセット作成

In [2]:
df = pd.read_csv('./data/train.csv')
to_pipe ={
            'df': df,
            'split_kwrg': {'test_size': 0.2, 'to_array': True},
            'train_flg': True,
            'retrain': False,
            }
# 特徴量候補を設定
pipe_lines = [
            # pipe_1,  # base 
            # pipe_2,  # StSlopeCat
            # pipe_3,  # CholsetMean
            # pipe_4,  # AgeCat
            # pipe_5,  # StSlopeCat CholetMean AgeCat
            # pipe_6,  # RestingBpCat
            # pipe_7,  # OldPeakCat
            # pipe_8,  # RestingBpCat OldPeakCat
            pipe_9,  # Onehot
            pipe_10,  # CholestMean AgeCat Onehot
            pipe_11,  # CholCut
            pipe_12,  # CholCut Onehot,
            pipe_13,  # DropByShap
            pipe_14,
            ]
data_set = {pipe.__name__: pipe(**to_pipe) for pipe in pipe_lines}

                          Onehot Standard(pipe_9)                          


Unnamed: 0,Age,Sex,RestingBP,Cholesterol,FastingBS,MaxHR,ExerciseAngina,Oldpeak,ChestPainType_ASY,ChestPainType_ATA,ChestPainType_NAP,ChestPainType_TA,RestingECG_LVH,RestingECG_Normal,RestingECG_ST,ST_Slope_Down,ST_Slope_Flat,ST_Slope_Up
0,56,1,155,342,1,150,1,3.0,1,0,0,0,0,1,0,0,1,0
1,55,0,130,394,0,150,0,0.0,0,1,0,0,1,0,0,0,0,1
2,47,1,110,0,1,120,1,0.0,0,0,1,0,0,1,0,0,1,0


                CholestMean AgeCat Onehot Standard(pipe_10)                


Unnamed: 0,Age,Sex,RestingBP,Cholesterol,FastingBS,MaxHR,ExerciseAngina,Oldpeak,ChestPainType_ASY,ChestPainType_ATA,ChestPainType_NAP,ChestPainType_TA,RestingECG_LVH,RestingECG_Normal,RestingECG_ST,ST_Slope_Down,ST_Slope_Flat,ST_Slope_Up
0,3,1,155,342.0,1,150,1,3.0,1,0,0,0,0,1,0,0,1,0
1,3,0,130,394.0,0,150,0,0.0,0,1,0,0,1,0,0,0,0,1
2,2,1,110,243.414258,1,120,1,0.0,0,0,1,0,0,1,0,0,1,0


                              CholCut(pipe_11)                             


Unnamed: 0,Age,Sex,RestingBP,Cholesterol,FastingBS,MaxHR,ExerciseAngina,Oldpeak
0,56,1,155,342,1,150,1,3.0
1,55,0,130,394,0,150,0,0.0
2,54,0,160,201,0,163,0,0.0


                          CholCut Onehot(pipe_12)                          


Unnamed: 0,Age,Sex,RestingBP,Cholesterol,FastingBS,MaxHR,ExerciseAngina,Oldpeak,ChestPainType_ASY,ChestPainType_ATA,ChestPainType_NAP,ChestPainType_TA,RestingECG_LVH,RestingECG_Normal,RestingECG_ST,ST_Slope_Down,ST_Slope_Flat,ST_Slope_Up
0,56,1,155,342,1,150,1,3.0,1,0,0,0,0,1,0,0,1,0
1,55,0,130,394,0,150,0,0.0,0,1,0,0,1,0,0,0,0,1
2,54,0,160,201,0,163,0,0.0,0,0,1,0,0,1,0,0,0,1


                            DropByShap(pipe_13)                            


Unnamed: 0,Age,Sex,Cholesterol,FastingBS,MaxHR,Oldpeak,ChestPainType_ASY,ST_Slope_Flat,ST_Slope_Up
0,56,1,342,1,150,3.0,1,1,0
1,55,0,394,0,150,0.0,0,0,1
2,47,1,0,1,120,0.0,0,1,0


                        ZeroCat DropByShap(pipe_14)                        


Unnamed: 0,Age,Sex,Cholesterol,FastingBS,MaxHR,Oldpeak,ChestPainType_ASY,ST_Slope_Flat,ST_Slope_Up
0,56,1,342,1,150,3.0,1,1,0
1,55,0,394,0,150,0.0,0,0,1
2,54,0,201,0,163,0.0,0,0,1


### モデル候補を用意

In [3]:
# XGBoost
xgboost = {'model': XGBClassifier,
           'param_grid': {
                          'max_depth':[3, 5, 7, 9, 15],
                          'learning_rate': [0.05, 0.1, 0.3],
                          'n_estimators': [50, 75, 100, 150],
                          },
           'model_arg': {'random_state': 42, 'early_stopping_rounds': 50}
          }
# SVC
svc = {'model': SVC,
           'param_grid': {
                          'C': [0.01, 0.1, 1.0, 10],
                          'kernel': ['linear', 'rbf']
                          },
           'model_arg': {'random_state': 42, 'probability': True}
          }

### モデルの訓練

In [4]:
# モデルの候補を設定
model_candidates = [
                    xgboost,
                    svc
                    ]

trained_models = {}  # 訓練したモデルの格納先
for candidate in model_candidates:
    print(candidate['model'].__name__.center(50, '#'))
    models = {}
    for key, pack in data_set.items():
        print(key.center(50))
        models[key] = grid_search_cv(pack, **candidate)
    trained_models[candidate['model'].__name__] = models
    
#モデルを保存
file_name = 'mixed'
# file_name = 'xgboost'
# file_name = 'svc'
with open(f'./data/{file_name}.pkl', mode='wb') as f:
    pickle.dump(trained_models, f)

##################XGBClassifier###################
                      pipe_9                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.94152,0.936396,0.956679,0.946429
test,0.868217,0.901235,0.890244,0.895706


                     pipe_10                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.916179,0.914894,0.931408,0.923077
test,0.852713,0.888889,0.878049,0.883436


                     pipe_11                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.951807,0.972973,0.923077,0.947368
test,0.817308,0.84,0.792453,0.815534


                     pipe_12                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.937349,0.928934,0.938462,0.933673
test,0.836538,0.875,0.792453,0.831683


                     pipe_13                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.929825,0.928826,0.942238,0.935484
test,0.868217,0.891566,0.902439,0.89697


                     pipe_14                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.992771,0.994845,0.989744,0.992288
test,0.875,0.916667,0.830189,0.871287


#######################SVC########################
                      pipe_9                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.883041,0.875433,0.913357,0.893993
test,0.837209,0.858824,0.890244,0.874251


                     pipe_10                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.881092,0.880282,0.902527,0.891266
test,0.844961,0.8875,0.865854,0.876543


                     pipe_11                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.814458,0.859756,0.723077,0.785515
test,0.826923,0.87234,0.773585,0.82


                     pipe_12                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.915663,0.908163,0.912821,0.910486
test,0.875,0.87037,0.886792,0.878505


                     pipe_13                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.875244,0.861017,0.916968,0.888112
test,0.868217,0.873563,0.926829,0.899408


                     pipe_14                      
-------------------- 評価結果 --------------------


Unnamed: 0,accuracy_score,precision_score,recall_score,f1_score
train,0.896386,0.887755,0.892308,0.890026
test,0.846154,0.862745,0.830189,0.846154
