# Pokémon Detector Demo

This notebook demonstrates how to load a trained ConvNeXt model and use it to predict the class of a Pokémon from an input image. The workflow includes loading class names, preparing the model, preprocessing the image, and making a prediction.

## Load Class Names

Load the list of Pokémon class names from a JSON file. This is used to map model predictions to class labels.

In [34]:
import numpy as np
import torch
import json
import timm
from torchvision import transforms
from PIL import Image
import click

## Imports

We import the necessary libraries for model loading, image processing, and prediction.

In [5]:
with open("../data/class_names.json", "r") as f:
    class_names = json.load(f)

print("Class names loaded:", class_names)
print("Number of classes:", len(class_names))

Class names loaded: ['Abra', 'Aerodactyl', 'Alakazam', 'Arbok', 'Arcanine', 'Articuno', 'Beedrill', 'Bellsprout', 'Blastoise', 'Bulbasaur', 'Butterfree', 'Caterpie', 'Chansey', 'Charizard', 'Charmander', 'Charmeleon', 'Clefable', 'Clefairy', 'Cloyster', 'Cubone', 'Dewgong', 'Diglett', 'Ditto', 'Dodrio', 'Doduo', 'Dragonair', 'Dragonite', 'Dratini', 'Drowzee', 'Dugtrio', 'Eevee', 'Ekans', 'Electabuzz', 'Electrode', 'Exeggcute', 'Exeggutor', 'Farfetchd', 'Fearow', 'Flareon', 'Gastly', 'Gengar', 'Geodude', 'Gloom', 'Golbat', 'Goldeen', 'Golduck', 'Golem', 'Graveler', 'Grimer', 'Growlithe', 'Gyarados', 'Haunter', 'Hitmonchan', 'Hitmonlee', 'Horsea', 'Hypno', 'Ivysaur', 'Jigglypuff', 'Jolteon', 'Jynx', 'Kabuto', 'Kabutops', 'Kadabra', 'Kakuna', 'Kangaskhan', 'Kingler', 'Koffing', 'Krabby', 'Lapras', 'Lickitung', 'Machamp', 'Machoke', 'Machop', 'Magikarp', 'Magmar', 'Magnemite', 'Magneton', 'Mankey', 'Marowak', 'Meowth', 'Metapod', 'Mew', 'Mewtwo', 'Moltres', 'MrMime', 'Muk', 'Nidoking', 'Ni

## Model Setup

Set up the ConvNeXt model, specify the device (CPU or GPU), and load the trained weights.

In [8]:
MODEL_NAME = "convnext_base"
NUM_CLASSES = len(class_names)  # Set based on the number of class names
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if DEVICE.type == "cuda":
    print("GPU in use")
else:
    print("CPU in use")
MODEL_WEIGHT_PATH = "../models/best_model_fold1.pth"  # Path to the model weights

GPU in use


### Load Model Weights

Load the trained weights into the model and set it to evaluation mode.

In [19]:
try:
    model = timm.create_model(MODEL_NAME, pretrained=False, num_classes=NUM_CLASSES)
    model.load_state_dict(torch.load(MODEL_WEIGHT_PATH, map_location=DEVICE))
    model.to(DEVICE)
    model.eval()
    print("Model Loaded")
except Exception as e:
    print(e)

Model Loaded


## Image Preprocessing

Define the image transformations required for the model input.

In [12]:
eval_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

### Image Loading Function

Define a function to load and process an image.

In [None]:
def predict_image(img_path, model, transform, device, class_names, topk=2):
    """
    Loads an image, preprocesses it, and predicts its class using the provided model.
    Returns:
        list of tuples: [(class_index, class_name, confidence), ...] for topk predictions
    """
    image = Image.open(img_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)
    model.eval()
    with torch.no_grad():
        output = model(image_tensor)
        probabilities = torch.softmax(output, dim=1)[0]
        topk_results = torch.topk(probabilities, topk)
        results = []
        for idx, conf in zip(topk_results.indices, topk_results.values):
            results.append((idx.item(), class_names[idx.item()], conf.item()))
    return results

## Load and Predict

Load the input image, preprocess it, and use the model to predict the Pokémon class.

In [31]:
img_path = "../data/gengar.png"  # Change to your image path
results= predict_image(img_path, model, eval_transform, DEVICE, class_names, topk=4)
print("Predictions for image:", img_path)
for idx, class_name, confidence in results:
    print(f"Class Index: {idx}, Class Name: {class_name}, Confidence: {confidence:.4f}")


Predictions for image: ../data/gengar.png
Class Index: 40, Class Name: Gengar, Confidence: 0.9984
Class Index: 51, Class Name: Haunter, Confidence: 0.0011
Class Index: 114, Class Name: Rhydon, Confidence: 0.0001
Class Index: 45, Class Name: Golduck, Confidence: 0.0000


In [35]:
@click.command()
@click.argument("img_path")
@click.option("--topk", default=2, help="Number of top predictions to return.")
def cli(img_path, topk):
    results = predict_image(img_path, model, eval_transform, DEVICE, class_names, topk=topk)
    print(f"Predictions for image: {img_path}")
    for idx, class_name, confidence in results:
        print(f"Class Index: {idx}, Class Name: {class_name}, Confidence: {confidence:.4f}")