# 1. 模型可视化
利用tensorboard可视化模型的loss，accuracy以及变量

# 2. CNN模型

- 加载数据集

In [1]:
import tensorflow as tf

In [3]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
print("training data shape", x_train.shape)
print("training label shape", y_train.shape)
print("test data shape", x_test.shape)
print("test label shape", y_test.shape)

training data shape (60000, 28, 28)
training label shape (60000,)
test data shape (10000, 28, 28)
test label shape (10000,)


- 数据集预处理

In [4]:
x_train = x_train.astype("float32")/255.0
x_test = x_test.astype("float32") / 255
x_train = tf.expand_dims(x_train, -1)
x_test = tf.expand_dims(x_test, -1)
print("x_train shape", x_train.shape)

x_train shape (60000, 28, 28, 1)


In [5]:
## 对label进行one-hot编码
y_train = tf.one_hot(y_train, 10)
y_test = tf.one_hot(y_test, 10)
print("y train label: ",y_train[:4])

y train label:  tf.Tensor(
[[0. 0. 0. 0. 0. 1. 0. 0. 0. 0.]
 [1. 0. 0. 0. 0. 0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 1. 0. 0. 0. 0. 0.]
 [0. 1. 0. 0. 0. 0. 0. 0. 0. 0.]], shape=(4, 10), dtype=float32)


- 搭建CNN模型

In [7]:
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(input_shape=(x_train.shape[1], x_train.shape[2], x_train.shape[3]), 
                                 filters=32, kernel_size=(3, 3),
                                 strides=(1, 1), padding="same", activation="relu", name="conv1"))
model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same", 
                                 activation="relu", name="conv2"))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="valid", name="pool1"))
model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same",
                                activation="relu", name="conv3"))
model.add(tf.keras.layers.Conv2D(filters=32, kernel_size=(3, 3), strides=(1, 1), padding="same",
                                activation="relu", name="conv4"))
model.add(tf.keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2), padding="valid", name="pool2"))
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(32, activation="relu", kernel_initializer="he_normal", name="fc1"))
model.add(tf.keras.layers.Dense(10, activation="softmax", kernel_initializer="he_normal", name="fc2"))
model.compile(optimizer=tf.keras.optimizers.SGD(0.01),
             loss=tf.keras.losses.CategoricalCrossentropy(),
             metrics=["accuracy"])
model.summary()

Model: "sequential_1"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv1 (Conv2D)               (None, 28, 28, 32)        320       
_________________________________________________________________
conv2 (Conv2D)               (None, 28, 28, 32)        9248      
_________________________________________________________________
pool1 (MaxPooling2D)         (None, 14, 14, 32)        0         
_________________________________________________________________
conv3 (Conv2D)               (None, 14, 14, 32)        9248      
_________________________________________________________________
conv4 (Conv2D)               (None, 14, 14, 32)        9248      
_________________________________________________________________
pool2 (MaxPooling2D)         (None, 7, 7, 32)          0         
_________________________________________________________________
flatten (Flatten)            (None, 1568)             

In [10]:
import time
model_path = "/home/xiaofang/notebook/company_tf/CNN/logs/cnn_event-{}".format(int(time.time()))
my_callbacks = [tf.keras.callbacks.ModelCheckpoint("./models/cnn/mnist_cnn.h5"),
               tf.keras.callbacks.TensorBoard(log_dir=model_path, histogram_freq=1)]
model.fit(x_train, y_train, batch_size=64, epochs=10, callbacks=my_callbacks, validation_split=0.1)

Train on 54000 samples, validate on 6000 samples
Epoch 1/10
Epoch 2/10
Epoch 3/10
Epoch 4/10
Epoch 5/10
Epoch 6/10
Epoch 7/10
Epoch 8/10
Epoch 9/10
Epoch 10/10


<tensorflow.python.keras.callbacks.History at 0x7f9ba81db828>