## softmax zoo multi-classification
: multi-nomial classification (다중 분류) : Y값의 범주가 3개 이상인 분류
#### 활성화 함수(Activation function) 으로 softmax함수 가 사용된다

In [2]:
import tensorflow as tf
import numpy as np
tf.random.set_seed(5)

In [22]:
# 데이터 불러오기
xy = np.loadtxt('data-04-zoo.csv',delimiter=',',dtype=np.float32)
xy.shape   # (101, 17)

# 학습 데이터 분리 : 70% , 70개
x_train = xy[:70,:-1]   # X
y_train = xy[:70,[-1]]  # Y
y_train = y_train.astype('int32') # one-hot 인코딩을 위해 반드시 int형으로 사용
x_train.shape   # (70, 16)
y_train.shape   # (70, 1)

# 검증 데이터 분리 : 30% , 31개
x_test = xy[70:,:-1]
y_test = xy[70:,[-1]]
y_test = y_test.astype('int32')   # one-hot 인코딩을 위해 반드시 int형으로 사용
x_test.shape  # (31, 16)
y_test.shape  # (31, 1)
# y_train

(31, 1)

In [23]:
# one-hot 인코딩
nb_classes = 7 # 분류(class) 갯수, (0,1,2,3,4,5,6)

Y_one_hot = tf.one_hot(y_train,nb_classes)
print(Y_one_hot.shape)  #  (70,1,7)     , 3차원
Y_one_hot = tf.reshape(Y_one_hot,[-1,nb_classes])
print(Y_one_hot.shape)  #  (70,7)       , 2차원

(70, 1, 7)
(70, 7)


In [24]:
# 변수 초기화 : weight, bias
# (70,16) * (16,7) = (70,7)
W = tf.Variable(tf.random.normal([16,nb_classes]), name = 'weight')
b = tf.Variable(tf.random.normal([nb_classes]), name = 'bias')

In [25]:
# 예측 함수(hypothesis) : H(X) = softmax(X*W + b)
def logits(X):
    return tf.matmul(X,W) + b

def hypothesis(X):
    return tf.nn.softmax(logits(X))   

In [26]:
# 비용함수 구현 방법 2 : tf.nn.softmax_cross_entropy_with_logits() 함수 사용
def cost_func():
    cost_i = tf.nn.softmax_cross_entropy_with_logits(logits = logits(x_train),
                                             labels = Y_one_hot)
    cost =  tf.reduce_mean(cost_i)
    return cost

In [27]:
# 경사 하강법
# learning_rate(학습율)을 0.01 로 설정하여 optimizer객체를 생성
optimizer = tf.keras.optimizers.Adam(lr=0.01)

In [28]:
# 학습 시작
print('****** Start Learning!!')
for step in range(5001):
    # cost를 minimize 한다
    optimizer.minimize(cost_func,var_list=[W,b]) # W,b를 업데이트
    if step % 100 == 0:
        print('%04d'%step,'cost:[',cost_func().numpy(),']',
             ' W:',W.numpy(),' b:',b.numpy())
        
print('****** Learning Finished!!') 

