# 경사 하강법을 이용한 얕은 신경망 학습


In [14]:
import tensorflow as tf
import numpy as np

## 하이퍼 파라미터 설정

In [15]:
EPOCHS = 1000

## 네트워크 구조 정의
### 얕은 신경망
#### 입력 계층 : 2, 은닉 계층 : 128 (Sigmoid activation), 출력 계층 : 10 (Softmax activation)

In [16]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) # df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 학습 루프 정의

In [17]:
@tf.function
def train_step(model, inputs, labels, loss_object, optimizer, train_loss, train_metric):
    with tf.GradientTape() as tape:
        predictions = model(inputs)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables) # df(x)/dx
    
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_metric(labels, predictions)

## 데이터셋 생성, 전처리

In [18]:
np.random.seed(0)

pts = list()
labels = list()
center_pts = np.random.uniform(-8.0,8.0,(10,2))
for label, center_pt in enumerate(center_pts):
    for _ in range(100):
        pts.append(center_pt + np.random.randn(*center_pt.shape))
        labels.append(label)
        
pts = np.stack(pts, axis=0).astype(np.float32)
labels = np.stack(labels, axis = 0)
train_ds = tf.data.Dataset.from_tensor_slices((pts, labels)).shuffle(1000).batch(32)

## 모델 생성

In [19]:
model = MyModel()

## 손실 함수 및 최적화 알고리즘 설정
### CrossEntropy, Adam Optimizer

In [20]:
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

## 평가 지표 설정
### Accuracy

In [21]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseTopKCategoricalAccuracy(name='train_accuracy')

## 학습 루프

In [23]:
for epoch in range(EPOCHS):
    for x, label in train_ds:
        train_step(model,x, label, loss_object, optimizer, train_loss, train_accuracy)
        
    template = 'Epoch :{}, loss : {}, Accuracy: {}'
    print(template.format(epoch+1,train_loss.result(),train_accuracy.result()*100))

Epoch :1, loss : 0.31947091221809387, Accuracy: 99.8941879272461
Epoch :2, loss : 0.3193243145942688, Accuracy: 99.89442443847656
Epoch :3, loss : 0.31919005513191223, Accuracy: 99.89466857910156
Epoch :4, loss : 0.31905606389045715, Accuracy: 99.89490509033203
Epoch :5, loss : 0.3189074993133545, Accuracy: 99.8951416015625
Epoch :6, loss : 0.31876152753829956, Accuracy: 99.8953857421875
Epoch :7, loss : 0.31861212849617004, Accuracy: 99.89561462402344
Epoch :8, loss : 0.31847572326660156, Accuracy: 99.8958511352539
Epoch :9, loss : 0.31833332777023315, Accuracy: 99.89608764648438
Epoch :10, loss : 0.31818887591362, Accuracy: 99.89631652832031
Epoch :11, loss : 0.3180614113807678, Accuracy: 99.89654541015625
Epoch :12, loss : 0.3179284334182739, Accuracy: 99.89677429199219
Epoch :13, loss : 0.31780311465263367, Accuracy: 99.89700317382812
Epoch :14, loss : 0.3177090883255005, Accuracy: 99.89723205566406
Epoch :15, loss : 0.31758227944374084, Accuracy: 99.8974609375
Epoch :16, loss : 0.

Epoch :129, loss : 0.30505791306495667, Accuracy: 99.91809844970703
Epoch :130, loss : 0.3049628436565399, Accuracy: 99.91824340820312
Epoch :131, loss : 0.30487096309661865, Accuracy: 99.91838073730469
Epoch :132, loss : 0.304775208234787, Accuracy: 99.91852569580078
Epoch :133, loss : 0.30468451976776123, Accuracy: 99.91867065429688
Epoch :134, loss : 0.30458882451057434, Accuracy: 99.91881561279297
Epoch :135, loss : 0.3044956922531128, Accuracy: 99.91895294189453
Epoch :136, loss : 0.3043949007987976, Accuracy: 99.9190902709961
Epoch :137, loss : 0.30431169271469116, Accuracy: 99.91923522949219
Epoch :138, loss : 0.30422243475914, Accuracy: 99.91938018798828
Epoch :139, loss : 0.30412691831588745, Accuracy: 99.91951751708984
Epoch :140, loss : 0.30403247475624084, Accuracy: 99.9196548461914
Epoch :141, loss : 0.30395007133483887, Accuracy: 99.9197998046875
Epoch :142, loss : 0.30388835072517395, Accuracy: 99.91992950439453
Epoch :143, loss : 0.30380263924598694, Accuracy: 99.920066

