In [1]:
import tensorflow as tf
from tensorflow.keras import layers, models
import keras
import numpy as np
import matplotlib.pyplot as plt

## 1. get data and train model

In [2]:
# download dataset
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()

# only keep label 0 and 1
train_filter = (y_train == 0) | (y_train == 1)
x_train, y_train = x_train[train_filter], y_train[train_filter]
test_filter = (y_test == 0) | (y_test == 1)
x_test, y_test = x_test[test_filter], y_test[test_filter]

# data preprocessingg
x_train, x_test = x_train / 255.0, x_test / 255.0
x_train = x_train.reshape((-1, 28 * 28))
x_test = x_test.reshape((-1, 28 * 28))

In [10]:

def square_activation(x):
    return tf.square(x) - x

model = models.Sequential([
    layers.Dense(64, activation=square_activation, input_shape=(28 * 28,)),
    layers.Dense(1, activation=square_activation)
])

"""
model = models.Sequential([
    layers.Dense(64, activation='sigmoid', input_shape=(28 * 28,)),
    layers.Dense(1, activation='sigmoid')
])
"""


model.compile(optimizer='adam',
              loss='binary_crossentropy',
              metrics=['accuracy'])

history = model.fit(x_train, y_train, epochs=5, batch_size=1024, validation_data=(x_test, y_test))

