In [1]:
from keras.preprocessing.image import ImageDataGenerator

from keras.applications import Xception
from keras import models
from keras import layers
from keras import optimizers

from keras.callbacks import EarlyStopping

import numpy as np
import gc

"""
./format_data/train/
./format_data/val/
"""
train_dir = "./format_data/train/"
test_dir = "./format_data/val/"

feature_size = 299
batch_size = 5

call_back = EarlyStopping(
    monitor='val_loss', 
    patience=0, 
    verbose=0, 
    mode='auto'
)

Using TensorFlow backend.


In [2]:
train_gen = ImageDataGenerator(rescale=1./255,rotation_range = 360,fill_mode='reflect',
                               width_shift_range = 0.2,height_shift_range = 0.2,
                               shear_range = 0.2,
                               horizontal_flip=True,vertical_flip=True)
test_gen = ImageDataGenerator(rescale=1./255)

train_gener = train_gen.flow_from_directory(train_dir,target_size=(feature_size,feature_size),batch_size=batch_size,shuffle = True,seed = 100,class_mode="categorical")
test_gener = test_gen.flow_from_directory(test_dir,target_size=(feature_size,feature_size),batch_size=batch_size,class_mode="categorical")

for data_batch,label_batch in train_gener:
    print(data_batch.shape)
    print(label_batch.shape)
    break

Found 1884 images belonging to 12 classes.
Found 252 images belonging to 12 classes.
(5, 299, 299, 3)
(5, 12)


In [3]:
X_model = Xception(weights=None,include_top=False,input_shape=(feature_size,feature_size,3))
X_model.load_weights("./xception_weights_notop.h5")
X_model.summary()

____________________________________________________________________________________________________
Layer (type)                     Output Shape          Param #     Connected to                     
input_1 (InputLayer)             (None, 299, 299, 3)   0                                            
____________________________________________________________________________________________________
block1_conv1 (Conv2D)            (None, 149, 149, 32)  864         input_1[0][0]                    
____________________________________________________________________________________________________
block1_conv1_bn (BatchNormalizat (None, 149, 149, 32)  128         block1_conv1[0][0]               
____________________________________________________________________________________________________
block1_conv1_act (Activation)    (None, 149, 149, 32)  0           block1_conv1_bn[0][0]            
___________________________________________________________________________________________

In [4]:
model = models.Sequential()

model.add(X_model)
model.add(layers.Flatten())
model.add(layers.Dense(650,activation='relu'))
model.add(layers.Dropout(0.5))
model.add(layers.Dense(12,activation='softmax'))

Instructions for updating:
keep_dims is deprecated, use keepdims instead


In [5]:
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
xception (Model)             (None, 10, 10, 2048)      20861480  
_________________________________________________________________
flatten_1 (Flatten)          (None, 204800)            0         
_________________________________________________________________
dense_1 (Dense)              (None, 650)               133120650 
_________________________________________________________________
dropout_1 (Dropout)          (None, 650)               0         
_________________________________________________________________
dense_2 (Dense)              (None, 12)                7812      
Total params: 153,989,942
Trainable params: 153,935,414
Non-trainable params: 54,528
_________________________________________________________________


In [6]:
model.compile(loss='categorical_crossentropy',optimizer=optimizers.Adam(lr=1e-4),metrics=['acc'])
history = model.fit_generator(train_gener,steps_per_epoch=180,epochs=15,
                              validation_data=test_gener,validation_steps=25,callbacks=[call_back])

Instructions for updating:
keep_dims is deprecated, use keepdims instead
Instructions for updating:
keep_dims is deprecated, use keepdims instead
Epoch 1/15
Epoch 2/15
Epoch 3/15
Epoch 4/15
Epoch 5/15
Epoch 6/15
Epoch 7/15


In [7]:
model.save('Xception_whole.h5')