<a href="https://colab.research.google.com/github/AlirezaAhadipour/Transfer-Learning/blob/main/Transfer-Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow_datasets as tfds
import tensorflow as tf

In [3]:
# load TenserFlow Flowers dataset
dataset, info = tfds.load('tf_flowers', as_supervised=True, with_info=True)

In [4]:
info

tfds.core.DatasetInfo(
    name='tf_flowers',
    full_name='tf_flowers/3.0.1',
    description="""
    A large set of images of flowers
    """,
    homepage='https://www.tensorflow.org/tutorials/load_data/images',
    data_path='/root/tensorflow_datasets/tf_flowers/3.0.1',
    file_format=tfrecord,
    download_size=218.21 MiB,
    dataset_size=221.83 MiB,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    splits={
        'train': <SplitInfo num_examples=3670, num_shards=2>,
    },
    citation="""@ONLINE {tfflowers,
    author = "The TensorFlow Team",
    title = "Flowers",
    month = "jan",
    year = "2019",
    url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }""",
)

In [5]:
dataset_size = info.splits['train'].num_examples
class_names = info.features['label'].names
n_classes = info.features['label'].num_classes

In [6]:
class_names

['dandelion', 'daisy', 'tulips', 'sunflowers', 'roses']

In [9]:
train, val, test = tfds.load('tf_flowers', split=['train[:80%]', 'train[80%:90%]', 'train[90%:]'], as_supervised=True)

In [10]:
def preprocess(image, label):
  resized_image = tf.image.resize(image, [299,299])
  processed_image = tf.keras.applications.xception.preprocess_input(resized_image)
  return processed_image, label

In [11]:
batch_size = 32
train = train.map(preprocess).batch(batch_size).prefetch(1)
val = val.map(preprocess).batch(batch_size).prefetch(1)
test = test.map(preprocess).batch(batch_size).prefetch(1)

In [12]:
base_model = tf.keras.applications.xception.Xception(weights='imagenet', include_top=False)

avg_pooling = tf.keras.layers.GlobalAveragePooling2D()(base_model.output)
output = tf.keras.layers.Dense(n_classes, activation='softmax')(avg_pooling)

model = tf.keras.Model(inputs=base_model.input, outputs=output)

Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/xception/xception_weights_tf_dim_ordering_tf_kernels_notop.h5


In [13]:
# Freez the base model's layers

for layer in base_model.layers:
  layer.trainable = False

In [14]:
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),
              metrics=['accuracy'])

In [15]:
history = model.fit(train, epochs=5, validation_data=val)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


In [16]:
# Unreez the base model's layers

for layer in base_model.layers:
  layer.trainable = True

In [17]:
model.compile(loss='sparse_categorical_crossentropy',
              optimizer=tf.keras.optimizers.SGD(learning_rate=0.01, momentum=0.9),
              metrics=['accuracy'])

In [18]:
history = model.fit(train, epochs=5, validation_data=val)

Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5
