<a href="https://colab.research.google.com/github/DeshengKong/Hellow-World/blob/master/Cervical_Spinal_DTI_DTT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
import tkinter as tk
from tkinter import ttk, filedialog, messagebox, colorchooser
import vtk
from vtk.tk.vtkTkRenderWindowInteractor import vtkTkRenderWindowInteractor
import numpy as np
import pydicom
import os
from dipy.io.image import load_nifti, save_nifti
from dipy.core.gradients import gradient_table
from dipy.reconst.dti import TensorModel
from dipy.reconst.csdeconv import ConstrainedSphericalDeconvModel
from dipy.tracking.local_tracking import LocalTracking
from dipy.tracking.stopping_criterion import ThresholdStoppingCriterion
from dipy.tracking import utils
from dipy.data import get_sphere
from dipy.direction import peaks_from_model
from dipy.segment.mask import median_otsu
from dipy.viz import window, actor
import threading
import queue
import SimpleITK as sitk
from pathlib import Path
import nibabel as nib
import matplotlib.pyplot as plt
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import datetime
import json
import logging
import time

# 设置日志
logging.basicConfig(level=logging.INFO,
                   format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
                   filename='spinal_dti.log')
logger = logging.getLogger('SpinalDTI')

class DicomDTIProcessor:
    """增强的DICOM DTI处理器，支持多种厂商的数据格式"""

    # 不同厂商DTI参数标签映射
    VENDOR_MAPPINGS = {
        'SIEMENS': {
            'bvalue': (0x0019, 0x100c),
            'gradient_x': (0x0019, 0x100e),
            'gradient_y': (0x0019, 0x100f),
            'gradient_z': (0x0019, 0x1010)
        },
        'GE': {
            'bvalue': (0x0043, 0x1039),
            'gradient_x': (0x0019, 0x10bb),
            'gradient_y': (0x0019, 0x10bc),
            'gradient_z': (0x0019, 0x10bd)
        },
        'PHILIPS': {
            'bvalue': (0x2001, 0x1003),
            'gradient_x': (0x2005, 0x10b0),
            'gradient_y': (0x2005, 0x10b1),
            'gradient_z': (0x2005, 0x10b2)
        }
    }

    def __init__(self):
        self.dicom_files = []
        self.gradient_directions = []
        self.b_values = []
        self.volume_data = None
        self.slice_location_sort_idx = None
        self.vendor = None
        self.metadata = {}
        self.series_info = {}

    def load_dicom_series(self, dicom_dir, progress_callback=None):
        """加载DICOM序列并提取DTI信息，支持进度回调"""
        start_time = time.time()

        # 获取DICOM文件列表
        dicom_files = self._get_dicom_files(dicom_dir)
        if not dicom_files:
            raise ValueError(f"找不到DICOM文件: {dicom_dir}")

        # 检测厂商并加载元数据
        self._detect_vendor(dicom_files[0])
        logger.info(f"检测到厂商: {self.vendor}")

        if progress_callback:
            progress_callback(10, f"检测到厂商: {self.vendor}")

        # 使用SimpleITK读取体积数据
        reader = sitk.ImageSeriesReader()
        dicom_names = reader.GetGDCMSeriesFileNames(dicom_dir)
        reader.SetFileNames(dicom_names)
        image = reader.Execute()

        # 获取图像信息
        self._extract_image_info(image)

        # 转换为numpy数组
        self.volume_data = sitk.GetArrayFromImage(image)

        if progress_callback:
            progress_callback(40, "DICOM体积数据加载完成")

        # 提取b值和梯度方向
        self.extract_bvals_bvecs(dicom_names, progress_callback)

        # 验证数据完整性
        self._validate_data()

        elapsed_time = time.time() - start_time
        logger.info(f"DICOM加载完成，耗时: {elapsed_time:.2f}秒")

        if progress_callback:
            progress_callback(100, "DICOM数据加载与预处理完成")

        return self.volume_data, self.b_values, self.gradient_directions, self.metadata

    def _get_dicom_files(self, dicom_dir):
        """获取目录中的所有DICOM文件"""
        dicom_files = []
        for root, _, files in os.walk(dicom_dir):
            for file in files:
                try:
                    file_path = os.path.join(root, file)
                    # 快速检查文件是否为DICOM
                    with open(file_path, 'rb') as f:
                        if f.read(128)[128:132] == b'DICM':
                            dicom_files.append(file_path)
                except:
                    continue
        return dicom_files

    def _detect_vendor(self, first_dicom_file):
        """检测DICOM设备厂商"""
        ds = pydicom.dcmread(first_dicom_file, stop_before_pixels=True)

        self.metadata['patient_name'] = str(ds.get('PatientName', ''))
        self.metadata['patient_id'] = str(ds.get('PatientID', ''))
        self.metadata['study_date'] = str(ds.get('StudyDate', ''))
        self.metadata['study_time'] = str(ds.get('StudyTime', ''))
        self.metadata['modality'] = str(ds.get('Modality', ''))

        # 检测厂商
        manufacturer = str(ds.get('Manufacturer', '')).upper()

        if 'SIEMENS' in manufacturer:
            self.vendor = 'SIEMENS'
        elif 'GE' in manufacturer:
            self.vendor = 'GE'
        elif 'PHILIPS' in manufacturer:
            self.vendor = 'PHILIPS'
        else:
            self.vendor = 'UNKNOWN'
            logger.warning(f"未知厂商: {manufacturer}")

        self.metadata['manufacturer'] = manufacturer
        self.metadata['vendor'] = self.vendor

    def _extract_image_info(self, image):
        """提取图像信息"""
        self.metadata['dimensions'] = image.GetSize()
        self.metadata['spacing'] = image.GetSpacing()
        self.metadata['origin'] = image.GetOrigin()
        self.metadata['direction'] = image.GetDirection()

    def extract_bvals_bvecs(self, dicom_files, progress_callback=None):
        """从DICOM文件中提取b值和梯度方向"""
        self.b_values = []
        self.gradient_directions = []

        total_files = len(dicom_files)

        for i, dcm_file in enumerate(dicom_files):
            if i % 10 == 0 and progress_callback:
                progress_callback(40 + (i / total_files) * 30, f"提取DTI参数: {i}/{total_files}")

            ds = pydicom.dcmread(dcm_file)

            # 从DICOM中提取切片信息
            self._extract_slice_info(ds, i)

            # 提取b值
            b_value = self._extract_bvalue(ds)

            # 提取梯度方向
            gradient = self._extract_gradient(ds)

            self.b_values.append(b_value)
            self.gradient_directions.append(gradient)

        # 对数据进行排序和重组
        self._reorganize_diffusion_data()

        if progress_callback:
            progress_callback(70, "提取DTI参数完成")

    def _extract_slice_info(self, ds, index):
        """提取切片信息"""
        if index == 0:
            # 第一个切片，初始化序列信息
            self.series_info['series_description'] = str(ds.get('SeriesDescription', ''))
            self.series_info['series_number'] = str(ds.get('SeriesNumber', ''))
            self.series_info['slice_thickness'] = float(ds.get('SliceThickness', 0))

        # 收集图像位置信息，用于后续排序
        if hasattr(ds, 'ImagePositionPatient'):
            if not hasattr(self, 'slice_positions'):
                self.slice_positions = []
            self.slice_positions.append((index, ds.ImagePositionPatient[2]))

    def _extract_bvalue(self, ds):
        """根据不同厂商提取b值"""
        # 首先尝试标准标签
        if hasattr(ds, 'DiffusionBValue'):
            return float(ds.DiffusionBValue)

        # 然后尝试厂商特定标签
        if self.vendor in self.VENDOR_MAPPINGS:
            tag = self.VENDOR_MAPPINGS[self.vendor]['bvalue']
            try:
                if tag in ds:
                    return float(ds[tag].value)
            except:
                pass

        # 对于GE的特殊处理
        if self.vendor == 'GE':
            try:
                # GE的b值通常存储在私有标签的字符串中
                b_value_str = str(ds[0x0043, 0x1039].value)
                # 提取b值的数字部分
                for part in b_value_str.split('\\'):
                    if 'b' in part.lower():
                        return float(part.lower().replace('b', '').strip())
            except:
                pass

        # 如果都失败了，假设为b=0图像
        return 0.0

    def _extract_gradient(self, ds):
        """根据不同厂商提取梯度方向"""
        # 首先尝试标准标签
        if hasattr(ds, 'DiffusionGradientOrientation'):
            return [float(x) for x in ds.DiffusionGradientOrientation]

        # 然后尝试厂商特定标签
        if self.vendor in self.VENDOR_MAPPINGS:
            try:
                x_tag = self.VENDOR_MAPPINGS[self.vendor]['gradient_x']
                y_tag = self.VENDOR_MAPPINGS[self.vendor]['gradient_y']
                z_tag = self.VENDOR_MAPPINGS[self.vendor]['gradient_z']

                if x_tag in ds and y_tag in ds and z_tag in ds:
                    return [
                        float(ds[x_tag].value),
                        float(ds[y_tag].value),
                        float(ds[z_tag].value)
                    ]
            except:
                pass

        # 默认为零向量（适用于b=0图像）
        return [0, 0, 0]

    def _reorganize_diffusion_data(self):
        """重组扩散数据，确保正确的顺序"""
        if hasattr(self, 'slice_positions') and self.slice_positions:
            # 按照切片位置排序
            self.slice_positions.sort(key=lambda x: x[1])
            sorted_indices = [x[0] for x in self.slice_positions]

            # 重新排序数据
            self.volume_data = self.volume_data[sorted_indices]

            # 重组b值和梯度方向
            sorted_bvals = []
            sorted_bvecs = []

            for idx in sorted_indices:
                sorted_bvals.append(self.b_values[idx])
                sorted_bvecs.append(self.gradient_directions[idx])

            self.b_values = sorted_bvals
            self.gradient_directions = sorted_bvecs

    def _validate_data(self):
        """验证DTI数据的完整性"""
        # 检查体积数据和梯度方向/b值的数量是否匹配
        if len(self.b_values) != self.volume_data.shape[0]:
            logger.warning(f"b值数量 ({len(self.b_values)}) 与体积数量 ({self.volume_data.shape[0]}) 不匹配")

            # 尝试修复
            if len(self.b_values) > self.volume_data.shape[0]:
                self.b_values = self.b_values[:self.volume_data.shape[0]]
                self.gradient_directions = self.gradient_directions[:self.volume_data.shape[0]]
            else:
                # 为缺少的切片生成零值
                missing = self.volume_data.shape[0] - len(self.b_values)
                self.b_values.extend([0] * missing)
                self.gradient_directions.extend([[0,0,0]] * missing)

        # 检查b=0图像是否存在
        if 0 not in self.b_values:
            logger.warning("未找到b=0图像，DTI分析可能不准确")

        # 检查梯度方向是否有效
        non_zero_gradients = [g for i, g in enumerate(self.gradient_directions) if self.b_values[i] > 0]
        if not non_zero_gradients:
            logger.warning("未找到有效的梯度方向")

        # 将数据转换为numpy数组
        self.b_values = np.array(self.b_values)
        self.gradient_directions = np.array(self.gradient_directions)

