# 英文版

In [11]:
import os
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter import ttk
from PIL import Image, ImageTk
import numpy as np
import pandas as pd
import torch
import torch.nn.functional as F
from torchvision import transforms

# 初始化Tkinter窗口
root = tk.Tk()
root.title("Damage States Prediction GUI")
root.geometry("950x600")  # 调整窗口大小

# 使用ttk样式
style = ttk.Style()
style.theme_use('clam')  # 使用较为现代的clam主题
style.configure("TButton", padding=6, relief="flat", background="#ccc")
style.configure("TLabel", padding=6, background="#eee", font=("Times new roman", 12))
style.configure("TFrame", background="#eee")

# 设置设备
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# 定义全局变量
model = None
test_images = []  # 初始化为空列表
test_image_paths = []
idx_to_labels = {}
img_tk = None  # 用于保存图像对象

# 图像预处理
test_transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# 创建主画布
canvas = tk.Canvas(root, width=900, height=500)
canvas.pack(fill=tk.BOTH, expand=True)

# 在底部加载并显示背景图像
def load_bottom_image():
    global img_tk_bottom  # 确保图像对象为全局变量
    img_path_bottom = r".\S.png"
    if os.path.exists(img_path_bottom):
        img_bottom = Image.open(img_path_bottom)
        img_bottom = img_bottom.resize((950, 600), Image.LANCZOS)  # 调整背景图像大小
        img_tk_bottom = ImageTk.PhotoImage(img_bottom)
        canvas.create_image(0, 500, anchor=tk.NW, image=img_tk_bottom)  # 将图像放置在窗口的底部

load_bottom_image()

# 加载背景图像
def load_background_image():
    global img_tk  # 确保图像对象为全局变量
    img_path = r".\CC-Vit.png"
    if os.path.exists(img_path):
        img = Image.open(img_path)
        img = img.resize((400, 500), Image.LANCZOS)  # 调整背景图像大小
        img_tk = ImageTk.PhotoImage(img)
        canvas.create_image(500, 50, anchor=tk.NW, image=img_tk)  # 放置在右侧区域
        
        # 添加图片标题
        canvas.create_text(720, 40, text="CC-ViT Model", font=("Arial", 14, "bold"), fill="black")

        # 添加作者信息
        canvas.create_text(10, 440, text="Author: Y. Li, Z. Sun, Q. Li, et al.", font=("Arial", 10), fill="black", anchor=tk.W)
        canvas.create_text(10, 460, text="Email: tumu16lyl@163.com", font=("Arial", 10), fill="black", anchor=tk.W)
        canvas.create_text(10, 480, text="Southeast University, Jiangsu, China", font=("Arial", 10), fill="black", anchor=tk.W)

load_background_image()

