In [None]:
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import GridSearchCV

In [None]:
df = pd.read_csv("data_gyeonggi.csv")

In [None]:
x = df[['death', 'released', 'weather_temp', 'weather_rain', 'social_keyword_corona', 'social_keyword_disinfection', 'social_keyword_mask', 'social_keyword_infection', 'social_keyword_briefing', 'date_red', 'traffic_subway', 'confirmed_foreign', 'group_total', 'PM10']]
y = df[['confirmed']]

x_train, x_test, y_train, y_test = train_test_split(x,y, train_size=0.8, test_size=0.2, random_state=0)

In [None]:
gbrt = GradientBoostingClassifier(random_state=0)

gbrt.fit(x_train,y_train.values.ravel())
print("경기 train 정확도 : {:.3f}".format(gbrt.score(x_train,y_train)))
y_pred=gbrt.predict(x_test)
print("경기 예측 정확도 : {:.3f}".format(accuracy_score(y_test,y_pred)))
print("경기 모델 정확도 : {:.3f}".format(gbrt.score(x,y)))

#train 정확도가 100%이므로 overfitting

In [None]:
import time

# GBM 수행시간 측정을 위함. 시작시간 설정
start_time = time.time()
print('GBM 수행 시간: {:.1f}초'.format(time.time() - start_time))

#GridSearchCV를 통한 하이퍼파라미터 튜닝


param = {
        'max_depth':[2,4,6],
    'min_samples_split':[2,3],
    'n_estimators' : [100,200,300,400,500],
    'learning_rate' : [0.01,0.05,0.1]
}
gb_clf = GradientBoostingClassifier(random_state=0)
grid_cv = GridSearchCV(gb_clf, param_grid=param, cv=2, verbose=1, n_jobs=-1)
grid_cv.fit(x_train, y_train.values.ravel())
print('최적 하이퍼 파라미터: \n', grid_cv.best_params_)
print('최고 예측 정확도: {0:.4f}'.format(grid_cv.best_score_))

In [None]:
#위의 결과로 나온 최적 파라미터로 다시 모델을 학습하여 테스트 세트 데이터에서 예측 성능을 측정
#경기
gbrt = GradientBoostingClassifier(n_estimators=100,random_state=0,learning_rate=0.05,min_samples_split=3,max_depth=2)

gbrt= gbrt.fit(x_train,y_train.values.ravel())
print("경기 train 정확도 : {:.3f}".format(gbrt.score(x_train,y_train)))
print("경기 test 정확도 : {:.3f}".format(gbrt.score(x_test,y_test)))