In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
from torchvision.models import resnet50, ResNet50_Weights
from torchvision.transforms import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
import os
import copy
import time
from typing import Dict, List

In [2]:
weights = ResNet50_Weights.IMAGENET1K_V1
model = resnet50(weights=weights)

In [3]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)

In [4]:
imagenet_indices = {
    # Local Index : ImageNet Index 
    0: 418,  # ballpoint
    1: 436,  # car
    2: 504,  # coffeemug
    3: 504,  # cup (using the same index as coffeemug, common for similar items)
    4: 546,  # electricguitar
    5: 620,  # laptop
    6: 651,  # microwave
    7: 673,  # mouse
    8: 681, # notebook
    9: 837, # sunglasses
    10: 859, # toaster
    11: 898, # waterbottle
}

In [5]:
local_to_imagenet_map: Dict[int, int] = {i: imagenet_indices[i] for i in range(len(imagenet_indices))}

In [13]:
data_dir = 'finetune_dataset' 
num_epochs = 100
batch_size = 32
learning_rate = 0.0001 # Critical: Use a low LR for full finetuning

In [7]:
class CustomImageFolder(ImageFolder):
    """
    Overrides the ImageFolder class to map local class indices (0-13) 
    to their corresponding global ImageNet indices (0-999).
    """
    def __init__(self, root: str, transform=None, index_map: Dict[int, int] = None):
        super().__init__(root, transform)
        self.index_map = index_map

        if self.index_map is None:
            raise ValueError("index_map must be provided for ImageNet finetuning.")
            
        # Re-map the self.samples list to use the global index as the target
        new_samples = []
        for path, local_idx in self.samples:
            global_idx = self.index_map.get(local_idx)
            if global_idx is not None:
                new_samples.append((path, global_idx))
            else:
                print(f"Warning: Missing ImageNet index for local class {local_idx}")
                
        self.samples = new_samples
        
    def __getitem__(self, index: int):
        """
        Modified to ensure the returned target is the ImageNet index.
        """
        path, target = self.samples[index] # target is now the global ImageNet index
        sample = self.loader(path)
        if self.transform is not None:
            sample = self.transform(sample)
        return sample, target

In [8]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(weights.transforms().mean, weights.transforms().std)
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(weights.transforms().mean, weights.transforms().std)
    ]),
}

In [9]:
try:
    image_datasets = {
        x: CustomImageFolder(os.path.join(data_dir, x), 
                             data_transforms[x], 
                             index_map=local_to_imagenet_map) 
        for x in ['train', 'val']
    }
    dataloaders = {
        x: DataLoader(image_datasets[x], batch_size=batch_size, shuffle=True, num_workers=0) 
        for x in ['train', 'val']
    }
    dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
except FileNotFoundError as e:
    print(f"\nFATAL ERROR: Could not find data directory or split: {e}")
    print(f"Please ensure your data structure looks like this: {data_dir}/train/ballpoint/... and {data_dir}/val/ballpoint/...")
    

In [14]:
criterion = nn.CrossEntropyLoss()

# Finetuning ALL layers: apply optimizer to all model parameters
optimizer_ft = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)

# Decays the learning rate by a factor of 0.1 every 7 epochs
scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=15, gamma=0.1)

In [15]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=25):
    """General function to train and validate a model."""
    since = time.time()
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    print("--- Starting Full Finetuning ---")
    
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch}/{num_epochs - 1}')
        print('-' * 25)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in dataloaders[phase]:
                inputs = inputs.to(device)
                labels = labels.to(device) # Labels are now ImageNet indices

                optimizer.zero_grad() 

                # Forward pass
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    # Backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':
                scheduler.step()

            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print(f'{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}')

            # Deep copy the model if it has the best validation accuracy
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())

    time_elapsed = time.time() - since
    print(f'\nTraining complete in {time_elapsed // 60:.0f}m {time_elapsed % 60:.0f}s')
    print(f'Best val Acc: {best_acc:.4f}')

    # Load best model weights
    model.load_state_dict(best_model_wts)
    return model

In [16]:
model_ft = train_model(model, criterion, optimizer_ft, scheduler, num_epochs=num_epochs)

--- Starting Full Finetuning ---

Epoch 0/99
-------------------------
train Loss: 4.8042 Acc: 0.1915
val Loss: 2.6095 Acc: 0.5306

