# Importing stuff 🧭

In [None]:
import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import tensorflow as tf
import matplotlib.pyplot as plt
from PIL import Image
from glob import glob
import cv2
import shutil
import os

- Checking version and system configuration of accelerators (GPU/TPUs).
- Seting seed for reprodutibility.

In [None]:
print(tf.__version__)
print(tf.config.list_physical_devices())
tf.random.set_seed(25081994)

Erasing local temporary images, if there are any.

In [None]:
!rm -rf *.jpg
!rm -rf *.png

Converting and saving images from dataset to local storage.

In [None]:
for image_path in glob('../input/pokemon-images-and-types/images/images/*'):
    if image_path.endswith('.png'):
        img = cv2.imread(image_path, cv2.IMREAD_UNCHANGED)
        if img.shape[2] == 4:
            alpha_mask = img[:,:,3] == 0
            img[alpha_mask] = [255, 255, 255, 255]
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2BGR)
        
        image_file_name = os.path.basename(image_path)
        image_file_name = image_file_name[:-3] + 'jpg'
        cv2.imwrite(image_file_name, img, [int(cv2.IMWRITE_JPEG_QUALITY), 100])
    else:
        image_file_name = os.path.basename(image_path)
        shutil.copy(image_path, image_file_name)

Image data generator to load images into the model.

Some random parameters allow for more generalization capability of the network mainly when dealing with small datasets.

In [None]:
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(
    samplewise_center=True,
    samplewise_std_normalization=True,
    rotation_range=20,
    width_shift_range=0.2,
    height_shift_range=0.2,
    validation_split=0.1,
    horizontal_flip=True,
)

Loading table of Pokémons 🐴 image path, name and types.

In [None]:
df = pd.read_csv('../input/pokemon-images-and-types/pokemon.csv')
df['filename'] = df['Name'].apply(lambda x: x+'.jpg')
df

Where are the images to train from:

In [None]:
IMAGES_DIRECTORY = './'

Setting traning and validation flows.

In [None]:
training_flow = image_generator.flow_from_dataframe(
    df,
    directory=IMAGES_DIRECTORY,
    x_col='filename',
    y_col='Type1',
    subset='training',
    batch_size=32,
)
validation_flow = image_generator.flow_from_dataframe(
    df,
    directory=IMAGES_DIRECTORY,
    x_col='filename',
    y_col='Type1',
    subset='validation',
    batch_size=32,
)

Pokémon types in the dataset.

In [None]:
CLASSES = {**training_flow.class_indices, **validation_flow.class_indices}
CLASSES

Loading MobileNetV2 and switching the outmost layers to adapt for our use-case.

Freezing the original lower layers for now.

In [None]:
base_model = tf.keras.applications.MobileNetV2(
    include_top=False
)
base_model.trainable = False

x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(128, activation='relu')(x)
preds = tf.keras.layers.Dense(len(CLASSES), activation='softmax')(x)

model = tf.keras.Model(inputs=base_model.input, outputs=preds)

model.compile(optimizer='Adam', loss='categorical_crossentropy', metrics=['accuracy'])

Setting callback function that will stop the training when there are 15 epochs without improvement on the validation.

In [None]:
callback = tf.keras.callbacks.EarlyStopping(
    monitor='loss',
    patience=15,
    restore_best_weights=True
)

Training the new layers inserted in the model.

In [None]:
model.fit(
    training_flow,
    validation_data=validation_flow,
    epochs=500,
    callbacks=[callback]
)

Allowing the training of the full model.

Setting a small learning rate.

In [None]:
base_model.trainable = True
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(optimizer=opt, loss='categorical_crossentropy', metrics=['accuracy'])

Some info about the model architecture.

In [None]:
model.summary()

Training the full model.

In [None]:
model.fit(
    training_flow,
    validation_data=validation_flow,
    epochs=500,
    callbacks=[callback]
)

Saving the model for further use.

In [None]:
model.save('model')