# 5. 다중 분류 (Multi-class Classification) - KNN

* 타겟 데이터에 2개 이상의 클래스가 포함된 문제를 다중 분류라고 함

## 4-1. 데이터 준비하기

In [51]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

ori_data = pd.read_csv('data/02_fish/fish.csv')
data = ori_data.copy()
data.head()

Unnamed: 0,Species,Weight,Vertical_Length,Diagonal_Length,Cross_Length,Height,Width
0,Bream,242.0,23.2,25.4,30.0,11.52,4.02
1,Bream,290.0,24.0,26.3,31.2,12.48,4.3056
2,Bream,340.0,23.9,26.5,31.1,12.3778,4.6961
3,Bream,363.0,26.3,29.0,33.5,12.73,4.4555
4,Bream,430.0,26.5,29.0,34.0,12.444,5.134


In [52]:
pd.unique(data['Species'])

array(['Bream', 'Roach', 'Whitefish', 'Parkki', 'Perch', 'Pike', 'Smelt'],
      dtype=object)

In [53]:
# 입력데이터 만들기
input_df = data[['Weight','Vertical_Length','Diagonal_Length','Cross_Length','Height','Width']]
input_df.head()

Unnamed: 0,Weight,Vertical_Length,Diagonal_Length,Cross_Length,Height,Width
0,242.0,23.2,25.4,30.0,11.52,4.02
1,290.0,24.0,26.3,31.2,12.48,4.3056
2,340.0,23.9,26.5,31.1,12.3778,4.6961
3,363.0,26.3,29.0,33.5,12.73,4.4555
4,430.0,26.5,29.0,34.0,12.444,5.134


In [54]:
# 입력데이터는 2차원 배열로 만들기
input_data = input_df.to_numpy()
input_data[:5]

array([[242.    ,  23.2   ,  25.4   ,  30.    ,  11.52  ,   4.02  ],
       [290.    ,  24.    ,  26.3   ,  31.2   ,  12.48  ,   4.3056],
       [340.    ,  23.9   ,  26.5   ,  31.1   ,  12.3778,   4.6961],
       [363.    ,  26.3   ,  29.    ,  33.5   ,  12.73  ,   4.4555],
       [430.    ,  26.5   ,  29.    ,  34.    ,  12.444 ,   5.134 ]])

In [55]:
# 타겟데이터는 1차원 배열로 만들기
target_data = data['Species'].to_numpy()
target_data[:5]

array(['Bream', 'Bream', 'Bream', 'Bream', 'Bream'], dtype=object)

## 4-2. 데이터 나누기

In [56]:
from sklearn.model_selection import train_test_split

# 훈련 세트와 테스트 세트 나누기
train_input, test_input, train_target, test_target = train_test_split(input_data, target_data, stratify=target_data, random_state=42)
train_input.shape, test_input.shape

((119, 6), (40, 6))

## 4-3. 데이터 전처리 - 표준화

In [57]:
from sklearn.preprocessing import StandardScaler
ss = StandardScaler()
ss.fit(train_input)
train_scaled = ss.transform(train_input)
test_scaled = ss.transform(test_input)

In [58]:
train_scaled[:5]

array([[-0.75628803, -0.64716022, -0.66065677, -0.62357446, -0.78015159,
        -0.45043644],
       [-0.45991057, -0.12483205, -0.1248453 , -0.24414603, -0.4293487 ,
         0.03516919],
       [ 0.07356886, -0.00991985,  0.0212851 ,  0.2165885 ,  0.79541208,
         0.37481797],
       [ 1.54063728,  1.00339682,  1.0441979 ,  1.23743166,  2.29283234,
         1.34130358],
       [-0.87483902, -0.79341211, -0.75807703, -0.82232269, -0.80672937,
        -0.5697143 ]])

## 4-4. 모델 학습 및 평가

In [65]:
from sklearn.neighbors import KNeighborsClassifier

kn = KNeighborsClassifier(n_neighbors=3)
kn.fit(train_scaled, train_target)

KNeighborsClassifier(n_neighbors=3)

In [66]:
kn.predict(test_scaled)

array(['Perch', 'Perch', 'Roach', 'Parkki', 'Parkki', 'Roach', 'Perch',
       'Pike', 'Bream', 'Perch', 'Pike', 'Roach', 'Bream', 'Perch',
       'Perch', 'Bream', 'Perch', 'Bream', 'Pike', 'Roach', 'Bream',
       'Roach', 'Perch', 'Roach', 'Pike', 'Perch', 'Roach', 'Smelt',
       'Perch', 'Smelt', 'Bream', 'Bream', 'Bream', 'Parkki', 'Smelt',
       'Roach', 'Perch', 'Perch', 'Bream', 'Perch'], dtype=object)

In [67]:
test_target

array(['Roach', 'Perch', 'Perch', 'Parkki', 'Parkki', 'Perch', 'Perch',
       'Pike', 'Bream', 'Roach', 'Pike', 'Perch', 'Bream', 'Perch',
       'Perch', 'Bream', 'Roach', 'Bream', 'Pike', 'Perch', 'Bream',
       'Roach', 'Perch', 'Perch', 'Pike', 'Perch', 'Whitefish', 'Smelt',
       'Smelt', 'Smelt', 'Bream', 'Bream', 'Bream', 'Parkki', 'Smelt',
       'Roach', 'Perch', 'Perch', 'Bream', 'Perch'], dtype=object)

In [71]:
kn.score(test_scaled, test_target)

0.75

In [69]:
# 알파벳순으로 저장됨
kn.classes_

array(['Bream', 'Parkki', 'Perch', 'Pike', 'Roach', 'Smelt', 'Whitefish'],
      dtype=object)

In [70]:
kn.predict_proba(test_scaled[:5])

array([[0.        , 0.        , 0.66666667, 0.        , 0.33333333,
        0.        , 0.        ],
       [0.        , 0.        , 0.66666667, 0.        , 0.33333333,
        0.        , 0.        ],
       [0.        , 0.        , 0.33333333, 0.        , 0.66666667,
        0.        , 0.        ],
       [0.        , 1.        , 0.        , 0.        , 0.        ,
        0.        , 0.        ],
       [0.        , 0.66666667, 0.        , 0.        , 0.33333333,
        0.        , 0.        ]])