In [14]:
import torch
from torchvision.models import resnet50, ResNet50_Weights
import cv2
from torchvision.transforms import transforms
from PIL import Image
import os
import re

In [15]:
weights = ResNet50_Weights.IMAGENET1K_V1
model = resnet50(weights=weights)

preprocess = weights.transforms()
categories = weights.meta["categories"]

def image_loader(image_name):
    if not os.path.exists(image_name):
        print(f"ERROR: File not found at path: {image_name}")
        return None
        
    image = Image.open(image_name).convert("RGB")
    return image

In [16]:
glare_images = 'glare_images/'
glare_reduction = 'glare_reduction_outputs/'
results = {}
labels = []

In [17]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [18]:
def getLabel(filename):
    match = re.search(r'glare\d+(.*?)\.jpg$', filename)
    if match:
        return match.group(1) # Return the content of the captured group (.*?)
    return None

In [None]:
def get_predictions(input_batch, image_type, file_key, verbose=False):
    with torch.no_grad():
        output = model(input_batch)
        
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    
    # Get Top 5 Predictions
    top5_prob, top5_indices = torch.topk(probabilities, 5)
    
    predictions = []
    for i in range(top5_prob.size(0)):
        predicted_index = top5_indices[i].item()
        predicted_label = categories[predicted_index]
        probability = top5_prob[i].item()
        predictions.append({
            "rank": i + 1,
            "label": predicted_label,
            "probability": probability
        })
    
    # Get the single best prediction (index 0 of the list)
    best_pred = predictions[0]
    
    results[file_key][image_type] = {
        "best_label": best_pred["label"],
        "best_probability": best_pred["probability"],
        "top_5_predictions": predictions
    }
    
    if(verbose):
        print(f"  {image_type.capitalize()} Best Label: **{best_pred['label']}** (P: {best_pred['probability']:.4f})")

In [None]:
verbose = False
print(f"Starting batch prediction on device: {device}")
print("-" * 50)

for filename in os.listdir(glare_images):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
        
        original_file_path = os.path.join(glare_images, filename)
        processed_filename = 'processed1_' + filename
        processed_file_path = os.path.join(glare_reduction, processed_filename) 
        
        if not os.path.exists(processed_file_path):
            print(f"Warning: Corresponding processed file not found for **{filename}**. Skipping.")
            continue
            
        if(verbose):
            print(f"Processing: **{filename}**")
        
        file_key = os.path.splitext(filename)[0]
        results[file_key] = {
            "original": {},
            "processed": {},
            "label" : getLabel(filename)
        }
        
        try:
            image_original = image_loader(original_file_path)
            image_processed = image_loader(processed_file_path)
            
            input_tensor_original = preprocess(image_original)
            input_tensor_processed = preprocess(image_processed)
            
            input_batch_original = input_tensor_original.unsqueeze(0).to(device)
            input_batch_processed = input_tensor_processed.unsqueeze(0).to(device)
            
            get_predictions(input_batch_original, "original", file_key)
            get_predictions(input_batch_processed, "processed",file_key)
            if(verbose):
                print("-" * 50)
            
        except Exception as e:
            print(f"An error occurred while processing **{filename}**: {e}")
            del results[file_key] # Remove incomplete result
            print("-" * 50)
            continue
            
print("Batch processing complete.")

Starting batch prediction on device: cuda
--------------------------------------------------
Batch processing complete.


In [21]:
def check_label_match(ground_truth, predicted_label):
    gt = ground_truth.lower().replace(" ", "")
    pl = predicted_label.lower().replace(" ", "")
    # print(f"ground truth: {gt} == {pl} : predicted")
    if gt in pl:
        return True
    
    if pl in gt:
        return True
        
    if gt == pl:
        return True
        
    return False

In [None]:
def print_side_by_side_comparison(results_dict, print_results=False):
    max_label_width = 0
    
    # Find the longest label across all Top 5 lists
    for data in results_dict.values():
        for pred in data["original"]["top_5_predictions"]:
            max_label_width = max(max_label_width, len(pred["label"]))
        for pred in data["processed"]["top_5_predictions"]:
            max_label_width = max(max_label_width, len(pred["label"]))

    label_padding = max_label_width + 12
    original_correct = 0
    processed_correct = 0
    processed_improved_list = []
    for file_key, data in results_dict.items():
        original_preds = data["original"]["top_5_predictions"]
        processed_preds = data["processed"]["top_5_predictions"]
        ground_truth = data['label'] 
        if(print_results):
            print("=" * 70)
            print(f"**{file_key}** (Ground Truth: {ground_truth})")
            print("-" * 70)
            
            header_format = f"{'Original Top 5':<{label_padding}} | {'Processed Top 5'}"
            print(header_format)
            print("-" * 70)

        original_is_counted = False
        processed_is_counted = False
        original_prob = 0
        processed_prob = 0
        for i in range(5):
            orig_label = original_preds[i]['label']
            orig_prob = original_preds[i]['probability']
            orig_is_match = check_label_match(ground_truth, orig_label)
            orig_star = "*" if orig_is_match else ""
            if orig_is_match and not original_is_counted:
                original_correct += 1 
                original_is_counted = True
                original_prob = orig_prob
            
            proc_label = processed_preds[i]['label']
            proc_prob = processed_preds[i]['probability']
            proc_is_match = check_label_match(ground_truth, proc_label)
            proc_star = "*" if proc_is_match else ""
            if proc_is_match and not processed_is_counted:
                processed_correct += 1 
                processed_is_counted = True
                processed_prob = proc_prob
            
            orig_output = f"{orig_label}: {orig_prob:.4f}{orig_star}"
            proc_output = f"{proc_label}: {proc_prob:.4f}{proc_star}"
            
            # Use f-string formatting to align the original column based on max width
            comparison_line = f"{orig_output:<{label_padding}} | {proc_output}"
            
            if(print_results):
                print(comparison_line)

        if(original_correct != 0 and processed_correct != 0 and original_prob < processed_prob): 
            processed_improved_list.append({file_key, original_prob, processed_prob})
    if(print_results):
        print("=" * 70)
    return original_correct, processed_correct, processed_improved_list


In [23]:
original_correct, predicted_correct, processed_improved_list = print_side_by_side_comparison(results, print_results=False)

In [24]:
print([original_correct,predicted_correct,len(processed_improved_list)])

[31, 29, 12]


In [25]:
processed_improved_list

[{0.4767727255821228, 0.5010380148887634, 'glare10spotlight'},
 {0.9995918869972229, 0.9996447563171387, 'glare15screwdriver'},
 {0.95809406042099, 0.9622239470481873, 'glare18remotecontrol'},
 {0.6868853569030762, 0.953280508518219, 'glare19ballpoint'},
 {0.9011024832725525, 0.9927190542221069, 'glare22waterbottle'},
 {0.31214749813079834, 0.3436669111251831, 'glare26winebottle'},
 {0.4658496081829071, 0.5057600736618042, 'glare28mixingbowl'},
 {0, 0.5233057141304016, 'glare32refrigerator'},
 {0.25956276059150696, 0.525222659111023, 'glare3lampshade'},
 {0.7205631732940674, 0.742561399936676, 'glare40laptop'},
 {0.9976363182067871, 0.9977924823760986, 'glare48light'},
 {0, 0.5589098334312439, 'glare50seashore'}]