In [None]:
#DataSet for classification models: https://www.kaggle.com/datasets/noodulz/pokemon-dataset-1000
#DataSet for detection model (YOLO): we prepared it using the DataSet (after extracting the images from the folders structure) in https://app.roboflow.com/

In [1]:
import gradio as gr
from PIL import Image
import tempfile
import os
import shutil
import torch
import torchvision.transforms as transforms
from ultralytics import YOLO
from timm import create_model
import numpy as np
import pandas as pd
import torchvision.models as Tmodels
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader, Dataset
from torchvision.datasets import ImageFolder
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor, ViTModel
import cv2
from PIL import Image as PILImage
from rembg import remove
import torchvision.utils as vutils

#Check for GPU availability
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

#Paths
metadata_path = "data/pokemon-dataset-1000/metadata.csv"
train_dir = "data/pokemon-dataset-1000/train"
val_dir = "data/pokemon-dataset-1000/val"
test_dir = "data/pokemon-dataset-1000/test"

#Load metadata
metadata = pd.read_csv(metadata_path)
metadata.head()
print("Metadata columns:", metadata.columns)

  check_for_updates()


Using device: cuda
Metadata columns: Index(['label', 'image_path', 'hp', 'attack', 'defense', 'special-attack',
       'special-defense', 'speed'],
      dtype='object')


In [2]:
#Train yolo model first due to ultralytics training automaticly cleaning vram after trainig. training is done with transferlerning on weights from https://github.com/vovod/yolov8-pokemon-object-detection
modelyolo= YOLO("yolov8n.yaml").load("vovodbest.pt")

results = modelyolo.train( data="data.yaml", epochs=50, imgsz=128, device=0, batch=1000, mask_ratio=4, plots=True, name='vovodtransfer',patience=0, close_mosaic=4, workers=4, freeze=17,multi_scale=True)

