# 경사 하강법을 이용한 얕은 신경망 학습


In [1]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [2]:
EPOCHS = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [3]:
class MyModel(tf.keras.Model):
    def __init__(self):  # 얘 자체가 그냥 입력 계층
        super(MyModel, self).__init__()
        self.d1 = tf.keras.layers.Dense(128, input_dim=2, activation='sigmoid')
        self.d2 = tf.keras.layers.Dense(10, activation='softmax')
    
    def call(self, x, training=None, mask=None):
        x = self.d1(x)  # 은닉 계층 정의
        return self.d2(x)  # 출력 계층 return

## 학습 루프 정의

In [4]:
@tf.function  # 오토그래프로 텐서플로우 최적화

def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    # gradient 계산
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)  
    # loss를 trainable_variables로 미분해서 gradient를 구한다.  # df(x) / dx

    optimizer.apply_gradients(zip(gradients, model.trainable_variables))  # zip을 이용해서 gradient와 model에 의해서 학습 될 것을 zip으로 넣는다.
    train_loss(loss)  # loss값들을 종합해주는 역활
    train_metric(labels, predictions)  # 정답과 예측값을 비교해서 평가지표로 계산
    

## 데이터셋 생성, 전처리

In [5]:
np.random.seed(0)

pts = list()  # 입력값
labels = list()  # 출력값
center_pts = np.random.uniform(-8.0, 8.0, (10, 2))  # (10, 2) - 10개의 point를 x,y 2개의 dimension을 가지게 만든다.
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)

pts = np.stack(pts, axis = 0).astype(np.float32)
labels = np.stack(labels, axis=0)  # label은 int라 float으로 바꾸지 않아도 된다.

train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)
# from_tensor_slices : input(pts)과 labels을 넣어줬을 때 train dataset으로 합쳐준다.

## 모델 생성

In [6]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [7]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [8]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [10]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model, x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch : {}, Loss : {}, Accuracy : {}'
    print(template.format(epoch + 1,
                         train_loss.result(),
                         train_accuracy.result() * 100))

Epoch : 1, Loss : 0.3505007326602936, Accuracy : 87.93123626708984
Epoch : 2, Loss : 0.35019710659980774, Accuracy : 87.93647003173828
Epoch : 3, Loss : 0.34987834095954895, Accuracy : 87.94100952148438
Epoch : 4, Loss : 0.3495696783065796, Accuracy : 87.94584655761719
Epoch : 5, Loss : 0.34929850697517395, Accuracy : 87.95197296142578
Epoch : 6, Loss : 0.3489929437637329, Accuracy : 87.95741271972656
Epoch : 7, Loss : 0.34874582290649414, Accuracy : 87.96148681640625
Epoch : 8, Loss : 0.3484452962875366, Accuracy : 87.96881866455078
Epoch : 9, Loss : 0.3481723964214325, Accuracy : 87.9741439819336
Epoch : 10, Loss : 0.3478877544403076, Accuracy : 87.97813415527344
Epoch : 11, Loss : 0.34760674834251404, Accuracy : 87.98208618164062
Epoch : 12, Loss : 0.34733402729034424, Accuracy : 87.98473358154297
Epoch : 13, Loss : 0.34706711769104004, Accuracy : 87.98993682861328
Epoch : 14, Loss : 0.34680241346359253, Accuracy : 87.99414825439453
Epoch : 15, Loss : 0.34651175141334534, Accuracy :

Epoch : 130, Loss : 0.3227499723434448, Accuracy : 88.35039520263672
Epoch : 131, Loss : 0.3226005434989929, Accuracy : 88.3547134399414
Epoch : 132, Loss : 0.3224618434906006, Accuracy : 88.3569107055664
Epoch : 133, Loss : 0.32230907678604126, Accuracy : 88.3591079711914
Epoch : 134, Loss : 0.32215461134910583, Accuracy : 88.36267852783203
Epoch : 135, Loss : 0.32202064990997314, Accuracy : 88.36415100097656
Epoch : 136, Loss : 0.32186442613601685, Accuracy : 88.36492156982422
Epoch : 137, Loss : 0.3217172920703888, Accuracy : 88.36707305908203
Epoch : 138, Loss : 0.32156968116760254, Accuracy : 88.36990356445312
Epoch : 139, Loss : 0.32141757011413574, Accuracy : 88.37203979492188
Epoch : 140, Loss : 0.3212722837924957, Accuracy : 88.37416076660156
Epoch : 141, Loss : 0.32111749053001404, Accuracy : 88.37718200683594
Epoch : 142, Loss : 0.3209725320339203, Accuracy : 88.37905883789062
Epoch : 143, Loss : 0.3208472728729248, Accuracy : 88.38182830810547
Epoch : 144, Loss : 0.32069486

