In [None]:
import os
import torch
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
from torchvision.models import vgg16
from torch.utils.data import DataLoader
from PIL import Image
import shutil

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = vgg16(pretrained=True).to(device).eval()

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# Dataset path
dataset_path = 'imagenet-val/'
dataset = ImageFolder(dataset_path, transform=transform)
loader = DataLoader(dataset, batch_size=1, shuffle=False)

# Load ImageNet class index to WordNet mapping
import json
from urllib.request import urlopen

url = "https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt"
imagenet_labels = [s.strip() for s in urlopen(url).read().decode("utf-8").splitlines()]

wnid_mapping_url = "https://raw.githubusercontent.com/raghakot/keras-vis/refs/heads/master/resources/imagenet_class_index.json"
class_idx = json.load(urlopen(wnid_mapping_url))  # {'0': ['n01440764', 'tench'], ...}

idx_to_wnid = {int(k): v[0] for k, v in class_idx.items()}

# Output path
output_dir = './filtered_correct_vgg16'
os.makedirs(output_dir, exist_ok=True)

saved_wnids = set()

# Iterate over dataset
for i, (img, label) in enumerate(loader):
    img = img.to(device)

    with torch.no_grad():
        pred = model(img).argmax(dim=1).item()

    gt_wnid = dataset.classes[label.item()]
    pred_wnid = idx_to_wnid[pred]

    if gt_wnid == pred_wnid and gt_wnid not in saved_wnids:
        original_img_path = dataset.imgs[i][0]

        shutil.copy(original_img_path, os.path.join(output_dir, os.path.basename(original_img_path)))
        saved_wnids.add(gt_wnid)
        print(f"[{len(saved_wnids)}] Saved correct prediction for class {gt_wnid}")

    if len(saved_wnids) >= 500:
        print("✅ Saved 500 correct images, 1 per class.")
        break
