# This code consists of a single codeblock for classifying images with a GUI.
* It uses Tkinter, and will function as a user friendly image classification program.
* It has been incorporated into an executable (.exe) file through PyInstaller. 
* It automatically detects the model (.pth) file as long as it is in the same directory as the .exe file.
* It moves classified images into their respective predicted class folders, and stores all image predictions and confidence scores for each image. These are then reported in the predictions.xlsx file after classification.

In [3]:
# With automatic normalization but WITHOUT renaming
# Automatically finds the model .pth file. 

import os
import torch
from torchvision import transforms
from torchvision.models import resnet18
from PIL import Image, UnidentifiedImageError
import shutil
import tkinter as tk
from tkinter import filedialog, messagebox
from threading import Thread, Event
import time
import pandas as pd

# Load the trained model, class mapping, and number of classes
def load_model(model_path, device=torch.device("cuda" if torch.cuda.is_available() else "cpu")):
    checkpoint = torch.load(model_path, map_location=device)
    
    # Extract number of classes and class mapping from the checkpoint
    num_classes = checkpoint['num_classes']
    class_mapping = {v: k for k, v in checkpoint['class_mapping'].items()}  # Reverse mapping for index-to-class name

    # Extract normalization values from checkpoint
    normalization = checkpoint['normalization']
    loaded_mean = normalization['mean']
    loaded_std = normalization['std']

    # Initialize the model
    model = resnet18(num_classes=num_classes)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.to(device)
    model.eval()
    
    return model, class_mapping, device, loaded_mean, loaded_std

# Check if a file is an image
def is_image(file_path):
    try:
        Image.open(file_path).verify()
        return True
    except UnidentifiedImageError:
        return False
    except Exception:
        return False

# Define transformations
def get_transform(mean, std):
    return transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std)
    ])

def save_predictions_to_excel(predictions, folder_path):
    output_dir = folder_path
    os.makedirs(output_dir, exist_ok=True)
    excel_file_path = os.path.join(output_dir, "predictions.xlsx")
    df = pd.DataFrame(predictions, columns=["Image Name", "Predicted Class", "Confidence Score"])
    df.to_excel(excel_file_path, index=False)
    print(f"Predictions saved to {excel_file_path}")

def save_partial_predictions(predictions, folder_path):
    """Save predictions with a unique 'cancelled' file name."""
    if not predictions:
        return  # Skip if no predictions to save

    # Base file name
    base_name = "predictions_cancelled"
    extension = ".xlsx"

    # Start checking for available file names
    for i in range(0, 1000):  # Arbitrary upper limit to avoid infinite loops
        if i == 0:
            file_name = f"{base_name}{extension}"  # First file: predictions_cancelled.xlsx
        else:
            file_name = f"{base_name}{i}{extension}"  # Subsequent files: predictions_cancelled1.xlsx, etc.

        file_path = os.path.join(folder_path, file_name)
        if not os.path.exists(file_path):
            # Save the predictions when an available name is found
            df = pd.DataFrame(predictions, columns=["Image Name", "Predicted Class", "Confidence Score"])
            df.to_excel(file_path, index=False)
            print(f"Partial predictions saved to {file_path}")
            break

