In [None]:
#UI1

In [16]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from matplotlib.figure import Figure
import matplotlib.patches as patches
import os


class DataVisualizationApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Time Series Visualization")
        self.root.geometry("1600x1000")

        # 数据存储
        self.npz_data = None
        self.npy_data = None
        self.arr_0 = None  # (sample, shape_number, shape_length)
        self.arr_1 = None  # (sample, shape_number, VP)
        self.x_train = None  # (sample, length, dimension_number)

        # 可视化相关变量
        self.current_zoom = 1.0
        self.pan_offset = [0, 0]

        # 创建主要布局
        self.create_main_layout()

    def create_main_layout(self):
        """创建主要布局"""
        # 创建主要的分割布局
        main_paned = ttk.PanedWindow(self.root, orient=tk.HORIZONTAL)
        main_paned.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        # 左侧菜单栏
        self.create_left_menu(main_paned)

        # 右侧可视化区域
        self.create_right_visualization(main_paned)

    def create_left_menu(self, parent):
        """创建左侧菜单栏"""
        menu_frame = ttk.Frame(parent, width=450)
        parent.add(menu_frame, weight=1)

        # 使用Canvas和Scrollbar创建可滚动的菜单
        canvas = tk.Canvas(menu_frame)
        scrollbar = ttk.Scrollbar(menu_frame, orient="vertical", command=canvas.yview)
        scrollable_frame = ttk.Frame(canvas)

        scrollable_frame.bind(
            "<Configure>",
            lambda e: canvas.configure(scrollregion=canvas.bbox("all"))
        )

        canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)

        canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")

        # 第一部分：文件加载
        self.create_file_loading_section(scrollable_frame)

        # 第二部分：上半部分可视化控制
        self.create_upper_viz_controls(scrollable_frame)

        # 第三部分：下半部分可视化控制
        self.create_lower_viz_controls(scrollable_frame)

        # 控制按钮
        self.create_control_buttons(scrollable_frame)

        # Shape位置查看功能
        self.create_shape_position_controls(scrollable_frame)

    def create_file_loading_section(self, parent):
        """创建文件加载部分"""
        section1 = ttk.LabelFrame(parent, text="file download", padding=10)
        section1.pack(fill=tk.X, padx=5, pady=5)

        # NPZ文件选择
        ttk.Label(section1, text="Choose NPZ File:").pack(anchor=tk.W)
        npz_frame = ttk.Frame(section1)
        npz_frame.pack(fill=tk.X, pady=2)

        self.npz_path_var = tk.StringVar()
        self.npz_entry = ttk.Entry(npz_frame, textvariable=self.npz_path_var, state="readonly")
        self.npz_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
        ttk.Button(npz_frame, text="Browse", command=self.browse_npz_file).pack(side=tk.RIGHT, padx=(5, 0))

        # NPY文件选择
        ttk.Label(section1, text="Choose X_train.npy File:").pack(anchor=tk.W, pady=(10, 0))
        npy_frame = ttk.Frame(section1)
        npy_frame.pack(fill=tk.X, pady=2)

        self.npy_path_var = tk.StringVar()
        self.npy_entry = ttk.Entry(npy_frame, textvariable=self.npy_path_var, state="readonly")
        self.npy_entry.pack(side=tk.LEFT, fill=tk.X, expand=True)
        ttk.Button(npy_frame, text="Browse", command=self.browse_npy_file).pack(side=tk.RIGHT, padx=(5, 0))

        # 加载按钮
        ttk.Button(section1, text="Download Data File", command=self.load_data).pack(pady=(10, 0))

        # 数据信息显示
        self.data_info_text = tk.Text(section1, height=4, width=40)
        self.data_info_text.pack(pady=(5, 0), fill=tk.X)
        self.data_info_text.insert(tk.END, "unknown dataset")
        self.data_info_text.config(state=tk.DISABLED)

    def create_upper_viz_controls(self, parent):
        """创建上半部分可视化控制"""
        section2 = ttk.LabelFrame(parent, text="The Upper Part Controlled", padding=10)
        section2.pack(fill=tk.X, padx=5, pady=5)

        # 图片数量选择
        ttk.Label(section2, text="Display the Number of Plot:").pack(anchor=tk.W)
        self.plot_count_var = tk.IntVar(value=2)
        plot_count_frame = ttk.Frame(section2)
        plot_count_frame.pack(fill=tk.X, pady=2)
        for i in [1, 2, 3, 4]:
            ttk.Radiobutton(plot_count_frame, text=str(i), variable=self.plot_count_var,
                            value=i, command=self.update_sequence_controls).pack(side=tk.LEFT, padx=5)

        # 创建可滚动的序列参数控制区域
        self.sequence_canvas = tk.Canvas(section2, height=200)
        seq_scrollbar = ttk.Scrollbar(section2, orient="vertical", command=self.sequence_canvas.yview)
        self.sequence_frame = ttk.Frame(self.sequence_canvas)

        self.sequence_frame.bind(
            "<Configure>",
            lambda e: self.sequence_canvas.configure(scrollregion=self.sequence_canvas.bbox("all"))
        )

        self.sequence_canvas.create_window((0, 0), window=self.sequence_frame, anchor="nw")
        self.sequence_canvas.configure(yscrollcommand=seq_scrollbar.set)

        self.sequence_canvas.pack(side="left", fill="both", expand=True, pady=(10, 0))
        seq_scrollbar.pack(side="right", fill="y", pady=(10, 0))

        # 序列控制变量
        self.sequence_controls = []

        # 初始化序列控制
        self.update_sequence_controls()

    def update_sequence_controls(self):
        """更新序列控制界面"""
        # 清除现有控件
        for widget in self.sequence_frame.winfo_children():
            widget.destroy()

        self.sequence_controls = []
        plot_count = self.plot_count_var.get()

        for i in range(plot_count):
            # 为每个序列创建控制组
            seq_frame = ttk.LabelFrame(self.sequence_frame, text=f"sequence {i + 1}", padding=5)
            seq_frame.pack(fill=tk.X, pady=2)

            # 创建控制变量字典
            controls = {}

            # Sample序号（从1开始）
            ttk.Label(seq_frame, text="Sample:").grid(row=0, column=0, sticky="w", padx=2)
            controls['sample'] = tk.IntVar(value=1)
            sample_spinbox = ttk.Spinbox(seq_frame, from_=1, to=1000, textvariable=controls['sample'], width=10)
            sample_spinbox.grid(row=0, column=1, padx=2, pady=1)

            # Dimension序号（从1开始）
            ttk.Label(seq_frame, text="Dim:").grid(row=0, column=2, sticky="w", padx=2)
            controls['dimension'] = tk.IntVar(value=1)
            dim_spinbox = ttk.Spinbox(seq_frame, from_=1, to=1000, textvariable=controls['dimension'], width=10)
            dim_spinbox.grid(row=0, column=3, padx=2, pady=1)

            # 初始时间
            ttk.Label(seq_frame, text="time0:").grid(row=1, column=0, sticky="w", padx=2)
            controls['start_time'] = tk.IntVar(value=0)
            start_spinbox = ttk.Spinbox(seq_frame, from_=0, to=999, textvariable=controls['start_time'], width=10)
            start_spinbox.grid(row=1, column=1, padx=2, pady=1)

            # 序列长度
            ttk.Label(seq_frame, text="length:").grid(row=1, column=2, sticky="w", padx=2)
            controls['length'] = tk.IntVar(value=100)
            length_spinbox = ttk.Spinbox(seq_frame, from_=1, to=999, textvariable=controls['length'], width=10)
            length_spinbox.grid(row=1, column=3, padx=2, pady=1)

            # 保存控件引用以便后续更新范围
            controls['sample_spinbox'] = sample_spinbox
            controls['dimension_spinbox'] = dim_spinbox
            controls['start_spinbox'] = start_spinbox
            controls['length_spinbox'] = length_spinbox

            self.sequence_controls.append(controls)

    def create_lower_viz_controls(self, parent):
        """创建下半部分可视化控制"""
        section3 = ttk.LabelFrame(parent, text="The Lower Part Controlled", padding=10)
        section3.pack(fill=tk.X, padx=5, pady=5)

        # Sample选择（从1开始）
        ttk.Label(section3, text="Sample Number (From 1):").pack(anchor=tk.W)
        self.lower_sample_var = tk.IntVar(value=1)
        self.lower_sample_spinbox = ttk.Spinbox(section3, from_=1, to=1000, textvariable=self.lower_sample_var,
                                                width=20)
        self.lower_sample_spinbox.pack(anchor=tk.W, pady=2)

        # Shape number选择（从1开始）
        ttk.Label(section3, text="Shape Number (From 1):").pack(anchor=tk.W, pady=(10, 0))
        self.shape_number_var = tk.IntVar(value=1)
        self.shape_number_spinbox = ttk.Spinbox(section3, from_=1, to=1000, textvariable=self.shape_number_var,
                                                width=20)
        self.shape_number_spinbox.pack(anchor=tk.W, pady=2)

    def create_control_buttons(self, parent):
        """创建控制按钮"""
        button_frame = ttk.LabelFrame(parent, text="Plot Controlled", padding=10)
        button_frame.pack(fill=tk.X, padx=5, pady=5)

        # 更新按钮
        ttk.Button(button_frame, text="Update Plot", command=self.update_plots,
                   style="Accent.TButton").pack(fill=tk.X, pady=2)

        # 图像控制按钮
        control_frame = ttk.Frame(button_frame)
        control_frame.pack(fill=tk.X, pady=(10, 0))

        ttk.Button(control_frame, text="Reset Plot", command=self.reset_view).pack(side=tk.TOP, fill=tk.X, pady=1)

        zoom_frame = ttk.Frame(control_frame)
        zoom_frame.pack(fill=tk.X, pady=2)
        ttk.Button(zoom_frame, text="Enlarge the Plot", command=self.zoom_in).pack(side=tk.LEFT, padx=2, fill=tk.X, expand=True)
        ttk.Button(zoom_frame, text="Zoom Out the Plot", command=self.zoom_out).pack(side=tk.RIGHT, padx=2, fill=tk.X,
                                                                            expand=True)

        ttk.Button(control_frame, text="Move the Plot", command=self.enable_pan_mode).pack(side=tk.TOP, fill=tk.X, pady=1)

    def create_shape_position_controls(self, parent):
        """创建Shape位置查看控制"""
        section4 = ttk.LabelFrame(parent, text="Shape Position Comparison", padding=10)
        section4.pack(fill=tk.X, padx=5, pady=5)

        # 第一个对比组
        compare1_frame = ttk.LabelFrame(section4, text="Comparison 1", padding=5)
        compare1_frame.pack(fill=tk.X, pady=2)

        # Sample 1
        sample1_frame = ttk.Frame(compare1_frame)
        sample1_frame.pack(fill=tk.X)
        ttk.Label(sample1_frame, text="Sample:").pack(side=tk.LEFT)
        self.pos_sample1_var = tk.IntVar(value=1)
        self.pos_sample1_spinbox = ttk.Spinbox(sample1_frame, from_=1, to=1000, textvariable=self.pos_sample1_var,
                                               width=10)
        self.pos_sample1_spinbox.pack(side=tk.LEFT, padx=5)

        ttk.Label(sample1_frame, text="Shape:").pack(side=tk.LEFT, padx=(10, 0))
        self.pos_shape1_var = tk.IntVar(value=1)
        self.pos_shape1_spinbox = ttk.Spinbox(sample1_frame, from_=1, to=1000, textvariable=self.pos_shape1_var,
                                              width=10)
        self.pos_shape1_spinbox.pack(side=tk.LEFT, padx=5)

        #ttk.Button(compare1_frame, text="Postion 1", command=lambda: self.show_shape_position(1)).pack(fill=tk.X, pady=5)

        # 第二个对比组
        compare2_frame = ttk.LabelFrame(section4, text="Comparison 2", padding=5)
        compare2_frame.pack(fill=tk.X, pady=2)

        # Sample 2
        sample2_frame = ttk.Frame(compare2_frame)
        sample2_frame.pack(fill=tk.X)
        ttk.Label(sample2_frame, text="Sample:").pack(side=tk.LEFT)
        self.pos_sample2_var = tk.IntVar(value=1)
        self.pos_sample2_spinbox = ttk.Spinbox(sample2_frame, from_=1, to=1000, textvariable=self.pos_sample2_var,
                                               width=10)
        self.pos_sample2_spinbox.pack(side=tk.LEFT, padx=5)

        ttk.Label(sample2_frame, text="Shape:").pack(side=tk.LEFT, padx=(10, 0))
        self.pos_shape2_var = tk.IntVar(value=1)
        self.pos_shape2_spinbox = ttk.Spinbox(sample2_frame, from_=1, to=1000, textvariable=self.pos_shape2_var,
                                              width=10)
        self.pos_shape2_spinbox.pack(side=tk.LEFT, padx=5)

        # ttk.Button(compare2_frame, text="Position 2", command=lambda: self.show_shape_position(2)).pack(fill=tk.X,pady=5)

        # 同时对比按钮
        ttk.Button(section4, text="Conpare Tow Plot", command=self.compare_shape_positions).pack(fill=tk.X, pady=10)

    def create_right_visualization(self, parent):
        """创建右侧可视化区域"""
        viz_frame = ttk.Frame(parent)
        parent.add(viz_frame, weight=3)

        # 创建上下分割的可视化区域
        viz_paned = ttk.PanedWindow(viz_frame, orient=tk.VERTICAL)
        viz_paned.pack(fill=tk.BOTH, expand=True)

        # 上半部分
        upper_frame = ttk.LabelFrame(viz_paned, text="Time Series Data Visualization")
        viz_paned.add(upper_frame, weight=1)

        # 下半部分
        lower_frame = ttk.LabelFrame(viz_paned, text="NPZ File Visualization")
        viz_paned.add(lower_frame, weight=1)

        # 创建matplotlib图形
        self.create_upper_plots(upper_frame)
        self.create_lower_plots(lower_frame)

    def create_upper_plots(self, parent):
        """创建上半部分的图形"""
        self.upper_fig = Figure(figsize=(14, 8), dpi=100)
        self.upper_canvas = FigureCanvasTkAgg(self.upper_fig, parent)
        self.upper_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # 添加工具栏
        self.upper_toolbar = NavigationToolbar2Tk(self.upper_canvas, parent)
        self.upper_toolbar.update()

    def create_lower_plots(self, parent):
        """创建下半部分的图形"""
        self.lower_fig = Figure(figsize=(14, 8), dpi=100)
        self.lower_canvas = FigureCanvasTkAgg(self.lower_fig, parent)
        self.lower_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # 添加工具栏
        self.lower_toolbar = NavigationToolbar2Tk(self.lower_canvas, parent)
        self.lower_toolbar.update()

    def browse_npz_file(self):
        """浏览NPZ文件"""
        filename = filedialog.askopenfilename(
            title="Choose NPZ File",
            filetypes=[("NPZ files", "*.npz"), ("All files", "*.*")]
        )
        if filename:
            self.npz_path_var.set(filename)

    def browse_npy_file(self):
        """浏览NPY文件"""
        filename = filedialog.askopenfilename(
            title="Choose NPY File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            self.npy_path_var.set(filename)

    def load_data(self):
        """加载数据文件"""
        try:
            npz_path = self.npz_path_var.get()
            npy_path = self.npy_path_var.get()

            if not npz_path or not npy_path:
                messagebox.showerror("Error", "Please Choose NPZ and NPY File at first")
                return

            # 加载NPZ文件
            self.npz_data = np.load(npz_path)
            self.arr_0 = self.npz_data['arr_0']  # (sample, shape_number, shape_length)
            self.arr_1 = self.npz_data['arr_1']  # (sample, shape_number, VP)

            # 加载NPY文件
            self.x_train = np.load(npy_path)  # (sample, length, dimension_number)

            # 验证数据格式
            if len(self.arr_0.shape) != 3 or len(self.arr_1.shape) != 3:
                raise ValueError("The NPZ file data format is incorrect!")
            if len(self.x_train.shape) != 3:
                raise ValueError("The NPY file data format is incorrect!")

            # 更新控件范围
            self.update_control_ranges()

            # 更新信息显示
            info_text = f"The data load is successful!\n\n"
            info_text += f"NPZ File Information:\n"
            info_text += f"  arr_0 shape: {self.arr_0.shape}\n"
            info_text += f"  arr_1 shape: {self.arr_1.shape}\n\n"
            info_text += f"NPY File Information:\n"
            info_text += f"  x_train shape: {self.x_train.shape}\n\n"
            info_text += f"data range:\n"
            info_text += f"  Sample Number: {self.x_train.shape[0]}\n"
            info_text += f"  Time Length: {self.x_train.shape[1]}\n"
            info_text += f"  Dimension Number: {self.x_train.shape[2]}\n"
            info_text += f"  Shape Number: {self.arr_0.shape[1]}"

            self.data_info_text.config(state=tk.NORMAL)
            self.data_info_text.delete(1.0, tk.END)
            self.data_info_text.insert(tk.END, info_text)
            self.data_info_text.config(state=tk.DISABLED)

            # 初始化图形
            self.update_plots()

            messagebox.showinfo("Successful", "The data load is successful！")

        except Exception as e:
            messagebox.showerror("Error", f"There was an error loading the data: {str(e)}")
            self.data_info_text.config(state=tk.NORMAL)
            self.data_info_text.delete(1.0, tk.END)
            self.data_info_text.insert(tk.END, f"loading failed: {str(e)}")
            self.data_info_text.config(state=tk.DISABLED)

    def update_control_ranges(self):
        """更新控件的范围"""
        if self.x_train is not None:
            sample_count = self.x_train.shape[0]
            length = self.x_train.shape[1]
            dimension_count = self.x_train.shape[2]

            # 更新下半部分控件
            self.lower_sample_spinbox.config(to=sample_count)

            # 更新位置查看控件
            self.pos_sample1_spinbox.config(to=sample_count)
            self.pos_sample2_spinbox.config(to=sample_count)

            # 更新序列控制控件
            for controls in self.sequence_controls:
                controls['sample_spinbox'].config(to=sample_count)
                controls['dimension_spinbox'].config(to=dimension_count)
                controls['start_spinbox'].config(to=length - 1)
                controls['length_spinbox'].config(to=length)

        if self.arr_0 is not None:
            shape_number_count = self.arr_0.shape[1]
            # 更新shape相关控件
            self.shape_number_spinbox.config(to=shape_number_count)
            self.pos_shape1_spinbox.config(to=shape_number_count)
            self.pos_shape2_spinbox.config(to=shape_number_count)

    def update_plots(self):
        """更新图形显示"""
        if self.x_train is None or self.arr_0 is None:
            messagebox.showwarning("Warning", "Please load the data file first!")
            return

        try:
            self.update_upper_plots()
            self.update_lower_plots()
        except Exception as e:
            messagebox.showerror("Error", f"An error occurred while updating the drawing: {str(e)}")

    def update_upper_plots(self):
        """更新上半部分图形"""
        self.upper_fig.clear()

        plot_count = self.plot_count_var.get()

        # 根据图片数量确定子图布局
        if plot_count == 1:
            subplot_layout = (1, 1)
        elif plot_count == 2:
            subplot_layout = (1, 2)
        elif plot_count == 3:
            subplot_layout = (1, 3)
        elif plot_count == 4:
            subplot_layout = (2, 2)
        elif plot_count == 6:
            subplot_layout = (2, 3)
        else:
            subplot_layout = (1, 1)

        for i in range(plot_count):
            if i < len(self.sequence_controls):
                controls = self.sequence_controls[i]

                sample_idx = controls['sample'].get() - 1  # 转换为0索引
                dimension_idx = controls['dimension'].get() - 1  # 转换为0索引
                start_time = controls['start_time'].get()
                seq_length = controls['length'].get()

                # 验证参数
                if sample_idx >= self.x_train.shape[0] or sample_idx < 0:
                    sample_idx = 0
                if dimension_idx >= self.x_train.shape[2] or dimension_idx < 0:
                    dimension_idx = 0
                if start_time + seq_length > self.x_train.shape[1]:
                    seq_length = self.x_train.shape[1] - start_time

                # 提取数据
                end_time = start_time + seq_length
                data_to_plot = self.x_train[sample_idx, start_time:end_time, dimension_idx]

                # 创建子图
                ax = self.upper_fig.add_subplot(subplot_layout[0], subplot_layout[1], i + 1)
                ax.plot(data_to_plot, linewidth=2, label=f'Seq {i + 1}')
                ax.set_title(
                    f'sequence {i + 1}: Sample {sample_idx + 1}, Dim {dimension_idx + 1}\ntime {start_time}-{end_time}')
                ax.set_xlabel('Time')
                ax.set_ylabel('Value')
                ax.grid(True, alpha=0.3)
                ax.legend()

        self.upper_fig.tight_layout()
        self.upper_canvas.draw()

    def update_lower_plots(self):
        """更新下半部分图形"""
        self.lower_fig.clear()

        sample_idx = self.lower_sample_var.get() - 1  # 转换为0索引
        shape_number = self.shape_number_var.get() - 1  # 转换为0索引

        # 验证参数
        if sample_idx >= self.arr_0.shape[0] or sample_idx < 0:
            sample_idx = 0
        if shape_number >= self.arr_0.shape[1] or shape_number < 0:
            shape_number = 0

        # 创建两个子图
        ax1 = self.lower_fig.add_subplot(1, 2, 1)
        ax2 = self.lower_fig.add_subplot(1, 2, 2)

        # 绘制arr_0数据
        data_0 = self.arr_0[sample_idx, shape_number, :]
        ax1.plot(data_0, linewidth=2, color='blue')
        ax1.set_title(f'arr_0: Sample {sample_idx + 1}, Shape {shape_number + 1}', fontsize=12)
        ax1.set_xlabel('Shape Length')
        ax1.set_ylabel('Value')
        ax1.grid(True, alpha=0.3)

        # 绘制arr_1数据（VP信息）
        data_1 = self.arr_1[sample_idx, shape_number, :]
        bars = ax2.bar(range(len(data_1)), data_1, color='orange', alpha=0.7)
        ax2.set_title(f'arr_1 (VP): Sample {sample_idx + 1}, Shape {shape_number + 1}', fontsize=12)
        ax2.set_xlabel('VP Index')
        ax2.set_ylabel('VP Value')
        ax2.grid(True, alpha=0.3)

        # 添加数值标签
        for i, bar in enumerate(bars):
            height = bar.get_height()
            ax2.text(bar.get_x() + bar.get_width() / 2., height,
                     f'{data_1[i]:.2f}',
                     ha='center', va='bottom', fontsize=8)

        self.lower_fig.tight_layout()
        self.lower_canvas.draw()

    def show_shape_position(self, group_id):
        """显示Shape在Sequence中的具体位置"""
        if self.x_train is None or self.arr_0 is None or self.arr_1 is None:
            messagebox.showwarning("Warning", "Please load the data file first!")
            return

        try:
            if group_id == 1:
                sample_idx = self.pos_sample1_var.get() - 1
                shape_idx = self.pos_shape1_var.get() - 1
            else:
                sample_idx = self.pos_sample2_var.get() - 1
                shape_idx = self.pos_shape2_var.get() - 1

            # 验证参数
            if sample_idx >= self.arr_0.shape[0] or sample_idx < 0:
                messagebox.showerror("Error", "The sample number is out of range")
                return
            if shape_idx >= self.arr_0.shape[1] or shape_idx < 0:
                messagebox.showerror("Error", "The sample number is out of range")
                return

            # 从arr_1获取VP信息
            vp_data = self.arr_1[sample_idx, shape_idx, :]
            if len(vp_data) < 4:
                messagebox.showerror("Error", "The VP data is not formatted correctly and should contain at least 4 elements.")
                return

            length = int(vp_data[0])
            start_time = int(vp_data[1])
            end_time = int(vp_data[2])
            label = vp_data[3]

            # 创建新窗口显示结果
            self.show_position_window(sample_idx, shape_idx, length, start_time, end_time, label, group_id)

        except Exception as e:
            messagebox.showerror("Error", f"An error occurred while displaying the Shape position: {str(e)}")

    def show_position_window(self, sample_idx, shape_idx, length, start_time, end_time, label, group_id):
        """创建新窗口显示Shape位置"""
        pos_window = tk.Toplevel(self.root)
        pos_window.title(f"Shape Position {group_id} - Sample {sample_idx + 1}, Shape {shape_idx + 1}")
        pos_window.geometry("900x700")

        # 创建图形
        fig = Figure(figsize=(12, 10), dpi=100)
        canvas = FigureCanvasTkAgg(fig, pos_window)
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # 创建三个子图
        ax1 = fig.add_subplot(3, 1, 1)
        ax2 = fig.add_subplot(3, 1, 2)
        ax3 = fig.add_subplot(3, 1, 3)

        # 上图：显示原始shape数据，高亮时间段
        # shape_data = self.arr_0[sample_idx, shape_idx, :]
        # ax1.plot(shape_data, linewidth=2, color='blue', label='Shape Data')
        #
        # # 高亮指定时间段
        # if start_time < len(shape_data) and end_time <= len(shape_data) and start_time < end_time:
        #     ax1.axvspan(start_time, end_time, alpha=0.3, color='red',
        #                 label=f'Highlighted Segment ({start_time}-{end_time})')
        #     # 高亮部分的数据用不同颜色绘制
        #     highlight_data = shape_data[start_time:end_time]
        #     ax1.plot(range(start_time, end_time), highlight_data,
        #              linewidth=3, color='red', alpha=0.8)
        #
        # ax1.set_title(f'Shape Data with Highlighted Segment\nSample {sample_idx + 1}, Shape {shape_idx + 1}')
        # ax1.set_xlabel('Shape Length Index')
        # ax1.set_ylabel('Value')
        # ax1.legend()
        # ax1.grid(True, alpha=0.3)

        # 中图：显示对应的时间序列数据（第一个维度）
        if start_time < self.x_train.shape[1] and end_time <= self.x_train.shape[1]:
            time_series = self.x_train[sample_idx, :, 0]  # 使用第一个维度
            ax2.plot(time_series, linewidth=2, color='green', label='Time Series (Dim 1)')

            # 高亮对应的时间段
            if start_time < end_time:
                ax2.axvspan(start_time, end_time, alpha=0.3, color='red',
                            label=f'Corresponding Time Segment')

                # 在时间序列上标记对应段
                highlight_ts = time_series[start_time:end_time]
                ax2.plot(range(start_time, end_time), highlight_ts,
                         linewidth=3, color='red', alpha=0.8)

        ax2.set_title(f'Corresponding Time Series Data (Dimension 1)')
        ax2.set_xlabel('Time Index')
        ax2.set_ylabel('Value')
        ax2.legend()
        ax2.grid(True, alpha=0.3)

        # 下图：显示VP信息的柱状图
        vp_data = self.arr_1[sample_idx, shape_idx, :]
        ax3.bar(range(len(vp_data)), vp_data, color='purple', alpha=0.7)
        ax3.set_title(f'VP Data: length={length}, start={start_time}, end={end_time}, label={label:.3f}')
        ax3.set_xlabel('VP Index')
        ax3.set_ylabel('VP Value')
        ax3.grid(True, alpha=0.3)

        # 在柱状图上标注数值
        for i, val in enumerate(vp_data):
            ax3.text(i, val, f'{val:.2f}', ha='center', va='bottom', fontsize=9)

        fig.tight_layout()
        canvas.draw()

        # 添加工具栏
        toolbar = NavigationToolbar2Tk(canvas, pos_window)
        toolbar.update()

    def compare_shape_positions(self):
        """同时对比两个Shape位置"""
        if self.x_train is None or self.arr_0 is None or self.arr_1 is None:
            messagebox.showwarning("Warning", "Please load the data file first.")
            return

        try:
            # 获取两组参数
            sample1_idx = self.pos_sample1_var.get() - 1
            shape1_idx = self.pos_shape1_var.get() - 1
            sample2_idx = self.pos_sample2_var.get() - 1
            shape2_idx = self.pos_shape2_var.get() - 1

            # 验证参数
            for idx, name in [(sample1_idx, "Sample1"), (sample2_idx, "Sample2")]:
                if idx >= self.arr_0.shape[0] or idx < 0:
                    messagebox.showerror("Error", f"{name}the serial number is out of range")
                    return
            for idx, name in [(shape1_idx, "Shape1"), (shape2_idx, "Shape2")]:
                if idx >= self.arr_0.shape[1] or idx < 0:
                    messagebox.showerror("Error", f"{name}the serial number is out of range")
                    return

            # 获取VP数据
            vp1_data = self.arr_1[sample1_idx, shape1_idx, :]
            vp2_data = self.arr_1[sample2_idx, shape2_idx, :]

            if len(vp1_data) < 4 or len(vp2_data) < 4:
                messagebox.showerror("Error", "The VP data format is incorrect.")
                return

            # 解析VP数据
            length1, start1, end1, label1 = int(vp1_data[0]), int(vp1_data[1]), int(vp1_data[2]), vp1_data[3]
            length2, start2, end2, label2 = int(vp2_data[0]), int(vp2_data[1]), int(vp2_data[2]), vp2_data[3]

            # 创建对比窗口
            self.show_comparison_window(
                sample1_idx, shape1_idx, length1, start1, end1, label1,
                sample2_idx, shape2_idx, length2, start2, end2, label2
            )

        except Exception as e:
            messagebox.showerror("Error", f"Error comparing Shape position: {str(e)}")

    def show_comparison_window(self, sample1_idx, shape1_idx, length1, start1, end1, label1,
                               sample2_idx, shape2_idx, length2, start2, end2, label2):
        """创建对比窗口"""
        comp_window = tk.Toplevel(self.root)
        comp_window.title("Shape Position Comparison")
        comp_window.geometry("1200x900")

        # 创建图形
        fig = Figure(figsize=(15, 12), dpi=100)
        canvas = FigureCanvasTkAgg(fig, comp_window)
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # 创建6个子图 (3行2列)
        # 第一行：Shape数据对比
        # ax1 = fig.add_subplot(2, 2, 1)
        # ax2 = fig.add_subplot(2, 3, 2)

        # 第二行：时间序列数据对比
        ax3 = fig.add_subplot(2, 2, 1)
        ax4 = fig.add_subplot(2, 2, 3)

        # 第三行：VP数据对比
        ax5 = fig.add_subplot(2, 2, 2)
        ax6 = fig.add_subplot(2, 2, 4)

        # # 绘制第一组Shape数据
        # shape1_data = self.arr_0[sample1_idx, shape1_idx, :]
        # ax1.plot(shape1_data, linewidth=2, color='blue', label='Shape Data 1')
        # if start1 < len(shape1_data) and end1 <= len(shape1_data) and start1 < end1:
        #     ax1.axvspan(start1, end1, alpha=0.3, color='red',
        #                 label=f'Segment ({start1}-{end1})')
        #     highlight1 = shape1_data[start1:end1]
        #     ax1.plot(range(start1, end1), highlight1, linewidth=3, color='red', alpha=0.8)
        # ax1.set_title(f'Shape 1: Sample {sample1_idx + 1}, Shape {shape1_idx + 1}')
        # ax1.set_xlabel('Shape Length Index')
        # ax1.set_ylabel('Value')
        # ax1.legend()
        # ax1.grid(True, alpha=0.3)
        #
        # # 绘制第二组Shape数据
        # shape2_data = self.arr_0[sample2_idx, shape2_idx, :]
        # ax2.plot(shape2_data, linewidth=2, color='blue', label='Shape Data 2')
        # if start2 < len(shape2_data) and end2 <= len(shape2_data) and start2 < end2:
        #     ax2.axvspan(start2, end2, alpha=0.3, color='red',
        #                 label=f'Segment ({start2}-{end2})')
        #     highlight2 = shape2_data[start2:end2]
        #     ax2.plot(range(start2, end2), highlight2, linewidth=3, color='red', alpha=0.8)
        # ax2.set_title(f'Shape 2: Sample {sample2_idx + 1}, Shape {shape2_idx + 1}')
        # ax2.set_xlabel('Shape Length Index')
        # ax2.set_ylabel('Value')
        # ax2.legend()
        # ax2.grid(True, alpha=0.3)

        # 绘制第一组时间序列数据
        if start1 < self.x_train.shape[1] and end1 <= self.x_train.shape[1]:
            ts1 = self.x_train[sample1_idx, :, 0]
            ax3.plot(ts1, linewidth=2, color='green', label='Time Series 1')
            if start1 < end1:
                ax3.axvspan(start1, end1, alpha=0.3, color='red',
                            label=f'Corresponding Segment')
                highlight_ts1 = ts1[start1:end1]
                ax3.plot(range(start1, end1), highlight_ts1, linewidth=3, color='red', alpha=0.8)
        ax3.set_title(f'Time Series 1: Sample {sample1_idx + 1}, Dim 1')
        ax3.set_xlabel('Time Index')
        ax3.set_ylabel('Value')
        ax3.legend()
        ax3.grid(True, alpha=0.3)

        # 绘制第二组时间序列数据
        if start2 < self.x_train.shape[1] and end2 <= self.x_train.shape[1]:
            ts2 = self.x_train[sample2_idx, :, 0]
            ax4.plot(ts2, linewidth=2, color='green', label='Time Series 2')
            if start2 < end2:
                ax4.axvspan(start2, end2, alpha=0.3, color='red',
                            label=f'Corresponding Segment')
                highlight_ts2 = ts2[start2:end2]
                ax4.plot(range(start2, end2), highlight_ts2, linewidth=3, color='red', alpha=0.8)
        ax4.set_title(f'Time Series 2: Sample {sample2_idx + 1}, Dim 1')
        ax4.set_xlabel('Time Index')
        ax4.set_ylabel('Value')
        ax4.legend()
        ax4.grid(True, alpha=0.3)

        # 绘制第一组VP数据
        vp1_data = self.arr_1[sample1_idx, shape1_idx, :]
        bars1 = ax5.bar(range(len(vp1_data)), vp1_data, color='purple', alpha=0.7)
        ax5.set_title(f'VP Data 1: len={length1}, start={start1}, end={end1}, label={label1:.3f}')
        ax5.set_xlabel('VP Index')
        ax5.set_ylabel('VP Value')
        ax5.grid(True, alpha=0.3)
        for i, val in enumerate(vp1_data):
            ax5.text(i, val, f'{val:.2f}', ha='center', va='bottom', fontsize=8)

        # 绘制第二组VP数据
        vp2_data = self.arr_1[sample2_idx, shape2_idx, :]
        bars2 = ax6.bar(range(len(vp2_data)), vp2_data, color='purple', alpha=0.7)
        ax6.set_title(f'VP Data 2: len={length2}, start={start2}, end={end2}, label={label2:.3f}')
        ax6.set_xlabel('VP Index')
        ax6.set_ylabel('VP Value')
        ax6.grid(True, alpha=0.3)
        for i, val in enumerate(vp2_data):
            ax6.text(i, val, f'{val:.2f}', ha='center', va='bottom', fontsize=8)

        # 添加总标题
        fig.suptitle('Shape position comparison analysis', fontsize=16, y=0.95)

        fig.tight_layout()
        canvas.draw()

        # 添加工具栏
        toolbar = NavigationToolbar2Tk(canvas, comp_window)
        toolbar.update()

        # 添加对比信息文本框
        info_frame = ttk.Frame(comp_window)
        info_frame.pack(fill=tk.X, padx=10, pady=5)

        info_text = tk.Text(info_frame, height=4, width=100)
        info_text.pack(fill=tk.X)

        comparison_info = f"Compare the results of the analysis:\n"
        comparison_info += f"Group1: Sample {sample1_idx + 1}, Shape {shape1_idx + 1} - time period: {start1}-{end1} (length: {end1 - start1}), label: {label1:.3f}\n"
        comparison_info += f"Group2: Sample {sample2_idx + 1}, Shape {shape2_idx + 1} - time period: {start2}-{end2} (length: {end2 - start2}), label: {label2:.3f}\n"
        comparison_info += f"time overlapping: {'Yes' if max(start1, start2) < min(end1, end2) else 'No'}"

        info_text.insert(tk.END, comparison_info)
        info_text.config(state=tk.DISABLED)

    def reset_view(self):
        """重置视图"""
        if hasattr(self, 'upper_toolbar'):
            self.upper_toolbar.home()
        if hasattr(self, 'lower_toolbar'):
            self.lower_toolbar.home()
        self.current_zoom = 1.0
        self.pan_offset = [0, 0]

    def zoom_in(self):
        """放大"""
        self.current_zoom *= 1.2
        if hasattr(self, 'upper_toolbar'):
            self.upper_toolbar.zoom()
        if hasattr(self, 'lower_toolbar'):
            self.lower_toolbar.zoom()

    def zoom_out(self):
        """缩小"""
        self.current_zoom /= 1.2
        if hasattr(self, 'upper_toolbar'):
            self.upper_toolbar.back()
        if hasattr(self, 'lower_toolbar'):
            self.lower_toolbar.back()

    def enable_pan_mode(self):
        """启用移动模式"""
        if hasattr(self, 'upper_toolbar'):
            self.upper_toolbar.pan()
        if hasattr(self, 'lower_toolbar'):
            self.lower_toolbar.pan()
        messagebox.showinfo("Move", "Moving mode is enabled, tap and drag the image to move it！")

def main():
        root = tk.Tk()

        # 设置应用程序图标和样式
        try:
            # 尝试设置现代化的主题
            style = ttk.Style()
            available_themes = style.theme_names()
            if 'clam' in available_themes:
                style.theme_use('clam')
            elif 'alt' in available_themes:
                style.theme_use('alt')
        except:
            pass

        app = DataVisualizationApp(root)

        # 设置窗口关闭事件
        def on_closing():
            if messagebox.askokcancel("exit", "Are you sure you want to quit the app？"):
                root.destroy()

        root.protocol("WM_DELETE_WINDOW", on_closing)

        # 启动主循环
        root.mainloop()

if __name__ == "__main__":
        main()

  self.upper_fig.tight_layout()
  self.lower_fig.tight_layout()


In [None]:
#UI2

In [7]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from matplotlib.figure import Figure
import matplotlib.patches as patches
from matplotlib.colors import ListedColormap
import seaborn as sns
import os

class AdvancedVisualizationApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Advanced Visualization & Analysis")
        self.root.geometry("1800x1000")
        
        # 数据存储
        self.heatmap_data = None  # (sample, shape_number, shape_number)
        self.shape_data = None  # NPZ文件数据
        self.arr_0 = None  # (sample, shape_number, shape_length)
        self.arr_1 = None  # (sample, shape_number, VP)
        self.x_train = None  # (sample, length, dimension_number)
        self.attention_data = None  # (sample_number, shape_number, value_number)
        self.sorted_attention_data = None  # 排序后的数据
        self.original_indices = None  # 原始索引
        self.all_samples_indices = []  # 所有sample的排序索引
        
        # Heatmap相关变量
        self.heatmap_colorbar = None
        self.current_heatmap_ax = None
        self.shape_plot_ax = None
        
        # 当前选中的shape信息
        self.selected_shapes = {'shape1': None, 'shape2': None}
        self.current_click_count = 0
        
        # Attention plot相关
        self.attention_annotations = []  # 存储注释对象
        
        # 创建主要布局
        self.create_main_layout()
        
    def create_main_layout(self):
        """创建主要布局"""
        # 创建主要的分割布局
        main_paned = ttk.PanedWindow(self.root, orient=tk.HORIZONTAL)
        main_paned.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # 左侧菜单栏
        self.create_left_menu(main_paned)
        
        # 右侧可视化区域
        self.create_right_visualization(main_paned)
        
    def create_left_menu(self, parent):
        """创建左侧菜单栏"""
        menu_frame = ttk.Frame(parent, width=450)
        parent.add(menu_frame, weight=1)
        
        # 使用Canvas和Scrollbar创建可滚动的菜单
        canvas = tk.Canvas(menu_frame)
        scrollbar = ttk.Scrollbar(menu_frame, orient="vertical", command=canvas.yview)
        scrollable_frame = ttk.Frame(canvas)
        
        scrollable_frame.bind(
            "<Configure>",
            lambda e: canvas.configure(scrollregion=canvas.bbox("all"))
        )
        
        canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)
        
        canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")
        
        # 上半部分：Heatmap控制
        self.create_heatmap_controls(scrollable_frame)
        
        # 下半部分：Attention控制
        self.create_attention_controls(scrollable_frame)
        
    def create_heatmap_controls(self, parent):
        """创建Heatmap控制区域"""
        heatmap_frame = ttk.LabelFrame(parent, text="Heatmap Controls", padding=10)
        heatmap_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # 第一个按钮：加载Heatmap NPY文件
        ttk.Button(heatmap_frame, text="Load Heatmap Data (.npy)", 
                  command=self.load_heatmap_data).pack(fill=tk.X, pady=2)
        
        # Heatmap数据信息显示
        self.heatmap_info_label = ttk.Label(heatmap_frame, text="No heatmap data loaded", 
                                           foreground="red")
        self.heatmap_info_label.pack(pady=2)
        
        # 加载Shape NPZ文件
        ttk.Button(heatmap_frame, text="Load Shape Data (.npz)", 
                  command=self.load_shape_data).pack(fill=tk.X, pady=2)
        
        # Shape数据信息显示
        self.shape_info_label = ttk.Label(heatmap_frame, text="No shape data loaded", 
                                         foreground="red")
        self.shape_info_label.pack(pady=2)
        
        # 加载X_train NPY文件
        ttk.Button(heatmap_frame, text="Load X_train Data (.npy)", 
                  command=self.load_xtrain_data).pack(fill=tk.X, pady=2)
        
        # X_train数据信息显示
        self.xtrain_info_label = ttk.Label(heatmap_frame, text="No X_train data loaded", 
                                          foreground="red")
        self.xtrain_info_label.pack(pady=2)
        
        # 第二个按钮：Sample序号选择
        sample_frame = ttk.Frame(heatmap_frame)
        sample_frame.pack(fill=tk.X, pady=5)
        ttk.Label(sample_frame, text="Sample Number:").pack(side=tk.LEFT)
        self.heatmap_sample_var = tk.IntVar(value=1)
        self.heatmap_sample_spinbox = ttk.Spinbox(sample_frame, from_=1, to=1000, 
                                                 textvariable=self.heatmap_sample_var, width=15)
        self.heatmap_sample_spinbox.pack(side=tk.RIGHT)
        
        # 第三个和第四个按钮：Shape范围选择
        range_frame = ttk.LabelFrame(heatmap_frame, text="Shape Number Range", padding=5)
        range_frame.pack(fill=tk.X, pady=5)
        
        start_frame = ttk.Frame(range_frame)
        start_frame.pack(fill=tk.X, pady=2)
        ttk.Label(start_frame, text="Start Shape:").pack(side=tk.LEFT)
        self.heatmap_start_var = tk.IntVar(value=1)
        self.heatmap_start_spinbox = ttk.Spinbox(start_frame, from_=1, to=1000,
                                                textvariable=self.heatmap_start_var, width=15)
        self.heatmap_start_spinbox.pack(side=tk.RIGHT)
        
        end_frame = ttk.Frame(range_frame)
        end_frame.pack(fill=tk.X, pady=2)
        ttk.Label(end_frame, text="End Shape:").pack(side=tk.LEFT)
        self.heatmap_end_var = tk.IntVar(value=10)
        self.heatmap_end_spinbox = ttk.Spinbox(end_frame, from_=1, to=1000,
                                              textvariable=self.heatmap_end_var, width=15)
        self.heatmap_end_spinbox.pack(side=tk.RIGHT)
        
        # 更新按钮
        ttk.Button(heatmap_frame, text="Update Heatmap", 
                  command=self.update_heatmap, style="Accent.TButton").pack(fill=tk.X, pady=10)
        
        # 显示点击信息的标签
        self.click_info_label = ttk.Label(heatmap_frame, text="Click two shapes on heatmap", 
                                         foreground="blue")
        self.click_info_label.pack(pady=2)
        
        # 重置选择按钮
        ttk.Button(heatmap_frame, text="Reset Shape Selection", 
                  command=self.reset_shape_selection).pack(fill=tk.X, pady=2)
        
    def create_attention_controls(self, parent):
        """创建Attention控制区域"""
        attention_frame = ttk.LabelFrame(parent, text="Attention Controls", padding=10)
        attention_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # 第一个按钮：加载Attention文件
        ttk.Button(attention_frame, text="Load Attention Data (.npy)", 
                  command=self.load_attention_data).pack(fill=tk.X, pady=2)
        
        # 数据信息显示
        self.attention_info_label = ttk.Label(attention_frame, text="No attention data loaded", 
                                             foreground="red")
        self.attention_info_label.pack(pady=2)
        
        # 第二个按钮：Sample选择
        sample_frame = ttk.Frame(attention_frame)
        sample_frame.pack(fill=tk.X, pady=5)
        ttk.Label(sample_frame, text="Sample Number:").pack(side=tk.LEFT)
        self.attention_sample_var = tk.IntVar(value=1)
        self.attention_sample_spinbox = ttk.Spinbox(sample_frame, from_=1, to=1000,
                                                   textvariable=self.attention_sample_var, width=15)
        self.attention_sample_spinbox.pack(side=tk.RIGHT)
        
        # 第三个按钮：Shape数量选择
        count_frame = ttk.Frame(attention_frame)
        count_frame.pack(fill=tk.X, pady=5)
        ttk.Label(count_frame, text="Number of Shapes:").pack(side=tk.LEFT)
        self.attention_count_var = tk.IntVar(value=10)
        self.attention_count_spinbox = ttk.Spinbox(count_frame, from_=1, to=1000,
                                                  textvariable=self.attention_count_var, width=15)
        self.attention_count_spinbox.pack(side=tk.RIGHT)
        
        # 更新按钮
        ttk.Button(attention_frame, text="Update Attention Plot", 
                  command=self.update_attention_plot, style="Accent.TButton").pack(fill=tk.X, pady=10)
        
        # 下载索引按钮
        ttk.Button(attention_frame, text="Export TopxShapes", 
                  command=self.download_indices).pack(fill=tk.X, pady=2)
        
    def create_right_visualization(self, parent):
        """创建右侧可视化区域"""
        viz_frame = ttk.Frame(parent)
        parent.add(viz_frame, weight=3)
        
        # 创建上下分割的可视化区域
        viz_paned = ttk.PanedWindow(viz_frame, orient=tk.VERTICAL)
        viz_paned.pack(fill=tk.BOTH, expand=True)
        
        # 上半部分：分为左右两个区域
        upper_frame = ttk.Frame(viz_paned)
        viz_paned.add(upper_frame, weight=1)
        
        upper_paned = ttk.PanedWindow(upper_frame, orient=tk.HORIZONTAL)
        upper_paned.pack(fill=tk.BOTH, expand=True)
        
        # 左侧：Heatmap
        heatmap_frame = ttk.LabelFrame(upper_paned, text="Heatmap Visualization")
        upper_paned.add(heatmap_frame, weight=1)
        
        # 右侧：Sequence显示
        sequence_frame = ttk.LabelFrame(upper_paned, text="Selected Shapes Sequence Visualization")
        upper_paned.add(sequence_frame, weight=1)
        
        # 下半部分：Attention
        lower_frame = ttk.LabelFrame(viz_paned, text="Attention Values Visualization")
        viz_paned.add(lower_frame, weight=1)
        
        # 创建matplotlib图形
        self.create_heatmap_plot(heatmap_frame)
        self.create_sequence_plot(sequence_frame)
        self.create_attention_plot(lower_frame)
        
    def create_heatmap_plot(self, parent):
        """创建Heatmap图形"""
        self.heatmap_fig = Figure(figsize=(8, 6), dpi=100)
        self.heatmap_canvas = FigureCanvasTkAgg(self.heatmap_fig, parent)
        self.heatmap_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        self.heatmap_toolbar = NavigationToolbar2Tk(self.heatmap_canvas, parent)
        self.heatmap_toolbar.update()
        
        # 绑定点击事件
        self.heatmap_canvas.mpl_connect('button_press_event', self.on_heatmap_click)
        
    def create_sequence_plot(self, parent):
        """创建序列图形显示"""
        self.sequence_fig = Figure(figsize=(8, 6), dpi=100)
        self.sequence_canvas = FigureCanvasTkAgg(self.sequence_fig, parent)
        self.sequence_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        self.sequence_toolbar = NavigationToolbar2Tk(self.sequence_canvas, parent)
        self.sequence_toolbar.update()
        
    def create_attention_plot(self, parent):
        """创建Attention图形"""
        self.attention_fig = Figure(figsize=(14, 6), dpi=100)
        self.attention_canvas = FigureCanvasTkAgg(self.attention_fig, parent)
        self.attention_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        self.attention_toolbar = NavigationToolbar2Tk(self.attention_canvas, parent)
        self.attention_toolbar.update()
        
        # 绑定鼠标事件
        self.attention_canvas.mpl_connect('motion_notify_event', self.on_attention_hover)
        self.attention_canvas.mpl_connect('axes_leave_event', self.on_attention_leave)
        
    def load_heatmap_data(self):
        """加载Heatmap数据"""
        filename = filedialog.askopenfilename(
            title="Select Heatmap Data File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.heatmap_data = np.load(filename)
                
                # 验证数据格式
                if len(self.heatmap_data.shape) != 3:
                    raise ValueError("Data should be 3D (sample, shape_number, shape_number)")
                if self.heatmap_data.shape[1] != self.heatmap_data.shape[2]:
                    raise ValueError("Second and third dimensions should be equal")
                
                # 更新控件范围
                sample_count = self.heatmap_data.shape[0]
                shape_count = self.heatmap_data.shape[1]
                
                self.heatmap_sample_spinbox.config(to=sample_count)
                self.heatmap_start_spinbox.config(to=shape_count)
                self.heatmap_end_spinbox.config(to=shape_count)
                self.heatmap_end_var.set(min(20, shape_count))
                
                # 更新信息显示
                info_text = f"Heatmap loaded: {self.heatmap_data.shape}"
                self.heatmap_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "Heatmap data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading heatmap data: {str(e)}")
                self.heatmap_info_label.config(text="Load failed", foreground="red")
    
    def load_shape_data(self):
        """加载Shape数据"""
        filename = filedialog.askopenfilename(
            title="Select Shape Data File",
            filetypes=[("NPZ files", "*.npz"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.shape_data = np.load(filename)
                self.arr_0 = self.shape_data['arr_0']  # (sample, shape_number, shape_length)
                self.arr_1 = self.shape_data['arr_1']  # (sample, shape_number, VP)
                
                # 验证数据格式
                if len(self.arr_0.shape) != 3 or len(self.arr_1.shape) != 3:
                    raise ValueError("Arrays should be 3D")
                
                # 更新信息显示
                info_text = f"Shape loaded: arr_0{self.arr_0.shape}, arr_1{self.arr_1.shape}"
                self.shape_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "Shape data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading shape data: {str(e)}")
                self.shape_info_label.config(text="Load failed", foreground="red")
    
    def load_xtrain_data(self):
        """加载X_train数据"""
        filename = filedialog.askopenfilename(
            title="Select X_train Data File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.x_train = np.load(filename)
                
                # 验证数据格式
                if len(self.x_train.shape) != 3:
                    raise ValueError("X_train should be 3D (sample, length, dimension_number)")
                
                # 更新信息显示
                info_text = f"X_train loaded: {self.x_train.shape}"
                self.xtrain_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "X_train data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading X_train data: {str(e)}")
                self.xtrain_info_label.config(text="Load failed", foreground="red")
    
    def load_attention_data(self):
        """加载Attention数据"""
        filename = filedialog.askopenfilename(
            title="Select Attention Data File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.attention_data = np.load(filename)
                
                # 验证数据格式
                if len(self.attention_data.shape) != 3:
                    raise ValueError("Data should be 3D (sample_number, shape_number, value_number)")
                
                # 对每个sample进行排序并保留原始索引
                self.process_attention_data()
                
                # 更新控件范围
                sample_count = self.attention_data.shape[0]
                shape_count = self.attention_data.shape[1]
                
                self.attention_sample_spinbox.config(to=sample_count)
                self.attention_count_spinbox.config(to=shape_count)
                self.attention_count_var.set(min(15, shape_count))
                
                # 更新信息显示
                info_text = f"Attention loaded: {self.attention_data.shape}"
                self.attention_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "Attention data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading attention data: {str(e)}")
                self.attention_info_label.config(text="Load failed", foreground="red")
    
    def process_attention_data(self):
        """处理attention数据，按sample排序并保留原始索引"""
        if self.attention_data is None:
            return
        
        sample_count, shape_count, value_count = self.attention_data.shape
        
        # 为每个sample创建排序数据和索引
        self.sorted_attention_data = np.zeros_like(self.attention_data)
        self.original_indices = np.zeros((sample_count, shape_count), dtype=int)
        self.all_samples_indices = []  # 存储所有sample的索引列表
        
        for sample_idx in range(sample_count):
            # 计算每个shape的平均attention值
            mean_values = np.mean(self.attention_data[sample_idx], axis=1)
            
            # 获取从大到小的排序索引
            sorted_indices = np.argsort(mean_values)[::-1]
            
            # 存储排序后的数据和原始索引
            self.sorted_attention_data[sample_idx] = self.attention_data[sample_idx][sorted_indices]
            self.original_indices[sample_idx] = sorted_indices
            
            # 添加到所有sample的索引列表
            self.all_samples_indices.append([sample_idx + 1, sorted_indices.tolist()])
    
    def update_heatmap(self):
        """更新Heatmap显示"""
        if self.heatmap_data is None:
            messagebox.showwarning("Warning", "Please load heatmap data first!")
            return
        
        try:
            sample_idx = self.heatmap_sample_var.get() - 1  # 转换为0索引
            start_shape = self.heatmap_start_var.get() - 1  # 转换为0索引
            end_shape = self.heatmap_end_var.get()  # 保持为结束位置
            
            # 验证参数
            if sample_idx >= self.heatmap_data.shape[0] or sample_idx < 0:
                sample_idx = 0
            if start_shape < 0:
                start_shape = 0
            if end_shape > self.heatmap_data.shape[1]:
                end_shape = self.heatmap_data.shape[1]
            if start_shape >= end_shape:
                messagebox.showerror("Error", "Start shape must be less than end shape!")
                return
            
            # 清除之前的图形，但保留colorbar
            if self.current_heatmap_ax is not None:
                self.current_heatmap_ax.clear()
            else:
                self.heatmap_fig.clear()
            
            # 提取数据切片
            data_slice = self.heatmap_data[sample_idx, start_shape:end_shape, start_shape:end_shape]
            
            # 创建或更新heatmap
            if self.current_heatmap_ax is None:
                self.current_heatmap_ax = self.heatmap_fig.add_subplot(1, 1, 1)
            
            # 使用学术界专用的颜色（viridis或plasma）
            im = self.current_heatmap_ax.imshow(data_slice, cmap='viridis', aspect='auto', interpolation='nearest')
            
            # 设置标题
            self.current_heatmap_ax.set_title(f'Heatmap: Sample {sample_idx+1}, Shapes {start_shape+1}-{end_shape}')
            
            # 隐藏坐标轴数字
            self.current_heatmap_ax.set_xticks([])
            self.current_heatmap_ax.set_yticks([])
            self.current_heatmap_ax.set_xlabel('Shape Index')
            self.current_heatmap_ax.set_ylabel('Shape Index')
            
            # 存储当前显示的信息，用于点击事件
            self.current_start_shape = start_shape
            self.current_end_shape = end_shape
            self.current_sample_idx = sample_idx
            
            # 只在第一次或colorbar不存在时添加colorbar
            if self.heatmap_colorbar is None:
                self.heatmap_colorbar = self.heatmap_fig.colorbar(im, ax=self.current_heatmap_ax)
                self.heatmap_colorbar.set_label('Value')
            else:
                # 更新现有colorbar的映射
                self.heatmap_colorbar.mappable.set_array(data_slice)
                self.heatmap_colorbar.mappable.set_clim(vmin=data_slice.min(), vmax=data_slice.max())
            
            self.heatmap_fig.tight_layout()
            self.heatmap_canvas.draw()
            
        except Exception as e:
            messagebox.showerror("Error", f"Error updating heatmap: {str(e)}")
    
    def on_heatmap_click(self, event):
        """处理heatmap点击事件"""
        if event.inaxes != self.current_heatmap_ax:
            return
        
        if self.heatmap_data is None or self.current_heatmap_ax is None:
            return
        
        try:
            # 获取点击的像素坐标
            x, y = int(event.xdata), int(event.ydata)
            
            # 转换为实际的shape坐标
            actual_x = x + self.current_start_shape
            actual_y = y + self.current_start_shape
            
            # 确定是第一个还是第二个shape
            if self.current_click_count == 0:
                self.selected_shapes['shape1'] = {'x': actual_x, 'y': actual_y}
                self.current_click_count = 1
                info_text = f"Shape 1 selected: ({actual_y+1}, {actual_x+1}). Click another shape."
            else:
                self.selected_shapes['shape2'] = {'x': actual_x, 'y': actual_y}
                self.current_click_count = 0
                info_text = f"Shape 2 selected: ({actual_y+1}, {actual_x+1}). Both shapes ready."
                # 显示两个shape的序列
                self.display_selected_sequences()
            
            self.click_info_label.config(text=info_text, foreground="blue")
            
        except (TypeError, IndexError):
            # 点击超出范围时忽略
            pass
    
    def reset_shape_selection(self):
        """重置shape选择"""
        self.selected_shapes = {'shape1': None, 'shape2': None}
        self.current_click_count = 0
        self.click_info_label.config(text="Click two shapes on heatmap", foreground="blue")
        
        # 清除序列显示
        self.sequence_fig.clear()
        self.sequence_canvas.draw()
    
    def display_selected_sequences(self):
        """显示选中的两个shape的序列"""
        if (self.selected_shapes['shape1'] is None or 
            self.selected_shapes['shape2'] is None or 
            self.arr_0 is None or self.x_train is None):
            return
        
        try:
            sample_idx = self.current_sample_idx
            shape1_idx = self.selected_shapes['shape1']['y']  # 使用y坐标作为shape索引
            shape2_idx = self.selected_shapes['shape2']['y']  # 使用y坐标作为shape索引
            
            # 清除之前的图形
            self.sequence_fig.clear()
            
            # 创建四个子图：2行2列
            ax1 = self.sequence_fig.add_subplot(2, 2, 1)
            ax2 = self.sequence_fig.add_subplot(2, 2, 2)
            ax3 = self.sequence_fig.add_subplot(2, 2, 3)
            ax4 = self.sequence_fig.add_subplot(2, 2, 4)
            
            # 显示第一个shape的arr_0数据
            if shape1_idx < self.arr_0.shape[1]:
                shape1_data = self.arr_0[sample_idx, shape1_idx, :]
                ax1.plot(shape1_data, linewidth=2, color='blue')
                ax1.set_title(f'Shape {shape1_idx+1} (arr_0)')
                ax1.set_xlabel('Shape Length')
                ax1.set_ylabel('Value')
                ax1.grid(True, alpha=0.3)
            
            # 显示第二个shape的arr_0数据
            if shape2_idx < self.arr_0.shape[1]:
                shape2_data = self.arr_0[sample_idx, shape2_idx, :]
                ax2.plot(shape2_data, linewidth=2, color='red')
                ax2.set_title(f'Shape {shape2_idx+1} (arr_0)')
                ax2.set_xlabel('Shape Length')
                ax2.set_ylabel('Value')
                ax2.grid(True, alpha=0.3)
            
            # 显示X_train数据的相关维度
            if sample_idx < self.x_train.shape[0]:
                # 使用shape索引作为维度索引（如果维度足够）
                dim1 = min(shape1_idx, self.x_train.shape[2] - 1)
                dim2 = min(shape2_idx, self.x_train.shape[2] - 1)
                
                x_data1 = self.x_train[sample_idx, :, dim1]
                x_data2 = self.x_train[sample_idx, :, dim2]
                
                ax3.plot(x_data1, linewidth=2, color='blue')
                ax3.set_title(f'X_train: Dim {dim1+1}')
                ax3.set_xlabel('Time Length')
                ax3.set_ylabel('Value')
                ax3.grid(True, alpha=0.3)
                
                ax4.plot(x_data2, linewidth=2, color='red')
                ax4.set_title(f'X_train: Dim {dim2+1}')
                ax4.set_xlabel('Time Length')
                ax4.set_ylabel('Value')
                ax4.grid(True, alpha=0.3)
            
            self.sequence_fig.suptitle(f'Selected Shapes Comparison: Sample {sample_idx+1}')
            self.sequence_fig.tight_layout()
            self.sequence_canvas.draw()
            
        except Exception as e:
            print(f"Error displaying sequences: {str(e)}")
    
    def update_attention_plot(self):
        """更新Attention图表显示"""
        if self.attention_data is None or self.sorted_attention_data is None:
            messagebox.showwarning("Warning", "Please load attention data first!")
            return
        
        try:
            sample_idx = self.attention_sample_var.get() - 1  # 转换为0索引
            shape_count = self.attention_count_var.get()
            
            # 验证参数
            if sample_idx >= self.attention_data.shape[0] or sample_idx < 0:
                sample_idx = 0
            if shape_count > self.attention_data.shape[1]:
                shape_count = self.attention_data.shape[1]
            
            # 清除之前的图形和注释
            self.attention_fig.clear()
            self.attention_annotations = []
            
            # 获取排序后的数据和原始索引
            sorted_data = self.sorted_attention_data[sample_idx, :shape_count]
            original_idx = self.original_indices[sample_idx, :shape_count]
            
            # 计算平均值用于柱状图显示
            mean_values = np.mean(sorted_data, axis=1)
            
            # 创建柱状图
            self.attention_ax = self.attention_fig.add_subplot(1, 1, 1)
            
            self.attention_bars = self.attention_ax.bar(range(shape_count), mean_values, 
                                                      color='steelblue', alpha=0.7)
            
            # 设置标题和标签
            self.attention_ax.set_title(f'Attention Values: Sample {sample_idx+1}, Top {shape_count} Shapes (High to Low)')
            self.attention_ax.set_xlabel('Rank (High to Low)')
            self.attention_ax.set_ylabel('Attention Value')
            self.attention_ax.grid(True, alpha=0.3)
            
            # 设置x轴标签
            self.attention_ax.set_xticks(range(shape_count))
            self.attention_ax.set_xticklabels([f'#{i+1}' for i in range(shape_count)])
            
            # 存储原始索引用于hover显示
            self.current_attention_indices = original_idx
            
            self.attention_fig.tight_layout()
            self.attention_canvas.draw()
            
        except Exception as e:
            messagebox.showerror("Error", f"Error updating attention plot: {str(e)}")
    
    def on_attention_hover(self, event):
        """处理attention plot的鼠标悬停事件"""
        if event.inaxes != getattr(self, 'attention_ax', None):
            return
        
        if not hasattr(self, 'attention_bars') or not hasattr(self, 'current_attention_indices'):
            return
        
        try:
            # 清除之前的注释
            for annotation in self.attention_annotations:
                annotation.remove()
            self.attention_annotations = []
            
            # 检查鼠标是否在某个柱子上
            for i, bar in enumerate(self.attention_bars):
                if bar.contains(event)[0]:
                    # 显示原始索引
                    original_idx = self.current_attention_indices[i]
                    height = bar.get_height()
                    
                    annotation = self.attention_ax.annotate(
                        f'Original Idx: {original_idx}',
                        xy=(bar.get_x() + bar.get_width()/2., height),
                        xytext=(0, 10), textcoords='offset points',
                        ha='center', va='bottom',
                        bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7),
                        fontsize=9
                    )
                    self.attention_annotations.append(annotation)
                    break
            
            self.attention_canvas.draw_idle()
            
        except Exception as e:
            pass  # 忽略hover错误
    
    def on_attention_leave(self, event):
        """处理鼠标离开attention plot事件"""
        try:
            # 清除所有注释
            for annotation in self.attention_annotations:
                annotation.remove()
            self.attention_annotations = []
            self.attention_canvas.draw_idle()
        except:
            pass
    
    def download_indices(self):
        """下载所有sample的排序索引为NPY文件"""
        if not self.all_samples_indices:
            messagebox.showwarning("Warning", "No attention data loaded! Please load attention data first.")
            return
        
        try:
            # 选择保存位置
            filename = filedialog.asksaveasfilename(
                defaultextension=".npy",
                filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
            )
            
            if filename:
                # 保存为NPY格式
                np.save(filename, self.all_samples_indices)
                
                # 显示保存信息
                info_msg = f"Indices data saved successfully!\n\n"
                info_msg += f"Total samples: {len(self.all_samples_indices)}\n"
                info_msg += f"Format: [[sample1, [indices...]], [sample2, [indices...]], ...]\n"
                info_msg += f"File: {filename}\n\n"
                info_msg += "Example structure:\n"
                if self.all_samples_indices:
                    example = self.all_samples_indices[0]
                    info_msg += f"[{example[0]}, {example[1][:5]}...]"
                
                messagebox.showinfo("Success", info_msg)
                
        except Exception as e:
            messagebox.showerror("Error", f"Error saving indices: {str(e)}")

def main():
    root = tk.Tk()
    
    # 设置应用程序样式
    try:
        style = ttk.Style()
        available_themes = style.theme_names()
        if 'clam' in available_themes:
            style.theme_use('clam')
        elif 'alt' in available_themes:
            style.theme_use('alt')
    except:
        pass
    
    app = AdvancedVisualizationApp(root)
    
    # 设置窗口关闭事件
    def on_closing():
        if messagebox.askokcancel("Exit", "Are you sure you want to exit?"):
            root.destroy()
    
    root.protocol("WM_DELETE_WINDOW", on_closing)
    
    # 启动主循环
    root.mainloop()

if __name__ == "__main__":
    main()

  self.heatmap_fig.tight_layout()
  self.sequence_fig.tight_layout()
  self.attention_fig.tight_layout()


In [12]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg, NavigationToolbar2Tk
from matplotlib.figure import Figure
import matplotlib.patches as patches
from matplotlib.colors import ListedColormap
import seaborn as sns
import os

class AdvancedVisualizationApp:
    def __init__(self, root):
        self.root = root
        self.root.title("Advanced Visualization & Analysis")
        self.root.geometry("1800x1000")
        
        # 数据存储
        self.heatmap_data = None  # (sample, shape_number, shape_number)
        self.shape_data = None  # NPZ文件数据
        self.arr_0 = None  # (sample, shape_number, shape_length)
        self.arr_1 = None  # (sample, shape_number, VP)
        self.x_train = None  # (sample, length, dimension_number)
        self.attention_data = None  # (sample_number, shape_number, value_number)
        self.sorted_attention_data = None  # 排序后的数据
        self.original_indices = None  # 原始索引
        self.all_samples_indices = []  # 所有sample的排序索引
        
        # Heatmap相关变量
        self.heatmap_colorbar = None
        self.current_heatmap_ax = None
        self.sequence_plot_ax = None
        
        # 当前选中的shape信息
        self.selected_shape_info = None
        
        # Attention plot相关
        self.attention_annotations = []  # 存储注释对象
        
        # 创建主要布局
        self.create_main_layout()
        
    def create_main_layout(self):
        """创建主要布局"""
        # 创建主要的分割布局
        main_paned = ttk.PanedWindow(self.root, orient=tk.HORIZONTAL)
        main_paned.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)
        
        # 左侧菜单栏
        self.create_left_menu(main_paned)
        
        # 右侧可视化区域
        self.create_right_visualization(main_paned)
        
    def create_left_menu(self, parent):
        """创建左侧菜单栏"""
        menu_frame = ttk.Frame(parent, width=450)
        parent.add(menu_frame, weight=1)
        
        # 使用Canvas和Scrollbar创建可滚动的菜单
        canvas = tk.Canvas(menu_frame)
        scrollbar = ttk.Scrollbar(menu_frame, orient="vertical", command=canvas.yview)
        scrollable_frame = ttk.Frame(canvas)
        
        scrollable_frame.bind(
            "<Configure>",
            lambda e: canvas.configure(scrollregion=canvas.bbox("all"))
        )
        
        canvas.create_window((0, 0), window=scrollable_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)
        
        canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")
        
        # 上半部分：Heatmap控制
        self.create_heatmap_controls(scrollable_frame)
        
        # 下半部分：Attention控制
        self.create_attention_controls(scrollable_frame)
        
    def create_heatmap_controls(self, parent):
        """创建Heatmap控制区域"""
        heatmap_frame = ttk.LabelFrame(parent, text="Heatmap Controls", padding=10)
        heatmap_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # 第一个按钮：加载Heatmap NPY文件
        ttk.Button(heatmap_frame, text="Load Heatmap Data (.npy)", 
                  command=self.load_heatmap_data).pack(fill=tk.X, pady=2)
        
        # Heatmap数据信息显示
        self.heatmap_info_label = ttk.Label(heatmap_frame, text="No heatmap data loaded", 
                                           foreground="red")
        self.heatmap_info_label.pack(pady=2)
        
        # 加载Shape NPZ文件
        ttk.Button(heatmap_frame, text="Load Shape Data (.npz)", 
                  command=self.load_shape_data).pack(fill=tk.X, pady=2)
        
        # Shape数据信息显示
        self.shape_info_label = ttk.Label(heatmap_frame, text="No shape data loaded", 
                                         foreground="red")
        self.shape_info_label.pack(pady=2)
        
        # 加载X_train NPY文件
        ttk.Button(heatmap_frame, text="Load X_train Data (.npy)", 
                  command=self.load_xtrain_data).pack(fill=tk.X, pady=2)
        
        # X_train数据信息显示
        self.xtrain_info_label = ttk.Label(heatmap_frame, text="No X_train data loaded", 
                                          foreground="red")
        self.xtrain_info_label.pack(pady=2)
        
        # 第二个按钮：Sample序号选择
        sample_frame = ttk.Frame(heatmap_frame)
        sample_frame.pack(fill=tk.X, pady=5)
        ttk.Label(sample_frame, text="Sample Number:").pack(side=tk.LEFT)
        self.heatmap_sample_var = tk.IntVar(value=1)
        self.heatmap_sample_spinbox = ttk.Spinbox(sample_frame, from_=1, to=1000, 
                                                 textvariable=self.heatmap_sample_var, width=15)
        self.heatmap_sample_spinbox.pack(side=tk.RIGHT)
        
        # 第三个和第四个按钮：Shape范围选择
        range_frame = ttk.LabelFrame(heatmap_frame, text="Shape Number Range", padding=5)
        range_frame.pack(fill=tk.X, pady=5)
        
        start_frame = ttk.Frame(range_frame)
        start_frame.pack(fill=tk.X, pady=2)
        ttk.Label(start_frame, text="Start Shape:").pack(side=tk.LEFT)
        self.heatmap_start_var = tk.IntVar(value=1)
        self.heatmap_start_spinbox = ttk.Spinbox(start_frame, from_=1, to=1000,
                                                textvariable=self.heatmap_start_var, width=15)
        self.heatmap_start_spinbox.pack(side=tk.RIGHT)
        
        end_frame = ttk.Frame(range_frame)
        end_frame.pack(fill=tk.X, pady=2)
        ttk.Label(end_frame, text="End Shape:").pack(side=tk.LEFT)
        self.heatmap_end_var = tk.IntVar(value=10)
        self.heatmap_end_spinbox = ttk.Spinbox(end_frame, from_=1, to=1000,
                                              textvariable=self.heatmap_end_var, width=15)
        self.heatmap_end_spinbox.pack(side=tk.RIGHT)
        
        # 更新按钮
        ttk.Button(heatmap_frame, text="Update Heatmap", 
                  command=self.update_heatmap, style="Accent.TButton").pack(fill=tk.X, pady=10)
        
        # 显示点击信息的标签
        self.click_info_label = ttk.Label(heatmap_frame, text="Click on heatmap to view shapes", 
                                         foreground="blue")
        self.click_info_label.pack(pady=2)
        
    def create_attention_controls(self, parent):
        """创建Attention控制区域"""
        attention_frame = ttk.LabelFrame(parent, text="Attention Controls", padding=10)
        attention_frame.pack(fill=tk.X, padx=5, pady=5)
        
        # 第一个按钮：加载Attention文件
        ttk.Button(attention_frame, text="Load Attention Data (.npy)", 
                  command=self.load_attention_data).pack(fill=tk.X, pady=2)
        
        # 数据信息显示
        self.attention_info_label = ttk.Label(attention_frame, text="No attention data loaded", 
                                             foreground="red")
        self.attention_info_label.pack(pady=2)
        
        # 第二个按钮：Sample选择
        sample_frame = ttk.Frame(attention_frame)
        sample_frame.pack(fill=tk.X, pady=5)
        ttk.Label(sample_frame, text="Sample Number:").pack(side=tk.LEFT)
        self.attention_sample_var = tk.IntVar(value=1)
        self.attention_sample_spinbox = ttk.Spinbox(sample_frame, from_=1, to=1000,
                                                   textvariable=self.attention_sample_var, width=15)
        self.attention_sample_spinbox.pack(side=tk.RIGHT)
        
        # 第三个按钮：Shape数量选择
        count_frame = ttk.Frame(attention_frame)
        count_frame.pack(fill=tk.X, pady=5)
        ttk.Label(count_frame, text="Number of Shapes:").pack(side=tk.LEFT)
        self.attention_count_var = tk.IntVar(value=10)
        self.attention_count_spinbox = ttk.Spinbox(count_frame, from_=1, to=1000,
                                                  textvariable=self.attention_count_var, width=15)
        self.attention_count_spinbox.pack(side=tk.RIGHT)
        
        # 更新按钮
        ttk.Button(attention_frame, text="Update Attention Plot", 
                  command=self.update_attention_plot, style="Accent.TButton").pack(fill=tk.X, pady=10)
        
        # 下载索引按钮
        ttk.Button(attention_frame, text="Export Top X Shapes (.npy)", 
                  command=self.export_top_shapes).pack(fill=tk.X, pady=2)
        
    def create_right_visualization(self, parent):
        """创建右侧可视化区域"""
        viz_frame = ttk.Frame(parent)
        parent.add(viz_frame, weight=3)
        
        # 创建上下分割的可视化区域
        viz_paned = ttk.PanedWindow(viz_frame, orient=tk.VERTICAL)
        viz_paned.pack(fill=tk.BOTH, expand=True)
        
        # 上半部分：分为左右两个区域
        upper_frame = ttk.Frame(viz_paned)
        viz_paned.add(upper_frame, weight=1)
        
        upper_paned = ttk.PanedWindow(upper_frame, orient=tk.HORIZONTAL)
        upper_paned.pack(fill=tk.BOTH, expand=True)
        
        # 左侧：Heatmap
        heatmap_frame = ttk.LabelFrame(upper_paned, text="Heatmap Visualization")
        upper_paned.add(heatmap_frame, weight=1)
        
        # 右侧：Sequence显示
        sequence_frame = ttk.LabelFrame(upper_paned, text="Selected Shapes Sequence Visualization")
        upper_paned.add(sequence_frame, weight=1)
        
        # 下半部分：Attention
        lower_frame = ttk.LabelFrame(viz_paned, text="Attention Values Visualization")
        viz_paned.add(lower_frame, weight=1)
        
        # 创建matplotlib图形
        self.create_heatmap_plot(heatmap_frame)
        self.create_sequence_plot(sequence_frame)
        self.create_attention_plot(lower_frame)
        
    def create_heatmap_plot(self, parent):
        """创建Heatmap图形"""
        self.heatmap_fig = Figure(figsize=(8, 6), dpi=100)
        self.heatmap_canvas = FigureCanvasTkAgg(self.heatmap_fig, parent)
        self.heatmap_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        self.heatmap_toolbar = NavigationToolbar2Tk(self.heatmap_canvas, parent)
        self.heatmap_toolbar.update()
        
        # 绑定点击事件
        self.heatmap_canvas.mpl_connect('button_press_event', self.on_heatmap_click)
        
    def create_sequence_plot(self, parent):
        """创建序列图形显示"""
        self.sequence_fig = Figure(figsize=(8, 6), dpi=100)
        self.sequence_canvas = FigureCanvasTkAgg(self.sequence_fig, parent)
        self.sequence_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        self.sequence_toolbar = NavigationToolbar2Tk(self.sequence_canvas, parent)
        self.sequence_toolbar.update()
        
    def create_attention_plot(self, parent):
        """创建Attention图形"""
        self.attention_fig = Figure(figsize=(14, 6), dpi=100)
        self.attention_canvas = FigureCanvasTkAgg(self.attention_fig, parent)
        self.attention_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)
        
        # 添加工具栏
        self.attention_toolbar = NavigationToolbar2Tk(self.attention_canvas, parent)
        self.attention_toolbar.update()
        
        # 绑定鼠标事件
        self.attention_canvas.mpl_connect('motion_notify_event', self.on_attention_hover)
        self.attention_canvas.mpl_connect('axes_leave_event', self.on_attention_leave)
        
    def load_heatmap_data(self):
        """加载Heatmap数据"""
        filename = filedialog.askopenfilename(
            title="Select Heatmap Data File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.heatmap_data = np.load(filename)
                
                # 验证数据格式
                if len(self.heatmap_data.shape) != 3:
                    raise ValueError("Data should be 3D (sample, shape_number, shape_number)")
                if self.heatmap_data.shape[1] != self.heatmap_data.shape[2]:
                    raise ValueError("Second and third dimensions should be equal")
                
                # 更新控件范围
                sample_count = self.heatmap_data.shape[0]
                shape_count = self.heatmap_data.shape[1]
                
                self.heatmap_sample_spinbox.config(to=sample_count)
                self.heatmap_start_spinbox.config(to=shape_count)
                self.heatmap_end_spinbox.config(to=shape_count)
                self.heatmap_end_var.set(min(20, shape_count))
                
                # 更新信息显示
                info_text = f"Heatmap loaded: {self.heatmap_data.shape}"
                self.heatmap_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "Heatmap data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading heatmap data: {str(e)}")
                self.heatmap_info_label.config(text="Load failed", foreground="red")
    
    def load_shape_data(self):
        """加载Shape数据"""
        filename = filedialog.askopenfilename(
            title="Select Shape Data File",
            filetypes=[("NPZ files", "*.npz"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.shape_data = np.load(filename)
                self.arr_0 = self.shape_data['arr_0']  # (sample, shape_number, shape_length)
                self.arr_1 = self.shape_data['arr_1']  # (sample, shape_number, VP)
                
                # 验证数据格式
                if len(self.arr_0.shape) != 3 or len(self.arr_1.shape) != 3:
                    raise ValueError("Arrays should be 3D")
                
                # 更新信息显示
                info_text = f"Shape loaded: arr_0{self.arr_0.shape}, arr_1{self.arr_1.shape}"
                self.shape_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "Shape data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading shape data: {str(e)}")
                self.shape_info_label.config(text="Load failed", foreground="red")
    
    def load_xtrain_data(self):
        """加载X_train数据"""
        filename = filedialog.askopenfilename(
            title="Select X_train Data File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.x_train = np.load(filename)
                
                # 验证数据格式
                if len(self.x_train.shape) != 3:
                    raise ValueError("X_train should be 3D (sample, length, dimension_number)")
                
                # 更新信息显示
                info_text = f"X_train loaded: {self.x_train.shape}"
                self.xtrain_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "X_train data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading X_train data: {str(e)}")
                self.xtrain_info_label.config(text="Load failed", foreground="red")
    
    def load_attention_data(self):
        """加载Attention数据"""
        filename = filedialog.askopenfilename(
            title="Select Attention Data File",
            filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
        )
        if filename:
            try:
                self.attention_data = np.load(filename)
                
                # 验证数据格式
                if len(self.attention_data.shape) != 3:
                    raise ValueError("Data should be 3D (sample_number, shape_number, value_number)")
                
                # 对每个sample进行排序并保留原始索引
                self.process_attention_data()
                
                # 更新控件范围
                sample_count = self.attention_data.shape[0]
                shape_count = self.attention_data.shape[1]
                
                self.attention_sample_spinbox.config(to=sample_count)
                self.attention_count_spinbox.config(to=shape_count)
                self.attention_count_var.set(min(15, shape_count))
                
                # 更新信息显示
                info_text = f"Attention loaded: {self.attention_data.shape}"
                self.attention_info_label.config(text=info_text, foreground="green")
                
                messagebox.showinfo("Success", "Attention data loaded successfully!")
                
            except Exception as e:
                messagebox.showerror("Error", f"Error loading attention data: {str(e)}")
                self.attention_info_label.config(text="Load failed", foreground="red")
    
    def process_attention_data(self):
        """处理attention数据，按sample排序并保留原始索引"""
        if self.attention_data is None:
            return
        
        sample_count, shape_count, value_count = self.attention_data.shape
        
        # 为每个sample创建排序数据和索引
        self.sorted_attention_data = np.zeros_like(self.attention_data)
        self.original_indices = np.zeros((sample_count, shape_count), dtype=int)
        self.all_samples_indices = []  # 存储所有sample的索引列表
        
        for sample_idx in range(sample_count):
            # 计算每个shape的平均attention值
            mean_values = np.mean(self.attention_data[sample_idx], axis=1)
            
            # 获取从大到小的排序索引
            sorted_indices = np.argsort(mean_values)[::-1]
            
            # 存储排序后的数据和原始索引
            self.sorted_attention_data[sample_idx] = self.attention_data[sample_idx][sorted_indices]
            self.original_indices[sample_idx] = sorted_indices
            
            # 添加到所有sample的索引列表
            self.all_samples_indices.append([sample_idx + 1, sorted_indices.tolist()])
    
    def update_heatmap(self):
        """更新Heatmap显示"""
        if self.heatmap_data is None:
            messagebox.showwarning("Warning", "Please load heatmap data first!")
            return
        
        try:
            sample_idx = self.heatmap_sample_var.get() - 1  # 转换为0索引
            start_shape = self.heatmap_start_var.get() - 1  # 转换为0索引
            end_shape = self.heatmap_end_var.get()  # 保持为结束位置
            
            # 验证参数
            if sample_idx >= self.heatmap_data.shape[0] or sample_idx < 0:
                sample_idx = 0
            if start_shape < 0:
                start_shape = 0
            if end_shape > self.heatmap_data.shape[1]:
                end_shape = self.heatmap_data.shape[1]
            if start_shape >= end_shape:
                messagebox.showerror("Error", "Start shape must be less than end shape!")
                return
            
            # 清除之前的图形，但保留colorbar
            if self.current_heatmap_ax is not None:
                self.current_heatmap_ax.clear()
            else:
                self.heatmap_fig.clear()
            
            # 提取数据切片
            data_slice = self.heatmap_data[sample_idx, start_shape:end_shape, start_shape:end_shape]
            
            # 创建或更新heatmap
            if self.current_heatmap_ax is None:
                self.current_heatmap_ax = self.heatmap_fig.add_subplot(1, 1, 1)
            
            # 使用学术界专用的颜色（viridis）
            im = self.current_heatmap_ax.imshow(data_slice, cmap='viridis', aspect='auto', interpolation='nearest')
            
            # 设置标题
            self.current_heatmap_ax.set_title(f'Heatmap: Sample {sample_idx+1}, Shapes {start_shape+1}-{end_shape}')
            
            # 隐藏坐标轴数字
            self.current_heatmap_ax.set_xticks([])
            self.current_heatmap_ax.set_yticks([])
            self.current_heatmap_ax.set_xlabel('Shape Index')
            self.current_heatmap_ax.set_ylabel('Shape Index')
            
            # 存储当前显示的信息，用于点击事件
            self.current_start_shape = start_shape
            self.current_end_shape = end_shape
            self.current_sample_idx = sample_idx
            
            # 只在第一次或colorbar不存在时添加colorbar
            if self.heatmap_colorbar is None:
                self.heatmap_colorbar = self.heatmap_fig.colorbar(im, ax=self.current_heatmap_ax)
                self.heatmap_colorbar.set_label('Value')
            else:
                # 更新现有colorbar的映射
                self.heatmap_colorbar.mappable.set_array(data_slice)
                self.heatmap_colorbar.mappable.set_clim(vmin=data_slice.min(), vmax=data_slice.max())
            
            self.heatmap_fig.tight_layout()
            self.heatmap_canvas.draw()
            
        except Exception as e:
            messagebox.showerror("Error", f"Error updating heatmap: {str(e)}")
    
    def on_heatmap_click(self, event):
        """处理heatmap点击事件"""
        if event.inaxes != self.current_heatmap_ax:
            return
        
        if self.heatmap_data is None or self.current_heatmap_ax is None:
            return
        
        try:
            # 获取点击的像素坐标
            x, y = int(event.xdata), int(event.ydata)
            
            # 转换为实际的shape坐标
            actual_x = x + self.current_start_shape
            actual_y = y + self.current_start_shape
            
            # 更新信息显示并保存选中的shape信息
            info_text = f"Selected: Shape X={actual_x+1}, Shape Y={actual_y+1}"
            self.click_info_label.config(text=info_text, foreground="blue")
            
            # 保存选中的shape信息（两个shape index）
            self.selected_shape_info = {
                'sample': self.current_sample_idx,
                'shape_x': actual_x,
                'shape_y': actual_y
            }
            
            # 显示选中的两个shape的序列
            self.display_selected_shapes()
            
        except (TypeError, IndexError):
            # 点击超出范围时忽略
            pass
    
    def display_selected_shapes(self):
        """显示选中的两个shape的序列"""
        if (self.selected_shape_info is None or 
            self.arr_0 is None or self.x_train is None):
            return
        
        try:
            sample_idx = self.selected_shape_info['sample']
            shape_x_idx = self.selected_shape_info['shape_x']
            shape_y_idx = self.selected_shape_info['shape_y']
            
            # 清除之前的图形
            self.sequence_fig.clear()
            
            # 创建四个子图：2行2列
            ax1 = self.sequence_fig.add_subplot(2, 2, 1)
            ax2 = self.sequence_fig.add_subplot(2, 2, 2)
            ax3 = self.sequence_fig.add_subplot(2, 2, 3)
            ax4 = self.sequence_fig.add_subplot(2, 2, 4)
            
            # 显示第一个shape (X) 的arr_0数据
            if shape_x_idx < self.arr_0.shape[1]:
                shape_x_data = self.arr_0[sample_idx, shape_x_idx, :]
                ax1.plot(shape_x_data, linewidth=2, color='blue')
                ax1.set_title(f'Shape {shape_x_idx+1} (arr_0)')
                ax1.set_xlabel('Shape Length')
                ax1.set_ylabel('Value')
                ax1.grid(True, alpha=0.3)
            
            # 显示第二个shape (Y) 的arr_0数据
            if shape_y_idx < self.arr_0.shape[1]:
                shape_y_data = self.arr_0[sample_idx, shape_y_idx, :]
                ax2.plot(shape_y_data, linewidth=2, color='red')
                ax2.set_title(f'Shape {shape_y_idx+1} (arr_0)')
                ax2.set_xlabel('Shape Length')
                ax2.set_ylabel('Value')
                ax2.grid(True, alpha=0.3)
            
            # 显示X_train数据的相关维度
            if sample_idx < self.x_train.shape[0]:
                # 使用shape索引作为维度索引（如果维度足够）
                dim_x = min(shape_x_idx, self.x_train.shape[2] - 1)
                dim_y = min(shape_y_idx, self.x_train.shape[2] - 1)
                
                x_data_x = self.x_train[sample_idx, :, dim_x]
                x_data_y = self.x_train[sample_idx, :, dim_y]
                
                ax3.plot(x_data_x, linewidth=2, color='blue')
                ax3.set_title(f'X_train: Dim {dim_x+1}')
                ax3.set_xlabel('Time Length')
                ax3.set_ylabel('Value')
                ax3.grid(True, alpha=0.3)
                
                ax4.plot(x_data_y, linewidth=2, color='red')
                ax4.set_title(f'X_train: Dim {dim_y+1}')
                ax4.set_xlabel('Time Length')
                ax4.set_ylabel('Value')
                ax4.grid(True, alpha=0.3)
            
            self.sequence_fig.suptitle(f'Selected Shapes Comparison: Sample {sample_idx+1}')
            self.sequence_fig.tight_layout()
            self.sequence_canvas.draw()
            
        except Exception as e:
            print(f"Error displaying shapes: {str(e)}")
    
    def update_attention_plot(self):
        """更新Attention图表显示"""
        if self.attention_data is None or self.sorted_attention_data is None:
            messagebox.showwarning("Warning", "Please load attention data first!")
            return
        
        try:
            sample_idx = self.attention_sample_var.get() - 1  # 转换为0索引
            shape_count = self.attention_count_var.get()
            
            # 验证参数
            if sample_idx >= self.attention_data.shape[0] or sample_idx < 0:
                sample_idx = 0
            if shape_count > self.attention_data.shape[1]:
                shape_count = self.attention_data.shape[1]
            
            # 清除之前的图形和注释
            self.attention_fig.clear()
            self.attention_annotations = []
            
            # 获取排序后的数据和原始索引
            sorted_data = self.sorted_attention_data[sample_idx, :shape_count]
            original_idx = self.original_indices[sample_idx, :shape_count]
            
            # 计算平均值用于柱状图显示
            mean_values = np.mean(sorted_data, axis=1)
            
            # 创建柱状图
            self.attention_ax = self.attention_fig.add_subplot(1, 1, 1)
            
            self.attention_bars = self.attention_ax.bar(range(shape_count), mean_values, 
                                                      color='steelblue', alpha=0.7)
            
            # 设置标题和标签
            self.attention_ax.set_title(f'Attention Values: Sample {sample_idx+1}, Top {shape_count} Shapes (High to Low)')
            self.attention_ax.set_xlabel('Rank (High to Low)')
            self.attention_ax.set_ylabel('Attention Value')
            self.attention_ax.grid(True, alpha=0.3)
            
            # 不显示横坐标序号
            self.attention_ax.set_xticks([])
            
            # 存储原始索引用于hover显示
            self.current_attention_indices = original_idx
            
            self.attention_fig.tight_layout()
            self.attention_canvas.draw()
            
        except Exception as e:
            messagebox.showerror("Error", f"Error updating attention plot: {str(e)}")
    
    def on_attention_hover(self, event):
        """处理attention plot的鼠标悬停事件"""
        if event.inaxes != getattr(self, 'attention_ax', None):
            return
        
        if not hasattr(self, 'attention_bars') or not hasattr(self, 'current_attention_indices'):
            return
        
        try:
            # 清除之前的注释
            for annotation in self.attention_annotations:
                annotation.remove()
            self.attention_annotations = []
            
            # 检查鼠标是否在某个柱子上
            for i, bar in enumerate(self.attention_bars):
                if bar.contains(event)[0]:
                    # 显示原始索引
                    original_idx = self.current_attention_indices[i]
                    height = bar.get_height()
                    
                    annotation = self.attention_ax.annotate(
                        f'Original Idx: {original_idx}',
                        xy=(bar.get_x() + bar.get_width()/2., height),
                        xytext=(0, 10), textcoords='offset points',
                        ha='center', va='bottom',
                        bbox=dict(boxstyle='round,pad=0.3', facecolor='yellow', alpha=0.7),
                        fontsize=9
                    )
                    self.attention_annotations.append(annotation)
                    break
            
            self.attention_canvas.draw_idle()
            
        except Exception as e:
            pass  # 忽略hover错误
    
    def on_attention_leave(self, event):
        """处理鼠标离开attention plot事件"""
        try:
            # 清除所有注释
            for annotation in self.attention_annotations:
                annotation.remove()
            self.attention_annotations = []
            self.attention_canvas.draw_idle()
        except:
            pass
    
    def export_top_shapes(self):
        """导出所有sample的排序索引为NPY文件"""
        if not self.all_samples_indices:
            messagebox.showwarning("Warning", "No attention data loaded! Please load attention data first.")
            return
        
        try:
            # 选择保存位置
            filename = filedialog.asksaveasfilename(
                defaultextension=".npy",
                filetypes=[("NPY files", "*.npy"), ("All files", "*.*")]
            )
            
            if filename:
                # 保存为NPY格式
                np.save(filename, self.all_samples_indices)
                
                # 显示保存信息
                info_msg = f"Top shapes indices exported successfully!\n\n"
                info_msg += f"Total samples: {len(self.all_samples_indices)}\n"
                info_msg += f"Format: [[sample1, [indices...]], [sample2, [indices...]], ...]\n"
                info_msg += f"File: {filename}\n\n"
                info_msg += "Example structure:\n"
                if self.all_samples_indices:
                    example = self.all_samples_indices[0]
                    info_msg += f"[{example[0]}, {example[1][:5]}...]"
                
                messagebox.showinfo("Success", info_msg)
                
        except Exception as e:
            messagebox.showerror("Error", f"Error exporting indices: {str(e)}")

def main():
    root = tk.Tk()
    
    # 设置应用程序样式
    try:
        style = ttk.Style()
        available_themes = style.theme_names()
        if 'clam' in available_themes:
            style.theme_use('clam')
        elif 'alt' in available_themes:
            style.theme_use('alt')
    except:
        pass
    
    app = AdvancedVisualizationApp(root)
    
    # 设置窗口关闭事件
    def on_closing():
        if messagebox.askokcancel("Exit", "Are you sure you want to exit?"):
            root.destroy()
    
    root.protocol("WM_DELETE_WINDOW", on_closing)
    
    # 启动主循环
    root.mainloop()

if __name__ == "__main__":
    main()

  self.attention_fig.tight_layout()
  self.sequence_fig.tight_layout()