Epoch 1/5
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m1s[0m 20ms/step - accuracy: 0.7077 - loss: 3.2533 - val_accuracy: 0.9357 - val_loss: 0.5406
Epoch 2/5
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 7ms/step - accuracy: 0.9370 - loss: 0.5649 - val_accuracy: 0.9702 - val_loss: 0.1727
Epoch 3/5
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.9701 - loss: 0.2191 - val_accuracy: 0.9915 - val_loss: 0.0541
Epoch 4/5
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 6ms/step - accuracy: 0.9867 - loss: 0.0933 - val_accuracy: 0.9915 - val_loss: 0.0552
Epoch 5/5
[1m13/13[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 5ms/step - accuracy: 0.9923 - loss: 0.0559 - val_accuracy: 0.9953 - val_loss: 0.0537


In [6]:
def show_history():
    plt.plot(history.history['accuracy'], label='train_acc')
    plt.plot(history.history['val_accuracy'], label='test_acc')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend()
    plt.show()

In [11]:
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f"test_acc {test_acc:.4f}")

[1m67/67[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 1ms/step - accuracy: 0.9978 - loss: 0.0308
test_acc 0.9953


## 2. save and load model

In [3]:
model.save('NN_model_sig.keras')

NameError: name 'model' is not defined

In [13]:
model.save('NN_model_square_activation.keras')

In [4]:
model = keras.models.load_model('NN_model_sig.keras')


In [None]:
model = keras.models.load_model('NN_model_square_activation.keras')


In [5]:
model.count_params()

50305

In [13]:
def compute_output_by_weight(data_point, weights) -> float:
    a1 = []
    # layer 1
    for j in range(64):  
        z1 = 0  
        for i in range(28 * 28): 
            z1 += data_point[i] * weights[0][i][j]  
        
        z1 += weights[1][j] 
        print(j, z1, 1 / (1 + np.exp(-z1)))
        a1.append(1 / (1 + np.exp(-z1))  )
    
    # layer 2
    output = 0  
    for j in range(1):  
        z2 = 0  
        for i in range(64):  
            z2 += a1[i] * weights[2][i][j]  
        z2 += weights[3][j]  
        output = 1 / (1 + np.exp(-z2))
    print(f"z2:{z2}")
    return output
    

In [7]:
weights = model.get_weights()
for i, weight in enumerate(weights):
    print(f"Layer {i} weights shape: {weight.shape}")


Layer 0 weights shape: (784, 64)
Layer 1 weights shape: (64,)
Layer 2 weights shape: (64, 1)
Layer 3 weights shape: (1,)


In [12]:
len(weights[3])

1

In [22]:
# how the value is computed
data_point = x_test[0].reshape(1, 28 * 28)  
output = model.predict(data_point)
print(output)

[1m1/1[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 64ms/step
[[0.9998373]]


In [14]:
data_point = x_test[0]
print(compute_output_by_weight(data_point, weights))

0 -1.3994797354153192 0.1978986825533113
1 -1.6132311404390518 0.16614049537055994
2 1.945045262408973 0.8749053723381297
3 2.201067094975973 0.9003452954411929
4 2.6923439086442262 0.936573361073289
5 1.4170386094627356 0.8048737439702651
6 -1.9720044762991813 0.12217374924239087
7 -1.4531158196465936 0.18952250146211905
8 1.626653036592967 0.8357106223139326
9 -1.6514388542562977 0.16091457912600907
10 2.143213521986834 0.8950329011589654
11 -1.5494892125278166 0.1751600539876168
12 1.7648297075364803 0.8538135156597884
13 -1.5610558859282633 0.17349518646932702
14 3.776337694621408 0.9776065261712433
15 -1.6933873126584593 0.15533089478761383
16 -1.5400169768165255 0.17653280686314787
17 1.7220186491551643 0.8483886684136601
18 -1.580775355230135 0.17068570084167906
19 3.2276991681177543 0.9618634430766423
20 1.976139312891253 0.8782690085891208
21 -1.8841984226057926 0.13190737791534674
22 -1.4080093485631922 0.19654822422600715
23 -1.487930071363956 0.1842326156667387
24 -1.447591

In [10]:
len(weights[0])

784

In [11]:
weights[0][0][0]

0.027937025

In [32]:
for i in range(len(weights[1])):
        print(weights[1][i])

0.19550614
0.18299238
0.15183985
0.1490501
0.15348421
-0.1967776
-0.19212602
-0.17550483
-0.19675587
0.18355839
0.18302464
-0.14857928
0.15127198
-0.1565152
0.19835603
0.19179423
0.1603507
0.14868261
0.1524891
0.12893803
0.15651946
0.1572804
0.16083649
-0.220862
0.14284225
0.15280357
0.20909384
0.14985813
-0.20925581
-0.1539407
-0.15269047
0.15413153
-0.16002108
-0.16585073
0.16407321
0.15151474
0.19430858
-0.1569914
-0.14327994
-0.16010475
-0.14239517
-0.15462334
0.1600015
-0.16614117
-0.1975271
-0.16345172
-0.16263728
0.17282203
0.17572317
-0.1545989
-0.15428074
-0.16630653
-0.15993567
0.21688364
0.22011319
-0.1521142
0.1507878
0.16647126
-0.15587918
-0.17712839
0.161808
-0.20197114
0.15262972
0.17443809


In [31]:
for i in range(len(weights[2])):
    for j in range(len(weights[2][i])):
        print(weights[2][i][j])

0.29599276
0.32631713
0.54180104
0.48299828
0.4418062
-0.25009224
-0.26255456
-0.28972653
-0.26959103
0.348105
0.2948175
-0.33340284
0.44595614
-0.38044462
0.3199096
0.31665012
0.4753405
0.5242221
0.46320173
0.4796282
0.41070753
0.44911137
0.38729963
-0.28798977
0.45597464
0.477009
0.308231
0.49806178
-0.28063798
-0.4687638
-0.44490987
0.4917959
-0.3932617
-0.28350845
0.39047703
0.5082379
0.3280227
-0.36585948
-0.42523184
-0.4074847
-0.43738484
-0.44604006
0.4617679
-0.36324763
-0.29375383
-0.45835617
-0.3515071
0.4396817
0.39581138
-0.36762667
-0.40645856
-0.33712304
-0.38579664
0.34719884
0.37487468
-0.4582405
0.35147443
0.43008977
-0.37173235
-0.31126967
0.42215422
-0.2885715
0.42578548
0.35704643