def classify_images(parent_folder, model, device, class_mapping, progress_callback, stop_event, start_time, 
                    loaded_mean, loaded_std, append_finished_path_callback):
    transform = get_transform(mean=loaded_mean, std=loaded_std)
    total_images_classified = 0  # Track total images classified

    for root, dirs, files in os.walk(parent_folder):
        dirs[:] = [d for d in dirs if not d.endswith("_classified")]
        image_files = [file for file in files if is_image(os.path.join(root, file))]

        if not image_files:
            continue

        total_images = len(image_files)
        folder_predictions = []

        for idx, image_filename in enumerate(image_files):
            if stop_event.is_set():  # Check if cancellation was triggered
                # Save partial predictions for the current folder
                save_partial_predictions(folder_predictions, root)
                messagebox.showwarning("Cancelled", "Classification cancelled by user. Partial predictions (if any) saved.")
                return

            image_path = os.path.join(root, image_filename)
            try:
                image = Image.open(image_path)
                image = transform(image).unsqueeze(0).to(device)

                with torch.no_grad():
                    outputs = model(image)
                probs = torch.nn.functional.softmax(outputs, dim=1)[0]
                confidence_score, predicted_class = probs.max(0)
                predicted_class_name = class_mapping[predicted_class.item()]
                confidence_score = confidence_score.item()

                # Current file name without renaming
                new_image_name = image_filename  # Retain original name
                new_image_path = os.path.join(root, f"{predicted_class_name}_classified", new_image_name)

                # Move the image without changing the name
                class_folder = os.path.join(root, f"{predicted_class_name}_classified")
                os.makedirs(class_folder, exist_ok=True)
                shutil.move(image_path, new_image_path)

                folder_predictions.append([new_image_name, predicted_class_name, round(confidence_score, 2)])

            except Exception as e:
                print(f"Error processing {image_filename}: {e}")

            progress_callback(idx + 1, total_images, (idx + 1) / (time.time() - start_time), root, time.time() - start_time)

        # Save predictions to Excel
        save_predictions_to_excel(folder_predictions, root)
        
        # Compute a meaningful relative path
        relative_path = os.path.relpath(root, parent_folder)
        if relative_path == ".":  # If the folder is the root, show the parent folder name
            relative_path = os.path.basename(parent_folder)
        
        progress_callback(
            total_images, total_images, 0, root, time.time() - start_time
        )  # Update progress for completed folder
        append_finished_path_callback(relative_path, total_images, total_images)
        
    # Add final messages to the scrollable box    
    final_message = "\nClassification finished! You may close this window now.\n"
    append_finished_path_callback(final_message, None, None)  # Skip | 0/0 for the final message

    messagebox.showinfo("Done", f"Classification completed. All subfolders within {parent_folder} classified.")

def find_model_file(directory, extension=".pth"):
    files = [f for f in os.listdir(directory) if f.endswith(extension)]
    if len(files) == 1:
        return files[0]
    elif len(files) > 1:
        raise ValueError("Multiple model files found. Please remove one model (.pth file) to proceed.")
    else:
        raise FileNotFoundError("No model file (.pth) found in the directory.")

