# 以 MNIST 資料集測試遷移學習
學習前五種數字 (0,1,2,3,4) 的特徵，轉移到後五種數字 (5,6,7,8,9) 的分類辨識。

In [1]:
import numpy as np
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Dropout
from tensorflow.keras.utils import to_categorical

# 指定亂數種子
seed = 7
np.random.seed(seed)

# 載入資料集
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 建立2個資料集，一個數字小於 5 [0,1,2,3,4]，一個數字大於等於 5 [5,6,7,8,9]
x_train_lt5 = x_train[y_train < 5]
y_train_lt5 = y_train[y_train < 5]
x_test_lt5 = x_test[y_test < 5]
y_test_lt5 = y_test[y_test < 5]

x_train_gte5 = x_train[y_train >= 5]
y_train_gte5 = y_train[y_train >= 5] - 5
x_test_gte5 = x_test[y_test >= 5]
y_test_gte5 = y_test[y_test >= 5] - 5

# 將圖片轉換成 4D 張量
x_train_lt5 = x_train_lt5.reshape((x_train_lt5.shape[0], 28, 28, 1)).astype('float32')
x_test_lt5 = x_test_lt5.reshape((x_test_lt5.shape[0], 28, 28, 1)).astype('float32')
x_train_gte5 = x_train_gte5.reshape((x_train_gte5.shape[0], 28, 28, 1)).astype('float32')
x_test_gte5 = x_test_gte5.reshape((x_test_gte5.shape[0], 28, 28, 1)).astype('float32')

# 因為是固定範圍, 所以執行正規化, 從 0-255 至 0-1
x_train_lt5 = x_train_lt5 / 255
x_test_lt5 = x_test_lt5 / 255
x_train_gte5 = x_train_gte5 / 255
x_test_gte5 = x_test_gte5 / 255

# One-hot編碼
y_train_lt5 = to_categorical(y_train_lt5, 5)
y_test_lt5 = to_categorical(y_test_lt5, 5)
y_train_gte5 = to_categorical(y_train_gte5, 5)
y_test_gte5 = to_categorical(y_test_gte5, 5)

# 定義模型
model = Sequential()
model.add(Conv2D(8, kernel_size=(3, 3), input_shape=(28, 28, 1), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(8, kernel_size=(3, 3), activation='relu'))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.25))
model.add(Dense(5, activation='softmax'))

model.summary()   # 顯示模型摘要資訊

# 編譯模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 訓練模型
history = model.fit(x_train_lt5, y_train_lt5, validation_split=0.2, epochs=5, batch_size=128, verbose=2)

# 評估模型
loss, accuracy = model.evaluate(x_test_lt5, y_test_lt5, verbose=0)
print('測試資料集的準確度 = {:.2f}'.format(accuracy))


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 8)         80        
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 8)         0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 8)         584       
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 5, 5, 8)           0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 200)               0         
                              

In [2]:
# 用5之前的訓練網路，來測試其他資料
# 評估模型
loss, accuracy = model.evaluate(x_test_lt5, y_test_lt5, verbose=0)
print('測試資料集(lt5)的準確度 = {:.2f}'.format(accuracy))

loss, accuracy = model.evaluate(x_test_gte5, y_test_gte5, verbose=0)
print('測試資料集(gte5)的準確度 = {:.2f}'.format(accuracy))


測試資料集(lt5)的準確度 = 0.99
測試資料集(gte5)的準確度 = 0.36


In [3]:
# 顯示各神經層
print(len(model.layers))
for i in range(len(model.layers)):
    print(i, model.layers[i])

# 凍結上層模型
for i in range(4):
    model.layers[i].trainable = False


8
0 <keras.src.layers.convolutional.conv2d.Conv2D object at 0x7fb31abe2380>
1 <keras.src.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb31abe2a70>
2 <keras.src.layers.convolutional.conv2d.Conv2D object at 0x7fb31ac3a8f0>
3 <keras.src.layers.pooling.max_pooling2d.MaxPooling2D object at 0x7fb31ac3a440>
4 <keras.src.layers.reshaping.flatten.Flatten object at 0x7fb31ac39360>
5 <keras.src.layers.core.dense.Dense object at 0x7fb31ac3be80>
6 <keras.src.layers.regularization.dropout.Dropout object at 0x7fb31ac3a650>
7 <keras.src.layers.core.dense.Dense object at 0x7fb3147149a0>


In [4]:
# 編譯模型
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])

# 訓練模型
history = model.fit(x_train_gte5, y_train_gte5, validation_split=0.2, epochs=5, batch_size=128, verbose=2)

# 評估模型
loss, accuracy = model.evaluate(x_test_gte5, y_test_gte5, verbose=0)
print('測試資料集的準確度 = {:.2f}'.format(accuracy))


Epoch 1/5
184/184 - 4s - loss: 0.7101 - accuracy: 0.7662 - val_loss: 0.2126 - val_accuracy: 0.9441 - 4s/epoch - 20ms/step
Epoch 2/5
184/184 - 3s - loss: 0.2256 - accuracy: 0.9291 - val_loss: 0.1276 - val_accuracy: 0.9650 - 3s/epoch - 15ms/step
Epoch 3/5
184/184 - 3s - loss: 0.1586 - accuracy: 0.9523 - val_loss: 0.1018 - val_accuracy: 0.9740 - 3s/epoch - 15ms/step
Epoch 4/5
184/184 - 4s - loss: 0.1331 - accuracy: 0.9598 - val_loss: 0.0914 - val_accuracy: 0.9759 - 4s/epoch - 23ms/step
Epoch 5/5
184/184 - 3s - loss: 0.1181 - accuracy: 0.9640 - val_loss: 0.0796 - val_accuracy: 0.9772 - 3s/epoch - 15ms/step
測試資料集的準確度 = 0.98


In [None]:
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 conv2d (Conv2D)             (None, 26, 26, 8)         80        
                                                                 
 max_pooling2d (MaxPooling2  (None, 13, 13, 8)         0         
 D)                                                              
                                                                 
 conv2d_1 (Conv2D)           (None, 11, 11, 8)         584       
                                                                 
 max_pooling2d_1 (MaxPoolin  (None, 5, 5, 8)           0         
 g2D)                                                            
                                                                 
 flatten (Flatten)           (None, 200)               0         
                                                                 
 dense (Dense)               (None, 64)                1