Epoch :258, loss : 0.29548361897468567, Accuracy: 99.93328094482422
Epoch :259, loss : 0.29541581869125366, Accuracy: 99.93338012695312
Epoch :260, loss : 0.29534703493118286, Accuracy: 99.9334716796875
Epoch :261, loss : 0.2952823042869568, Accuracy: 99.9335708618164
Epoch :262, loss : 0.2952117919921875, Accuracy: 99.93366241455078
Epoch :263, loss : 0.295151025056839, Accuracy: 99.93376159667969
Epoch :264, loss : 0.2950810194015503, Accuracy: 99.93385314941406
Epoch :265, loss : 0.2950162887573242, Accuracy: 99.93395233154297
Epoch :266, loss : 0.29497161507606506, Accuracy: 99.93404388427734
Epoch :267, loss : 0.29490819573402405, Accuracy: 99.93413543701172
Epoch :268, loss : 0.2948400676250458, Accuracy: 99.93423461914062
Epoch :269, loss : 0.2947748303413391, Accuracy: 99.934326171875
Epoch :270, loss : 0.294709712266922, Accuracy: 99.93441772460938
Epoch :271, loss : 0.29464396834373474, Accuracy: 99.93450927734375
Epoch :272, loss : 0.2945803105831146, Accuracy: 99.9346008300

Epoch :388, loss : 0.2883552014827728, Accuracy: 99.94379425048828
Epoch :389, loss : 0.28830453753471375, Accuracy: 99.94385528564453
Epoch :390, loss : 0.28825122117996216, Accuracy: 99.94393157958984
Epoch :391, loss : 0.28820982575416565, Accuracy: 99.9439926147461
Epoch :392, loss : 0.2881624102592468, Accuracy: 99.94406127929688
Epoch :393, loss : 0.28811389207839966, Accuracy: 99.94412994384766
Epoch :394, loss : 0.28805992007255554, Accuracy: 99.94419860839844
Epoch :395, loss : 0.2880183458328247, Accuracy: 99.94426727294922
Epoch :396, loss : 0.28797847032546997, Accuracy: 99.94432830810547
Epoch :397, loss : 0.2879287898540497, Accuracy: 99.94439697265625
Epoch :398, loss : 0.2878774404525757, Accuracy: 99.94446563720703
Epoch :399, loss : 0.2878275513648987, Accuracy: 99.94453430175781
Epoch :400, loss : 0.28777775168418884, Accuracy: 99.94459533691406
Epoch :401, loss : 0.28773579001426697, Accuracy: 99.94466400146484
Epoch :402, loss : 0.28768500685691833, Accuracy: 99.94

Epoch :516, loss : 0.2829873561859131, Accuracy: 99.95133972167969
Epoch :517, loss : 0.28294843435287476, Accuracy: 99.95138549804688
Epoch :518, loss : 0.2829202711582184, Accuracy: 99.9514389038086
Epoch :519, loss : 0.2828806936740875, Accuracy: 99.95148468017578
Epoch :520, loss : 0.2828374207019806, Accuracy: 99.9515380859375
Epoch :521, loss : 0.282797247171402, Accuracy: 99.95159149169922
Epoch :522, loss : 0.28275883197784424, Accuracy: 99.95164489746094
Epoch :523, loss : 0.282720685005188, Accuracy: 99.95169067382812
Epoch :524, loss : 0.2826917767524719, Accuracy: 99.95174407958984
Epoch :525, loss : 0.2826518416404724, Accuracy: 99.95178985595703
Epoch :526, loss : 0.2826116681098938, Accuracy: 99.95184326171875
Epoch :527, loss : 0.28257814049720764, Accuracy: 99.95189666748047
Epoch :528, loss : 0.2825373113155365, Accuracy: 99.95194244384766
Epoch :529, loss : 0.28250959515571594, Accuracy: 99.95199584960938
Epoch :530, loss : 0.28247448801994324, Accuracy: 99.952041625

