In [1]:
from torchvision.io import decode_image
from torchvision.models import vgg19, VGG19_Weights

import glob
from PIL import Image
import os

In [2]:
#image_paths = glob.glob('./modified_imgs/**/*.jpg', recursive=True)
image_paths = glob.glob('./further_modified_imgs/**/*.jpg', recursive=True)
image_paths

['./further_modified_imgs/(chain saw)0.9998_double.jpg',
 './further_modified_imgs/(chain saw)0.9993_double.jpg',
 './further_modified_imgs/(gas pump)0.9999_further_distortion.jpg',
 './further_modified_imgs/(chain saw)0.9999_double.jpg',
 './further_modified_imgs/(church)0.9963_silouette.jpg',
 './further_modified_imgs/(gas pump)0.9994_further_distortion.jpg',
 './further_modified_imgs/(golf ball)1_no_texture.jpg',
 './further_modified_imgs/(chain saw)1_silouette.jpg',
 './further_modified_imgs/(chain saw)0.9992_no_saw.jpg',
 './further_modified_imgs/(parachute)0.9996.jpg']

In [3]:
# Step 1: Initialize model with the best available weights
weights = VGG19_Weights.DEFAULT
model = vgg19(weights=weights)
model.eval()

# Step 2: Initialize the inference transforms
preprocess = weights.transforms()

for img_path in image_paths:
    # Read image
    img = decode_image(img_path)
    
    # Step 3: Apply inference preprocessing transforms
    batch = preprocess(img).unsqueeze(0)
    
    # Step 4: Use the model and print the predicted category
    prediction = model(batch).squeeze(0).softmax(0)
    class_id = prediction.argmax().item()
    score = prediction[class_id].item()
    category_name = weights.meta["categories"][class_id]

    filename = os.path.basename(img_path)
    name = os.path.splitext(filename)[0]

    # Expecting format like "(class name)0.1234"
    try:
        true_label = name.split(")")[0][1:]  # removes '(' at start and ')' at end
    except IndexError:
        true_label = "unknown"

    # Confidence score for the true label
    if true_label in weights.meta["categories"]:
        true_index = weights.meta["categories"].index(true_label)
        true_conf_score = prediction[true_index].item()
    else:
        true_conf_score = None  # label not found

    print(f"file name: {name}, prediction: {category_name}, conf: {score:.4f}, True: {true_label}, True Conf: {true_conf_score:.4f}")
    

file name: (chain saw)0.9998_double, prediction: chain saw, conf: 0.8452, True: chain saw, True Conf: 0.8452
file name: (chain saw)0.9993_double, prediction: chain saw, conf: 0.6866, True: chain saw, True Conf: 0.6866
file name: (gas pump)0.9999_further_distortion, prediction: vending machine, conf: 0.6098, True: gas pump, True Conf: 0.1009
file name: (chain saw)0.9999_double, prediction: chain saw, conf: 0.9409, True: chain saw, True Conf: 0.9409
file name: (church)0.9963_silouette, prediction: stupa, conf: 0.2656, True: church, True Conf: 0.1921
file name: (gas pump)0.9994_further_distortion, prediction: gas pump, conf: 0.5161, True: gas pump, True Conf: 0.5161
file name: (golf ball)1_no_texture, prediction: ping-pong ball, conf: 0.8256, True: golf ball, True Conf: 0.0061
file name: (chain saw)1_silouette, prediction: chain saw, conf: 0.1778, True: chain saw, True Conf: 0.1778
file name: (chain saw)0.9992_no_saw, prediction: chain saw, conf: 0.9323, True: chain saw, True Conf: 0.9323