Transferred 319/355 items from pretrained weights
New https://pypi.org/project/ultralytics/8.3.75 available 😃 Update with 'pip install -U ultralytics'
Ultralytics 8.3.73 🚀 Python-3.12.7 torch-2.7.0.dev20250207+cu128 CUDA:0 (NVIDIA GeForce RTX 5080, 16303MiB)
[34m[1mengine/trainer: [0mtask=detect, mode=train, model=yolov8n.yaml, data=data.yaml, epochs=50, time=None, patience=0, batch=1000, imgsz=128, save=True, save_period=-1, cache=False, device=0, workers=4, project=None, name=vovodtransfer9, exist_ok=False, pretrained=vovodbest.pt, optimizer=auto, verbose=True, seed=0, deterministic=True, single_cls=False, rect=False, cos_lr=False, close_mosaic=4, resume=False, amp=True, fraction=1.0, profile=False, freeze=17, multi_scale=True, overlap_mask=True, mask_ratio=4, dropout=0.0, val=True, split=val, save_json=False, save_hybrid=False, conf=None, iou=0.7, max_det=300, half=False, dnn=False, plots=True, source=None, vid_stride=1, stream_buffer=False, visualize=False, augment=False, agnost

[34m[1mtrain: [0mScanning /home/omerd/dlfinalprojmerged/train/labels.cache... 7018 images, 0 backgrounds, 0[0m

[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, num_output_channels=3, method='weighted_average'), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))



[34m[1mval: [0mScanning /home/omerd/dlfinalprojmerged/valid/labels.cache... 1008 images, 0 backgrounds, 0 c[0m


Plotting labels to runs/detect/vovodtransfer9/labels.jpg... 
[34m[1moptimizer:[0m 'optimizer=auto' found, ignoring 'lr0=0.01' and 'momentum=0.937' and determining best 'optimizer', 'lr0' and 'momentum' automatically... 
[34m[1moptimizer:[0m AdamW(lr=0.002, momentum=0.9) with parameter groups 57 weight(decay=0.0), 64 weight(decay=0.0078125), 63 bias(decay=0.0)
Image sizes 128 train, 128 val
Using 4 dataloader workers
Logging results to [1mruns/detect/vovodtransfer9[0m
Starting training for 50 epochs...

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       1/50      6.13G      1.771      3.264      1.615         50        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008     0.0169      0.992       0.87      0.531






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       2/50       5.7G      1.608      3.093       1.55         50         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.943      0.909       0.94      0.548






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       3/50      4.22G      1.569      2.365      1.445         50         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008       0.97      0.752      0.956       0.66






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       4/50      6.62G      1.452      1.872      1.385         34        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.921      0.928      0.963      0.702






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       5/50      3.93G      1.345       1.54      1.321         44         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.914      0.948      0.975      0.739






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       6/50      3.74G      1.372      1.556      1.281         45        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.878      0.916      0.952      0.687






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       7/50      4.84G      1.346      1.492      1.247         40         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.784       0.74      0.828       0.57






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       8/50      4.05G      1.285      1.386      1.215         54         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.824      0.855      0.909      0.651






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


       9/50      4.69G      1.175      1.154      1.186         41         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.923      0.903      0.967      0.717






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      10/50      4.12G        1.1      1.022      1.157         45        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.894      0.811      0.922      0.647






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      11/50      4.72G      1.193      1.138      1.173         41         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008       0.79      0.863      0.861      0.625






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      12/50      3.68G      1.285       1.22       1.18         48         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.859      0.928      0.953      0.709






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      13/50      4.22G      1.186      1.172       1.17         49         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.907      0.893      0.952      0.662






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      14/50      3.98G      1.089      0.962      1.124         41         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.917       0.85      0.924      0.639






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      15/50      4.73G      1.137      1.027      1.134         47         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.932      0.945       0.97       0.72






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      16/50      2.24G      1.191      1.057      1.135         52        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.924      0.958      0.971      0.741






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      17/50      4.61G      1.216      1.122      1.157         43        192: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.917      0.938      0.963      0.714






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      18/50      4.95G      1.275      1.158      1.159         41        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.742       0.73      0.795      0.483






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      19/50      4.33G      1.255       1.11      1.154         49         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.808       0.75      0.862      0.604






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      20/50      5.49G       1.18      1.029       1.15         43        192: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.943      0.931       0.98      0.741






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      21/50      2.09G      1.181     0.9987      1.132         51        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.915      0.923      0.971      0.725






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      22/50      4.61G      1.141     0.9525      1.138         45        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.782      0.924      0.917      0.688






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      23/50      7.18G      1.088     0.8879      1.121         44         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.935      0.934      0.977      0.756






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      24/50      4.92G      1.056     0.8619      1.118         46        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.928      0.975      0.984       0.76






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      25/50      6.69G       1.01     0.8281      1.105         44        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.966      0.972       0.99      0.766






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      26/50      5.71G      1.067      0.843      1.104         41        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.968      0.973       0.99      0.777






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      27/50         4G      1.013     0.7977      1.093         45        192: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.956      0.974       0.99      0.778






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      28/50       5.4G     0.9602     0.7495      1.083         53        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008       0.97      0.978      0.993      0.777






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      29/50      3.83G      1.088     0.8393      1.098         42         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.963      0.969      0.991      0.789






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      30/50      5.71G     0.9708     0.7517      1.091         53        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.978       0.97      0.993       0.79






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      31/50       3.5G      1.051     0.8067      1.091         44        192: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.973      0.973      0.992      0.784






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      32/50      5.63G      1.021     0.7824      1.086         47        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.971      0.973      0.991      0.768






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      33/50      5.78G      1.074     0.8377      1.091         53         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.961      0.975      0.992      0.782






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      34/50      2.98G       1.07     0.7954      1.092         49        192: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.957      0.961      0.988      0.748






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      35/50      3.81G      1.024     0.7885      1.086         44        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.967      0.967      0.991      0.791






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      36/50      3.66G       1.05     0.7846      1.088         49        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.969      0.972      0.991      0.777






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      37/50       5.1G      1.025      0.751       1.08         45        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008       0.98      0.986      0.994      0.795






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      38/50      4.74G      1.046     0.7778      1.079         47        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008       0.98      0.977      0.993      0.801






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      39/50      3.94G     0.9355     0.6988       1.07         55        192: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.985      0.984      0.994      0.785






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      40/50      6.57G      1.002     0.7215      1.075         37         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.988      0.985      0.994      0.795






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      41/50      5.02G      1.009     0.7307      1.075         46         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.981      0.987      0.993      0.812






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      42/50      5.23G      0.989      0.712       1.07         44        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.981      0.987      0.993      0.806






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      43/50      3.68G      1.003     0.7357      1.076         45        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.983      0.988      0.994      0.803






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      44/50      4.33G     0.9474     0.6974      1.061         47        160: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.984      0.984      0.994      0.802






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      45/50      4.12G     0.9885     0.7261       1.07         39         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.978       0.99      0.994       0.81






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      46/50      3.03G     0.9791     0.7147      1.072         45         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.988      0.986      0.994      0.808





Closing dataloader mosaic
[34m[1malbumentations: [0mBlur(p=0.01, blur_limit=(3, 7)), MedianBlur(p=0.01, blur_limit=(3, 7)), ToGray(p=0.01, num_output_channels=3, method='weighted_average'), CLAHE(p=0.01, clip_limit=(1.0, 4.0), tile_grid_size=(8, 8))

      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      47/50      4.59G     0.8171     0.6972     0.9832         18         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.992      0.996      0.995      0.812






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      48/50      6.13G     0.6831     0.5358     0.9647         18        128: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.993      0.995      0.995      0.815






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      49/50      5.69G     0.6814     0.4918     0.9436         18         96: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.996      0.996      0.995      0.819






      Epoch    GPU_mem   box_loss   cls_loss   dfl_loss  Instances       Size


      50/50      4.72G     0.7319     0.5073     0.9451         18         64: 100%|██████████| 8
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██

                   all       1008       1008      0.996      0.996      0.995      0.823






50 epochs completed in 0.106 hours.
Optimizer stripped from runs/detect/vovodtransfer9/weights/last.pt, 6.2MB
Optimizer stripped from runs/detect/vovodtransfer9/weights/best.pt, 6.2MB

Validating runs/detect/vovodtransfer9/weights/best.pt...
Ultralytics 8.3.73 🚀 Python-3.12.7 torch-2.7.0.dev20250207+cu128 CUDA:0 (NVIDIA GeForce RTX 5080, 16303MiB)
YOLOv8n summary (fused): 168 layers, 3,005,843 parameters, 0 gradients, 8.1 GFLOPs


                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██


                   all       1008       1008      0.995      0.996      0.995      0.821
Speed: 0.0ms preprocess, 0.0ms inference, 0.0ms loss, 0.8ms postprocess per image
Results saved to [1mruns/detect/vovodtransfer9[0m


Yolo Lable Data:

![labels_correlogram.jpg](attachment:88f62d14-1b47-455a-bdb7-d50154eafc75.jpg)


Train Batches(mosiac is closed on last 4 epoches):


![train_batch0.jpg](attachment:00c9bfcf-b55e-4492-ac20-50b05b961090.jpg)


![train_batch368.jpg](attachment:cc19444f-19a7-4837-b299-4f8525286c11.jpg)

Validation Batches:

![val_batch0_labels.jpg](attachment:6c9d8fad-8b44-4b4f-903c-835af1f3c6b1.jpg)
![val_batch0_pred.jpg](attachment:528ca87c-7302-42c5-8c6a-7d6939d5222d.jpg)


![P_curve.png](attachment:2977e30f-d6aa-4272-a69f-a123fb0ff2f4.png)
![PR_curve.png](attachment:745c1395-a7f8-43ca-8e93-a788a4259f09.png)
![R_curve.png](attachment:ac08009b-51a8-42d3-ab9a-b88187b93a80.png)
![results.png](attachment:a3424265-9111-4d77-bf32-119fc13e734c.png)

In [3]:
test_resaults=modelyolo.val(data="data.yaml", imgsz=128, batch=32,device=0,split='test',workers=4)

Ultralytics 8.3.73 🚀 Python-3.12.7 torch-2.7.0.dev20250207+cu128 CUDA:0 (NVIDIA GeForce RTX 5080, 16303MiB)
YOLOv8n summary (fused): 168 layers, 3,005,843 parameters, 0 gradients, 8.1 GFLOPs


[34m[1mval: [0mScanning /home/omerd/dlfinalprojmerged/test/labels.cache... 1974 images, 0 backgrounds, 0 co[0m
                 Class     Images  Instances      Box(P          R      mAP50  mAP50-95): 100%|██


                   all       1974       1974      0.994      0.999      0.995       0.82
Speed: 0.0ms preprocess, 0.4ms inference, 0.0ms loss, 0.8ms postprocess per image
Results saved to [1mruns/detect/vovodtransfer92[0m


![confusion_matrix_normalized.png](attachment:311f0eb7-c4c6-4749-86fd-e200fbfcca4e.png)

![F1_curve.png](attachment:3a5e484d-c756-4978-ac12-14043773ae4f.png)

![P_curve.png](attachment:27cf47cd-61df-45e7-94b7-a05364ea1215.png)

![R_curve.png](attachment:6125d7d2-342d-4af1-92e0-3d667a484544.png)

![PR_curve.png](attachment:9699804b-1b0b-4064-9e12-611dfe9d18ca.png)


![val_batch0_pred.jpg](attachment:c3712dd8-4f06-4965-8821-991763fdf386.jpg)
![val_batch1_labels.jpg](attachment:48a75cc8-10a5-4f5e-952b-bd7f6572bfa0.jpg)
![val_batch1_pred.jpg](attachment:9698bab7-c74b-4e54-bea4-a43ee18a8952.jpg)
![val_batch2_labels.jpg](attachment:5f9b0b09-3da5-41e6-b3f2-e1238484518a.jpg)
![val_batch2_pred.jpg](attachment:10a95d6c-275c-4946-84bb-4539625132f5.jpg)
![val_batch0_labels.jpg](attachment:313b3c22-c979-4095-a3ca-72bd9d14e664.jpg)

In [4]:
#Transform wrapper
class AlbumentationsTransform:
    def __init__(self, transform):
        self.transform = transform

    def __call__(self, img):
        img = np.array(img)  #Convert PIL to NumPy
        transformed = self.transform(image=img)
        return transformed["image"]

#Transformations
def make_transforms(dim):
    transform_train = A.Compose([
        A.SmallestMaxSize(max_size=dim),
        A.Affine(scale=(0.95, 1.05), translate_percent=(-0.05, 0.05), rotate=(-15, 15), p=0.5),
        A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
        A.RandomBrightnessContrast(p=0.5),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])

    transform_val_test = A.Compose([
        A.Resize(dim, dim),
        A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
        ToTensorV2(),
    ])
    return transform_train,transform_val_test

#Dataset loader class
class AlbumentationsDataset(Dataset):
    def __init__(self, image_folder, transform=None):
        self.image_folder = image_folder
        self.transform = transform
        self.classes = image_folder.classes

    def __len__(self):
        return len(self.image_folder)

    def __getitem__(self, idx):
        img_path, label = self.image_folder.samples[idx]
        image = Image.open(img_path).convert("RGB")
        image = np.array(image)

        if self.transform:
            transformed = self.transform(image=image)
            image = transformed["image"]

        return image, label

In [5]:
transform_train,transform_val_test = make_transforms(128)

transform_train_vit,transform_val_test_vit = make_transforms(224)

In [6]:
#Define datasets and loaders and show examples from train and validation batches
%matplotlib inline #To show the images
def show_batch(dataloader, title, num_images=20):
    images, labels = next(iter(dataloader))
    images = images[:num_images]
    labels = labels[:num_images]
    print(f"{title} Batch - Labels: {labels.tolist()}")
    print(f"{title} Batch - Image Tensor Shape: {images.shape}")
    grid = vutils.make_grid(images, nrow=5, normalize=True)
    plt.figure(figsize=(10, 10))
    plt.title(title)
    plt.imshow(grid.permute(1, 2, 0))
    plt.axis('off')
    plt.show()

train_image_folder = ImageFolder(train_dir)
val_image_folder = ImageFolder(val_dir)
test_image_folder = ImageFolder(test_dir)

train_dataset = AlbumentationsDataset(train_image_folder, transform=transform_train)
train_dataset_vit = AlbumentationsDataset(train_image_folder, transform=transform_train_vit)
val_dataset = AlbumentationsDataset(val_image_folder, transform=transform_val_test)
val_dataset_vit = AlbumentationsDataset(val_image_folder, transform=transform_val_test_vit)
test_dataset = AlbumentationsDataset(test_image_folder, transform=transform_val_test)
test_dataset_vit = AlbumentationsDataset(test_image_folder, transform=transform_val_test_vit)

train_loader = DataLoader(train_dataset, batch_size=250, shuffle=True)
train_loader_vit = DataLoader(train_dataset_vit, batch_size=250, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=250, shuffle=False)
val_loader_vit = DataLoader(val_dataset_vit, batch_size=250, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=250, shuffle=False)
test_loader_vit = DataLoader(test_dataset_vit, batch_size=250, shuffle=False)

print("Class names:", train_image_folder.classes)

show_batch(train_loader, "Train Batch")
show_batch(val_loader, "Validation Batch")

Class names: ['abomasnow', 'abra', 'absol', 'accelgor', 'aegislash-shield', 'aerodactyl', 'aggron', 'aipom', 'alakazam', 'alcremie', 'alomomola', 'altaria', 'amaura', 'ambipom', 'amoonguss', 'ampharos', 'annihilape', 'anorith', 'appletun', 'applin', 'araquanid', 'arbok', 'arboliva', 'arcanine', 'arceus', 'archen', 'archeops', 'arctibax', 'arctovish', 'arctozolt', 'ariados', 'armaldo', 'armarouge', 'aromatisse', 'aron', 'arrokuda', 'articuno', 'audino', 'aurorus', 'avalugg', 'axew', 'azelf', 'azumarill', 'azurill', 'bagon', 'baltoy', 'banette', 'barbaracle', 'barboach', 'barraskewda', 'basculegion-male', 'basculin-red-striped', 'bastiodon', 'baxcalibur', 'bayleef', 'beartic', 'beautifly', 'beedrill', 'beheeyem', 'beldum', 'bellibolt', 'bellossom', 'bellsprout', 'bergmite', 'bewear', 'bibarel', 'bidoof', 'binacle', 'bisharp', 'blacephalon', 'blastoise', 'blaziken', 'blipbug', 'blissey', 'blitzle', 'boldore', 'boltund', 'bombirdier', 'bonsly', 'bouffalant', 'bounsweet', 'braixen', 'brambl

<Figure size 1000x1000 with 1 Axes>

Validation Batch Batch - Labels: [0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 3, 3, 4, 5, 5, 5, 5, 6, 6]
Validation Batch Batch - Image Tensor Shape: torch.Size([20, 3, 128, 128])


<Figure size 1000x1000 with 1 Axes>

In [7]:
#Define the model
num_classes = len(train_dataset.classes)
model = create_model('rexnet_150', pretrained=True, num_classes=num_classes).to(device)
VITweights= Tmodels.ViT_B_16_Weights.DEFAULT
modelVIT= Tmodels.vit_b_16(weights=VITweights).to(device)

#Freeze ViT model weights
for param in modelVIT.parameters():
    param.requires_grad= False

modelVIT.heads = torch.nn.Linear(in_features=768, out_features=num_classes).to(device)

#Loss and optimizers
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
optimizerVIT = torch.optim.Adam(modelVIT.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
schedulerVIT = torch.optim.lr_scheduler.StepLR(optimizerVIT, step_size=3, gamma=0.1)

In [8]:
def train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, epochs=10):
    train_accuracies = []
    val_accuracies = []

    for epoch in range(epochs):
        model.train()
        train_loss = 0.0
        train_correct = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            train_correct += (outputs.argmax(dim=1) == labels).sum().item()

        train_accuracy = train_correct / len(train_loader.dataset)
        train_accuracies.append(train_accuracy)  # Store training accuracy

        #Validation step
        model.eval()
        val_correct = 0
        with torch.no_grad():
            for images, labels in val_loader:
                images, labels = images.to(device), labels.to(device)
                outputs = model(images)
                val_correct += (outputs.argmax(dim=1) == labels).sum().item()

        val_accuracy = val_correct / len(val_loader.dataset)
        val_accuracies.append(val_accuracy)  # Store validation accuracy

        print(f"Epoch {epoch+1}/{epochs}, Train Loss: {train_loss:.4f}, "
              f"Train Accuracy: {train_accuracy:.4f}, Validation Accuracy: {val_accuracy:.4f}")

        scheduler.step()

    #Plotting the training and validation accuracy
    plt.plot(range(1, epochs + 1), train_accuracies, marker='o', label='Train Accuracy')
    plt.plot(range(1, epochs + 1), val_accuracies, marker='x', label='Validation Accuracy')
    plt.title('Training and Validation Accuracy Over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.xticks(range(1, epochs + 1))
    plt.ylim(0, 1)
    plt.grid()
    plt.legend()
    plt.show()

In [9]:
#Train the CNN model
train_model(model, criterion, optimizer, scheduler, train_loader, val_loader, epochs=10)
torch.save(model.state_dict(),"CNNpokemon.pt")

Epoch 1/10, Train Loss: 375.5373, Train Accuracy: 0.2880, Validation Accuracy: 0.0113
Epoch 2/10, Train Loss: 40.1643, Train Accuracy: 0.9012, Validation Accuracy: 0.0185
Epoch 3/10, Train Loss: 7.0773, Train Accuracy: 0.9836, Validation Accuracy: 0.0193
Epoch 4/10, Train Loss: 2.3205, Train Accuracy: 0.9958, Validation Accuracy: 0.0193
Epoch 5/10, Train Loss: 1.2194, Train Accuracy: 0.9984, Validation Accuracy: 0.0193
Epoch 6/10, Train Loss: 0.8239, Train Accuracy: 0.9993, Validation Accuracy: 0.0193
Epoch 7/10, Train Loss: 0.7175, Train Accuracy: 0.9995, Validation Accuracy: 0.0193
Epoch 8/10, Train Loss: 0.6634, Train Accuracy: 0.9996, Validation Accuracy: 0.0193
Epoch 9/10, Train Loss: 0.6618, Train Accuracy: 0.9995, Validation Accuracy: 0.0193
Epoch 10/10, Train Loss: 0.7184, Train Accuracy: 0.9993, Validation Accuracy: 0.0193


<Figure size 640x480 with 1 Axes>

In [10]:
#Train the VIT model
train_model(modelVIT, criterion, optimizerVIT, schedulerVIT, train_loader_vit, val_loader_vit, epochs=9)
torch.save(modelVIT.state_dict(),"VITpokemon.pt")

Epoch 1/9, Train Loss: 445.1782, Train Accuracy: 0.2662, Validation Accuracy: 0.0139
Epoch 2/9, Train Loss: 213.7798, Train Accuracy: 0.7315, Validation Accuracy: 0.0177
Epoch 3/9, Train Loss: 115.2419, Train Accuracy: 0.8695, Validation Accuracy: 0.0181
Epoch 4/9, Train Loss: 79.2685, Train Accuracy: 0.9253, Validation Accuracy: 0.0181
Epoch 5/9, Train Loss: 75.4343, Train Accuracy: 0.9305, Validation Accuracy: 0.0181
Epoch 6/9, Train Loss: 71.8070, Train Accuracy: 0.9351, Validation Accuracy: 0.0185
Epoch 7/9, Train Loss: 69.1174, Train Accuracy: 0.9389, Validation Accuracy: 0.0185
Epoch 8/9, Train Loss: 69.0258, Train Accuracy: 0.9403, Validation Accuracy: 0.0185
Epoch 9/9, Train Loss: 68.4462, Train Accuracy: 0.9408, Validation Accuracy: 0.0185


<Figure size 640x480 with 1 Axes>

In [11]:
#Evaluation on Test
def evaluate_model(model, test_loader):
    model.eval()
    test_correct = 0
    all_preds = torch.zeros(len(test_loader.dataset), dtype=torch.long).to(device) 
    all_labels = torch.zeros(len(test_loader.dataset), dtype=torch.long).to(device) 
    
    with torch.no_grad():
        for batch_idx, (images, labels) in enumerate(test_loader):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            preds = outputs.argmax(dim=1)
            
            start_index = batch_idx * test_loader.batch_size
            end_index = start_index + labels.size(0)
            all_preds[start_index:end_index] = preds
            all_labels[start_index:end_index] = labels
            
            test_correct += (preds == labels).sum().item()

    test_acc = test_correct / len(test_loader.dataset) * 100
    print(f"Test Accuracy: {test_acc:.2f}%")
    
    print("Classification Report:")
    print(classification_report(all_labels.cpu(), all_preds.cpu(), zero_division=0))

    cm = confusion_matrix(all_labels.cpu(), all_preds.cpu())
    
    plt.figure(figsize=(200, 180))  # Adjust figure size as needed
    sns.set(font_scale=0.5)  # Adjust font scale for better readability
    sns.heatmap(cm, annot=False, fmt='d', cmap='Blues', cbar=True)
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.title("Confusion Matrix")
    plt.show()

In [12]:
#Evaluate CNN
evaluate_model(model, test_loader)

Test Accuracy: 98.12%
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         4
           1       1.00      0.25      0.40         4
           2       1.00      1.00      1.00         5
           3       1.00      1.00      1.00         2
           4       1.00      1.00      1.00         3
           5       1.00      1.00      1.00         4
           6       1.00      1.00      1.00         5
           7       1.00      1.00      1.00         4
           8       0.80      1.00      0.89         4
           9       1.00      1.00      1.00         2
          10       1.00      1.00      1.00         2
          11       1.00      1.00      1.00         5
          12       1.00      0.67      0.80         3
          13       1.00      1.00      1.00         4
          14       1.00      1.00      1.00         2
          15       1.00      1.00      1.00         4
          16       1.00      1.00   

<Figure size 20000x18000 with 2 Axes>

In [13]:
#Evaluate VIT
evaluate_model(modelVIT, test_loader_vit)

Test Accuracy: 87.37%
Classification Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         4
           1       0.75      0.75      0.75         4
           2       1.00      0.80      0.89         5
           3       1.00      0.50      0.67         2
           4       1.00      1.00      1.00         3
           5       1.00      0.75      0.86         4
           6       0.80      0.80      0.80         5
           7       1.00      0.75      0.86         4
           8       0.67      0.50      0.57         4
           9       1.00      0.50      0.67         2
          10       0.67      1.00      0.80         2
          11       1.00      1.00      1.00         5
          12       1.00      0.67      0.80         3
          13       0.80      1.00      0.89         4
          14       1.00      1.00      1.00         2
          15       0.75      0.75      0.75         4
          16       1.00      0.50   

<Figure size 20000x18000 with 2 Axes>

In [14]:
#Single Pokémon Classification Widget (No YOLO Detection) with Clickable URL

def classify_pokemon(image):
    """
    Given an image containing a single Pokemon, predicts its class using Rexnet and ViT ensamble, 
    and returns an annotaded image along with a url
    to the predicted Pokemon's Bulbapedia page (Pokemon wiki).
    """
    try:
        #Ensure the input is a NumPy array
        if not isinstance(image, np.ndarray):
            image = np.array(image)
        image=np.asarray(remove(Image.fromarray(image)))
        if image.shape[2] == 4:
                image = image[:, :, :3]
        #Prepare inputs for RexNet and VIT using the appropriate transforms
        #For RexNet (input size 128)
        transformed_rex = transform_val_test(image=image)
        input_tensor_rex = transformed_rex["image"].unsqueeze(0).to(device)
        
        #For ViT (input size 224)
        transformed_vit = transform_val_test_vit(image=image)
        input_tensor_vit = transformed_vit["image"].unsqueeze(0).to(device)
        
        #Set to evaluation mode
        model.eval()
        modelVIT.eval()
        with torch.no_grad():
            output_rex = model(input_tensor_rex)
            output_vit = modelVIT(input_tensor_vit)
            
            softmax_rex = torch.softmax(output_rex, dim=1)
            softmax_vit = torch.softmax(output_vit, dim=1)
            
            #Ensemble predictions with equal weighting
            ensemble_weight = 0.5
            ensemble_softmax = ensemble_weight * softmax_rex + (1 - ensemble_weight) * softmax_vit
            pred_idx = ensemble_softmax.argmax(dim=1).item()
            ensemble_conf = ensemble_softmax[0, pred_idx].item()
        
        #Extract individual model confidences and class name
        rexnet_conf = torch.max(softmax_rex).item()
        vit_conf = torch.max(softmax_vit).item()
        class_name = train_image_folder.classes[pred_idx]
        
        #Create clickable URL for Bulbapedia.
        base_url = "http://bulbapedia.bulbagarden.net/wiki/"
        url = f"{base_url}{class_name}_(Pokémon)"
        link_html = f'<a href="{url}" target="_blank">Click here for more info on {class_name}</a>'
        
        #Annotate the image with the prediction details
        annotated_image = image.copy()
        annotated_image_cv = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
        label_text = f"{class_name}: Ens {ensemble_conf*100:.1f}% | RexNet {rexnet_conf*100:.1f}% | ViT {vit_conf*100:.1f}%"
        cv2.putText(annotated_image_cv, label_text, (10, 30),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)
        annotated_image_cv = cv2.cvtColor(annotated_image_cv, cv2.COLOR_BGR2RGB)
        annotated_pil = PILImage.fromarray(annotated_image_cv)
        
        #Prediction details and the clickable link in HTML
        details_html = f"""
        <p>{label_text}</p>
        <p>{link_html}</p>
        """
        return annotated_pil, details_html

    except Exception as e:
        print("Error in classify_pokemon:", e)
        error_img = np.zeros((256, 256, 3), dtype=np.uint8)
        cv2.putText(error_img, "Error occurred", (10, 128),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
        return PILImage.fromarray(error_img), "<p>Error occurred.</p>"

#Create a Gradio interface for single Pokémon classification
interface_classify = gr.Interface(
    fn=classify_pokemon,
    inputs=gr.Image(type="numpy", label="Input Image (Single Pokémon)"),
    outputs=[
        gr.Image(type="pil", label="Annotated Image"),
        gr.HTML(label="Prediction Details")
    ],
    title="Pokémon Classifier (Single Pokémon)",
    description="This widget uses RexNet and ViT to classify a single Pokémon image without using YOLO detection."
)

interface_classify.launch()

* Running on local URL:  http://127.0.0.1:7860

To create a public link, set `share=True` in `launch()`.




In [16]:
def run_yolo(image_bgr, conf_threshold, iou_threshold):
    model_yolo = modelyolo
    width = image_bgr.shape[0]
    height = image_bgr.shape[1]
    results = model_yolo.predict(source=image_bgr, imgsz=(width, height), conf=conf_threshold, iou=iou_threshold)
    if len(results) == 0 or not results[0].boxes:
        return np.empty((0, 4)), np.empty((0,))

    try:
        boxes = results[0].boxes.xyxy.cpu().numpy()
        confs = results[0].boxes.conf.cpu().numpy()
    except Exception as e:
        print("Error extracting YOLO boxes/confidences:", e)
        boxes = np.array(results[0].boxes.xyxy)
        confs = np.array(results[0].boxes.conf)
    return boxes, confs

def apply_nms(boxes, confs, conf_threshold, nms_threshold=0.4):
    #Apply Non-Maximum Suppression on detected boxes.
    bboxes = []
    for box in boxes:
        x, y, x2, y2 = box
        bboxes.append([int(x), int(y), int(x2 - x), int(y2 - y)])
    indices = cv2.dnn.NMSBoxes(bboxes, confs.tolist(), score_threshold=conf_threshold, nms_threshold=nms_threshold)
    if len(indices) > 0:
        indices = indices.flatten()
        return boxes[indices], confs[indices]
    else:
        return np.empty((0,4)), np.empty((0,))

def tta_detection(image_bgr, conf_threshold,iou_threshold):
    """
    Perform Test-Time Augmentation by running YOLO on the original and 
    horizontally flipped image, then combine the detections.
    """
    boxes_orig, confs_orig = run_yolo(image_bgr, conf_threshold,iou_threshold)
    image_flip = cv2.flip(image_bgr, 1)
    boxes_flip, confs_flip = run_yolo(image_flip, conf_threshold,iou_threshold)
    
    h, w, _ = image_bgr.shape
    if boxes_flip.shape[0] > 0:
        boxes_flip_adjusted = boxes_flip.copy()
        boxes_flip_adjusted[:, 0] = w - boxes_flip[:, 2]
        boxes_flip_adjusted[:, 2] = w - boxes_flip[:, 0]
    else:
        boxes_flip_adjusted = boxes_flip
    
    if boxes_orig.size > 0 and boxes_flip_adjusted.size > 0:
        combined_boxes = np.concatenate([boxes_orig, boxes_flip_adjusted], axis=0)
        combined_confs = np.concatenate([confs_orig, confs_flip], axis=0)
    elif boxes_orig.size > 0:
        combined_boxes = boxes_orig
        combined_confs = confs_orig
    else:
        combined_boxes = boxes_flip_adjusted
        combined_confs = confs_flip
    
    final_boxes, final_confs = apply_nms(combined_boxes, combined_confs, conf_threshold, nms_threshold=0.4)
    return final_boxes, final_confs

def detect_and_classify(image, conf_threshold=0.1,iou_threshold=0.1):
    """
    Detection and classification with TTA.
    Runs YOLO on the original and flipped images, combines the detections with NMS,
    classifies each region, annotates the image, and returns detailed prediction information.
    """
    ensemble_weight = 0.5  #RexNet weight (ViT gets 1 - ensemble_weight)
    margin = 10            #Margin (pixels) to expand bounding boxes

    detected_details = []  #To accumulate detailed prediction info.

    try:
        #Convert image from RGB to BGR for YOLO.
        image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
        boxes, confs = tta_detection(image_bgr, conf_threshold,iou_threshold)
        print(f"Combined detections: {len(boxes)} boxes found.")
        
        if boxes.shape[0] == 0:
            annotated_image = image.copy()
            annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_RGB2BGR)
            cv2.putText(annotated_image, "No Pokémon detected", (20, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
            return PILImage.fromarray(annotated_image), "No Pokémon detected."
        
        annotated_image = image_bgr.copy()
        predictions_count = 0
        
        for i, box in enumerate(boxes):
            if confs[i] < conf_threshold:
                continue
            predictions_count += 1
            
            x1, y1, x2, y2 = map(int, box[:4])
            h_img, w_img, _ = image_bgr.shape
            x1 = max(0, x1 - margin)
            y1 = max(0, y1 - margin)
            x2 = min(w_img, x2 + margin)
            y2 = min(h_img, y2 + margin)
            
            #Use rmbg to remove background from crop to improve prediction results by converting to PIL using remove and converting back to numpy
            crop = np.asarray(remove(Image.fromarray(image[y1:y2, x1:x2])))
            if crop.shape[2] == 4:
                crop = crop[:, :, :3]
            if crop.size == 0 or crop.shape[0] < 20 or crop.shape[1] < 20:
                print(f"Skipping crop {i} due to small size.")
                continue
            
            try:
                transformed_rex = transform_val_test(image=crop)
                crop_tensor_rex = transformed_rex["image"].unsqueeze(0).to(device)
            except Exception as e:
                print(f"Error in RexNet transform for detection {i}:", e)
                continue
            
            try:
                transformed_vit = transform_val_test_vit(image=crop)
                crop_tensor_vit = transformed_vit["image"].unsqueeze(0).to(device)
            except Exception as e:
                print(f"Error in ViT transform for detection {i}:", e)
                continue
            
            model.eval()
            modelVIT.eval()
            with torch.no_grad():
                output_rex = model(crop_tensor_rex)
                output_vit = modelVIT(crop_tensor_vit)
                softmax_rex = torch.softmax(output_rex, dim=1)
                softmax_vit = torch.softmax(output_vit, dim=1)
                ensemble_softmax = ensemble_weight * softmax_rex + (1 - ensemble_weight) * softmax_vit
                pred_idx = ensemble_softmax.argmax(dim=1).item()
                ensemble_conf = ensemble_softmax[0, pred_idx].item()
            
            rexnet_conf = torch.max(softmax_rex).item()
            vit_conf = torch.max(softmax_vit).item()
            class_name = train_image_folder.classes[pred_idx]
            
            #Append detailed info.
            detected_details.append(
                f"{class_name}: Ensemble: {ensemble_conf*100:.1f}% | RexNet: {rexnet_conf*100:.1f}% | ViT: {vit_conf*100:.1f}%"
            )
            
            label_text = (f"{class_name}\n"
                          f"Ens: {ensemble_conf*100:.1f}%\n"
                          f"RexNet: {rexnet_conf*100:.1f}%\n"
                          f"ViT: {vit_conf*100:.1f}%")
            
            cv2.rectangle(annotated_image, (x1, y1), (x2, y2), (0, 255, 0), 2)
            y_offset = 0
            for line in label_text.split("\n"):
                cv2.putText(annotated_image, line, (x1, y1 - 10 - y_offset),
                            cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 255, 0), 2)
                y_offset += 20
        
        if predictions_count == 0:
            cv2.putText(annotated_image, "No predictions above threshold", (20, 30),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), 2)
            print("No predictions passed the confidence threshold.")
            detected_details = ["No predictions above threshold."]
        
        annotated_image = cv2.cvtColor(annotated_image, cv2.COLOR_BGR2RGB)
        annotated_image = np.clip(annotated_image, 0, 255).astype(np.uint8)
        annotated_pil = PILImage.fromarray(annotated_image)
        
        details_str = "\n".join(detected_details)
        return annotated_pil, details_str
    
    except Exception as e:
        print("An error occurred:", e)
        error_img = np.zeros((256, 256, 3), dtype=np.uint8)
        cv2.putText(error_img, "Error occurred", (10, 128),
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 255), 2)
        return PILImage.fromarray(error_img), "Error occurred."

# Create the Gradio interface with two outputs.
interface = gr.Interface(
    fn=detect_and_classify,
    inputs=[
        gr.Image(type="numpy", label="Input Image"),
        gr.Slider(0.0, 1.0, value=0.1, label="Confidence Threshold"),
        gr.Slider(0.0, 1.0, value=0.1, label="IoU Threshold") 
    ],
    outputs=[
        gr.Image(type="pil", label="Annotated Image"),
        gr.Textbox(label="Detected Pokémon Details")
    ],
    title="Enhanced Pokémon Detector & Classifier with TTA",
    description=("This tool uses YOLO with test-time augmentation (TTA) to improve detection, then ensembles predictions "
                 "from RexNet and ViT for classification. The output includes detailed percentages for each model's prediction. "
                 "Adjust the confidence and IoU thresholds to fine-tune performance.")
)

interface.launch()

* Running on local URL:  http://127.0.0.1:7862

To create a public link, set `share=True` in `launch()`.





0: 704x1408 1 Pokemon, 242.3ms
Speed: 35.7ms preprocess, 242.3ms inference, 39.0ms postprocess per image at shape (1, 3, 704, 1408)

0: 704x1408 3 Pokemons, 140.2ms
Speed: 36.6ms preprocess, 140.2ms inference, 86.5ms postprocess per image at shape (1, 3, 704, 1408)
Combined detections: 4 boxes found.
