In [1]:
import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import torch
from torchvision.transforms.functional import to_tensor
from IQADataset import NonOverlappingCropPatches
from Network import CNNIQAnet
from training import IQAPerformance
import tifffile as tf
import numpy as np
import matplotlib.pyplot as plt
import glob
import os
import imagehash
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg

In [2]:

class IQAApp(tk.Tk):
    def __init__(self):
        super().__init__()

        # Set the window title
        self.title("Image Quality Assessment")

        # Create a label to display the image
        self.image_label = tk.Label(self)
        self.image_label.pack()

        # Create a button to load an image
        self.load_button = tk.Button(self, text="Load Image", command=self.load_image)
        self.load_button.pack()

        # Create a button to perform image quality assessment
        self.assess_button = tk.Button(self, text="Assess Image Quality", command=self.assess_quality)
        self.assess_button.pack()
        # Create a button to perform image quality assessment
        self.assess_button = tk.Button(self, text="Application's classification", command=self.assess_quality)
        self.assess_button.pack()
        # Load the trained model
        model = CNNIQAnet()
        model.load_state_dict(torch.load('C:/Users/win 10/Desktop/CNNIQA/CNNIQA/results/CNNIQA-EuroSat-EXP0-lr=0.001'))
        model.eval()

        # Initialize the IQAPerformance class
        self.performance_metrics = IQAPerformance()

    def load_image(self):
        # Open a file dialog to choose an image
        file_path = filedialog.askopenfilename(filetypes=[("TIFF Files", "*.tif;*.tiff")])
        if file_path:
            
            # Load the image using tifffile
            image = tifffile.imread(file_path)

            # Display the image in the GUI
            image = Image.fromarray(image)
            image.thumbnail((400, 400))  # Resize the image to fit the label
            photo = ImageTk.PhotoImage(image)
            self.image_label.config(image=photo)
            self.image_label.image = photo

            # Convert the image to tensor for assessment
            self.input_image = image.convert("L")  # Add batch dimension

    def assess_quality(self):
        model = CNNIQAnet()
        model.load_state_dict(torch.load('C:/Users/win 10/Desktop/CNNIQA/CNNIQA/results/CNNIQA-EuroSat-EXP0-lr=0.001'))
        model.eval()
        if hasattr(self, 'input_image'):
            # Perform image quality assessment using the model
            with torch.no_grad():
                im = self.input_image
                patches = NonOverlappingCropPatches(im, 32, 32)
                patch_scores = model(torch.stack(patches).to(torch.device('cpu')))
                '''output = self.model(self.input_image)'''
                predicted_score = patch_scores.mean().item()
            # Assuming you have the ground truth score for the input image
            ground_truth_score = 8.8371  # Replace with the actual ground truth score
            # Calculate the difference between predicted and ground truth scores
            difference = abs(predicted_score - ground_truth_score)
            # Display the result in a message box
            tk.messagebox.showinfo("Image Quality Assessment",
                                   f"Predicted Score: {predicted_score}\n"
                                   f"Ground Truth Score: {ground_truth_score}\n"
                                   f"Difference: {difference}")

    def run(self):
        self.mainloop()

if __name__ == "__main__":
    app = IQAApp()
    app.run()


