In [3]:
import tkinter as tk
from tkinter import messagebox
from tkinter.ttk import Button
from tkinter.filedialog import askopenfilename
from PIL import Image, ImageTk, UnidentifiedImageError
import torch as th
import torch.nn as nn
from torchvision import models, transforms
import numpy as np
import matplotlib.pyplot as plt

# Set up device
device = th.device("cuda" if th.cuda.is_available() else "cpu")

# Load model and class names
class_names = ['battery', 'biological', 'cardboard', 'clothes', 'glass', 'metal', 'paper', 'plastic', 'shoes', 'trash']  # Replace with actual class names
num_classes = len(class_names)

model = models.resnet50(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, num_classes)
model.load_state_dict(th.load("garbage_classification_model.pth", map_location=device))
model.to(device)
model.eval()

# Image preprocessing
data_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def predict_image(pil_image):
    """Processes the image and predicts its class."""
    image_tensor = data_transforms(pil_image).unsqueeze(0).to(device)
    with th.no_grad():
        output = model(image_tensor)
        _, predicted = th.max(output, 1)
    return class_names[predicted.item()]

class ImageUploader:
    def __init__(self, master=None):
        """Initialize the GUI."""
        self.master = master or tk.Tk()
        self.master.title("Image Classifier")
        self.master.geometry(f"{600}x{600}")

        self.canvas = tk.Canvas(self.master, width=600, height=500, bg="white")
        self.canvas.pack()

        self.btn = Button(self.master, text="Upload Image", command=self.upload_image)
        self.btn.pack()

        if master is None:
            self.master.mainloop()

    def upload_image(self):
        """Handles image upload, classification, and display."""
        try:
            img_path = askopenfilename()
            if not img_path:
                return  # Exit if no file selected

            im = Image.open(img_path).convert("RGB")  # Open image and ensure RGB format
            predicted_class = predict_image(im)

            img_width, img_height = im.size
            if img_width > 224 or img_height > 224:
                im.thumbnail((224, 224))

            img = ImageTk.PhotoImage(im)
            self.canvas.image = img
            self.canvas.create_image(300, 250, image=img, anchor=tk.CENTER)

            messagebox.showinfo("Prediction", f"Predicted Class: {predicted_class}")

        except UnidentifiedImageError:
            messagebox.showinfo("Error", "Invalid file type")

def main():
    """Runs the application."""
    ImageUploader()

if __name__ == "__main__":
    main()
