In [1]:
#Main.ipynb this is the GUI code 

import tkinter as tk
from tkinter import filedialog
from PIL import Image, ImageTk
import cv2
import os
import time
import numpy as np
import import_ipynb
import matplotlib.pyplot as plt
import pandas as pd
import joblib
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix
from sklearn.preprocessing import StandardScaler

# Importing functions from the separate Jupyter notebooks
import model_implementation
import preprocessing

# Load the CSV files
feature_train = pd.read_csv('FeatureTrain.csv')
feature_test = pd.read_csv('FeatureTest.csv')

# Separate features and labels
X_train = feature_train.drop(columns='Class').values
y_train = feature_train['Class'].values
X_test = feature_test.drop(columns='Class').values
y_test = feature_test['Class'].values

# the training data was scaled, scale the test data in the same way
scaler = StandardScaler().fit(X_train)  
X_test_scaled = scaler.transform(X_test)

def get_specimen_name(image_path):
    # Extract the directory name that contains the species name
    directory_name = os.path.basename(os.path.dirname(image_path))
    
    # The directory name is in the format '12. Celtis sp'
    # Split and get the species name part
    species_name = ' '.join(directory_name.split()[1:])  # This will get 'Celtis sp'
    
    return species_name

class PlantSpeciesClassifierGUI:
    def __init__(self, master):
        self.master = master
        self.master.title("Plant Species Classification")
        self.master.geometry("1200x800")  # Set fixed window size
        self.master.resizable(False, False)  # Make window non-resizable

        # Frame for buttons at the top
        self.top_frame = tk.Frame(master)
        self.top_frame.pack(side=tk.TOP, fill=tk.X, padx=10, pady=10)

        self.upload_button = tk.Button(self.top_frame, text="Browse Input Image", command=self.upload_image)
        self.upload_button.pack(side=tk.LEFT, padx=5)

        self.preprocess_button = tk.Button(self.top_frame, text="Pre-processing", command=self.preprocess_image)
        self.preprocess_button.pack(side=tk.LEFT, padx=5)

        self.classify_button = tk.Button(self.top_frame, text="Classification", command=self.classify_image)
        self.classify_button.pack(side=tk.LEFT, padx=5)

        self.clear_button = tk.Button(self.top_frame, text="Clear All", command=self.clear_all)
        self.clear_button.pack(side=tk.LEFT, padx=5)

        self.exit_button = tk.Button(self.top_frame, text="Exit", command=self.exit_application)
        self.exit_button.pack(side=tk.LEFT, padx=5)

        # Frame for model selection on the left
        self.left_frame = tk.LabelFrame(master, text="Model Selection")
        self.left_frame.pack(side=tk.LEFT, fill=tk.Y, padx=5, pady=5)

        self.model_var = tk.StringVar(value="")  # No default model selected

        self.rf_button = tk.Radiobutton(self.left_frame, text="Random Forest", variable=self.model_var, value="rf_model.pkl")
        self.rf_button.pack(anchor=tk.W, pady=2)

        self.lr_button = tk.Radiobutton(self.left_frame, text="Logistic Regression", variable=self.model_var, value="logistic_model.pkl")
        self.lr_button.pack(anchor=tk.W, pady=2)

        self.svm_button = tk.Radiobutton(self.left_frame, text="SVM", variable=self.model_var, value="svm_model.pkl")
        self.svm_button.pack(anchor=tk.W, pady=2)
        
        # Warning label for model selection
        self.warning_label = tk.Label(master, text="", fg="red", font=('Helvetica', 16))
        self.warning_label.pack(pady=10)
        
        # Frame for images in the center
        self.image_frame = tk.Frame(master)
        self.image_frame.pack(side=tk.LEFT, expand=True, padx=10, pady=10)

        self.labels = []
        self.images = []

        for i in range(6):
            label_frame = tk.Frame(self.image_frame, width=200, height=200, bg='gray')
            label_frame.grid_propagate(False)  # Prevent frame from resizing to fit content
            label_frame.grid(row=i//3, column=i%3, padx=5, pady=5)
            label = tk.Label(label_frame)
            label.pack(fill=tk.BOTH, expand=True)
            self.labels.append(label)

        self.species_label = tk.Label(self.image_frame, text="Detected Species:", font=('Helvetica', 16))
        self.species_label.grid(row=2, column=0, columnspan=3, pady=10)

        self.result_label = tk.Label(master, text="", font=('Helvetica', 16))
        self.result_label.pack()

        # Frame for metrics on the right in tabular format
        self.metrics_frame = tk.LabelFrame(master, text="Results", font=('Helvetica', 14, 'bold'))
        self.metrics_frame.pack(side=tk.RIGHT, fill=tk.Y, padx=10, pady=10)

        headers = ["Metric", "Value"]
        for i, header in enumerate(headers):
            label = tk.Label(self.metrics_frame, text=header, font=('Helvetica', 12, 'bold'))
            label.grid(row=0, column=i, padx=5, pady=5)

        metrics = ["Accuracy", "Precision", "Recall", "F1 Score", "Specificity", "Elapsed Time (s)"]
        self.metric_labels = {}

        for i, metric in enumerate(metrics):
            metric_label = tk.Label(self.metrics_frame, text=f"{metric}:", font=('Helvetica', 12), anchor='w')
            metric_label.grid(row=i+1, column=0, padx=5, pady=2, sticky='w')
            self.metric_labels[metric] = tk.Label(self.metrics_frame, text="", font=('Helvetica', 12), anchor='w')
            self.metric_labels[metric].grid(row=i+1, column=1, padx=5, pady=2, sticky='w')

        self.filepath = None
        
    def upload_image(self):
        self.filepath = filedialog.askopenfilename(filetypes=[("Image files", "*.jpg;*.jpeg;*.png")])
        if not self.filepath:
            return
        img = Image.open(self.filepath)
        img = img.resize((200, 200), Image.LANCZOS)  # Increase image size
        self.images.append(ImageTk.PhotoImage(img))
        self.labels[0].configure(image=self.images[-1], text="Input Image", compound='top')
        
        # Extract specimen name directly from the file path
        self.specimen_name = get_specimen_name(self.filepath)
    
    def preprocess_image(self):
        if not self.filepath:
            return
        original_image, enhanced_img, gray_img, binary_img, sobel_combined, canny_edges = preprocessing.preprocess_image(self.filepath)
        
        self.display_image(enhanced_img, 1, "Enhancement")
        self.display_image(gray_img, 2, "Gray Conversion")
        self.display_image(binary_img, 3, "Binarization")
        self.display_image(sobel_combined, 4, "Sobel Edges")
        self.display_image(canny_edges, 5, "Canny Edges")
        self.binary_img = binary_img
        self.segmented_image = canny_edges
    
    def display_image(self, image, index, text):
        if image.dtype == np.float64:
            image = cv2.convertScaleAbs(image, alpha=(255.0/image.max()))
        if len(image.shape) == 2:  # if grayscale, convert to RGB for display
            image = cv2.cvtColor(image, cv2.COLOR_GRAY2RGB)
        img = Image.fromarray(image)
        img = img.resize((250, 250), Image.LANCZOS)  # Increase image size
        self.images.append(ImageTk.PhotoImage(img))
        self.labels[index].configure(image=self.images[-1], text=text, compound='top')
        
    def classify_image(self):
        # Ensure an image file and segmented image are available
        if not self.filepath or not hasattr(self, 'segmented_image'):
            self.warning_label.config(text="WARNNING!!! Please upload and preprocess an image first.")
            return
    
        # Ensure a model is selected
        if not self.model_var.get():
            self.warning_label.config(text="WARNNING!!! Please select a model before classification.")
            return
        else:
            self.warning_label.config(text="")  # Clear warning text if model is selected
    
        start_time = time.time()
        model_path = self.model_var.get()
    
        # Attempt to load the selected model
        try:
            model = joblib.load(model_path)
        except Exception as e:
            self.warning_label.config(text=f"Failed to load model: {e}")
            return
    
        # Extract features from the segmented image and predict the species
        features = model_implementation.extract_features(self.segmented_image)
        print(f"Features shape: {features.shape}")
        if len(features.shape) == 1:
            features = features.reshape(1, -1)
    
        try:
            species = model_implementation.predict_species(model, features)
        except Exception as e:
            self.warning_label.config(text=f"Error during prediction: {e}")
            return
    
        elapsed_time = time.time() - start_time
        self.metric_labels["Elapsed Time (s)"].configure(text=f"{elapsed_time:.4f}s")
    
        # Display the name and detected species
        self.species_label.configure(text=f"Species Name: {self.specimen_name}, Detected Species: {species}")
    
        # Calculate and display performance metrics
        try:
            accuracy, precision, recall, specificity, f1, _ = model_implementation.calculate_metrics(model, X_test_scaled, y_test)
            self.metric_labels["Accuracy"].configure(text=f"{accuracy:.3f}")
            self.metric_labels["Precision"].configure(text=f"{precision:.3f}")
            self.metric_labels["Recall"].configure(text=f"{recall:.3f}")
            self.metric_labels["F1 Score"].configure(text=f"{f1:.3f}")
            self.metric_labels["Specificity"].configure(text=f"{specificity:.3f}")
        except Exception as e:
            self.warning_label.config(text=f"Failed to calculate metrics: {e}")
            return
    
        # Clear the previous results to avoid duplication
        self.result_label.configure(text="")
    
    def clear_all(self):
        for label in self.labels:
            label.configure(image='', text='')
        self.species_label.configure(text="Detected Species:")
        
        for key in self.metric_labels:
            self.metric_labels[key].configure(text="")
        self.filepath = None
        self.images.clear()
        self.warning_label.config(text="")  # Clear the warning text when clearing all inputs
    
    def exit_application(self):
        self.master.destroy()
    
    # Start the GUI
root = tk.Tk()
app = PlantSpeciesClassifierGUI(root)
root.mainloop()


importing Jupyter notebook from model_implementation.ipynb
importing Jupyter notebook from preprocessing.ipynb
Reading image from: D:/Classification and Regrassion/RGB/1. Quercus suber/iPAD2_C01_EX01.JPG
Features shape: (16,)
Extracted features: [[0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0]] RandomForestClassifier(min_samples_split=10, n_estimators=200, random_state=42)
Predicted species: 22
