In [1]:
from keras.datasets import cifar10
from resnet_builder import resnet # 這是從 resnet_builder.py 中直接 import 撰寫好的 resnet 函數
from keras.models import Model
from keras.optimizers import Adam
from keras.utils import to_categorical
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint
from keras.callbacks import EarlyStopping
import os

Using TensorFlow backend.


In [2]:
batch_size = 64 # batch 的大小，如果出現 OOM error，請降低這個值
num_classes = 10 # 類別的數量，Cifar 10 共有 10 個類別
epochs = 30 # 訓練整個資料集共 30個循環

In [3]:
# 讀取資料集並作前處理
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')

x_train = x_train / 255.
x_test = x_test / 255.
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

In [4]:
# 建立 ImageDataGenerator，並指定我們要做資料增強的數值範圍
data_generator = ImageDataGenerator(
    zca_whitening=False,
    rotation_range=40,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest')

In [5]:
# 建立 ResNet 模型
model = resnet(input_shape=(32,32,3)) 
model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
input_1 (InputLayer)            (None, 32, 32, 3)    0                                            
__________________________________________________________________________________________________
conv2d_1 (Conv2D)               (None, 32, 32, 16)   448         input_1[0][0]                    
__________________________________________________________________________________________________
batch_normalization_1 (BatchNor (None, 32, 32, 16)   64          conv2d_1[0][0]                   
__________________________________________________________________________________________________
activation_1 (Activation)       (None, 32, 32, 16)   0           batch_normalization_1[0][0]      
____________________________________________________________________________________________

In [6]:
model.compile(loss='categorical_crossentropy',
              optimizer=Adam(lr=1e-5),
              metrics=['accuracy'])

In [7]:
file_path = "model_weights.hd5"
print(file_path)
callbackCheckpoint  = ModelCheckpoint(os.path.join(os.path.abspath(os.getcwd()), file_path), monitor='val_loss', save_weights_only=True,verbose=1,save_best_only=True, period=1)

if os.path.exists(file_path):
    model.load_weights(file_path)
    # 若成功加载前面保存的参数，输出下列信息
    print("checkpoint_loaded")
    
callbacks = [callbackCheckpoint]

model_weights.hd5
checkpoint_loaded


In [8]:
history = model.fit_generator(data_generator.flow(x_train, y_train, batch_size),
    steps_per_epoch=round(len(x_train)/batch_size),
    epochs=epochs,
    verbose=1,
    validation_data=(x_test, y_test),
    callbacks=callbacks)

Epoch 1/30

Epoch 00001: val_loss improved from inf to 1.12584, saving model to C:\Users\bawan.wang\Desktop\ai100\homework\model_weights.hd5
Epoch 2/30

Epoch 00002: val_loss improved from 1.12584 to 1.09570, saving model to C:\Users\bawan.wang\Desktop\ai100\homework\model_weights.hd5
Epoch 3/30

Epoch 00003: val_loss did not improve from 1.09570
Epoch 4/30

Epoch 00004: val_loss did not improve from 1.09570
Epoch 5/30

Epoch 00005: val_loss improved from 1.09570 to 1.08966, saving model to C:\Users\bawan.wang\Desktop\ai100\homework\model_weights.hd5
Epoch 6/30

Epoch 00006: val_loss improved from 1.08966 to 1.08882, saving model to C:\Users\bawan.wang\Desktop\ai100\homework\model_weights.hd5
Epoch 7/30

Epoch 00007: val_loss did not improve from 1.08882
Epoch 8/30

Epoch 00008: val_loss improved from 1.08882 to 1.08801, saving model to C:\Users\bawan.wang\Desktop\ai100\homework\model_weights.hd5
Epoch 9/30

Epoch 00009: val_loss improved from 1.08801 to 1.08656, saving model to C:\Use

In [9]:
score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])

Test loss: 1.0423324500083924
Test accuracy: 0.7534999847412109