class DTIPreprocessor:
    """DTI数据预处理类"""

    def __init__(self):
        self.original_data = None
        self.processed_data = None
        self.mask = None

    def preprocess(self, volume_data, bvals, bvecs, options, progress_callback=None):
        """执行DTI数据预处理"""
        self.original_data = volume_data.copy()
        self.processed_data = volume_data.copy()

        # 进度追踪
        total_steps = sum(1 for opt in options.values() if opt)
        current_step = 0

        # 生成脑/脊髓掩码
        if options.get('create_mask', False):
            if progress_callback:
                progress_callback(
                    int((current_step / total_steps) * 100),
                    "创建脑/脊髓掩码..."
                )

            self._create_mask()
            current_step += 1

        # 运动校正
        if options.get('motion_correction', False):
            if progress_callback:
                progress_callback(
                    int((current_step / total_steps) * 100),
                    "应用运动校正..."
                )

            self.processed_data = self._apply_motion_correction(self.processed_data)
            current_step += 1

        # 涡流校正
        if options.get('eddy_correction', False):
            if progress_callback:
                progress_callback(
                    int((current_step / total_steps) * 100),
                    "应用涡流校正..."
                )

            self.processed_data = self._apply_eddy_correction(self.processed_data)
            current_step += 1

        # 去噪
        if options.get('denoising', False):
            if progress_callback:
                progress_callback(
                    int((current_step / total_steps) * 100),
                    "应用去噪处理..."
                )

            self.processed_data = self._apply_denoising(self.processed_data)
            current_step += 1

        # 偏置场校正
        if options.get('bias_correction', False):
            if progress_callback:
                progress_callback(
                    int((current_step / total_steps) * 100),
                    "应用偏置场校正..."
                )

            self.processed_data = self._apply_bias_correction(self.processed_data)
            current_step += 1

        # 完成
        if progress_callback:
            progress_callback(100, "预处理完成")

        return self.processed_data

    def _create_mask(self):
        """创建脑/脊髓掩码"""
        # 使用b=0图像创建掩码
        b0 = self.processed_data[0]  # 假设第一个体积是b=0

        # 使用中值滤波和阈值分割创建掩码
        self.mask, _ = median_otsu(b0, median_radius=2, numpass=2)

        # 应用掩码
        for i in range(self.processed_data.shape[0]):
            self.processed_data[i] = self.processed_data[i] * self.mask

    def _apply_motion_correction(self, volume_data):
        """应用运动校正"""
        # 获取参考体积 (b=0)
        reference_volume = volume_data[0]
        fixed_image = sitk.GetImageFromArray(reference_volume)

        corrected_volumes = [reference_volume]

        for i in range(1, volume_data.shape[0]):
            moving_image = sitk.GetImageFromArray(volume_data[i])

            # 配准设置
            registration_method = sitk.ImageRegistrationMethod()

            # 多分辨率策略
            registration_method.SetShrinkFactorsPerLevel([4, 2, 1])
            registration_method.SetSmoothingSigmasPerLevel([2, 1, 0])

            # 相似性度量
            registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
            registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
            registration_method.SetMetricSamplingPercentage(0.2)

            # 优化器
            registration_method.SetOptimizerAsGradientDescent(
                learningRate=1.0,
                numberOfIterations=200,
                convergenceMinimumValue=1e-6,
                convergenceWindowSize=10
            )

            # 变换
            registration_method.SetTransformAsRigid3D()

            # 初始化
            initial_transform = sitk.CenteredTransformInitializer(
                fixed_image,
                moving_image,
                sitk.Euler3DTransform(),
                sitk.CenteredTransformInitializerFilter.GEOMETRY
            )
            registration_method.SetInitialTransform(initial_transform)

            # 执行配准
            final_transform = registration_method.Execute(fixed_image, moving_image)

            # 应用变换
            resampler = sitk.ResampleImageFilter()
            resampler.SetReferenceImage(fixed_image)
            resampler.SetInterpolator(sitk.sitkLinear)
            resampler.SetDefaultPixelValue(0)
            resampler.SetTransform(final_transform)

            corrected_image = resampler.Execute(moving_image)
            corrected_volumes.append(sitk.GetArrayFromImage(corrected_image))

        return np.stack(corrected_volumes)

    def _apply_eddy_correction(self, volume_data):
        """应用涡流校正"""
        # 实际应用中，应该使用FSL的eddy工具
        # 这里提供一个简化版实现

        # 获取参考体积 (b=0)
        reference_volume = volume_data[0]
        fixed_image = sitk.GetImageFromArray(reference_volume)

        corrected_volumes = [reference_volume]

        for i in range(1, volume_data.shape[0]):
            moving_image = sitk.GetImageFromArray(volume_data[i])

            # 非刚性配准
            registration_method = sitk.ImageRegistrationMethod()

            # 多分辨率策略
            registration_method.SetShrinkFactorsPerLevel([4, 2, 1])
            registration_method.SetSmoothingSigmasPerLevel([2, 1, 0])

            # 相似性度量
            registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)

            # 优化器
            registration_method.SetOptimizerAsGradientDescent(
                learningRate=1.0,
                numberOfIterations=100
            )

            # 非刚性变换 (B样条)
            transform_domain_mesh_size = [8] * fixed_image.GetDimension()
            initial_transform = sitk.BSplineTransformInitializer(
                fixed_image,
                transform_domain_mesh_size
            )
            registration_method.SetInitialTransform(initial_transform)

            # 执行配准
            final_transform = registration_method.Execute(fixed_image, moving_image)

            # 应用变换
            corrected_image = sitk.Resample(
                moving_image,
                fixed_image,
                final_transform,
                sitk.sitkLinear,
                0.0,
                moving_image.GetPixelID()
            )

            corrected_volumes.append(sitk.GetArrayFromImage(corrected_image))

        return np.stack(corrected_volumes)

    def _apply_denoising(self, volume_data):
        """应用去噪处理"""
        denoised_volumes = []

        for i in range(volume_data.shape[0]):
            volume = sitk.GetImageFromArray(volume_data[i])

            # 非局部均值滤波
            denoised = sitk.CurvatureAnisotropicDiffusion(
                volume,
                timeStep=0.0625,
                conductanceParameter=3.0,
                numberOfIterations=5
            )

            denoised_volumes.append(sitk.GetArrayFromImage(denoised))

        return np.stack(denoised_volumes)

    def _apply_bias_correction(self, volume_data):
        """应用偏置场校正"""
        corrected_volumes = []

        for i in range(volume_data.shape[0]):
            volume = sitk.GetImageFromArray(volume_data[i])

            # N4偏置场校正
            mask_image = sitk.GetImageFromArray(
                np.ones_like(volume_data[i], dtype=np.uint8)
            ) if self.mask is None else sitk.GetImageFromArray(self.mask.astype(np.uint8))

            corrector = sitk.N4BiasFieldCorrectionImageFilter()
            corrector.SetMaximumNumberOfIterations([4, 2, 1])

            # 应用校正
            corrected = corrector.Execute(volume, mask_image)
            corrected_volumes.append(sitk.GetArrayFromImage(corrected))

        return np.stack(corrected_volumes)

class DTIAnalyzer:
    """DTI分析类"""

    def __init__(self):
        self.gtab = None
        self.tensor_model = None
        self.tensor_fit = None
        self.fa = None
        self.md = None
        self.rd = None
        self.ad = None
        self.rgb = None
        self.mask = None
        self.streamlines = None
        self.peaks = None

    def setup(self, bvals, bvecs):
        """设置梯度表"""
        self.gtab = gradient_table(bvals, bvecs)

    def fit_tensor(self, data, mask=None, progress_callback=None):
        """拟合张量模型"""
        if progress_callback:
            progress_callback(0, "初始化张量模型...")

        self.tensor_model = TensorModel(self.gtab)

        if progress_callback:
            progress_callback(10, "应用张量拟合...")

        # 应用掩码（如果有）
        self.mask = mask
        if self.mask is not None:
            # 确保掩码与数据形状匹配
            if self.mask.shape != data.shape[1:]:
                raise ValueError(f"掩码形状 {self.mask.shape} 与数据形状 {data.shape[1:]} 不匹配")
            masked_data = np.zeros_like(data)
            for i in range(data.shape[0]):
                masked_data[i] = data[i] * self.mask
            data = masked_data

        self.tensor_fit = self.tensor_model.fit(data)

        if progress_callback:
            progress_callback(50, "计算DTI指标...")

        # 计算各向异性和各向同性指标
        self.fa = self.tensor_fit.fa
        self.md = self.tensor_fit.md
        self.ad = self.tensor_fit.ad
        self.rd = self.tensor_fit.rd
        self.rgb = self.tensor_fit.rgb

        if progress_callback:
            progress_callback(100, "张量拟合完成")

        return {
            'fa': self.fa,
            'md': self.md,
            'rd': self.rd,
            'ad': self.ad,
            'rgb': self.rgb
        }

    def compute_peaks(self, progress_callback=None):
        """计算峰值，为纤维追踪做准备"""
        if self.tensor_fit is None:
            raise ValueError("在计算峰值前，必须先进行张量拟合")

        if progress_callback:
            progress_callback(0, "计算扩散峰值...")

        sphere = get_sphere('symmetric724')
        self.peaks = peaks_from_model(
            model=self.tensor_model,
            data=self.tensor_fit.quadratic_form,
            sphere=sphere,
            relative_peak_threshold=0.5,
            min_separation_angle=25,
            normalize_peaks=True,
            mask=self.mask
        )

        if progress_callback:
            progress_callback(100, "峰值计算完成")

        return self.peaks

    def track_fibers(self, fa_threshold=0.2, max_angle=45.0, step_size=0.5, progress_callback=None):
        """执行纤维追踪"""
        if self.peaks is None:
            if progress_callback:
                progress_callback(0, "准备扩散峰值...")
            self.compute_peaks()

        if progress_callback:
            progress_callback(20, "设置纤维追踪参数...")

        # 设置停止条件
        stopping_criterion = ThresholdStoppingCriterion(self.fa, fa_threshold)

        # 转换角度为弧度
        max_angle_rad = np.deg2rad(max_angle)

        if progress_callback:
            progress_callback(40, "生成种子点...")

        # 生成种子点
        if self.mask is None:
            # 如果没有掩码，使用FA阈值创建一个
            self.mask = self.fa > fa_threshold

        seeds = utils.seeds_from_mask(
            self.mask,
            density=[2, 2, 2],
            affine=np.eye(4)
        )

        if progress_callback:
            progress_callback(60, "开始纤维追踪...")

        # 执行纤维追踪
        tracking_algorithm = LocalTracking(
            self.peaks,
            stopping_criterion,
            seeds,
            affine=np.eye(4),
            step_size=step_size,
            max_cross=1,
            maxlen=100,
            return_all=False
        )

        self.streamlines = tracking_algorithm.generate_streamlines()

        if progress_callback:
            progress_callback(100, "纤维追踪完成")

        return self.streamlines

    def get_roi_statistics(self, roi_mask):
        """计算ROI内的DTI指标统计"""
        if self.fa is None or self.md is None:
            raise ValueError("在计算ROI统计前，必须先进行张量拟合")

        # 确保ROI掩码有效
        if roi_mask.shape != self.fa.shape:
            raise ValueError(f"ROI掩码形状 {roi_mask.shape} 与FA图形状 {self.fa.shape} 不匹配")

        # 在ROI内提取指标值
        fa_values = self.fa[roi_mask]
        md_values = self.md[roi_mask]
        rd_values = self.rd[roi_mask]
        ad_values = self.ad[roi_mask]

        # 计算统计量
        stats = {
            'fa': {
                'mean': np.mean(fa_values),
                'std': np.std(fa_values),
                'min': np.min(fa_values),
                'max': np.max(fa_values),
                'median': np.median(fa_values)
            },
            'md': {
                'mean': np.mean(md_values),
                'std': np.std(md_values),
                'min': np.min(md_values),
                'max': np.max(md_values),
                'median': np.median(md_values)
            },
            'rd': {
                'mean': np.mean(rd_values),
                'std': np.std(rd_values),
                'min': np.min(rd_values),
                'max': np.max(rd_values),
                'median': np.median(rd_values)
            },
            'ad': {
                'mean': np.mean(ad_values),
                'std': np.std(ad_values),
                'min': np.min(ad_values),
                'max': np.max(ad_values),
                'median': np.median(ad_values)
            }
        }

        return stats

    def save_results(self, output_dir, metadata=None):
        """保存分析结果"""
        if not os.path.exists(output_dir):
            os.makedirs(output_dir)

        # 保存标量图
        if self.fa is not None:
            save_nifti(os.path.join(output_dir, 'fa.nii.gz'), self.fa, np.eye(4))

        if self.md is not None:
            save_nifti(os.path.join(output_dir, 'md.nii.gz'), self.md, np.eye(4))

        if self.rd is not None:
            save_nifti(os.path.join(output_dir, 'rd.nii.gz'), self.rd, np.eye(4))

        if self.ad is not None:
            save_nifti(os.path.join(output_dir, 'ad.nii.gz'), self.ad, np.eye(4))

        if self.rgb is not None:
            save_nifti(os.path.join(output_dir, 'rgb.nii.gz'), self.rgb, np.eye(4))

        # 保存纤维束
        if self.streamlines is not None:
            # 保存为TrackVis格式
            from dipy.io.streamline import save_trk
            save_trk(os.path.join(output_dir, 'streamlines.trk'),
                    self.streamlines,
                    np.eye(4),
                    self.fa.shape)

        # 保存元数据
        if metadata:
            with open(os.path.join(output_dir, 'metadata.json'), 'w') as f:
                json.dump(metadata, f, indent=2)

