In [70]:
import tensorflow as tf
from tensorflow.keras import layers, Model, optimizers, datasets

In [71]:
def preprocess(x, y):
    x = tf.cast(x, dtype=tf.float32)/255.
    y = tf.cast(y, dtype=tf.int32)
    return x, y

In [72]:
batch_size = 128
(x, y), (x_val, y_val) = datasets.cifar10.load_data()
y = tf.squeeze(y)
y_val = tf.squeeze(y_val)
y = tf.one_hot(y, depth=10)
y_val = tf.one_hot(y_val, depth=10)

In [73]:
train_db = tf.data.Dataset.from_tensor_slices((x, y))
train_db = train_db.map(preprocess).shuffle(10000).batch(batch_size)
test_db = tf.data.Dataset.from_tensor_slices((x_val, y_val))
test_db = test_db.map(preprocess).batch(batch_size)

In [74]:
sample = next(iter(train_db))


In [75]:
sample[0].shape

TensorShape([128, 32, 32, 3])

In [87]:
class MyDense(layers.Layer):
    def __init__(self, inp_dim, out_dim):
        super(MyDense, self).__init__()
        self.kernel = self.add_weight('w', [inp_dim, out_dim])
        self.bias = self.add_weight('b', [out_dim])
    def call(self, inputs, training=None):
        out = inputs@self.kernel + self.bias
        return out

In [88]:
class MyModel(Model): # keras.Model 繼承
    def __init__(self):
        super(MyModel, self).__init__()
        self.fc1 = MyDense(32*32*3, 256)
        self.fc2 = MyDense(256, 128)
        self.fc3 = MyDense(128, 64)
        self.fc4 = MyDense(64, 32)
        self.fc5 = MyDense(32, 10)
    def call(self, inputs, training=None):
        x = tf.reshape(inputs, [-1, 32*32*3])
        x = self.fc1(x)
        x = tf.nn.relu(x)
        x = self.fc2(x)
        x = tf.nn.relu(x)
        x = self.fc3(x)
        x = tf.nn.relu(x)
        x = self.fc4(x)
        x = tf.nn.relu(x)
        x = self.fc5(x)
        return x
    

In [89]:
network = MyModel()

In [90]:
network.compile(optimizer=optimizers.Adam(lr=1e-3), loss=tf.losses.CategoricalCrossentropy(from_logits=True), metrics=['acc'])

In [91]:
network.fit(train_db, epochs=10, validation_data=test_db, validation_freq=1)

Train for 391 steps, validate for 79 steps
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x2316ada94c8>

In [92]:
network.summary()

Model: "my_model_11"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
my_dense_55 (MyDense)        multiple                  786688    
_________________________________________________________________
my_dense_56 (MyDense)        multiple                  32896     
_________________________________________________________________
my_dense_57 (MyDense)        multiple                  8256      
_________________________________________________________________
my_dense_58 (MyDense)        multiple                  2080      
_________________________________________________________________
my_dense_59 (MyDense)        multiple                  330       
Total params: 830,250
Trainable params: 830,250
Non-trainable params: 0
_________________________________________________________________
