In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import os
from PIL import Image

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras import preprocessing
from tensorflow.keras.preprocessing import image
from tensorflow.keras.preprocessing.image import ImageDataGenerator,load_img
from tensorflow.keras.layers.experimental.preprocessing import RandomFlip, RandomRotation
from tensorflow.keras.preprocessing import image_dataset_from_directory

from tensorflow.keras.applications.resnet50 import ResNet50
from tensorflow.keras.applications.resnet50 import preprocess_input, decode_predictions

In [3]:
BATCH_SIZE = 32
IMG_SIZE = (160, 160)
directory = "../input/alpaca-dataset-small/dataset"
train_dataset = image_dataset_from_directory(directory,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE,
                                             validation_split=0.2,
                                             subset='training',
                                             seed=42)
validation_dataset = image_dataset_from_directory(directory,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE,
                                             validation_split=0.2,
                                             subset='validation',
                                             seed=42)

In [4]:
class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        plt.imshow(images[i].numpy().astype("uint8"))
        plt.title(class_names[labels[i]])
        plt.axis("off")

In [5]:
def data_augmenter():

    data_augmentation = tf.keras.Sequential()
    data_augmentation.add(RandomFlip('horizontal'))
    data_augmentation.add(RandomRotation(0.2))
    
    return data_augmentation

In [6]:
data_augmentation = data_augmenter()

for image, _ in train_dataset.take(1):
    plt.figure(figsize=(10, 10))
    first_image = image[10]
    for i in range(9):
        ax = plt.subplot(3, 3, i + 1)
        augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
        plt.imshow(augmented_image[0] / 255)
        plt.axis('off')

In [7]:
input_shape = (160,160,3)

base_model = ResNet50(input_shape=input_shape,include_top=False,weights='imagenet')

base_model.trainable = False 

inputs = tf.keras.Input(shape=input_shape) 
x = data_augmenter()(inputs)
x = preprocess_input(x)
x = base_model(x, training=False) 
x = layers.GlobalAveragePooling2D()(x) 
x = layers.Dropout(0.2)(x)
outputs = layers.Dense(1)(x)

model = tf.keras.Model(inputs, outputs)

In [8]:
base_model2 = model.layers[4]
base_model2.trainable = True

fine_tune_at = 130

for layer in base_model2.layers[:fine_tune_at]:
    layer.trainable = False

In [9]:
print("Number of layers in the model: ", len(base_model2.layers))

In [10]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.0001)
metrics= ['accuracy']
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = optimizer,
              metrics=metrics)

In [11]:
Checkpoint_cb = tf.keras.callbacks.ModelCheckpoint("save_model.h5", save_best_only=True)
early_stopping = tf.keras.callbacks.EarlyStopping(
    patience=10,
    min_delta=0.001,
    restore_best_weights=True,
)

callbacks = [Checkpoint_cb, early_stopping]

In [12]:
model.summary()

In [13]:
initial_epochs = 20
history = model.fit(train_dataset, validation_data=validation_dataset, epochs=initial_epochs, callbacks=callbacks)

In [14]:
history_df = pd.DataFrame(history.history)
history_df.loc[:, ['loss', 'val_loss']].plot()
history_df.loc[:, ['accuracy', 'val_accuracy']].plot()

In [15]:
model.load_weights('save_model.h5')

In [16]:
model.evaluate(validation_dataset)