## softmax zoo multi-classification
: multi-nomial classification (다중 분류) : Y값의 범주가 3개 이상인 분류
#### 활성화 함수(Activation function) 으로 softmax함수 가 사용된다

In [1]:
import tensorflow as tf
import numpy as np
tf.random.set_seed(5)

In [2]:
# 데이터 불러오기
xy = np.loadtxt('data-04-zoo.csv',delimiter=',',dtype=np.float32)
xy.shape   # (101, 17)

# 학습 데이터 분리 : 70% , 70개
x_train = xy[:70,:-1]   # X
y_train = xy[:70,[-1]]  # Y
y_train = y_train.astype('int32') # one-hot 인코딩을 위해 반드시 int형으로 사용
x_train.shape   # (70, 16)
y_train.shape   # (70, 1)

# 검증 데이터 분리 : 30% , 31개
x_test = xy[70:,:-1]
y_test = xy[70:,[-1]]
y_test = y_test.astype('int32')   # one-hot 인코딩을 위해 반드시 int형으로 사용
x_test.shape  # (31, 16)
y_test.shape  # (31, 1)
# y_train

(31, 1)

In [3]:
# one-hot 인코딩
nb_classes = 7 # 분류(class) 갯수, (0,1,2,3,4,5,6)

Y_one_hot = tf.one_hot(y_train,nb_classes)
print(Y_one_hot.shape)  #  (70,1,7)     , 3차원
Y_one_hot = tf.reshape(Y_one_hot,[-1,nb_classes])
print(Y_one_hot.shape)  #  (70,7)       , 2차원

(70, 1, 7)
(70, 7)


In [4]:
# 변수 초기화 : weight, bias
# (70,16) * (16,7) = (70,7)
W = tf.Variable(tf.random.normal([16,nb_classes]), name = 'weight')
b = tf.Variable(tf.random.normal([nb_classes]), name = 'bias')

In [5]:
# 예측 함수(hypothesis) : H(X) = softmax(X*W + b)
def logits(X):
    return tf.matmul(X,W) + b

def hypothesis(X):
    return tf.nn.softmax(logits(X))   

In [6]:
# 비용함수 구현 방법 2 : tf.nn.softmax_cross_entropy_with_logits() 함수 사용
def cost_func():
    cost_i = tf.nn.softmax_cross_entropy_with_logits(logits = logits(x_train),
                                             labels = Y_one_hot)
    cost =  tf.reduce_mean(cost_i)
    return cost

In [7]:
# 경사 하강법
# learning_rate(학습율)을 0.01 로 설정하여 optimizer객체를 생성
optimizer = tf.keras.optimizers.Adam(lr=0.01)

In [8]:
# 학습 시작
print('****** Start Learning!!')
for step in range(5001):
    # cost를 minimize 한다
    optimizer.minimize(cost_func,var_list=[W,b]) # W,b를 업데이트
    if step % 100 == 0:
        print('%04d'%step,'cost:[',cost_func().numpy(),']',
             ' W:',W.numpy(),' b:',b.numpy())
        
print('****** Learning Finished!!') 