Epoch : 256, Loss : 0.3074910342693329, Accuracy : 88.58863067626953
Epoch : 257, Loss : 0.3073902726173401, Accuracy : 88.59135437011719
Epoch : 258, Loss : 0.30728742480278015, Accuracy : 88.5926284790039
Epoch : 259, Loss : 0.3072004020214081, Accuracy : 88.59408569335938
Epoch : 260, Loss : 0.3071015477180481, Accuracy : 88.5949935913086
Epoch : 261, Loss : 0.30701321363449097, Accuracy : 88.5971450805664
Epoch : 262, Loss : 0.30691590905189514, Accuracy : 88.59840393066406
Epoch : 263, Loss : 0.3068130314350128, Accuracy : 88.60054779052734
Epoch : 264, Loss : 0.3067169487476349, Accuracy : 88.6021499633789
Epoch : 265, Loss : 0.3066277503967285, Accuracy : 88.60427856445312
Epoch : 266, Loss : 0.3065272569656372, Accuracy : 88.6062240600586
Epoch : 267, Loss : 0.3064327538013458, Accuracy : 88.60780334472656
Epoch : 268, Loss : 0.3063296675682068, Accuracy : 88.60991668701172
Epoch : 269, Loss : 0.3062468469142914, Accuracy : 88.61042785644531
Epoch : 270, Loss : 0.30615267157554

Epoch : 376, Loss : 0.29789215326309204, Accuracy : 88.76008605957031
Epoch : 377, Loss : 0.2978249192237854, Accuracy : 88.76117706298828
Epoch : 378, Loss : 0.297747939825058, Accuracy : 88.76287078857422
Epoch : 379, Loss : 0.2976706624031067, Accuracy : 88.76380920410156
Epoch : 380, Loss : 0.29760318994522095, Accuracy : 88.76519012451172
Epoch : 381, Loss : 0.2975287139415741, Accuracy : 88.76641845703125
Epoch : 382, Loss : 0.2974635362625122, Accuracy : 88.76823425292969
Epoch : 383, Loss : 0.29740625619888306, Accuracy : 88.7685775756836
Epoch : 384, Loss : 0.2973388433456421, Accuracy : 88.76994323730469
Epoch : 385, Loss : 0.2972821593284607, Accuracy : 88.77188873291016
Epoch : 386, Loss : 0.2972109913825989, Accuracy : 88.77339172363281
Epoch : 387, Loss : 0.29714661836624146, Accuracy : 88.77460479736328
Epoch : 388, Loss : 0.2970770299434662, Accuracy : 88.7763900756836
Epoch : 389, Loss : 0.2970142960548401, Accuracy : 88.7781753540039
Epoch : 390, Loss : 0.296955525875

Epoch : 495, Loss : 0.290758341550827, Accuracy : 88.89996337890625
Epoch : 496, Loss : 0.29071077704429626, Accuracy : 88.90122985839844
Epoch : 497, Loss : 0.29066529870033264, Accuracy : 88.9024887084961
Epoch : 498, Loss : 0.2906179428100586, Accuracy : 88.90335845947266
Epoch : 499, Loss : 0.290569543838501, Accuracy : 88.90360260009766
Epoch : 500, Loss : 0.29053181409835815, Accuracy : 88.90473175048828
Epoch : 501, Loss : 0.29049742221832275, Accuracy : 88.90573120117188
Epoch : 502, Loss : 0.29044997692108154, Accuracy : 88.90709686279297
Epoch : 503, Loss : 0.2903901934623718, Accuracy : 88.9085922241211
Epoch : 504, Loss : 0.2903335690498352, Accuracy : 88.9095687866211
Epoch : 505, Loss : 0.29028159379959106, Accuracy : 88.91055297851562
Epoch : 506, Loss : 0.2902331054210663, Accuracy : 88.91141510009766
Epoch : 507, Loss : 0.2901948094367981, Accuracy : 88.91226959228516
Epoch : 508, Loss : 0.2901574373245239, Accuracy : 88.91387176513672
Epoch : 509, Loss : 0.29011616110

Epoch : 618, Loss : 0.2852019667625427, Accuracy : 89.02613830566406
Epoch : 619, Loss : 0.2851746082305908, Accuracy : 89.0271987915039
Epoch : 620, Loss : 0.28512638807296753, Accuracy : 89.02837371826172
Epoch : 621, Loss : 0.2850927412509918, Accuracy : 89.02921295166016
Epoch : 622, Loss : 0.2850489616394043, Accuracy : 89.03015899658203
Epoch : 623, Loss : 0.2850036025047302, Accuracy : 89.03153228759766
Epoch : 624, Loss : 0.2849666476249695, Accuracy : 89.03269958496094
Epoch : 625, Loss : 0.2849271595478058, Accuracy : 89.03363800048828
Epoch : 626, Loss : 0.2848840653896332, Accuracy : 89.0350112915039
Epoch : 627, Loss : 0.2848389148712158, Accuracy : 89.0356216430664
Epoch : 628, Loss : 0.28480029106140137, Accuracy : 89.03644561767578
Epoch : 629, Loss : 0.28475672006607056, Accuracy : 89.0376968383789
Epoch : 630, Loss : 0.284720242023468, Accuracy : 89.0383071899414
Epoch : 631, Loss : 0.2846759259700775, Accuracy : 89.03923797607422
Epoch : 632, Loss : 0.284630745649337

