In [None]:
import tensorflow as tf
from helper_functions import create_tensorboard_callback, plot_loss_curves, compare_historys, walk_through_dir, unzip_data
import matplotlib.pyplot as plt


In [None]:
train_dir_10_percent = "10_food_classes_10_percent/train/"
test_dir_10_percent  = "10_food_classes_10_percent/test/"



In [None]:
walk_through_dir("10_food_classes_10_percent")

In [None]:
train_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir_10_percent,
    label_mode = "categorical",
    batch_size=32,
    image_size=(224,224),
    shuffle = True 
)

valid_data_10_percent = tf.keras.preprocessing.image_dataset_from_directory(
    test_dir_10_percent,
    label_mode = "categorical",
    batch_size=32,
    image_size=(224,224),
    shuffle=False
)


In [None]:
data_augmentation = tf.keras.Sequential([
    tf.keras.layers.RandomFlip("horizontal"),
    tf.keras.layers.RandomHeight(0.2),
    tf.keras.layers.RandomWidth(0.2),
    tf.keras.layers.RandomZoom(0.2),
    tf.keras.layers.RandomRotation(0.2)
])

In [None]:
base_model = tf.keras.applications.EfficientNetB0(include_top=False)

base_model.trainable = False

inputs = tf.keras.Input(shape=(224,224,3),name="input_label")

x = data_augmentation(inputs)

x = base_model(x,training=False)

x = tf.keras.layers.GlobalAveragePooling2D(name="globalaveragepooling2D_layer")(x)

outputs = tf.keras.layers.Dense(10,activation="softmax",name="output_layer")(x)

model_1 = tf.keras.Model(inputs, outputs)

model_1.compile(
    loss = tf.keras.losses.CategoricalCrossentropy(),
    optimizer = tf.keras.optimizers.Adam(),
    metrics = ["accuracy"]
)

initial_epochs = 10

history_1 = model_1.fit(
    train_data_10_percent,
    epochs=initial_epochs,
    steps_per_epoch=len(train_data_10_percent),
    validation_data = valid_data_10_percent,
    validation_steps = len(valid_data_10_percent),
    callbacks = [create_tensorboard_callback(
        dir_name="transfer_learning_exercise",
        experiment_name="10_food_classes_10_percent_feature_extraction"),
                tf.keras.callbacks.ModelCheckpoint(
                    "10_food_classes_10_percent_model_1_checkpoint.ckpt",
                    save_best_only=False
                    ,save_freq="epoch"
                    ,save_weights_only=True)]
)

In [None]:
plot_loss_curves(history_1)

In [None]:
model_1.load_weights("10_food_classes_10_percent_model_1_checkpoint.ckpt")

model_1.summary()

In [None]:
for i,layer in enumerate(model_1.layers):
    print(i,layer.name,layer.trainable)

In [None]:
base_model.trainable = True

for i,layer in enumerate(model_1.layers[2].layers):
    print(i,layer.name,layer.trainable)

In [None]:
for i, layer in enumerate(model_1.layers[2].layers[:-20]):
    layer.trainable = False

In [None]:
for i,layer in enumerate(model_1.layers[2].layers[-20:]):
    print(i, layer.name, layer.trainable)

In [None]:
train_dir_all_data = "10_food_classes_all_data/train/"

train_all_data = tf.keras.preprocessing.image_dataset_from_directory(
    train_dir_all_data,
    label_mode="categorical",
    batch_size=32,
    image_size=(224,224),
    shuffle=True
)

In [None]:
fine_tuned_initial_epochs = initial_epochs + 10

model_1.compile(
    loss = tf.keras.losses.CategoricalCrossentropy(),
    optimizer = tf.keras.optimizers.Adam(lr=0.0001),
    metrics = ["accuracy"]
)

history_1_fine_tuned = model_1.fit(
    train_all_data,
    epochs=fine_tuned_initial_epochs,
    steps_per_epoch=len(train_all_data),
    validation_data=valid_data_10_percent,
    validation_steps=len(valid_data_10_percent),
    initial_epoch = history_1.epoch[-1],
    callbacks = [create_tensorboard_callback(
        dir_name="transfer_learning_exercise",
        experiment_name="10_food_classes_10_percent_model_1_fine_tuned"),
                tf.keras.callbacks.ModelCheckpoint(
                    "fine_tuned_model_1_exercise_checkpoint.ckpt",
                    save_best_only=False,
                    save_freq="epoch",
                    save_weights_only=True
                )]
) 

In [None]:
compare_historys(history_1,history_1_fine_tuned,initial_epochs=10) 