Epoch :641, loss : 0.27876994013786316, Accuracy: 99.95697784423828
Epoch :642, loss : 0.27873483300209045, Accuracy: 99.95701599121094
Epoch :643, loss : 0.2786996066570282, Accuracy: 99.9570541381836
Epoch :644, loss : 0.27866533398628235, Accuracy: 99.95709991455078
Epoch :645, loss : 0.27863267064094543, Accuracy: 99.95713806152344
Epoch :646, loss : 0.27860116958618164, Accuracy: 99.9571762084961
Epoch :647, loss : 0.2785658538341522, Accuracy: 99.95721435546875
Epoch :648, loss : 0.2785325050354004, Accuracy: 99.95726013183594
Epoch :649, loss : 0.27850228548049927, Accuracy: 99.95729064941406
Epoch :650, loss : 0.2784736156463623, Accuracy: 99.95733642578125
Epoch :651, loss : 0.2784389853477478, Accuracy: 99.95736694335938
Epoch :652, loss : 0.2784058451652527, Accuracy: 99.95741271972656
Epoch :653, loss : 0.2783713936805725, Accuracy: 99.95745086669922
Epoch :654, loss : 0.2783380150794983, Accuracy: 99.95748901367188
Epoch :655, loss : 0.27830252051353455, Accuracy: 99.95753

Epoch :763, loss : 0.2752600312232971, Accuracy: 99.96134948730469
Epoch :764, loss : 0.2752360701560974, Accuracy: 99.96138000488281
Epoch :765, loss : 0.2752074897289276, Accuracy: 99.96141052246094
Epoch :766, loss : 0.2751848101615906, Accuracy: 99.9614486694336
Epoch :767, loss : 0.27515873312950134, Accuracy: 99.96147918701172
Epoch :768, loss : 0.2751307785511017, Accuracy: 99.96150970458984
Epoch :769, loss : 0.27510276436805725, Accuracy: 99.96154022216797
Epoch :770, loss : 0.2750833332538605, Accuracy: 99.9615707397461
Epoch :771, loss : 0.27505600452423096, Accuracy: 99.96160888671875
Epoch :772, loss : 0.27502870559692383, Accuracy: 99.96163940429688
Epoch :773, loss : 0.27501124143600464, Accuracy: 99.961669921875
Epoch :774, loss : 0.27498170733451843, Accuracy: 99.96170043945312
Epoch :775, loss : 0.27495408058166504, Accuracy: 99.96173095703125
Epoch :776, loss : 0.2749271094799042, Accuracy: 99.96176147460938
Epoch :777, loss : 0.27489930391311646, Accuracy: 99.961791

Epoch :891, loss : 0.27214479446411133, Accuracy: 99.96507263183594
Epoch :892, loss : 0.27211982011795044, Accuracy: 99.96510314941406
Epoch :893, loss : 0.2720983624458313, Accuracy: 99.96512603759766
Epoch :894, loss : 0.2720724046230316, Accuracy: 99.96515655517578
Epoch :895, loss : 0.27205690741539, Accuracy: 99.96517944335938
Epoch :896, loss : 0.2720356285572052, Accuracy: 99.96520233154297
Epoch :897, loss : 0.2720128893852234, Accuracy: 99.9652328491211
Epoch :898, loss : 0.2719890773296356, Accuracy: 99.96525573730469
Epoch :899, loss : 0.27196699380874634, Accuracy: 99.96527862548828
Epoch :900, loss : 0.27194443345069885, Accuracy: 99.9653091430664
Epoch :901, loss : 0.2719210088253021, Accuracy: 99.96533203125
Epoch :902, loss : 0.2718976140022278, Accuracy: 99.9653549194336
Epoch :903, loss : 0.27187487483024597, Accuracy: 99.96538543701172
Epoch :904, loss : 0.2718616724014282, Accuracy: 99.96540832519531
Epoch :905, loss : 0.27184632420539856, Accuracy: 99.965438842773

## 데이터셋 및 학습 파라미터 저장

In [24]:
np.savez_compressed('2_dataset.npz', inputs=pts, labels=labels)

W_h, b_h = model.d1.get_weights()
W_o, b_o = model.d2.get_weights()
W_h = np.transpose(W_h)
W_o = np.transpose(W_o)
np.savez_compressed('2_parameters.npz',
                    W_h=W_h,
                    b_h=b_h,
                    W_o=W_o,
                    b_o=b_o)