# Transfer learning in CIFAR 10 using VGG

Here I have tried to build a model using VGG. Please see the image on the link below<br>

Conv block1 to bloc 4 are kept frozen i.e. their weights will remain unchanged during training. <br>
Conv block 5 + 2 new fully connected layers are trained to get results. 

<img src ="https://github.com/bhavsarpratik/Deep_Learning_Notebooks/raw/master/data/images/vgg16_original.png" width="40%">

In [1]:
import os
import numpy as np
from keras import applications, optimizers
from keras.callbacks import ModelCheckpoint
from keras.datasets import cifar10
from keras.utils import np_utils
from keras.layers import Input,Dense, Dropout, Flatten
from keras.models import Sequential,Model
from keras.preprocessing.image import ImageDataGenerator

Using TensorFlow backend.


In [2]:
num_classes = 10
epochs = 3
batch_size = 32
img_width, img_height =32,32

input_tensor = Input(shape=(img_width, img_height, 3))
base_model = applications.VGG16(weights='imagenet',include_top= False,input_tensor=input_tensor)
print('VGG model')
base_model.summary()

VGG model
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         (None, 32, 32, 3)         0         
_________________________________________________________________
block1_conv1 (Conv2D)        (None, 32, 32, 64)        1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 32, 32, 64)        36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 16, 16, 64)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 16, 16, 128)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 16, 16, 128)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 8, 8, 128)         0         


<img src ="https://github.com/bhavsarpratik/Deep_Learning_Notebooks/raw/master/data/images/vgg16_modified.png" width="40%">

In [3]:
top_model = Sequential()
top_model.add(Flatten(input_shape=base_model.output_shape[1:]))
top_model.add(Dense(256, activation='relu'))
top_model.add(Dropout(0.5))
top_model.add(Dense(10, activation='sigmoid'))

model = Model(inputs= base_model.input, outputs= top_model(base_model.output))

# set the first 15 layers (up to the conv block 4) to non-trainable (weights will not be updated)
for layer in model.layers[:15]:
    layer.trainable = False

# compile the model with a SGD/momentum optimizer and a very slow learning rate.
model.compile(loss='binary_crossentropy',optimizer=optimizers.SGD(lr=1e-4, momentum=0.9),metrics=['accuracy'])

Model weights loaded


In [10]:
# model.load_weights('Keras_VGG_CIFAR10.h5')
print('Model weights loaded')

(X_train, y_train), (X_test, y_test) = cifar10.load_data()
Y_train = np_utils.to_categorical(y_train, num_classes)
Y_test = np_utils.to_categorical(y_test, num_classes)

datagen = ImageDataGenerator(
    featurewise_center=True,
    featurewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    horizontal_flip=True)

datagen.fit(X_train)
print(X_train.shape)
print(X_test.shape)

train_generator = datagen.flow(X_train, Y_train, batch_size)
validation_generator = datagen.flow(X_train, Y_train, batch_size)

checkpoint = ModelCheckpoint('Keras_VGG_CIFAR10-{epoch:02d}-{val_acc:.3f}.h5', monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=True, mode='auto', period=1)
callbacks_list = [checkpoint]

model.fit_generator(
    train_generator,
    steps_per_epoch=len(X_train) // batch_size,
    epochs=epochs,
    validation_data=validation_generator,
    validation_steps=len(X_test) // batch_size,
    callbacks=callbacks_list
    )

Model weights loaded
(50000, 32, 32, 3)
(10000, 32, 32, 3)
Epoch 1/3
Epoch 2/3
Epoch 3/3


<keras.callbacks.History at 0x1c22ba708d0>