In [2]:
import numpy as np
import pandas as pd

In [3]:
data = pd.read_csv('./data/train.csv')
data.head(3)

Unnamed: 0,Age,Sex,ChestPainType,RestingBP,Cholesterol,FastingBS,RestingECG,MaxHR,ExerciseAngina,Oldpeak,ST_Slope,HeartDisease
0,56,1,ASY,155,342,1,Normal,150,1,3.0,Flat,1
1,55,0,ATA,130,394,0,LVH,150,0,0.0,Up,0
2,47,1,NAP,110,0,1,Normal,120,1,0.0,Flat,1


In [8]:
from __future__ import annotations

import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_score, \
                            recall_score, f1_score


class PipeLine(object):
    def __init__(self):
        """アトリビュートで訓練データ、正解データを管理する
        df: callメソッドでオリジナルのデータが格納される
        df_num: callで指定した数値データ。クラスメソッドで上書きされる
        df_cat: callで指定したカテゴリデータ。クラスメソッドで上書きされる
        viewer: bool, viewer_row :int 更新後のデータを表示する
        """
        self.df: pd.DataFrame = None
        self.df_num: pd.DataFrame = None
        self.df_cat: pd.DataFrame = None
        self.df_target: pd.DataFrame = None
        self.viewer = True  # 更新したカラムの表示を切り替え
        self.viewer_row = 3  # 表示カラムの行数
        self.random_seed = 42  # 乱数シード値

    def __call__(self,
                 data: pd.DataFrame,
                 numerical=['Age', 'Sex', 'RestingBP', 'Cholesterol', \
                 'FastingBS', 'MaxHR', 'ExerciseAngina', 'Oldpeak'],
                 categorical=['ChestPainType', 'RestingECG', 'ST_Slope'],
                 target=['HeartDisease'],
                 train_flg=True
                 ) -> pd.DataFrame:

        self.df = data
        self.df_num = data[numerical]
        self.df_cat = data[categorical]
        # 正解ラベルが与えられない本番環境では引数からFalseにすること
        if train_flg:
            self.df_target = data[target]
        return None

    def standard_scaler(self):
        """アトリビュートのdf_numを標準化する
        view: 標準化したdf_numを確認できる
        """
        columns = self.df_num.columns
        scaler = StandardScaler()
        scaler.fit(self.df_num)
        self.df_num = scaler.transform(self.df_num)
        self.df_num = pd.DataFrame(self.df_num, columns=columns)
        if self.viewer:
            print('-'*20, '標準化されたdf_num', '-'*20)
            display(self.df_num.head(self.viewer_row))
        return None

    def one_hot(self,
                columns: list[str],
                concat=True) -> pd.DataFrame:
        """指定したカテゴリをワンホットして返し、アトリビュートの更新を行う
        columns: ワンホット化したいカラム名を指定する
        concat: Trueの場合はアトリビュートのself.df_numに連結し更新する
        view: ワンホットされたデータがdf_numにconcatされているのを確認できる
        """
        one_hotted = pd.get_dummies(self.df_cat[columns])
        self.df_num = pd.concat((self.df_num, one_hotted), axis=1)
        if self.viewer:
            print('-'*20, f'ワンホットされたカラム{columns}', '-'*20)
            display(self.df_num.head(self.viewer_row))
        return one_hotted

    def fold_out_split(self, test_size=0.3, to_array=True) -> np.ndarray:
        pack = train_test_split(self.df_num,  self.df_target,
                                test_size=test_size,
                                random_state=self.random_seed)
        x_tr, x_te, y_tr, y_te = pack
        if to_array:
            x_tr, x_te, y_tr, y_te = [i.values for i in pack]
            y_tr, y_te = y_tr.reshape(-1), y_te.reshape(-1)
        if self.viewer:
            print('-'*20, '分割されたデータShape', '-'*20)
            print(f'x_train: {x_tr.shape} x_test: {x_te.shape}')
            print(f'y_train: {y_tr.shape} y_test: {y_te.shape}')
        return x_tr, x_te, y_tr, y_te


### 改修内容
Pipelineのクラスメソッドであるfold_out_splitを改修<br>
クラスメソッドを実行する際に引数でto_array=Falseを指定するとDataFrame型で出力される

In [12]:
data = pd.read_csv('./data/train.csv')
one_hot = ['RestingECG','ST_Slope']

pipe = PipeLine()
pipe.viewer = False
pipe(data)
_ = pipe.one_hot(columns=one_hot)
pipe.standard_scaler()
x_tr, x_te, y_tr, y_te = pipe.fold_out_split(to_array=False)  # to_arrayをFalseに指定するとDataFrame型で出力される
display(x_tr)
display(y_tr)

Unnamed: 0,Age,Sex,RestingBP,Cholesterol,FastingBS,MaxHR,ExerciseAngina,Oldpeak,RestingECG_LVH,RestingECG_Normal,RestingECG_ST,ST_Slope_Down,ST_Slope_Flat,ST_Slope_Up
199,0.675805,0.520852,-0.672566,-1.792941,-0.573753,-0.178647,1.252198,1.036818,-0.481919,0.785575,-0.489267,-0.287456,-0.984543,1.147907
209,1.829427,-1.919930,1.550563,0.958718,-0.573753,0.999968,-0.798596,-0.440119,-0.481919,0.785575,-0.489267,-0.287456,-0.984543,1.147907
9,-2.155813,0.520852,-0.672566,0.922272,-0.573753,1.934731,-0.798596,-0.809353,-0.481919,0.785575,-0.489267,-0.287456,-0.984543,1.147907
595,1.724552,0.520852,1.328250,0.439365,-0.573753,0.227772,-0.798596,-0.809353,2.075039,-1.272953,-0.489267,-0.287456,-0.984543,1.147907
509,-1.526564,0.520852,-0.672566,0.029350,-0.573753,0.918684,1.252198,0.113732,-0.481919,0.785575,-0.489267,-0.287456,1.015700,-0.871151
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
71,0.990429,0.520852,-1.228348,0.503145,-0.573753,0.105846,1.252198,1.036818,-0.481919,-1.272953,2.043874,-0.287456,1.015700,-0.871151
106,-1.211940,0.520852,0.438999,1.468959,-0.573753,1.325103,-0.798596,-0.809353,-0.481919,0.785575,-0.489267,-0.287456,-0.984543,1.147907
270,-1.002191,0.520852,-0.672566,0.603371,-0.573753,1.447029,-0.798596,-0.809353,-0.481919,0.785575,-0.489267,-0.287456,-0.984543,1.147907
435,-1.736314,-1.919930,-0.672566,0.576037,-0.573753,-0.300573,-0.798596,-0.809353,-0.481919,0.785575,-0.489267,-0.287456,-0.984543,1.147907


Unnamed: 0,HeartDisease
199,0
209,0
9,0
595,0
509,0
...,...
71,1
106,0
270,0
435,0
