In [1]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

In [2]:
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from sklearn.preprocessing import LabelEncoder, OneHotEncoder
from sklearn.model_selection import train_test_split, cross_val_score

# 모델 학습
from sklearn.neighbors import KNeighborsRegressor
from sklearn.linear_model import LinearRegression, Ridge, Lasso
from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import SGDRegressor

from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.svm import SVC
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier

from sklearn.ensemble import RandomForestClassifier, VotingClassifier, BaggingClassifier
from sklearn.ensemble import AdaBoostClassifier, GradientBoostingClassifier
from xgboost import XGBClassifier
from lightgbm import LGBMClassifier

from sklearn.cluster import KMeans

from sklearn.decomposition import PCA

# 성능 평가
from sklearn.metrics import confusion_matrix
from sklearn.metrics import accuracy_score, precision_score, recall_score
from sklearn.metrics import roc_auc_score

from sklearn.metrics import mean_squared_error, r2_score

from sklearn.metrics import silhouette_score

# 경고 메세지 숨김
import warnings     
warnings.filterwarnings('ignore')

## 테이블 열기

In [3]:
fish = pd.read_csv('https://bit.ly/fish_csv_data')
fish[:5]

Unnamed: 0,Species,Weight,Length,Diagonal,Height,Width
0,Bream,242.0,25.4,30.0,11.52,4.02
1,Bream,290.0,26.3,31.2,12.48,4.3056
2,Bream,340.0,26.5,31.1,12.3778,4.6961
3,Bream,363.0,29.0,33.5,12.73,4.4555
4,Bream,430.0,29.0,34.0,12.444,5.134


In [4]:
# 특성
fish_data = fish.iloc[:, 1:]

# 타겟
fish_target = fish.iloc[:, 0]

In [5]:
# 특성 확인
fish_data.head()

Unnamed: 0,Weight,Length,Diagonal,Height,Width
0,242.0,25.4,30.0,11.52,4.02
1,290.0,26.3,31.2,12.48,4.3056
2,340.0,26.5,31.1,12.3778,4.6961
3,363.0,29.0,33.5,12.73,4.4555
4,430.0,29.0,34.0,12.444,5.134


In [6]:
# 타겟 확인
fish_target.value_counts()

Perch        56
Bream        35
Roach        20
Pike         17
Smelt        14
Parkki       11
Whitefish     6
Name: Species, dtype: int64

# 데이터 분할

In [25]:
data = fish_data       # 특성(5개)
target = fish_target     # 정답(7개)

xtrain, xtest, ytrain, ytest = train_test_split(
 
    data, target             # x: 특성, y: 정답
    , test_size = 0.25       # 테스트 데이터 비율(0~1)
#    , stratify = target      # 계층화(훈련＆테스트 데이터 비율: x=y)
    , random_state = 42     # 랜덤 시드
    )

# 데이터 스케일링(표준화)

In [8]:
ss = StandardScaler()

xtrain_scaled = ss.fit_transform(xtrain)    # 훈련 데이터 → 표준화
xtest_scaled = ss.transform(xtest)          # 테스트 데이터 → 표준화

# K-최근접 이웃

In [9]:
knc = KNeighborsClassifier(
    n_neighbors = 3    # 하이퍼 파라미터
    )

# (분류)변수.fit(xtrain, ytrain)        # 모델 학습
knc.fit(xtrain_scaled, ytrain)   # 모델 학습(스케일링)

knc.classes_  # 타겟 목록

array(['Bream', 'Parkki', 'Perch', 'Pike', 'Roach', 'Smelt', 'Whitefish'],
      dtype=object)

In [10]:
knc.score( xtrain_scaled, ytrain )     # 훈련 데이터 정확도(스케일링)

0.8907563025210085

In [11]:
knc.score( xtest_scaled, ytest )       # 테스트 데이터 정확도(스케일링)

0.85

In [12]:
# 예측
knc.predict( xtest_scaled[:5] )

array(['Perch', 'Smelt', 'Pike', 'Perch', 'Perch'], dtype=object)

In [13]:
# 예측 확률
proba = knc.predict_proba( xtest_scaled[:5] )

# 소수점 반올림
np.round(proba, decimals=4)

array([[0.    , 0.    , 1.    , 0.    , 0.    , 0.    , 0.    ],
       [0.    , 0.    , 0.    , 0.    , 0.    , 1.    , 0.    ],
       [0.    , 0.    , 0.    , 1.    , 0.    , 0.    , 0.    ],
       [0.    , 0.    , 0.6667, 0.    , 0.3333, 0.    , 0.    ],
       [0.    , 0.    , 0.6667, 0.    , 0.3333, 0.    , 0.    ]])

In [19]:
knc.kneighbors(xtest_scaled[3:4])    # (

(array([[0.20774583, 0.24862983, 0.33682411]]), array([[104, 115, 106]]))

In [18]:
# distances, indexes = knc.kneighbors(xtest_scaled[3:4])    # ( 최근접 이웃까지의 거리, 최근접 이웃 인덱스 )
# ytrain[indexes]

In [14]:
knc.kneighbors(xtest_scaled[:5])    # ( 최근접 이웃까지의 거리, 최근접 이웃 인덱스 )

(array([[0.13880285, 0.15188629, 0.15908025],
        [0.07310337, 0.10341686, 0.11506625],
        [0.63337713, 0.80646808, 0.82050896],
        [0.20774583, 0.24862983, 0.33682411],
        [0.17898697, 0.1859878 , 0.19013472]]),
 array([[ 39,  64,  63],
        [ 93,  99,  43],
        [100, 105,  27],
        [104, 115, 106],
        [ 72,  26,  68]]))