In [None]:
import numpy as np
import pandas as pd
from PIL import Image
import os

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.xception import Xception, preprocess_input

print("Num GPUs Available: ", len(tf.config.experimental.list_physical_devices('GPU')))
import datetime


import matplotlib.pyplot as plt
import seaborn as sns

# 1. Load Data

In [None]:
# Switch to right inventory
os.chdir('data')
os.listdir()

In [None]:
# Load data with tf ImageDataGenerator

train_datagen = ImageDataGenerator(
        preprocessing_function=preprocess_input,
        shear_range=0.1,
        zoom_range=0.1,
        horizontal_flip=True)
train_ds = train_datagen.flow_from_directory(
        'train',
        target_size=(224, 224),
        batch_size=32,
        class_mode='categorical')


val_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
val_ds = val_datagen.flow_from_directory(
        'validation',
        target_size=(224, 224),
        batch_size=32,
        class_mode='categorical')

test_datagen = ImageDataGenerator(preprocessing_function=preprocess_input)
test_ds = test_datagen.flow_from_directory(
        'test',
        target_size=(224, 224),
        batch_size=32,
        shuffle=False, # NOTE: Avoid shuffling, otherwise it is impossible to make the confusion matrix
        class_mode='categorical')


In [None]:
# Create inverse map to get class names.
inv_map = {v: k for k, v in test_ds.class_indices.items()}
inv_map

# 2. Load Xception model and train new layers for bird images

In [None]:
#instantiate a base model with pre-trained weigts.
base_model=Xception(
    include_top=False,
    weights="imagenet",
    input_shape=(224,224,3))

#freeze the base model
base_model.trainable = False

In [None]:
from tensorflow.keras import layers

num_classes = 226 # Number of bird labels

#Create new model on top
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,Flatten,Dropout, GlobalAveragePooling2D

model=Sequential()
model.add(base_model)
model.add(GlobalAveragePooling2D())
model.add(Dense(2048,activation='relu',kernel_initializer='he_normal'))
model.add(Dropout(0.35))
model.add(Dense(1024,activation='relu',kernel_initializer='he_normal'))
model.add(Dropout(0.35))
model.add(Dense(num_classes,activation='softmax',kernel_initializer='glorot_normal'))


model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.2, momentum=0.9, decay=0.01),
              loss="categorical_crossentropy",
              metrics=['accuracy'])
             

# log_dir = "logs/fit/" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
# tensorboard_callback = tf.keras.callbacks.TensorBoard(log_dir=log_dir, histogram_freq=1)

history = model.fit(train_ds,
          epochs=30,
          validation_data=val_ds,
          workers=3,
          use_multiprocessing=True)


In [None]:
model.save('../bird_classifier_xception_30eps')

# 3. Finetune layers

In [None]:
# Unlock Xception layers and finetune them.
for layer in base_model.layers:
    layer.trainable = True

In [None]:
# Finetune
model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.01, momentum=0.9, decay=0.001),
              loss="categorical_crossentropy",
              metrics=['accuracy'])

history2 = model.fit(train_ds,
          epochs=20,
          validation_data=val_ds,
          workers=3,
          use_multiprocessing=True)

In [None]:
# Finetune some more
history3 = model.fit(train_ds,
          epochs=40,
          validation_data=val_ds,
          workers=3,
          use_multiprocessing=True)

In [None]:
# Finetune even more
history4 = model.fit(train_ds,
          epochs=10,
          validation_data=val_ds,
          workers=3,
          use_multiprocessing=True)

In [None]:
# Save model
model.save('../bird_classifier_xception_110eps')

# Final tests

The bird classifier achieves an accuracy of 98.5% on the test set, and as the visualized confusion matrix underneath shows, t



In [None]:
model = tf.keras.models.load_model("../models/bird_classifier_xception_100eps")

In [None]:
preds = model.predict(val_ds)
y_pred = preds.argmax(axis=-1)
y_true = test_ds.classes

In [None]:
model.evaluate(test_ds)

In [None]:
print("Precision --> {}".format(precision_score(y_true, y_pred, average='weighted')))
print("Recall --> {}".format(recall_score(y_true, y_pred, average='weighted')))

In [None]:
from sklearn.metrics import confusion_matrix, f1_score, log_loss, precision_score, recall_score

cm = confusion_matrix(y_true, y_pred)
cm

# Plot confusion matrix
fig, ax = plt.subplots(figsize=(14,14))
sns.heatmap(cm, cmap='tab10')


fig.savefig('Bird_Pred_Xception_Confusion_Matrix_Visualized')