In [29]:
import os
import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Input, Dense, Flatten, Dropout
from tensorflow.keras.applications import EfficientNetB0
from sklearn.model_selection import train_test_split

In [None]:
# Define constants
#IMAGE_DIR = "C:\\Users\\willi\\Documents\\Drexel\\Fall Quart 5\\MEM 679 - Machine Learning\\pokemon_images\\Pokemon_Dataset_Subset"
IMAGE_DIR = "C:\\Users\\willi\\Documents\\Drexel\\Fall Quart 5\\MEM 679 - Machine Learning\\pokemon_images\\Pokemon_Dataset_Full"
IMG_SIZE = (224, 224) #Images will be resized to 244 x 244
BATCH_SIZE = 32 #Number of images to process at a time

In [31]:
# Function to parse labels from filenames
def parse_filename(filename):
    """
    Parses the filename to extract metadata.

    Args:
        filename (str): The name of the image file.

    Returns:
        pokemon_name (str): Name of Pokemon
        shiny_form (int): 1 if Pokemon is shiny, 0 if Pokemon is Normal
        gender (str): Gender of Pokemon

    """
    parts = filename.replace('.png', '').split(' ')
    location_name = parts[0]
    shiny = 1 if 'Shiny' in parts else 0
    gender = 'Unknown'
    if 'Male & Female' in filename:
        gender = 'Male & Female'
    elif 'Male' in filename:
        gender = 'Male'
    elif 'Female' in filename:
        gender = 'Female'
    
    location, name = location_name.split('_', 1)
    
    return name, shiny, gender

In [32]:
# Load dataset
data = []
labels_name = []
labels_shiny = []
labels_gender = []

for file in os.listdir(IMAGE_DIR):
    if file.endswith((".png")):
        filepath = os.path.join(IMAGE_DIR, file)
        name, shiny, gender = parse_filename(file)
        data.append(filepath)
        labels_name.append(name)
        labels_shiny.append(shiny)
        labels_gender.append(gender)

In [None]:
# Preprocess labels
unique_names = sorted(set(labels_name))
unique_genders = ["Male", "Female", "Male & emale"]

name_to_idx = {name: i for i, name in enumerate(unique_names)}
gender_to_idx = {gender: i for i, gender in enumerate(unique_genders)}

y_name = [name_to_idx[name] for name in labels_name]
y_shiny = labels_shiny
y_gender = [gender_to_idx[gender] for gender in labels_gender]

# Split data
train_data, val_data, train_labels, val_labels = train_test_split(
    data, list(zip(y_name, y_shiny, y_gender)), test_size=0.2, random_state=42
)

In [40]:
# Preprocess images
datagen = ImageDataGenerator(rescale=1.0/255.0)

def preprocess_images(filepaths, labels, batch_size):
    def generator():
        for filepath, label in zip(filepaths, labels):
            image = tf.keras.utils.load_img(filepath, target_size=IMG_SIZE)
            image = tf.keras.utils.img_to_array(image) / 255.0
            # Restructure labels into a dictionary for model outputs
            label_dict = {
                "name_output": label[0],
                "shiny_output": label[1],
                "gender_output": label[2]
            }
            yield image, label_dict
    return tf.data.Dataset.from_generator(
        generator,
        output_signature=(
            tf.TensorSpec(shape=(*IMG_SIZE, 3), dtype=tf.float32),
            {
                "name_output": tf.TensorSpec(shape=(), dtype=tf.int32),
                "shiny_output": tf.TensorSpec(shape=(), dtype=tf.int32),
                "gender_output": tf.TensorSpec(shape=(), dtype=tf.int32),
            },
        )
    ).batch(batch_size)


train_dataset = preprocess_images(train_data, train_labels, BATCH_SIZE)
val_dataset = preprocess_images(val_data, val_labels, BATCH_SIZE)

In [41]:
# Build the model
base_model = EfficientNetB0(include_top=False, input_shape=(*IMG_SIZE, 3))
x = Flatten()(base_model.output)
x = Dropout(0.5)(x)

name_output = Dense(len(unique_names), activation="softmax", name="name_output")(x)
shiny_output = Dense(1, activation="sigmoid", name="shiny_output")(x)
gender_output = Dense(len(unique_genders), activation="softmax", name="gender_output")(x)

model = Model(inputs=base_model.input, outputs=[name_output, shiny_output, gender_output])
model.compile(optimizer="adam", 
              loss={"name_output": "sparse_categorical_crossentropy",
                    "shiny_output": "binary_crossentropy",
                    "gender_output": "sparse_categorical_crossentropy"},
              metrics=["accuracy"])

In [44]:
# Train the model
model.compile(
    optimizer="adam",
    loss={
        "name_output": "sparse_categorical_crossentropy",
        "shiny_output": "binary_crossentropy",
        "gender_output": "sparse_categorical_crossentropy",
    },
    metrics={
        "name_output": ["accuracy"],
        "shiny_output": ["accuracy"],
        "gender_output": ["accuracy"],
    }
)


history = model.fit(train_dataset, validation_data=val_dataset, epochs=10)

Epoch 1/10


      2/Unknown [1m79s[0m 3s/step - gender_output_accuracy: 0.4072 - gender_output_loss: 2.6716 - loss: 9.6399 - name_output_accuracy: 0.0294 - name_output_loss: 5.6554 - shiny_output_accuracy: 0.4951 - shiny_output_loss: 2.0972     



[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m88s[0m 12s/step - gender_output_accuracy: 0.4283 - gender_output_loss: 2.5976 - loss: 10.4041 - name_output_accuracy: 0.0392 - name_output_loss: 5.2978 - shiny_output_accuracy: 0.4935 - shiny_output_loss: 2.0543 - val_gender_output_accuracy: 0.5385 - val_gender_output_loss: 0.6133 - val_loss: 6.3727 - val_name_output_accuracy: 0.0000e+00 - val_name_output_loss: 2.1616 - val_shiny_output_accuracy: 0.5385 - val_shiny_output_loss: 0.4114
Epoch 2/10
[1m2/2[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m13s[0m 5s/step - gender_output_accuracy: 0.6605 - gender_output_loss: 2.0040 - loss: 8.0396 - name_output_accuracy: 0.3787 - name_output_loss: 3.4311 - shiny_output_accuracy: 0.5039 - shiny_output_loss: 1.5535 - val_gender_output_accuracy: 0.5385 - val_gender_output_loss: 0.5855 - val_loss: 7.2457 - val_name_output_accuracy: 0.0769 - val_name_output_loss: 2.6105 - val_shiny_output_accuracy: 0.5385 - val_shiny_output_loss: 0.4268
Epoch