class SpinalDTIViewer:
    def __init__(self, root):
        self.root = root
        self.root.title("脊髓DTI分析与可视化工具 V2.0")
        self.root.geometry("1200x800")

        # 设置应用图标
        # self.root.iconbitmap("spine_icon.ico")  # 如果有图标文件

        # 初始化处理器和分析器
        self.processor = DicomDTIProcessor()
        self.preprocessor = DTIPreprocessor()
        self.analyzer = DTIAnalyzer()

        # 数据存储
        self.volume_data = None
        self.processed_data = None
        self.bvals = None
        self.bvecs = None
        self.tensor_fit = None
        self.streamlines = None
        self.rois = []
        self.roi_masks = {}
        self.current_roi_name = None
        self.metadata = {}

        # 显示设置
        self.current_slice = 0
        self.current_view = "axial"
        self.current_map = "fa"
        self.current_volume = 0  # 针对原始数据

        # 界面颜色设置
        self.bg_color = "#f0f0f0"
        self.fg_color = "#333333"
        self.accent_color = "#4a86e8"

        # 设置样式
        style = ttk.Style()
        style.configure("TButton", font=("Arial", 10), padx=5, pady=5)
        style.configure("TLabel", font=("Arial", 10))
        style.configure("TFrame", background=self.bg_color)
        style.configure("TLabelframe", background=self.bg_color)

        # 设置界面
        self.setup_main_gui()
        self.setup_menu()
        self.setup_tabs()
        self.setup_vtk()
        self.setup_interactor()

        # 状态追踪
        self.processing_queue = queue.Queue()
        self.processing_thread = None

        # 定期更新状态
        self.root.after(100, self.check_processing_queue)

    def setup_main_gui(self):
        """设置主界面框架"""
        # 主框架
        self.main_frame = ttk.Frame(self.root)
        self.main_frame.pack(fill=tk.BOTH, expand=True)

        # 分割窗口
        self.main_paned = ttk.PanedWindow(self.main_frame, orient=tk.HORIZONTAL)
        self.main_paned.pack(fill=tk.BOTH, expand=True)

        # 左侧控制面板
        self.control_frame = ttk.Frame(self.main_paned, width=300)

        # 右侧显示区域
        self.display_frame = ttk.Frame(self.main_paned)

        self.main_paned.add(self.control_frame, weight=1)
        self.main_paned.add(self.display_frame, weight=3)

    def setup_menu(self):
        """设置菜单栏"""
        menubar = tk.Menu(self.root)

        # 文件菜单
        file_menu = tk.Menu(menubar, tearoff=0)
        file_menu.add_command(label="加载DICOM目录", command=self.load_dicom_directory)
        file_menu.add_command(label="加载NIfTI文件", command=self.load_nifti_file)
        file_menu.add_separator()
        file_menu.add_command(label="保存结果", command=self.save_results)
        file_menu.add_command(label="导出报告", command=self.export_report)
        file_menu.add_separator()
        file_menu.add_command(label="退出", command=self.root.quit)
        menubar.add_cascade(label="文件", menu=file_menu)

        # 视图菜单
        view_menu = tk.Menu(menubar, tearoff=0)
        view_menu.add_command(label="轴向视图", command=lambda: self.set_view("axial"))
        view_menu.add_command(label="矢状位视图", command=lambda: self.set_view("sagittal"))
        view_menu.add_command(label="冠状位视图", command=lambda: self.set_view("coronal"))
        view_menu.add_separator()
        view_menu.add_command(label="3D视图", command=self.show_3d_view)
        menubar.add_cascade(label="视图", menu=view_menu)

        # 分析菜单
        analysis_menu = tk.Menu(menubar, tearoff=0)
        analysis_menu.add_command(label="DTI张量分析", command=self.start_dti_analysis)
        analysis_menu.add_command(label="纤维追踪", command=self.start_fiber_tracking)
        analysis_menu.add_separator()
        analysis_menu.add_command(label="ROI分析", command=self.show_roi_analysis)
        menubar.add_cascade(label="分析", menu=analysis_menu)

        # 帮助菜单
        help_menu = tk.Menu(menubar, tearoff=0)
        help_menu.add_command(label="帮助内容", command=self.show_help)
        help_menu.add_command(label="关于", command=self.show_about)
        menubar.add_cascade(label="帮助", menu=help_menu)

        self.root.config(menu=menubar)

    def setup_tabs(self):
        """设置选项卡"""
        # 选项卡控件
        self.tab_control = ttk.Notebook(self.control_frame)
        self.tab_control.pack(fill=tk.BOTH, expand=True)

        # 数据加载选项卡
        self.data_tab = ttk.Frame(self.tab_control)
        self.tab_control.add(self.data_tab, text="数据")
        self.setup_data_tab()

        # 预处理选项卡
        self.preprocess_tab = ttk.Frame(self.tab_control)
        self.tab_control.add(self.preprocess_tab, text="预处理")
        self.setup_preprocess_tab()

        # 分析选项卡
        self.analysis_tab = ttk.Frame(self.tab_control)
        self.tab_control.add(self.analysis_tab, text="分析")
        self.setup_analysis_tab()

        # ROI选项卡
        self.roi_tab = ttk.Frame(self.tab_control)
        self.tab_control.add(self.roi_tab, text="ROI")
        self.setup_roi_tab()

        # 设置选项卡
        self.settings_tab = ttk.Frame(self.tab_control)
        self.tab_control.add(self.settings_tab, text="设置")
        self.setup_settings_tab()

        # 状态信息
        self.status_frame = ttk.Frame(self.control_frame)
        self.status_frame.pack(fill=tk.X, padx=5, pady=5)

        # 进度条
        self.progress = ttk.Progressbar(self.status_frame, mode='determinate')
        self.progress.pack(fill=tk.X, pady=2)

        # 状态标签
        self.status_label = ttk.Label(self.status_frame, text="就绪")
        self.status_label.pack(anchor=tk.W)

    def setup_data_tab(self):
        """设置数据加载选项卡"""
        # DICOM加载
        dicom_frame = ttk.LabelFrame(self.data_tab, text="DICOM数据")
        dicom_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Button(dicom_frame,
                  text="选择DICOM目录",
                  command=self.load_dicom_directory).pack(fill=tk.X, padx=5, pady=5)

        # NIfTI加载
        nifti_frame = ttk.LabelFrame(self.data_tab, text="NIfTI数据")
        nifti_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Button(nifti_frame,
                  text="加载DTI数据",
                  command=self.load_nifti_file).pack(fill=tk.X, padx=5, pady=2)

        ttk.Button(nifti_frame,
                  text="加载b值文件",
                  command=self.load_bval_file).pack(fill=tk.X, padx=5, pady=2)

        ttk.Button(nifti_frame,
                  text="加载梯度方向文件",
                  command=self.load_bvec_file).pack(fill=tk.X, padx=5, pady=2)

        # 数据信息显示
        info_frame = ttk.LabelFrame(self.data_tab, text="数据信息")
        info_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        self.data_info_text = tk.Text(info_frame, height=10, width=30,
                                     font=("Courier", 9), wrap=tk.WORD)
        self.data_info_text.pack(fill=tk.BOTH, expand=True, padx=2, pady=2)
        self.data_info_text.insert(tk.END, "未加载数据")
        self.data_info_text.config(state=tk.DISABLED)

    def setup_preprocess_tab(self):
        """设置预处理选项卡"""
        # 预处理选项
        options_frame = ttk.LabelFrame(self.preprocess_tab, text="预处理选项")
        options_frame.pack(fill=tk.X, padx=5, pady=5)

        # 生成掩码
        self.create_mask_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(options_frame,
                       text="生成脑/脊髓掩码",
                       variable=self.create_mask_var).pack(anchor=tk.W, padx=5, pady=2)

        # 运动校正
        self.motion_correction_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(options_frame,
                       text="运动校正",
                       variable=self.motion_correction_var).pack(anchor=tk.W, padx=5, pady=2)

        # 涡流校正
        self.eddy_correction_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(options_frame,
                       text="涡流校正",
                       variable=self.eddy_correction_var).pack(anchor=tk.W, padx=5, pady=2)

        # 去噪
        self.denoising_var = tk.BooleanVar(value=True)
        ttk.Checkbutton(options_frame,
                       text="图像去噪",
                       variable=self.denoising_var).pack(anchor=tk.W, padx=5, pady=2)

        # 偏置场校正
        self.bias_correction_var = tk.BooleanVar(value=False)
        ttk.Checkbutton(options_frame,
                       text="偏置场校正",
                       variable=self.bias_correction_var).pack(anchor=tk.W, padx=5, pady=2)

        # 执行预处理按钮
        ttk.Button(self.preprocess_tab,
                  text="执行预处理",
                  command=self.start_preprocessing).pack(fill=tk.X, padx=5, pady=5)

        # 预处理流程图
        workflow_frame = ttk.LabelFrame(self.preprocess_tab, text="预处理流程")
        workflow_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        # 添加流程图（简化版）
        steps = [
            "原始数据",
            "↓",
            "脑/脊髓掩码",
            "↓",
            "运动校正",
            "↓",
            "涡流校正",
            "↓",
            "去噪",
            "↓",
            "偏置场校正",
            "↓",
            "预处理完成"
        ]

        for step in steps:
            ttk.Label(workflow_frame, text=step, anchor=tk.CENTER).pack(fill=tk.X)

    def setup_analysis_tab(self):
        """设置分析选项卡"""
        # DTI分析选项
        dti_frame = ttk.LabelFrame(self.analysis_tab, text="DTI张量分析")
        dti_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Button(dti_frame,
                  text="执行DTI分析",
                  command=self.start_dti_analysis).pack(fill=tk.X, padx=5, pady=5)

        # DTI参数显示选项
        param_frame = ttk.LabelFrame(dti_frame, text="DTI参数图")
        param_frame.pack(fill=tk.X, padx=5, pady=5)

        self.dti_map_var = tk.StringVar(value="fa")
        ttk.Radiobutton(param_frame, text="FA",
                       variable=self.dti_map_var,
                       value="fa",
                       command=self.update_dti_map).pack(anchor=tk.W)
        ttk.Radiobutton(param_frame, text="MD",
                       variable=self.dti_map_var,
                       value="md",
                       command=self.update_dti_map).pack(anchor=tk.W)
        ttk.Radiobutton(param_frame, text="AD",
                       variable=self.dti_map_var,
                       value="ad",
                       command=self.update_dti_map).pack(anchor=tk.W)
        ttk.Radiobutton(param_frame, text="RD",
                       variable=self.dti_map_var,
                       value="rd",
                       command=self.update_dti_map).pack(anchor=tk.W)
        ttk.Radiobutton(param_frame, text="RGB",
                       variable=self.dti_map_var,
                       value="rgb",
                       command=self.update_dti_map).pack(anchor=tk.W)

        # 纤维追踪选项
        tracking_frame = ttk.LabelFrame(self.analysis_tab, text="纤维追踪")
        tracking_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Label(tracking_frame, text="FA阈值:").pack(anchor=tk.W, padx=5)
        self.fa_threshold = ttk.Entry(tracking_frame)
        self.fa_threshold.insert(0, "0.2")
        self.fa_threshold.pack(fill=tk.X, padx=5, pady=2)

        ttk.Label(tracking_frame, text="最大角度 (度):").pack(anchor=tk.W, padx=5)
        self.angle_threshold = ttk.Entry(tracking_frame)
        self.angle_threshold.insert(0, "45")
        self.angle_threshold.pack(fill=tk.X, padx=5, pady=2)

        ttk.Label(tracking_frame, text="步长 (mm):").pack(anchor=tk.W, padx=5)
        self.step_size = ttk.Entry(tracking_frame)
        self.step_size.insert(0, "0.5")
        self.step_size.pack(fill=tk.X, padx=5, pady=2)

        ttk.Button(tracking_frame,
                  text="开始纤维追踪",
                  command=self.start_fiber_tracking).pack(fill=tk.X, padx=5, pady=5)

    def setup_roi_tab(self):
        """设置ROI选项卡"""
        # ROI工具
        tools_frame = ttk.LabelFrame(self.roi_tab, text="ROI工具")
        tools_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Button(tools_frame,
                  text="创建矩形ROI",
                  command=lambda: self.create_roi("rectangle")).pack(fill=tk.X, padx=5, pady=2)

        ttk.Button(tools_frame,
                  text="创建圆形ROI",
                  command=lambda: self.create_roi("circle")).pack(fill=tk.X, padx=5, pady=2)

        ttk.Button(tools_frame,
                  text="自由绘制ROI",
                  command=lambda: self.create_roi("freehand")).pack(fill=tk.X, padx=5, pady=2)

        # ROI管理
        manage_frame = ttk.LabelFrame(self.roi_tab, text="ROI管理")
        manage_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Label(manage_frame, text="ROI名称:").pack(anchor=tk.W, padx=5)
        self.roi_name_entry = ttk.Entry(manage_frame)
        self.roi_name_entry.insert(0, "ROI_1")
        self.roi_name_entry.pack(fill=tk.X, padx=5, pady=2)

        self.roi_color_button = ttk.Button(
            manage_frame,
            text="选择ROI颜色",
            command=self.choose_roi_color
        )
        self.roi_color_button.pack(fill=tk.X, padx=5, pady=2)

        ttk.Label(manage_frame, text="已创建的ROI:").pack(anchor=tk.W, padx=5)

        self.roi_listbox = tk.Listbox(manage_frame, height=5)
        self.roi_listbox.pack(fill=tk.X, padx=5, pady=2)
        self.roi_listbox.bind('<<ListboxSelect>>', self.select_roi)

        roi_button_frame = ttk.Frame(manage_frame)
        roi_button_frame.pack(fill=tk.X, padx=5, pady=2)

        ttk.Button(roi_button_frame,
                  text="删除ROI",
                  command=self.delete_roi).pack(side=tk.LEFT, padx=2)

        ttk.Button(roi_button_frame,
                  text="分析ROI",
                  command=self.analyze_roi).pack(side=tk.LEFT, padx=2)

        ttk.Button(roi_button_frame,
                  text="保存ROI",
                  command=self.save_roi).pack(side=tk.LEFT, padx=2)

        # ROI统计数据
        stats_frame = ttk.LabelFrame(self.roi_tab, text="ROI统计")
        stats_frame.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        self.roi_stats_text = tk.Text(stats_frame, height=8, width=30,
                                     font=("Courier", 9), wrap=tk.WORD)
        self.roi_stats_text.pack(fill=tk.BOTH, expand=True, padx=2, pady=2)
        self.roi_stats_text.insert(tk.END, "未选择ROI")
        self.roi_stats_text.config(state=tk.DISABLED)

    def setup_settings_tab(self):
        """设置设置选项卡"""
        # 显示设置
        display_frame = ttk.LabelFrame(self.settings_tab, text="显示设置")
        display_frame.pack(fill=tk.X, padx=5, pady=5)

        # 颜色映射选择
        ttk.Label(display_frame, text="颜色映射:").pack(anchor=tk.W, padx=5)
        self.colormap_var = tk.StringVar(value="jet")
        colormap_options = ["viridis", "jet", "hot", "coolwarm", "spectral", "RdBu", "bwr"]
        colormap_menu = ttk.Combobox(display_frame, textvariable=self.colormap_var,
                                   values=colormap_options)
        colormap_menu.pack(fill=tk.X, padx=5, pady=2)
        colormap_menu.bind("<<ComboboxSelected>>", self.update_display)

        # 窗宽窗位调整
        window_frame = ttk.Frame(display_frame)
        window_frame.pack(fill=tk.X, padx=5, pady=2)

        ttk.Label(window_frame, text="窗宽:").pack(side=tk.LEFT)
        self.window_scale = ttk.Scale(window_frame,
                                    from_=0, to=2000,
                                    orient=tk.HORIZONTAL,
                                    command=self.update_window)
        self.window_scale.set(500)
        self.window_scale.pack(side=tk.LEFT, fill=tk.X, expand=True)

        level_frame = ttk.Frame(display_frame)
        level_frame.pack(fill=tk.X, padx=5, pady=2)

        ttk.Label(level_frame, text="窗位:").pack(side=tk.LEFT)
        self.level_scale = ttk.Scale(level_frame,
                                   from_=-1000, to=1000,
                                   orient=tk.HORIZONTAL,
                                   command=self.update_window)
        self.level_scale.set(0)
        self.level_scale.pack(side=tk.LEFT, fill=tk.X, expand=True)

        # 切片控制
        slice_frame = ttk.Frame(display_frame)
        slice_frame.pack(fill=tk.X, padx=5, pady=2)

        ttk.Label(slice_frame, text="切片:").pack(side=tk.LEFT)
        self.slice_scale = ttk.Scale(slice_frame,
                                   from_=0, to=100,
                                   orient=tk.HORIZONTAL,
                                   command=self.update_slice)
        self.slice_scale.pack(side=tk.LEFT, fill=tk.X, expand=True)

        # 视图设置
        view_frame = ttk.LabelFrame(self.settings_tab, text="视图设置")
        view_frame.pack(fill=tk.X, padx=5, pady=5)

        self.view_var = tk.StringVar(value="axial")
        ttk.Radiobutton(view_frame, text="轴向",
                       variable=self.view_var,
                       value="axial",
                       command=self.update_view).pack(anchor=tk.W, padx=5)
        ttk.Radiobutton(view_frame, text="矢状位",
                       variable=self.view_var,
                       value="sagittal",
                       command=self.update_view).pack(anchor=tk.W, padx=5)
        ttk.Radiobutton(view_frame, text="冠状位",
                       variable=self.view_var,
                       value="coronal",
                       command=self.update_view).pack(anchor=tk.W, padx=5)

        # 保存设置
        save_frame = ttk.LabelFrame(self.settings_tab, text="保存设置")
        save_frame.pack(fill=tk.X, padx=5, pady=5)

        ttk.Button(save_frame,
                  text="设置保存路径",
                  command=self.set_save_path).pack(fill=tk.X, padx=5, pady=5)

        self.save_path_var = tk.StringVar(value=os.path.join(os.path.expanduser("~"), "DTI_Results"))
        save_path_entry = ttk.Entry(save_frame, textvariable=self.save_path_var)
        save_path_entry.pack(fill=tk.X, padx=5, pady=2)

    def setup_vtk(self):
        """设置VTK渲染环境"""
        # 分割右侧显示区域
        self.display_paned = ttk.PanedWindow(self.display_frame, orient=tk.VERTICAL)
        self.display_paned.pack(fill=tk.BOTH, expand=True)

        # 上部：VTK渲染区域
        self.render_frame = ttk.Frame(self.display_paned)

        # 下部：图表区域
        self.chart_frame = ttk.Frame(self.display_paned)

        self.display_paned.add(self.render_frame, weight=3)
        self.display_paned.add(self.chart_frame, weight=1)

        # 设置VTK渲染器
        self.renderer = vtk.vtkRenderer()
        self.renderer.SetBackground(0.2, 0.2, 0.2)

        self.render_window = vtk.vtkRenderWindow()
        self.render_window.AddRenderer(self.renderer)

        self.render_widget = vtkTkRenderWindowInteractor(
            self.render_frame,
            rw=self.render_window,
            width=800,
            height=500
        )
        self.render_widget.pack(fill=tk.BOTH, expand=True)

        self.interactor = self.render_widget.GetRenderWindow().GetInteractor()

        # 设置交互样式
        self.interaction_style = vtk.vtkInteractorStyleImage()
        self.interactor.SetInteractorStyle(self.interaction_style)

        # 图表区域初始设置
        self.setup_chart()

    def setup_chart(self):
        """设置图表区域"""
        # 创建空的图表
        self.figure = plt.Figure(figsize=(8, 3), dpi=100)
        self.ax = self.figure.add_subplot(111)
        self.ax.set_title("DTI指标分布")
        self.ax.set_xlabel("值域")
        self.ax.set_ylabel("频率")

        self.canvas = FigureCanvasTkAgg(self.figure, self.chart_frame)
        self.canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def setup_interactor(self):
        """设置VTK交互器"""
        self.interactor.Initialize()
        self.interactor.Start()

        # 设置键盘和鼠标回调
        self.interaction_style.AddObserver("KeyPressEvent", self.key_press_callback)
        self.interaction_style.AddObserver("MouseWheelForwardEvent", self.mouse_wheel_callback)
        self.interaction_style.AddObserver("MouseWheelBackwardEvent", self.mouse_wheel_callback)

    def key_press_callback(self, obj, event):
        """键盘按键回调"""
        key = self.interactor.GetKeySym()

        if key == "Up":
            self.change_slice(1)
        elif key == "Down":
            self.change_slice(-1)
        elif key == "Right":
            self.change_volume(1)
        elif key == "Left":
            self.change_volume(-1)

    def mouse_wheel_callback(self, obj, event):
        """鼠标滚轮回调"""
        if obj.GetClassName() == "vtkInteractorStyleImage":
            if event == "MouseWheelForwardEvent":
                self.change_slice(1)
            else:
                self.change_slice(-1)

    def change_slice(self, step):
        """改变当前切片"""
        if self.volume_data is not None:
            max_slice = self.get_max_slice()
            new_slice = self.current_slice + step

            if 0 <= new_slice <= max_slice:
                self.current_slice = new_slice
                self.slice_scale.set(new_slice)
                self.update_display()

    def change_volume(self, step):
        """改变当前体积（对于多体积数据）"""
        if self.volume_data is not None:
            max_volume = self.volume_data.shape[0] - 1
            new_volume = self.current_volume + step

            if 0 <= new_volume <= max_volume:
                self.current_volume = new_volume
                self.update_display()

    def get_max_slice(self):
        """获取当前视图的最大切片索引"""
        if self.volume_data is None:
            return 0

        if self.current_view == "axial":
            return self.volume_data.shape[1] - 1
        elif self.current_view == "sagittal":
            return self.volume_data.shape[2] - 1
        elif self.current_view == "coronal":
            return self.volume_data.shape[3] - 1

        return 0

    def check_processing_queue(self):
        """检查处理队列"""
        try:
            while True:
                progress, message = self.processing_queue.get_nowait()
                self.update_progress(progress, message)
        except queue.Empty:
            pass

        # 继续检查
        self.root.after(100, self.check_processing_queue)

    def update_progress(self, progress, message=None):
        """更新进度条和状态信息"""
        self.progress["value"] = progress
        if message:
            self.status_label.config(text=message)

        # 强制更新UI
        self.root.update_idletasks()

    def load_dicom_directory(self):
        """加载DICOM目录"""
        dicom_dir = filedialog.askdirectory(title="选择DICOM目录")
        if not dicom_dir:
            return

        try:
            self.update_progress(0, "加载DICOM数据...")

            # 在新线程中处理
            self.processing_thread = threading.Thread(
                target=self.process_dicom_data,
                args=(dicom_dir,)
            )
            self.processing_thread.daemon = True
            self.processing_thread.start()

        except Exception as e:
            logger.error(f"加载DICOM数据失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"加载DICOM数据失败: {str(e)}")

    def process_dicom_data(self, dicom_dir):
        """处理DICOM数据"""
        try:
            # 进度回调函数
            def progress_callback(progress, message):
                self.processing_queue.put((progress, message))

            # 加载DICOM序列
            self.volume_data, self.bvals, self.bvecs, self.metadata = self.processor.load_dicom_series(
                dicom_dir, progress_callback)

            # 保存原始数据
            self.processed_data = self.volume_data.copy()

            # 更新显示
            self.root.after(0, self.update_data_info)
            self.root.after(0, self.update_display)

        except Exception as e:
            logger.error(f"处理DICOM数据失败: {str(e)}", exc_info=True)
            self.root.after(0, lambda: messagebox.showerror(
                "错误", f"处理DICOM数据失败: {str(e)}"))
            self.processing_queue.put((0, "处理失败"))

    def load_nifti_file(self):
        """加载NIfTI文件"""
        nifti_file = filedialog.askopenfilename(
            title="选择NIfTI文件",
            filetypes=[("NIfTI文件", "*.nii *.nii.gz")]
        )
        if not nifti_file:
            return

        try:
            self.update_progress(0, "加载NIfTI数据...")

            # 加载NIfTI数据
            data, affine = load_nifti(nifti_file)

            # 调整数据维度
            if len(data.shape) == 3:
                # 单体积数据
                self.volume_data = np.array([data])
            else:
                # 多体积数据
                self.volume_data = data

            # 在数据信息中更新文件路径
            self.metadata = {'nifti_file': nifti_file}

            # 保存原始数据
            self.processed_data = self.volume_data.copy()

            self.update_progress(100, "NIfTI数据加载完成")

            # 更新显示
            self.update_data_info()
            self.update_display()

            # 提示加载b值和梯度方向
            messagebox.showinfo("提示", "请记得加载相应的b值和梯度方向文件")

        except Exception as e:
            logger.error(f"加载NIfTI数据失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"加载NIfTI数据失败: {str(e)}")
            self.update_progress(0, "加载失败")

    def load_bval_file(self):
        """加载b值文件"""
        bval_file = filedialog.askopenfilename(
            title="选择b值文件",
            filetypes=[("b值文件", "*.bval"), ("文本文件", "*.txt"), ("所有文件", "*.*")]
        )
        if not bval_file:
            return

        try:
            # 读取b值文件
            with open(bval_file, 'r') as f:
                bvals_str = f.read().strip()
                self.bvals = np.array([float(val) for val in bvals_str.split()])

            # 更新元数据
            if hasattr(self, 'metadata'):
                self.metadata['bval_file'] = bval_file
            else:
                self.metadata = {'bval_file': bval_file}

            # 检查b值数量与体积数量是否匹配
            if self.volume_data is not None and len(self.bvals) != self.volume_data.shape[0]:
                messagebox.showwarning(
                    "警告",
                    f"b值数量 ({len(self.bvals)}) 与体积数量 ({self.volume_data.shape[0]}) 不匹配"
                )

            self.update_data_info()
            messagebox.showinfo("成功", "b值数据加载完成")

        except Exception as e:
            logger.error(f"加载b值数据失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"加载b值数据失败: {str(e)}")

    def load_bvec_file(self):
        """加载梯度方向文件"""
        bvec_file = filedialog.askopenfilename(
            title="选择梯度方向文件",
            filetypes=[("梯度方向文件", "*.bvec"), ("文本文件", "*.txt"), ("所有文件", "*.*")]
        )
        if not bvec_file:
            return

        try:
            # 读取梯度方向文件
            with open(bvec_file, 'r') as f:
                bvecs_content = f.readlines()

            if len(bvecs_content) == 3:
                # FSL格式 (3xN)
                x = np.array([float(val) for val in bvecs_content[0].strip().split()])
                y = np.array([float(val) for val in bvecs_content[1].strip().split()])
                z = np.array([float(val) for val in bvecs_content[2].strip().split()])
                self.bvecs = np.vstack((x, y, z)).T
            else:
                # 假定是Nx3格式
                bvecs = []
                for line in bvecs_content:
                    bvecs.append([float(val) for val in line.strip().split()])
                self.bvecs = np.array(bvecs)

            # 更新元数据
            if hasattr(self, 'metadata'):
                self.metadata['bvec_file'] = bvec_file
            else:
                self.metadata = {'bvec_file': bvec_file}

            # 检查梯度方向数量与体积数量是否匹配
            if self.volume_data is not None and len(self.bvecs) != self.volume_data.shape[0]:
                messagebox.showwarning(
                    "警告",
                    f"梯度方向数量 ({len(self.bvecs)}) 与体积数量 ({self.volume_data.shape[0]}) 不匹配"
                )

            self.update_data_info()
            messagebox.showinfo("成功", "梯度方向数据加载完成")

        except Exception as e:
            logger.error(f"加载梯度方向数据失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"加载梯度方向数据失败: {str(e)}")

    def update_data_info(self):
        """更新数据信息显示"""
        if self.volume_data is None:
            return

        # 准备信息文本
        info_text = "数据信息:\n"
        info_text += f"维度: {self.volume_data.shape}\n"
        info_text += f"数据类型: {self.volume_data.dtype}\n"

        # 显示b值信息
        if hasattr(self, 'bvals') and self.bvals is not None:
            unique_bvals = np.unique(self.bvals)
            info_text += f"b值: {unique_bvals.tolist()}\n"
            info_text += f"方向数: {len(self.bvals)}\n"

        # 显示元数据
        if hasattr(self, 'metadata') and self.metadata:
            info_text += "\n元数据:\n"
            for key, value in self.metadata.items():
                # 跳过复杂对象和长文本
                if isinstance(value, (str, int, float)):
                    if isinstance(value, str) and len(value) > 30:
                        value = value[:27] + "..."
                    info_text += f"{key}: {value}\n"

        # 更新文本控件
        self.data_info_text.config(state=tk.NORMAL)
        self.data_info_text.delete(1.0, tk.END)
        self.data_info_text.insert(tk.END, info_text)
        self.data_info_text.config(state=tk.DISABLED)

    def start_preprocessing(self):
        """开始预处理"""
        if self.volume_data is None:
            messagebox.showwarning("警告", "请先加载数据")
            return

        # 获取预处理选项
        options = {
            'create_mask': self.create_mask_var.get(),
            'motion_correction': self.motion_correction_var.get(),
            'eddy_correction': self.eddy_correction_var.get(),
            'denoising': self.denoising_var.get(),
            'bias_correction': self.bias_correction_var.get()
        }

        # 确认预处理
        enabled_options = [k for k, v in options.items() if v]
        if not messagebox.askyesno("确认",
                                 f"将执行以下预处理: {', '.join(enabled_options)}\n继续?"):
            return

        self.update_progress(0, "准备预处理...")

        # 在新线程中进行预处理
        self.processing_thread = threading.Thread(
            target=self.run_preprocessing,
            args=(options,)
        )
        self.processing_thread.daemon = True
        self.processing_thread.start()

    def run_preprocessing(self, options):
        """执行预处理"""
        try:
            # 进度回调
            def progress_callback(progress, message):
                self.processing_queue.put((progress, message))

            # 执行预处理
            self.processed_data = self.preprocessor.preprocess(
                self.volume_data,
                self.bvals,
                self.bvecs,
                options,
                progress_callback
            )

            # 更新mask
            self.mask = self.preprocessor.mask

            # 完成后更新显示
            self.root.after(0, lambda: self.update_display(use_processed=True))

        except Exception as e:
            logger.error(f"预处理失败: {str(e)}", exc_info=True)
            self.root.after(0, lambda: messagebox.showerror(
                "错误", f"预处理失败: {str(e)}"))
            self.processing_queue.put((0, "预处理失败"))

    def start_dti_analysis(self):
        """开始DTI分析"""
        if self.processed_data is None:
            messagebox.showwarning("警告", "请先加载并预处理数据")
            return

        if self.bvals is None or self.bvecs is None:
            messagebox.showwarning("警告", "请确保已加载b值和梯度方向数据")
            return

        self.update_progress(0, "准备DTI分析...")

        # 在新线程中进行DTI分析
        self.processing_thread = threading.Thread(
            target=self.run_dti_analysis
        )
        self.processing_thread.daemon = True
        self.processing_thread.start()

    def run_dti_analysis(self):
        """执行DTI分析"""
        try:
            # 进度回调
            def progress_callback(progress, message):
                self.processing_queue.put((progress, message))

            # 设置梯度表
            self.analyzer.setup(self.bvals, self.bvecs)

            # 执行张量拟合
            results = self.analyzer.fit_tensor(
                self.processed_data,
                self.mask,
                progress_callback
            )

            # 完成后更新显示
            self.root.after(0, lambda: self.update_dti_display(results))

        except Exception as e:
            logger.error(f"DTI分析失败: {str(e)}", exc_info=True)
            self.root.after(0, lambda: messagebox.showerror(
                "错误", f"DTI分析失败: {str(e)}"))
            self.processing_queue.put((0, "DTI分析失败"))

    def update_dti_display(self, results):
        """更新DTI结果显示"""
        # 获取当前选择的DTI参数图
        map_type = self.dti_map_var.get()

        # 显示参数图
        self.display_dti_map(results[map_type], map_type)

        # 更新直方图
        self.update_histogram(results[map_type], map_type)

        # 更新状态
        self.update_progress(100, f"显示{map_type.upper()}图")

    def display_dti_map(self, data, map_type):
        """显示DTI参数图"""
        # 清除当前渲染器
        self.renderer.RemoveAllViewProps()

        # 获取当前切片
        if self.current_view == "axial":
            slice_data = data[:, self.current_slice, :]
        elif self.current_view == "sagittal":
            slice_data = data[:, :, self.current_slice]
        elif self.current_view == "coronal":
            slice_data = data[self.current_slice, :, :]

        # 创建VTK图像数据
        vtk_image = self.create_vtk_image(slice_data, map_type)

        # 创建图像actor
        actor = vtk.vtkImageActor()
        actor.SetInputData(vtk_image)

        # 添加到渲染器
        self.renderer.AddActor(actor)

        # 重置相机
        self.renderer.ResetCamera()

        # 渲染
        self.render_window.Render()

    def create_vtk_image(self, data, map_type):
        """创建VTK图像对象"""
        # 创建VTK图像数据
        vtk_image = vtk.vtkImageData()
        vtk_image.SetDimensions(data.shape[1], data.shape[0], 1)
        vtk_image.AllocateScalars(vtk.VTK_FLOAT, 1)

        # 将numpy数组拷贝到VTK图像
        for y in range(data.shape[0]):
            for x in range(data.shape[1]):
                vtk_image.SetScalarComponentFromFloat(x, y, 0, 0, data[y, x])

        # 创建颜色映射
        vtk_colormap = vtk.vtkLookupTable()

        # 设置颜色映射
        if map_type == "fa":
            vtk_colormap.SetHueRange(0.0, 0.667)  # 红到蓝
        elif map_type == "md" or map_type == "ad" or map_type == "rd":
            vtk_colormap.SetHueRange(0.0, 0.333)  # 红到黄
        else:
            vtk_colormap.SetHueRange(0.0, 1.0)  # 全色谱

        vtk_colormap.SetSaturationRange(1.0, 1.0)
        vtk_colormap.SetValueRange(1.0, 1.0)
        vtk_colormap.SetTableRange(0.0, 1.0 if map_type == "fa" else 0.003)
        vtk_colormap.Build()

        # 应用颜色映射
        vtk_color_mapper = vtk.vtkImageMapToColors()
        vtk_color_mapper.SetLookupTable(vtk_colormap)
        vtk_color_mapper.SetInputData(vtk_image)
        vtk_color_mapper.Update()

        return vtk_color_mapper.GetOutput()

    def update_histogram(self, data, map_type):
        """更新直方图显示"""
        # 清除当前图表
        self.ax.clear()

        # 计算直方图
        values = data.flatten()
        values = values[~np.isnan(values)]  # 移除NaN值
        values = values[values > 0]  # 移除零值

        if len(values) > 0:
            # 设置范围
            if map_type == "fa":
                range_max = 1.0
            else:  # md, ad, rd
                range_max = np.percentile(values, 99)  # 使用99百分位数避免离群点

            # 绘制直方图
            n, bins, patches = self.ax.hist(values, bins=50, range=(0, range_max),
                                          alpha=0.7, color=self.accent_color)

            # 添加平均值线
            mean_val = np.mean(values)
            self.ax.axvline(mean_val, color='r', linestyle='dashed', linewidth=1)
            self.ax.text(mean_val*1.05, max(n)*0.9, f'Mean: {mean_val:.4f}',
                        color='r', fontsize=9)

            # 设置标题和标签
            self.ax.set_title(f"{map_type.upper()} 值分布")
            self.ax.set_xlabel(f"{map_type.upper()} 值")
            self.ax.set_ylabel("频率")

            # 更新图表
            self.canvas.draw()

    def update_dti_map(self):
        """更新DTI参数图显示"""
        if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
            return

        # 获取当前选择的DTI参数图
        map_type = self.dti_map_var.get()

        # 获取相应的数据
        if map_type == "fa":
            data = self.analyzer.fa
        elif map_type == "md":
            data = self.analyzer.md
        elif map_type == "ad":
            data = self.analyzer.ad
        elif map_type == "rd":
            data = self.analyzer.rd
        elif map_type == "rgb":
            data = self.analyzer.rgb

        # 显示参数图
        self.display_dti_map(data, map_type)

        # 更新直方图
        self.update_histogram(data, map_type)

    def start_fiber_tracking(self):
        """开始纤维追踪"""
        if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
            messagebox.showwarning("警告", "请先进行DTI分析")
            return

        try:
            # 获取参数
            fa_threshold = float(self.fa_threshold.get())
            max_angle = float(self.angle_threshold.get())
            step_size = float(self.step_size.get())
        except ValueError:
            messagebox.showerror("错误", "请输入有效的数值参数")
            return

        self.update_progress(0, "准备纤维追踪...")

        # 在新线程中进行纤维追踪
        self.processing_thread = threading.Thread(
            target=self.run_fiber_tracking,
            args=(fa_threshold, max_angle, step_size)
        )
        self.processing_thread.daemon = True
        self.processing_thread.start()

    def run_fiber_tracking(self, fa_threshold, max_angle, step_size):
        """执行纤维追踪"""
        try:
            # 进度回调
            def progress_callback(progress, message):
                self.processing_queue.put((progress, message))

            # 执行纤维追踪
            self.streamlines = self.analyzer.track_fibers(
                fa_threshold,
                max_angle,
                step_size,
                progress_callback
            )

            # 完成后显示结果
            self.root.after(0, self.display_streamlines)

        except Exception as e:
            logger.error(f"纤维追踪失败: {str(e)}", exc_info=True)
            self.root.after(0, lambda: messagebox.showerror(
                "错误", f"纤维追踪失败: {str(e)}"))
            self.processing_queue.put((0, "纤维追踪失败"))

    def display_streamlines(self):
        """显示纤维束"""
        # 使用VTK显示纤维束
        streamlines_vtk = vtk.vtkPolyData()
        points = vtk.vtkPoints()
        lines = vtk.vtkCellArray()

        point_id = 0
        for streamline in self.streamlines:
            line = vtk.vtkPolyLine()
            line.GetPointIds().SetNumberOfIds(len(streamline))

            for i, point in enumerate(streamline):
                points.InsertNextPoint(point[0], point[1], point[2])
                line.GetPointIds().SetId(i, point_id)
                point_id += 1

            lines.InsertNextCell(line)

        streamlines_vtk.SetPoints(points)
        streamlines_vtk.SetLines(lines)

        # 创建颜色映射
        vtk_colors = vtk.vtkUnsignedCharArray()
        vtk_colors.SetNumberOfComponents(3)
        vtk_colors.SetName("Colors")

        # 为每条纤维分配一个颜色
        for i in range(len(self.streamlines)):
            r, g, b = self.get_random_color()
            for j in range(len(self.streamlines[i])):
                vtk_colors.InsertNextTuple3(r, g, b)

        streamlines_vtk.GetPointData().SetScalars(vtk_colors)

        # 创建纤维actor
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputData(streamlines_vtk)

        actor = vtk.vtkActor()
        actor.SetMapper(mapper)

        # 清除当前场景
        self.renderer.RemoveAllViewProps()

        # 添加FA背景（中间切片）
        if self.analyzer.fa is not None:
            # 创建轴向FA背景
            axial_slice = self.analyzer.fa.shape[0] // 2
            self.add_background_slice(self.analyzer.fa, axial_slice, "axial")

            # 创建矢状位FA背景
            sagittal_slice = self.analyzer.fa.shape[2] // 2
            self.add_background_slice(self.analyzer.fa, sagittal_slice, "sagittal")

            # 创建冠状位FA背景
            coronal_slice = self.analyzer.fa.shape[1] // 2
            self.add_background_slice(self.analyzer.fa, coronal_slice, "coronal")

        # 添加纤维束
        self.renderer.AddActor(actor)

        # 更改交互样式为3D
        self.interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera())

        # 设置视角
        self.renderer.ResetCamera()

        # 渲染
        self.render_window.Render()

        # 更新状态
        self.update_progress(100, f"显示{len(self.streamlines)}条纤维")

    def add_background_slice(self, data, slice_index, orientation):
        """添加背景切片"""
        if orientation == "axial":
            slice_data = data[slice_index, :, :]
            # 创建平面
            plane = vtk.vtkPlaneSource()
            plane.SetOrigin(0, 0, slice_index)
            plane.SetPoint1(data.shape[2], 0, slice_index)
            plane.SetPoint2(0, data.shape[1], slice_index)

        elif orientation == "sagittal":
            slice_data = data[:, :, slice_index]
            # 创建平面
            plane = vtk.vtkPlaneSource()
            plane.SetOrigin(slice_index, 0, 0)
            plane.SetPoint1(slice_index, data.shape[1], 0)
            plane.SetPoint2(slice_index, 0, data.shape[0])

        elif orientation == "coronal":
            slice_data = data[:, slice_index, :]
            # 创建平面
            plane = vtk.vtkPlaneSource()
            plane.SetOrigin(0, slice_index, 0)
            plane.SetPoint1(data.shape[2], slice_index, 0)
            plane.SetPoint2(0, slice_index, data.shape[0])

        # 创建VTK图像数据
        vtk_image = self.create_vtk_image(slice_data, "fa")

        # 创建纹理
        texture = vtk.vtkTexture()
        texture.SetInputData(vtk_image)
        texture.InterpolateOn()

        # 创建映射器
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputConnection(plane.GetOutputPort())

        # 创建actor
        actor = vtk.vtkActor()
        actor.SetMapper(mapper)
        actor.SetTexture(texture)
        actor.GetProperty().SetOpacity(0.5)  # 半透明

        # 添加到渲染器
        self.renderer.AddActor(actor)

    def get_random_color(self):
        """生成随机颜色"""
        return np.random.randint(100, 255), np.random.randint(100, 255), np.random.randint(100, 255)

    def create_roi(self, roi_type):
        """创建ROI"""
        if self.volume_data is None:
            messagebox.showwarning("警告", "请先加载数据")
            return

        # 获取ROI名称
        roi_name = self.roi_name_entry.get()
        if not roi_name:
            messagebox.showwarning("警告", "请输入ROI名称")
            return

        # 检查名称是否已存在
        if roi_name in self.roi_masks:
            if not messagebox.askyesno("警告", f"ROI '{roi_name}' 已存在，是否覆盖?"):
                return

        # 设置当前ROI名称
        self.current_roi_name = roi_name

        # 根据ROI类型创建不同的交互器
        if roi_type == "rectangle":
            self.start_rectangle_roi()
        elif roi_type == "circle":
            self.start_circle_roi()
        elif roi_type == "freehand":
            self.start_freehand_roi()

    def start_rectangle_roi(self):
        """开始矩形ROI绘制"""
        # 更改交互样式
        self.interaction_style = vtk.vtkInteractorStyleImage()
        self.interactor.SetInteractorStyle(self.interaction_style)

        # 创建矩形widget
        self.roi_widget = vtk.vtkBoxWidget2()
        self.roi_widget.SetInteractor(self.interactor)

        # 配置表示
        rep = vtk.vtkBoxRepresentation()
        self.roi_widget.SetRepresentation(rep)

        # 设置初始位置和大小
        bounds = [0, 100, 0, 100, 0, 1]
        rep.PlaceWidget(bounds)

        # 添加回调
        self.roi_callback = RoiCallback(self)
        self.roi_widget.AddObserver(vtk.vtkCommand.InteractionEvent, self.roi_callback)

        # 激活widget
        self.roi_widget.On()

        # 更新状态
        self.status_label.config(text="绘制矩形ROI：调整大小和位置，完成后右键点击")

    def start_circle_roi(self):
        """开始圆形ROI绘制"""
        # 更改交互样式
        self.interaction_style = vtk.vtkInteractorStyleImage()
        self.interactor.SetInteractorStyle(self.interaction_style)

        # 创建圆形widget
        self.roi_widget = vtk.vtkEllipticalSectorWidget()
        self.roi_widget.SetInteractor(self.interactor)

        # 设置初始位置和大小
        center = [50, 50, 0]
        self.roi_widget.SetCenter(center)
        self.roi_widget.SetRadius(20)

        # 添加回调
        self.roi_callback = RoiCallback(self)
        self.roi_widget.AddObserver(vtk.vtkCommand.InteractionEvent, self.roi_callback)

        # 激活widget
        self.roi_widget.On()

        # 更新状态
        self.status_label.config(text="绘制圆形ROI：调整位置和大小，完成后右键点击")

    def start_freehand_roi(self):
        """开始自由绘制ROI"""
        # 更改交互样式为自定义的绘制样式
        self.interaction_style = DrawPolylineStyle(self)
        self.interactor.SetInteractorStyle(self.interaction_style)

        # 更新状态
        self.status_label.config(text="自由绘制ROI：点击添加点，双击或右键完成")

    def finish_roi(self, points=None):
        """完成ROI创建"""
        if not self.current_roi_name:
            return

        try:
            # 创建ROI掩码
            if hasattr(self, 'analyzer') and self.analyzer.fa is not None:
                # 使用FA图的形状
                mask_shape = self.analyzer.fa.shape
            else:
                # 使用原始数据的形状
                mask_shape = self.volume_data[0].shape

            mask = np.zeros(mask_shape, dtype=bool)

            # 根据不同的ROI类型填充掩码
            if hasattr(self, 'roi_widget'):
                # 矩形或圆形ROI
                if isinstance(self.roi_widget, vtk.vtkBoxWidget2):
                    # 矩形ROI
                    rep = self.roi_widget.GetRepresentation()
                    bounds = [0, 0, 0, 0, 0, 0]
                    rep.GetBounds(bounds)

                    # 转换为整数坐标
                    x_min, x_max = int(bounds[0]), int(bounds[1])
                    y_min, y_max = int(bounds[2]), int(bounds[3])
                    z_min, z_max = int(bounds[4]), int(bounds[5])

                    # 根据当前视图设置掩码
                    if self.current_view == "axial":
                        mask[self.current_slice, y_min:y_max, x_min:x_max] = True
                    elif self.current_view == "sagittal":
                        mask[z_min:z_max, y_min:y_max, self.current_slice] = True
                    elif self.current_view == "coronal":
                        mask[z_min:z_max, self.current_slice, x_min:x_max] = True

                elif isinstance(self.roi_widget, vtk.vtkEllipticalSectorWidget):
                    # 圆形ROI
                    center = [0, 0, 0]
                    self.roi_widget.GetCenter(center)
                    radius = self.roi_widget.GetRadius()

                    # 生成圆形掩码
                    y, x, z = np.ogrid[:mask_shape[0], :mask_shape[1], :mask_shape[2]]

                    if self.current_view == "axial":
                        dist_from_center = np.sqrt((x - center[0])**2 + (y - center[1])**2)
                        circle_mask = dist_from_center <= radius
                        mask[self.current_slice, :, :] = circle_mask

                    elif self.current_view == "sagittal":
                        dist_from_center = np.sqrt((z - center[0])**2 + (y - center[1])**2)
                        circle_mask = dist_from_center <= radius
                        mask[:, :, self.current_slice] = circle_mask

                    elif self.current_view == "coronal":
                        dist_from_center = np.sqrt((x - center[0])**2 + (z - center[1])**2)
                        circle_mask = dist_from_center <= radius
                        mask[:, self.current_slice, :] = circle_mask

                # 清除widget
                self.roi_widget.Off()
                del self.roi_widget

            elif points is not None:
                # 自由绘制ROI
                # 创建多边形掩码
                if len(points) >= 3:
                    polygon = np.array(points)

                    # 获取当前视图中的所有点坐标
                    y, x = np.mgrid[:mask_shape[1], :mask_shape[2]]
                    coords = np.column_stack((x.flatten(), y.flatten()))

                    # 检查每个点是否在多边形内
                    from matplotlib.path import Path
                    poly_path = Path(polygon)
                    mask_2d = poly_path.contains_points(coords).reshape(mask_shape[1], mask_shape[2])

                    # 根据当前视图设置掩码
                    if self.current_view == "axial":
                        mask[self.current_slice] = mask_2d
                    elif self.current_view == "sagittal":
                        mask[:, :, self.current_slice] = mask_2d.T
                    elif self.current_view == "coronal":
                        mask[:, self.current_slice] = mask_2d

            # 保存ROI掩码
            self.roi_masks[self.current_roi_name] = mask

            # 更新ROI列表
            self.update_roi_list()

            # 分析ROI
            self.analyze_current_roi()

            # 显示ROI
            self.show_roi_overlay()

            # 更新状态
            self.status_label.config(text=f"ROI '{self.current_roi_name}' 创建完成")

        except Exception as e:
            logger.error(f"创建ROI失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"创建ROI失败: {str(e)}")
            self.status_label.config(text="创建ROI失败")

    def update_roi_list(self):
        """更新ROI列表"""
        # 清除列表
        self.roi_listbox.delete(0, tk.END)

        # 添加所有ROI
        for roi_name in self.roi_masks:
            self.roi_listbox.insert(tk.END, roi_name)

    def select_roi(self, event):
        """选择ROI"""
        selection = event.widget.curselection()
        if selection:
            index = selection[0]
            roi_name = self.roi_listbox.get(index)
            self.current_roi_name = roi_name

            # 分析所选ROI
            self.analyze_current_roi()

            # 显示ROI
            self.show_roi_overlay()

    def analyze_current_roi(self):
        """分析当前选择的ROI"""
        if not self.current_roi_name or self.current_roi_name not in self.roi_masks:
            return

        try:
            # 获取ROI掩码
            roi_mask = self.roi_masks[self.current_roi_name]

            # 检查是否已完成DTI分析
            if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
                # 还没有DTI结果，只计算体积
                volume_voxels = np.sum(roi_mask)

                stats_text = f"ROI: {self.current_roi_name}\n"
                stats_text += f"体积: {volume_voxels} 体素\n"

                # 更新统计文本
                self.roi_stats_text.config(state=tk.NORMAL)
                self.roi_stats_text.delete(1.0, tk.END)
                self.roi_stats_text.insert(tk.END, stats_text)
                self.roi_stats_text.config(state=tk.DISABLED)
                return

            # 计算DTI指标统计
            stats = self.analyzer.get_roi_statistics(roi_mask)

            # 生成统计文本
            stats_text = f"ROI: {self.current_roi_name}\n"
            stats_text += f"体积: {np.sum(roi_mask)} 体素\n\n"

            # FA
            stats_text += "FA:\n"
            stats_text += f"  平均值: {stats['fa']['mean']:.4f}\n"
            stats_text += f"  标准差: {stats['fa']['std']:.4f}\n"
            stats_text += f"  中位数: {stats['fa']['median']:.4f}\n"

            # MD
            stats_text += "MD:\n"
            stats_text += f"  平均值: {stats['md']['mean']:.6f}\n"
            stats_text += f"  标准差: {stats['md']['std']:.6f}\n"
            stats_text += f"  中位数: {stats['md']['median']:.6f}\n"

            # 更新统计文本
            self.roi_stats_text.config(state=tk.NORMAL)
            self.roi_stats_text.delete(1.0, tk.END)
            self.roi_stats_text.insert(tk.END, stats_text)
            self.roi_stats_text.config(state=tk.DISABLED)

        except Exception as e:
            logger.error(f"分析ROI失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"分析ROI失败: {str(e)}")

    def show_roi_overlay(self):
        """显示ROI叠加"""
        if not self.current_roi_name or self.current_roi_name not in self.roi_masks:
            return

        try:
            # 获取当前显示的数据
            if hasattr(self.analyzer, 'fa') and self.analyzer.fa is not None:
                map_type = self.dti_map_var.get()

                if map_type == "fa":
                    data = self.analyzer.fa
                elif map_type == "md":
                    data = self.analyzer.md
                elif map_type == "ad":
                    data = self.analyzer.ad
                elif map_type == "rd":
                    data = self.analyzer.rd
                elif map_type == "rgb":
                    data = self.analyzer.fa  # 使用FA作为底图
            else:
                # 使用原始数据
                data = self.volume_data[self.current_volume]

            # 获取ROI掩码
            roi_mask = self.roi_masks[self.current_roi_name]

            # 获取当前切片的数据和掩码
            if self.current_view == "axial":
                slice_data = data[self.current_slice]
                slice_mask = roi_mask[self.current_slice]
            elif self.current_view == "sagittal":
                slice_data = data[:, :, self.current_slice]
                slice_mask = roi_mask[:, :, self.current_slice]
            elif self.current_view == "coronal":
                slice_data = data[:, self.current_slice]
                slice_mask = roi_mask[:, self.current_slice]

            # 创建带ROI叠加的图像
            overlay_data = slice_data.copy()

            # 设置ROI区域为高亮色
            roi_value = np.max(overlay_data) if np.max(overlay_data) > 0 else 1
            overlay_data[slice_mask] = roi_value

            # 显示叠加图像
            self.display_overlay(overlay_data, slice_mask)

        except Exception as e:
            logger.error(f"显示ROI叠加失败: {str(e)}", exc_info=True)

    def display_overlay(self, data, mask):
        """显示叠加图像"""
        # 清除当前渲染器
        self.renderer.RemoveAllViewProps()

        # 创建VTK图像数据
        vtk_image = vtk.vtkImageData()
        vtk_image.SetDimensions(data.shape[1], data.shape[0], 1)
        vtk_image.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 3)  # RGB

        # 将numpy数组拷贝到VTK图像
        # 创建彩色图像，ROI区域使用红色高亮
        for y in range(data.shape[0]):
            for x in range(data.shape[1]):
                if mask[y, x]:
                    # ROI区域为红色
                    vtk_image.SetScalarComponentFromFloat(x, y, 0, 0, 255)
                    vtk_image.SetScalarComponentFromFloat(x, y, 0, 1, 50)
                    vtk_image.SetScalarComponentFromFloat(x, y, 0, 2, 50)
                else:
                    # 非ROI区域使用灰度
                    value = int(data[y, x] * 255 / np.max(data)) if np.max(data) > 0 else 0
                    vtk_image.SetScalarComponentFromFloat(x, y, 0, 0, value)
                    vtk_image.SetScalarComponentFromFloat(x, y, 0, 1, value)
                    vtk_image.SetScalarComponentFromFloat(x, y, 0, 2, value)

        # 创建图像actor
        actor = vtk.vtkImageActor()
        actor.SetInputData(vtk_image)

        # 添加到渲染器
        self.renderer.AddActor(actor)

        # 重置相机
        self.renderer.ResetCamera()

        # 渲染
        self.render_window.Render()

    def delete_roi(self):
        """删除ROI"""
        if not self.current_roi_name:
            messagebox.showwarning("警告", "请先选择ROI")
            return

        # 询问确认
        if not messagebox.askyesno("确认", f"确定要删除ROI '{self.current_roi_name}'?"):
            return

        # 删除ROI
        if self.current_roi_name in self.roi_masks:
            del self.roi_masks[self.current_roi_name]

            # 更新ROI列表
            self.update_roi_list()

            # 清除当前选择
            self.current_roi_name = None

            # 清除统计文本
            self.roi_stats_text.config(state=tk.NORMAL)
            self.roi_stats_text.delete(1.0, tk.END)
            self.roi_stats_text.insert(tk.END, "未选择ROI")
            self.roi_stats_text.config(state=tk.DISABLED)

            # 更新显示
            self.update_display()

            # 更新状态
            self.status_label.config(text="ROI已删除")

    def save_roi(self):
        """保存ROI掩码"""
        if not self.roi_masks:
            messagebox.showwarning("警告", "没有ROI可保存")
            return

        # 选择保存目录
        save_dir = filedialog.askdirectory(title="选择保存目录")
        if not save_dir:
            return

        try:
            # 创建目录
            roi_dir = os.path.join(save_dir, "ROIs")
            os.makedirs(roi_dir, exist_ok=True)

            # 保存所有ROI
            for roi_name, mask in self.roi_masks.items():
                # 保存为NIfTI文件
                roi_path = os.path.join(roi_dir, f"{roi_name}.nii.gz")
                save_nifti(roi_path, mask.astype(np.uint8), np.eye(4))

            messagebox.showinfo("成功", f"已保存 {len(self.roi_masks)} 个ROI到 {roi_dir}")

        except Exception as e:
            logger.error(f"保存ROI失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"保存ROI失败: {str(e)}")

    def analyze_roi(self):
        """分析ROI中的DTI指标"""
        if not self.current_roi_name:
            messagebox.showwarning("警告", "请先选择ROI")
            return

        # 已经在select_roi和analyze_current_roi中实现
        self.analyze_current_roi()

        # 显示详细分析结果窗口
        self.show_detailed_roi_analysis()

    def show_detailed_roi_analysis(self):
        """显示详细的ROI分析结果"""
        if not self.current_roi_name or self.current_roi_name not in self.roi_masks:
            return

        # 检查是否已完成DTI分析
        if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
            messagebox.showinfo("提示", "请先进行DTI分析")
            return

        # 创建新窗口
        analysis_window = tk.Toplevel(self.root)
        analysis_window.title(f"ROI分析: {self.current_roi_name}")
        analysis_window.geometry("800x600")

        # 获取ROI掩码
        roi_mask = self.roi_masks[self.current_roi_name]

        # 计算DTI指标统计
        stats = self.analyzer.get_roi_statistics(roi_mask)

        # 创建选项卡
        tab_control = ttk.Notebook(analysis_window)

        # 统计选项卡
        stats_tab = ttk.Frame(tab_control)
        tab_control.add(stats_tab, text="统计数据")

        # 直方图选项卡
        hist_tab = ttk.Frame(tab_control)
        tab_control.add(hist_tab, text="直方图")

        # 配置文件选项卡
        report_tab = ttk.Frame(tab_control)
        tab_control.add(report_tab, text="报告")

        tab_control.pack(expand=1, fill=tk.BOTH)

        # 填充统计选项卡
        stats_frame = ttk.LabelFrame(stats_tab, text="DTI指标统计")
        stats_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        # 创建表格
        columns = ("指标", "均值", "标准差", "最小值", "最大值", "中位数")
        tree = ttk.Treeview(stats_frame, columns=columns, show="headings")

        for col in columns:
            tree.heading(col, text=col)
            tree.column(col, width=100)

        # 添加数据
        for metric, values in stats.items():
            tree.insert("", tk.END, values=(
                metric.upper(),
                f"{values['mean']:.6f}",
                f"{values['std']:.6f}",
                f"{values['min']:.6f}",
                f"{values['max']:.6f}",
                f"{values['median']:.6f}"
            ))

        tree.pack(fill=tk.BOTH, expand=True, padx=5, pady=5)

        # 填充直方图选项卡
        # 创建图形
        fig = plt.Figure(figsize=(8, 6), dpi=100)

        # FA直方图
        ax1 = fig.add_subplot(221)
        fa_values = self.analyzer.fa[roi_mask]
        ax1.hist(fa_values, bins=30, color='blue', alpha=0.7)
        ax1.set_title("FA分布")
        ax1.axvline(np.mean(fa_values), color='r', linestyle='dashed', linewidth=1)
        ax1.text(np.mean(fa_values)*1.05, ax1.get_ylim()[1]*0.9,
                f'Mean: {np.mean(fa_values):.4f}', color='r')

        # MD直方图
        ax2 = fig.add_subplot(222)
        md_values = self.analyzer.md[roi_mask]
        ax2.hist(md_values, bins=30, color='green', alpha=0.7)
        ax2.set_title("MD分布")
        ax2.axvline(np.mean(md_values), color='r', linestyle='dashed', linewidth=1)
        ax2.text(np.mean(md_values)*1.05, ax2.get_ylim()[1]*0.9,
                f'Mean: {np.mean(md_values):.6f}', color='r')

        # AD直方图
        ax3 = fig.add_subplot(223)
        ad_values = self.analyzer.ad[roi_mask]
        ax3.hist(ad_values, bins=30, color='red', alpha=0.7)
        ax3.set_title("AD分布")
        ax3.axvline(np.mean(ad_values), color='black', linestyle='dashed', linewidth=1)
        ax3.text(np.mean(ad_values)*1.05, ax3.get_ylim()[1]*0.9,
                f'Mean: {np.mean(ad_values):.6f}', color='black')

        # RD直方图
        ax4 = fig.add_subplot(224)
        rd_values = self.analyzer.rd[roi_mask]
        ax4.hist(rd_values, bins=30, color='purple', alpha=0.7)
        ax4.set_title("RD分布")
        ax4.axvline(np.mean(rd_values), color='black', linestyle='dashed', linewidth=1)
        ax4.text(np.mean(rd_values)*1.05, ax4.get_ylim()[1]*0.9,
                f'Mean: {np.mean(rd_values):.6f}', color='black')

        fig.tight_layout()

        # 显示图形
        canvas = FigureCanvasTkAgg(fig, hist_tab)
        canvas.draw()
        canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

        # 填充报告选项卡
        report_frame = ttk.Frame(report_tab)
        report_frame.pack(fill=tk.BOTH, expand=True, padx=10, pady=10)

        # 创建报告文本
        report_text = tk.Text(report_frame, wrap=tk.WORD, font=("Arial", 11))
        report_text.pack(fill=tk.BOTH, expand=True, side=tk.LEFT)

        # 添加滚动条
        scrollbar = ttk.Scrollbar(report_frame, command=report_text.yview)
        scrollbar.pack(fill=tk.Y, side=tk.RIGHT)
        report_text.config(yscrollcommand=scrollbar.set)

        # 生成报告内容
        report_content = self.generate_roi_report(stats)
        report_text.insert(tk.END, report_content)

        # 添加保存报告按钮
        ttk.Button(report_tab,
                  text="保存报告",
                  command=lambda: self.save_report(report_content)).pack(pady=10)

    def generate_roi_report(self, stats):
        """生成ROI分析报告"""
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        report = f"ROI分析报告 - {self.current_roi_name}\n"
        report += f"生成时间: {now}\n\n"

        # 患者信息
        if hasattr(self, 'metadata') and self.metadata:
            if 'patient_name' in self.metadata:
                report += f"患者姓名: {self.metadata['patient_name']}\n"
            if 'patient_id' in self.metadata:
                report += f"患者ID: {self.metadata['patient_id']}\n"
            if 'study_date' in self.metadata:
                report += f"检查日期: {self.metadata['study_date']}\n"

        report += "\n"

        # ROI信息
        roi_mask = self.roi_masks[self.current_roi_name]
        report += f"ROI体积: {np.sum(roi_mask)} 体素\n\n"

        # DTI指标
        report += "DTI指标统计:\n"
        report += "-" * 50 + "\n"
        report += f"{'指标':<10}{'均值':<15}{'标准差':<15}{'中位数':<15}\n"
        report += "-" * 50 + "\n"

        for metric, values in stats.items():
            report += f"{metric.upper():<10}{values['mean']:<15.6f}{values['std']:<15.6f}{values['median']:<15.6f}\n"

        report += "-" * 50 + "\n\n"

        # 临床意义
        report += "参考值和临床意义:\n"
        report += "- 正常脊髓FA值通常在0.6-0.8范围内\n"
        report += "- FA值降低可能表示脊髓组织结构受损或退行性改变\n"
        report += "- MD值增高可能表示脊髓脱髓鞘或水肿\n\n"

        # 建议
        fa_mean = stats['fa']['mean']
        md_mean = stats['md']['mean']

        report += "分析建议:\n"
        if fa_mean < 0.5:
            report += "- FA值明显降低，提示该区域可能存在显著的组织结构破坏\n"
        elif fa_mean < 0.6:
            report += "- FA值轻度降低，提示该区域可能存在轻微的组织结构改变\n"
        else:
            report += "- FA值在正常范围内\n"

        if md_mean > 0.002:
            report += "- MD值升高，可能存在组织水肿或脱髓鞘\n"
        else:
            report += "- MD值在正常范围内\n"

        return report

    def save_report(self, report_content):
        """保存分析报告"""
        file_path = filedialog.asksaveasfilename(
            title="保存分析报告",
            defaultextension=".txt",
            filetypes=[("文本文件", "*.txt"), ("所有文件", "*.*")]
        )

        if file_path:
            with open(file_path, 'w', encoding='utf-8') as f:
                f.write(report_content)
            messagebox.showinfo("成功", f"报告已保存到: {file_path}")

    def choose_roi_color(self):
        """选择ROI颜色"""
        color = colorchooser.askcolor(title="选择ROI颜色")
        if color[1]:  # color是一个元组 ((r,g,b), hex)
            # 保存颜色信息
            if not hasattr(self, 'roi_colors'):
                self.roi_colors = {}

            if self.current_roi_name:
                self.roi_colors[self.current_roi_name] = color[1]

                # 更新显示
                self.show
def choose_roi_color(self):
        """选择ROI颜色"""
        color = colorchooser.askcolor(title="选择ROI颜色")
        if color[1]:  # color是一个元组 ((r,g,b), hex)
            # 保存颜色信息
            if not hasattr(self, 'roi_colors'):
                self.roi_colors = {}

            if self.current_roi_name:
                self.roi_colors[self.current_roi_name] = color[1]

                # 更新显示
                self.show_roi_overlay()

    def update_display(self, event=None, use_processed=False):
        """更新当前显示"""
        if self.volume_data is None:
            return

        # 根据处理标志选择数据
        display_data = self.processed_data if use_processed else self.volume_data

        # 清除当前渲染器
        self.renderer.RemoveAllViewProps()

        # 获取当前切片数据
        if self.current_view == "axial":
            max_slice = display_data.shape[1] - 1
            if self.current_slice > max_slice:
                self.current_slice = max_slice
            slice_data = display_data[self.current_volume, self.current_slice]
        elif self.current_view == "sagittal":
            max_slice = display_data.shape[3] - 1
            if self.current_slice > max_slice:
                self.current_slice = max_slice
            slice_data = display_data[self.current_volume, :, :, self.current_slice]
        elif self.current_view == "coronal":
            max_slice = display_data.shape[2] - 1
            if self.current_slice > max_slice:
                self.current_slice = max_slice
            slice_data = display_data[self.current_volume, :, self.current_slice, :]

        # 更新切片滑块范围
        self.slice_scale.configure(to=max_slice)
        self.slice_scale.set(self.current_slice)

        # 创建VTK图像
        vtk_image = self.create_vtk_image_from_array(slice_data)

        # 创建图像actor
        image_actor = vtk.vtkImageActor()
        image_actor.SetInputData(vtk_image)

        # 添加到渲染器
        self.renderer.AddActor(image_actor)

        # 重置相机
        self.renderer.ResetCamera()

        # 渲染
        self.render_window.Render()

        # 更新状态
        self.status_label.config(text=f"{self.current_view.capitalize()} 视图, 切片 {self.current_slice}/{max_slice}")

        # 检查是否需要显示ROI
        if hasattr(self, 'roi_masks') and self.roi_masks and self.current_roi_name in self.roi_masks:
            self.show_roi_overlay()

    def create_vtk_image_from_array(self, array_data):
        """从numpy数组创建VTK图像"""
        # 归一化数据到0-255
        if array_data.max() > array_data.min():
            normalized_data = 255.0 * (array_data - array_data.min()) / (array_data.max() - array_data.min())
        else:
            normalized_data = np.zeros_like(array_data)

        # 创建VTK图像数据
        vtk_image = vtk.vtkImageData()
        vtk_image.SetDimensions(array_data.shape[1], array_data.shape[0], 1)
        vtk_image.AllocateScalars(vtk.VTK_UNSIGNED_CHAR, 1)

        # 拷贝数据
        for y in range(array_data.shape[0]):
            for x in range(array_data.shape[1]):
                vtk_image.SetScalarComponentFromFloat(x, y, 0, 0, normalized_data[y, x])

        return vtk_image

    def update_view(self, event=None):
        """更新视图方向"""
        self.current_view = self.view_var.get()
        self.current_slice = 0  # 重置切片索引
        self.update_display()

    def update_slice(self, event=None):
        """更新切片位置"""
        self.current_slice = int(self.slice_scale.get())
        self.update_display()

    def update_window(self, event=None):
        """更新窗宽窗位"""
        # 这个函数在有vtk图像显示器时使用
        if hasattr(self, 'image_viewer'):
            window = self.window_scale.get()
            level = self.level_scale.get()
            self.image_viewer.SetColorWindow(window)
            self.image_viewer.SetColorLevel(level)
            self.render_window.Render()

    def show_3d_view(self):
        """显示3D视图"""
        if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
            messagebox.showwarning("警告", "请先进行DTI分析")
            return

        # 创建新窗口
        view_window = tk.Toplevel(self.root)
        view_window.title("DTI 3D视图")
        view_window.geometry("800x600")

        # 创建VTK渲染环境
        renderer = vtk.vtkRenderer()
        renderer.SetBackground(0.2, 0.2, 0.2)

        render_window = vtk.vtkRenderWindow()
        render_window.AddRenderer(renderer)

        # 创建渲染窗口小部件
        render_widget = vtkTkRenderWindowInteractor(
            view_window,
            rw=render_window,
            width=800,
            height=600
        )
        render_widget.pack(fill=tk.BOTH, expand=True)

        interactor = render_widget.GetRenderWindow().GetInteractor()
        interactor.SetInteractorStyle(vtk.vtkInteractorStyleTrackballCamera())

        # 添加FA体积渲染
        if self.analyzer.fa is not None:
            # 创建体积
            volume_mapper = vtk.vtkSmartVolumeMapper()
            volume_property = vtk.vtkVolumeProperty()

            # 创建VTK图像数据
            fa_data = self.analyzer.fa
            vtk_volume = vtk.vtkImageData()
            vtk_volume.SetDimensions(fa_data.shape[2], fa_data.shape[1], fa_data.shape[0])
            vtk_volume.AllocateScalars(vtk.VTK_FLOAT, 1)

            # 填充数据
            for z in range(fa_data.shape[0]):
                for y in range(fa_data.shape[1]):
                    for x in range(fa_data.shape[2]):
                        vtk_volume.SetScalarComponentFromFloat(x, y, z, 0, fa_data[z, y, x])

            volume_mapper.SetInputData(vtk_volume)

            # 设置不透明度传输函数
            opacity_function = vtk.vtkPiecewiseFunction()
            opacity_function.AddPoint(0.0, 0.0)
            opacity_function.AddPoint(0.2, 0.0)
            opacity_function.AddPoint(0.5, 0.1)
            opacity_function.AddPoint(0.8, 0.2)
            opacity_function.AddPoint(1.0, 0.3)

            # 设置颜色传输函数
            color_function = vtk.vtkColorTransferFunction()
            color_function.AddRGBPoint(0.0, 0.0, 0.0, 0.0)
            color_function.AddRGBPoint(0.2, 0.1, 0.1, 0.3)
            color_function.AddRGBPoint(0.5, 0.3, 0.3, 0.7)
            color_function.AddRGBPoint(0.8, 0.7, 0.7, 0.9)
            color_function.AddRGBPoint(1.0, 1.0, 1.0, 1.0)

            # 设置体积属性
            volume_property.SetColor(color_function)
            volume_property.SetScalarOpacity(opacity_function)
            volume_property.ShadeOn()
            volume_property.SetInterpolationTypeToLinear()

            # 创建体积
            volume = vtk.vtkVolume()
            volume.SetMapper(volume_mapper)
            volume.SetProperty(volume_property)

            # 添加到渲染器
            renderer.AddVolume(volume)

        # 如果有纤维束，添加纤维束
        if hasattr(self, 'streamlines') and self.streamlines:
            # 创建纤维束表示
            streamlines_vtk = vtk.vtkPolyData()
            points = vtk.vtkPoints()
            lines = vtk.vtkCellArray()

            point_id = 0
            for streamline in self.streamlines:
                line = vtk.vtkPolyLine()
                line.GetPointIds().SetNumberOfIds(len(streamline))

                for i, point in enumerate(streamline):
                    points.InsertNextPoint(point[0], point[1], point[2])
                    line.GetPointIds().SetId(i, point_id)
                    point_id += 1

                lines.InsertNextCell(line)

            streamlines_vtk.SetPoints(points)
            streamlines_vtk.SetLines(lines)

            # 创建颜色映射
            vtk_colors = vtk.vtkUnsignedCharArray()
            vtk_colors.SetNumberOfComponents(3)
            vtk_colors.SetName("Colors")

            # 为每条纤维分配方向颜色
            if hasattr(self.analyzer, 'rgb') and self.analyzer.rgb is not None:
                # 使用RGB方向颜色
                rgb_data = self.analyzer.rgb

                for i, streamline in enumerate(self.streamlines):
                    # 获取纤维中点的RGB值
                    mid_point = streamline[len(streamline)//2]
                    x, y, z = int(mid_point[0]), int(mid_point[1]), int(mid_point[2])

                    # 边界检查
                    x = max(0, min(x, rgb_data.shape[2]-1))
                    y = max(0, min(y, rgb_data.shape[1]-1))
                    z = max(0, min(z, rgb_data.shape[0]-1))

                    r, g, b = rgb_data[z, y, x]

                    # 将颜色值缩放到0-255
                    r = int(r * 255)
                    g = int(g * 255)
                    b = int(b * 255)

                    for j in range(len(streamline)):
                        vtk_colors.InsertNextTuple3(r, g, b)
            else:
                # 使用随机颜色
                for i in range(len(self.streamlines)):
                    r, g, b = self.get_random_color()
                    for j in range(len(self.streamlines[i])):
                        vtk_colors.InsertNextTuple3(r, g, b)

            streamlines_vtk.GetPointData().SetScalars(vtk_colors)

            # 创建纤维bundle actor
            mapper = vtk.vtkPolyDataMapper()
            mapper.SetInputData(streamlines_vtk)

            actor = vtk.vtkActor()
            actor.SetMapper(mapper)

            # 添加到渲染器
            renderer.AddActor(actor)

        # 添加坐标轴
        axes = vtk.vtkAxesActor()
        axes_widget = vtk.vtkOrientationMarkerWidget()
        axes_widget.SetOrientationMarker(axes)
        axes_widget.SetInteractor(interactor)
        axes_widget.SetViewport(0.0, 0.0, 0.2, 0.2)
        axes_widget.SetEnabled(1)
        axes_widget.InteractiveOff()

        # 设置相机
        renderer.ResetCamera()

        # 启动
        interactor.Initialize()
        interactor.Start()

    def set_view(self, view):
        """设置视图方向"""
        self.view_var.set(view)
        self.update_view()

    def show_roi_analysis(self):
        """显示ROI分析界面"""
        if not self.roi_masks:
            messagebox.showwarning("警告", "请先创建ROI")
            return

        # 选择一个ROI进行分析
        if not self.current_roi_name:
            # 如果没有当前选择的ROI，使用第一个
            self.current_roi_name = list(self.roi_masks.keys())[0]

        # 分析当前ROI
        self.analyze_roi()

    def set_save_path(self):
        """设置保存路径"""
        save_path = filedialog.askdirectory(title="选择保存目录")
        if save_path:
            self.save_path_var.set(save_path)

    def save_results(self):
        """保存分析结果"""
        if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
            messagebox.showwarning("警告", "请先进行DTI分析")
            return

        # 获取保存路径
        save_path = self.save_path_var.get()

        # 确认路径存在
        if not os.path.exists(save_path):
            try:
                os.makedirs(save_path)
            except:
                messagebox.showerror("错误", f"无法创建目录: {save_path}")
                return

        try:
            # 创建时间戳子目录
            timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
            result_dir = os.path.join(save_path, f"DTI_Results_{timestamp}")
            os.makedirs(result_dir)

            # 保存DTI分析结果
            self.analyzer.save_results(result_dir, self.metadata)

            # 保存ROI掩码
            if self.roi_masks:
                roi_dir = os.path.join(result_dir, "ROIs")
                os.makedirs(roi_dir)

                for roi_name, mask in self.roi_masks.items():
                    # 保存为NIfTI文件
                    roi_path = os.path.join(roi_dir, f"{roi_name}.nii.gz")
                    save_nifti(roi_path, mask.astype(np.uint8), np.eye(4))

            # 保存纤维束
            if hasattr(self, 'streamlines') and self.streamlines:
                # 保存为TrackVis格式
                from dipy.io.streamline import save_trk
                track_path = os.path.join(result_dir, 'streamlines.trk')
                save_trk(track_path,
                         self.streamlines,
                         np.eye(4),
                         self.analyzer.fa.shape)

            # 保存ROI分析报告
            if self.roi_masks and self.current_roi_name:
                report_path = os.path.join(result_dir, f"{self.current_roi_name}_report.txt")
                roi_mask = self.roi_masks[self.current_roi_name]
                stats = self.analyzer.get_roi_statistics(roi_mask)
                report_content = self.generate_roi_report(stats)

                with open(report_path, 'w', encoding='utf-8') as f:
                    f.write(report_content)

            # 保存图像截图
            screenshot_path = os.path.join(result_dir, "screenshot.png")
            self.save_screenshot(screenshot_path)

            messagebox.showinfo("成功", f"分析结果已保存到: {result_dir}")

        except Exception as e:
            logger.error(f"保存结果失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"保存结果失败: {str(e)}")

    def save_screenshot(self, path):
        """保存当前视图截图"""
        # 创建窗口到图像过滤器
        window_to_image = vtk.vtkWindowToImageFilter()
        window_to_image.SetInput(self.render_window)
        window_to_image.Update()

        # 创建PNG写入器
        writer = vtk.vtkPNGWriter()
        writer.SetFileName(path)
        writer.SetInputConnection(window_to_image.GetOutputPort())
        writer.Write()

    def export_report(self):
        """导出分析报告"""
        if not hasattr(self.analyzer, 'fa') or self.analyzer.fa is None:
            messagebox.showwarning("警告", "请先进行DTI分析")
            return

        # 选择保存路径
        report_path = filedialog.asksaveasfilename(
            title="保存分析报告",
            defaultextension=".html",
            filetypes=[("HTML文件", "*.html"), ("PDF文件", "*.pdf"), ("所有文件", "*.*")]
        )

        if not report_path:
            return

        try:
            # 生成报告
            report_content = self.generate_full_report()

            # 保存报告
            with open(report_path, 'w', encoding='utf-8') as f:
                f.write(report_content)

            # 打开报告
            import webbrowser
            webbrowser.open(report_path)

            messagebox.showinfo("成功", f"报告已导出到: {report_path}")

        except Exception as e:
            logger.error(f"导出报告失败: {str(e)}", exc_info=True)
            messagebox.showerror("错误", f"导出报告失败: {str(e)}")

    def generate_full_report(self):
        """生成完整的分析报告"""
        now = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")

        # 创建HTML报告
        html = f"""
        <!DOCTYPE html>
        <html>
        <head>
            <meta charset="UTF-8">
            <title>DTI分析报告</title>
            <style>
                body {{ font-family: Arial, sans-serif; margin: 20px; }}
                h1, h2, h3 {{ color: #333366; }}
                .container {{ max-width: 1000px; margin: 0 auto; }}
                table {{ border-collapse: collapse; width: 100%; margin: 15px 0; }}
                th, td {{ border: 1px solid #ddd; padding: 8px; text-align: left; }}
                th {{ background-color: #f2f2f2; }}
                tr:nth-child(even) {{ background-color: #f9f9f9; }}
                .section {{ margin: 20px 0; padding: 15px; border: 1px solid #ddd; border-radius: 5px; }}
                .image-container {{ text-align: center; margin: 15px 0; }}
                img {{ max-width: 100%; }}
                .footer {{ margin-top: 50px; font-size: 12px; color: #666; text-align: center; }}
            </style>
        </head>
        <body>
            <div class="container">
                <h1>DTI分析报告</h1>
                <p><strong>生成时间:</strong> {now}</p>
        """

        # 患者信息
        html += """
                <div class="section">
                    <h2>患者信息</h2>
                    <table>
                        <tr><th>属性</th><th>值</th></tr>
        """

        if hasattr(self, 'metadata') and self.metadata:
            for key, value in self.metadata.items():
                if key in ['patient_name', 'patient_id', 'study_date', 'study_time', 'modality']:
                    html += f"<tr><td>{key}</td><td>{value}</td></tr>\n"

        html += """
                    </table>
                </div>
        """

        # DTI参数
        html += """
                <div class="section">
                    <h2>全脑DTI参数</h2>
        """

        # 添加FA直方图
        if hasattr(self.analyzer, 'fa') and self.analyzer.fa is not None:
            fa_values = self.analyzer.fa.flatten()
            fa_values = fa_values[~np.isnan(fa_values)]  # 移除NaN
            fa_values = fa_values[fa_values > 0]  # 移除零值

            if len(fa_values) > 0:
                # 计算统计值
                fa_mean = np.mean(fa_values)
                fa_std = np.std(fa_values)
                fa_median = np.median(fa_values)

                html += f"""
                    <p><strong>FA均值:</strong> {fa_mean:.4f} ± {fa_std:.4f}</p>
                    <p><strong>FA中位数:</strong> {fa_median:.4f}</p>
                """

        html += """
                </div>
        """

        # ROI分析
        if self.roi_masks:
            html += """
                <div class="section">
                    <h2>ROI分析</h2>
                    <table>
                        <tr><th>ROI名称</th><th>体积 (体素)</th><th>平均FA</th><th>平均MD</th></tr>
            """

            for roi_name, mask in self.roi_masks.items():
                if hasattr(self.analyzer, 'fa') and self.analyzer.fa is not None:
                    stats = self.analyzer.get_roi_statistics(mask)
                    volume = np.sum(mask)
                    fa_mean = stats['fa']['mean']
                    md_mean = stats['md']['mean']

                    html += f"<tr><td>{roi_name}</td><td>{volume}</td><td>{fa_mean:.4f}</td><td>{md_mean:.6f}</td></tr>\n"

            html += """
                    </table>
                </div>
            """

        # 纤维追踪
        if hasattr(self, 'streamlines') and self.streamlines:
            html += f"""
                <div class="section">
                    <h2>纤维追踪</h2>
                    <p><strong>纤维数量:</strong> {len(self.streamlines)}</p>
                    <p><strong>FA阈值:</strong> {self.fa_threshold.get()}</p>
                    <p><strong>角度阈值:</strong> {self.angle_threshold.get()} 度</p>
                </div>
            """

        # 结论和建议
        html += """
                <div class="section">
                    <h2>结论和建议</h2>
                    <p>本报告由计算机软件自动生成，仅供参考。请结合临床症状和其他影像学检查进行综合分析。</p>
                </div>

                <div class="footer">
                    <p>脊髓DTI分析与可视化工具生成</p>
                </div>
            </div>
        </body>
        </html>
        """

        return html

    def show_help(self):
        """显示帮助信息"""
        help_window = tk.Toplevel(self.root)
        help_window.title("帮助")
        help_window.geometry("600x400")

        help_text = tk.Text(help_window, wrap=tk.WORD, font=("Arial", 10))
        help_text.pack(fill=tk.BOTH, expand=True, side=tk.LEFT)

        scrollbar = ttk.Scrollbar(help_window, command=help_text.yview)
        scrollbar.pack(fill=tk.Y, side=tk.RIGHT)
        help_text.config(yscrollcommand=scrollbar.set)

        help_content = """
脊髓DTI分析与可视化工具 - 使用指南

1. 数据加载
   - 选择"文件"菜单中的"加载DICOM目录"或"加载NIfTI文件"
   - 对于NIfTI格式，需要同时加载b值文件(.bval)和梯度方向文件(.bvec)

2. 数据预处理
   - 在"预处理"选项卡中选择需要的预处理步骤
   - 点击"执行预处理"按钮进行数据预处理

3. DTI分析
   - 在"分析"选项卡中点击"执行DTI分析"按钮
   - 分析完成后可以查看FA、MD、AD、RD等参数图

4. 纤维追踪
   - 在"分析"选项卡中设置FA阈值、角度阈值和步长
   - 点击"开始纤维追踪"按钮进行追踪

5. ROI分析
   - 在"ROI"选项卡中创建感兴趣区域
   - 可以创建矩形、圆形或自由绘制的ROI
   - 点击"分析ROI"可以查看详细统计数据

6. 保存结果
   - 选择"文件"菜单中的"保存结果"可以保存所有分析数据
   - 选择"导出报告"可以生成HTML格式的分析报告

快捷键:
- 上/下方向键: 切换切片
- 左/右方向键: 切换体积
- 鼠标滚轮: 切换切片

如需更多帮助，请参考用户手册或联系技术支持。
        """

        help_text.insert(tk.END, help_content)
        help_text.config(state=tk.DISABLED)

    def show_about(self):
        """显示关于信息"""
        messagebox.showinfo(
            "关于",
            "脊髓DTI分析与可视化工具 V2.0\n\n"
            "一款用于脊髓DTI数据处理、分析和可视化的专业工具。\n\n"
            "支持DICOM和NIfTI格式的DTI数据\n"
            "提供预处理、张量分析、纤维追踪和ROI分析功能\n\n"
            "© 2025 版权所有"
        )

class RoiCallback:
    """ROI交互回调类"""

    def __init__(self, viewer):
        self.viewer = viewer

    def __call__(self, caller, event):
        # 右键点击完成ROI
        if event == "InteractionEvent":
            if self.viewer.interactor.GetRightButtonLongPressEvent():
                self.viewer.finish_roi()

class DrawPolylineStyle(vtk.vtkInteractorStyleImage):
    """自定义的多边形绘制交互样式"""

    def __init__(self, viewer):
        self.viewer = viewer
        self.points = []
        self.renderer = viewer.renderer
        self.interactor = viewer.interactor

        # 创建点和线的actors
        self.points_actor = vtk.vtkActor()
        self.lines_actor = vtk.vtkActor()

        self.renderer.AddActor(self.points_actor)
        self.renderer.AddActor(self.lines_actor)

        # 添加观察者
        self.AddObserver("LeftButtonPressEvent", self.left_button_press)
        self.AddObserver("RightButtonPressEvent", self.right_button_press)
        self.AddObserver("MouseMoveEvent", self.mouse_move)

    def left_button_press(self, obj, event):
        """左键点击添加点"""
        click_pos = self.interactor.GetEventPosition()

        # 将屏幕坐标转换为世界坐标
        picker = vtk.vtkWorldPointPicker()
        picker.Pick(click_pos[0], click_pos[1], 0, self.renderer)
        world_pos = picker.GetPickPosition()

        # 添加点
        self.points.append((world_pos[0], world_pos[1]))

        # 更新显示
        self.update_polyline()

        # 重新渲染
        self.interactor.Render()

    def right_button_press(self, obj, event):
        """右键点击完成绘制"""
        if len(self.points) >= 3:
            self.viewer.finish_roi(self.points)
        else:
            messagebox.showwarning("警告", "至少需要3个点才能创建ROI")

        # 清除点和线
        self.points = []
        self.renderer.RemoveActor(self.points_actor)
        self.renderer.RemoveActor(self.lines_actor)

        # 恢复默认交互样式
        self.viewer.interaction_style = vtk.vtkInteractorStyleImage()
        self.interactor.SetInteractorStyle(self.viewer.interaction_style)

        # 重新渲染
        self.interactor.Render()

    def mouse_move(self, obj, event):
        """鼠标移动时更新临时线"""
        if len(self.points) > 0:
            current_pos = self.interactor.GetEventPosition()

            # 将屏幕坐标转换为世界坐标
            picker = vtk.vtkWorldPointPicker()
            picker.Pick(current_pos[0], current_pos[1], 0, self.renderer)
            world_pos = picker.GetPickPosition()

            # 更新临时线
            self.update_temp_line(world_pos)

            # 重新渲染
            self.interactor.Render()

    def update_polyline(self):
        """更新多边形线条显示"""
        # 创建点
        points = vtk.vtkPoints()
        for point in self.points:
            points.InsertNextPoint(point[0], point[1], 0)

        # 创建多边形
        polygon = vtk.vtkPolygon()
        polygon.GetPointIds().SetNumberOfIds(len(self.points))
        for i in range(len(self.points)):
            polygon.GetPointIds().SetId(i, i)

        # 创建单元格阵列
        polygons = vtk.vtkCellArray()
        polygons.InsertNextCell(polygon)

        # 创建多边形数据
        polydata = vtk.vtkPolyData()
        polydata.SetPoints(points)
        polydata.SetPolys(polygons)

        # 创建映射器
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputData(polydata)

        # 更新actor
        self.points_actor.SetMapper(mapper)
        self.points_actor.GetProperty().SetColor(1, 0, 0)  # 红色
        self.points_actor.GetProperty().SetPointSize(5)

    def update_temp_line(self, current_pos):
        """更新临时线条"""
        if not self.points:
            return

        # 创建点
        points = vtk.vtkPoints()
        for point in self.points:
            points.InsertNextPoint(point[0], point[1], 0)

        # 添加当前位置
        points.InsertNextPoint(current_pos[0], current_pos[1], 0)

        # 创建线
        lines = vtk.vtkCellArray()
        for i in range(len(self.points)):
            line = vtk.vtkLine()
            line.GetPointIds().SetId(0, i)
            line.GetPointIds().SetId(1, (i + 1) % (len(self.points) + 1))
            lines.InsertNextCell(line)

        # 创建线数据
        linedata = vtk.vtkPolyData()
        linedata.SetPoints(points)
        linedata.SetLines(lines)

        # 创建映射器
        mapper = vtk.vtkPolyDataMapper()
        mapper.SetInputData(linedata)

        # 更新actor
        self.lines_actor.SetMapper(mapper)
        self.lines_actor.GetProperty().SetColor(0, 1, 0)  # 绿色
        self.lines_actor.GetProperty().SetLineWidth(2)

def main():
    root = tk.Tk()
    app = SpinalDTIViewer(root)
    root.mainloop()

if __name__ == "__main__":
    main()

IndentationError: unindent does not match any outer indentation level (<tokenize>, line 2671)