In [10]:
import numpy as np
import pandas as pd
import os,shutil
from keras.utils import image_utils
from keras.preprocessing.image import ImageDataGenerator
from keras.applications.resnet import ResNet50
from keras.models import Sequential
from keras.layers import Dense, GlobalAveragePooling2D
from keras import Model
from sklearn.model_selection import train_test_split
from itertools import cycle
from PIL import Image
from matplotlib import pyplot as plt

In [11]:
# Making new folder for the data
if not os.path.isdir('./res'):
    os.mkdir('./res')
PARENT = './pokemon_data/'
DATA = './res/'
IMG_HEIGHT = 256
IMG_WIDTH = 256
IMG_DIMS = (IMG_WIDTH, IMG_HEIGHT)
IMG_CHANNELS = 3

In [12]:
# Making folders for new train/test split
sets = ['X_train_img','X_test_img','X_train_aug_img']
for folder in sets:
    if not os.path.isdir(os.path.join(DATA,folder)):
        os.mkdir(os.path.join(DATA,folder))

In [13]:
# Putting folders of each class into the new folders
for set in sets:
    for label in os.listdir(PARENT):
        class_path = os.path.join(DATA, set, label)
        os.makedirs(class_path, exist_ok=True)

In [14]:
X = []
y = []
for folder in os.listdir(PARENT):
    folder_path = os.path.join(PARENT, folder)
    for filename in os.listdir(folder_path):
        filepath = os.path.join(folder_path, filename)
        X.append(filepath)
        y.append(folder)

In [15]:
# Making a train/test split
X_train, X_test, y_train, y_test = train_test_split(
    X, 
    y, 
    test_size=0.2, 
    shuffle=True, 
    random_state=1410,
    )

In [16]:
X_train

