# 경사 하강법을 이용한 얕은 신경망 학습


In [1]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [2]:
EPOCHS = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [3]:
class MyModel(tf.keras.Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
            
    def call(self, x, training=None, mask=None):
        x = self.d1(x)
        return self.d2(x)

## 학습 루프 정의

In [4]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        
        # 모델에 값을 입력하고 결과 값을 얻어냄
        predictions = model(inputs)
        
        # labels(정답)과 비교하여 손실 값을 계산함
        loss = loss_object(labels, predictions) 
        
    # df(x)/dx, loss값을 trainable_valuables로 미분하여 결과 겂을 넘겨줌
    gradients = tape.gradient(loss, model.trainable_variables) 
    
    # 최적화를 위하여 미분값과 trainable_valuables을 넘겨줌
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    
    # 손실함수를 이용하여 종합
    train_loss(loss)
    
    # labels(정답)과 예측된 계산값을 비교함
    train_metric(labels, predictions)

## 데이터셋 생성, 전처리

In [5]:
# 랜덤 값을 몇번 반복하더라도 동일하게 되도록 설정함
np.random.seed(0)

pts = list()  # 2개의 입력값들을 생성
labels = list() # 10개의 출력값들을 생성

# -8.0과 8.0사이에 점 10개를 쌍으로 생성
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))

for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)

print(pts[:10])
pts = np.stack(pts, axis=0).astype(np.float32) # 행을 열로 구조 변경해줌
labels = np.stack(labels, axis=0) # 행을 열로 구조 변경해줌

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

[array([2.27509514, 3.2378716 ]), array([1.09408376, 2.58893412]), array([-1.77197375,  4.09664846]), array([1.64545226, 2.70086484]), array([3.05077069, 1.98866419]), array([0.82677458, 3.25584601]), array([2.31379528, 4.91238863]), array([0.93596349, 3.82119238]), array([-0.10676968,  1.46223339]), array([0.43310391, 3.59937883])]


## 모델 생성

In [6]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [7]:
# 손실함수는 Cross Entropy 알고리즘 사용
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()

# 최적화함수는 Adam 알고리즘 사용
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [8]:
# 손실에 대한 평가는 MSE(Mean Squred Error) 사용(?)
train_loss = tf.keras.metrics.Mean(name='train_loss')

# 정확도에 대한 평가는 Sparse 표현 사용
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [9]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch {}, Loss: {}, Accuracy: {}'
    print(template.format(epoch + 1,
                          train_loss.result(),
                          train_accuracy.result() * 100))
    train_loss.reset_states()
    train_accuracy.reset_states()

Epoch 1, Loss: 2.1787407398223877, Accuracy: 32.900001525878906
Epoch 2, Loss: 1.823689579963684, Accuracy: 48.10000228881836
Epoch 3, Loss: 1.5829306840896606, Accuracy: 58.39999771118164
Epoch 4, Loss: 1.4172084331512451, Accuracy: 63.20000457763672
Epoch 5, Loss: 1.2895838022232056, Accuracy: 68.30000305175781
Epoch 6, Loss: 1.1809724569320679, Accuracy: 75.9000015258789
Epoch 7, Loss: 1.098690390586853, Accuracy: 75.0
Epoch 8, Loss: 1.028332233428955, Accuracy: 81.30000305175781
Epoch 9, Loss: 0.9594354033470154, Accuracy: 82.0
Epoch 10, Loss: 0.915228545665741, Accuracy: 81.19999694824219
Epoch 11, Loss: 0.8597851395606995, Accuracy: 82.20000457763672
Epoch 12, Loss: 0.8231942653656006, Accuracy: 84.5999984741211
Epoch 13, Loss: 0.7801067233085632, Accuracy: 84.79999542236328
Epoch 14, Loss: 0.7471539974212646, Accuracy: 85.9000015258789
Epoch 15, Loss: 0.7140805721282959, Accuracy: 83.5
Epoch 16, Loss: 0.6978130340576172, Accuracy: 84.79999542236328
Epoch 17, Loss: 0.659702837467

