In [None]:
import torch
import os
from PIL import Image
from torchvision import models, transforms
from torchvision.models import ResNet50_Weights

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = models.resnet18(pretrained=True)
model = model.to(device)
model.eval()

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std=[0.229, 0.224, 0.225]
    )
])


In [None]:
image_dir = "cat_species"

In [None]:
imagenet_classes = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]

valid_cat_keywords = [
    "cat", "tabby", "tiger", "persian", "egyptian", "lynx"
]

In [None]:
correct = 0
total = 0

for species in os.listdir(image_dir):
    species_path = os.path.join(image_dir, species)

    if not os.path.isdir(species_path):
        continue

    for img_name in os.listdir(species_path):
        if not img_name.lower().endswith((".png", ".jpg", ".jpeg")):
            continue

        img_path = os.path.join(species_path, img_name)
        image = Image.open(img_path).convert("RGB")
        image = transform(image).unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(image)
            pred_idx = output.argmax(dim=1).item()

        label = imagenet_classes[pred_idx].lower()

        if any(k in label for k in valid_cat_keywords):
            correct += 1

        total += 1

accuracy = (correct / total) * 100
print(f"Proxy Accuracy (Cat Detection): {accuracy:.2f}%")