# 使用 Canvas 放置控件
def create_control_frame():
    control_frame = tk.Frame(canvas, bg="#ffffff", bd=5)
    control_frame.place(x=10, y=10, width=450, height=400)  # 增加高度以适应进度条

    # 选择测试集文件夹
    def select_folder():
        global test_images, test_image_paths
        dataset_dir = filedialog.askdirectory()
        if not dataset_dir:
            return
        val_dir = os.path.join(dataset_dir, 'val')
        
        # 确保val目录存在
        if not os.path.isdir(val_dir):
            messagebox.showerror("Error", "No 'val' folder found in the specified directory")
            return
        
        # 遍历val目录下的所有图片文件
        test_images = []
        test_image_paths = []
        valid_extensions = ('.jpg', '.jpeg', '.png', '.bmp', '.gif')
        
        for root, _, files in os.walk(val_dir):
            for file in files:
                if file.lower().endswith(valid_extensions):
                    img_path = os.path.join(root, file)
                    test_image_paths.append(img_path)
                    try:
                        img = Image.open(img_path).convert('RGB')
                        img = test_transform(img)  # 预处理
                        test_images.append(img)
                    except Exception as e:
                        log_message(f"Cannot process file {img_path}: {e}")

        # 更新标签
        folder_label.config(text=f"Test Set Folder: {dataset_dir}")
        update_status(f"Test set folder imported successfully, number of images: {len(test_images)}")
        log_message(f"Number of test images: {len(test_images)}")

    # 选择模型文件
    def select_model():
        global model
        model_weight_path = filedialog.askopenfilename(filetypes=[("PyTorch模型文件", "*.pth")])
        if model_weight_path:
            model = torch.load(model_weight_path, map_location=device)
            model = model.eval().to(device)
            model_label.config(text=f"Model File: {model_weight_path}")
            update_status("Model loaded successfully!")
            log_message("Model loaded successfully")

    # 选择类别映射文件
    def select_idx_to_labels():
        global idx_to_labels
        idx_to_labels_path = filedialog.askopenfilename(filetypes=[("Class Mapping Files", "*.txt")])
        if idx_to_labels_path:
            idx_to_labels = {}
            with open(idx_to_labels_path, 'r') as f:
                for line in f:
                    key, value = line.strip().split(': ')
                    idx_to_labels[int(key)] = value
            
            labels_label.config(text=f"Class Mapping File: {idx_to_labels_path}")
            update_status("Class mapping file loaded successfully!")
            log_message("Class mapping file loaded successfully")

    # 执行预测
    def predict():
        if model is None or not test_image_paths or not idx_to_labels:
            messagebox.showwarning("Warning", "Please ensure all required files are loaded")
            return

        # 表格A-测试集图像路径
        img_paths = test_image_paths
        df = pd.DataFrame()
        df['Image_Path'] = img_paths

        # 表格B-测试集每张图像的图像分类预测结果，以及各类别置信度
        n = 3
        df_pred_list = []  # 用于保存所有的字典，稍后一次性转为 DataFrame

        # 计算总图像数量
        total_images = len(img_paths)
        
        for idx, row in enumerate(df.itertuples(index=False)):
            img_path = row.Image_Path
            img_pil = Image.open(img_path).convert('RGB')
            input_img = test_transform(img_pil).unsqueeze(0).to(device)  # 预处理
            pred_logits = model(input_img)  # 执行前向预测，得到所有类别的 logit 预测分数
            pred_softmax = F.softmax(pred_logits, dim=1)  # 对 logit 分数做 softmax 运算

            pred_dict = {}
            top_n = torch.topk(pred_softmax, n)  # 取置信度最大的 n 个结果
            pred_ids = top_n[1].cpu().detach().numpy().squeeze()  # 解析出类别

            # top-n 预测结果
            for i in range(n):
                pred_dict[f'top-{i+1}-Prediction ID'] = pred_ids[i]
                pred_dict[f'top-{i+1}-Prediction Name'] = idx_to_labels[pred_ids[i]]
                pred_dict[f'top-{i+1}-Prediction Confidence'] = pred_softmax[0][pred_ids[i]].cpu().detach().numpy()
            
            df_pred_list.append(pred_dict)  # 将字典添加到列表中

            # 更新进度条
            update_progress_bar(idx + 1, total_images)

        # 使用 pd.concat 一次性将所有字典转为 DataFrame
        df_pred = pd.concat([pd.DataFrame([d]) for d in df_pred_list], ignore_index=True)

        # 拼接AB两张表格
        df = pd.concat([df, df_pred], axis=1)

        # 导出完整表格
        save_path = filedialog.asksaveasfilename(defaultextension=".csv", filetypes=[("CSV文件", "*.csv")])
        if save_path:
            df.to_csv(save_path, index=False, encoding='utf-8-sig')
            update_status("Prediction results saved")
            log_message("Prediction results saved")

    # 创建控件
    #folder_icon = tk.PhotoImage(file='folder_icon.png')  # 请将图标文件放在与脚本相同的目录
    folder_btn = ttk.Button(control_frame, text="Select Test Set Folder", command=select_folder)  #, image=folder_icon
    folder_btn.grid(row=0, column=0, padx=10, pady=5, sticky='w')

    folder_label = ttk.Label(control_frame, text="No Test Set Folder Selected")
    folder_label.grid(row=0, column=1, padx=10, pady=5, sticky='w')

    model_btn = ttk.Button(control_frame, text="Select Model File", command=select_model)
    model_btn.grid(row=1, column=0, padx=10, pady=5, sticky='w')

    model_label = ttk.Label(control_frame, text="No Model File Selected")
    model_label.grid(row=1, column=1, padx=10, pady=5, sticky='w')

    labels_btn = ttk.Button(control_frame, text="Select Class Mapping File", command=select_idx_to_labels)
    labels_btn.grid(row=2, column=0, padx=10, pady=5, sticky='w')

    labels_label = ttk.Label(control_frame, text="No Class Mapping File Selected")
    labels_label.grid(row=2, column=1, padx=10, pady=5, sticky='w')

    # 确认对话框函数
    def confirm_action(message, action):
        if messagebox.askokcancel("Confirm", message):
            action()

    predict_btn = ttk.Button(control_frame, text="Start Prediction", command=lambda: confirm_action("Sure you want to start predicting?", predict))
    predict_btn.grid(row=3, column=0, padx=10, pady=10) #, columnspan=2

    # 添加帮助按钮
    def show_help():
        messagebox.showinfo("Help", "1. Select the test set folder. \n2. Selecting a model file. \n3. Select the category mapping file. \n4. Click on the 'Start Prediction' button to make a prediction.")

    help_btn = ttk.Button(control_frame, text="Help", command=show_help)
    help_btn.grid(row=4, column=0, padx=10, pady=10) #, columnspan=2

    # 添加日志区域
    log_text = tk.Text(control_frame, height=10, width=50, state=tk.DISABLED)
    log_text.grid(row=6, column=0, padx=0, pady=10, columnspan=2)

    def log_message(message):
        log_text.config(state=tk.NORMAL)
        log_text.insert(tk.END, message + "\n")
        log_text.config(state=tk.DISABLED)
        log_text.yview(tk.END)

    # 添加状态栏
    status_var = tk.StringVar()
    status_var.set("Ready to go")
    status_bar = tk.Label(root, textvariable=status_var, anchor=tk.W, relief=tk.SUNKEN, height=2)
    status_bar.pack(side=tk.BOTTOM, fill=tk.X)

    def update_status(message):
        status_var.set(message)

    # 添加进度条
    progress_var = tk.DoubleVar()
    progress_bar = ttk.Progressbar(control_frame, variable=progress_var, maximum=100)
    progress_bar.grid(row=5, column=0, padx=0, pady=10, columnspan=1, sticky='ew')

    def update_progress_bar(current, total):
        progress_var.set((current / total) * 100)
        root.update_idletasks()  # 更新进度条

create_control_frame()

root.mainloop()
