# V1

In [3]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
from sklearn.preprocessing import LabelEncoder
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform
# 配置日志记录
logging.basicConfig(filename='crop_classification.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

fiona.supported_drivers['ESRI Shapefile'] = 'rw'
gpd.io.file.fiona.drvsupport.supported_drivers['ESRI Shapefile'] = 'rw'

def extract_features(image_path, gdf, progress_callback=None):
    features = []
    valid_geometries = []
    total_geometries = len(gdf)

    with rasterio.open(image_path) as src:
        for idx, geometry in enumerate(gdf.geometry):
            try:
                minx, miny, maxx, maxy = geometry.bounds
                window = src.window(minx, miny, maxx, maxy)
                masked_image = src.read(window=window, indexes=[1, 2, 3])
                
                if masked_image.shape[0] < 3 or masked_image.size == 0:
                    logging.warning(f"几何体 {idx} 无有效数据，跳过")
                    print(f"警告：几何体 {idx} 无有效数据，跳过")
                    continue
                
                rgb_means = np.nanmean(masked_image, axis=(1, 2))
                rgb_stds = np.nanstd(masked_image, axis=(1, 2))
                
                if np.isnan(rgb_means).any() or np.isnan(rgb_stds).any():
                    logging.warning(f"几何体 {idx} 包含 NaN 值，跳过")
                    print(f"警告：几何体 {idx} 包含 NaN 值，跳过")
                    continue
                
                r, g, b = rgb_means
                exg = 2 * g - r - b
                vari = (g - r) / (g + r - b + 1e-8)
                
                hsv_image = rgb2hsv(np.moveaxis(masked_image, 0, -1))
                hsv_means = np.nanmean(hsv_image, axis=(0, 1))
                
                green_channel = masked_image[1].astype(np.uint8)
                if green_channel.size > 0:
                    glcm = graycomatrix(green_channel, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
                    contrast = graycoprops(glcm, 'contrast')[0, 0]
                    dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
                    homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
                    energy = graycoprops(glcm, 'energy')[0, 0]
                    correlation = graycoprops(glcm, 'correlation')[0, 0]
                else:
                    logging.warning(f"几何体 {idx} 的绿色通道为空，使用默认纹理特征")
                    contrast = dissimilarity = homogeneity = energy = correlation = 0
                
                feature = np.concatenate([rgb_means, rgb_stds, [exg, vari], hsv_means, 
                                          [contrast, dissimilarity, homogeneity, energy, correlation]])
                features.append(feature)
                valid_geometries.append(idx)

                if progress_callback:
                    progress_callback((idx + 1) / total_geometries * 100)
            except Exception as e:
                logging.error(f"处理几何体 {idx} 时出错: {str(e)}")
    
    if not features:
        raise ValueError("没有成功提取到任何特征")
    
    return np.array(features), valid_geometries

def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, force_new_model=False, progress_callback=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    X_new, valid_indices = extract_features(image_path, gdf_train, progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        X_old, y_old = joblib.load(data_output_path)
        
        if X_old.shape[1] != X_new.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new.shape[1]}")
            logging.info("将重新训练模型")
            clf = RandomForestClassifier(n_estimators=100, random_state=42)
            X_combined, y_combined = X_new, y_new
        else:
            X_combined = np.vstack((X_old, X_new))
            y_combined = np.concatenate((y_old, y_new))
        
        clf.fit(X_combined, y_combined)
    else:
        logging.info("创建新模型")
        clf = RandomForestClassifier(n_estimators=100, random_state=42)
        clf.fit(X_new, y_new)
        X_combined, y_combined = X_new, y_new

    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")

def predict_new_data(model_path, le_path, new_image_path, new_shp_path, output_shp_path, progress_callback=None):
    logging.info("开始预测新数据")
    clf = joblib.load(model_path)
    le = joblib.load(le_path)

    gdf_new = gpd.read_file(new_shp_path, encoding='utf-8')
    X_new, valid_indices = extract_features(new_image_path, gdf_new, progress_callback)
    gdf_new = gdf_new.iloc[valid_indices]
    
    if X_new.shape[1] != clf.n_features_in_:
        raise ValueError(f"错误：特征数量不匹配。模型期望 {clf.n_features_in_} 个特征，但提供了 {X_new.shape[1]} 个特征。")
    
    y_pred = clf.predict(X_new)
    y_proba = clf.predict_proba(X_new)
    gdf_new['ZZZW'] = le.inverse_transform(y_pred)
    gdf_new['ZZZW_proba'] = np.max(y_proba, axis=1)
    gdf_new.to_file(output_shp_path, encoding='utf-8')

    logging.info(f"预测结果已保存到: {output_shp_path}")

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("高级作物分类模型")
        master.geometry("800x600")

        self.create_widgets()
        self.create_menu()

        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="预测")
        self.notebook.add(self.info_tab, text="信息")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    def create_menu(self):
        menubar = tk.Menu(self.master)
        self.master.config(menu=menubar)

        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="文件", menu=file_menu)
        file_menu.add_command(label="退出", command=self.master.quit)

        help_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="帮助", menu=help_menu)
        help_menu.add_command(label="关于", command=self.show_about)

    def setup_train_tab(self):
        ttk.Label(self.train_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.tif_path = tk.StringVar()
        ttk.Entry(self.train_tab, textvariable=self.tif_path, width=50).grid(row=0, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)

        ttk.Label(self.train_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.shp_path = tk.StringVar()
        ttk.Entry(self.train_tab, textvariable=self.shp_path, width=50).grid(row=1, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)

        ttk.Label(self.train_tab, text="模型存储路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.train_tab, textvariable=self.model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=self.browse_model_path).grid(row=2, column=2, padx=5, pady=5)

        self.model_action = tk.StringVar(value="update")
        ttk.Radiobutton(self.train_tab, text="更新现有模型", variable=self.model_action, value="update").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        ttk.Radiobutton(self.train_tab, text="创建新模型", variable=self.model_action, value="new").grid(row=3, column=1, sticky="w", padx=5, pady=5)

        self.train_button = ttk.Button(self.train_tab, text="开始训练", command=self.start_training)
        self.train_button.grid(row=4, column=0, columnspan=3, pady=20)

    def setup_predict_tab(self):
        ttk.Label(self.predict_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.predict_tif_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.predict_tif_path, width=50).grid(row=0, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.predict_shp_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.predict_shp_path, width=50).grid(row=1, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="模型路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.predict_tab, textvariable=self.predict_model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=self.browse_predict_model_path).grid(row=2, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="输出文件路径:").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.output_path, width=50).grid(row=3, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_save_file(self.output_path, [("SHP files", "*.shp")])).grid(row=3, column=2, padx=5, pady=5)

        self.predict_button = ttk.Button(self.predict_tab, text="开始预测", command=self.start_prediction)
        self.predict_button.grid(row=4, column=0, columnspan=3, pady=20)

    def setup_info_tab(self):
        self.info_text = scrolledtext.ScrolledText(self.info_tab, wrap=tk.WORD, width=90, height=20)
        self.info_text.pack(padx=10, pady=10)

        ttk.Button(self.info_tab, text="刷新信息", command=self.refresh_info).pack(pady=10)

    def browse_file(self, path_var, file_types):
        filename = filedialog.askopenfilename(filetypes=file_types)
        if filename:
            path_var.set(filename)

    def browse_save_file(self, path_var, file_types):
        filename = filedialog.asksaveasfilename(filetypes=file_types, defaultextension=file_types[0][1])
        if filename:
            path_var.set(filename)

    def browse_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.model_path.set(path)

    def browse_predict_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.predict_model_path.set(path)

    def start_training(self):
        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"

        if not tif_file or not shp_file or not model_dir:
            messagebox.showerror("错误", "请选择TIF文件、SHP文件和模型存储路径")
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "训练中..."
        self.log_text.insert(tk.END, "开始训练...\n")
        action_text = "创建新模型" if force_new else "更新现有模型"
        self.log_text.insert(tk.END, f"操作: {action_text}\n")
        
        def training_thread():
            try:
                update_or_train_model(tif_file, shp_file, 
                                      os.path.join(model_dir, "model.joblib"), 
                                      os.path.join(model_dir, "label_encoder.joblib"), 
                                      os.path.join(model_dir, "training_data.joblib"), 
                                      force_new,
                                      self.update_progress)
                self.master.after(0, self.training_complete)
            except Exception as e:
                self.master.after(0, lambda: self.training_error(str(e)))

        threading.Thread(target=training_thread).start()

    def start_prediction(self):
        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()

        if not tif_file or not shp_file or not model_dir or not output_file:
            messagebox.showerror("错误", "请选择所有必要的文件和路径")
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "预测中..."
        self.log_text.insert(tk.END, "开始预测...\n")
        
        def prediction_thread():
            try:
                predict_new_data(os.path.join(model_dir, "model.joblib"), 
                                 os.path.join(model_dir, "label_encoder.joblib"), 
                                 tif_file, shp_file, output_file,
                                 self.update_progress)
                self.master.after(0, self.prediction_complete)
            except Exception as e:
                self.master.after(0, lambda: self.prediction_error(str(e)))

        threading.Thread(target=prediction_thread).start()

    def update_progress(self, value):
        self.progress_queue.put(value)

    def check_progress_queue(self):
        try:
            while True:
                value = self.progress_queue.get_nowait()
                self.progress['value'] = value
        except queue.Empty:
            pass
        finally:
            self.master.after(100, self.check_progress_queue)

    def training_complete(self):
        self.enable_buttons()
        self.status_bar['text'] = "训练完成"
        action_text = "创建新模型" if self.model_action.get() == "new" else "更新现有模型"
        self.log_text.insert(tk.END, f"{action_text}完成\n")
        messagebox.showinfo("成功", f"{action_text}完成")
        self.refresh_info()

    def training_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "训练出错"
        self.log_text.insert(tk.END, f"训练过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"训练过程中出错：{error_message}")

    def prediction_complete(self):
        self.enable_buttons()
        self.status_bar['text'] = "预测完成"
        self.log_text.insert(tk.END, "预测完成\n")
        messagebox.showinfo("成功", "预测完成")

    def prediction_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "预测出错"
        self.log_text.insert(tk.END, f"预测过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"预测过程中出错：{error_message}")

    def refresh_info(self):
        self.info_text.delete('1.0', tk.END)
        model_dir = self.model_path.get()
        try:
            model = joblib.load(os.path.join(model_dir, "model.joblib"))
            le = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
            
            info = f"模型信息：\n"
            info += f"模型存储路径：{model_dir}\n"
            info += f"特征数量：{model.n_features_in_}\n"
            info += f"类别：{', '.join(le.classes_)}\n"
            info += f"树的数量：{model.n_estimators}\n"
            
            self.info_text.insert(tk.END, info)
        except Exception as e:
            self.info_text.insert(tk.END, f"无法加载模型信息：{str(e)}")

    def disable_buttons(self):
        self.train_button['state'] = 'disabled'
        self.predict_button['state'] = 'disabled'

    def enable_buttons(self):
        self.train_button['state'] = 'normal'
        self.predict_button['state'] = 'normal'

    def show_about(self):
        messagebox.showinfo("关于", "高级作物分类模型 v1.0\n\n作者：AI Assistant\n\n版权所有 © 2023")

def main():
    root = tk.Tk()
    app = CropClassificationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

ImportError: DLL load failed while importing _base: 找不到指定的模块。

# V2

In [None]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split,RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform

# 配置日志记录
logging.basicConfig(filename='crop_classification.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

fiona.supported_drivers['ESRI Shapefile'] = 'rw'
gpd.io.file.fiona.drvsupport.supported_drivers['ESRI Shapefile'] = 'rw'

def extract_features(image_path, gdf, progress_callback=None):
    features = []
    valid_geometries = []
    total_geometries = len(gdf)

    with rasterio.open(image_path) as src:
        for idx, geometry in enumerate(gdf.geometry):
            try:
                minx, miny, maxx, maxy = geometry.bounds
                window = src.window(minx, miny, maxx, maxy)
                masked_image = src.read(window=window, indexes=[1, 2, 3])
                
                if masked_image.shape[0] < 3 or masked_image.size == 0:
                    logging.warning(f"几何体 {idx} 无有效数据，跳过")
                    print(f"警告：几何体 {idx} 无有效数据，跳过")
                    continue
                
                rgb_means = np.nanmean(masked_image, axis=(1, 2))
                rgb_stds = np.nanstd(masked_image, axis=(1, 2))
                
                if np.isnan(rgb_means).any() or np.isnan(rgb_stds).any():
                    logging.warning(f"几何体 {idx} 包含 NaN 值，跳过")
                    print(f"警告：几何体 {idx} 包含 NaN 值，跳过")
                    continue
                
                r, g, b = rgb_means
                exg = 2 * g - r - b
                vari = (g - r) / (g + r - b + 1e-8)
                
                hsv_image = rgb2hsv(np.moveaxis(masked_image, 0, -1))
                hsv_means = np.nanmean(hsv_image, axis=(0, 1))
                
                green_channel = masked_image[1].astype(np.uint8)
                if green_channel.size > 0:
                    glcm = graycomatrix(green_channel, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
                    contrast = graycoprops(glcm, 'contrast')[0, 0]
                    dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
                    homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
                    energy = graycoprops(glcm, 'energy')[0, 0]
                    correlation = graycoprops(glcm, 'correlation')[0, 0]
                else:
                    logging.warning(f"几何体 {idx} 的绿色通道为空，使用默认纹理特征")
                    contrast = dissimilarity = homogeneity = energy = correlation = 0
                
                feature = np.concatenate([rgb_means, rgb_stds, [exg, vari], hsv_means, 
                                          [contrast, dissimilarity, homogeneity, energy, correlation]])
                features.append(feature)
                valid_geometries.append(idx)

                if progress_callback:
                    progress_callback((idx + 1) / total_geometries * 100)
            except Exception as e:
                logging.error(f"处理几何体 {idx} 时出错: {str(e)}")
    
    if not features:
        raise ValueError("没有成功提取到任何特征")
    
    return np.array(features), valid_geometries

def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, force_new_model=False, progress_callback=None, hyperparameters=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    X_new, valid_indices = extract_features(image_path, gdf_train, progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        X_old, y_old = joblib.load(data_output_path)
        
        if X_old.shape[1] != X_new.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new.shape[1]}")
            logging.info("将重新训练模型")
            X_combined, y_combined = X_new, y_new
        else:
            X_combined = np.vstack((X_old, X_new))
            y_combined = np.concatenate((y_old, y_new))
    else:
        logging.info("创建新模型")
        X_combined, y_combined = X_new, y_new

    # 定义随机搜索的参数范围
    if hyperparameters is None:
        hyperparameters = {
            'n_estimators': randint(50, 300),
            'max_depth': randint(5, 50),
            'min_samples_split': randint(2, 20),
            'min_samples_leaf': randint(1, 10),
            'max_features': uniform(0.1, 0.9)
        }

    # 创建随机搜索对象
    random_search = RandomizedSearchCV(
        RandomForestClassifier(random_state=42),
        param_distributions=hyperparameters,
        n_iter=50,
        cv=5,
        random_state=42,
        n_jobs=-1
    )

    # 执行随机搜索
    random_search.fit(X_combined, y_combined)

    # 获取最佳模型
    clf = random_search.best_estimator_

    # 评估模型
    y_pred = clf.predict(X_combined)
    accuracy = accuracy_score(y_combined, y_pred)
    report = classification_report(y_combined, y_pred, target_names=le.classes_, output_dict=True)

    # 保存模型和数据
    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")
    logging.info(f"最佳参数: {random_search.best_params_}")
    logging.info(f"模型整体精度: {accuracy}")
    logging.info("各类别精度:")
    for class_name, metrics in report.items():
        if isinstance(metrics, dict):
            logging.info(f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}")

    return accuracy, report

def predict_new_data(model_path, le_path, new_image_path, new_shp_path, output_shp_path, progress_callback=None):
    logging.info("开始预测新数据")
    clf = joblib.load(model_path)
    le = joblib.load(le_path)

    gdf_new = gpd.read_file(new_shp_path, encoding='utf-8')
    X_new, valid_indices = extract_features(new_image_path, gdf_new, progress_callback)
    gdf_new = gdf_new.iloc[valid_indices]
    
    if X_new.shape[1] != clf.n_features_in_:
        raise ValueError(f"错误：特征数量不匹配。模型期望 {clf.n_features_in_} 个特征，但提供了 {X_new.shape[1]} 个特征。")
    
    y_pred = clf.predict(X_new)
    y_proba = clf.predict_proba(X_new)
    gdf_new['ZZZW'] = le.inverse_transform(y_pred)
    gdf_new['ZZZW_proba'] = np.max(y_proba, axis=1)
    gdf_new.to_file(output_shp_path, encoding='utf-8')

    logging.info(f"预测结果已保存到: {output_shp_path}")

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("高级作物分类模型")
        master.geometry("800x600")

        self.create_widgets()
        self.create_menu()

        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="预测")
        self.notebook.add(self.info_tab, text="信息")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    def create_menu(self):
        menubar = tk.Menu(self.master)
        self.master.config(menu=menubar)

        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="文件", menu=file_menu)
        file_menu.add_command(label="退出", command=self.master.quit)

        help_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="帮助", menu=help_menu)
        help_menu.add_command(label="关于", command=self.show_about)

    def setup_train_tab(self):
        ttk.Label(self.train_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.tif_path = tk.StringVar()
        ttk.Entry(self.train_tab, textvariable=self.tif_path, width=50).grid(row=0, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)

        ttk.Label(self.train_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.shp_path = tk.StringVar()
        ttk.Entry(self.train_tab, textvariable=self.shp_path, width=50).grid(row=1, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)

        ttk.Label(self.train_tab, text="模型存储路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.train_tab, textvariable=self.model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=self.browse_model_path).grid(row=2, column=2, padx=5, pady=5)

        self.model_action = tk.StringVar(value="update")
        ttk.Radiobutton(self.train_tab, text="更新现有模型", variable=self.model_action, value="update").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        ttk.Radiobutton(self.train_tab, text="创建新模型", variable=self.model_action, value="new").grid(row=3, column=1, sticky="w", padx=5, pady=5)
        
        # 添加超参数输入框
        ttk.Label(self.train_tab, text="n_estimators 范围:").grid(row=5, column=0, sticky="w", padx=5, pady=5)
        self.n_estimators_min = tk.IntVar(value=50)
        self.n_estimators_max = tk.IntVar(value=300)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_min, width=5).grid(row=5, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=5, column=1)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_max, width=5).grid(row=5, column=1, sticky="e", padx=5, pady=5)

        ttk.Label(self.train_tab, text="max_depth 范围:").grid(row=6, column=0, sticky="w", padx=5, pady=5)
        self.max_depth_min = tk.IntVar(value=5)
        self.max_depth_max = tk.IntVar(value=50)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_min, width=5).grid(row=6, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=6, column=1)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_max, width=5).grid(row=6, column=1, sticky="e", padx=5, pady=5)


        self.train_button = ttk.Button(self.train_tab, text="开始训练", command=self.start_training)
        self.train_button.grid(row=4, column=0, columnspan=3, pady=20)
        
        

    def setup_predict_tab(self):
        ttk.Label(self.predict_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.predict_tif_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.predict_tif_path, width=50).grid(row=0, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.predict_shp_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.predict_shp_path, width=50).grid(row=1, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="模型路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.predict_tab, textvariable=self.predict_model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=self.browse_predict_model_path).grid(row=2, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="输出文件路径:").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.output_path, width=50).grid(row=3, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_save_file(self.output_path, [("SHP files", "*.shp")])).grid(row=3, column=2, padx=5, pady=5)

        self.predict_button = ttk.Button(self.predict_tab, text="开始预测", command=self.start_prediction)
        self.predict_button.grid(row=4, column=0, columnspan=3, pady=20)

    def setup_info_tab(self):
        self.info_text = scrolledtext.ScrolledText(self.info_tab, wrap=tk.WORD, width=90, height=20)
        self.info_text.pack(padx=10, pady=10)

        ttk.Button(self.info_tab, text="刷新信息", command=self.refresh_info).pack(pady=10)

    def browse_file(self, path_var, file_types):
        filename = filedialog.askopenfilename(filetypes=file_types)
        if filename:
            path_var.set(filename)

    def browse_save_file(self, path_var, file_types):
        filename = filedialog.asksaveasfilename(filetypes=file_types, defaultextension=file_types[0][1])
        if filename:
            path_var.set(filename)

    def browse_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.model_path.set(path)

    def browse_predict_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.predict_model_path.set(path)

    def start_training(self):
        
        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        hyperparameters = {
        'n_estimators': randint(self.n_estimators_min.get(), self.n_estimators_max.get()),
        'max_depth': randint(self.max_depth_min.get(), self.max_depth_max.get()),
        # ... [添加其他超参数]
    }
        if not tif_file or not shp_file or not model_dir:
            messagebox.showerror("错误", "请选择TIF文件、SHP文件和模型存储路径")
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "训练中..."
        self.log_text.insert(tk.END, "开始训练...\n")
        action_text = "创建新模型" if force_new else "更新现有模型"
        self.log_text.insert(tk.END, f"操作: {action_text}\n")
        
        def training_thread():
            try:
                accuracy, report = update_or_train_model(
                    tif_file, shp_file, 
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    os.path.join(model_dir, "training_data.joblib"), 
                    force_new,
                    self.update_progress,
                    hyperparameters
                )
                self.master.after(0, lambda: self.training_complete(accuracy, report))
            except Exception as e:
                self.master.after(0, lambda: self.training_error(str(e)))

        threading.Thread(target=training_thread).start()

    def start_prediction(self):
        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()

        if not tif_file or not shp_file or not model_dir or not output_file:
            messagebox.showerror("错误", "请选择所有必要的文件和路径")
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "预测中..."
        self.log_text.insert(tk.END, "开始预测...\n")
        
        def prediction_thread():
            try:
                predict_new_data(os.path.join(model_dir, "model.joblib"), 
                                 os.path.join(model_dir, "label_encoder.joblib"), 
                                 tif_file, shp_file, output_file,
                                 self.update_progress)
                self.master.after(0, self.prediction_complete)
            except Exception as e:
                self.master.after(0, lambda: self.prediction_error(str(e)))

        threading.Thread(target=prediction_thread).start()

    def update_progress(self, value):
        self.progress_queue.put(value)

    def check_progress_queue(self):
        try:
            while True:
                value = self.progress_queue.get_nowait()
                self.progress['value'] = value
        except queue.Empty:
            pass
        finally:
            self.master.after(100, self.check_progress_queue)

    def training_complete(self, accuracy, report):
        self.enable_buttons()
        self.status_bar['text'] = "训练完成"
        action_text = "创建新模型" if self.model_action.get() == "new" else "更新现有模型"
        self.log_text.insert(tk.END, f"{action_text}完成\n")
        self.log_text.insert(tk.END, f"模型整体精度: {accuracy:.4f}\n")
        self.log_text.insert(tk.END, "各类别精度:\n")
        for class_name, metrics in report.items():
            if isinstance(metrics, dict):
                self.log_text.insert(tk.END, f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}\n")
        messagebox.showinfo("成功", f"{action_text}完成\n模型整体精度: {accuracy:.4f}")
        self.refresh_info()

    def training_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "训练出错"
        self.log_text.insert(tk.END, f"训练过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"训练过程中出错：{error_message}")

    def prediction_complete(self):
        self.enable_buttons()
        self.status_bar['text'] = "预测完成"
        self.log_text.insert(tk.END, "预测完成\n")
        messagebox.showinfo("成功", "预测完成")

    def prediction_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "预测出错"
        self.log_text.insert(tk.END, f"预测过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"预测过程中出错：{error_message}")

    def refresh_info(self):
        self.info_text.delete('1.0', tk.END)
        model_dir = self.model_path.get()
        try:
            model = joblib.load(os.path.join(model_dir, "model.joblib"))
            le = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
            
            info = f"模型信息：\n"
            info += f"模型存储路径：{model_dir}\n"
            info += f"特征数量：{model.n_features_in_}\n"
            info += f"类别：{', '.join(le.classes_)}\n"
            info += f"树的数量：{model.n_estimators}\n"
            
            self.info_text.insert(tk.END, info)
        except Exception as e:
            self.info_text.insert(tk.END, f"无法加载模型信息：{str(e)}")

    def disable_buttons(self):
        self.train_button['state'] = 'disabled'
        self.predict_button['state'] = 'disabled'

    def enable_buttons(self):
        self.train_button['state'] = 'normal'
        self.predict_button['state'] = 'normal'

    def show_about(self):
        messagebox.showinfo("关于", "高级作物分类模型 v1.0\n\n作者：AI Assistant\n\n版权所有 © 2023")

def main():
    root = tk.Tk()
    app = CropClassificationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

# V3

In [1]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform

# 配置日志记录
logging.basicConfig(filename='crop_classification.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

fiona.supported_drivers['ESRI Shapefile'] = 'rw'
gpd.io.file.fiona.drvsupport.supported_drivers['ESRI Shapefile'] = 'rw'

def validate_shp_data(gdf, column_name='ZZZW'):
    """
    验证SHP数据是否包含指定的标签列，以及该列是否有空值
    """
    if column_name not in gdf.columns:
        raise ValueError(f"SHP文件中缺少'{column_name}'列")
    
    if gdf[column_name].isnull().any():
        raise ValueError(f"'{column_name}'列中存在空值")

def extract_features(image_path, gdf, progress_callback=None):
    features = []
    valid_geometries = []
    invalid_geometries = []
    total_geometries = len(gdf)

    with rasterio.open(image_path) as src:
        for idx, geometry in enumerate(gdf.geometry):
            try:
                minx, miny, maxx, maxy = geometry.bounds
                window = src.window(minx, miny, maxx, maxy)
                masked_image = src.read(window=window, indexes=[1, 2, 3])
                
                if masked_image.shape[0] < 3 or masked_image.size == 0:
                    logging.warning(f"几何体 {idx} 无有效数据，跳过")
                    invalid_geometries.append(idx)
                    continue
                
                rgb_means = np.nanmean(masked_image, axis=(1, 2))
                rgb_stds = np.nanstd(masked_image, axis=(1, 2))
                
                if np.isnan(rgb_means).any() or np.isnan(rgb_stds).any():
                    logging.warning(f"几何体 {idx} 包含 NaN 值，跳过")
                    invalid_geometries.append(idx)
                    continue
                
                r, g, b = rgb_means
                exg = 2 * g - r - b
                vari = (g - r) / (g + r - b + 1e-8)
                
                hsv_image = rgb2hsv(np.moveaxis(masked_image, 0, -1))
                hsv_means = np.nanmean(hsv_image, axis=(0, 1))
                
                green_channel = masked_image[1].astype(np.uint8)
                if green_channel.size > 0:
                    glcm = graycomatrix(green_channel, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
                    contrast = graycoprops(glcm, 'contrast')[0, 0]
                    dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
                    homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
                    energy = graycoprops(glcm, 'energy')[0, 0]
                    correlation = graycoprops(glcm, 'correlation')[0, 0]
                else:
                    logging.warning(f"几何体 {idx} 的绿色通道为空，使用默认纹理特征")
                    contrast = dissimilarity = homogeneity = energy = correlation = 0
                
                feature = np.concatenate([rgb_means, rgb_stds, [exg, vari], hsv_means, 
                                          [contrast, dissimilarity, homogeneity, energy, correlation]])
                features.append(feature)
                valid_geometries.append(idx)

                if progress_callback:
                    progress_callback((idx + 1) / total_geometries * 100)
            except Exception as e:
                logging.error(f"处理几何体 {idx} 时出错: {str(e)}")
                invalid_geometries.append(idx)
    
    if not features:
        raise ValueError("没有成功提取到任何特征")
    
    return np.array(features), valid_geometries, invalid_geometries

def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, force_new_model=False, progress_callback=None, hyperparameters=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    
    # 验证SHP数据
    validate_shp_data(gdf_train)
    
    X_new, valid_indices, _ = extract_features(image_path, gdf_train, progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        X_old, y_old = joblib.load(data_output_path)
        
        if X_old.shape[1] != X_new.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new.shape[1]}")
            logging.info("将重新训练模型")
            X_combined, y_combined = X_new, y_new
        else:
            X_combined = np.vstack((X_old, X_new))
            y_combined = np.concatenate((y_old, y_new))
    else:
        logging.info("创建新模型")
        X_combined, y_combined = X_new, y_new

    # 定义随机搜索的参数范围
    if hyperparameters is None:
        hyperparameters = {
            'n_estimators': randint(50, 300),
            'max_depth': randint(5, 50),
            'min_samples_split': randint(2, 20),
            'min_samples_leaf': randint(1, 10),
            'max_features': uniform(0.1, 0.9)
        }

    # 创建随机搜索对象
    random_search = RandomizedSearchCV(
        RandomForestClassifier(random_state=42),
        param_distributions=hyperparameters,
        n_iter=50,
        cv=5,
        random_state=42,
        n_jobs=-1
    )

    # 执行随机搜索
    random_search.fit(X_combined, y_combined)

    # 获取最佳模型
    clf = random_search.best_estimator_

    # 评估模型
    y_pred = clf.predict(X_combined)
    accuracy = accuracy_score(y_combined, y_pred)
    report = classification_report(y_combined, y_pred, target_names=le.classes_, output_dict=True)

    # 保存模型和数据
    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")
    logging.info(f"最佳参数: {random_search.best_params_}")
    logging.info(f"模型整体精度: {accuracy}")
    logging.info("各类别精度:")
    for class_name, metrics in report.items():
        if isinstance(metrics, dict):
            logging.info(f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}")

    return accuracy, report

def predict_new_data(model_path, le_path, new_image_path, new_shp_path, output_shp_path, progress_callback=None):
    logging.info("开始预测新数据")
    clf = joblib.load(model_path)
    le = joblib.load(le_path)

    gdf_new = gpd.read_file(new_shp_path, encoding='utf-8')
    X_new, valid_indices, invalid_indices = extract_features(new_image_path, gdf_new, progress_callback)
    
    if X_new.shape[1] != clf.n_features_in_:
        raise ValueError(f"错误：特征数量不匹配。模型期望 {clf.n_features_in_} 个特征，但提供了 {X_new.shape[1]} 个特征。")
    
    y_pred = clf.predict(X_new)
    y_proba = clf.predict_proba(X_new)
    
    # 为有效的几何体添加预测结果
    gdf_new.loc[valid_indices, 'ZZZW'] = le.inverse_transform(y_pred)
    gdf_new.loc[valid_indices, 'ZZZW_proba'] = np.max(y_proba, axis=1)
    
    # 为无效的几何体添加标记
    gdf_new.loc[invalid_indices, 'ZZZW'] = 'Invalid'
    gdf_new.loc[invalid_indices, 'ZZZW_proba'] = 0
    
    gdf_new.to_file(output_shp_path, encoding='utf-8')

    logging.info(f"预测结果已保存到: {output_shp_path}")
    return len(invalid_indices)

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("RGB分类模型")
        master.geometry("800x600")

        self.create_widgets()
        self.create_menu()

        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="预测")
        self.notebook.add(self.info_tab, text="模型信息")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    def create_menu(self):
        menubar = tk.Menu(self.master)
        self.master.config(menu=menubar)

        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="文件", menu=file_menu)
        file_menu.add_command(label="退出", command=self.master.quit)

        help_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="帮助", menu=help_menu)
        help_menu.add_command(label="关于", command=self.show_about)

    def setup_train_tab(self):
        ttk.Label(self.train_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.tif_path = tk.StringVar()
        ttk.Entry(self.train_tab, textvariable=self.tif_path, width=50).grid(row=0, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)

        ttk.Label(self.train_tab, text="选择带有标签的SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.shp_path = tk.StringVar()
        ttk.Entry(self.train_tab, textvariable=self.shp_path, width=50).grid(row=1, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)

        ttk.Label(self.train_tab, text="模型存储路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.train_tab, textvariable=self.model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=self.browse_model_path).grid(row=2, column=2, padx=5, pady=5)

        self.model_action = tk.StringVar(value="update")
        ttk.Radiobutton(self.train_tab, text="更新现有模型", variable=self.model_action, value="update").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        ttk.Radiobutton(self.train_tab, text="创建新模型", variable=self.model_action, value="new").grid(row=3, column=1, sticky="w", padx=5, pady=5)
        
        # 超参数输入框
        ttk.Label(self.train_tab, text="n_estimators 范围:").grid(row=5, column=0, sticky="w", padx=5, pady=5)
        self.n_estimators_min = tk.IntVar(value=50)
        self.n_estimators_max = tk.IntVar(value=300)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_min, width=5).grid(row=5, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=5, column=1)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_max, width=5).grid(row=5, column=1, sticky="e", padx=5, pady=5)

        ttk.Label(self.train_tab, text="max_depth 范围:").grid(row=6, column=0, sticky="w", padx=5, pady=5)
        self.max_depth_min = tk.IntVar(value=5)
        self.max_depth_max = tk.IntVar(value=50)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_min, width=5).grid(row=6, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=6, column=1)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_max, width=5).grid(row=6, column=1, sticky="e", padx=5, pady=5)

        self.train_button = ttk.Button(self.train_tab, text="开始训练", command=self.start_training)
        self.train_button.grid(row=7, column=0, columnspan=3, pady=20)

    def setup_predict_tab(self):
        ttk.Label(self.predict_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.predict_tif_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.predict_tif_path, width=50).grid(row=0, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.predict_shp_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.predict_shp_path, width=50).grid(row=1, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="模型路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.predict_tab, textvariable=self.predict_model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=self.browse_predict_model_path).grid(row=2, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="输出文件路径:").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.output_path, width=50).grid(row=3, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_save_file(self.output_path, [("SHP files", "*.shp")])).grid(row=3, column=2, padx=5, pady=5)

        self.predict_button = ttk.Button(self.predict_tab, text="开始预测", command=self.start_prediction)
        self.predict_button.grid(row=4, column=0, columnspan=3, pady=20)

    def setup_info_tab(self):
        self.info_text = scrolledtext.ScrolledText(self.info_tab, wrap=tk.WORD, width=90, height=20)
        self.info_text.pack(padx=10, pady=10)

        ttk.Button(self.info_tab, text="获取模型信息", command=self.refresh_info).pack(pady=10)

    def browse_file(self, path_var, file_types):
        filename = filedialog.askopenfilename(filetypes=file_types)
        if filename:
            path_var.set(filename)

    def browse_save_file(self, path_var, file_types):
        filename = filedialog.asksaveasfilename(filetypes=file_types, defaultextension=file_types[0][1])
        if filename:
            path_var.set(filename)

    def browse_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.model_path.set(path)

    def browse_predict_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.predict_model_path.set(path)

    def start_training(self):
        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        hyperparameters = {
            'n_estimators': randint(self.n_estimators_min.get(), self.n_estimators_max.get()),
            'max_depth': randint(self.max_depth_min.get(), self.max_depth_max.get()),
            'min_samples_split': randint(2, 20),
            'min_samples_leaf': randint(1, 10),
            'max_features': uniform(0.1, 0.9)
        }

        if not tif_file or not shp_file or not model_dir:
            messagebox.showerror("错误", "请选择TIF文件、SHP文件和模型存储路径")
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "训练中..."
        self.log_text.insert(tk.END, "开始训练...\n")
        action_text = "创建新模型" if force_new else "更新现有模型"
        self.log_text.insert(tk.END, f"操作: {action_text}\n")
        
        def training_thread():
            try:
                accuracy, report = update_or_train_model(
                    tif_file, shp_file, 
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    os.path.join(model_dir, "training_data.joblib"), 
                    force_new,
                    self.update_progress,
                    hyperparameters
                )
                self.master.after(0, lambda: self.training_complete(accuracy, report))
            except Exception as e:
                self.master.after(0, lambda: self.training_error(str(e)))

        threading.Thread(target=training_thread).start()

    def start_prediction(self):
        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()

        if not tif_file or not shp_file or not model_dir or not output_file:
            messagebox.showerror("错误", "请选择所有必要的文件和路径")
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "预测中..."
        self.log_text.insert(tk.END, "开始预测...\n")
        
        def prediction_thread():
            try:
                invalid_count = predict_new_data(
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    tif_file, shp_file, output_file,
                    self.update_progress
                )
                self.master.after(0, lambda: self.prediction_complete(invalid_count))
            except Exception as e:
                self.master.after(0, lambda: self.prediction_error(str(e)))

        threading.Thread(target=prediction_thread).start()

    def update_progress(self, value):
        self.progress_queue.put(value)

    def check_progress_queue(self):
        try:
            while True:
                value = self.progress_queue.get_nowait()
                self.progress['value'] = value
        except queue.Empty:
            pass
        finally:
            self.master.after(100, self.check_progress_queue)

    def training_complete(self, accuracy, report):
        self.enable_buttons()
        self.status_bar['text'] = "训练完成"
        action_text = "创建新模型" if self.model_action.get() == "new" else "更新现有模型"
        self.log_text.insert(tk.END, f"{action_text}完成\n")
        self.log_text.insert(tk.END, f"模型整体精度: {accuracy:.4f}\n")
        self.log_text.insert(tk.END, "各类别精度:\n")
        for class_name, metrics in report.items():
            if isinstance(metrics, dict):
                self.log_text.insert(tk.END, f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}\n")
        messagebox.showinfo("成功", f"{action_text}完成\n模型整体精度: {accuracy:.4f}")
        self.refresh_info()

    def training_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "训练出错"
        self.log_text.insert(tk.END, f"训练过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"训练过程中出错：{error_message}")

    def prediction_complete(self, invalid_count):
        self.enable_buttons()
        self.status_bar['text'] = "预测完成"
        self.log_text.insert(tk.END, "预测完成\n")
        if invalid_count > 0:
            self.log_text.insert(tk.END, f"警告：{invalid_count}个几何体无法进行预测，已在输出中标记为'Invalid'\n")
        messagebox.showinfo("成功", f"预测完成\n{invalid_count}个几何体无法预测")

    def prediction_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "预测出错"
        self.log_text.insert(tk.END, f"预测过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"预测过程中出错：{error_message}")

    def refresh_info(self):
        self.info_text.delete('1.0', tk.END)
        model_dir = self.model_path.get()
        try:
            model = joblib.load(os.path.join(model_dir, "model.joblib"))
            le = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
            
            info = f"模型信息：\n"
            info += f"模型存储路径：{model_dir}\n"
            info += f"特征数量：{model.n_features_in_}\n"
            info += f"类别：{', '.join(le.classes_)}\n"
            info += f"树的数量：{model.n_estimators}\n"
            info += f"最大深度：{model.max_depth}\n"
            info += f"最小分裂样本数：{model.min_samples_split}\n"
            info += f"最小叶子节点样本数：{model.min_samples_leaf}\n"
            info += f"特征选择方式：{model.max_features}\n"
            
            self.info_text.insert(tk.END, info)
        except Exception as e:
            self.info_text.insert(tk.END, f"无法加载模型信息：{str(e)}")

    def disable_buttons(self):
        self.train_button['state'] = 'disabled'
        self.predict_button['state'] = 'disabled'

    def enable_buttons(self):
        self.train_button['state'] = 'normal'
        self.predict_button['state'] = 'normal'

    def show_about(self):
        messagebox.showinfo("关于", "RGB分类模型 v1.0\n\n作者：AI 贵州雏阳\n\n版权所有 © 2024")

def main():
    root = tk.Tk()
    app = CropClassificationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

AttributeError: 'NoneType' object has no attribute 'drvsupport'

# V4

In [None]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform

# ... [前面的导入和设置保持不变]

def get_raster_info(raster_path):
    with rasterio.open(raster_path) as src:
        return {
            'pixel_size': src.res,
            'crs': src.crs,
            'bounds': src.bounds
        }

def validate_raster_consistency(new_raster_path, model_info_path):
    new_info = get_raster_info(new_raster_path)
    
    if os.path.exists(model_info_path):
        old_info = joblib.load(model_info_path)
        if new_info['pixel_size'] != old_info['pixel_size']:
            return False, f"像元大小不匹配。原始: {old_info['pixel_size']}, 新: {new_info['pixel_size']}"
        if new_info['crs'] != old_info['crs']:
            return False, f"坐标系统不匹配。原始: {old_info['crs']}, 新: {new_info['crs']}"
    
    return True, ""

def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, model_info_path, force_new_model=False, progress_callback=None, hyperparameters=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    
    # 检查栅格一致性
    if not force_new_model:
        is_consistent, message = validate_raster_consistency(image_path, model_info_path)
        if not is_consistent:
            raise ValueError(f"栅格数据不一致: {message}")
    
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    
    # 验证SHP数据
    validate_shp_data(gdf_train)
    
    X_new, valid_indices, _ = extract_features(image_path, gdf_train, progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    # 使用StandardScaler进行特征标准化
    scaler = StandardScaler()
    X_new_scaled = scaler.fit_transform(X_new)
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        old_scaler = joblib.load(os.path.join(os.path.dirname(model_output_path), "scaler.joblib"))
        X_old, y_old = joblib.load(data_output_path)
        
        # 使用旧的scaler转换新数据，确保一致性
        X_new_scaled = old_scaler.transform(X_new)
        
        if X_old.shape[1] != X_new_scaled.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new_scaled.shape[1]}")
            logging.info("将重新训练模型")
            X_combined, y_combined = X_new_scaled, y_new
        else:
            X_combined = np.vstack((X_old, X_new_scaled))
            y_combined = np.concatenate((y_old, y_new))
    else:
        logging.info("创建新模型")
        X_combined, y_combined = X_new_scaled, y_new

    # ... [随机搜索和模型训练部分保持不变]

    # 保存模型和数据
    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)
    joblib.dump(scaler, os.path.join(os.path.dirname(model_output_path), "scaler.joblib"))
    
    # 保存栅格信息
    raster_info = get_raster_info(image_path)
    joblib.dump(raster_info, model_info_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")
    logging.info(f"特征缩放器已保存到: {os.path.join(os.path.dirname(model_output_path), 'scaler.joblib')}")
    logging.info(f"栅格信息已保存到: {model_info_path}")
    logging.info(f"最佳参数: {clf.get_params()}")
    logging.info(f"模型整体精度: {accuracy}")

    return accuracy, report

class CropClassificationApp:
    # ... [大部分代码保持不变]

    def start_training(self):
        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        
        if not force_new:
            is_consistent, message = validate_raster_consistency(tif_file, os.path.join(model_dir, "raster_info.joblib"))
            if not is_consistent:
                response = messagebox.askyesno("警告", f"{message}\n是否继续更新模型？这可能会影响模型性能。")
                if not response:
                    return

        # ... [其余代码保持不变]

    def start_prediction(self):
        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()

        is_consistent, message = validate_raster_consistency(tif_file, os.path.join(model_dir, "raster_info.joblib"))
        if not is_consistent:
            response = messagebox.askyesno("警告", f"{message}\n是否继续进行预测？这可能会影响预测结果的准确性。")
            if not response:
                return

        # ... [其余代码保持不变]

# ... [主函数和其他部分保持不变]

In [None]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder, StandardScaler
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform
from osgeo import gdal

# 设置GDAL不要压缩输出文件
gdal.SetConfigOption('COMPRESS_OVERVIEW', 'NONE')
gdal.UseExceptions()  # 启用GDAL异常

# ... [其他导入和设置保持不变]

def process_raster(input_file, output_file, target_srs, target_res):
    try:
        # 如果输出文件已存在，跳过处理
        if os.path.exists(output_file):
            print(f"输出文件已存在，跳过处理: {output_file}")
            return output_file

        # 打开输入文件
        ds = gdal.Open(input_file)
        if ds is None:
            print(f"无法打开文件: {input_file}")
            return None

        # 设置转换选项
        options = gdal.WarpOptions(
            dstSRS=target_srs,
            xRes=target_res[0], yRes=target_res[1],
            resampleAlg=gdal.GRA_Bilinear,
            format='GTiff',
            creationOptions=['COMPRESS=LZW', 'TILED=YES', 'BLOCKXSIZE=256', 'BLOCKYSIZE=256', 'BIGTIFF=YES'],
            warpOptions=['CUTLINE_ALL_TOUCHED=TRUE'],
            multithread=True,
            warpMemoryLimit=4096  # 限制内存使用为4GB
        )

        # 执行转换
        print(f"正在处理文件: {input_file}")
        gdal.Warp(output_file, ds, options=options)
        
        # 关闭数据集
        ds = None
        print(f"成功处理文件: {input_file}")
        return output_file

    except Exception as e:
        print(f"处理文件 {input_file} 时出错: {str(e)}")
        return None

# ... [其他函数保持不变]

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("高级作物分类模型")
        master.geometry("800x700")  # 增加窗口高度以容纳新的控件

        self.create_widgets()
        self.create_menu()

        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)
        self.settings_tab = ttk.Frame(self.notebook)  # 新增设置标签页

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="预测")
        self.notebook.add(self.info_tab, text="信息")
        self.notebook.add(self.settings_tab, text="设置")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()
        self.setup_settings_tab()  # 设置新的标签页

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    # ... [其他方法保持不变]

    def setup_settings_tab(self):
        ttk.Label(self.settings_tab, text="目标坐标系统 (EPSG):").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.target_srs = tk.StringVar(value="EPSG:4544")
        ttk.Entry(self.settings_tab, textvariable=self.target_srs, width=20).grid(row=0, column=1, padx=5, pady=5)

        ttk.Label(self.settings_tab, text="目标像元大小 X:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.target_res_x = tk.DoubleVar(value=0.1)
        ttk.Entry(self.settings_tab, textvariable=self.target_res_x, width=10).grid(row=1, column=1, padx=5, pady=5)

        ttk.Label(self.settings_tab, text="目标像元大小 Y:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.target_res_y = tk.DoubleVar(value=0.1)
        ttk.Entry(self.settings_tab, textvariable=self.target_res_y, width=10).grid(row=2, column=1, padx=5, pady=5)

    def start_training(self):
        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        
        if not os.path.exists(tif_file) or not os.path.exists(shp_file):
            messagebox.showerror("错误", "TIF文件或SHP文件不存在")
            return

        # 处理TIF文件
        processed_tif = os.path.join(os.path.dirname(tif_file), 'processed_' + os.path.basename(tif_file))
        processed_tif = process_raster(tif_file, processed_tif, self.target_srs.get(), (self.target_res_x.get(), self.target_res_y.get()))
        
        if processed_tif is None:
            messagebox.showerror("错误", "TIF文件处理失败")
            return

        # 更新SHP文件的坐标系
        gdf = gpd.read_file(shp_file)
        gdf = gdf.to_crs(self.target_srs.get())
        processed_shp = os.path.join(os.path.dirname(shp_file), 'processed_' + os.path.basename(shp_file))
        gdf.to_file(processed_shp)

        # ... [其余训练代码保持不变，但使用processed_tif和processed_shp]

    def start_prediction(self):
        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()

        if not os.path.exists(tif_file) or not os.path.exists(shp_file):
            messagebox.showerror("错误", "TIF文件或SHP文件不存在")
            return

        # 处理TIF文件
        processed_tif = os.path.join(os.path.dirname(tif_file), 'processed_' + os.path.basename(tif_file))
        processed_tif = process_raster(tif_file, processed_tif, self.target_srs.get(), (self.target_res_x.get(), self.target_res_y.get()))
        
        if processed_tif is None:
            messagebox.showerror("错误", "TIF文件处理失败")
            return

        # 更新SHP文件的坐标系
        gdf = gpd.read_file(shp_file)
        gdf = gdf.to_crs(self.target_srs.get())
        processed_shp = os.path.join(os.path.dirname(shp_file), 'processed_' + os.path.basename(shp_file))
        gdf.to_file(processed_shp)

        # ... [其余预测代码保持不变，但使用processed_tif和processed_shp]

# ... [主函数和其他部分保持不变]

# V5

In [5]:
import tkinterdnd2
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
from tkinter import filedialog, messagebox, ttk, scrolledtext
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score, confusion_matrix, ConfusionMatrixDisplay
from sklearn.preprocessing import LabelEncoder
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform
from logging.handlers import RotatingFileHandler
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
from sklearn.inspection import permutation_importance

def setup_logging():
    log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    log_file = 'crop_classification.log'
    log_handler = RotatingFileHandler(log_file, maxBytes=1024 * 1024, backupCount=5)
    log_handler.setFormatter(log_formatter)
    log_handler.setLevel(logging.INFO)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    root_logger.addHandler(log_handler)

setup_logging()

# 配置日志记录
logging.basicConfig(filename='crop_classification.log', level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

fiona.supported_drivers['ESRI Shapefile'] = 'rw'
gpd.io.file.fiona.drvsupport.supported_drivers['ESRI Shapefile'] = 'rw'

def validate_shp_data(gdf, column_name='ZZZW'):
    """
    验证SHP数据是否包含指定的标签列，以及该列是否有空值
    """
    if column_name not in gdf.columns:
        raise ValueError(f"SHP文件中缺少'{column_name}'列")
    
    if gdf[column_name].isnull().any():
        raise ValueError(f"'{column_name}'列中存在空值")
def extract_features_chunked(image_path, gdf, use_chunking=False, chunk_size=1024, progress_callback=None):
    features = []
    valid_geometries = []
    invalid_geometries = []
    total_geometries = len(gdf)

    with rasterio.open(image_path) as src:
        for idx, geometry in enumerate(gdf.geometry):
            try:
                window = rasterio.features.geometry_window(src, [geometry])
                
                if use_chunking:
                    chunk_windows = []
                    for col_off in range(0, window.width, chunk_size):
                        for row_off in range(0, window.height, chunk_size):
                            chunk_window = rasterio.windows.Window(
                                window.col_off + col_off, 
                                window.row_off + row_off,
                                min(chunk_size, window.width - col_off),
                                min(chunk_size, window.height - row_off)
                            )
                            chunk_windows.append(chunk_window)
                else:
                    chunk_windows = [window]

                feature_chunks = []
                for chunk_window in chunk_windows:
                    masked_chunk = src.read(window=chunk_window, indexes=[1, 2, 3])
                    
                    if masked_chunk.size == 0:
                        continue
                    
                    rgb_means = np.nanmean(masked_chunk, axis=(1, 2))
                    rgb_stds = np.nanstd(masked_chunk, axis=(1, 2))
                    
                    r, g, b = rgb_means
                    exg = 2 * g - r - b
                    vari = (g - r) / (g + r - b + 1e-8)
                    
                    hsv_chunk = rgb2hsv(np.moveaxis(masked_chunk, 0, -1))
                    hsv_means = np.nanmean(hsv_chunk, axis=(0, 1))
                    
                    green_channel = masked_chunk[1].astype(np.uint8)
                    if green_channel.size > 0:
                        glcm = graycomatrix(green_channel, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
                        contrast = graycoprops(glcm, 'contrast')[0, 0]
                        dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
                        homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
                        energy = graycoprops(glcm, 'energy')[0, 0]
                        correlation = graycoprops(glcm, 'correlation')[0, 0]
                    else:
                        contrast = dissimilarity = homogeneity = energy = correlation = 0
                    
                    chunk_feature = np.concatenate([rgb_means, rgb_stds, [exg, vari], hsv_means, 
                                                    [contrast, dissimilarity, homogeneity, energy, correlation]])
                    feature_chunks.append(chunk_feature)
                
                if feature_chunks:
                    feature = np.mean(feature_chunks, axis=0)
                    features.append(feature)
                    valid_geometries.append(idx)
                else:
                    invalid_geometries.append(idx)

                if progress_callback:
                    progress_callback((idx + 1) / total_geometries * 100)
            except Exception as e:
                logging.error(f"处理几何体 {idx} 时出错: {str(e)}")
                invalid_geometries.append(idx)
    
    if not features:
        raise ValueError("没有成功提取到任何特征")
    
    return np.array(features), valid_geometries, invalid_geometries

def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, force_new_model=False, progress_callback=None, hyperparameters=None, use_chunking=False, chunk_size=None,cancel_check=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    
    # 验证SHP数据
    validate_shp_data(gdf_train)
    
    X_new, valid_indices, _ = extract_features_chunked(image_path, gdf_train, use_chunking=use_chunking, chunk_size=chunk_size, progress_callback=progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        X_old, y_old = joblib.load(data_output_path)
        
        if X_old.shape[1] != X_new.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new.shape[1]}")
            logging.info("将重新训练模型")
            X_combined, y_combined = X_new, y_new
        else:
            X_combined = np.vstack((X_old, X_new))
            y_combined = np.concatenate((y_old, y_new))
    else:
        logging.info("创建新模型")
        X_combined, y_combined = X_new, y_new

    # 定义随机搜索的参数范围
    if hyperparameters is None:
        hyperparameters = {
            'n_estimators': randint(50, 300),
            'max_depth': randint(5, 50),
            'min_samples_split': randint(2, 20),
            'min_samples_leaf': randint(1, 10),
            'max_features': uniform(0.1, 0.9)
        }

    # 创建随机搜索对象
    random_search = RandomizedSearchCV(
        RandomForestClassifier(random_state=42),
        param_distributions=hyperparameters,
        n_iter=50,
        cv=5,
        random_state=42,
        n_jobs=-1
    )

    # 执行随机搜索
    random_search.fit(X_combined, y_combined)

    # 获取最佳模型
    clf = random_search.best_estimator_

    # 评估模型
    y_pred = clf.predict(X_combined)
    accuracy = accuracy_score(y_combined, y_pred)
    report = classification_report(y_combined, y_pred, target_names=le.classes_, output_dict=True)

    # 保存模型和数据
    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")
    logging.info(f"最佳参数: {random_search.best_params_}")
    logging.info(f"模型整体精度: {accuracy}")
    logging.info("各类别精度:")
    for class_name, metrics in report.items():
        if isinstance(metrics, dict):
            logging.info(f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}")
    if cancel_check and cancel_check():
        raise InterruptedError("操作被用户取消")

    return accuracy, report

def  predict_new_data(model_path, le_path, new_image_path, new_shp_path, output_shp_path, progress_callback=None, use_chunking=False, chunk_size=None,cancel_check=None):
    logging.info("开始预测新数据")
    clf = joblib.load(model_path)
    le = joblib.load(le_path)

    gdf_new = gpd.read_file(new_shp_path, encoding='utf-8')
    X_new, valid_indices, invalid_indices = extract_features_chunked(new_image_path, gdf_new, use_chunking=use_chunking, chunk_size=chunk_size, progress_callback=progress_callback)
    
    if X_new.shape[1] != clf.n_features_in_:
        raise ValueError(f"错误：特征数量不匹配。模型期望 {clf.n_features_in_} 个特征，但提供了 {X_new.shape[1]} 个特征。")
    
    y_pred = clf.predict(X_new)
    y_proba = clf.predict_proba(X_new)
    
    # 为有效的几何体添加预测结果
    gdf_new.loc[valid_indices, 'ZZZW'] = le.inverse_transform(y_pred)
    gdf_new.loc[valid_indices, 'ZZZW_proba'] = np.max(y_proba, axis=1)
    
    # 为无效的几何体添加标记
    gdf_new.loc[invalid_indices, 'ZZZW'] = 'Invalid'
    gdf_new.loc[invalid_indices, 'ZZZW_proba'] = 0
    
    gdf_new.to_file(output_shp_path, encoding='utf-8')

    logging.info(f"预测结果已保存到: {output_shp_path}")
    if cancel_check and cancel_check():
        raise InterruptedError("操作被用户取消")
    return len(invalid_indices)

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("RGB分类模型")
        master.geometry("800x600")
        
        self.create_widgets()
        self.create_menu()
        
        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

        self.use_chunking = tk.BooleanVar(value=False)
        self.chunk_size = tk.IntVar(value=1024)

        self.cancel_operation = False
    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="预测")
        self.notebook.add(self.info_tab, text="模型信息")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    def create_menu(self):
        menubar = tk.Menu(self.master)
        self.master.config(menu=menubar)

        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="文件", menu=file_menu)
        file_menu.add_command(label="退出", command=self.master.quit)

        help_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="帮助", menu=help_menu)
        help_menu.add_command(label="关于", command=self.show_about)

    def setup_train_tab(self):
        ttk.Label(self.train_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.tif_path = tk.StringVar()
        tif_entry = ttk.Entry(self.train_tab, textvariable=self.tif_path, width=50)
        tif_entry.grid(row=0, column=1, padx=5, pady=5)
        tif_entry.drop_target_register(tkinterdnd2.DND_FILES)
        tif_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.tif_path))
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)
        ttk.Button(self.train_tab, text="预览", command=lambda: self.preview_file(self.tif_path)).grid(row=0, column=3, padx=5, pady=5)

        ttk.Label(self.train_tab, text="选择带有标签的SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.shp_path = tk.StringVar()
        shp_entry = ttk.Entry(self.train_tab, textvariable=self.shp_path, width=50)
        shp_entry.grid(row=1, column=1, padx=5, pady=5)
        shp_entry.drop_target_register(tkinterdnd2.DND_FILES)
        shp_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.shp_path))
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)
        ttk.Button(self.train_tab, text="预览", command=lambda: self.preview_file(self.shp_path)).grid(row=1, column=3, padx=5, pady=5)

        ttk.Label(self.train_tab, text="模型存储路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.train_tab, textvariable=self.model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=self.browse_model_path).grid(row=2, column=2, padx=5, pady=5)

        self.model_action = tk.StringVar(value="update")
        ttk.Radiobutton(self.train_tab, text="更新现有模型", variable=self.model_action, value="update").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        ttk.Radiobutton(self.train_tab, text="创建新模型", variable=self.model_action, value="new").grid(row=3, column=1, sticky="w", padx=5, pady=5)
        
        # 超参数输入框
        ttk.Label(self.train_tab, text="n_estimators 范围:").grid(row=5, column=0, sticky="w", padx=5, pady=5)
        self.n_estimators_min = tk.IntVar(value=50)
        self.n_estimators_max = tk.IntVar(value=300)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_min, width=5).grid(row=5, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=5, column=1)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_max, width=5).grid(row=5, column=1, sticky="e", padx=5, pady=5)

        ttk.Label(self.train_tab, text="max_depth 范围:").grid(row=6, column=0, sticky="w", padx=5, pady=5)
        self.max_depth_min = tk.IntVar(value=5)
        self.max_depth_max = tk.IntVar(value=50)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_min, width=5).grid(row=6, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=6, column=1)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_max, width=5).grid(row=6, column=1, sticky="e", padx=5, pady=5)

        self.train_button = ttk.Button(self.train_tab, text="开始训练", command=self.start_training)
        self.train_button.grid(row=7, column=0, columnspan=3, pady=20)


        ttk.Checkbutton(self.train_tab, text="使用分块处理", variable=self.use_chunking, command=self.toggle_chunk_size).grid(row=8, column=0, columnspan=2, sticky="w", padx=5, pady=5)
    
        self.chunk_size_label = ttk.Label(self.train_tab, text="分块大小:")
        self.chunk_size_label.grid(row=9, column=0, sticky="w", padx=5, pady=5)
        self.chunk_size_entry = ttk.Entry(self.train_tab, textvariable=self.chunk_size, width=10)
        self.chunk_size_entry.grid(row=9, column=1, sticky="w", padx=5, pady=5)
        
        # 初始状态下禁用分块大小输入
        self.chunk_size_label['state'] = 'disabled'
        self.chunk_size_entry['state'] = 'disabled'

    def setup_predict_tab(self):
        ttk.Label(self.predict_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.predict_tif_path = tk.StringVar()
        predict_tif_entry = ttk.Entry(self.predict_tab, textvariable=self.predict_tif_path, width=50)
        predict_tif_entry.grid(row=0, column=1, padx=5, pady=5)
        predict_tif_entry.drop_target_register(tkinterdnd2.DND_FILES)
        predict_tif_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.predict_tif_path))
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="预览", command=lambda: self.preview_file(self.predict_tif_path)).grid(row=0, column=3, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.predict_shp_path = tk.StringVar()
        predict_shp_entry = ttk.Entry(self.predict_tab, textvariable=self.predict_shp_path, width=50)
        predict_shp_entry.grid(row=1, column=1, padx=5, pady=5)
        predict_shp_entry.drop_target_register(tkinterdnd2.DND_FILES)
        predict_shp_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.predict_shp_path))
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="预览", command=lambda: self.preview_file(self.predict_shp_path)).grid(row=1, column=3, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="模型路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.predict_tab, textvariable=self.predict_model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=self.browse_predict_model_path).grid(row=2, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="输出文件路径:").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.output_path, width=50).grid(row=3, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_save_file(self.output_path, [("SHP files", "*.shp")])).grid(row=3, column=2, padx=5, pady=5)

        self.predict_button = ttk.Button(self.predict_tab, text="开始预测", command=self.start_prediction)
        self.predict_button.grid(row=4, column=0, columnspan=3, pady=20)

        self.show_map_button = ttk.Button(self.predict_tab, text="显示分类地图", command=lambda: self.show_classification_map(self.output_path.get()))
        self.show_map_button.grid(row=5, column=0, columnspan=3, pady=10)
        self.show_map_button['state'] = 'disabled'  # 初始状态为禁用

        ttk.Checkbutton(self.predict_tab, text="使用分块处理", variable=self.use_chunking, command=self.toggle_chunk_size).grid(row=5, column=0, columnspan=2, sticky="w", padx=5, pady=5)
    
        self.chunk_size_label_predict = ttk.Label(self.predict_tab, text="分块大小:")
        self.chunk_size_label_predict.grid(row=6, column=0, sticky="w", padx=5, pady=5)
        self.chunk_size_entry_predict = ttk.Entry(self.predict_tab, textvariable=self.chunk_size, width=10)
        self.chunk_size_entry_predict.grid(row=6, column=1, sticky="w", padx=5, pady=5)
    
        # 初始状态下禁用分块大小输入
        self.chunk_size_label_predict['state'] = 'disabled'
        self.chunk_size_entry_predict['state'] = 'disabled'
    def toggle_chunk_size(self):
        state = 'normal' if self.use_chunking.get() else 'disabled'
        self.chunk_size_label['state'] = state
        self.chunk_size_entry['state'] = state
        self.chunk_size_label_predict['state'] = state
        self.chunk_size_entry_predict['state'] = state

    def validate_inputs(self):
        errors = []
        if not self.tif_path.get():
            errors.append("请选择TIF文件")
        if not self.shp_path.get():
            errors.append("请选择SHP文件")
        if self.use_chunking.get() and (self.chunk_size.get() <= 0 or self.chunk_size.get() > 10000):
            errors.append("分块大小应在1到10000之间")
        return errors
    def setup_info_tab(self):
        self.info_text = scrolledtext.ScrolledText(self.info_tab, wrap=tk.WORD, width=90, height=20)
        self.info_text.pack(padx=10, pady=10)

        ttk.Button(self.info_tab, text="获取模型信息", command=self.refresh_info).pack(pady=10)

    def browse_file(self, path_var, file_types):
        filename = filedialog.askopenfilename(filetypes=file_types)
        if filename:
            path_var.set(filename)

    def browse_save_file(self, path_var, file_types):
        filename = filedialog.asksaveasfilename(filetypes=file_types, defaultextension=file_types[0][1])
        if filename:
            path_var.set(filename)

    def browse_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.model_path.set(path)

    def browse_predict_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.predict_model_path.set(path)

    def drop_file(self, event, path_var):
        file_path = event.data
        if file_path.startswith('{') and file_path.endswith('}'):
            file_path = file_path[1:-1]
        file_extension = os.path.splitext(file_path)[1].lower()
        if (path_var in [self.tif_path, self.predict_tif_path] and file_extension == '.tif') or \
        (path_var in [self.shp_path, self.predict_shp_path] and file_extension == '.shp'):
            path_var.set(file_path)
        else:
            messagebox.showerror("错误", "请拖放正确的文件类型")
    @staticmethod
    def get_file_info(file_path):
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
        file_type = os.path.splitext(file_path)[1].lower()
        
        info = f"文件路径: {file_path}\n"
        info += f"文件大小: {file_size:.2f} MB\n"
        info += f"文件类型: {file_type}\n"
        
        if file_type == '.tif':
            with rasterio.open(file_path) as src:
                info += f"图像大小: {src.width} x {src.height}\n"
                info += f"波段数: {src.count}\n"
                info += f"坐标系统: {src.crs}\n"
        elif file_type == '.shp':
            gdf = gpd.read_file(file_path)
            info += f"要素数量: {len(gdf)}\n"
            info += f"几何类型: {gdf.geom_type.iloc[0]}\n"
            info += f"属性列: {', '.join(gdf.columns)}\n"
        
        return info
    def show_classification_map(self, shp_file_path):
        try:
            # 读取分类结果的shapefile
            gdf = gpd.read_file(shp_file_path)
            
            # 创建一个新的Toplevel窗口
            map_window = tk.Toplevel(self.master)
            map_window.title("分类结果地图")
            map_window.geometry("800x600")

            # 创建matplotlib图形
            fig, ax = plt.subplots(figsize=(10, 8))
            
            # 绘制地图
            gdf.plot(column='ZZZW', ax=ax, legend=True, cmap='viridis', legend_kwds={'label': '分类结果', 'orientation': 'horizontal'})
            ax.set_axis_off()
            plt.title('分类结果空间分布')

            # 将matplotlib图形嵌入到Tkinter窗口中
            canvas = FigureCanvasTkAgg(fig, master=map_window)
            canvas.draw()
            canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        except Exception as e:
            messagebox.showerror("错误", f"无法创建地图：{str(e)}")
    def preview_file(self, path_var):
        file_path = path_var.get()
        if not file_path:
            messagebox.showwarning("警告", "请先选择文件")
            return
        
        try:
            info = self.get_file_info(file_path)
            messagebox.showinfo("文件预览", info)
        except Exception as e:
            messagebox.showerror("错误", f"无法预览文件：{str(e)}")
    def cancel_current_operation(self):
        self.cancel_operation = True
        self.log_text.insert(tk.END, "正在取消操作...\n")
        self.status_bar['text'] = "正在取消..."

    def operation_cancelled(self):
        self.enable_buttons()
        self.status_bar['text'] = "操作已取消"
        self.log_text.insert(tk.END, "操作已取消\n")
        messagebox.showinfo("已取消", "操作已被用户取消")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def save_model_version(self, model, le, accuracy, report, timestamp):
        version_dir = os.path.join(self.model_path.get(), f"version_{timestamp}")
        os.makedirs(version_dir, exist_ok=True)
        
        joblib.dump(model, os.path.join(version_dir, "model.joblib"))
        joblib.dump(le, os.path.join(version_dir, "label_encoder.joblib"))
        
        # 保存模型性能信息
        with open(os.path.join(version_dir, "performance.txt"), "w") as f:
            f.write(f"Accuracy: {accuracy}\n\n")
            f.write("Classification Report:\n")
            f.write(str(report))
        
        # 更新版本列表
        versions = self.load_model_versions()
        versions.append({
            "timestamp": timestamp,
            "accuracy": accuracy,
            "path": version_dir
        })
        with open(os.path.join(self.model_path.get(), "versions.json"), "w") as f:
            json.dump(versions, f)
    def load_model_versions(self):
        version_file = os.path.join(self.model_path.get(), "versions.json")
        if os.path.exists(version_file):
            with open(version_file, "r") as f:
                return json.load(f)
        return []        
        
    def start_training(self):
        errors = self.validate_inputs()
        if errors:
            messagebox.showerror("输入错误", "\n".join(errors))
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "训练中..."
        self.log_text.insert(tk.END, "开始训练...\n")
        self.cancel_operation = False

        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        use_chunking = self.use_chunking.get()
        chunk_size = self.chunk_size.get() if use_chunking else None

        def training_thread():
            try:
                accuracy, report = update_or_train_model(
                    tif_file, shp_file, 
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    os.path.join(model_dir, "training_data.joblib"), 
                    force_new,
                    self.update_progress,
                    use_chunking=use_chunking,
                    chunk_size=chunk_size,
                    cancel_check=lambda: self.cancel_operation
                    
                )
                if not self.cancel_operation:
                    self.master.after(0, lambda: self.training_complete(accuracy, report))
                else:
                    self.master.after(0, self.operation_cancelled)
            except Exception as e:
                logging.exception("训练过程中出错")
                self.master.after(0, lambda: self.training_error(str(e)))

        self.current_thread = threading.Thread(target=training_thread)
        self.current_thread.start()

        # 添加取消按钮
        self.cancel_button = ttk.Button(self.train_tab, text="取消", command=self.cancel_current_operation)
        self.cancel_button.grid(row=10, column=0, columnspan=3, pady=10)

    def start_prediction(self):
        errors = self.validate_inputs()
        if errors:
            messagebox.showerror("输入错误", "\n".join(errors))
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "预测中..."
        self.log_text.insert(tk.END, "开始预测...\n")
        self.cancel_operation = False

        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()
        use_chunking = self.use_chunking.get()
        chunk_size = self.chunk_size.get() if use_chunking else None

        def prediction_thread():
            try:
                invalid_count = predict_new_data(
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    tif_file, shp_file, output_file,
                    self.update_progress,
                    use_chunking=use_chunking,
                    chunk_size=chunk_size,
                    cancel_check=lambda: self.cancel_operation
                )
                if not self.cancel_operation:
                    self.master.after(0, lambda: self.prediction_complete(invalid_count))
                else:
                    self.master.after(0, self.operation_cancelled)
            except Exception as e:
                logging.exception("预测过程中出错")
                self.master.after(0, lambda: self.prediction_error(str(e)))

        self.current_thread = threading.Thread(target=prediction_thread)
        self.current_thread.start()

        # 添加取消按钮
        self.cancel_button = ttk.Button(self.predict_tab, text="取消", command=self.cancel_current_operation)
        self.cancel_button.grid(row=7, column=0, columnspan=3, pady=10)

    def update_progress(self, value):
        self.progress_queue.put(value)

    def check_progress_queue(self):
        try:
            while True:
                value = self.progress_queue.get_nowait()
                self.progress['value'] = value
        except queue.Empty:
            pass
        finally:
            self.master.after(100, self.check_progress_queue)

    def training_complete(self, accuracy, report):
        self.enable_buttons()
        self.status_bar['text'] = "训练完成"
        action_text = "创建新模型" if self.model_action.get() == "new" else "更新现有模型"
        self.log_text.insert(tk.END, f"{action_text}完成\n")
        self.log_text.insert(tk.END, f"模型整体精度: {accuracy:.4f}\n")
        self.log_text.insert(tk.END, "各类别精度:\n")
        for class_name, metrics in report.items():
            if isinstance(metrics, dict):
                self.log_text.insert(tk.END, f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}\n")
        messagebox.showinfo("成功", f"{action_text}完成\n模型整体精度: {accuracy:.4f}")
        self.refresh_info()

    def training_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "训练出错"
        self.log_text.insert(tk.END, f"训练过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"训练过程中出错：{error_message}")

    def prediction_complete(self, invalid_count):
        self.enable_buttons()
        self.status_bar['text'] = "预测完成"
        self.log_text.insert(tk.END, "预测完成\n")
        if invalid_count > 0:
            self.log_text.insert(tk.END, f"警告：{invalid_count}个几何体无法进行预测，已在输出中标记为'Invalid'\n")
        messagebox.showinfo("成功", f"预测完成\n{invalid_count}个几何体无法预测")
        # 显示分类结果地图
        self.show_classification_map(self.output_path.get())
        # 启用显示地图
        self.show_map_button['state'] = 'normal'  # 启用显示地图按钮

    def prediction_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "预测出错"
        self.log_text.insert(tk.END, f"预测过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"预测过程中出错：{error_message}")
        
    def refresh_info(self):
        self.info_text.delete('1.0', tk.END)
        model_dir = self.model_path.get()
        try:
            model = joblib.load(os.path.join(model_dir, "model.joblib"))
            le = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
            
            info = f"模型信息：\n"
            info += f"模型存储路径：{model_dir}\n"
            info += f"特征数量：{model.n_features_in_}\n"
            info += f"类别：{', '.join(le.classes_)}\n"
            info += f"树的数量：{model.n_estimators}\n"
            info += f"最大深度：{model.max_depth}\n"
            info += f"最小分裂样本数：{model.min_samples_split}\n"
            info += f"最小叶子节点样本数：{model.min_samples_leaf}\n"
            info += f"特征选择方式：{model.max_features}\n"
            
            self.info_text.insert(tk.END, info)
        except Exception as e:
            self.info_text.insert(tk.END, f"无法加载模型信息：{str(e)}")

    def disable_buttons(self):
        self.train_button['state'] = 'disabled'
        self.predict_button['state'] = 'disabled'

    def enable_buttons(self):
        self.train_button['state'] = 'normal'
        self.predict_button['state'] = 'normal'

    def show_about(self):
        messagebox.showinfo("关于", "RGB分类模型 v1.0\n\n作者：AI 贵州雏阳\n\n版权所有 © 2024")

def main():
    root = tkinterdnd2.TkinterDnD.Tk()
    app = CropClassificationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

ImportError: Matplotlib requires numpy>=1.23; you have 1.22.4

# V6

In [10]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext
import os
import joblib
import threading
import queue
import logging
from logging.handlers import RotatingFileHandler
import rasterio
import geopandas as gpd
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform
import tkinterdnd2

# 配置日志记录
def setup_logging():
    log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    log_file = 'crop_classification.log'
    log_handler = RotatingFileHandler(log_file, maxBytes=1024 * 1024, backupCount=5)
    log_handler.setFormatter(log_formatter)
    log_handler.setLevel(logging.INFO)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    root_logger.addHandler(log_handler)

setup_logging()

fiona.supported_drivers['ESRI Shapefile'] = 'rw'
# gpd.io.file.fiona.drvsupport.supported_drivers['ESRI Shapefile'] = 'rw'

def validate_shp_data(gdf, column_name='ZZZW'):
    """
    验证SHP数据是否包含指定的标签列，以及该列是否有空值
    """
    if column_name not in gdf.columns:
        raise ValueError(f"SHP文件中缺少'{column_name}'列")
    
    if gdf[column_name].isnull().any():
        raise ValueError(f"'{column_name}'列中存在空值")

def extract_features_chunked(image_path, gdf, use_chunking=False, chunk_size=1024, progress_callback=None):
    features = []
    valid_geometries = []
    invalid_geometries = []
    total_geometries = len(gdf)

    with rasterio.open(image_path) as src:
        for idx, geometry in enumerate(gdf.geometry):
            try:
                window = rasterio.windows.from_bounds(*geometry.bounds, src.transform)
                
                if use_chunking:
                    chunk_windows = []
                    for col_off in range(0, window.width, chunk_size):
                        for row_off in range(0, window.height, chunk_size):
                            chunk_window = rasterio.windows.Window(
                                window.col_off + col_off, 
                                window.row_off + row_off,
                                min(chunk_size, window.width - col_off),
                                min(chunk_size, window.height - row_off)
                            )
                            chunk_windows.append(chunk_window)
                else:
                    chunk_windows = [window]

                feature_chunks = []
                for chunk_window in chunk_windows:
                    masked_chunk = src.read(window=chunk_window, indexes=[1, 2, 3])
                    
                    if masked_chunk.size == 0:
                        logging.warning(f"几何体 {idx}: 无数据")
                        continue
                    
                    rgb_means = np.nanmean(masked_chunk, axis=(1, 2))
                    rgb_stds = np.nanstd(masked_chunk, axis=(1, 2))
                    
                    r, g, b = rgb_means
                    exg = 2 * g - r - b
                    vari = (g - r) / (g + r - b + 1e-8)
                    
                    hsv_chunk = rgb2hsv(np.moveaxis(masked_chunk, 0, -1))
                    hsv_means = np.nanmean(hsv_chunk, axis=(0, 1))
                    
                    green_channel = masked_chunk[1].astype(np.uint8)
                    if green_channel.size > 0:
                        glcm = graycomatrix(green_channel, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
                        contrast = graycoprops(glcm, 'contrast')[0, 0]
                        dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
                        homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
                        energy = graycoprops(glcm, 'energy')[0, 0]
                        correlation = graycoprops(glcm, 'correlation')[0, 0]
                    else:
                        contrast = dissimilarity = homogeneity = energy = correlation = 0
                    
                    chunk_feature = np.concatenate([rgb_means, rgb_stds, [exg, vari], hsv_means, 
                                                    [contrast, dissimilarity, homogeneity, energy, correlation]])
                    feature_chunks.append(chunk_feature)
                
                if feature_chunks:
                    feature = np.mean(feature_chunks, axis=0)
                    features.append(feature)
                    valid_geometries.append(idx)
                else:
                    logging.warning(f"几何体 {idx}: 未能提取特征")
                    invalid_geometries.append(idx)

                if progress_callback:
                    progress_callback((idx + 1) / total_geometries * 100)
            except Exception as e:
                logging.error(f"处理几何体 {idx} 时出错: {str(e)}")
                invalid_geometries.append(idx)
    
    if not features:
        raise ValueError("没有成功提取到任何特征。请检查输入数据是否正确，以及TIF和SHP文件是否匹配。")
    
    return np.array(features), valid_geometries, invalid_geometries
    
    return np.array(features), valid_geometries, invalid_geometries
def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, force_new_model=False, progress_callback=None, hyperparameters=None, use_chunking=False, chunk_size=None, cancel_check=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    
    # 验证SHP数据
    validate_shp_data(gdf_train)
    
    X_new, valid_indices, _ = extract_features_chunked(image_path, gdf_train, use_chunking=use_chunking, chunk_size=chunk_size, progress_callback=progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        X_old, y_old = joblib.load(data_output_path)
        
        if X_old.shape[1] != X_new.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new.shape[1]}")
            logging.info("将重新训练模型")
            X_combined, y_combined = X_new, y_new
        else:
            X_combined = np.vstack((X_old, X_new))
            y_combined = np.concatenate((y_old, y_new))
    else:
        logging.info("创建新模型")
        X_combined, y_combined = X_new, y_new

    # 定义随机搜索的参数范围
    if hyperparameters is None:
        hyperparameters = {
            'n_estimators': randint(50, 300),
            'max_depth': randint(5, 50),
            'min_samples_split': randint(2, 20),
            'min_samples_leaf': randint(1, 10),
            'max_features': uniform(0.1, 0.9)
        }

    # 创建随机搜索对象
    random_search = RandomizedSearchCV(
        RandomForestClassifier(random_state=42),
        param_distributions=hyperparameters,
        n_iter=50,
        cv=5,
        random_state=42,
        n_jobs=-1
    )

    # 执行随机搜索
    random_search.fit(X_combined, y_combined)

    # 获取最佳模型
    clf = random_search.best_estimator_

    # 评估模型
    y_pred = clf.predict(X_combined)
    accuracy = accuracy_score(y_combined, y_pred)
    report = classification_report(y_combined, y_pred, target_names=le.classes_, output_dict=True)

    # 保存模型和数据
    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")
    logging.info(f"最佳参数: {random_search.best_params_}")
    logging.info(f"模型整体精度: {accuracy}")
    logging.info("各类别精度:")
    for class_name, metrics in report.items():
        if isinstance(metrics, dict):
            logging.info(f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}")

    if cancel_check and cancel_check():
        raise InterruptedError("操作被用户取消")

    return accuracy, report

def predict_new_data(model_path, le_path, new_image_path, new_shp_path, output_shp_path, progress_callback=None, use_chunking=False, chunk_size=None, cancel_check=None):
    logging.info("开始预测新数据")
    clf = joblib.load(model_path)
    le = joblib.load(le_path)

    gdf_new = gpd.read_file(new_shp_path, encoding='utf-8')
    X_new, valid_indices, invalid_indices = extract_features_chunked(new_image_path, gdf_new, use_chunking=use_chunking, chunk_size=chunk_size, progress_callback=progress_callback)
    
    if X_new.shape[1] != clf.n_features_in_:
        raise ValueError(f"错误：特征数量不匹配。模型期望 {clf.n_features_in_} 个特征，但提供了 {X_new.shape[1]} 个特征。")
    
    y_pred = clf.predict(X_new)
    y_proba = clf.predict_proba(X_new)
    
    # 为有效的几何体添加预测结果
    gdf_new.loc[valid_indices, 'ZZZW'] = le.inverse_transform(y_pred)
    gdf_new.loc[valid_indices, 'ZZZW_proba'] = np.max(y_proba, axis=1)
    
    # 为无效的几何体添加标记
    gdf_new.loc[invalid_indices, 'ZZZW'] = 'Invalid'
    gdf_new.loc[invalid_indices, 'ZZZW_proba'] = 0
    
    gdf_new.to_file(output_shp_path, encoding='utf-8')

    if cancel_check and cancel_check():
        raise InterruptedError("操作被用户取消")

    logging.info(f"预测结果已保存到: {output_shp_path}")
    return len(invalid_indices)

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("RGB分类模型")
        master.geometry("800x600")

        # 初始化所有需要的属性
        self.use_chunking = tk.BooleanVar(value=False)
        self.chunk_size = tk.IntVar(value=1024)
        self.tif_path = tk.StringVar()
        self.shp_path = tk.StringVar()
        self.model_path = tk.StringVar(value=os.getcwd())
        self.model_action = tk.StringVar(value="update")
        self.predict_tif_path = tk.StringVar()
        self.predict_shp_path = tk.StringVar()
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        self.output_path = tk.StringVar()

        self.create_widgets()
        self.create_menu()

        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

        self.cancel_operation = False

    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="预测")
        self.notebook.add(self.info_tab, text="模型信息")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    def create_menu(self):
        menubar = tk.Menu(self.master)
        self.master.config(menu=menubar)

        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="文件", menu=file_menu)
        file_menu.add_command(label="退出", command=self.master.quit)

        help_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="帮助", menu=help_menu)
        help_menu.add_command(label="关于", command=self.show_about)

    def setup_train_tab(self):
        ttk.Label(self.train_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.tif_path = tk.StringVar()
        tif_entry = ttk.Entry(self.train_tab, textvariable=self.tif_path, width=50)
        tif_entry.grid(row=0, column=1, padx=5, pady=5)
        tif_entry.drop_target_register(tkinterdnd2.DND_FILES)
        tif_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.tif_path))
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)
        ttk.Button(self.train_tab, text="预览", command=lambda: self.preview_file(self.tif_path)).grid(row=0, column=3, padx=5, pady=5)

        ttk.Label(self.train_tab, text="选择带有标签的SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.shp_path = tk.StringVar()
        shp_entry = ttk.Entry(self.train_tab, textvariable=self.shp_path, width=50)
        shp_entry.grid(row=1, column=1, padx=5, pady=5)
        shp_entry.drop_target_register(tkinterdnd2.DND_FILES)
        shp_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.shp_path))
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)
        ttk.Button(self.train_tab, text="预览", command=lambda: self.preview_file(self.shp_path)).grid(row=1, column=3, padx=5, pady=5)

        ttk.Label(self.train_tab, text="模型存储路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.train_tab, textvariable=self.model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=self.browse_model_path).grid(row=2, column=2, padx=5, pady=5)

        self.model_action = tk.StringVar(value="update")
        ttk.Radiobutton(self.train_tab, text="更新现有模型", variable=self.model_action, value="update").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        ttk.Radiobutton(self.train_tab, text="创建新模型", variable=self.model_action, value="new").grid(row=3, column=1, sticky="w", padx=5, pady=5)

        ttk.Checkbutton(self.train_tab, text="使用分块处理", variable=self.use_chunking, command=self.toggle_chunk_size).grid(row=4, column=0, columnspan=2, sticky="w", padx=5, pady=5)
        
        self.chunk_size_label = ttk.Label(self.train_tab, text="分块大小:")
        self.chunk_size_label.grid(row=5, column=0, sticky="w", padx=5, pady=5)
        self.chunk_size_entry = ttk.Entry(self.train_tab, textvariable=self.chunk_size, width=10)
        self.chunk_size_entry.grid(row=5, column=1, sticky="w", padx=5, pady=5)
        
        # 初始状态下禁用分块大小输入
        self.chunk_size_label['state'] = 'disabled'
        self.chunk_size_entry['state'] = 'disabled'

        self.train_button = ttk.Button(self.train_tab, text="开始训练", command=self.start_training)
        self.train_button.grid(row=6, column=0, columnspan=3, pady=20)

    def setup_predict_tab(self):
        ttk.Label(self.predict_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.predict_tif_path = tk.StringVar()
        predict_tif_entry = ttk.Entry(self.predict_tab, textvariable=self.predict_tif_path, width=50)
        predict_tif_entry.grid(row=0, column=1, padx=5, pady=5)
        predict_tif_entry.drop_target_register(tkinterdnd2.DND_FILES)
        predict_tif_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.predict_tif_path))
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="预览", command=lambda: self.preview_file(self.predict_tif_path)).grid(row=0, column=3, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.predict_shp_path = tk.StringVar()
        predict_shp_entry = ttk.Entry(self.predict_tab, textvariable=self.predict_shp_path, width=50)
        predict_shp_entry.grid(row=1, column=1, padx=5, pady=5)
        predict_shp_entry.drop_target_register(tkinterdnd2.DND_FILES)
        predict_shp_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.predict_shp_path))
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="预览", command=lambda: self.preview_file(self.predict_shp_path)).grid(row=1, column=3, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="模型路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.predict_tab, textvariable=self.predict_model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=self.browse_predict_model_path).grid(row=2, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="输出文件路径:").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.output_path, width=50).grid(row=3, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_save_file(self.output_path, [("SHP files", "*.shp")])).grid(row=3, column=2, padx=5, pady=5)

        ttk.Checkbutton(self.predict_tab, text="使用分块处理", variable=self.use_chunking, command=self.toggle_chunk_size).grid(row=4, column=0, columnspan=2, sticky="w", padx=5, pady=5)
        
        self.chunk_size_label_predict = ttk.Label(self.predict_tab, text="分块大小:")
        self.chunk_size_label_predict.grid(row=5, column=0, sticky="w", padx=5, pady=5)
        self.chunk_size_entry_predict = ttk.Entry(self.predict_tab, textvariable=self.chunk_size, width=10)
        self.chunk_size_entry_predict.grid(row=5, column=1, sticky="w", padx=5, pady=5)
        
        # 初始状态下禁用分块大小输入
        self.chunk_size_label_predict['state'] = 'disabled'
        self.chunk_size_entry_predict['state'] = 'disabled'

        self.predict_button = ttk.Button(self.predict_tab, text="开始预测", command=self.start_prediction)
        self.predict_button.grid(row=6, column=0, columnspan=3, pady=20)

    def setup_info_tab(self):
        self.info_text = scrolledtext.ScrolledText(self.info_tab, wrap=tk.WORD, width=90, height=20)
        self.info_text.pack(padx=10, pady=10)

        ttk.Button(self.info_tab, text="获取模型信息", command=self.refresh_info).pack(pady=10)

    def browse_file(self, path_var, file_types):
        filename = filedialog.askopenfilename(filetypes=file_types)
        if filename:
            path_var.set(filename)

    def browse_save_file(self, path_var, file_types):
        filename = filedialog.asksaveasfilename(filetypes=file_types, defaultextension=file_types[0][1])
        if filename:
            path_var.set(filename)

    def browse_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.model_path.set(path)

    def browse_predict_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.predict_model_path.set(path)

    def toggle_chunk_size(self):
        state = 'normal' if self.use_chunking.get() else 'disabled'
        self.chunk_size_label['state'] = state
        self.chunk_size_entry['state'] = state
        self.chunk_size_label_predict['state'] = state
        self.chunk_size_entry_predict['state'] = state

    def drop_file(self, event, path_var):
        file_path = event.data
        if file_path.startswith('{') and file_path.endswith('}'):
            file_path = file_path[1:-1]
        file_extension = os.path.splitext(file_path)[1].lower()
        if (path_var in [self.tif_path, self.predict_tif_path] and file_extension == '.tif') or \
           (path_var in [self.shp_path, self.predict_shp_path] and file_extension == '.shp'):
            path_var.set(file_path)
        else:
            messagebox.showerror("错误", "请拖放正确的文件类型")

    def preview_file(self, path_var):
        file_path = path_var.get()
        if not file_path:
            messagebox.showwarning("警告", "请先选择文件")
            return
        
        try:
            info = self.get_file_info(file_path)
            messagebox.showinfo("文件预览", info)
        except Exception as e:
            messagebox.showerror("错误", f"无法预览文件：{str(e)}")

    @staticmethod
    def get_file_info(file_path):
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
        file_type = os.path.splitext(file_path)[1].lower()
        
        info = f"文件路径: {file_path}\n"
        info += f"文件大小: {file_size:.2f} MB\n"
        info += f"文件类型: {file_type}\n"
        
        if file_type == '.tif':
            with rasterio.open(file_path) as src:
                info += f"图像大小: {src.width} x {src.height}\n"
                info += f"波段数: {src.count}\n"
                info += f"坐标系统: {src.crs}\n"
        elif file_type == '.shp':
            gdf = gpd.read_file(file_path)
            info += f"要素数量: {len(gdf)}\n"
            info += f"几何类型: {gdf.geom_type.iloc[0]}\n"
            info += f"属性列: {', '.join(gdf.columns)}\n"
        
        return info

    def start_training(self):
        errors = self.validate_inputs()
        if errors:
            messagebox.showerror("输入错误", "\n".join(errors))
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "训练中..."
        self.log_text.insert(tk.END, "开始训练...\n")
        self.cancel_operation = False

        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        use_chunking = self.use_chunking.get()
        chunk_size = self.chunk_size.get() if use_chunking else None

        def training_thread():
            try:
                accuracy, report = update_or_train_model(
                    tif_file, shp_file, 
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    os.path.join(model_dir, "training_data.joblib"), 
                    force_new,
                    self.update_progress,
                    use_chunking=use_chunking,
                    chunk_size=chunk_size,
                    cancel_check=lambda: self.cancel_operation
                )
                if not self.cancel_operation:
                    self.master.after(0, lambda: self.training_complete(accuracy, report))
                else:
                    self.master.after(0, self.operation_cancelled)
            except Exception as e:
                logging.exception("训练过程中出错")
                error_message = str(e)
                self.master.after(0, lambda: self.training_error(error_message))

        self.current_thread = threading.Thread(target=training_thread)
        self.current_thread.start()

        # 添加取消按钮
        self.cancel_button = ttk.Button(self.train_tab, text="取消", command=self.cancel_current_operation)
        self.cancel_button.grid(row=7, column=0, columnspan=3, pady=10)

    def start_prediction(self):
        errors = self.validate_inputs()
        if errors:
            messagebox.showerror("输入错误", "\n".join(errors))
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "预测中..."
        self.log_text.insert(tk.END, "开始预测...\n")
        self.cancel_operation = False

        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()
        use_chunking = self.use_chunking.get()
        chunk_size = self.chunk_size.get() if use_chunking else None

        def prediction_thread():
            try:
                invalid_count = predict_new_data(
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    tif_file, shp_file, output_file,
                    self.update_progress,
                    use_chunking=use_chunking,
                    chunk_size=chunk_size,
                    cancel_check=lambda: self.cancel_operation
                )
                if not self.cancel_operation:
                    self.master.after(0, lambda: self.prediction_complete(invalid_count))
                else:
                    self.master.after(0, self.operation_cancelled)
            except Exception as e:
                logging.exception("预测过程中出错")
                self.master.after(0, lambda: self.prediction_error(str(e)))

        self.current_thread = threading.Thread(target=prediction_thread)
        self.current_thread.start()

        # 添加取消按钮
        self.cancel_button = ttk.Button(self.predict_tab, text="取消", command=self.cancel_current_operation)
        self.cancel_button.grid(row=7, column=0, columnspan=3, pady=10)

    def update_progress(self, value):
        self.progress_queue.put(value)

    def check_progress_queue(self):
        try:
            while True:
                value = self.progress_queue.get_nowait()
                self.progress['value'] = value
        except queue.Empty:
            pass
        finally:
            self.master.after(100, self.check_progress_queue)

    def training_complete(self, accuracy, report):
        self.enable_buttons()
        self.status_bar['text'] = "训练完成"
        action_text = "创建新模型" if self.model_action.get() == "new" else "更新现有模型"
        self.log_text.insert(tk.END, f"{action_text}完成\n")
        self.log_text.insert(tk.END, f"模型整体精度: {accuracy:.4f}\n")
        self.log_text.insert(tk.END, "各类别精度:\n")
        for class_name, metrics in report.items():
            if isinstance(metrics, dict):
                self.log_text.insert(tk.END, f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}\n")
        messagebox.showinfo("成功", f"{action_text}完成\n模型整体精度: {accuracy:.4f}")
        self.refresh_info()
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def prediction_complete(self, invalid_count):
        self.enable_buttons()
        self.status_bar['text'] = "预测完成"
        self.log_text.insert(tk.END, "预测完成\n")
        if invalid_count > 0:
            self.log_text.insert(tk.END, f"警告：{invalid_count}个几何体无法进行预测，已在输出中标记为'Invalid'\n")
        messagebox.showinfo("成功", f"预测完成\n{invalid_count}个几何体无法预测")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def training_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "训练出错"
        self.log_text.insert(tk.END, f"训练过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"训练过程中出错：{error_message}")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def prediction_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "预测出错"
        self.log_text.insert(tk.END, f"预测过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"预测过程中出错：{error_message}")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def refresh_info(self):
        self.info_text.delete('1.0', tk.END)
        model_dir = self.model_path.get()
        try:
            model = joblib.load(os.path.join(model_dir, "model.joblib"))
            le = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
            
            info = f"模型信息：\n"
            info += f"模型存储路径：{model_dir}\n"
            info += f"特征数量：{model.n_features_in_}\n"
            info += f"类别：{', '.join(le.classes_)}\n"
            info += f"树的数量：{model.n_estimators}\n"
            info += f"最大深度：{model.max_depth}\n"
            info += f"最小分裂样本数：{model.min_samples_split}\n"
            info += f"最小叶子节点样本数：{model.min_samples_leaf}\n"
            info += f"特征选择方式：{model.max_features}\n"
            
            self.info_text.insert(tk.END, info)
        except Exception as e:
            self.info_text.insert(tk.END, f"无法加载模型信息：{str(e)}")

    def disable_buttons(self):
        self.train_button['state'] = 'disabled'
        self.predict_button['state'] = 'disabled'

    def enable_buttons(self):
        self.train_button['state'] = 'normal'
        self.predict_button['state'] = 'normal'

    def show_about(self):
        messagebox.showinfo("关于", "RGB分类模型 v1.0\n\n作者：AI 贵州雏阳\n\n版权所有 © 2024")

    def cancel_current_operation(self):
        self.cancel_operation = True
        self.log_text.insert(tk.END, "正在取消操作...\n")
        self.status_bar['text'] = "正在取消..."

    def operation_cancelled(self):
        self.enable_buttons()
        self.status_bar['text'] = "操作已取消"
        self.log_text.insert(tk.END, "操作已取消\n")
        messagebox.showinfo("已取消", "操作已被用户取消")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def validate_inputs(self):
        errors = []
        if not self.tif_path.get():
            errors.append("请选择TIF文件")
        if not self.shp_path.get():
            errors.append("请选择SHP文件")
        if self.use_chunking.get() and (self.chunk_size.get() <= 0 or self.chunk_size.get() > 10000):
            errors.append("分块大小应在1到10000之间")
        return errors

def main():
    root = tkinterdnd2.TkinterDnD.Tk()
    app = CropClassificationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()

In [None]:

import os
import geopandas as gpd
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, scrolledtext

import joblib
import threading
import queue
import logging
from logging.handlers import RotatingFileHandler
import rasterio

import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split, RandomizedSearchCV
from sklearn.metrics import classification_report, accuracy_score
from sklearn.preprocessing import LabelEncoder
from skimage.feature import graycomatrix, graycoprops
from skimage.color import rgb2hsv
import fiona
from collections import Counter
from scipy.stats import randint, uniform
import tkinterdnd2
# 修改版本号
np.__version__ = '1.23.0'
# 配置日志记录
def setup_logging():
    log_formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
    log_file = 'crop_classification.log'
    log_handler = RotatingFileHandler(log_file, maxBytes=1024 * 1024, backupCount=5)
    log_handler.setFormatter(log_formatter)
    log_handler.setLevel(logging.INFO)

    root_logger = logging.getLogger()
    root_logger.setLevel(logging.INFO)
    root_logger.addHandler(log_handler)

setup_logging()

fiona.supported_drivers['ESRI Shapefile'] = 'rw'
# gpd.io.file.fiona.drvsupport.supported_drivers['ESRI Shapefile'] = 'rw'

def validate_shp_data(gdf, column_name='ZZZW'):
    """
    验证SHP数据是否包含指定的标签列，以及该列是否有空值
    """
    if column_name not in gdf.columns:
        raise ValueError(f"SHP文件中缺少'{column_name}'列")
    
    if gdf[column_name].isnull().any():
        raise ValueError(f"'{column_name}'列中存在空值")

def extract_features(image_path, gdf, progress_callback=None):
    features = []
    valid_geometries = []
    invalid_geometries = []
    total_geometries = len(gdf)

    with rasterio.open(image_path) as src:
        for idx, geometry in enumerate(gdf.geometry):
            try:
                window = rasterio.windows.from_bounds(*geometry.bounds, src.transform)
                masked_image = src.read(window=window, indexes=[1, 2, 3])
                
                if masked_image.shape[0] < 3 or masked_image.size == 0:
                    logging.warning(f"几何体 {idx}: 无数据")
                    invalid_geometries.append(idx)
                    continue
                
                rgb_means = np.nanmean(masked_image, axis=(1, 2))
                rgb_stds = np.nanstd(masked_image, axis=(1, 2))
                
                r, g, b = rgb_means
                exg = 2 * g - r - b
                vari = (g - r) / (g + r - b + 1e-8)
                
                hsv_image = rgb2hsv(np.moveaxis(masked_image, 0, -1))
                hsv_means = np.nanmean(hsv_image, axis=(0, 1))
                
                green_channel = masked_image[1].astype(np.uint8)
                if green_channel.size > 0:
                    glcm = graycomatrix(green_channel, distances=[1], angles=[0], levels=256, symmetric=True, normed=True)
                    contrast = graycoprops(glcm, 'contrast')[0, 0]
                    dissimilarity = graycoprops(glcm, 'dissimilarity')[0, 0]
                    homogeneity = graycoprops(glcm, 'homogeneity')[0, 0]
                    energy = graycoprops(glcm, 'energy')[0, 0]
                    correlation = graycoprops(glcm, 'correlation')[0, 0]
                else:
                    contrast = dissimilarity = homogeneity = energy = correlation = 0
                
                feature = np.concatenate([rgb_means, rgb_stds, [exg, vari], hsv_means, 
                                          [contrast, dissimilarity, homogeneity, energy, correlation]])
                features.append(feature)
                valid_geometries.append(idx)

                if progress_callback:
                    progress_callback((idx + 1) / total_geometries * 100)
            except Exception as e:
                logging.error(f"处理几何体 {idx} 时出错: {str(e)}")
                invalid_geometries.append(idx)
    
    if not features:
        raise ValueError("没有成功提取到任何特征")
    
    return np.array(features), valid_geometries, invalid_geometries
    
def update_or_train_model(image_path, train_shp_path, model_output_path, le_output_path, data_output_path, force_new_model=False, progress_callback=None, hyperparameters=None, cancel_check=None):
    logging.info(f"开始{'创建新模型' if force_new_model else '更新模型'}")
    gdf_train = gpd.read_file(train_shp_path, encoding='utf-8')
    
    # 验证SHP数据
    validate_shp_data(gdf_train)
    
    X_new, valid_indices, _ = extract_features(image_path, gdf_train, progress_callback=progress_callback)
    gdf_train = gdf_train.iloc[valid_indices]
    
    le = LabelEncoder()
    y_new = le.fit_transform(gdf_train['ZZZW'])
    
    if os.path.exists(model_output_path) and os.path.exists(le_output_path) and os.path.exists(data_output_path) and not force_new_model:
        logging.info("加载现有模型并进行更新")
        clf = joblib.load(model_output_path)
        X_old, y_old = joblib.load(data_output_path)
        
        if X_old.shape[1] != X_new.shape[1]:
            logging.warning(f"新旧数据的特征数量不一致。旧数据：{X_old.shape[1]}，新数据：{X_new.shape[1]}")
            logging.info("将重新训练模型")
            X_combined, y_combined = X_new, y_new
        else:
            X_combined = np.vstack((X_old, X_new))
            y_combined = np.concatenate((y_old, y_new))
    else:
        logging.info("创建新模型")
        X_combined, y_combined = X_new, y_new

    # 定义随机搜索的参数范围
    if hyperparameters is None:
        hyperparameters = {
            'n_estimators': randint(10, 1000),
            'max_depth': randint(2, 50),
            'min_samples_split': randint(2, 50),
            'min_samples_leaf': randint(1, 20),
            'max_features': uniform(0.1, 0.9)
        }

    # 创建随机搜索对象
    random_search = RandomizedSearchCV(
        RandomForestClassifier(random_state=42),
        param_distributions=hyperparameters,
        n_iter=50,
        cv=5,
        random_state=42,
        n_jobs=-1
    )

    # 执行随机搜索
    random_search.fit(X_combined, y_combined)

    # 获取最佳模型
    clf = random_search.best_estimator_

    # 评估模型
    y_pred = clf.predict(X_combined)
    accuracy = accuracy_score(y_combined, y_pred)
    report = classification_report(y_combined, y_pred, target_names=le.classes_, output_dict=True)

    # 保存模型和数据
    joblib.dump(clf, model_output_path)
    joblib.dump(le, le_output_path)
    joblib.dump((X_combined, y_combined), data_output_path)

    logging.info(f"模型已保存到: {model_output_path}")
    logging.info(f"标签编码器已保存到: {le_output_path}")
    logging.info(f"训练数据已保存到: {data_output_path}")
    logging.info(f"最佳参数: {random_search.best_params_}")
    logging.info(f"模型整体精度: {accuracy}")
    logging.info("各类别精度:")
    for class_name, metrics in report.items():
        if isinstance(metrics, dict):
            logging.info(f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}")

    if cancel_check and cancel_check():
        raise InterruptedError("操作被用户取消")

    return accuracy, report

def predict_new_data(model_path, le_path, new_image_path, new_shp_path, output_shp_path, progress_callback=None, use_chunking=False, chunk_size=None, cancel_check=None):
    logging.info("开始预测新数据")
    clf = joblib.load(model_path)
    le = joblib.load(le_path)

    gdf_new = gpd.read_file(new_shp_path, encoding='utf-8')
    X_new, valid_indices, invalid_indices = extract_features(new_image_path, gdf_new, use_chunking=use_chunking, chunk_size=chunk_size, progress_callback=progress_callback)
    
    if X_new.shape[1] != clf.n_features_in_:
        raise ValueError(f"错误：特征数量不匹配。模型期望 {clf.n_features_in_} 个特征，但提供了 {X_new.shape[1]} 个特征。")
    
    y_pred = clf.predict(X_new)
    y_proba = clf.predict_proba(X_new)
    
    # 为有效的几何体添加预测结果
    gdf_new.loc[valid_indices, 'ZZZW'] = le.inverse_transform(y_pred)
    gdf_new.loc[valid_indices, 'ZZZW_proba'] = np.max(y_proba, axis=1)
    
    # 为无效的几何体添加标记
    gdf_new.loc[invalid_indices, 'ZZZW'] = 'Invalid'
    gdf_new.loc[invalid_indices, 'ZZZW_proba'] = 0
    
    gdf_new.to_file(output_shp_path, encoding='utf-8')

    if cancel_check and cancel_check():
        raise InterruptedError("操作被用户取消")

    logging.info(f"预测结果已保存到: {output_shp_path}")
    return len(invalid_indices)

class CropClassificationApp:
    def __init__(self, master):
        self.master = master
        master.title("RGB分类模型")
        master.geometry("800x600")

        # 初始化所有需要的属性
        self.tif_path = tk.StringVar()
        self.shp_path = tk.StringVar()
        self.model_path = tk.StringVar(value=os.getcwd())
        self.model_action = tk.StringVar(value="update")
        self.predict_tif_path = tk.StringVar()
        self.predict_shp_path = tk.StringVar()
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        self.output_path = tk.StringVar()

        # 添加这些行来初始化超参数变量
        self.n_estimators_min = tk.IntVar(value=50)
        self.n_estimators_max = tk.IntVar(value=300)
        self.max_depth_min = tk.IntVar(value=5)
        self.max_depth_max = tk.IntVar(value=50)

        self.create_widgets()
        self.create_menu()

        self.progress_queue = queue.Queue()
        self.master.after(100, self.check_progress_queue)

        self.cancel_operation = False

    def create_widgets(self):
        self.notebook = ttk.Notebook(self.master)
        self.notebook.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        self.train_tab = ttk.Frame(self.notebook)
        self.predict_tab = ttk.Frame(self.notebook)
        self.info_tab = ttk.Frame(self.notebook)

        self.notebook.add(self.train_tab, text="训练模型")
        self.notebook.add(self.predict_tab, text="模型预测")
        self.notebook.add(self.info_tab, text="模型信息")

        self.setup_train_tab()
        self.setup_predict_tab()
        self.setup_info_tab()

        self.status_bar = ttk.Label(self.master, text="就绪", relief=tk.SUNKEN, anchor=tk.W)
        self.status_bar.pack(side=tk.BOTTOM, fill=tk.X)

        self.progress = ttk.Progressbar(self.master, length=780, mode='determinate')
        self.progress.pack(pady=10)

        self.log_text = scrolledtext.ScrolledText(self.master, wrap=tk.WORD, width=95, height=10)
        self.log_text.pack(padx=10, pady=10)

    def create_menu(self):
        menubar = tk.Menu(self.master)
        self.master.config(menu=menubar)

        file_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="文件", menu=file_menu)
        file_menu.add_command(label="退出", command=self.master.quit)

        help_menu = tk.Menu(menubar, tearoff=0)
        menubar.add_cascade(label="帮助", menu=help_menu)
        help_menu.add_command(label="关于", command=self.show_about)

    def setup_train_tab(self):
        ttk.Label(self.train_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.tif_path = tk.StringVar()
        tif_entry = ttk.Entry(self.train_tab, textvariable=self.tif_path, width=50)
        tif_entry.grid(row=0, column=1, padx=5, pady=5)
        tif_entry.drop_target_register(tkinterdnd2.DND_FILES)
        tif_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.tif_path))
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)
        ttk.Button(self.train_tab, text="预览", command=lambda: self.preview_file(self.tif_path)).grid(row=0, column=3, padx=5, pady=5)

        ttk.Label(self.train_tab, text="选择带有标签的SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.shp_path = tk.StringVar()
        shp_entry = ttk.Entry(self.train_tab, textvariable=self.shp_path, width=50)
        shp_entry.grid(row=1, column=1, padx=5, pady=5)
        shp_entry.drop_target_register(tkinterdnd2.DND_FILES)
        shp_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.shp_path))
        ttk.Button(self.train_tab, text="浏览", command=lambda: self.browse_file(self.shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)
        ttk.Button(self.train_tab, text="预览", command=lambda: self.preview_file(self.shp_path)).grid(row=1, column=3, padx=5, pady=5)

        ttk.Label(self.train_tab, text="模型存储路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.train_tab, textvariable=self.model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.train_tab, text="浏览", command=self.browse_model_path).grid(row=2, column=2, padx=5, pady=5)

        self.model_action = tk.StringVar(value="update")
        ttk.Radiobutton(self.train_tab, text="更新现有模型", variable=self.model_action, value="update").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        ttk.Radiobutton(self.train_tab, text="创建新模型", variable=self.model_action, value="new").grid(row=3, column=1, sticky="w", padx=5, pady=5)

        # 添加超参数设置控件
        ttk.Label(self.train_tab, text="n_estimators:").grid(row=4, column=0, sticky="w", padx=5, pady=5)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_min, width=5).grid(row=4, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=4, column=1)
        ttk.Entry(self.train_tab, textvariable=self.n_estimators_max, width=5).grid(row=4, column=1, sticky="e", padx=5, pady=5)

        ttk.Label(self.train_tab, text="max_depth:").grid(row=5, column=0, sticky="w", padx=5, pady=5)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_min, width=5).grid(row=5, column=1, sticky="w", padx=5, pady=5)
        ttk.Label(self.train_tab, text="-").grid(row=5, column=1)
        ttk.Entry(self.train_tab, textvariable=self.max_depth_max, width=5).grid(row=5, column=1, sticky="e", padx=5, pady=5)
        

        self.train_button = ttk.Button(self.train_tab, text="开始训练", command=self.start_training)
        self.train_button.grid(row=6, column=0, columnspan=3, pady=20)

    def setup_predict_tab(self):
        ttk.Label(self.predict_tab, text="选择TIF文件:").grid(row=0, column=0, sticky="w", padx=5, pady=5)
        self.predict_tif_path = tk.StringVar()
        predict_tif_entry = ttk.Entry(self.predict_tab, textvariable=self.predict_tif_path, width=50)
        predict_tif_entry.grid(row=0, column=1, padx=5, pady=5)
        predict_tif_entry.drop_target_register(tkinterdnd2.DND_FILES)
        predict_tif_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.predict_tif_path))
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_tif_path, [("TIF files", "*.tif")])).grid(row=0, column=2, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="预览", command=lambda: self.preview_file(self.predict_tif_path)).grid(row=0, column=3, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="选择SHP文件:").grid(row=1, column=0, sticky="w", padx=5, pady=5)
        self.predict_shp_path = tk.StringVar()
        predict_shp_entry = ttk.Entry(self.predict_tab, textvariable=self.predict_shp_path, width=50)
        predict_shp_entry.grid(row=1, column=1, padx=5, pady=5)
        predict_shp_entry.drop_target_register(tkinterdnd2.DND_FILES)
        predict_shp_entry.dnd_bind('<<Drop>>', lambda e: self.drop_file(e, self.predict_shp_path))
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_file(self.predict_shp_path, [("SHP files", "*.shp")])).grid(row=1, column=2, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="预览", command=lambda: self.preview_file(self.predict_shp_path)).grid(row=1, column=3, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="模型路径:").grid(row=2, column=0, sticky="w", padx=5, pady=5)
        self.predict_model_path = tk.StringVar(value=os.getcwd())
        ttk.Entry(self.predict_tab, textvariable=self.predict_model_path, width=50).grid(row=2, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=self.browse_predict_model_path).grid(row=2, column=2, padx=5, pady=5)

        ttk.Label(self.predict_tab, text="输出文件路径:").grid(row=3, column=0, sticky="w", padx=5, pady=5)
        self.output_path = tk.StringVar()
        ttk.Entry(self.predict_tab, textvariable=self.output_path, width=50).grid(row=3, column=1, padx=5, pady=5)
        ttk.Button(self.predict_tab, text="浏览", command=lambda: self.browse_save_file(self.output_path, [("SHP files", "*.shp")])).grid(row=3, column=2, padx=5, pady=5)


        self.predict_button = ttk.Button(self.predict_tab, text="开始预测", command=self.start_prediction)
        self.predict_button.grid(row=6, column=0, columnspan=3, pady=20)

    def setup_info_tab(self):
        self.info_text = scrolledtext.ScrolledText(self.info_tab, wrap=tk.WORD, width=90, height=20)
        self.info_text.pack(padx=10, pady=10)

        ttk.Button(self.info_tab, text="获取模型信息", command=self.refresh_info).pack(pady=10)

    def browse_file(self, path_var, file_types):
        filename = filedialog.askopenfilename(filetypes=file_types)
        if filename:
            path_var.set(filename)

    def browse_save_file(self, path_var, file_types):
        filename = filedialog.asksaveasfilename(filetypes=file_types, defaultextension=file_types[0][1])
        if filename:
            path_var.set(filename)

    def browse_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.model_path.set(path)

    def browse_predict_model_path(self):
        path = filedialog.askdirectory()
        if path:
            self.predict_model_path.set(path)

    def toggle_chunk_size(self):
        state = 'normal' if self.use_chunking.get() else 'disabled'
        self.chunk_size_label['state'] = state
        self.chunk_size_entry['state'] = state
        self.chunk_size_label_predict['state'] = state
        self.chunk_size_entry_predict['state'] = state

    def drop_file(self, event, path_var):
        file_path = event.data
        if file_path.startswith('{') and file_path.endswith('}'):
            file_path = file_path[1:-1]
        file_extension = os.path.splitext(file_path)[1].lower()
        if (path_var in [self.tif_path, self.predict_tif_path] and file_extension == '.tif') or \
           (path_var in [self.shp_path, self.predict_shp_path] and file_extension == '.shp'):
            path_var.set(file_path)
        else:
            messagebox.showerror("错误", "请拖放正确的文件类型")

    def preview_file(self, path_var):
        file_path = path_var.get()
        if not file_path:
            messagebox.showwarning("警告", "请先选择文件")
            return
        
        try:
            info = self.get_file_info(file_path)
            messagebox.showinfo("文件预览", info)
        except Exception as e:
            messagebox.showerror("错误", f"无法预览文件：{str(e)}")

    @staticmethod
    def get_file_info(file_path):
        file_size = os.path.getsize(file_path) / (1024 * 1024)  # Size in MB
        file_type = os.path.splitext(file_path)[1].lower()
        
        info = f"文件路径: {file_path}\n"
        info += f"文件大小: {file_size:.2f} MB\n"
        info += f"文件类型: {file_type}\n"
        
        if file_type == '.tif':
            with rasterio.open(file_path) as src:
                info += f"图像大小: {src.width} x {src.height}\n"
                info += f"波段数: {src.count}\n"
                info += f"坐标系统: {src.crs}\n"
        elif file_type == '.shp':
            gdf = gpd.read_file(file_path)
            info += f"要素数量: {len(gdf)}\n"
            info += f"几何类型: {gdf.geom_type.iloc[0]}\n"
            info += f"属性列: {', '.join(gdf.columns)}\n"
        
        return info

    def start_training(self):
        errors = self.validate_inputs()
        if errors:
            messagebox.showerror("输入错误", "\n".join(errors))
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "训练中..."
        self.log_text.insert(tk.END, "开始训练...\n")
        self.cancel_operation = False

        tif_file = self.tif_path.get()
        shp_file = self.shp_path.get()
        model_dir = self.model_path.get()
        force_new = self.model_action.get() == "new"
        # 添加超参数设置
        hyperparameters = {
            'n_estimators': randint(self.n_estimators_min.get(), self.n_estimators_max.get()),
            'max_depth': randint(self.max_depth_min.get(), self.max_depth_max.get()),
            'min_samples_split': randint(2, 20),
            'min_samples_leaf': randint(1, 10),
            'max_features': uniform(0.1, 0.9)
        }
        def training_thread():
            try:
                accuracy, report = update_or_train_model(
                    tif_file, shp_file, 
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    os.path.join(model_dir, "training_data.joblib"), 
                    force_new,
                    self.update_progress,
                     hyperparameters=hyperparameters,
                    cancel_check=lambda: self.cancel_operation
                )
                if not self.cancel_operation:
                    self.master.after(0, lambda: self.training_complete(accuracy, report))
                else:
                    self.master.after(0, self.operation_cancelled)
            except Exception as e:
                logging.exception("训练过程中出错")
                error_message = str(e)
                self.master.after(0, lambda: self.training_error(error_message))

        self.current_thread = threading.Thread(target=training_thread)
        self.current_thread.start()

        # 添加取消按钮
        self.cancel_button = ttk.Button(self.train_tab, text="取消", command=self.cancel_current_operation)
        self.cancel_button.grid(row=7, column=0, columnspan=3, pady=10)

    def start_prediction(self):
        errors = self.validate_inputs()
        if errors:
            messagebox.showerror("输入错误", "\n".join(errors))
            return

        self.disable_buttons()
        self.progress['value'] = 0
        self.status_bar['text'] = "预测中..."
        self.log_text.insert(tk.END, "开始预测...\n")
        self.cancel_operation = False

        tif_file = self.predict_tif_path.get()
        shp_file = self.predict_shp_path.get()
        model_dir = self.predict_model_path.get()
        output_file = self.output_path.get()
        # 添加文件存在性检查
        if not self.check_file_exists(tif_file):
            self.prediction_error(f"TIF文件不存在: {tif_file}")
            return
        if not self.check_file_exists(shp_file):
            self.prediction_error(f"SHP文件不存在: {shp_file}")
            return
        if not self.check_file_exists(os.path.join(model_dir, "model.joblib")):
            self.prediction_error(f"模型文件不存在: {os.path.join(model_dir, 'model.joblib')}")
            return
        def prediction_thread():
            try:
                invalid_count = predict_new_data(
                    os.path.join(model_dir, "model.joblib"), 
                    os.path.join(model_dir, "label_encoder.joblib"), 
                    tif_file, shp_file, output_file,
                    self.update_progress,
                    cancel_check=lambda: self.cancel_operation
                )
                if not self.cancel_operation:
                    self.master.after(0, lambda: self.prediction_complete(invalid_count))
                else:
                    self.master.after(0, self.operation_cancelled)
            except Exception as e:
                logging.exception("预测过程中出错")
                self.master.after(0, lambda: self.prediction_error(str(e)))

        self.current_thread = threading.Thread(target=prediction_thread)
        self.current_thread.start()

        # 添加取消按钮
        self.cancel_button = ttk.Button(self.predict_tab, text="取消", command=self.cancel_current_operation)
        self.cancel_button.grid(row=7, column=0, columnspan=3, pady=10)
    def check_file_exists(self, file_path):
        return os.path.exists(file_path)
    def update_progress(self, value):
        self.progress_queue.put(value)

    def check_progress_queue(self):
        try:
            while True:
                value = self.progress_queue.get_nowait()
                self.progress['value'] = value
        except queue.Empty:
            pass
        finally:
            self.master.after(100, self.check_progress_queue)

    def training_complete(self, accuracy, report):
        self.enable_buttons()
        self.status_bar['text'] = "训练完成"
        action_text = "创建新模型" if self.model_action.get() == "new" else "更新现有模型"
        self.log_text.insert(tk.END, f"{action_text}完成\n")
        self.log_text.insert(tk.END, f"模型整体精度: {accuracy:.4f}\n")
        self.log_text.insert(tk.END, "各类别精度:\n")
        for class_name, metrics in report.items():
            if isinstance(metrics, dict):
                self.log_text.insert(tk.END, f"{class_name}: 精度 = {metrics['precision']:.2f}, 召回率 = {metrics['recall']:.2f}, F1分数 = {metrics['f1-score']:.2f}\n")
        messagebox.showinfo("成功", f"{action_text}完成\n模型整体精度: {accuracy:.4f}")
        self.refresh_info()
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def prediction_complete(self, invalid_count):
        self.enable_buttons()
        self.status_bar['text'] = "预测完成"
        self.log_text.insert(tk.END, "预测完成\n")
        if invalid_count > 0:
            self.log_text.insert(tk.END, f"警告：{invalid_count}个几何体无法进行预测，已在输出中标记为'Invalid'\n")
        messagebox.showinfo("成功", f"预测完成\n{invalid_count}个几何体无法预测")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def training_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "训练出错"
        self.log_text.insert(tk.END, f"训练过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"训练过程中出错：{error_message}")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def prediction_error(self, error_message):
        self.enable_buttons()
        self.status_bar['text'] = "预测出错"
        self.log_text.insert(tk.END, f"预测过程中出错：{error_message}\n")
        messagebox.showerror("错误", f"预测过程中出错：{error_message}")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def refresh_info(self):
        self.info_text.delete('1.0', tk.END)
        model_dir = self.model_path.get()
        try:
            model = joblib.load(os.path.join(model_dir, "model.joblib"))
            le = joblib.load(os.path.join(model_dir, "label_encoder.joblib"))
            
            info = f"模型信息：\n"
            info += f"模型存储路径：{model_dir}\n"
            info += f"特征数量：{model.n_features_in_}\n"
            info += f"类别：{', '.join(le.classes_)}\n"
            info += f"树的数量：{model.n_estimators}\n"
            info += f"最大深度：{model.max_depth}\n"
            info += f"最小分裂样本数：{model.min_samples_split}\n"
            info += f"最小叶子节点样本数：{model.min_samples_leaf}\n"
            info += f"特征选择方式：{model.max_features}\n"
            
            # 尝试加载训练数据并获取样本数量
            try:
                X, y = joblib.load(os.path.join(model_dir, "training_data.joblib"))
                info += f"参训练样本数量：{len(X)}\n"
            except Exception as e:
                info += f"无法加载训练数据信息：{str(e)}\n"
            
            self.info_text.insert(tk.END, info)
        except Exception as e:
            self.info_text.insert(tk.END, f"无法加载模型信息：{str(e)}")

    def disable_buttons(self):
        self.train_button['state'] = 'disabled'
        self.predict_button['state'] = 'disabled'

    def enable_buttons(self):
        self.train_button['state'] = 'normal'
        self.predict_button['state'] = 'normal'

    def show_about(self):
        messagebox.showinfo("关于", "RGB分类模型 v1.0\n\n作者：AI 贵州雏阳\n\n版权所有 © 2024")

    def cancel_current_operation(self):
        self.cancel_operation = True
        self.log_text.insert(tk.END, "正在取消操作...\n")
        self.status_bar['text'] = "正在取消..."

    def operation_cancelled(self):
        self.enable_buttons()
        self.status_bar['text'] = "操作已取消"
        self.log_text.insert(tk.END, "操作已取消\n")
        messagebox.showinfo("已取消", "操作已被用户取消")
        if hasattr(self, 'cancel_button'):
            self.cancel_button.destroy()

    def validate_inputs(self):
        errors = []
        if self.notebook.index(self.notebook.select()) == 0:  # Training tab
            if not self.tif_path.get():
                errors.append("请选择训练用TIF文件")
            if not self.shp_path.get():
                errors.append("请选择训练用SHP文件")
        else:  # Prediction tab
            if not self.predict_tif_path.get():
                errors.append("请选择预测用TIF文件")
            if not self.predict_shp_path.get():
                errors.append("请选择预测用SHP文件")
            if not self.predict_model_path.get():
                errors.append("请选择模型路径")
            if not self.output_path.get():
                errors.append("请选择输出文件路径")
        return errors

def main():
    root = tkinterdnd2.TkinterDnD.Tk()
    app = CropClassificationApp(root)
    root.mainloop()

if __name__ == "__main__":
    main()