In [9]:
class ImageQualityAssessmentApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Image Quality Assessment App")
        self.folder_path = tk.StringVar()
        self.log_text = tk.StringVar()
        self.log_text.set("Your Results will appear here.")
        
        self.create_widgets()

    def create_widgets(self):

        self.image_folder = None
        self.image_scores = {}
        
        tk.Label(self.root, text="Select Folder:").pack(pady=10)  # Add some vertical spacing        
        folder_entry = tk.Entry(self.root, textvariable=self.folder_path,width=50)
        folder_entry.pack()
        tk.Button(root, text="Browse Folder", command=self.browse_folder,bg='black', fg='white').pack(pady=10, padx=20)  
        tk.Button(self.root, text="Detect and Remove Redundant Images", command=self.detect_redundant,bg='purple',fg='white').pack()
        tk.Button(root, text="Calculate Image Quality Score", command=self.calculate_scores, bg='green', fg='white').pack(pady=5, padx=20)  
        tk.Button(root, text="Plot Image Quality Scores", command=self.plot_scores, bg='green', fg='white').pack(pady=5, padx=20)  
        tk.Button(root, text="Remove Low Quality Images", command=self.remove_low_quality, bg='green', fg='white').pack(pady=5, padx=20)  
        tk.Button(root, text="Classify the image's Application", bg='blue', fg='white').pack(pady=10, padx=20)  
        tk.Label(self.root, text="Log:").pack()
        tk.Label(self.root, textvariable=self.log_text).pack()


        self.removed_images_label = tk.Label(root, text="The Removed Duplicates images are :", wraplength=400)
        self.removed_images_label.pack(pady=10)
        
        self.message_text = tk.Text(self.root, height=4, width=105)
        self.message_text.pack(pady=10)
        
        model = CNNIQAnet()
        model.load_state_dict(torch.load('C:/Users/win 10/Desktop/CNNIQA/CNNIQA/results/CNNIQA-EuroSat-EXP0-lr=0.001'))
        model.eval()
        
        '''self.plot_frame = tk.Frame(root)
        self.plot_frame.pack()'''

        self.removed_images_listbox = tk.Listbox(self.root,height=4, width=140)
        self.removed_images_listbox.pack()

        
        
        self.result_canvas = None

        # Initialize the IQAPerformance class
        self.performance_metrics = IQAPerformance()
        
    def browse_folder(self):
        self.image_folder = filedialog.askdirectory()
        self.folder_path.set(self.image_folder)
        
    def display_message(self, message):
        self.message_text.insert(tk.END, message + "\n")
        self.message_text.see(tk.END)
    
    
    def calculate_image_quality(self, image_path):
        
        model = CNNIQAnet()
        model.load_state_dict(torch.load('C:/Users/win 10/Desktop/CNNIQA/CNNIQA/results/CNNIQA-EuroSat-EXP0-lr=0.001'))
        model.eval()

        # Perform image quality assessment using the model
        im = tf.imread(image_path)
        im = Image.fromarray(im)
        im=im.convert("L")
        patches = NonOverlappingCropPatches(im, 32, 32)
        patch_scores = model(torch.stack(patches).to(torch.device('cpu')))
        predicted_score = patch_scores.mean().item()

        return predicted_score

    def calculate_scores(self):
        if self.image_folder is None:
            print("Please select a folder first.")
            return
        
        image_files = glob.glob(self.image_folder + "/*.tif")
        for image_path in image_files:
            score = self.calculate_image_quality(image_path)
            image_name = image_path.split('/')[-1]  # Get just the file name
            self.image_scores[image_name] = score

        self.display_message("Scores calculated.")
        
    def plot_scores(self):
        if not self.image_scores:
            self.display_message("No scores to plot.")
            return

        names = list(self.image_scores.keys())
        scores = list(self.image_scores.values())

        fig, ax = plt.subplots(figsize=(10,8))
        ax.bar(names, scores)
        ax.set_xticklabels(names, rotation=90)
        ax.set_xlabel("Image Names")
        ax.set_ylabel("Image Scores")
        ax.set_title("Image Quality Scores")
        plt.tight_layout()
        plt.show
        self.result_canvas = FigureCanvasTkAgg(fig, master=self.root)
        self.result_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def update_removed_images_listbox(self, removed_images):
        self.removed_images_listbox.delete(0, tk.END)
        for image_name, score in removed_images:
            self.removed_images_listbox.insert(tk.END, f"Removed: {image_name} (Score: {score:.2f})")

    def remove_low_quality(self):
        if not self.image_scores:
            self.display_message("No scores to evaluate.")
            return

        removed_images = []
        for image_name, score in self.image_scores.items():
            if score < 7.0:
                removed_images.append((image_name, score))
        
        for image_name, score in removed_images:
            self.display_message(f"Removing {image_name} with score {score:.2f}")
            del self.image_scores[image_name]
        
        if len(removed_images)==0 :
            self.display_message("No Low_quality LEO Satellite Images In This Folder.")
        else :
            self.display_message("LEO Satellite images with Low quality images removed.")
            
        self.update_removed_images_listbox(removed_images)
    
    def plot_images(self, img_list, image_size=(6, 4)):
        num_images = len(img_list)
        fig, axes = plt.subplots(1, num_images, figsize=(num_images * image_size[0], image_size[1]))
        fig.suptitle("Redundant Images")

        for ax, img_path in zip(axes, img_list):
            img = tf.imread(img_path)
            ax.imshow(img) 
            img_name = os.path.basename(img_path)
            ax.set_title(img_name)
            ax.axis('off')
        plt.tight_layout()
        plt.show()  
        plt.close(fig)


    def detect_redundant(self):
        if not self.image_folder:
            self.log_text.set("Please select a folder.")
            return

        image_hashes = {}
        removed_images = []

        for filename in os.listdir(self.image_folder):
            if filename.lower().endswith((".tif", ".tiff")):
                image_path = os.path.join(self.image_folder, filename)
                try:
                    img_array = tf.imread(image_path)
                    img = Image.fromarray(img_array)
                    img_hash = str(imagehash.phash(img))
                    if img_hash in image_hashes:
                        image_hashes[img_hash].append(image_path)
                    else:
                        image_hashes[img_hash] = [image_path]
                except Exception as e:
                    print(f"Error processing {filename}: {e}")

        redundant_detected = False
        for img_list in image_hashes.values():
            if len(img_list) > 1:
                redundant_detected = True
                self.plot_images(img_list)
                kept_image = img_list[0]
                removed_images.extend(img_list[1:])

        if redundant_detected:
            self.log_text.set("Redundant images detected and removed.")
            removed_images_text = "\n".join(removed_images)
            self.removed_images_label.config(text=f"Removed Images:\n{removed_images_text}\nKept Image: {kept_image}")
        else:
            self.log_text.set("No redundant images detected in your folder.")
            self.removed_images_label.config(text="")
        for img_list in image_hashes.values():
            if len(img_list) > 1:
                kept_image = img_list[0]
                removed_images.extend(img_list[1:])
                #for img_path in img_list[1:]:
                    #os.remove(img_path)
                self.log_text.set("Removed redundant images.")
                
                # Display removed images and kept image in label
                removed_images_text = "\n".join(removed_images)
                self.removed_images_label.config(text=f"Removed Images:\n{removed_images_text}\nKept Image: {kept_image}")
    

if __name__ == "__main__":
    root = tk.Tk()
    app = ImageQualityAssessmentApp(root)
    root.mainloop()