Epoch 138, Loss: 0.28252464532852173, Accuracy: 89.20000457763672
Epoch 139, Loss: 0.27990463376045227, Accuracy: 88.80000305175781
Epoch 140, Loss: 0.2735862731933594, Accuracy: 89.30000305175781
Epoch 141, Loss: 0.27217212319374084, Accuracy: 89.30000305175781
Epoch 142, Loss: 0.27213573455810547, Accuracy: 89.4000015258789
Epoch 143, Loss: 0.2756405174732208, Accuracy: 89.0999984741211
Epoch 144, Loss: 0.2719949185848236, Accuracy: 88.9000015258789
Epoch 145, Loss: 0.28318214416503906, Accuracy: 88.70000457763672
Epoch 146, Loss: 0.27366408705711365, Accuracy: 89.0
Epoch 147, Loss: 0.2707020342350006, Accuracy: 89.20000457763672
Epoch 148, Loss: 0.2771245241165161, Accuracy: 89.0999984741211
Epoch 149, Loss: 0.27408090233802795, Accuracy: 89.30000305175781
Epoch 150, Loss: 0.28992751240730286, Accuracy: 88.9000015258789
Epoch 151, Loss: 0.27875301241874695, Accuracy: 88.9000015258789
Epoch 152, Loss: 0.2734193205833435, Accuracy: 89.20000457763672
Epoch 153, Loss: 0.2841612100601196

Epoch 271, Loss: 0.25950703024864197, Accuracy: 89.5
Epoch 272, Loss: 0.25507527589797974, Accuracy: 89.20000457763672
Epoch 273, Loss: 0.2571532428264618, Accuracy: 89.5
Epoch 274, Loss: 0.25939419865608215, Accuracy: 88.80000305175781
Epoch 275, Loss: 0.26268479228019714, Accuracy: 88.70000457763672
Epoch 276, Loss: 0.25898289680480957, Accuracy: 89.70000457763672
Epoch 277, Loss: 0.2593052387237549, Accuracy: 89.20000457763672
Epoch 278, Loss: 0.26124605536460876, Accuracy: 89.60000610351562
Epoch 279, Loss: 0.2653380334377289, Accuracy: 89.0999984741211
Epoch 280, Loss: 0.2592270076274872, Accuracy: 89.20000457763672
Epoch 281, Loss: 0.2673254609107971, Accuracy: 89.60000610351562
Epoch 282, Loss: 0.26276469230651855, Accuracy: 88.70000457763672
Epoch 283, Loss: 0.25539666414260864, Accuracy: 89.0
Epoch 284, Loss: 0.2674054503440857, Accuracy: 89.0
Epoch 285, Loss: 0.26311397552490234, Accuracy: 89.4000015258789
Epoch 286, Loss: 0.2595055401325226, Accuracy: 89.30000305175781
Epoch

Epoch 402, Loss: 0.2592400908470154, Accuracy: 89.20000457763672
Epoch 403, Loss: 0.26051682233810425, Accuracy: 89.20000457763672
Epoch 404, Loss: 0.2555268406867981, Accuracy: 89.30000305175781
Epoch 405, Loss: 0.25786149501800537, Accuracy: 89.30000305175781
Epoch 406, Loss: 0.25350484251976013, Accuracy: 89.20000457763672
Epoch 407, Loss: 0.2609727382659912, Accuracy: 89.5
Epoch 408, Loss: 0.2604769468307495, Accuracy: 89.9000015258789
Epoch 409, Loss: 0.2520601749420166, Accuracy: 89.60000610351562
Epoch 410, Loss: 0.25544872879981995, Accuracy: 89.0
Epoch 411, Loss: 0.26057666540145874, Accuracy: 89.0999984741211
Epoch 412, Loss: 0.26892009377479553, Accuracy: 89.80000305175781
Epoch 413, Loss: 0.25273796916007996, Accuracy: 89.30000305175781
Epoch 414, Loss: 0.2523018717765808, Accuracy: 89.70000457763672
Epoch 415, Loss: 0.2563681900501251, Accuracy: 89.20000457763672
Epoch 416, Loss: 0.2537069022655487, Accuracy: 89.4000015258789
Epoch 417, Loss: 0.2552299499511719, Accuracy: 

Epoch 535, Loss: 0.2523549795150757, Accuracy: 89.0999984741211
Epoch 536, Loss: 0.25280627608299255, Accuracy: 89.30000305175781
Epoch 537, Loss: 0.2500551640987396, Accuracy: 89.9000015258789
Epoch 538, Loss: 0.26331013441085815, Accuracy: 89.4000015258789
Epoch 539, Loss: 0.2550557851791382, Accuracy: 89.20000457763672
Epoch 540, Loss: 0.25092145800590515, Accuracy: 89.5
Epoch 541, Loss: 0.2527674436569214, Accuracy: 89.0999984741211
Epoch 542, Loss: 0.25410035252571106, Accuracy: 89.60000610351562
Epoch 543, Loss: 0.2531154453754425, Accuracy: 89.30000305175781
Epoch 544, Loss: 0.2617743909358978, Accuracy: 89.60000610351562
Epoch 545, Loss: 0.2517605721950531, Accuracy: 89.0999984741211
Epoch 546, Loss: 0.25559911131858826, Accuracy: 89.70000457763672
Epoch 547, Loss: 0.251838356256485, Accuracy: 89.70000457763672
Epoch 548, Loss: 0.2599535286426544, Accuracy: 89.20000457763672
Epoch 549, Loss: 0.25351279973983765, Accuracy: 89.9000015258789
Epoch 550, Loss: 0.25527653098106384, A