****** Start Learning!!
0000 cost:[ 4.839144 ]  W: [[ 1.4765462  -1.0329607   0.85817504 -2.2477815   0.07975764 -0.10256936
  -0.17262045]
 [-0.0102913  -2.055402    0.08629883 -1.1741688  -0.60716754 -1.6230268
   0.5414934 ]
 [-2.1056848   1.2374349  -0.22617503 -2.1491005   0.8662374  -0.13944854
  -1.0872165 ]
 [ 0.45351046 -0.04618201  1.6368772   0.57476586  1.9437844   0.47304788
  -0.6681095 ]
 [ 0.5400391   0.22524849 -0.28083476  2.0702796  -0.5280023  -1.4842018
   2.147341  ]
 [-0.10611474 -0.45572862  1.0749733   2.5972738   2.5646122   0.9933401
  -1.6237358 ]
 [ 1.0955116   1.1068817  -0.22023039  0.36354733 -0.24794939 -1.7722493
  -0.24174799]
 [ 0.44639748 -0.68253785  0.6332387  -0.13475034  0.73873514 -0.5743464
  -0.24750163]
 [-0.4935193  -0.78649485  0.7880194   0.99918747 -1.2422616   0.31096044
  -0.93099165]
 [ 0.17966032 -0.91762686 -0.91800165 -1.2079287   0.45389992  0.5338235
   0.3346823 ]
 [ 0.5177204   0.00454136 -0.27970713  0.5030586   1.6470021  -0.

0600 cost:[ 0.016037721 ]  W: [[ 2.1529524  -4.5612707   0.14853543 -2.7137654  -0.5367806   0.26542073
  -1.1013731 ]
 [-1.5650823  -0.82688636 -1.457313   -2.5889328  -1.4728613  -3.3436873
  -1.0459031 ]
 [-4.0545917   1.5953139  -0.16763727 -1.1694125   0.1818623   0.4074087
   1.6775228 ]
 [ 1.2134858  -2.8574247   0.3792507  -1.6435606   1.3239948  -0.9229918
  -1.5940586 ]
 [-0.13850965  1.2621726  -1.2736691   0.7970461  -1.2059273   0.23900087
   0.30981225]
 [-0.02147746  0.13236955  0.2367786   3.4039423   1.9324785  -2.423848
   0.24424294]
 [ 1.3039882   0.08105772 -0.13402511  0.86267847 -0.76587725 -1.7814257
   1.8404837 ]
 [ 0.7674445  -2.736816    0.48128617  0.6657983   0.66228205 -2.7849138
  -4.408266  ]
 [-0.13230664  0.19562815  0.63290477  1.7973117  -1.6558449  -1.6982483
  -2.7735443 ]
 [ 0.58040184  0.14969699  1.0201255  -3.4296894   0.03613203  1.7270887
  -1.6722908 ]
 [-1.8694882  -2.1061916   1.6238744  -2.3067605   1.2991886  -0.20177768
  -1.0178034 ]


1200 cost:[ 0.0056242337 ]  W: [[ 2.3769557  -5.394794    0.04588718 -2.725563   -0.74701595  0.19845669
  -1.1415085 ]
 [-1.7899736  -0.588548   -1.745109   -2.7643018  -1.6293882  -3.6484315
  -1.4501413 ]
 [-4.6099653   1.5802288  -0.11044359 -1.0302899   0.16699359  0.40131214
   2.1305418 ]
 [ 1.4550198  -3.4483914   0.0210892  -2.1323225   1.1061505  -1.0682281
  -2.130396  ]
 [-0.15810129  1.3992076  -1.3815105   0.65854305 -1.2911344   0.5874516
  -0.0477796 ]
 [-0.21786381 -0.05460047  0.05099924  3.5067575   2.0021055  -3.1755202
   0.57993096]
 [ 1.2524409  -0.18176857 -0.09319359  0.882201   -0.76216096 -1.6713502
   2.062725  ]
 [ 0.6433273  -3.360436    0.48081562  0.76694024  0.8940233  -3.2446258
  -5.666368  ]
 [-0.11869356  0.31471506  0.6317197   1.8980854  -1.6435935  -2.090212
  -3.2986336 ]
 [ 0.6113106   0.26705632  1.3867482  -3.914921    0.04988696  1.9611696
  -2.1414766 ]
 [-2.2571957  -2.273431    2.0396962  -2.8201065   1.1113592  -0.32660767
  -1.1330142 ]

1800 cost:[ 0.0028456268 ]  W: [[ 2.5382433  -5.8632693  -0.00799018 -2.7297747  -0.8998525   0.16250108
  -1.1682043 ]
 [-1.9361726  -0.42665762 -1.9445003  -2.8774526  -1.7469404  -3.860873
  -1.6878667 ]
 [-4.971676    1.5842822  -0.07540109 -0.943071    0.1545103   0.39779934
   2.398708  ]
 [ 1.6266326  -3.778328   -0.2060116  -2.4516318   0.9476865  -1.1467024
  -2.5228038 ]
 [-0.16332757  1.4896104  -1.4457303   0.5678147  -1.3549296   0.7931668
  -0.25898898]
 [-0.35129198 -0.1788439  -0.07000803  3.5741148   2.0493803  -3.6258075
   0.78198594]
 [ 1.2175223  -0.3147783  -0.06783432  0.89335877 -0.7615657  -1.6142921
   2.1942728 ]
 [ 0.54401505 -3.7978225   0.48156205  0.8334406   1.0641438  -3.521967
  -6.4015636 ]
 [-0.11159091  0.3914749   0.6319683   1.9643486  -1.63743    -2.3393364
  -3.6104755 ]
 [ 0.63051254  0.33964238  1.6301956  -4.230595    0.05785256  2.1029594
  -2.4279737 ]
 [-2.4668589  -2.3537486   2.3120286  -3.102071    0.9552405  -0.40653932
  -1.2036808 ]


2400 cost:[ 0.0016747661 ]  W: [[ 2.6673586  -6.2074275  -0.04379058 -2.7318316  -1.0225704   0.13659286
  -1.18882   ]
 [-2.0469863  -0.29763174 -2.1042264  -2.9628546  -1.8453429  -4.03189
  -1.8655311 ]
 [-5.253109    1.5920844  -0.04928968 -0.8763986   0.14292945  0.3960376
   2.6048687 ]
 [ 1.7637076  -4.0203257  -0.38126367 -2.7036655   0.8203912  -1.2028778
  -2.8366752 ]
 [-0.1635469   1.561309   -1.4917592   0.49863937 -1.4084296   0.9462755
  -0.41670623]
 [-0.4570268  -0.27565023 -0.16413993  3.6275706   2.0862288  -3.9655378
   0.93461066]
 [ 1.1901106  -0.40514553 -0.04923307  0.8997389  -0.76234293 -1.5822349
   2.2965922 ]
 [ 0.46200526 -4.1491733   0.48251653  0.8863152   1.2014353  -3.730259
  -6.9511724 ]
 [-0.10641899  0.45177492  0.63252074  2.0170512  -1.6337562  -2.5314558
  -3.8444808 ]
 [ 0.6453776   0.39550766  1.8208908  -4.4791026   0.06351777  2.2121675
  -2.6464822 ]
 [-2.6133513  -2.4060392   2.5244238  -3.3007498   0.823802   -0.47075015
  -1.2599516 ]
 [

3000 cost:[ 0.0010691542 ]  W: [[ 2.7782114  -6.4899225  -0.07035837 -2.7329829  -1.1280782   0.11536351
  -1.2060714 ]
 [-2.1380827  -0.18648331 -2.2418206  -3.0329874  -1.9328327  -4.1798973
  -2.013126  ]
 [-5.4919224   1.6011064  -0.02769708 -0.8204707   0.13189854  0.39548042
   2.7796204 ]
 [ 1.8813708  -4.2190347  -0.5294268  -2.9203238   0.71090746 -1.2483364
  -3.105693  ]
 [-0.16127077  1.6230514  -1.5279022   0.4413886  -1.4560686   1.072497
  -0.54717064]
 [-0.5479927  -0.3585115  -0.24379139  3.6739502   2.117441   -4.249347
   1.062039  ]
 [ 1.1668246  -0.474742   -0.03422213  0.90320444 -0.7637817  -1.5647916
   2.3846612 ]
 [ 0.39073732 -4.453016    0.4836437   0.9322472   1.3196408  -3.903285
  -7.408302  ]
 [-0.10200047  0.5033293   0.6332996   2.062847   -1.6314417  -2.6936111
  -4.0393333 ]
 [ 0.6581457   0.44265556  1.9829404  -4.692338    0.06799624  2.3052843
  -2.8304281 ]
 [-2.728891   -2.445086    2.7046494  -3.4576728   0.7083608  -0.527366
  -1.3095462 ]
 [ 

3600 cost:[ 0.0007168063 ]  W: [[ 2.877837   -6.7362375  -0.09130933 -2.7336783  -1.222974    0.09668658
  -1.2212428 ]
 [-2.2168486  -0.08615395 -2.3656387  -3.0936842  -2.0136557  -4.3137803
  -2.143313  ]
 [-5.7053246   1.6106889  -0.00866799 -0.7709433   0.12118808  0.39584804
   2.9358888 ]
 [ 1.9871515  -4.392597   -0.661535   -3.1159818   0.61240536 -1.2877326
  -3.347083  ]
 [-0.15749328  1.6788347  -1.5578629   0.3915171  -1.5001084   1.1827532
  -0.6615252 ]
 [-0.6303092  -0.43339744 -0.3146136   3.7163196   2.1452632  -4.500536
   1.174646  ]
 [ 1.1460387  -0.5322653  -0.02136821  0.90464634 -0.76560193 -1.5571798
   2.4647133 ]
 [ 0.32637194 -4.7279673   0.4849637   0.974246    1.4258916  -4.0556216
  -7.811916  ]
 [-0.09790944  0.5496307   0.6343024   2.104733   -1.6300068  -2.837893
  -4.2114534 ]
 [ 0.66975456  0.4846024   2.1275103  -4.884633    0.07177331  2.3892777
  -2.9942315 ]
 [-2.8266075  -2.476606    2.865411   -3.5900235   0.6036038  -0.5798711
  -1.3556576 ]
 

4100 cost:[ 0.00052650314 ]  W: [[ 2.95519280e+00 -6.92321110e+00 -1.05783448e-01 -2.73406076e+00
  -1.29669321e+00  8.22985694e-02 -1.23281026e+00]
 [-2.27601147e+00 -7.96467531e-03 -2.46164179e+00 -3.13936019e+00
  -2.07771993e+00 -4.41820002e+00 -2.24328995e+00]
 [-5.87055254e+00  1.61894751e+00  5.98019734e-03 -7.32883871e-01
   1.12390794e-01  3.96748960e-01  3.05694675e+00]
 [ 2.06932259e+00 -4.52470732e+00 -7.63676286e-01 -3.26865363e+00
   5.35867989e-01 -1.31763744e+00 -3.53424835e+00]
 [-1.53551280e-01  1.72236228e+00 -1.57952857e+00  3.53790015e-01
  -1.53501797e+00  1.26657867e+00 -7.48744667e-01]
 [-6.94757104e-01 -4.92011756e-01 -3.69123816e-01  3.74976110e+00
   2.16675425e+00 -4.69380665e+00  1.26112270e+00]
 [ 1.12997186e+00 -5.74140787e-01 -1.17561258e-02  9.04628694e-01
  -7.67316997e-01 -1.55649233e+00  2.52767968e+00]
 [ 2.76318341e-01 -4.94261217e+00  4.86220032e-01  1.00741863e+00
   1.50829220e+00 -4.17233658e+00 -8.12202263e+00]
 [-9.46053937e-02  5.85585237e-0

4600 cost:[ 0.00039284077 ]  W: [[ 3.02896905e+00 -7.09816265e+00 -1.18178584e-01 -2.73430991e+00
  -1.36702192e+00  6.85935169e-02 -1.24366987e+00]
 [-2.33083057e+00  6.68184385e-02 -2.55284667e+00 -3.18177080e+00
  -2.13978767e+00 -4.51807594e+00 -2.33797407e+00]
 [-6.02788019e+00  1.62741530e+00  1.99218523e-02 -6.96869671e-01
   1.03638574e-01  3.98134589e-01  3.17223883e+00]
 [ 2.14772034e+00 -4.64875937e+00 -8.60846937e-01 -3.41487455e+00
   4.62838918e-01 -1.34573936e+00 -3.71250153e+00]
 [-1.49046883e-01  1.76404738e+00 -1.59891248e+00  3.18606883e-01
  -1.56882620e+00  1.34519649e+00 -8.30787957e-01]
 [-7.56751239e-01 -5.48376441e-01 -4.20709878e-01  3.78214693e+00
   2.18717456e+00 -4.87703371e+00  1.34294128e+00]
 [ 1.11469948e+00 -6.12055242e-01 -2.85766996e-03  9.03678060e-01
  -7.69170225e-01 -1.55999279e+00  2.58847737e+00]
 [ 2.28584513e-01 -5.14815903e+00  4.87631172e-01  1.03955972e+00
   1.58673370e+00 -4.28261089e+00 -8.41572762e+00]
 [-9.13270712e-02  6.19881988e-0

In [21]:
# Weight과 bias 출력
# print('Weight:',W.numpy())  # (16,7)
# print('bias:', b.numpy())   # (7,)

In [29]:
# 정확도 측정 : accuracy computation

# y_test 값의 one-hot 인코딩
Y_one_hot = tf.one_hot(y_test,nb_classes)   # [None,1,7]
print(Y_one_hot.shape)                       # [31,1,7]  , Rank=3 (3차원)
Y_one_hot = tf.reshape(Y_one_hot,[-1,nb_classes])
print(Y_one_hot.shape)                       # [31,7]  , Rank=2 (2차원)


# tf.argmax() : 값이 가장 큰 요소의 인덱스 값을 반환
def predict(X):
    return tf.argmax(hypothesis(X),axis=1)

correct_predict = tf.equal(predict(x_test),tf.argmax(Y_one_hot,1))
accuracy = tf.reduce_mean(tf.cast(correct_predict, dtype = tf.float32))
print("Accuracy:",accuracy.numpy()) # Accuracy: 0.8064516

(31, 1, 7)
(31, 7)
Accuracy: 0.8064516


In [30]:
#예측
print('***** Predict')
pred = predict(x_test).numpy()
for p,y in zip(pred, y_test.flatten()):
    print("[{}] Prediction: {} / Real Y: {}".format(p == int(y), p, int(y)))

***** Predict
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 1 / Real Y: 1
[False] Prediction: 5 / Real Y: 6
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 0 / Real Y: 0
[False] Prediction: 3 / Real Y: 2
[True] Prediction: 6 / Real Y: 6
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 2 / Real Y: 2
[False] Prediction: 1 / Real Y: 6
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 6 / Real Y: 6
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 5 / Real Y: 5
[True] Prediction: 4 / Real Y: 4
[False] Prediction: 1 / Real Y: 2
[False] Prediction: 4 / Real Y: 2
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 5 / Real Y: 5
[True] Prediction: 0 / Real Y: 0
[False] Prediction: 1 / 