Epoch 1/99
-------------------------
train Loss: 4.2769 Acc: 0.2979
val Loss: 2.5687 Acc: 0.5510

Epoch 2/99
-------------------------
train Loss: 4.4815 Acc: 0.2766
val Loss: 2.4975 Acc: 0.5510

Epoch 3/99
-------------------------
train Loss: 4.0607 Acc: 0.3191
val Loss: 2.3798 Acc: 0.5714

Epoch 4/99
-------------------------
train Loss: 3.8601 Acc: 0.3617
val Loss: 2.2325 Acc: 0.5714

Epoch 5/99
-------------------------
train Loss: 3.5140 Acc: 0.4681
val Loss: 2.0454 Acc: 0.6122

Epoch 6/99
-------------------------
train Loss: 3.5766 Acc: 0.4468
val Loss: 1.8542 Acc: 0.6327

Epoch 7/99
-------------------------
train Loss: 3.0474 Acc: 0.4894
val Loss: 1.6940 Acc: 0.6327

Epoch 8/99
-------------------------
train Loss: 3.0853 Acc: 0.5106
val Loss: 1.5503 Acc: 0.6939

Epoch 9/99
-------------------------
train Loss: 2.4884 Acc: 0.5106
val Loss: 1.4233

In [17]:
save_path = 'finetuned_resnet50_full_1000class100epochs.pth'
torch.save(model_ft.state_dict(), save_path)
print(f"\nBest model weights saved to {save_path}")


Best model weights saved to finetuned_resnet50_full_1000class100epochs.pth


In [18]:
import torch
from torchvision.models import resnet50, ResNet50_Weights
import cv2
import numpy as np
from torchvision.transforms import transforms
from PIL import Image
import os
import re

In [19]:
preprocess = weights.transforms()
categories = weights.meta["categories"]

In [20]:
def image_loader(image_name):
    if not os.path.exists(image_name):
        print(f"ERROR: File not found at path: {image_name}")
        return None
        
    image = Image.open(image_name).convert("RGB")
    return image

In [21]:
glare_images = './glare_distortion/glare_images/'
results = {}
labels = []

In [22]:
model.eval()

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 

In [23]:
def getLabel(filename):
    match = re.search(r'glare\d+(.*?)\.jpg$', filename)
    if match:
        return match.group(1) 
    return None

In [24]:
def get_predictions(input_batch, image_type, file_key, verbose=False):
    with torch.no_grad():
        output = model(input_batch)
        
    probabilities = torch.nn.functional.softmax(output[0], dim=0)
    
    # Get Top 5 Predictions
    top5_prob, top5_indices = torch.topk(probabilities, 5)
    
    predictions = []
    for i in range(top5_prob.size(0)):
        predicted_index = top5_indices[i].item()
        predicted_label = categories[predicted_index]
        probability = top5_prob[i].item()
        predictions.append({
            "rank": i + 1,
            "label": predicted_label,
            "probability": probability
        })
    
    # Get the single best prediction 
    best_pred = predictions[0]
    
    results[file_key][image_type] = {
        "best_label": best_pred["label"],
        "best_probability": best_pred["probability"],
        "top_5_predictions": predictions
    }
    if(verbose):     
        print(f"  {image_type.capitalize()} Best Label: **{best_pred['label']}** (P: {best_pred['probability']:.4f})")

In [25]:
def preprocess_cv_image(cv_img, preprocess_func):
    if cv_img.ndim == 2:
        rgb_img = np.stack([cv_img, cv_img, cv_img], axis=2)
    else:
        rgb_img = cv_img

    pil_image = Image.fromarray(rgb_img)
    
    input_tensor = preprocess_func(pil_image)
    
    return input_tensor

In [26]:
def check_label_match(ground_truth, predicted_label):
    gt = ground_truth.lower().replace(" ", "")
    pl = predicted_label.lower().replace(" ", "")
    # print(f"ground truth: {gt} == {pl} : predicted")
    if gt in pl:
        return True
    
    if pl in gt:
        return True
        
    if gt == pl:
        return True
        
    return False

In [13]:
verbose = False
print(f"Starting batch prediction on device: {device}")
print("-" * 50)