['./pokemon_data/Exeggutor/34201881295c423597a13648a237785c.jpg',
 './pokemon_data/Spearow/6e2f1ee87cd44fb48b058a33c16616ac.jpg',
 './pokemon_data/Metapod/5fd989ff68744e86949076c9052b5383.jpg',
 './pokemon_data/Diglett/3a67696fa46a470eb8a8902d2ee158a2.jpg',
 './pokemon_data/Golem/395e496d99074239a5c580cc9be55825.jpg',
 './pokemon_data/Squirtle/00000059.png',
 './pokemon_data/Muk/db7686074ea14fc18f4a8e7ad1b1aada.jpg',
 './pokemon_data/Arcanine/148dfa70c8474b43ae987e34813839c2.jpg',
 './pokemon_data/Vaporeon/c7a1528c69b449a2b72b26b3844c2874.jpg',
 './pokemon_data/Dugtrio/9892cb112dff4eed81e41ffae7e77eb1.jpg',
 './pokemon_data/Snorlax/574c1ccc6106499a9551f9fdecda4935.jpg',
 './pokemon_data/Squirtle/00000183.jpg',
 './pokemon_data/Lickitung/11e77092a8ee40faa4c2463195999cd4.jpg',
 './pokemon_data/Pikachu/00000153.gif',
 './pokemon_data/Jynx/b4eedf82e2fe48e19f73653784992fd4.jpg',
 './pokemon_data/Porygon/292164f7f7224e9c8f68769c053b5c96.jpg',
 './pokemon_data/Spearow/b93b78dbe62c4287be9ff244

In [17]:
TRAIN_SET = sets[0]
TEST_SET = sets[1]
def copy_to_new_folders(data, labels, set):
    # If folders already have images don't copy again
    num_imgs = sum([len([filename for filename in os.listdir(os.path.join(DATA, set, label)) ]) for label in labels])
    if num_imgs > 0:
        print(f"Folders in {DATA + set} already contain {num_imgs} images.")
        print("If you wish to copy again remove the folder and rerun the program")
        return
    print(f"COPYING IMAGES to {DATA + set}...")
    for img, label in zip(data, labels):
        filename = img.split('/')[-1]
        src = img
        dst = os.path.join(DATA, set, label, filename)
        shutil.copy(src,dst)
copy_to_new_folders(X_train, y_train, TRAIN_SET)
copy_to_new_folders(X_test, y_test, TEST_SET)

Folders in ./res/X_train_img already contain 1619296 images.
If you wish to copy again remove the folder and rerun the program
Folders in ./res/X_test_img already contain 30480 images.
If you wish to copy again remove the folder and rerun the program


In [18]:
def count_instances(data_path):
    pokemon_all = os.listdir(data_path)
    max = 0
    min = float('inf')
    avg = 0
    for pokemon in pokemon_all:
        p = os.path.join(data_path,pokemon)
        num_pokemon = len(os.listdir(p))
        if num_pokemon > max:
            max = num_pokemon
        if num_pokemon < min:
            min = num_pokemon
        avg += num_pokemon
        print(pokemon + ' count is: ' + str(num_pokemon))
    avg = avg/len(pokemon_all)
    print(f"Max: {max} Min: {min} Avg: {avg}")
    return max
        
most_imgs = count_instances(os.path.join(DATA, TRAIN_SET))

Abra count is: 224
Aerodactyl count is: 224
Alakazam count is: 224
Arbok count is: 224
Arcanine count is: 224
Articuno count is: 224
Beedrill count is: 224
Bellsprout count is: 224
Blastoise count is: 224
Bulbasaur count is: 224
Butterfree count is: 224
Caterpie count is: 224
Chansey count is: 224
Charizard count is: 224
Charmander count is: 224
Charmeleon count is: 224
Clefable count is: 224
Clefairy count is: 224
Cloyster count is: 224
Cubone count is: 224
Dewgong count is: 224
Diglett count is: 224
Ditto count is: 224
Dodrio count is: 224
Doduo count is: 224
Dragonair count is: 224
Dragonite count is: 224
Dratini count is: 224
Drowzee count is: 224
Dugtrio count is: 224
Eevee count is: 224
Ekans count is: 224
Electabuzz count is: 224
Electrode count is: 224
Exeggcute count is: 224
Exeggutor count is: 224
Farfetchd count is: 224
Fearow count is: 224
Flareon count is: 224
Gastly count is: 224
Gengar count is: 224
Geodude count is: 224
Gloom count is: 224
Golbat count is: 224
Goldeen c

In [19]:
# Create ImageDataGenerator for data augmentation
datagen = ImageDataGenerator(rescale=1.0/255,
                            width_shift_range=0.2,
                             height_shift_range=0.2,
                             shear_range=0.2,
                             zoom_range=0.2,
                             horizontal_flip=True,
                             rotation_range=30,
                             fill_mode='nearest'
                            )

In [20]:
def augment_data(data_path):
    for folder in os.listdir(data_path):
        folder_contents = os.listdir(os.path.join(DATA, TRAIN_SET, folder))
        folder_size = len(folder_contents)
        imgs_to_augment = most_imgs - folder_size
        imgs_augmented = 0
        for img in cycle(folder_contents):
            if imgs_to_augment <= imgs_augmented:
                print(f"Augmented {imgs_augmented} images for {folder}")
                break
            img_path = os.path.join(data_path, folder, img)
            aug_src_img = Image.open(img_path)
            aug_src_img = aug_src_img.resize(IMG_DIMS)
            aug_src_img = aug_src_img.convert('RGB')
            aug_src_img = np.array(aug_src_img)
            aug_img = datagen.apply_transform(
                aug_src_img,
                datagen.get_random_transform(aug_src_img.shape))
            filename = f'aug_{folder}_{imgs_augmented}.jpg'
            savepath = os.path.join(data_path, folder, filename)
            aug_img = Image.fromarray(aug_img)
            aug_img.save(savepath)
            imgs_augmented += 1
            
augment_data(os.path.join(DATA, TRAIN_SET))

Augmented 0 images for Abra
Augmented 0 images for Aerodactyl
Augmented 0 images for Alakazam
Augmented 0 images for Arbok
Augmented 0 images for Arcanine
Augmented 0 images for Articuno
Augmented 0 images for Beedrill
Augmented 0 images for Bellsprout
Augmented 0 images for Blastoise
Augmented 0 images for Bulbasaur
Augmented 0 images for Butterfree
Augmented 0 images for Caterpie
Augmented 0 images for Chansey
Augmented 0 images for Charizard
Augmented 0 images for Charmander
Augmented 0 images for Charmeleon
Augmented 0 images for Clefable
Augmented 0 images for Clefairy
Augmented 0 images for Cloyster
Augmented 0 images for Cubone
Augmented 0 images for Dewgong
Augmented 0 images for Diglett
Augmented 0 images for Ditto
Augmented 0 images for Dodrio
Augmented 0 images for Doduo
Augmented 0 images for Dragonair
Augmented 0 images for Dragonite
Augmented 0 images for Dratini
Augmented 0 images for Drowzee
Augmented 0 images for Dugtrio
Augmented 0 images for Eevee
Augmented 0 images 

In [21]:
count_instances(os.path.join(DATA, TRAIN_SET))

Abra count is: 224
Aerodactyl count is: 224
Alakazam count is: 224
Arbok count is: 224
Arcanine count is: 224
Articuno count is: 224
Beedrill count is: 224
Bellsprout count is: 224
Blastoise count is: 224
Bulbasaur count is: 224
Butterfree count is: 224
Caterpie count is: 224
Chansey count is: 224
Charizard count is: 224
Charmander count is: 224
Charmeleon count is: 224
Clefable count is: 224
Clefairy count is: 224
Cloyster count is: 224
Cubone count is: 224
Dewgong count is: 224
Diglett count is: 224
Ditto count is: 224
Dodrio count is: 224
Doduo count is: 224
Dragonair count is: 224
Dragonite count is: 224
Dratini count is: 224
Drowzee count is: 224
Dugtrio count is: 224
Eevee count is: 224
Ekans count is: 224
Electabuzz count is: 224
Electrode count is: 224
Exeggcute count is: 224
Exeggutor count is: 224
Farfetchd count is: 224
Fearow count is: 224
Flareon count is: 224
Gastly count is: 224
Gengar count is: 224
Geodude count is: 224
Gloom count is: 224
Golbat count is: 224
Goldeen c

224

In [22]:
print('TRAIN: ')
train_generator = datagen.flow_from_directory(
                    directory=DATA + TRAIN_SET,
                    class_mode = 'categorical',
                    batch_size = 64,
                    shuffle = True,
                    target_size = IMG_DIMS,
                    )
print('TEST: ')
val_generator = datagen.flow_from_directory(
                    directory=DATA + TEST_SET,
                    class_mode = 'categorical',
                    batch_size = 64,
                    shuffle = True,
                    target_size = IMG_DIMS,
                    )

TRAIN: 
Found 33372 images belonging to 149 classes.
TEST: 
Found 1808 images belonging to 149 classes.


In [23]:
# Download pretrained Resnet50
resnet = ResNet50(include_top=False,weights='imagenet',input_shape = (IMG_WIDTH,IMG_HEIGHT,IMG_CHANNELS))

In [24]:
# Add layers to the Resnet50
layer1 = GlobalAveragePooling2D()(resnet.output)
layer2 = Dense(1024,activation = 'relu')(layer1)
layer3 = Dense(512,activation='relu')(layer2)
layer_out = Dense(149,activation='softmax')(layer3)

model = Model(inputs=resnet.input,outputs = layer_out)
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 256, 256, 3  0           []                               
                                )]                                                                
                                                                                                  
 conv1_pad (ZeroPadding2D)      (None, 262, 262, 3)  0           ['input_1[0][0]']                
                                                                                                  
 conv1_conv (Conv2D)            (None, 128, 128, 64  9472        ['conv1_pad[0][0]']              
                                )                                                                 
                                                                                              

In [25]:
# Compile the model
model.compile(optimizer='adam',loss='categorical_crossentropy',metrics=['accuracy'])

In [None]:
# Train
epochs = 10
hist = model.fit(train_generator,epochs=epochs,validation_data=val_generator)

In [None]:
# Visualise training data
def create_graphs(history, epochs, name):
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']

    loss = history.history['loss']
    val_loss = history.history['val_loss']

    epochs_range = range(epochs)

    plt.figure(figsize=(8, 8))
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.xlabel('Epochs')
    plt.title('Training and Validation Loss')
    plt.show()
    plt.savefig(f'{name}.png')

In [None]:
# Save history and model
import pickle
with open('history.pkl', 'wb') as f:
    pickle.dump(hist.history, f)

model.save('resnet50_model.h5')
create_graphs(hist, epochs, 'resnet50_model')