In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from matplotlib import rcParams
import platform

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.ensemble import IsolationForest

# 모듈로 구현한 데이터 분석용 클래스 import 
from modules.DataAnalysis import DataCheck
from modules.DataAnalysis import DataPreprocessing
from modules.DataAnalysis import DataVisualize

import modules.DataModify as DataModify

# 모듈로 구현한 머신러닝 모델 import
from modules.Models import TreeXGBoostCox
import modules.ModelAnalysis as ModelAnalysis


In [3]:
### matplotlib 에서 한글 및 음수 표현이 깨지는 현상 수정

system = platform.system()

if system == "Windows":     # Windows
    rcParams['font.family'] = 'Malgun Gothic'
elif system == "Darwin":    # macOS
    rcParams['font.family'] = 'AppleGothic'

rcParams['axes.unicode_minus'] = False

In [4]:
### 데이터 로드

input_file_path = './data/Tree_data.csv'

df = pd.read_csv(input_file_path)

dp = DataPreprocessing(df)

In [5]:
### 분석 결과를 바탕으로 드랍할 컬럼 설정
drop_cols = ['No', 'Plot', 'Adult', 'Subplot', 'Core', 'Census']
dp.set_drop_cols(drop_cols)

### 전처리 과정 실행 
# (데이터 드랍 -> 이상치 처리 -> 결측치 처리 -> 날짜값 통일 -> 라벨 데이터 통합 -> 인코딩)

encode = dp.run(encoding='label', return_anomaly=False) # return anomaly : 이상치로 판단된 특성들을 확인할 것인지
display(encode)

Unnamed: 0,Species,Light_ISF,Light_Cat,Soil,Sterile,Conspecific,Myco,SoilMyco,PlantDate,AMF,EMF,Phenolics,Lignin,NSC,Time,Alive
0,0,0.106,0,0,0,0,0,0,0,22.00,0.00,0.79,13.86,12.15,14.0,0
1,1,0.106,0,1,0,0,1,1,1,15.82,31.07,6.54,20.52,19.29,115.5,1
2,2,0.106,0,0,0,0,1,0,1,24.45,28.19,4.71,24.74,15.01,63.0,0
3,0,0.080,0,0,0,0,0,0,0,22.23,0.00,0.64,14.29,12.36,14.0,0
4,0,0.060,1,0,0,0,0,0,0,21.15,0.00,0.77,10.85,11.20,14.0,0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2777,1,0.122,2,1,0,0,1,1,1,10.89,39.00,6.88,21.44,18.99,56.0,0
2778,3,0.111,0,3,0,0,0,1,0,40.89,0.00,2.18,9.15,11.88,56.0,0
2779,1,0.118,0,2,0,0,1,0,1,15.47,32.82,6.23,19.01,23.50,56.0,0
2780,1,0.118,0,1,0,0,1,1,1,11.96,37.67,6.86,21.13,19.10,56.0,0


### Cox 모델을 이용한 생존율 예측
  
##### Cox 모델이란?
t 시간 후의 생존율을 계산할 때, Event(사망)이 발생했는지 여부를 통해 계산함.
  
위험 점수를 통해 생존율을 예측하고, 해당 값을 이진 분류하여 정확도를 체크할 수 있음

##### Cox-XGBoost
X를 입력, y를 출력하는 기존 모델의 형태와 달리, X(특성), y(시간)을 입력값으로 받아 e(이벤트 발생 여부)에 관한 정보를 출력하는 모델이다.


In [6]:
train_set, test_set = DataModify.split_data_X_y_e(encode, random_state=42)


In [7]:
X, y, e = test_set

print(X, y, e)

      Species  Light_ISF  Light_Cat  Soil  Sterile  Conspecific  Myco  \
2407        1      0.055          1     4        1            1     1   
2206        0      0.055          1     6        0            0     0   
1728        2      0.047          1     5        0            0     1   
2154        2      0.087          0     0        0            0     1   
2701        3      0.065          1     6        0            0     0   
...       ...        ...        ...   ...      ...          ...   ...   
928         2      0.082          0     3        0            0     1   
388         0      0.092          0     3        0            0     0   
1513        1      0.138          2     0        0            0     1   
1021        2      0.079          0     0        0            0     1   
1339        2      0.073          1     5        0            0     1   

      SoilMyco  PlantDate    AMF    EMF  Phenolics  Lignin    NSC  
2407         2          1  10.58   4.23       3.64   19

In [8]:

xgb_cox = TreeXGBoostCox()

X, y, e = train_set

X_test, y_test, e_test = test_set

xgb_cox.fit(X, y, e)
score = xgb_cox.score(X_test, y_test, e_test, t=60, show_comparison=True)

print(score)

Predicted  Actual
1          1         233
0          0         216
           1          55
1          0          53
Name: count, dtype: int64
0.8061041292639138


In [9]:
conf_ = xgb_cox.confusion_matrix(*test_set, t=115.5)
print(ModelAnalysis.calculate_metrics(conf_))
print(conf_)

{'Accuracy': 0.6822262118491921, 'Precision': 0.6821705426356589, 'Recall': 0.8301886792452831}
{'TP': 264, 'TN': 116, 'FP': 123, 'FN': 54}