for filename in os.listdir(glare_images):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
        
        original_file_path = os.path.join(glare_images, filename)

        if(verbose):
            print(f"Processing: **{filename}**")
        
        file_key = os.path.splitext(filename)[0]
        results[file_key] = {
            "original": {},
            "processed": {},
            "label" : getLabel(filename)
        }
        
        try:
            image_original = image_loader(original_file_path)
            
            input_tensor_original = preprocess(image_original)
            
            input_batch_original = input_tensor_original.unsqueeze(0).to(device)
            
            get_predictions(input_batch_original, "original", file_key)
            if(verbose):
                print("-" * 50)
            
        except Exception as e:
            print(f"An error occurred while processing **{filename}**: {e}")
            del results[file_key] # Remove incomplete result
            print("-" * 50)
            continue
            
print("Batch processing complete.")

Starting batch prediction on device: cuda:0
--------------------------------------------------
Batch processing complete.


In [27]:
# For Finetuned Version
verbose = False
print(f"Starting batch prediction on device: {device}")
print("-" * 50)

for filename in os.listdir(glare_images):
    if filename.endswith(".jpg") or filename.endswith(".jpeg") or filename.endswith(".png"):
        
        original_file_path = os.path.join(glare_images, filename)

        if(verbose):
            print(f"Processing: **{filename}**")
        
        file_key = os.path.splitext(filename)[0]
        results[file_key] = {
            "original": {},
            "processed": {},
            "label" : getLabel(filename)
        }
        
        try:
            image_original = image_loader(original_file_path)
            
            input_tensor_original = preprocess(image_original)
            
            input_batch_original = input_tensor_original.unsqueeze(0).to(device)
            
            get_predictions(input_batch_original, "original", file_key)
            if(verbose):
                print("-" * 50)
            
        except Exception as e:
            print(f"An error occurred while processing **{filename}**: {e}")
            del results[file_key] # Remove incomplete result
            print("-" * 50)
            continue
            
print("Batch processing complete.")

Starting batch prediction on device: cuda:0
--------------------------------------------------
Batch processing complete.


In [28]:
def print_side_by_side_comparison(results_dict, print_results=False):
    max_label_width = 0
    
    # Find the longest label across all Top 5 lists
    for data in results_dict.values():
        for pred in data["original"]["top_5_predictions"]:
            max_label_width = max(max_label_width, len(pred["label"]))

    label_padding = max_label_width + 12
    orig_top1 = 0
    original_correct = 0
    for file_key, data in results_dict.items():
        original_preds = data["original"]["top_5_predictions"]
        ground_truth = data['label'] 
        if(print_results):
            print("=" * 70)
            print(f"**{file_key}** (Ground Truth: {ground_truth})")
            print("-" * 70)
            
            header_format = f"{'Original Top 5':<{label_padding}} | {'Processed Top 5'}"
            print(header_format)
            print("-" * 70)

        original_is_counted = False

        for i in range(5):
            orig_label = original_preds[i]['label']
            orig_prob = original_preds[i]['probability']
            orig_is_match = check_label_match(ground_truth, orig_label)
            if i == 0 and orig_is_match:
                orig_top1 += 1
            orig_star = "*" if orig_is_match else ""
            if orig_is_match and not original_is_counted:
                original_correct += 1 
                original_is_counted = True
                original_prob = orig_prob
            
            orig_output = f"{orig_label}: {orig_prob:.4f}{orig_star}"
            
            # Use f-string formatting to align the original column based on max width
            comparison_line = f"{orig_output:<{label_padding}}" 
            
            if(print_results):
                print(comparison_line)

        
    if(print_results):
        print("=" * 70)
    return original_correct, orig_top1


# Original

In [15]:
print_results = False
original_correct, orig_top1 = print_side_by_side_comparison(results, print_results=print_results)

In [16]:
print([original_correct,orig_top1])

[38, 31]


In [17]:
print(f'OG Acc: {original_correct/(len(os.listdir(glare_images)))}')

OG Acc: 0.6229508196721312


In [18]:
print(f'OG Acc Best Guess: {orig_top1/(len(os.listdir(glare_images)))}')

OG Acc Best Guess: 0.5081967213114754


# Finetuned

In [29]:
print_results = False
original_correct, orig_top1 = print_side_by_side_comparison(results, print_results=print_results)

In [30]:
print([original_correct,orig_top1])

[32, 26]


In [31]:
print(f'OG Acc: {original_correct/(len(os.listdir(glare_images)))}')
print(f'OG Acc Best Guess: {orig_top1/(len(os.listdir(glare_images)))}')

OG Acc: 0.5245901639344263
OG Acc Best Guess: 0.4262295081967213
