# Pokemon Classifier - High Accuracy (Transfer Learning)

This notebook implements a high-accuracy Pokemon classifier using Transfer Learning (MobileNetV2). 
It loads images directly from the folder structure, assuming each subfolder represents a class.

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import cv2

from sklearn.metrics import classification_report, confusion_matrix

import tensorflow as tf
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint

# Set random seed for reproducibility
np.random.seed(42)
tf.random.set_seed(42)

print(f"TensorFlow Version: {tf.__version__}")

## 1. Configuration

Define paths and hyperparameters. 
**NOTE:** Ensure your dataset is located at `DATASET_PATH` and contains subfolders for each Pokemon.

In [None]:
# Image Configuration
IMG_SIZE = (224, 224) # MobileNetV2 expects 224x224
CHANNELS = 3
BATCH_SIZE = 32
EPOCHS = 20

# Dataset Paths
# The dataset should be in a folder named 'PokemonData' (or similar) containing subfolders for each class
BASE_DIR = '.' # Current directory
DATASET_PATH = os.path.join(BASE_DIR, 'dataset', 'PokemonData') 

print(f"Dataset Path: {DATASET_PATH}")

if not os.path.exists(DATASET_PATH):
    print(f"WARNING: Dataset path not found at {DATASET_PATH}. Please check the path.")

## 2. Data Visualization

Let's visualize some sample images from the directory.

In [None]:
def visualize_samples(dataset_path, num_samples=9):
    if not os.path.exists(dataset_path):
        return

    classes = [d for d in os.listdir(dataset_path) if os.path.isdir(os.path.join(dataset_path, d))]
    if not classes:
        print("No classes found.")
        return

    plt.figure(figsize=(10, 10))
    
    for i in range(num_samples):
        # Pick a random class
        cls = np.random.choice(classes)
        cls_folder = os.path.join(dataset_path, cls)
        images = os.listdir(cls_folder)
        
        if images:
            # Pick a random image
            img_name = np.random.choice(images)
            img_path = os.path.join(cls_folder, img_name)
            
            try:
                img = cv2.imread(img_path)
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                plt.subplot(3, 3, i + 1)
                plt.imshow(img)
                plt.title(cls)
                plt.axis('off')
            except Exception as e:
                pass
    plt.tight_layout()
    plt.show()

# Uncomment to run if images are present
# visualize_samples(DATASET_PATH)

## 3. Data Generators

We use `ImageDataGenerator` with `flow_from_directory` to load images directly from folders. We split the data into 80% training and 20% validation.

In [None]:
# Data Augmentation for Training
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    fill_mode='nearest',
    validation_split=0.2 # Use 20% for validation
)

# Generators
if os.path.exists(DATASET_PATH):
    train_generator = train_datagen.flow_from_directory(
        DATASET_PATH,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='training', # Set as training data
        shuffle=True
    )

    val_generator = train_datagen.flow_from_directory(
        DATASET_PATH,
        target_size=IMG_SIZE,
        batch_size=BATCH_SIZE,
        class_mode='categorical',
        subset='validation', # Set as validation data
        shuffle=False
    )
    
    NUM_CLASSES = len(train_generator.class_indices)
    print(f"Number of classes: {NUM_CLASSES}")
else:
    print("Dataset not found. Please check DATASET_PATH.")
    NUM_CLASSES = 151 # Default fallback

## 4. Model Architecture (Transfer Learning)

We use **MobileNetV2** pre-trained on ImageNet as the base.

In [None]:
def build_model(num_classes):
    # Base Model (MobileNetV2)
    base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(IMG_SIZE[0], IMG_SIZE[1], 3))
    
    # Freeze base model layers initially
    base_model.trainable = False

    # Custom Head
    x = base_model.output
    x = GlobalAveragePooling2D()(x)
    x = Dense(1024, activation='relu')(x)
    x = Dropout(0.5)(x)
    predictions = Dense(num_classes, activation='softmax')(x)

    model = Model(inputs=base_model.input, outputs=predictions)
    return model

model = build_model(NUM_CLASSES)

model.compile(
    optimizer=Adam(learning_rate=0.001),
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

## 5. Training

We use callbacks for Early Stopping and Learning Rate Reduction.

In [None]:
if os.path.exists(DATASET_PATH):
    # Callbacks
    checkpoint = ModelCheckpoint(
        'pokemon_classifier_model_V3.h5', 
        monitor='val_accuracy', 
        save_best_only=True, 
        mode='max',
        verbose=1
    )
    
    reduce_lr = ReduceLROnPlateau(
        monitor='val_loss', 
        factor=0.2, 
        patience=3, 
        min_lr=1e-6, 
        verbose=1
    )
    
    early_stop = EarlyStopping(
        monitor='val_loss', 
        patience=5, 
        restore_best_weights=True, 
        verbose=1
    )

    # Train
    history = model.fit(
        train_generator,
        epochs=EPOCHS,
        validation_data=val_generator,
        callbacks=[checkpoint, reduce_lr, early_stop]
    )

## 6. Evaluation & Fine-tuning

Visualize training history and optionally fine-tune the model.

In [None]:
if 'history' in locals():
    # Plot Accuracy and Loss
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs_range = range(len(acc))

    plt.figure(figsize=(15, 5))
    
    plt.subplot(1, 2, 1)
    plt.plot(epochs_range, acc, label='Training Accuracy')
    plt.plot(epochs_range, val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.title('Training and Validation Accuracy')

    plt.subplot(1, 2, 2)
    plt.plot(epochs_range, loss, label='Training Loss')
    plt.plot(epochs_range, val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.title('Training and Validation Loss')
    plt.show()

In [None]:
# Fine-tuning (Optional)
# Unfreeze the base model and train with a very low learning rate
if os.path.exists(DATASET_PATH):
    print("Starting Fine-tuning...")
    model.trainable = True
    
    # Recompile with low learning rate
    model.compile(
        optimizer=Adam(learning_rate=1e-5),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Train for a few more epochs
    history_fine = model.fit(
        train_generator,
        epochs=10,
        validation_data=val_generator,
        callbacks=[checkpoint, reduce_lr, early_stop]
    )

## 7. Save Final Model

Save the final trained model for use in the Flask application.

In [None]:
if os.path.exists(DATASET_PATH):
    model.save('pokemon_classifier_model_V3.h5')
    print("Final model saved as pokemon_classifier_model_V3.h5")
    
    # Save class indices for the Flask app
    import json
    class_indices = train_generator.class_indices
    with open('class_indices.json', 'w') as f:
        json.dump(class_indices, f)
    print("Class indices saved as class_indices.json")