# GUI Application
class ResNetClassifierApp:
    def __init__(self, root):
        # Dynamically find the model file
        try:
            self.model_path = find_model_file(".")  # Search in the current directory
        except FileNotFoundError as e:
            messagebox.showerror("Error", f"No model file found: {e}")
            root.destroy()
            return
        except ValueError as e:
            messagebox.showerror("Error", f"Multiple model files found: {e}")
            root.destroy()
            return

        # Load model, class mapping, device, and normalization values
        self.model, self.class_mapping, self.device, self.loaded_mean, self.loaded_std = load_model(self.model_path)

        # GUI Elements
        self.stop_event = Event()
        self.parent_folder = ""
        self.root = root
        root.title("ResNet Image Classifier")
        root.geometry("600x400")

        # Configure resizing for root
        root.grid_rowconfigure(0, weight=1)  # Content frame resizes
        root.grid_rowconfigure(1, weight=1)  # Footer frame resizes
        root.grid_columnconfigure(0, weight=1)

        # Content frame
        content_frame = tk.Frame(root)
        content_frame.grid(row=0, column=0, sticky="nsew")
        content_frame.grid_rowconfigure(6, weight=1)  # Allow scroll_frame to resize
        content_frame.grid_columnconfigure(0, weight=1)

        # Footer frame
        footer_frame = tk.Frame(root)
        footer_frame.grid(row=1, column=0, sticky="nsew")
        footer_frame.grid_rowconfigure(0, weight=1)  # Elapsed Time row
        footer_frame.grid_rowconfigure(1, weight=1)  # Footer text row
        footer_frame.grid_columnconfigure(0, weight=1)  # Center column

        # Main Label
        self.label = tk.Label(content_frame, text="Welcome to the VPR image (roi) classifier.\nPlease select a parent folder to classify images.")
        self.label.grid(row=0, column=0, pady=10, sticky="n")

        # Select Folder Button
        self.select_button = tk.Button(content_frame, text="Select Folder", command=self.select_folder)
        self.select_button.grid(row=1, column=0, pady=5)

        # Directory info area
        self.dir_info = tk.Label(content_frame, text="", bg="white", font=("Helvetica", 10), anchor="w")
        self.dir_info.grid(row=2, column=0, sticky="ew", padx=10, pady=5)

        # Start and Cancel buttons
        self.start_button = tk.Button(content_frame, text="Start Classification", command=self.start_classification, state=tk.DISABLED)
        self.start_button.grid(row=3, column=0, pady=5)

        self.cancel_button = tk.Button(content_frame, text="Cancel Classification", command=self.cancel_classification, state=tk.DISABLED)
        self.cancel_button.grid(row=4, column=0, pady=5)

        # Progress Label
        self.progress_label = tk.Label(content_frame, text="", font=("Helvetica", 10))
        self.progress_label.grid(row=5, column=0, pady=5)

        # Scrollable label to display finished paths
        scroll_frame = tk.Frame(content_frame)
        scroll_frame.grid(row=6, column=0, sticky="nsew", padx=10, pady=5)
        content_frame.grid_rowconfigure(6, weight=1)  # Allow scroll_frame to expand
        content_frame.grid_columnconfigure(0, weight=1)  # Center column resizes
        
        self.finished_paths_text = tk.Text(
            scroll_frame, wrap=tk.WORD, font=("Helvetica", 10), height=5, width=50
        )
        self.finished_paths_text.pack(side=tk.LEFT, fill=tk.BOTH, expand=True)
        
        scrollbar = tk.Scrollbar(scroll_frame, command=self.finished_paths_text.yview)
        scrollbar.pack(side=tk.RIGHT, fill=tk.Y)
        
        self.finished_paths_text.config(yscrollcommand=scrollbar.set)
        self.finished_paths_text.config(state=tk.DISABLED)
        
        # Elapsed Time area
        self.elapsed_time_label = tk.Label(footer_frame, text="Elapsed Time: 0:00:00", font=("Helvetica", 8))
        self.elapsed_time_label.grid(row=0, column=0, sticky="nsew", padx=10, pady=(10,5))

        # Footer text
        self.footer_label = tk.Label(footer_frame, text="For instructions, please check VPR_classifier_manual.doc.", font=("Helvetica", 8))
        self.footer_label.grid(row=1, column=0, sticky="nsew", padx=10, pady=(5, 10))

    def select_folder(self):
        self.parent_folder = filedialog.askdirectory()
        if self.parent_folder:
            self.dir_info.config(text=f"Current folder: {self.parent_folder}")
            self.start_button.config(state=tk.NORMAL)

    def cancel_classification(self):
        self.stop_event.set()

    def update_progress(self, completed, total, images_per_sec, root, elapsed_time):
        # Calculate the relative path from the selected parent folder to the current folder
        relative_path = os.path.relpath(root, self.parent_folder)
        if relative_path == ".":  # If the folder is the root, show the parent folder name
            relative_path = os.path.basename(self.parent_folder)
    
        # Schedule updates to the GUI safely using after()
        self.root.after(0, lambda: self.progress_label.config(
            text=f"Working on: {relative_path} | {completed}/{total}"
        ))
        self.root.after(0, lambda: self.elapsed_time_label.config(
            text=f"Elapsed Time: {time.strftime('%H:%M:%S', time.gmtime(elapsed_time))}"
        ))

    def append_finished_path(self, relative_path, completed, total):
        # Enable the Text widget for editing
        self.finished_paths_text.config(state=tk.NORMAL)
    
        # Check for None and avoid adding "| 0/0" if unnecessary
        if completed is None or total is None:
            message = f"{relative_path}\n"
        else:
            message = f"Finished {relative_path} | {completed}/{total}\n"
    
        # Add the message to the widget
        self.finished_paths_text.insert(tk.END, message)
    
        # Auto-scroll to the end
        self.finished_paths_text.see(tk.END)
    
        # Disable the Text widget to make it read-only again
        self.finished_paths_text.config(state=tk.DISABLED)
    
    def start_classification(self):
        # Clear the scrollable box content
        self.finished_paths_text.config(state=tk.NORMAL)  # Enable editing
        self.finished_paths_text.delete(1.0, tk.END)  # Clear all content
        self.finished_paths_text.config(state=tk.DISABLED)  # Make it read-only again

        # Start classification
        self.stop_event.clear()
        self.progress_label.config(text="Loading...")
        self.start_time = time.time()  # Reset the timer
        thread = Thread(target=classify_images, args=(self.parent_folder, self.model, self.device,
                                                      self.class_mapping, self.update_progress,
                                                      self.stop_event, self.start_time,
                                                      self.loaded_mean, self.loaded_std,
                                                      self.append_finished_path))
        thread.start()
        self.cancel_button.config(state=tk.NORMAL)


# Main GUI Loop
if __name__ == "__main__":
    root = tk.Tk()
    app = ResNetClassifierApp(root)
    root.mainloop()


  checkpoint = torch.load(model_path, map_location=device)