Epoch 669, Loss: 0.2543683648109436, Accuracy: 89.30000305175781
Epoch 670, Loss: 0.25215446949005127, Accuracy: 89.60000610351562
Epoch 671, Loss: 0.2559077739715576, Accuracy: 89.9000015258789
Epoch 672, Loss: 0.25079146027565, Accuracy: 89.70000457763672
Epoch 673, Loss: 0.2567337453365326, Accuracy: 89.70000457763672
Epoch 674, Loss: 0.26125285029411316, Accuracy: 89.30000305175781
Epoch 675, Loss: 0.25070858001708984, Accuracy: 89.4000015258789
Epoch 676, Loss: 0.252262145280838, Accuracy: 89.80000305175781
Epoch 677, Loss: 0.2472262680530548, Accuracy: 89.30000305175781
Epoch 678, Loss: 0.2578320801258087, Accuracy: 89.4000015258789
Epoch 679, Loss: 0.24590277671813965, Accuracy: 89.9000015258789
Epoch 680, Loss: 0.2539852261543274, Accuracy: 89.20000457763672
Epoch 681, Loss: 0.25421154499053955, Accuracy: 89.30000305175781
Epoch 682, Loss: 0.24977965652942657, Accuracy: 89.9000015258789
Epoch 683, Loss: 0.26471996307373047, Accuracy: 89.4000015258789
Epoch 684, Loss: 0.25058975

Epoch 803, Loss: 0.250789076089859, Accuracy: 90.10000610351562
Epoch 804, Loss: 0.24503512680530548, Accuracy: 89.4000015258789
Epoch 805, Loss: 0.2494298815727234, Accuracy: 89.80000305175781
Epoch 806, Loss: 0.2551403343677521, Accuracy: 89.60000610351562
Epoch 807, Loss: 0.2556641697883606, Accuracy: 89.70000457763672
Epoch 808, Loss: 0.24608683586120605, Accuracy: 89.80000305175781
Epoch 809, Loss: 0.2454300969839096, Accuracy: 89.4000015258789
Epoch 810, Loss: 0.24814505875110626, Accuracy: 89.80000305175781
Epoch 811, Loss: 0.2501063644886017, Accuracy: 89.80000305175781
Epoch 812, Loss: 0.2609519064426422, Accuracy: 89.60000610351562
Epoch 813, Loss: 0.25823622941970825, Accuracy: 90.10000610351562
Epoch 814, Loss: 0.2457825243473053, Accuracy: 89.70000457763672
Epoch 815, Loss: 0.25231462717056274, Accuracy: 89.9000015258789
Epoch 816, Loss: 0.24345722794532776, Accuracy: 90.4000015258789
Epoch 817, Loss: 0.2465721070766449, Accuracy: 89.70000457763672
Epoch 818, Loss: 0.24817

Epoch 937, Loss: 0.25002411007881165, Accuracy: 89.60000610351562
Epoch 938, Loss: 0.25349700450897217, Accuracy: 89.9000015258789
Epoch 939, Loss: 0.24249626696109772, Accuracy: 90.0
Epoch 940, Loss: 0.24973011016845703, Accuracy: 89.70000457763672
Epoch 941, Loss: 0.24730752408504486, Accuracy: 89.9000015258789
Epoch 942, Loss: 0.25130409002304077, Accuracy: 90.0
Epoch 943, Loss: 0.24287079274654388, Accuracy: 90.10000610351562
Epoch 944, Loss: 0.2439485341310501, Accuracy: 89.5
Epoch 945, Loss: 0.24468107521533966, Accuracy: 89.9000015258789
Epoch 946, Loss: 0.24324806034564972, Accuracy: 89.80000305175781
Epoch 947, Loss: 0.25140801072120667, Accuracy: 90.0
Epoch 948, Loss: 0.24485880136489868, Accuracy: 89.60000610351562
Epoch 949, Loss: 0.25103890895843506, Accuracy: 89.70000457763672
Epoch 950, Loss: 0.24293804168701172, Accuracy: 89.80000305175781
Epoch 951, Loss: 0.24813562631607056, Accuracy: 89.60000610351562
Epoch 952, Loss: 0.24323201179504395, Accuracy: 90.10000610351562


## 데이터셋 및 학습 파라미터 저장

In [10]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                    W_h=W_h,
                    b_h=b_h,
                    W_o=W_o,
                    b_o=b_o)