<a href="https://colab.research.google.com/github/SaashaJoshi/cats-dogs-classification/blob/master/CatsDogs_transfer_learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
import os
import tensorflow as tf
from tensorflow import keras
from keras import layers
from keras import Model
from keras.optimizers import RMSprop

In [0]:
!wget --no-check-certificate \
    https://storage.googleapis.com/mledu-datasets/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5 \
    -O /tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5

In [0]:
from keras.applications.inception_v3 import InceptionV3
local_weights_file='/tmp/inception_v3_weights_tf_dim_ordering_tf_kernels_notop.h5'

pre_trained_model=InceptionV3(input_shape=(150, 150, 3), 
                             include_top=False, 
                             weights=None)

pre_trained_model.load_weights(local_weights_file)

for layer in pre_trained_model.layers:
  layer.trainable=False

In [0]:
last_layer=pre_trained_model.get_layer('mixed7')
print('Last layer shape: ', last_layer.output_shape)
last_output=last_layer.output

In [0]:
x=layers.Flatten()(last_output)
x=layers.Dense(1024, activation='relu')(x)
x=layers.Dropout(0.2)(x)
x=layers.Dense(1, activation='sigmoid')(x)

model=Model(pre_trained_model.input, x)

In [0]:
model.compile(optimizer=RMSprop(lr=0.0001), 
             loss='binary_crossentropy', 
             metrics=['accuracy'])

In [0]:
_URL='https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
zip_dir=keras.utils.get_file('cats_and_dogs_filtered.zip', origin=_URL, extract=True)

In [0]:
base_dir=os.path.join(os.path.dirname(zip_dir), 'cats_and_dogs_filtered')
train_dir=os.path.join(base_dir, 'train')
val_dir=os.path.join(base_dir, 'validation')

train_cats=os.path.join(train_dir, 'cats')
train_dogs=os.path.join(train_dir, 'dogs')
val_cats=os.path.join(val_dir, 'cats')
val_dogs=os.path.join(val_dir, 'dogs')

In [0]:
train_images_gen=keras.preprocessing.image.ImageDataGenerator(rescale=1./255, 
                                                             rotation_range=40, 
                                                             width_shift_range=0.2, 
                                                             height_shift_range=0.2, 
                                                             shear_range=0.2, 
                                                             zoom_range=0.2, 
                                                             horizontal_flip=True, 
                                                             fill_mode='nearest')

val_images_gen=keras.preprocessing.image.ImageDataGenerator(rescale=1./255)

In [0]:
train_data_gen=train_images_gen.flow_from_directory(batch_size=32, 
                                                   directory=train_dir, 
                                                   shuffle=True, 
                                                   target_size=(150, 150), 
                                                   class_mode='binary')

val_data_gen=val_images_gen.flow_from_directory(batch_size=32, 
                                                   directory=val_dir, 
                                                   shuffle=False, 
                                                   target_size=(150, 150), 
                                                   class_mode='binary')

In [0]:
history=model.fit_generator(train_data_gen, 
                            epochs=20, 
                            steps_per_epoch=100, 
                            validation_data=val_data_gen, 
                            validation_steps=50)

In [0]:
import matplotlib.pyplot as plt

acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range=range(len(acc))  #epochs_range=range(20)

plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.show()