****** Start Learning!!
0000 cost:[ 4.3771973 ]  W: [[-0.17030673 -0.9602862  -0.0475386  -0.7523069   1.3131572  -0.628547
   0.84872144]
 [-0.09818707  2.4588597   0.7576921   1.2559881   0.97023326  1.5193328
  -0.57714087]
 [ 0.888707   -1.2608008  -0.8504979   1.2707453  -0.67305064  0.03615696
   0.2332869 ]
 [ 0.9691402  -0.38048053  0.02533233  0.12795149  0.7496968   0.45404953
  -1.3891158 ]
 [-0.4606576   1.1283888  -0.9175112   0.73430395 -0.68890834 -0.82569605
  -0.5024537 ]
 [-0.9236115  -0.01845565 -0.050317    2.0287018  -1.4086664  -0.91990536
   1.1792337 ]
 [-0.10188054  0.31484452 -1.4101549  -1.5593029  -0.55799836 -0.16874883
   0.19945289]
 [ 1.006162   -2.0308204  -2.8300834  -1.3748323  -1.9845864   0.44651696
  -0.46207324]
 [ 1.0026199  -0.35762587 -0.09890626 -1.2563033   0.9800537   0.49316156
  -0.5365902 ]
 [ 0.19145192  0.68456185 -0.37840676 -0.7960441  -1.7899888  -1.5583152
  -2.0167358 ]
 [ 0.3794871  -0.30376196 -0.8005682   0.42040622 -0.5763615  

0600 cost:[ 0.008271489 ]  W: [[ 0.61235905 -1.5278612  -0.74270207 -1.4460946  -0.8022918  -0.9403321
  -2.0430994 ]
 [-2.9294236   4.9572453  -2.2012355  -0.3661719  -1.6740661  -1.287353
  -3.3275785 ]
 [-1.1910864  -2.082584    0.19446875  1.7965239   0.50366944  0.44239804
   1.7774298 ]
 [ 1.7717081  -0.84041154 -2.1014113  -1.85556    -1.3200544  -0.31255785
  -4.0573854 ]
 [-0.51933706  0.15806273 -3.1470182  -0.8914904  -2.8759024   0.41788125
  -4.2107725 ]
 [-1.6903646  -0.62693435 -3.1244886   2.8292568   0.35172856 -2.5261452
   2.305764  ]
 [ 0.16517816 -0.27000517 -0.36430013 -1.3392824   0.34978694 -1.3262341
   1.7451298 ]
 [ 1.3858793  -2.4894433  -1.0310935  -0.3565365   0.17070115 -0.5526458
  -2.474261  ]
 [ 1.4817922  -0.62744606  1.6935871  -0.5621677   1.2717003  -0.61255985
  -2.5648348 ]
 [ 0.69857734  0.11654146  1.4397787  -2.6576083  -1.3320842  -1.5440001
  -5.65973   ]
 [-1.3598471  -1.1691633   1.0255402  -1.3443954   0.7088136  -0.54592526
  -2.6895537 

1200 cost:[ 0.0027352378 ]  W: [[ 0.73116827 -1.5782866  -0.78863096 -1.4518462  -1.1956594  -1.0555869
  -2.5069082 ]
 [-3.6583073   5.452915   -2.663737   -0.5985367  -2.10424    -1.9593704
  -3.9024937 ]
 [-1.5494552  -2.139887    0.08081968  1.8848071   0.5451328   0.48191458
   2.0691848 ]
 [ 1.904084   -0.8850808  -2.4328792  -2.5974805  -1.7007396  -0.43457523
  -4.5653305 ]
 [-0.68577206  0.10383018 -3.3540921  -1.1411462  -3.2280867   0.54067886
  -5.2690454 ]
 [-1.8851215  -0.68393815 -3.5933292   3.008527    0.66176605 -2.7389176
   2.369112  ]
 [ 0.15998305 -0.317269   -0.4781686  -1.4693565   0.35914397 -1.4594893
   2.041802  ]
 [ 1.3784573  -2.5160894  -0.759413   -0.13800213  0.7704947  -0.7140072
  -2.9013746 ]
 [ 1.518438   -0.6180952   1.9639329  -0.38569143  1.4298139  -0.81237453
  -2.9958878 ]
 [ 0.74679226  0.07649292  1.7155015  -3.1340106  -1.258459   -1.5870638
  -6.395987  ]
 [-1.4947426  -1.2234645   1.3028598  -1.5346266   0.64106596 -0.49145493
  -2.801521

1800 cost:[ 0.0013717384 ]  W: [[ 0.8095121  -1.6144646  -0.82256114 -1.4536065  -1.4342501  -1.1397725
  -2.7749774 ]
 [-4.136565    5.77541    -2.9613316  -0.7493911  -2.3899148  -2.3830962
  -4.293236  ]
 [-1.780934   -2.1756537   0.01729922  1.952926    0.56730354  0.51104057
   2.251262  ]
 [ 1.9926672  -0.919885   -2.6669824  -3.1177065  -1.9336127  -0.51500744
  -4.863931  ]
 [-0.7950924   0.08065113 -3.4768567  -1.3006656  -3.433641    0.6068131
  -5.9301853 ]
 [-2.0020657  -0.7330781  -3.871756    3.1281705   0.86937076 -2.873831
   2.4114683 ]
 [ 0.15864465 -0.35012037 -0.5420113  -1.5549368   0.3790576  -1.5478885
   2.2299423 ]
 [ 1.3720156  -2.533337   -0.5876988   0.00827533  1.1527553  -0.81466246
  -3.1546276 ]
 [ 1.5431119  -0.6168931   2.1347682  -0.26564336  1.5478952  -0.938671
  -3.251908  ]
 [ 0.7795737   0.05075635  1.8892574  -3.4603655  -1.2267728  -1.6083956
  -6.8914533 ]
 [-1.5704234  -1.2556939   1.4777609  -1.6365446   0.5827204  -0.46203387
  -2.861787  ]

2300 cost:[ 0.0008766481 ]  W: [[ 0.86203206 -1.6397477  -0.8462395  -1.454303   -1.5875634  -1.1982979
  -2.9399853 ]
 [-4.4526963   5.990154   -3.157678   -0.8497122  -2.5849445  -2.658278
  -4.5550203 ]
 [-1.933566   -2.1987395  -0.02150762  2.0004647   0.5811975   0.5314154
   2.3689752 ]
 [ 2.0524561  -0.9451379  -2.8297257  -3.470711   -2.084102   -0.568658
  -5.0489697 ]
 [-0.86786854  0.06938557 -3.5533445  -1.404892   -3.5614562   0.64589214
  -6.3460264 ]
 [-2.0766468  -0.77002484 -4.047695    3.2083473   1.0096272  -2.9628792
   2.4407494 ]
 [ 0.15851201 -0.37254387 -0.58110267 -1.6117339   0.39729574 -1.6077845
   2.353306  ]
 [ 1.3668616  -2.5445487  -0.4756364   0.10671993  1.4039823  -0.87974215
  -3.3116715 ]
 [ 1.559789   -0.617762    2.2462442  -0.18435836  1.62902    -1.0210198
  -3.410849  ]
 [ 0.80162996  0.0339481   2.0024757  -3.6803513  -1.2129079  -1.6187663
  -7.2358794 ]
 [-1.6166648  -1.2760308   1.591784   -1.6960812   0.53953815 -0.44453877
  -2.8976898 ]


2800 cost:[ 0.00059678254 ]  W: [[ 9.0813917e-01 -1.6625066e+00 -8.6722076e-01 -1.4547166e+00
  -1.7186996e+00 -1.2506226e+00 -3.0758414e+00]
 [-4.7264061e+00  6.1778245e+00 -3.3277051e+00 -9.3732250e-01
  -2.7588959e+00 -2.8942251e+00 -4.7820458e+00]
 [-2.0659232e+00 -2.2184343e+00 -5.3641398e-02  2.0425980e+00
   5.9314507e-01  5.4976910e-01  2.4696772e+00]
 [ 2.1051555e+00 -9.6837687e-01 -2.9755483e+00 -3.7805748e+00
  -2.2133160e+00 -6.1559761e-01 -5.2018108e+00]
 [-9.3150872e-01  6.1793942e-02 -3.6166832e+00 -1.4944457e+00
  -3.6681421e+00  6.7732757e-01 -6.6902823e+00]
 [-2.1400878e+00 -8.0456424e-01 -4.1960006e+00  3.2785599e+00
   1.1331393e+00 -3.0404315e+00  2.4672015e+00]
 [ 1.5882429e-01 -3.9237168e-01 -6.1352611e-01 -1.6611946e+00
   4.1577414e-01 -1.6608405e+00  2.4601140e+00]
 [ 1.3617948e+00 -2.5541072e+00 -3.7907508e-01  1.9324270e-01
   1.6213653e+00 -9.3543082e-01 -3.4415696e+00]
 [ 1.5745034e+00 -6.1942381e-01  2.3422916e+00 -1.1268821e-01
   1.7002025e+00 -1.091925

3400 cost:[ 0.00039618718 ]  W: [[ 9.5823163e-01 -1.6877404e+00 -8.8979399e-01 -1.4550210e+00
  -1.8581587e+00 -1.3082000e+00 -3.2143853e+00]
 [-5.0193367e+00  6.3808436e+00 -3.5096884e+00 -1.0320781e+00
  -2.9508250e+00 -3.1449282e+00 -5.0235133e+00]
 [-2.2081091e+00 -2.2392597e+00 -8.6905807e-02  2.0883305e+00
   6.0612822e-01  5.7016003e-01  2.5765135e+00]
 [ 2.1626010e+00 -9.9460369e-01 -3.1362321e+00 -4.1155548e+00
  -2.3511829e+00 -6.6648072e-01 -5.3579764e+00]
 [-1.0005573e+00  5.5565376e-02 -3.6814840e+00 -1.5895885e+00
  -3.7791588e+00  7.0892394e-01 -7.0411148e+00]
 [-2.2071974e+00 -8.4388494e-01 -4.3512902e+00  3.3545864e+00
   1.2674354e+00 -3.1242216e+00  2.4967992e+00]
 [ 1.5955926e-01 -4.1398603e-01 -6.4714307e-01 -1.7144985e+00
   4.3786904e-01 -1.7189814e+00  2.5748115e+00]
 [ 1.3557525e+00 -2.5641692e+00 -2.7592662e-01  2.8726891e-01
   1.8542895e+00 -9.9454558e-01 -3.5748355e+00]
 [ 1.5905669e+00 -6.2203473e-01  2.4448850e+00 -3.4619603e-02
   1.7767175e+00 -1.167640

3900 cost:[ 0.0002896895 ]  W: [[ 9.9714386e-01 -1.7076594e+00 -9.0686750e-01 -1.4551867e+00
  -1.9646639e+00 -1.3533025e+00 -3.3157623e+00]
 [-5.2435808e+00  6.5379553e+00 -3.6489031e+00 -1.1054181e+00
  -3.1019592e+00 -3.3358588e+00 -5.2060876e+00]
 [-2.3174901e+00 -2.2550452e+00 -1.1176624e-01  2.1236231e+00
   6.1635590e-01  5.8630699e-01  2.6577988e+00]
 [ 2.2073450e+00 -1.0155921e+00 -3.2620730e+00 -4.3738570e+00
  -2.4567502e+00 -7.0596623e-01 -5.4723415e+00]
 [-1.0542281e+00  5.1957116e-02 -3.7289393e+00 -1.6619134e+00
  -3.8624618e+00  7.3191786e-01 -7.2967582e+00]
 [-2.2581794e+00 -8.7546831e-01 -4.4680843e+00  3.4134285e+00
   1.3716574e+00 -3.1890647e+00  2.5204198e+00]
 [ 1.6038749e-01 -4.3078902e-01 -6.7230088e-01 -1.7555908e+00
   4.5616606e-01 -1.7644815e+00  2.6630599e+00]
 [ 1.3506882e+00 -2.5717473e+00 -1.9696482e-01  3.6028978e-01
   2.0329523e+00 -1.0395579e+00 -3.6731446e+00]
 [ 1.6031051e+00 -6.2455893e-01  2.5234179e+00  2.6107654e-02
   1.8352314e+00 -1.2256035

4500 cost:[ 0.00020355025 ]  W: [[ 1.0416265  -1.730728   -0.9255625  -1.4553298  -2.08475    -1.4051483
  -3.425211  ]
 [-5.496381    6.716924   -3.8055425  -1.1890095  -3.2767906  -3.5503008
  -5.408319  ]
 [-2.4414964  -2.2727084  -0.1393147   2.1635604   0.62831616  0.6050729
   2.7490394 ]
 [ 2.2586095  -1.0401653  -3.4063764  -4.66659    -2.576032   -0.75110596
  -5.5957994 ]
 [-1.1157509   0.04893354 -3.7800844  -1.7428787  -3.9551034   0.7568081
  -7.5710645 ]
 [-2.3153644  -0.9124555  -4.5978303   3.4804182   1.4905361  -3.2629936
   2.5480723 ]
 [ 0.16158386 -0.44998118 -0.700213   -1.8022572   0.4780255  -1.8168195
   2.7631228 ]
 [ 1.3445752  -2.58016    -0.10784969  0.44369093  2.234835   -1.0901313
  -3.7804587 ]
 [ 1.6175082  -0.62790626  2.6120453   0.09553599  1.900922   -1.2910516
  -3.886083  ]
 [ 0.8761529  -0.02053033  2.3732378  -4.423892   -1.2027321  -1.6313885
  -8.509855  ]
 [-1.7561635  -1.3391937   1.9654546  -1.858529    0.3749752  -0.38945228
  -2.998014  

In [9]:
# Weight과 bias 출력
# print('Weight:',W.numpy())  # (16,7)
# print('bias:', b.numpy())   # (7,)

In [10]:
# 정확도 측정 : accuracy computation

# y_test 값의 one-hot 인코딩
Y_one_hot = tf.one_hot(y_test,nb_classes)   # [None,1,7]
print(Y_one_hot.shape)                       # [31,1,7]  , Rank=3 (3차원)
Y_one_hot = tf.reshape(Y_one_hot,[-1,nb_classes])
print(Y_one_hot.shape)                       # [31,7]  , Rank=2 (2차원)


# tf.argmax() : 값이 가장 큰 요소의 인덱스 값을 반환
def predict(X):
    return tf.argmax(hypothesis(X),axis=1)

correct_predict = tf.equal(predict(x_test),tf.argmax(Y_one_hot,1))
accuracy = tf.reduce_mean(tf.cast(correct_predict, dtype = tf.float32))
print("Accuracy:",accuracy.numpy()) # Accuracy: 0.8064516

(31, 1, 7)
(31, 7)
Accuracy: 0.8064516


In [11]:
#예측
print('***** Predict')
pred = predict(x_test).numpy()
for p,y in zip(pred, y_test.flatten()):
    print("[{}] Prediction: {} / Real Y: {}".format(p == int(y), p, int(y)))

***** Predict
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 1 / Real Y: 1
[False] Prediction: 5 / Real Y: 6
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 0 / Real Y: 0
[False] Prediction: 4 / Real Y: 2
[True] Prediction: 6 / Real Y: 6
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 2 / Real Y: 2
[False] Prediction: 2 / Real Y: 6
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 6 / Real Y: 6
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 5 / Real Y: 5
[True] Prediction: 4 / Real Y: 4
[False] Prediction: 0 / Real Y: 2
[False] Prediction: 4 / Real Y: 2
[True] Prediction: 3 / Real Y: 3
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 1 / Real Y: 1
[True] Prediction: 0 / Real Y: 0
[True] Prediction: 5 / Real Y: 5
[True] Prediction: 0 / Real Y: 0
[False] Prediction: 2 / 

In [19]:
print(hypothesis(x_test).numpy()[0])   # 예측값
print(Y_one_hot.numpy()[0])            # 실제값의 one-hot 벡터
print(y_test[0][0])                    # 실제 값
print(tf.argmax(hypothesis(x_test),axis=1).numpy()[0])

[9.9999750e-01 1.6258655e-06 3.5392588e-15 1.7857120e-13 1.8432178e-07
 6.5451525e-07 6.3041147e-16]
[1. 0. 0. 0. 0. 0. 0.]
0
0