Epoch : 737, Loss : 0.28072622418403625, Accuracy : 89.13739013671875
Epoch : 738, Loss : 0.28068938851356506, Accuracy : 89.13842010498047
Epoch : 739, Loss : 0.2806624472141266, Accuracy : 89.13964080810547
Epoch : 740, Loss : 0.2806330919265747, Accuracy : 89.14027404785156
Epoch : 741, Loss : 0.2805934250354767, Accuracy : 89.14100646972656
Epoch : 742, Loss : 0.28055551648139954, Accuracy : 89.14173889160156
Epoch : 743, Loss : 0.2805284559726715, Accuracy : 89.14295196533203
Epoch : 744, Loss : 0.280498743057251, Accuracy : 89.14386749267578
Epoch : 745, Loss : 0.2804630398750305, Accuracy : 89.14469146728516
Epoch : 746, Loss : 0.28043651580810547, Accuracy : 89.14531707763672
Epoch : 747, Loss : 0.28040415048599243, Accuracy : 89.14632415771484
Epoch : 748, Loss : 0.2803663909435272, Accuracy : 89.14733123779297
Epoch : 749, Loss : 0.28033238649368286, Accuracy : 89.1480484008789
Epoch : 750, Loss : 0.2803107798099518, Accuracy : 89.14905548095703
Epoch : 751, Loss : 0.28027629

Epoch : 859, Loss : 0.27696242928504944, Accuracy : 89.24690246582031
Epoch : 860, Loss : 0.27693265676498413, Accuracy : 89.24763488769531
Epoch : 861, Loss : 0.2769034504890442, Accuracy : 89.24819946289062
Epoch : 862, Loss : 0.27687957882881165, Accuracy : 89.24876403808594
Epoch : 863, Loss : 0.2768515944480896, Accuracy : 89.24949645996094
Epoch : 864, Loss : 0.27682045102119446, Accuracy : 89.25031280517578
Epoch : 865, Loss : 0.2767910361289978, Accuracy : 89.251220703125
Epoch : 866, Loss : 0.27676016092300415, Accuracy : 89.2522964477539
Epoch : 867, Loss : 0.2767384648323059, Accuracy : 89.25353240966797
Epoch : 868, Loss : 0.27670854330062866, Accuracy : 89.25434875488281
Epoch : 869, Loss : 0.27667906880378723, Accuracy : 89.25567626953125
Epoch : 870, Loss : 0.27665066719055176, Accuracy : 89.25621795654297
Epoch : 871, Loss : 0.2766267657279968, Accuracy : 89.25677490234375
Epoch : 872, Loss : 0.2765974700450897, Accuracy : 89.25741577148438
Epoch : 873, Loss : 0.2765707

Epoch : 984, Loss : 0.2736743688583374, Accuracy : 89.34736633300781
Epoch : 985, Loss : 0.2736514210700989, Accuracy : 89.34803009033203
Epoch : 986, Loss : 0.27362605929374695, Accuracy : 89.34892272949219
Epoch : 987, Loss : 0.2735999822616577, Accuracy : 89.34989929199219
Epoch : 988, Loss : 0.2735765874385834, Accuracy : 89.3504867553711
Epoch : 989, Loss : 0.2735556662082672, Accuracy : 89.35106658935547
Epoch : 990, Loss : 0.27353179454803467, Accuracy : 89.35172271728516
Epoch : 991, Loss : 0.2735138535499573, Accuracy : 89.35285186767578
Epoch : 992, Loss : 0.27348676323890686, Accuracy : 89.3536605834961
Epoch : 993, Loss : 0.2734597325325012, Accuracy : 89.35447692871094
Epoch : 994, Loss : 0.2734329104423523, Accuracy : 89.35482025146484
Epoch : 995, Loss : 0.2734098732471466, Accuracy : 89.35523986816406
Epoch : 996, Loss : 0.2733912169933319, Accuracy : 89.35604858398438
Epoch : 997, Loss : 0.27336809039115906, Accuracy : 89.35623931884766
Epoch : 998, Loss : 0.2733471691

## 데이터셋 및 학습 파라미터 저장

In [11]:
np.savez_compressed('ch2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('ch2_parameters.npz',
                   W_h=W_h,
                   b_h=b_h,
                   W_o=W_o,
                   b_o=b_o)