In [None]:

import tensorflow as tf
import os
os.makedirs('results', exist_ok=True)
os.makedirs('models', exist_ok=True)


In [None]:

import tensorflow_datasets as tfds
dataset, info = tfds.load('tf_flowers', with_info=True, as_supervised=True)


In [None]:

AUTOTUNE = tf.data.AUTOTUNE
IMG_SIZE = 224
BATCH_SIZE = 32

def format_image(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = image / 255.0
    return image, label

ds = dataset['train'].map(format_image, num_parallel_calls=AUTOTUNE)
ds = ds.shuffle(1024)
num_examples = info.splits['train'].num_examples
val_size = int(0.1 * num_examples)

train_ds = ds.skip(val_size).batch(BATCH_SIZE).prefetch(AUTOTUNE)
val_ds = ds.take(val_size).batch(BATCH_SIZE).prefetch(AUTOTUNE)
num_classes = info.features['label'].num_classes


In [None]:

from tensorflow.keras import layers
base_model = tf.keras.applications.EfficientNetB0(include_top=False, input_shape=(IMG_SIZE,IMG_SIZE,3), weights='imagenet')
base_model.trainable = False

inputs = layers.Input(shape=(IMG_SIZE,IMG_SIZE,3))
x = base_model(inputs, training=False)
x = layers.GlobalAveragePooling2D()(x)
x = layers.Dropout(0.3)(x)
outputs = layers.Dense(num_classes, activation='softmax')(x)

model = tf.keras.Model(inputs, outputs)
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
model.summary()


In [None]:

history = model.fit(train_ds, validation_data=val_ds, epochs=6)


In [None]:

base_model.trainable = True
fine_tune_at = 100
for layer in base_model.layers[:fine_tune_at]:
    layer.trainable = False

model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),
              loss='sparse_categorical_crossentropy', metrics=['accuracy'])

history_fine = model.fit(train_ds, validation_data=val_ds, epochs=6)


In [None]:

model.save('models/flowers_efficientnetb0.h5')

import matplotlib.pyplot as plt
plt.figure(figsize=(6,4))
plt.plot(history.history.get('accuracy', []))
plt.plot(history.history.get('val_accuracy', []))
plt.plot(history_fine.history.get('accuracy', []))
plt.plot(history_fine.history.get('val_accuracy', []))
plt.savefig('results/flowers_transfer_accuracy.png')
plt.show()
