# Detection of Alzheimer's Disease Using Graph Signal Processing of EEG Recordings


## Data

Link here: https://openneuro.org/datasets/ds004504/versions/1.0.6

### Participants:
- **Total subjects: 88**
  - **Alzheimer's disease (AD group): 36**
    - Average MMSE: 17.75 (sd=4.5)
    - Mean age: 66.4 (sd=7.9)
    - Median disease duration: 25 months, IQR: 24 - 28.5
  - **Frontotemporal Dementia (FTD group): 23**
    - Average MMSE: 22.17 (sd=8.22)
    - Mean age: 63.6 (sd=8.2)
  - **Healthy subjects (CN group): 29**
    - Average MMSE: 30
    - Mean age: 67.9 (sd=5.4)
- **MMSE score ranges:** 0 to 30 (lower scores indicate more severe decline)

### Recordings:
- **Location:** 2nd Department of Neurology of AHEPA General Hospital, Thessaloniki
- **Device:** Nihon Kohden EEG 2100 clinical device, 19 scalp electrodes, 2 reference electrodes
- **Parameters:** 500 Hz sampling rate, 10uV/mm resolution, sensitivity 10uV/mm, time constant 0.3s, high-frequency filter at 70 Hz
- **Duration:** 
  - AD group: 13.5 minutes (min=5.1, max=21.3)
  - FTD group: 12 minutes (min=7.9, max=16.9)
  - CN group: 13.8 minutes (min=12.5, max=16.5)
- **Total recordings:**
  - AD: 485.5 minutes
  - FTD: 276.5 minutes
  - CN: 402 minutes

### Preprocessing:
- **Exported to:** .eeg format, transformed to BIDS accepted .set format
- **Unprocessed recordings in folders named:** sub-0XX
- **Preprocessed and denoised recordings in sub-0XX within subfolder derivatives**
- **Preprocessing pipeline:**
  - Butterworth band-pass filter 0.5-45 Hz
  - Re-referenced to A1-A2
  - Artifact Subspace Reconstruction (ASR) applied
  - Independent Component Analysis (ICA) method performed, transformed to 19 ICA components
  - Eye and jaw artifacts automatically rejected
- **Automatic annotations of artifacts not included for language compatibility**
- **Preprocessed dataset available in Folder:** derivatives

## Methodology

Same as in [Detection of Epilepsy Using Graph Signal Processing of EEG Signals with Three Features](https://link.springer.com/chapter/10.1007/978-981-19-1520-8_46)


## GSP Processing

In [1]:
# 安裝和配置 CuPy（GPU 加速的 NumPy）
import sys
import subprocess

# 檢查是否有 GPU
try:
    result = subprocess.run(['nvidia-smi'], capture_output=True, text=True)
    has_cuda = result.returncode == 0
except:
    has_cuda = False

# 嘗試導入 CuPy
try:
    import cupy as cp
    USE_GPU = True
    print("✓ CuPy 已安裝，將使用 GPU 加速")
    print(f"GPU 設備數量: {cp.cuda.runtime.getDeviceCount()}")
except ImportError:
    USE_GPU = False
    if has_cuda:
        print("正在安裝 CuPy（GPU 加速的 NumPy）...")
        # 根據 CUDA 版本安裝對應的 CuPy
        # 可以根據需要調整版本，這裡使用通用版本
        subprocess.check_call([sys.executable, "-m", "pip", "install", "cupy-cuda11x"])
        import cupy as cp
        USE_GPU = True
        print("✓ CuPy 安裝完成，將使用 GPU 加速")
    else:
        print("未檢測到 GPU，將使用 CPU（NumPy）")
        cp = None

# 設置默認設備
if USE_GPU:
    # 使用 GPU 0
    cp.cuda.Device(0).use()
    print(f"使用 GPU 設備: {cp.cuda.Device(0).id}")


✓ CuPy 已安裝，將使用 GPU 加速
GPU 設備數量: 1
使用 GPU 設備: 0


In [2]:
# 安裝所需的套件
import sys
import subprocess

# 需要安裝的套件列表
packages = [
    'pandas',           # 數據處理
    'numpy',            # 數值計算
    'mne',              # MNE-Python - 神經科學數據分析
    'pygsp',            # PyGSP - 圖信號處理
    'scipy',            # 科學計算庫
    'networkx',         # 圖網絡分析
    'scikit-learn',     # 機器學習庫 (sklearn)
    'seaborn',          # 數據可視化
    'matplotlib',       # 繪圖庫
    'umap-learn'        # UMAP 降維 (導入時使用 import umap)
]

# 安裝套件
for package in packages:
    try:
        # 處理特殊情況：包名和導入名不同
        if package == 'scikit-learn':
            import_name = 'sklearn'
        elif package == 'umap-learn':
            import_name = 'umap'
        else:
            import_name = package
        __import__(import_name)
        print(f"✓ {package} 已安裝")
    except ImportError:
        print(f"正在安裝 {package}...")
        subprocess.check_call([sys.executable, "-m", "pip", "install", package])
        print(f"✓ {package} 安裝完成")



✓ pandas 已安裝
✓ numpy 已安裝
✓ mne 已安裝
✓ pygsp 已安裝
✓ scipy 已安裝
✓ networkx 已安裝
✓ scikit-learn 已安裝
✓ seaborn 已安裝
✓ matplotlib 已安裝


  from .autonotebook import tqdm as notebook_tqdm


✓ umap-learn 已安裝


In [3]:
import pandas as pd
import numpy as np
import mne
import os
from pygsp import graphs, utils
from scipy.spatial import distance_matrix
from scipy.stats import entropy
import networkx as nx
from sklearn.cluster import spectral_clustering

In [5]:
def compute_total_variation(W, data_values):
    """
    向量化版本的总变分计算，比双重循环快得多
    
    总变分公式: TV = sqrt(sum_{i,j} w_{ij} * ||x_j - x_i||^2)
    """
    # 使用广播计算所有 (i,j) 对的差值
    # data_values[i] 形状: (D,), data_values[j] 形状: (D,)
    # 我们需要计算所有 i,j 对的差值
    N = data_values.shape[0]
    
    # 方法1: 使用广播 (更高效)
    # 扩展维度: data_values[i] -> (N, 1, D), data_values[j] -> (1, N, D)
    data_i = data_values[:, np.newaxis, :]  # (N, 1, D)
    data_j = data_values[np.newaxis, :, :]  # (1, N, D)
    
    # 计算所有差值: (N, N, D)
    differences = data_j - data_i
    
    # 计算每个差值的 L2 范数的平方: (N, N)
    squared_norms = np.sum(differences ** 2, axis=2)
    
    # 与权重矩阵相乘并求和
    TV = np.sum(W * squared_norms)
    
    return np.sqrt(TV)

dir_path = r'/home/b12901075/eecsmed/ds004504/derivatives'
file_list = [
    os.path.join(root, file) 
    for root, dirs, files in os.walk(dir_path) 
    for file in files 
    if file.endswith(".set")
]

n_files = len(file_list)
if n_files == 0:
    raise ValueError(f"No .set files found in directory {dir_path}.")

print(f'Found {n_files} .set files.')

channel_names = ['Fp1', 'Fp2', 'F3', 'F4', 'C3', 'C4', 'P3', 'P4', 'O1', 'O2', 'F7', 'F8', 'T3', 'T4', 'T5', 'T6', 'Fz', 'Cz', 'Pz']

data_list = []
features = {}

for i, file in enumerate(file_list):
    print(f'Processing {file} ({i+1}/{n_files})')
    raw = mne.io.read_raw_eeglab(file)
    data = raw.get_data(picks=channel_names)
    transposed_data = np.transpose(data)
    data = pd.DataFrame(transposed_data, columns=channel_names)
    data = data.groupby(data.index // 50).median()
    data_list.append(data)

    # GSP analysis
    distances = distance_matrix(data.values, data.values)
    theta, k = 1.0, 1.0 
    W = np.exp(-distances**2 / theta**2)
    W[distances > k] = 0
    np.fill_diagonal(W, 0)
    G = graphs.Graph(W)
    L = G.L.toarray()
    eigenvalues, eigenvectors = np.linalg.eigh(L)
    X_GdataT = eigenvectors.T @ data.values
    C = np.cov(X_GdataT)
    T = eigenvectors.T.conj() @ C @ eigenvectors
    r = np.linalg.norm(np.diag(T)) / np.linalg.norm(T, 'fro')
    P = L @ data.values
    Y = np.sum(data.values * P)**2
    TV = compute_total_variation(W, data.values)

    # Spectral Graph Features
    graph_energy = np.sum(np.abs(eigenvalues))
    # 计算 spectral entropy：使用特征值的归一化分布作为概率分布
    # 将特征值转换为非负值并归一化
    eigenvals_abs = np.abs(eigenvalues)
    eigenvals_normalized = eigenvals_abs / (np.sum(eigenvals_abs) + 1e-10)  # 添加小值避免除零
    spectral_entropy = entropy(eigenvals_normalized)

    # Graph Signal Features
    signal_energy = np.sum(np.square(data.values))
    signal_power = np.var(data.values)

    # Graph Modularity and Community Structure
    labels = spectral_clustering(W)
    unique_labels = len(np.unique(labels))

    # Graph Degree Distribution
    degree_distribution = np.sum(W, axis=0)

    # Graph Diffusion Characteristics
    heat_trace = np.trace(np.exp(-L))
    diffusion_distance = np.sum(np.exp(-L))

    # Aggregating Features
    features[os.path.basename(file)] = {
        'stationary_ratio': r, 
        'Tik-norm': Y, 
        'Total_Variation': TV,
        'graph_energy': graph_energy,
        'spectral_entropy': spectral_entropy,
        'signal_energy': signal_energy,
        'signal_power': signal_power,
        'unique_clusters': unique_labels,
        'avg_degree': np.mean(degree_distribution),
        'heat_trace': heat_trace,
        'diffusion_distance': diffusion_distance
    }

features_data = pd.DataFrame(features).T
features_data.to_csv('features_tv.csv', index_label='participant_id')

Found 88 .set files.
Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-011/eeg/sub-011_task-eyesclosed_eeg.set (1/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-088/eeg/sub-088_task-eyesclosed_eeg.set (2/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-045/eeg/sub-045_task-eyesclosed_eeg.set (3/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-038/eeg/sub-038_task-eyesclosed_eeg.set (4/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-002/eeg/sub-002_task-eyesclosed_eeg.set (5/88)


  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-068/eeg/sub-068_task-eyesclosed_eeg.set (6/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-030/eeg/sub-030_task-eyesclosed_eeg.set (7/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-087/eeg/sub-087_task-eyesclosed_eeg.set (8/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


Processing /home/b12901075/eecsmed/ds004504/derivatives/sub-040/eeg/sub-040_task-eyesclosed_eeg.set (9/88)


  raw = mne.io.read_raw_eeglab(file)
  raw = mne.io.read_raw_eeglab(file)


KeyboardInterrupt: 

## Processing 

In [4]:
import pandas as pd
import numpy as np
import os

# 设置数据路径
base_dir = r'./'
dataset_dir = os.path.join(base_dir, 'ds004504')

# 读取特征数据
features = pd.read_csv(os.path.join(base_dir, "features_tv.csv"))
# 读取参与者信息
participants = pd.read_csv(os.path.join(dataset_dir, "participants.tsv"), delimiter='\t')
data = features.merge(participants, left_index=True, right_index=True)
data

Unnamed: 0,participant_id_x,stationary_ratio,Tik-norm,Total_Variation,graph_energy,spectral_entropy,signal_energy,signal_power,unique_clusters,avg_degree,heat_trace,diffusion_distance,participant_id_y,Gender,Age,Group,MMSE
0,sub-011_task-eyesclosed_eeg.set,0.065765,1.472647,1.557900,5.928230e+07,8.948846,0.000158,1.077305e-09,8.0,7698.999685,0.0,1.611460e+08,sub-001,F,57,A,16
1,sub-088_task-eyesclosed_eeg.set,0.037262,1.189033,1.476772,6.156756e+07,8.967759,0.000139,9.320460e-10,8.0,7845.999722,0.0,1.673580e+08,sub-002,F,78,A,22
2,sub-045_task-eyesclosed_eeg.set,0.063532,2.015106,1.684959,7.253077e+07,9.049702,0.000167,1.029976e-09,8.0,8515.999667,0.0,1.971591e+08,sub-003,M,70,A,14
3,sub-038_task-eyesclosed_eeg.set,0.057751,1.793222,1.636528,7.945048e+07,9.095266,0.000150,8.870028e-10,8.0,8912.999700,0.0,2.159688e+08,sub-004,F,67,A,20
4,sub-002_task-eyesclosed_eeg.set,0.036089,1.152527,1.465304,6.289283e+07,8.978408,0.000135,8.982981e-10,8.0,7929.999729,0.0,1.709604e+08,sub-005,M,70,A,22
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
83,sub-015_task-eyesclosed_eeg.set,0.044033,2.469593,1.772847,8.124318e+07,9.106423,0.000174,1.017948e-09,8.0,9012.999651,0.0,2.208419e+08,sub-084,F,71,F,24
84,sub-027_task-eyesclosed_eeg.set,0.031677,2.338241,1.748789,6.825238e+07,9.019301,0.000185,1.179025e-09,8.0,8260.999630,0.0,1.855292e+08,sub-085,M,64,F,26
85,sub-003_task-eyesclosed_eeg.set,0.048940,0.017792,0.516504,9.366660e+06,8.026170,0.000044,7.492876e-10,8.0,3059.999913,0.0,2.546122e+07,sub-086,M,49,F,26
86,sub-057_task-eyesclosed_eeg.set,0.031148,1.485881,1.561389,6.343326e+07,8.982687,0.000153,1.011276e-09,8.0,7963.999694,0.0,1.724295e+08,sub-087,M,73,F,24


## Plotting

In [5]:
# Print the frequency of each group
group_frequencies = data['Group'].value_counts()
print(group_frequencies)

Group
A    36
C    29
F    23
Name: count, dtype: int64


## Multi-Classification

In [6]:
features = {
    'stationary_ratio': 'Stationary Ratio',
    'Tik-norm': 'Tik-norm',
    'Total_Variation': 'Total Variation',
    'graph_energy': 'Graph Energy',
    'spectral_entropy': 'Spectral Entropy',
    'signal_energy': 'Signal Energy',
    'signal_power': 'Signal Power',
    'avg_degree': 'Average Degree',
    'diffusion_distance': 'Diffusion Distance',
}

### 階層式分類

#### LogisticRegression

In [7]:
# ============================================
# 完整階層式分類程式（包含資料載入）
# ============================================

# ============================================
# 步驟 0: 資料載入和預處理
# ============================================
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix, 
                            balanced_accuracy_score, roc_auc_score,
                            roc_curve, auc)
from sklearn.preprocessing import label_binarize
from sklearn.linear_model import LogisticRegression
import ast
import warnings
warnings.filterwarnings("ignore")

# 設置數據路徑
base_dir = r'./'
dataset_dir = os.path.join(base_dir, 'ds004504')

# 讀取特徵數據
features_df = pd.read_csv(os.path.join(base_dir, "features_tv.csv"))
# 讀取參與者信息
participants = pd.read_csv(os.path.join(dataset_dir, "participants.tsv"), delimiter='\t')

# 合併數據
data = features_df.merge(participants, left_index=True, right_index=True)

# 定義特徵字典
features = {
    'stationary_ratio': 'Stationary Ratio',
    'Tik-norm': 'Tik-norm',
    'Total_Variation': 'Total Variation',
    'graph_energy': 'Graph Energy',
    'spectral_entropy': 'Spectral Entropy',
    'signal_energy': 'Signal Energy',
    'signal_power': 'Signal Power',
    'avg_degree': 'Average Degree',
    'diffusion_distance': 'Diffusion Distance',
}

# 準備特徵和標籤
X = data[list(features.keys())].copy()
y = data['Group'].copy()

# 數據清理
for col in X.columns:
    first_val = X[col].iloc[0] if len(X) > 0 else None
    if isinstance(first_val, str):
        def safe_convert(x):
            if pd.isna(x):
                return np.nan
            if isinstance(x, str):
                try:
                    parsed = ast.literal_eval(x)
                    if isinstance(parsed, (list, tuple, np.ndarray)):
                        return float(parsed[0]) if len(parsed) > 0 else np.nan
                    return float(parsed)
                except:
                    try:
                        return float(x)
                    except:
                        try:
                            parsed = eval('[' + ','.join(x.strip('[]').split()) + ']')
                            return float(parsed[0]) if len(parsed) > 0 else np.nan
                        except:
                            return np.nan
            try:
                return float(x) if pd.notna(x) else np.nan
            except:
                return np.nan
        X[col] = X[col].apply(safe_convert)
    else:
        X[col] = pd.to_numeric(X[col], errors='coerce')

rows_with_all_nan = X.isna().all(axis=1)
valid_mask = ~(rows_with_all_nan | y.isna())
X = X[valid_mask].copy()
y = y[valid_mask].copy()

if X.isna().sum().sum() > 0:
    for col in X.columns:
        if X[col].isna().sum() > 0:
            median_val = X[col].median()
            X[col] = X[col].fillna(median_val)

X = X.astype(float)

if len(X) == 0:
    raise ValueError("錯誤：清理後數據為空！請檢查原始數據。")

# 數據分割（目標比例：訓練集 30, 驗證集 31, 測試集 27）
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=27/88, random_state=42, stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=31/61, random_state=42, stratify=y_train_val
)

X_train = X_train.astype(float)
X_val = X_val.astype(float)
X_test = X_test.astype(float)

# 標準化
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# 編碼標籤
le = LabelEncoder()
y_train_int = le.fit_transform(y_train)
y_val_int = le.transform(y_val)
y_test_int = le.transform(y_test)

print("="*60)
print("數據分割信息")
print("="*60)
print(f"總樣本數: {len(X)}")
print(f"訓練集樣本數: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"驗證集樣本數: {len(X_val)} ({len(X_val)/len(X)*100:.1f}%)")
print(f"測試集樣本數: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
print(f"\n類別編碼: {dict(zip(le.classes_, range(len(le.classes_))))}")
print(f"訓練集類別分布: {dict(zip(le.classes_, np.bincount(y_train_int)))}")
print(f"驗證集類別分布: {dict(zip(le.classes_, np.bincount(y_val_int)))}")
print(f"測試集類別分布: {dict(zip(le.classes_, np.bincount(y_test_int)))}")
print("="*60)

# ============================================
# 步驟 1：健康 vs 疾病（二分類）
# ============================================
print("\n" + "="*60)
print("階層式分類：步驟 1 - 健康 vs 疾病（二分類）")
print("="*60)

y_binary_train = (y_train_int == 1).astype(int)  # C=1 是健康，A 和 F 是疾病
y_binary_val = (y_val_int == 1).astype(int)
y_binary_test = (y_test_int == 1).astype(int)

# 訓練二分類模型
binary_model = LogisticRegression(C=0.1, class_weight='balanced', random_state=42, max_iter=1000)
binary_model.fit(X_train_scaled, y_binary_train)

# 驗證集預測
binary_val_pred = binary_model.predict(X_val_scaled)
binary_val_proba = binary_model.predict_proba(X_val_scaled)

# 測試集預測
binary_test_pred = binary_model.predict(X_test_scaled)
binary_test_proba = binary_model.predict_proba(X_test_scaled)

print("\n驗證集結果:")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_binary_val, binary_val_pred):.4f}")
print(f"混淆矩陣:")
print(confusion_matrix(y_binary_val, binary_val_pred))
print(f"\n分類報告:")
print(classification_report(y_binary_val, binary_val_pred, 
                            target_names=['疾病 (A+F)', '健康 (C)']))

# ROC AUC
binary_val_roc = roc_auc_score(y_binary_val, binary_val_proba[:, 1])
print(f"ROC AUC: {binary_val_roc:.4f}")

print("\n測試集結果:")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_binary_test, binary_test_pred):.4f}")
print(f"混淆矩陣:")
print(confusion_matrix(y_binary_test, binary_test_pred))
print(f"\n分類報告:")
print(classification_report(y_binary_test, binary_test_pred,
                            target_names=['疾病 (A+F)', '健康 (C)']))

# ROC AUC
binary_test_roc = roc_auc_score(y_binary_test, binary_test_proba[:, 1])
print(f"ROC AUC: {binary_test_roc:.4f}")

# ============================================
# 步驟 2：在疾病樣本中區分 A 和 F
# ============================================
print("\n" + "="*60)
print("階層式分類：步驟 2 - AD vs FTD（二分類）")
print("="*60)

disease_mask_train = y_train_int != 1
disease_mask_val = y_val_int != 1
disease_mask_test = y_test_int != 1

print(f"訓練集疾病樣本數: {disease_mask_train.sum()}")
print(f"驗證集疾病樣本數: {disease_mask_val.sum()}")
print(f"測試集疾病樣本數: {disease_mask_test.sum()}")

if disease_mask_train.sum() > 5 and disease_mask_test.sum() > 3:
    # 準備疾病樣本資料
    X_disease_train = X_train_scaled[disease_mask_train]
    y_disease_train = y_train_int[disease_mask_train]
    y_disease_train_binary = (y_disease_train == 2).astype(int)  # A=0, F=1
    
    X_disease_val = X_val_scaled[disease_mask_val]
    y_disease_val = y_val_int[disease_mask_val]
    y_disease_val_binary = (y_disease_val == 2).astype(int)
    
    X_disease_test = X_test_scaled[disease_mask_test]
    y_disease_test = y_test_int[disease_mask_test]
    y_disease_test_binary = (y_disease_test == 2).astype(int)
    
    print(f"\n疾病樣本分布:")
    print(f"訓練集 - A: {np.sum(y_disease_train_binary == 0)}, F: {np.sum(y_disease_train_binary == 1)}")
    print(f"驗證集 - A: {np.sum(y_disease_val_binary == 0)}, F: {np.sum(y_disease_val_binary == 1)}")
    print(f"測試集 - A: {np.sum(y_disease_test_binary == 0)}, F: {np.sum(y_disease_test_binary == 1)}")
    
    # 訓練疾病分類模型
    disease_model = LogisticRegression(C=0.1, class_weight='balanced', random_state=42, max_iter=1000)
    disease_model.fit(X_disease_train, y_disease_train_binary)
    
    # 驗證集預測
    disease_val_pred = disease_model.predict(X_disease_val)
    disease_val_proba = disease_model.predict_proba(X_disease_val)
    
    # 測試集預測
    disease_test_pred = disease_model.predict(X_disease_test)
    disease_test_proba = disease_model.predict_proba(X_disease_test)
    
    print("\n驗證集結果:")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_disease_val_binary, disease_val_pred):.4f}")
    print(f"混淆矩陣:")
    print(confusion_matrix(y_disease_val_binary, disease_val_pred))
    print(f"\n分類報告:")
    print(classification_report(y_disease_val_binary, disease_val_pred,
                                target_names=['AD (A)', 'FTD (F)']))
    
    # ROC AUC
    if len(np.unique(y_disease_val_binary)) > 1:
        disease_val_roc = roc_auc_score(y_disease_val_binary, disease_val_proba[:, 1])
        print(f"ROC AUC: {disease_val_roc:.4f}")
    
    print("\n測試集結果:")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_disease_test_binary, disease_test_pred):.4f}")
    print(f"混淆矩陣:")
    print(confusion_matrix(y_disease_test_binary, disease_test_pred))
    print(f"\n分類報告:")
    print(classification_report(y_disease_test_binary, disease_test_pred,
                                target_names=['AD (A)', 'FTD (F)']))
    
    # ROC AUC
    if len(np.unique(y_disease_test_binary)) > 1:
        disease_test_roc = roc_auc_score(y_disease_test_binary, disease_test_proba[:, 1])
        print(f"ROC AUC: {disease_test_roc:.4f}")
    
    # ============================================
    # 步驟 3：組合預測（三分類最終結果）
    # ============================================
    print("\n" + "="*60)
    print("階層式分類：最終三分類結果")
    print("="*60)
    
    # 驗證集組合預測
    final_val_pred = np.zeros(len(y_val_int))
    final_val_pred[y_binary_val == 1] = 1  # 健康 = C (1)
    disease_val_indices = np.where(y_binary_val == 0)[0]
    # 如果預測為疾病，則使用疾病分類器的預測：A=0, F=2
    final_val_pred[disease_val_indices] = disease_val_pred * 2
    
    # 測試集組合預測
    final_test_pred = np.zeros(len(y_test_int))
    final_test_pred[y_binary_test == 1] = 1  # 健康 = C (1)
    disease_test_indices = np.where(y_binary_test == 0)[0]
    # 如果預測為疾病，則使用疾病分類器的預測：A=0, F=2
    final_test_pred[disease_test_indices] = disease_test_pred * 2
    
    # ============================================
    # 驗證集完整評估
    # ============================================
    print("\n" + "="*60)
    print("Validation Metrics for Hierarchical Classification")
    print("="*60)
    
    val_balanced_acc = balanced_accuracy_score(y_val_int, final_val_pred.astype(int))
    print(f"\nBalanced Accuracy: {val_balanced_acc:.4f}")
    
    print(f"\n混淆矩陣:")
    print(confusion_matrix(y_val_int, final_val_pred.astype(int)))
    
    print(f"\n分類報告:")
    print(classification_report(y_val_int, final_val_pred.astype(int), 
                                target_names=list(le.classes_)))
    
    # 計算每個類別的預測概率（用於 ROC AUC）
    final_val_proba = np.zeros((len(y_val_int), 3))
    
    # 健康類別（C）的概率
    final_val_proba[:, 1] = binary_val_proba[:, 1]
    
    # 疾病類別的概率 = P(疾病) * P(具體疾病類型|疾病)
    disease_prob = binary_val_proba[:, 0]  # P(疾病)
    final_val_proba[disease_val_indices, 0] = disease_prob[disease_val_indices] * disease_val_proba[:, 0]  # A
    final_val_proba[disease_val_indices, 2] = disease_prob[disease_val_indices] * disease_val_proba[:, 1]  # F
    
    # 歸一化概率
    final_val_proba = final_val_proba / (final_val_proba.sum(axis=1, keepdims=True) + 1e-10)
    
    # ROC AUC
    val_binarized = label_binarize(y_val_int, classes=[0, 1, 2])
    print("\n每個類別的 ROC AUC:")
    for i, class_name in enumerate(le.classes_):
        fpr, tpr, _ = roc_curve(val_binarized[:, i], final_val_proba[:, i])
        roc_auc = auc(fpr, tpr)
        print(f"  {class_name} (Class {i}): {roc_auc:.4f}")
    
    macro_roc_auc_val = roc_auc_score(y_val_int, final_val_proba, multi_class='ovr', average='macro')
    weighted_roc_auc_val = roc_auc_score(y_val_int, final_val_proba, multi_class='ovr', average='weighted')
    
    print(f"\n宏平均 ROC AUC: {macro_roc_auc_val:.4f}")
    print(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}")
    
    # ============================================
    # 測試集完整評估
    # ============================================
    print("\n" + "="*60)
    print("Test Metrics for Hierarchical Classification")
    print("="*60)
    
    test_balanced_acc = balanced_accuracy_score(y_test_int, final_test_pred.astype(int))
    print(f"\nBalanced Accuracy: {test_balanced_acc:.4f}")
    
    print(f"\n混淆矩陣:")
    print(confusion_matrix(y_test_int, final_test_pred.astype(int)))
    
    print(f"\n分類報告:")
    print(classification_report(y_test_int, final_test_pred.astype(int),
                                target_names=list(le.classes_)))
    
    # 計算測試集概率
    final_test_proba = np.zeros((len(y_test_int), 3))
    final_test_proba[:, 1] = binary_test_proba[:, 1]  # 健康類別
    disease_prob_test = binary_test_proba[:, 0]
    final_test_proba[disease_test_indices, 0] = disease_prob_test[disease_test_indices] * disease_test_proba[:, 0]
    final_test_proba[disease_test_indices, 2] = disease_prob_test[disease_test_indices] * disease_test_proba[:, 1]
    final_test_proba = final_test_proba / (final_test_proba.sum(axis=1, keepdims=True) + 1e-10)
    
    # ROC AUC
    test_binarized = label_binarize(y_test_int, classes=[0, 1, 2])
    print("\n每個類別的 ROC AUC:")
    for i, class_name in enumerate(le.classes_):
        fpr, tpr, _ = roc_curve(test_binarized[:, i], final_test_proba[:, i])
        roc_auc = auc(fpr, tpr)
        print(f"  {class_name} (Class {i}): {roc_auc:.4f}")
    
    macro_roc_auc_test = roc_auc_score(y_test_int, final_test_proba, multi_class='ovr', average='macro')
    weighted_roc_auc_test = roc_auc_score(y_test_int, final_test_proba, multi_class='ovr', average='weighted')
    
    print(f"\n宏平均 ROC AUC: {macro_roc_auc_test:.4f}")
    print(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}")
    
    # ============================================
    # 過擬合分析
    # ============================================
    print("\n" + "="*60)
    print("過擬合分析")
    print("="*60)
    overfitting_gap = val_balanced_acc - test_balanced_acc
    print(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}")
    print(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}")
    print(f"過擬合差距: {overfitting_gap:.4f}")
    
    if overfitting_gap > 0.15:
        print("⚠️ 警告：存在明顯過擬合（差距 > 0.15）")
    elif overfitting_gap > 0.10:
        print("⚠️ 注意：存在一定過擬合（差距 > 0.10）")
    else:
        print("✓ 過擬合程度在可接受範圍內")
    print("="*60)
    
    # ============================================
    # 類別預測分布分析
    # ============================================
    print("\n" + "="*60)
    print("類別預測分布分析")
    print("="*60)
    
    print("\n驗證集預測分布:")
    val_pred_counts = {le.classes_[i]: np.sum(final_val_pred.astype(int) == i) for i in range(len(le.classes_))}
    val_true_counts = {le.classes_[i]: np.sum(y_val_int == i) for i in range(len(le.classes_))}
    print(f"真實分布: {val_true_counts}")
    print(f"預測分布: {val_pred_counts}")
    
    print("\n測試集預測分布:")
    test_pred_counts = {le.classes_[i]: np.sum(final_test_pred.astype(int) == i) for i in range(len(le.classes_))}
    test_true_counts = {le.classes_[i]: np.sum(y_test_int == i) for i in range(len(le.classes_))}
    print(f"真實分布: {test_true_counts}")
    print(f"預測分布: {test_pred_counts}")
    print("="*60)
    
else:
    print("⚠️ 警告：疾病樣本數量不足，無法進行第二步分類")
    print("只進行健康 vs 疾病的分類")





# ============================================
# 保存結果和繪製 Confusion Matrix
# ============================================
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# 創建結果目錄
results_dir = './two_stage_results/logistic_regression'
os.makedirs(results_dir, exist_ok=True)

# 保存預測結果到 CSV
if disease_mask_train.sum() > 5 and disease_mask_test.sum() > 3:
    # 驗證集結果
    val_results_df = pd.DataFrame({
        'true_label': [le.classes_[i] for i in y_val_int],
        'predicted_label': [le.classes_[int(i)] for i in final_val_pred],
        'true_label_int': y_val_int,
        'predicted_label_int': final_val_pred.astype(int),
        'prob_A': final_val_proba[:, 0],
        'prob_C': final_val_proba[:, 1],
        'prob_F': final_val_proba[:, 2]
    })
    val_results_df.to_csv(os.path.join(results_dir, 'validation_predictions.csv'), index=False)
    print(f"\n✓ 驗證集預測結果已保存到: {os.path.join(results_dir, 'validation_predictions.csv')}")
    
    # 測試集結果
    test_results_df = pd.DataFrame({
        'true_label': [le.classes_[i] for i in y_test_int],
        'predicted_label': [le.classes_[int(i)] for i in final_test_pred],
        'true_label_int': y_test_int,
        'predicted_label_int': final_test_pred.astype(int),
        'prob_A': final_test_proba[:, 0],
        'prob_C': final_test_proba[:, 1],
        'prob_F': final_test_proba[:, 2]
    })
    test_results_df.to_csv(os.path.join(results_dir, 'test_predictions.csv'), index=False)
    print(f"✓ 測試集預測結果已保存到: {os.path.join(results_dir, 'test_predictions.csv')}")
    
    # 保存評估指標到文本文件
    with open(os.path.join(results_dir, 'evaluation_metrics.txt'), 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("階層式分類評估結果\n")
        f.write("="*60 + "\n")
        f.write(f"生成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        f.write("驗證集結果:\n")
        f.write("-"*60 + "\n")
        f.write(f"Balanced Accuracy: {val_balanced_acc:.4f}\n")
        f.write(f"宏平均 ROC AUC: {macro_roc_auc_val:.4f}\n")
        f.write(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}\n\n")
        f.write("混淆矩陣:\n")
        f.write(str(confusion_matrix(y_val_int, final_val_pred.astype(int))) + "\n\n")
        f.write("分類報告:\n")
        f.write(classification_report(y_val_int, final_val_pred.astype(int), 
                                     target_names=list(le.classes_)) + "\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("測試集結果:\n")
        f.write("-"*60 + "\n")
        f.write(f"Balanced Accuracy: {test_balanced_acc:.4f}\n")
        f.write(f"宏平均 ROC AUC: {macro_roc_auc_test:.4f}\n")
        f.write(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}\n\n")
        f.write("混淆矩陣:\n")
        f.write(str(confusion_matrix(y_test_int, final_test_pred.astype(int))) + "\n\n")
        f.write("分類報告:\n")
        f.write(classification_report(y_test_int, final_test_pred.astype(int),
                                     target_names=list(le.classes_)) + "\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("過擬合分析:\n")
        f.write("-"*60 + "\n")
        f.write(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}\n")
        f.write(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}\n")
        f.write(f"過擬合差距: {overfitting_gap:.4f}\n")
    
    print(f"✓ 評估指標已保存到: {os.path.join(results_dir, 'evaluation_metrics.txt')}")
    
    # 繪製 Confusion Matrix
    class_names = list(le.classes_)
    
    # 驗證集 Confusion Matrix
    cm_val = confusion_matrix(y_val_int, final_val_pred.astype(int))
    cm_val_normalized = cm_val.astype('float') / (cm_val.sum(axis=1)[:, np.newaxis] + 1e-10)
    cm_val_normalized = np.nan_to_num(cm_val_normalized)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 原始混淆矩陣
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues', 
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Validation Set Confusion Matrix (Count)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    
    # 歸一化混淆矩陣
    sns.heatmap(cm_val_normalized, annot=True, fmt='.2%', cmap='Blues',
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[1], cbar_kws={'label': 'Percentage'})
    axes[1].set_title('Validation Set Confusion Matrix (Percentage)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    val_cm_path = os.path.join(results_dir, 'confusion_matrix_validation.png')
    plt.savefig(val_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ 驗證集混淆矩陣已保存到: {val_cm_path}")
    
    # 測試集 Confusion Matrix
    cm_test = confusion_matrix(y_test_int, final_test_pred.astype(int))
    cm_test_normalized = cm_test.astype('float') / (cm_test.sum(axis=1)[:, np.newaxis] + 1e-10)
    cm_test_normalized = np.nan_to_num(cm_test_normalized)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 原始混淆矩陣
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Blues', 
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Test Set Confusion Matrix (Count)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    # 歸一化混淆矩陣
    sns.heatmap(cm_test_normalized, annot=True, fmt='.2%', cmap='Blues',
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[1], cbar_kws={'label': 'Percentage'})
    axes[1].set_title('Test Set Confusion Matrix (Percentage)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    test_cm_path = os.path.join(results_dir, 'confusion_matrix_test.png')
    plt.savefig(test_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ 測試集混淆矩陣已保存到: {test_cm_path}")
    
    print("\n" + "="*60)
    print("所有結果已保存完成！")
    print("="*60)
    print(f"結果目錄: {results_dir}")
    print("保存的文件:")
    print("  - validation_predictions.csv (驗證集預測結果)")
    print("  - test_predictions.csv (測試集預測結果)")
    print("  - evaluation_metrics.txt (評估指標)")
    print("  - confusion_matrix_validation.png (驗證集混淆矩陣)")
    print("  - confusion_matrix_test.png (測試集混淆矩陣)")
    print("="*60)
else:
    # 如果只有二分類結果，也保存
    val_results_df = pd.DataFrame({
        'true_label': ['疾病 (A+F)' if i == 0 else '健康 (C)' for i in y_binary_val],
        'predicted_label': ['疾病 (A+F)' if i == 0 else '健康 (C)' for i in binary_val_pred],
        'true_label_int': y_binary_val,
        'predicted_label_int': binary_val_pred,
        'prob_disease': binary_val_proba[:, 0],
        'prob_healthy': binary_val_proba[:, 1]
    })
    val_results_df.to_csv(os.path.join(results_dir, 'validation_predictions_binary.csv'), index=False)
    
    test_results_df = pd.DataFrame({
        'true_label': ['疾病 (A+F)' if i == 0 else '健康 (C)' for i in y_binary_test],
        'predicted_label': ['疾病 (A+F)' if i == 0 else '健康 (C)' for i in binary_test_pred],
        'true_label_int': y_binary_test,
        'predicted_label_int': binary_test_pred,
        'prob_disease': binary_test_proba[:, 0],
        'prob_healthy': binary_test_proba[:, 1]
    })
    test_results_df.to_csv(os.path.join(results_dir, 'test_predictions_binary.csv'), index=False)
    
    print(f"\n✓ 二分類結果已保存到: {results_dir}")

數據分割信息
總樣本數: 88
訓練集樣本數: 30 (34.1%)
驗證集樣本數: 31 (35.2%)
測試集樣本數: 27 (30.7%)

類別編碼: {'A': 0, 'C': 1, 'F': 2}
訓練集類別分布: {'A': np.int64(12), 'C': np.int64(10), 'F': np.int64(8)}
驗證集類別分布: {'A': np.int64(13), 'C': np.int64(10), 'F': np.int64(8)}
測試集類別分布: {'A': np.int64(11), 'C': np.int64(9), 'F': np.int64(7)}

階層式分類：步驟 1 - 健康 vs 疾病（二分類）

驗證集結果:
Balanced Accuracy: 0.4571
混淆矩陣:
[[15  6]
 [ 8  2]]

分類報告:
              precision    recall  f1-score   support

    疾病 (A+F)       0.65      0.71      0.68        21
      健康 (C)       0.25      0.20      0.22        10

    accuracy                           0.55        31
   macro avg       0.45      0.46      0.45        31
weighted avg       0.52      0.55      0.53        31

ROC AUC: 0.4905

測試集結果:
Balanced Accuracy: 0.2778
混淆矩陣:
[[10  8]
 [ 9  0]]

分類報告:
              precision    recall  f1-score   support

    疾病 (A+F)       0.53      0.56      0.54        18
      健康 (C)       0.00      0.00      0.00         9

    accuracy                   

#### SVM

In [8]:
# ============================================
# 完整階層式分類程式（使用 SVM）
# ============================================

# ============================================
# 步驟 0: 資料載入和預處理
# ============================================
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix, 
                            balanced_accuracy_score, roc_auc_score,
                            roc_curve, auc)
from sklearn.preprocessing import label_binarize
from sklearn.svm import SVC
import ast
import warnings
warnings.filterwarnings("ignore")

# 設置數據路徑
base_dir = r''
dataset_dir = os.path.join(base_dir, 'ds004504')

# 讀取特徵數據
features_df = pd.read_csv(os.path.join(base_dir, "features_tv.csv"))
# 讀取參與者信息
participants = pd.read_csv(os.path.join(dataset_dir, "participants.tsv"), delimiter='\t')

# 合併數據
data = features_df.merge(participants, left_index=True, right_index=True)

# 定義特徵字典
features = {
    'stationary_ratio': 'Stationary Ratio',
    'Tik-norm': 'Tik-norm',
    'Total_Variation': 'Total Variation',
    'graph_energy': 'Graph Energy',
    'spectral_entropy': 'Spectral Entropy',
    'signal_energy': 'Signal Energy',
    'signal_power': 'Signal Power',
    'avg_degree': 'Average Degree',
    'diffusion_distance': 'Diffusion Distance',
}

# 準備特徵和標籤
X = data[list(features.keys())].copy()
y = data['Group'].copy()

# 數據清理
for col in X.columns:
    first_val = X[col].iloc[0] if len(X) > 0 else None
    if isinstance(first_val, str):
        def safe_convert(x):
            if pd.isna(x):
                return np.nan
            if isinstance(x, str):
                try:
                    parsed = ast.literal_eval(x)
                    if isinstance(parsed, (list, tuple, np.ndarray)):
                        return float(parsed[0]) if len(parsed) > 0 else np.nan
                    return float(parsed)
                except:
                    try:
                        return float(x)
                    except:
                        try:
                            parsed = eval('[' + ','.join(x.strip('[]').split()) + ']')
                            return float(parsed[0]) if len(parsed) > 0 else np.nan
                        except:
                            return np.nan
            try:
                return float(x) if pd.notna(x) else np.nan
            except:
                return np.nan
        X[col] = X[col].apply(safe_convert)
    else:
        X[col] = pd.to_numeric(X[col], errors='coerce')

rows_with_all_nan = X.isna().all(axis=1)
valid_mask = ~(rows_with_all_nan | y.isna())
X = X[valid_mask].copy()
y = y[valid_mask].copy()

if X.isna().sum().sum() > 0:
    for col in X.columns:
        if X[col].isna().sum() > 0:
            median_val = X[col].median()
            X[col] = X[col].fillna(median_val)

X = X.astype(float)

if len(X) == 0:
    raise ValueError("錯誤：清理後數據為空！請檢查原始數據。")

# 數據分割（目標比例：訓練集 30, 驗證集 31, 測試集 27）
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=27/88, random_state=42, stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=31/61, random_state=42, stratify=y_train_val
)

X_train = X_train.astype(float)
X_val = X_val.astype(float)
X_test = X_test.astype(float)

# 標準化
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# 編碼標籤
le = LabelEncoder()
y_train_int = le.fit_transform(y_train)
y_val_int = le.transform(y_val)
y_test_int = le.transform(y_test)

print("="*60)
print("數據分割信息")
print("="*60)
print(f"總樣本數: {len(X)}")
print(f"訓練集樣本數: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"驗證集樣本數: {len(X_val)} ({len(X_val)/len(X)*100:.1f}%)")
print(f"測試集樣本數: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
print(f"\n類別編碼: {dict(zip(le.classes_, range(len(le.classes_))))}")
print(f"訓練集類別分布: {dict(zip(le.classes_, np.bincount(y_train_int)))}")
print(f"驗證集類別分布: {dict(zip(le.classes_, np.bincount(y_val_int)))}")
print(f"測試集類別分布: {dict(zip(le.classes_, np.bincount(y_test_int)))}")
print("="*60)

# ============================================
# 步驟 1：健康 vs 疾病（二分類）- 使用 SVM
# ============================================
print("\n" + "="*60)
print("階層式分類：步驟 1 - 健康 vs 疾病（二分類）- 使用 SVM")
print("="*60)

y_binary_train = (y_train_int == 1).astype(int)  # C=1 是健康，A 和 F 是疾病
y_binary_val = (y_val_int == 1).astype(int)
y_binary_test = (y_test_int == 1).astype(int)

# 訓練 SVM 二分類模型
binary_model = SVC(
    kernel='rbf',
    C=1.0,  # 正則化參數，可以調整
    gamma='scale',  # RBF 核的參數
    class_weight='balanced',  # 處理類別不平衡
    probability=True,  # 需要概率輸出用於 ROC AUC
    random_state=42
)

binary_model.fit(X_train_scaled, y_binary_train)

# 驗證集預測
binary_val_pred = binary_model.predict(X_val_scaled)
binary_val_proba = binary_model.predict_proba(X_val_scaled)

# 測試集預測
binary_test_pred = binary_model.predict(X_test_scaled)
binary_test_proba = binary_model.predict_proba(X_test_scaled)

print("\n驗證集結果:")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_binary_val, binary_val_pred):.4f}")
print(f"混淆矩陣:")
print(confusion_matrix(y_binary_val, binary_val_pred))
print(f"\n分類報告:")
print(classification_report(y_binary_val, binary_val_pred, 
                            target_names=['疾病 (A+F)', '健康 (C)']))

# ROC AUC
binary_val_roc = roc_auc_score(y_binary_val, binary_val_proba[:, 1])
print(f"ROC AUC: {binary_val_roc:.4f}")

print("\n測試集結果:")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_binary_test, binary_test_pred):.4f}")
print(f"混淆矩陣:")
print(confusion_matrix(y_binary_test, binary_test_pred))
print(f"\n分類報告:")
print(classification_report(y_binary_test, binary_test_pred,
                            target_names=['疾病 (A+F)', '健康 (C)']))

# ROC AUC
binary_test_roc = roc_auc_score(y_binary_test, binary_test_proba[:, 1])
print(f"ROC AUC: {binary_test_roc:.4f}")

# ============================================
# 步驟 2：在疾病樣本中區分 A 和 F - 使用 SVM
# ============================================
print("\n" + "="*60)
print("階層式分類：步驟 2 - AD vs FTD（二分類）- 使用 SVM")
print("="*60)

disease_mask_train = y_train_int != 1
disease_mask_val = y_val_int != 1
disease_mask_test = y_test_int != 1

print(f"訓練集疾病樣本數: {disease_mask_train.sum()}")
print(f"驗證集疾病樣本數: {disease_mask_val.sum()}")
print(f"測試集疾病樣本數: {disease_mask_test.sum()}")

if disease_mask_train.sum() > 5 and disease_mask_test.sum() > 3:
    # 準備疾病樣本資料
    X_disease_train = X_train_scaled[disease_mask_train]
    y_disease_train = y_train_int[disease_mask_train]
    y_disease_train_binary = (y_disease_train == 2).astype(int)  # A=0, F=1
    
    X_disease_val = X_val_scaled[disease_mask_val]
    y_disease_val = y_val_int[disease_mask_val]
    y_disease_val_binary = (y_disease_val == 2).astype(int)
    
    X_disease_test = X_test_scaled[disease_mask_test]
    y_disease_test = y_test_int[disease_mask_test]
    y_disease_test_binary = (y_disease_test == 2).astype(int)
    
    print(f"\n疾病樣本分布:")
    print(f"訓練集 - A: {np.sum(y_disease_train_binary == 0)}, F: {np.sum(y_disease_train_binary == 1)}")
    print(f"驗證集 - A: {np.sum(y_disease_val_binary == 0)}, F: {np.sum(y_disease_val_binary == 1)}")
    print(f"測試集 - A: {np.sum(y_disease_test_binary == 0)}, F: {np.sum(y_disease_test_binary == 1)}")
    
    # 訓練 SVM 疾病分類模型
    disease_model = SVC(
        kernel='rbf',
        C=1.0,
        gamma='scale',
        class_weight='balanced',
        probability=True,
        random_state=42
    )
    
    disease_model.fit(X_disease_train, y_disease_train_binary)
    
    # 驗證集預測
    disease_val_pred = disease_model.predict(X_disease_val)
    disease_val_proba = disease_model.predict_proba(X_disease_val)
    
    # 測試集預測
    disease_test_pred = disease_model.predict(X_disease_test)
    disease_test_proba = disease_model.predict_proba(X_disease_test)
    
    print("\n驗證集結果:")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_disease_val_binary, disease_val_pred):.4f}")
    print(f"混淆矩陣:")
    print(confusion_matrix(y_disease_val_binary, disease_val_pred))
    print(f"\n分類報告:")
    print(classification_report(y_disease_val_binary, disease_val_pred,
                                target_names=['AD (A)', 'FTD (F)']))
    
    # ROC AUC
    if len(np.unique(y_disease_val_binary)) > 1:
        disease_val_roc = roc_auc_score(y_disease_val_binary, disease_val_proba[:, 1])
        print(f"ROC AUC: {disease_val_roc:.4f}")
    
    print("\n測試集結果:")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_disease_test_binary, disease_test_pred):.4f}")
    print(f"混淆矩陣:")
    print(confusion_matrix(y_disease_test_binary, disease_test_pred))
    print(f"\n分類報告:")
    print(classification_report(y_disease_test_binary, disease_test_pred,
                                target_names=['AD (A)', 'FTD (F)']))
    
    # ROC AUC
    if len(np.unique(y_disease_test_binary)) > 1:
        disease_test_roc = roc_auc_score(y_disease_test_binary, disease_test_proba[:, 1])
        print(f"ROC AUC: {disease_test_roc:.4f}")
    
    # ============================================
    # 步驟 3：組合預測（三分類最終結果）
    # ============================================
    print("\n" + "="*60)
    print("階層式分類：最終三分類結果（使用 SVM）")
    print("="*60)
    
    # 驗證集組合預測
    final_val_pred = np.zeros(len(y_val_int))
    final_val_pred[y_binary_val == 1] = 1  # 健康 = C (1)
    disease_val_indices = np.where(y_binary_val == 0)[0]
    # 如果預測為疾病，則使用疾病分類器的預測：A=0, F=2
    final_val_pred[disease_val_indices] = disease_val_pred * 2
    
    # 測試集組合預測
    final_test_pred = np.zeros(len(y_test_int))
    final_test_pred[y_binary_test == 1] = 1  # 健康 = C (1)
    disease_test_indices = np.where(y_binary_test == 0)[0]
    # 如果預測為疾病，則使用疾病分類器的預測：A=0, F=2
    final_test_pred[disease_test_indices] = disease_test_pred * 2
    
    # ============================================
    # 驗證集完整評估
    # ============================================
    print("\n" + "="*60)
    print("Validation Metrics for Hierarchical Classification (SVM)")
    print("="*60)
    
    val_balanced_acc = balanced_accuracy_score(y_val_int, final_val_pred.astype(int))
    print(f"\nBalanced Accuracy: {val_balanced_acc:.4f}")
    
    print(f"\n混淆矩陣:")
    print(confusion_matrix(y_val_int, final_val_pred.astype(int)))
    
    print(f"\n分類報告:")
    print(classification_report(y_val_int, final_val_pred.astype(int), 
                                target_names=list(le.classes_)))
    
    # 計算每個類別的預測概率（用於 ROC AUC）
    final_val_proba = np.zeros((len(y_val_int), 3))
    
    # 健康類別（C）的概率
    final_val_proba[:, 1] = binary_val_proba[:, 1]
    
    # 疾病類別的概率 = P(疾病) * P(具體疾病類型|疾病)
    disease_prob = binary_val_proba[:, 0]  # P(疾病)
    final_val_proba[disease_val_indices, 0] = disease_prob[disease_val_indices] * disease_val_proba[:, 0]  # A
    final_val_proba[disease_val_indices, 2] = disease_prob[disease_val_indices] * disease_val_proba[:, 1]  # F
    
    # 歸一化概率
    final_val_proba = final_val_proba / (final_val_proba.sum(axis=1, keepdims=True) + 1e-10)
    
    # ROC AUC
    val_binarized = label_binarize(y_val_int, classes=[0, 1, 2])
    print("\n每個類別的 ROC AUC:")
    for i, class_name in enumerate(le.classes_):
        fpr, tpr, _ = roc_curve(val_binarized[:, i], final_val_proba[:, i])
        roc_auc = auc(fpr, tpr)
        print(f"  {class_name} (Class {i}): {roc_auc:.4f}")
    
    macro_roc_auc_val = roc_auc_score(y_val_int, final_val_proba, multi_class='ovr', average='macro')
    weighted_roc_auc_val = roc_auc_score(y_val_int, final_val_proba, multi_class='ovr', average='weighted')
    
    print(f"\n宏平均 ROC AUC: {macro_roc_auc_val:.4f}")
    print(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}")
    
    # ============================================
    # 測試集完整評估
    # ============================================
    print("\n" + "="*60)
    print("Test Metrics for Hierarchical Classification (SVM)")
    print("="*60)
    
    test_balanced_acc = balanced_accuracy_score(y_test_int, final_test_pred.astype(int))
    print(f"\nBalanced Accuracy: {test_balanced_acc:.4f}")
    
    print(f"\n混淆矩陣:")
    print(confusion_matrix(y_test_int, final_test_pred.astype(int)))
    
    print(f"\n分類報告:")
    print(classification_report(y_test_int, final_test_pred.astype(int),
                                target_names=list(le.classes_)))
    
    # 計算測試集概率
    final_test_proba = np.zeros((len(y_test_int), 3))
    final_test_proba[:, 1] = binary_test_proba[:, 1]  # 健康類別
    disease_prob_test = binary_test_proba[:, 0]
    final_test_proba[disease_test_indices, 0] = disease_prob_test[disease_test_indices] * disease_test_proba[:, 0]
    final_test_proba[disease_test_indices, 2] = disease_prob_test[disease_test_indices] * disease_test_proba[:, 1]
    final_test_proba = final_test_proba / (final_test_proba.sum(axis=1, keepdims=True) + 1e-10)
    
    # ROC AUC
    test_binarized = label_binarize(y_test_int, classes=[0, 1, 2])
    print("\n每個類別的 ROC AUC:")
    for i, class_name in enumerate(le.classes_):
        fpr, tpr, _ = roc_curve(test_binarized[:, i], final_test_proba[:, i])
        roc_auc = auc(fpr, tpr)
        print(f"  {class_name} (Class {i}): {roc_auc:.4f}")
    
    macro_roc_auc_test = roc_auc_score(y_test_int, final_test_proba, multi_class='ovr', average='macro')
    weighted_roc_auc_test = roc_auc_score(y_test_int, final_test_proba, multi_class='ovr', average='weighted')
    
    print(f"\n宏平均 ROC AUC: {macro_roc_auc_test:.4f}")
    print(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}")
    
    # ============================================
    # 過擬合分析
    # ============================================
    print("\n" + "="*60)
    print("過擬合分析")
    print("="*60)
    overfitting_gap = val_balanced_acc - test_balanced_acc
    print(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}")
    print(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}")
    print(f"過擬合差距: {overfitting_gap:.4f}")
    
    if overfitting_gap > 0.15:
        print("⚠️ 警告：存在明顯過擬合（差距 > 0.15）")
    elif overfitting_gap > 0.10:
        print("⚠️ 注意：存在一定過擬合（差距 > 0.10）")
    else:
        print("✓ 過擬合程度在可接受範圍內")
    print("="*60)
    
    # ============================================
    # 類別預測分布分析
    # ============================================
    print("\n" + "="*60)
    print("類別預測分布分析")
    print("="*60)
    
    print("\n驗證集預測分布:")
    val_pred_counts = {le.classes_[i]: np.sum(final_val_pred.astype(int) == i) for i in range(len(le.classes_))}
    val_true_counts = {le.classes_[i]: np.sum(y_val_int == i) for i in range(len(le.classes_))}
    print(f"真實分布: {val_true_counts}")
    print(f"預測分布: {val_pred_counts}")
    
    print("\n測試集預測分布:")
    test_pred_counts = {le.classes_[i]: np.sum(final_test_pred.astype(int) == i) for i in range(len(le.classes_))}
    test_true_counts = {le.classes_[i]: np.sum(y_test_int == i) for i in range(len(le.classes_))}
    print(f"真實分布: {test_true_counts}")
    print(f"預測分布: {test_pred_counts}")
    print("="*60)
    
else:
    print("⚠️ 警告：疾病樣本數量不足，無法進行第二步分類")
    print("只進行健康 vs 疾病的分類")




# ============================================
# 保存結果和繪製 Confusion Matrix
# ============================================
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# 創建結果目錄
results_dir = './two_stage_results/SVM'
os.makedirs(results_dir, exist_ok=True)

# 保存預測結果到 CSV
if disease_mask_train.sum() > 5 and disease_mask_test.sum() > 3:
    # 驗證集結果
    val_results_df = pd.DataFrame({
        'true_label': [le.classes_[i] for i in y_val_int],
        'predicted_label': [le.classes_[int(i)] for i in final_val_pred],
        'true_label_int': y_val_int,
        'predicted_label_int': final_val_pred.astype(int),
        'prob_A': final_val_proba[:, 0],
        'prob_C': final_val_proba[:, 1],
        'prob_F': final_val_proba[:, 2]
    })
    val_results_df.to_csv(os.path.join(results_dir, 'validation_predictions.csv'), index=False)
    print(f"\n✓ 驗證集預測結果已保存到: {os.path.join(results_dir, 'validation_predictions.csv')}")
    
    # 測試集結果
    test_results_df = pd.DataFrame({
        'true_label': [le.classes_[i] for i in y_test_int],
        'predicted_label': [le.classes_[int(i)] for i in final_test_pred],
        'true_label_int': y_test_int,
        'predicted_label_int': final_test_pred.astype(int),
        'prob_A': final_test_proba[:, 0],
        'prob_C': final_test_proba[:, 1],
        'prob_F': final_test_proba[:, 2]
    })
    test_results_df.to_csv(os.path.join(results_dir, 'test_predictions.csv'), index=False)
    print(f"✓ 測試集預測結果已保存到: {os.path.join(results_dir, 'test_predictions.csv')}")
    
    # 保存評估指標到文本文件
    with open(os.path.join(results_dir, 'evaluation_metrics.txt'), 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("階層式分類評估結果\n")
        f.write("="*60 + "\n")
        f.write(f"生成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        f.write("驗證集結果:\n")
        f.write("-"*60 + "\n")
        f.write(f"Balanced Accuracy: {val_balanced_acc:.4f}\n")
        f.write(f"宏平均 ROC AUC: {macro_roc_auc_val:.4f}\n")
        f.write(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}\n\n")
        f.write("混淆矩陣:\n")
        f.write(str(confusion_matrix(y_val_int, final_val_pred.astype(int))) + "\n\n")
        f.write("分類報告:\n")
        f.write(classification_report(y_val_int, final_val_pred.astype(int), 
                                     target_names=list(le.classes_)) + "\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("測試集結果:\n")
        f.write("-"*60 + "\n")
        f.write(f"Balanced Accuracy: {test_balanced_acc:.4f}\n")
        f.write(f"宏平均 ROC AUC: {macro_roc_auc_test:.4f}\n")
        f.write(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}\n\n")
        f.write("混淆矩陣:\n")
        f.write(str(confusion_matrix(y_test_int, final_test_pred.astype(int))) + "\n\n")
        f.write("分類報告:\n")
        f.write(classification_report(y_test_int, final_test_pred.astype(int),
                                     target_names=list(le.classes_)) + "\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("過擬合分析:\n")
        f.write("-"*60 + "\n")
        f.write(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}\n")
        f.write(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}\n")
        f.write(f"過擬合差距: {overfitting_gap:.4f}\n")
    
    print(f"✓ 評估指標已保存到: {os.path.join(results_dir, 'evaluation_metrics.txt')}")
    
    # 繪製 Confusion Matrix
    class_names = list(le.classes_)
    
    # 驗證集 Confusion Matrix
    cm_val = confusion_matrix(y_val_int, final_val_pred.astype(int))
    cm_val_normalized = cm_val.astype('float') / (cm_val.sum(axis=1)[:, np.newaxis] + 1e-10)
    cm_val_normalized = np.nan_to_num(cm_val_normalized)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 原始混淆矩陣
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues', 
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Validation Set Confusion Matrix (Count)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    
    # 歸一化混淆矩陣
    sns.heatmap(cm_val_normalized, annot=True, fmt='.2%', cmap='Blues',
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[1], cbar_kws={'label': 'Percentage'})
    axes[1].set_title('Validation Set Confusion Matrix (Percentage)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    val_cm_path = os.path.join(results_dir, 'confusion_matrix_validation.png')
    plt.savefig(val_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ 驗證集混淆矩陣已保存到: {val_cm_path}")
    
    # 測試集 Confusion Matrix
    cm_test = confusion_matrix(y_test_int, final_test_pred.astype(int))
    cm_test_normalized = cm_test.astype('float') / (cm_test.sum(axis=1)[:, np.newaxis] + 1e-10)
    cm_test_normalized = np.nan_to_num(cm_test_normalized)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 原始混淆矩陣
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Blues', 
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Test Set Confusion Matrix (Count)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    # 歸一化混淆矩陣
    sns.heatmap(cm_test_normalized, annot=True, fmt='.2%', cmap='Blues',
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[1], cbar_kws={'label': 'Percentage'})
    axes[1].set_title('Test Set Confusion Matrix (Percentage)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    test_cm_path = os.path.join(results_dir, 'confusion_matrix_test.png')
    plt.savefig(test_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ 測試集混淆矩陣已保存到: {test_cm_path}")
    
    print("\n" + "="*60)
    print("所有結果已保存完成！")
    print("="*60)
    print(f"結果目錄: {results_dir}")
    print("保存的文件:")
    print("  - validation_predictions.csv (驗證集預測結果)")
    print("  - test_predictions.csv (測試集預測結果)")
    print("  - evaluation_metrics.txt (評估指標)")
    print("  - confusion_matrix_validation.png (驗證集混淆矩陣)")
    print("  - confusion_matrix_test.png (測試集混淆矩陣)")
    print("="*60)

數據分割信息
總樣本數: 88
訓練集樣本數: 30 (34.1%)
驗證集樣本數: 31 (35.2%)
測試集樣本數: 27 (30.7%)

類別編碼: {'A': 0, 'C': 1, 'F': 2}
訓練集類別分布: {'A': np.int64(12), 'C': np.int64(10), 'F': np.int64(8)}
驗證集類別分布: {'A': np.int64(13), 'C': np.int64(10), 'F': np.int64(8)}
測試集類別分布: {'A': np.int64(11), 'C': np.int64(9), 'F': np.int64(7)}

階層式分類：步驟 1 - 健康 vs 疾病（二分類）- 使用 SVM

驗證集結果:
Balanced Accuracy: 0.4548
混淆矩陣:
[[17  4]
 [ 9  1]]

分類報告:
              precision    recall  f1-score   support

    疾病 (A+F)       0.65      0.81      0.72        21
      健康 (C)       0.20      0.10      0.13        10

    accuracy                           0.58        31
   macro avg       0.43      0.45      0.43        31
weighted avg       0.51      0.58      0.53        31

ROC AUC: 0.6143

測試集結果:
Balanced Accuracy: 0.3611
混淆矩陣:
[[13  5]
 [ 9  0]]

分類報告:
              precision    recall  f1-score   support

    疾病 (A+F)       0.59      0.72      0.65        18
      健康 (C)       0.00      0.00      0.00         9

    accuracy           

#### random forest

In [9]:
# ============================================
# 完整階層式分類程式（使用 Random Forest）
# ============================================

# ============================================
# 步驟 0: 資料載入和預處理
# ============================================
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.metrics import (classification_report, confusion_matrix, 
                            balanced_accuracy_score, roc_auc_score,
                            roc_curve, auc)
from sklearn.preprocessing import label_binarize
from sklearn.ensemble import RandomForestClassifier
import ast
import warnings
warnings.filterwarnings("ignore")

# 設置數據路徑
base_dir = r'./'
dataset_dir = os.path.join(base_dir, 'ds004504')

# 讀取特徵數據
features_df = pd.read_csv(os.path.join(base_dir, "features_tv.csv"))
# 讀取參與者信息
participants = pd.read_csv(os.path.join(dataset_dir, "participants.tsv"), delimiter='\t')

# 合併數據
data = features_df.merge(participants, left_index=True, right_index=True)

# 定義特徵字典
features = {
    'stationary_ratio': 'Stationary Ratio',
    'Tik-norm': 'Tik-norm',
    'Total_Variation': 'Total Variation',
    'graph_energy': 'Graph Energy',
    'spectral_entropy': 'Spectral Entropy',
    'signal_energy': 'Signal Energy',
    'signal_power': 'Signal Power',
    'avg_degree': 'Average Degree',
    'diffusion_distance': 'Diffusion Distance',
}

# 準備特徵和標籤
X = data[list(features.keys())].copy()
y = data['Group'].copy()

# 數據清理
for col in X.columns:
    first_val = X[col].iloc[0] if len(X) > 0 else None
    if isinstance(first_val, str):
        def safe_convert(x):
            if pd.isna(x):
                return np.nan
            if isinstance(x, str):
                try:
                    parsed = ast.literal_eval(x)
                    if isinstance(parsed, (list, tuple, np.ndarray)):
                        return float(parsed[0]) if len(parsed) > 0 else np.nan
                    return float(parsed)
                except:
                    try:
                        return float(x)
                    except:
                        try:
                            parsed = eval('[' + ','.join(x.strip('[]').split()) + ']')
                            return float(parsed[0]) if len(parsed) > 0 else np.nan
                        except:
                            return np.nan
            try:
                return float(x) if pd.notna(x) else np.nan
            except:
                return np.nan
        X[col] = X[col].apply(safe_convert)
    else:
        X[col] = pd.to_numeric(X[col], errors='coerce')

rows_with_all_nan = X.isna().all(axis=1)
valid_mask = ~(rows_with_all_nan | y.isna())
X = X[valid_mask].copy()
y = y[valid_mask].copy()

if X.isna().sum().sum() > 0:
    for col in X.columns:
        if X[col].isna().sum() > 0:
            median_val = X[col].median()
            X[col] = X[col].fillna(median_val)

X = X.astype(float)

if len(X) == 0:
    raise ValueError("錯誤：清理後數據為空！請檢查原始數據。")

# 數據分割（目標比例：訓練集 30, 驗證集 31, 測試集 27）
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=27/88, random_state=42, stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=31/61, random_state=42, stratify=y_train_val
)

X_train = X_train.astype(float)
X_val = X_val.astype(float)
X_test = X_test.astype(float)

# 標準化
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# 編碼標籤
le = LabelEncoder()
y_train_int = le.fit_transform(y_train)
y_val_int = le.transform(y_val)
y_test_int = le.transform(y_test)

print("="*60)
print("數據分割信息")
print("="*60)
print(f"總樣本數: {len(X)}")
print(f"訓練集樣本數: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"驗證集樣本數: {len(X_val)} ({len(X_val)/len(X)*100:.1f}%)")
print(f"測試集樣本數: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
print(f"\n類別編碼: {dict(zip(le.classes_, range(len(le.classes_))))}")
print(f"訓練集類別分布: {dict(zip(le.classes_, np.bincount(y_train_int)))}")
print(f"驗證集類別分布: {dict(zip(le.classes_, np.bincount(y_val_int)))}")
print(f"測試集類別分布: {dict(zip(le.classes_, np.bincount(y_test_int)))}")
print("="*60)

# ============================================
# 步驟 1：健康 vs 疾病（二分類）- 使用 Random Forest
# ============================================
print("\n" + "="*60)
print("階層式分類：步驟 1 - 健康 vs 疾病（二分類）- 使用 Random Forest")
print("="*60)

y_binary_train = (y_train_int == 1).astype(int)  # C=1 是健康，A 和 F 是疾病
y_binary_val = (y_val_int == 1).astype(int)
y_binary_test = (y_test_int == 1).astype(int)

# 訓練 Random Forest 二分類模型
binary_model = RandomForestClassifier(
    n_estimators=50,  # 樹的數量
    max_depth=3,  # 限制深度防止過擬合
    min_samples_split=5,  # 內部節點最小樣本數
    min_samples_leaf=3,  # 葉節點最小樣本數
    max_features='sqrt',  # 每次分割考慮的特徵數
    class_weight='balanced',  # 處理類別不平衡
    random_state=42,
    n_jobs=-1  # 使用所有 CPU 核心
)

binary_model.fit(X_train_scaled, y_binary_train)

# 驗證集預測
binary_val_pred = binary_model.predict(X_val_scaled)
binary_val_proba = binary_model.predict_proba(X_val_scaled)

# 測試集預測
binary_test_pred = binary_model.predict(X_test_scaled)
binary_test_proba = binary_model.predict_proba(X_test_scaled)

print("\n驗證集結果:")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_binary_val, binary_val_pred):.4f}")
print(f"混淆矩陣:")
print(confusion_matrix(y_binary_val, binary_val_pred))
print(f"\n分類報告:")
print(classification_report(y_binary_val, binary_val_pred, 
                            target_names=['疾病 (A+F)', '健康 (C)']))

# ROC AUC
binary_val_roc = roc_auc_score(y_binary_val, binary_val_proba[:, 1])
print(f"ROC AUC: {binary_val_roc:.4f}")

print("\n測試集結果:")
print(f"Balanced Accuracy: {balanced_accuracy_score(y_binary_test, binary_test_pred):.4f}")
print(f"混淆矩陣:")
print(confusion_matrix(y_binary_test, binary_test_pred))
print(f"\n分類報告:")
print(classification_report(y_binary_test, binary_test_pred,
                            target_names=['疾病 (A+F)', '健康 (C)']))

# ROC AUC
binary_test_roc = roc_auc_score(y_binary_test, binary_test_proba[:, 1])
print(f"ROC AUC: {binary_test_roc:.4f}")

# ============================================
# 步驟 2：在疾病樣本中區分 A 和 F - 使用 Random Forest
# ============================================
print("\n" + "="*60)
print("階層式分類：步驟 2 - AD vs FTD（二分類）- 使用 Random Forest")
print("="*60)

disease_mask_train = y_train_int != 1
disease_mask_val = y_val_int != 1
disease_mask_test = y_test_int != 1

print(f"訓練集疾病樣本數: {disease_mask_train.sum()}")
print(f"驗證集疾病樣本數: {disease_mask_val.sum()}")
print(f"測試集疾病樣本數: {disease_mask_test.sum()}")

if disease_mask_train.sum() > 5 and disease_mask_test.sum() > 3:
    # 準備疾病樣本資料
    X_disease_train = X_train_scaled[disease_mask_train]
    y_disease_train = y_train_int[disease_mask_train]
    y_disease_train_binary = (y_disease_train == 2).astype(int)  # A=0, F=1
    
    X_disease_val = X_val_scaled[disease_mask_val]
    y_disease_val = y_val_int[disease_mask_val]
    y_disease_val_binary = (y_disease_val == 2).astype(int)
    
    X_disease_test = X_test_scaled[disease_mask_test]
    y_disease_test = y_test_int[disease_mask_test]
    y_disease_test_binary = (y_disease_test == 2).astype(int)
    
    print(f"\n疾病樣本分布:")
    print(f"訓練集 - A: {np.sum(y_disease_train_binary == 0)}, F: {np.sum(y_disease_train_binary == 1)}")
    print(f"驗證集 - A: {np.sum(y_disease_val_binary == 0)}, F: {np.sum(y_disease_val_binary == 1)}")
    print(f"測試集 - A: {np.sum(y_disease_test_binary == 0)}, F: {np.sum(y_disease_test_binary == 1)}")
    
    # 訓練 Random Forest 疾病分類模型
    disease_model = RandomForestClassifier(
        n_estimators=30,  # 較少的樹（因為樣本更少）
        max_depth=2,  # 更淺的樹
        min_samples_split=3,
        min_samples_leaf=2,
        max_features='sqrt',
        class_weight='balanced',
        random_state=42,
        n_jobs=-1
    )
    
    disease_model.fit(X_disease_train, y_disease_train_binary)
    
    # 驗證集預測
    disease_val_pred = disease_model.predict(X_disease_val)
    disease_val_proba = disease_model.predict_proba(X_disease_val)
    
    # 測試集預測
    disease_test_pred = disease_model.predict(X_disease_test)
    disease_test_proba = disease_model.predict_proba(X_disease_test)
    
    print("\n驗證集結果:")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_disease_val_binary, disease_val_pred):.4f}")
    print(f"混淆矩陣:")
    print(confusion_matrix(y_disease_val_binary, disease_val_pred))
    print(f"\n分類報告:")
    print(classification_report(y_disease_val_binary, disease_val_pred,
                                target_names=['AD (A)', 'FTD (F)']))
    
    # ROC AUC
    if len(np.unique(y_disease_val_binary)) > 1:
        disease_val_roc = roc_auc_score(y_disease_val_binary, disease_val_proba[:, 1])
        print(f"ROC AUC: {disease_val_roc:.4f}")
    
    print("\n測試集結果:")
    print(f"Balanced Accuracy: {balanced_accuracy_score(y_disease_test_binary, disease_test_pred):.4f}")
    print(f"混淆矩陣:")
    print(confusion_matrix(y_disease_test_binary, disease_test_pred))
    print(f"\n分類報告:")
    print(classification_report(y_disease_test_binary, disease_test_pred,
                                target_names=['AD (A)', 'FTD (F)']))
    
    # ROC AUC
    if len(np.unique(y_disease_test_binary)) > 1:
        disease_test_roc = roc_auc_score(y_disease_test_binary, disease_test_proba[:, 1])
        print(f"ROC AUC: {disease_test_roc:.4f}")
    
    # ============================================
    # 步驟 3：組合預測（三分類最終結果）
    # ============================================
    print("\n" + "="*60)
    print("階層式分類：最終三分類結果（使用 Random Forest）")
    print("="*60)
    
    # 驗證集組合預測
    final_val_pred = np.zeros(len(y_val_int))
    final_val_pred[y_binary_val == 1] = 1  # 健康 = C (1)
    disease_val_indices = np.where(y_binary_val == 0)[0]
    # 如果預測為疾病，則使用疾病分類器的預測：A=0, F=2
    final_val_pred[disease_val_indices] = disease_val_pred * 2
    
    # 測試集組合預測
    final_test_pred = np.zeros(len(y_test_int))
    final_test_pred[y_binary_test == 1] = 1  # 健康 = C (1)
    disease_test_indices = np.where(y_binary_test == 0)[0]
    # 如果預測為疾病，則使用疾病分類器的預測：A=0, F=2
    final_test_pred[disease_test_indices] = disease_test_pred * 2
    
    # ============================================
    # 驗證集完整評估
    # ============================================
    print("\n" + "="*60)
    print("Validation Metrics for Hierarchical Classification (Random Forest)")
    print("="*60)
    
    val_balanced_acc = balanced_accuracy_score(y_val_int, final_val_pred.astype(int))
    print(f"\nBalanced Accuracy: {val_balanced_acc:.4f}")
    
    print(f"\n混淆矩陣:")
    print(confusion_matrix(y_val_int, final_val_pred.astype(int)))
    
    print(f"\n分類報告:")
    print(classification_report(y_val_int, final_val_pred.astype(int), 
                                target_names=list(le.classes_)))
    
    # 計算每個類別的預測概率（用於 ROC AUC）
    final_val_proba = np.zeros((len(y_val_int), 3))
    
    # 健康類別（C）的概率
    final_val_proba[:, 1] = binary_val_proba[:, 1]
    
    # 疾病類別的概率 = P(疾病) * P(具體疾病類型|疾病)
    disease_prob = binary_val_proba[:, 0]  # P(疾病)
    final_val_proba[disease_val_indices, 0] = disease_prob[disease_val_indices] * disease_val_proba[:, 0]  # A
    final_val_proba[disease_val_indices, 2] = disease_prob[disease_val_indices] * disease_val_proba[:, 1]  # F
    
    # 歸一化概率
    final_val_proba = final_val_proba / (final_val_proba.sum(axis=1, keepdims=True) + 1e-10)
    
    # ROC AUC
    val_binarized = label_binarize(y_val_int, classes=[0, 1, 2])
    print("\n每個類別的 ROC AUC:")
    for i, class_name in enumerate(le.classes_):
        fpr, tpr, _ = roc_curve(val_binarized[:, i], final_val_proba[:, i])
        roc_auc = auc(fpr, tpr)
        print(f"  {class_name} (Class {i}): {roc_auc:.4f}")
    
    macro_roc_auc_val = roc_auc_score(y_val_int, final_val_proba, multi_class='ovr', average='macro')
    weighted_roc_auc_val = roc_auc_score(y_val_int, final_val_proba, multi_class='ovr', average='weighted')
    
    print(f"\n宏平均 ROC AUC: {macro_roc_auc_val:.4f}")
    print(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}")
    
    # ============================================
    # 測試集完整評估
    # ============================================
    print("\n" + "="*60)
    print("Test Metrics for Hierarchical Classification (Random Forest)")
    print("="*60)
    
    test_balanced_acc = balanced_accuracy_score(y_test_int, final_test_pred.astype(int))
    print(f"\nBalanced Accuracy: {test_balanced_acc:.4f}")
    
    print(f"\n混淆矩陣:")
    print(confusion_matrix(y_test_int, final_test_pred.astype(int)))
    
    print(f"\n分類報告:")
    print(classification_report(y_test_int, final_test_pred.astype(int),
                                target_names=list(le.classes_)))
    
    # 計算測試集概率
    final_test_proba = np.zeros((len(y_test_int), 3))
    final_test_proba[:, 1] = binary_test_proba[:, 1]  # 健康類別
    disease_prob_test = binary_test_proba[:, 0]
    final_test_proba[disease_test_indices, 0] = disease_prob_test[disease_test_indices] * disease_test_proba[:, 0]
    final_test_proba[disease_test_indices, 2] = disease_prob_test[disease_test_indices] * disease_test_proba[:, 1]
    final_test_proba = final_test_proba / (final_test_proba.sum(axis=1, keepdims=True) + 1e-10)
    
    # ROC AUC
    test_binarized = label_binarize(y_test_int, classes=[0, 1, 2])
    print("\n每個類別的 ROC AUC:")
    for i, class_name in enumerate(le.classes_):
        fpr, tpr, _ = roc_curve(test_binarized[:, i], final_test_proba[:, i])
        roc_auc = auc(fpr, tpr)
        print(f"  {class_name} (Class {i}): {roc_auc:.4f}")
    
    macro_roc_auc_test = roc_auc_score(y_test_int, final_test_proba, multi_class='ovr', average='macro')
    weighted_roc_auc_test = roc_auc_score(y_test_int, final_test_proba, multi_class='ovr', average='weighted')
    
    print(f"\n宏平均 ROC AUC: {macro_roc_auc_test:.4f}")
    print(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}")
    
    # ============================================
    # 特徵重要性分析（Random Forest 特有）
    # ============================================
    print("\n" + "="*60)
    print("特徵重要性分析")
    print("="*60)
    
    print("\n步驟 1（健康 vs 疾病）特徵重要性:")
    feature_importance_binary = pd.DataFrame({
        'feature': list(features.keys()),
        'importance': binary_model.feature_importances_
    }).sort_values('importance', ascending=False)
    print(feature_importance_binary)
    
    print("\n步驟 2（AD vs FTD）特徵重要性:")
    feature_importance_disease = pd.DataFrame({
        'feature': list(features.keys()),
        'importance': disease_model.feature_importances_
    }).sort_values('importance', ascending=False)
    print(feature_importance_disease)
    
    # ============================================
    # 過擬合分析
    # ============================================
    print("\n" + "="*60)
    print("過擬合分析")
    print("="*60)
    overfitting_gap = val_balanced_acc - test_balanced_acc
    print(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}")
    print(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}")
    print(f"過擬合差距: {overfitting_gap:.4f}")
    
    if overfitting_gap > 0.15:
        print("⚠️ 警告：存在明顯過擬合（差距 > 0.15）")
        print("   建議：增加 max_depth 限制或增加 min_samples_split/min_samples_leaf")
    elif overfitting_gap > 0.10:
        print("⚠️ 注意：存在一定過擬合（差距 > 0.10）")
    else:
        print("✓ 過擬合程度在可接受範圍內")
    print("="*60)
    
    # ============================================
    # 類別預測分布分析
    # ============================================
    print("\n" + "="*60)
    print("類別預測分布分析")
    print("="*60)
    
    print("\n驗證集預測分布:")
    val_pred_counts = {le.classes_[i]: np.sum(final_val_pred.astype(int) == i) for i in range(len(le.classes_))}
    val_true_counts = {le.classes_[i]: np.sum(y_val_int == i) for i in range(len(le.classes_))}
    print(f"真實分布: {val_true_counts}")
    print(f"預測分布: {val_pred_counts}")
    
    print("\n測試集預測分布:")
    test_pred_counts = {le.classes_[i]: np.sum(final_test_pred.astype(int) == i) for i in range(len(le.classes_))}
    test_true_counts = {le.classes_[i]: np.sum(y_test_int == i) for i in range(len(le.classes_))}
    print(f"真實分布: {test_true_counts}")
    print(f"預測分布: {test_pred_counts}")
    print("="*60)
    
else:
    print("⚠️ 警告：疾病樣本數量不足，無法進行第二步分類")
    print("只進行健康 vs 疾病的分類")



# ============================================
# 保存結果和繪製 Confusion Matrix
# ============================================
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime

# 創建結果目錄
results_dir = './two_stage_results/random_forest'
os.makedirs(results_dir, exist_ok=True)

# 保存預測結果到 CSV
if disease_mask_train.sum() > 5 and disease_mask_test.sum() > 3:
    # 驗證集結果
    val_results_df = pd.DataFrame({
        'true_label': [le.classes_[i] for i in y_val_int],
        'predicted_label': [le.classes_[int(i)] for i in final_val_pred],
        'true_label_int': y_val_int,
        'predicted_label_int': final_val_pred.astype(int),
        'prob_A': final_val_proba[:, 0],
        'prob_C': final_val_proba[:, 1],
        'prob_F': final_val_proba[:, 2]
    })
    val_results_df.to_csv(os.path.join(results_dir, 'validation_predictions.csv'), index=False)
    print(f"\n✓ 驗證集預測結果已保存到: {os.path.join(results_dir, 'validation_predictions.csv')}")
    
    # 測試集結果
    test_results_df = pd.DataFrame({
        'true_label': [le.classes_[i] for i in y_test_int],
        'predicted_label': [le.classes_[int(i)] for i in final_test_pred],
        'true_label_int': y_test_int,
        'predicted_label_int': final_test_pred.astype(int),
        'prob_A': final_test_proba[:, 0],
        'prob_C': final_test_proba[:, 1],
        'prob_F': final_test_proba[:, 2]
    })
    test_results_df.to_csv(os.path.join(results_dir, 'test_predictions.csv'), index=False)
    print(f"✓ 測試集預測結果已保存到: {os.path.join(results_dir, 'test_predictions.csv')}")
    
    # 保存評估指標到文本文件
    with open(os.path.join(results_dir, 'evaluation_metrics.txt'), 'w', encoding='utf-8') as f:
        f.write("="*60 + "\n")
        f.write("階層式分類評估結果\n")
        f.write("="*60 + "\n")
        f.write(f"生成時間: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n\n")
        
        f.write("驗證集結果:\n")
        f.write("-"*60 + "\n")
        f.write(f"Balanced Accuracy: {val_balanced_acc:.4f}\n")
        f.write(f"宏平均 ROC AUC: {macro_roc_auc_val:.4f}\n")
        f.write(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}\n\n")
        f.write("混淆矩陣:\n")
        f.write(str(confusion_matrix(y_val_int, final_val_pred.astype(int))) + "\n\n")
        f.write("分類報告:\n")
        f.write(classification_report(y_val_int, final_val_pred.astype(int), 
                                     target_names=list(le.classes_)) + "\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("測試集結果:\n")
        f.write("-"*60 + "\n")
        f.write(f"Balanced Accuracy: {test_balanced_acc:.4f}\n")
        f.write(f"宏平均 ROC AUC: {macro_roc_auc_test:.4f}\n")
        f.write(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}\n\n")
        f.write("混淆矩陣:\n")
        f.write(str(confusion_matrix(y_test_int, final_test_pred.astype(int))) + "\n\n")
        f.write("分類報告:\n")
        f.write(classification_report(y_test_int, final_test_pred.astype(int),
                                     target_names=list(le.classes_)) + "\n")
        
        f.write("\n" + "="*60 + "\n")
        f.write("過擬合分析:\n")
        f.write("-"*60 + "\n")
        f.write(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}\n")
        f.write(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}\n")
        f.write(f"過擬合差距: {overfitting_gap:.4f}\n")
    
    print(f"✓ 評估指標已保存到: {os.path.join(results_dir, 'evaluation_metrics.txt')}")
    
    # 繪製 Confusion Matrix
    class_names = list(le.classes_)
    
    # 驗證集 Confusion Matrix
    cm_val = confusion_matrix(y_val_int, final_val_pred.astype(int))
    cm_val_normalized = cm_val.astype('float') / (cm_val.sum(axis=1)[:, np.newaxis] + 1e-10)
    cm_val_normalized = np.nan_to_num(cm_val_normalized)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 原始混淆矩陣
    sns.heatmap(cm_val, annot=True, fmt='d', cmap='Blues', 
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Validation Set Confusion Matrix (Count)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    
    # 歸一化混淆矩陣
    sns.heatmap(cm_val_normalized, annot=True, fmt='.2%', cmap='Blues',
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[1], cbar_kws={'label': 'Percentage'})
    axes[1].set_title('Validation Set Confusion Matrix (Percentage)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    val_cm_path = os.path.join(results_dir, 'confusion_matrix_validation.png')
    plt.savefig(val_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ 驗證集混淆矩陣已保存到: {val_cm_path}")
    
    # 測試集 Confusion Matrix
    cm_test = confusion_matrix(y_test_int, final_test_pred.astype(int))
    cm_test_normalized = cm_test.astype('float') / (cm_test.sum(axis=1)[:, np.newaxis] + 1e-10)
    cm_test_normalized = np.nan_to_num(cm_test_normalized)
    
    fig, axes = plt.subplots(1, 2, figsize=(16, 6))
    
    # 原始混淆矩陣
    sns.heatmap(cm_test, annot=True, fmt='d', cmap='Blues', 
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[0], cbar_kws={'label': 'Count'})
    axes[0].set_title('Test Set Confusion Matrix (Count)', fontsize=14, fontweight='bold')
    axes[0].set_xlabel('Predicted Label', fontsize=12)
    axes[0].set_ylabel('True Label', fontsize=12)
    
    # 歸一化混淆矩陣
    sns.heatmap(cm_test_normalized, annot=True, fmt='.2%', cmap='Blues',
               xticklabels=class_names, yticklabels=class_names,
               ax=axes[1], cbar_kws={'label': 'Percentage'})
    axes[1].set_title('Test Set Confusion Matrix (Percentage)', fontsize=14, fontweight='bold')
    axes[1].set_xlabel('Predicted Label', fontsize=12)
    axes[1].set_ylabel('True Label', fontsize=12)
    
    plt.tight_layout()
    test_cm_path = os.path.join(results_dir, 'confusion_matrix_test.png')
    plt.savefig(test_cm_path, dpi=300, bbox_inches='tight')
    plt.close()
    print(f"✓ 測試集混淆矩陣已保存到: {test_cm_path}")
    
    print("\n" + "="*60)
    print("所有結果已保存完成！")
    print("="*60)
    print(f"結果目錄: {results_dir}")
    print("保存的文件:")
    print("  - validation_predictions.csv (驗證集預測結果)")
    print("  - test_predictions.csv (測試集預測結果)")
    print("  - evaluation_metrics.txt (評估指標)")
    print("  - confusion_matrix_validation.png (驗證集混淆矩陣)")
    print("  - confusion_matrix_test.png (測試集混淆矩陣)")
    print("="*60)


數據分割信息
總樣本數: 88
訓練集樣本數: 30 (34.1%)
驗證集樣本數: 31 (35.2%)
測試集樣本數: 27 (30.7%)

類別編碼: {'A': 0, 'C': 1, 'F': 2}
訓練集類別分布: {'A': np.int64(12), 'C': np.int64(10), 'F': np.int64(8)}
驗證集類別分布: {'A': np.int64(13), 'C': np.int64(10), 'F': np.int64(8)}
測試集類別分布: {'A': np.int64(11), 'C': np.int64(9), 'F': np.int64(7)}

階層式分類：步驟 1 - 健康 vs 疾病（二分類）- 使用 Random Forest

驗證集結果:
Balanced Accuracy: 0.5571
混淆矩陣:
[[15  6]
 [ 6  4]]

分類報告:
              precision    recall  f1-score   support

    疾病 (A+F)       0.71      0.71      0.71        21
      健康 (C)       0.40      0.40      0.40        10

    accuracy                           0.61        31
   macro avg       0.56      0.56      0.56        31
weighted avg       0.61      0.61      0.61        31

ROC AUC: 0.5286

測試集結果:
Balanced Accuracy: 0.3611
混淆矩陣:
[[11  7]
 [ 8  1]]

分類報告:
              precision    recall  f1-score   support

    疾病 (A+F)       0.58      0.61      0.59        18
      健康 (C)       0.12      0.11      0.12         9

    accuracy 


驗證集結果:
Balanced Accuracy: 0.5962
混淆矩陣:
[[9 4]
 [4 4]]

分類報告:
              precision    recall  f1-score   support

      AD (A)       0.69      0.69      0.69        13
     FTD (F)       0.50      0.50      0.50         8

    accuracy                           0.62        21
   macro avg       0.60      0.60      0.60        21
weighted avg       0.62      0.62      0.62        21

ROC AUC: 0.6106

測試集結果:
Balanced Accuracy: 0.4416
混淆矩陣:
[[5 6]
 [4 3]]

分類報告:
              precision    recall  f1-score   support

      AD (A)       0.56      0.45      0.50        11
     FTD (F)       0.33      0.43      0.38         7

    accuracy                           0.44        18
   macro avg       0.44      0.44      0.44        18
weighted avg       0.47      0.44      0.45        18

ROC AUC: 0.5455

階層式分類：最終三分類結果（使用 Random Forest）

Validation Metrics for Hierarchical Classification (Random Forest)

Balanced Accuracy: 0.7308

混淆矩陣:
[[ 9  0  4]
 [ 0 10  0]
 [ 4  0  4]]

分類報告:
           

### MLP (Multi-Layer Perceptron)

In [None]:
from sklearn.model_selection import GridSearchCV, train_test_split, StratifiedKFold
from sklearn.neural_network import MLPClassifier
from sklearn.preprocessing import RobustScaler
from sklearn.metrics import classification_report, roc_auc_score, confusion_matrix, balanced_accuracy_score
from sklearn.preprocessing import LabelEncoder
from sklearn.utils.class_weight import compute_sample_weight
import pandas as pd
import numpy as np
import ast
import warnings
from sklearn.exceptions import ConvergenceWarning

warnings.filterwarnings("ignore")

# 準備特徵和標籤
X = data[list(features.keys())].copy()
y = data['Group'].copy()

# 數據清理（與之前相同）
for col in X.columns:
    first_val = X[col].iloc[0] if len(X) > 0 else None
    if isinstance(first_val, str):
        def safe_convert(x):
            if pd.isna(x):
                return np.nan
            if isinstance(x, str):
                try:
                    parsed = ast.literal_eval(x)
                    if isinstance(parsed, (list, tuple, np.ndarray)):
                        return float(parsed[0]) if len(parsed) > 0 else np.nan
                    return float(parsed)
                except:
                    try:
                        return float(x)
                    except:
                        try:
                            parsed = eval('[' + ','.join(x.strip('[]').split()) + ']')
                            return float(parsed[0]) if len(parsed) > 0 else np.nan
                        except:
                            return np.nan
            try:
                return float(x) if pd.notna(x) else np.nan
            except:
                return np.nan
        X[col] = X[col].apply(safe_convert)
    else:
        X[col] = pd.to_numeric(X[col], errors='coerce')

rows_with_all_nan = X.isna().all(axis=1)
valid_mask = ~(rows_with_all_nan | y.isna())
X = X[valid_mask].copy()
y = y[valid_mask].copy()

if X.isna().sum().sum() > 0:
    for col in X.columns:
        if X[col].isna().sum() > 0:
            median_val = X[col].median()
            X[col] = X[col].fillna(median_val)

X = X.astype(float)

if len(X) == 0:
    raise ValueError("錯誤：清理後數據為空！請檢查原始數據。")

# Split the data into train+validation set and test set
X_train_val, X_test, y_train_val, y_test = train_test_split(X, y, test_size=0.3, random_state=42)

# Further split the train data into train set and validation set
X_train, X_val, y_train, y_val = train_test_split(X_train_val, y_train_val, test_size=0.5, random_state=42)

X_train = X_train.astype(float)
X_val = X_val.astype(float)
X_test = X_test.astype(float)

scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# 編碼標籤
le = LabelEncoder()
y_train_int = le.fit_transform(y_train)
y_val_int = le.transform(y_val)
y_test_int = le.transform(y_test)

print("="*60)
print("數據分割信息")
print("="*60)
print(f"總樣本數: {len(X)}")
print(f"訓練集樣本數: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"驗證集樣本數: {len(X_val)} ({len(X_val)/len(X)*100:.1f}%)")
print(f"測試集樣本數: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
print("="*60)

print("="*60)
print("數據集信息")
print("="*60)
print(f"訓練集類別分布: {dict(zip(le.classes_, np.bincount(y_train_int)))}")
print(f"驗證集類別分布: {dict(zip(le.classes_, np.bincount(y_val_int)))}")
print(f"測試集類別分布: {dict(zip(le.classes_, np.bincount(y_test_int)))}")

# 使用標準的 balanced 權重
class_counts = np.bincount(y_train_int)
total_samples = len(y_train_int)
n_classes = len(class_counts)

class_weights = total_samples / (n_classes * class_counts)
sample_weights = np.array([class_weights[y] for y in y_train_int])

print(f"類別權重: {dict(zip(le.classes_, class_weights))}")
print(f"樣本權重範圍: {sample_weights.min():.3f} - {sample_weights.max():.3f}\n")

# WeightedMLPClassifier 類
from sklearn.base import BaseEstimator, ClassifierMixin

class WeightedMLPClassifier(BaseEstimator, ClassifierMixin):
    _estimator_type = "classifier"
    
    def __init__(self, **kwargs):
        self.mlp = MLPClassifier(**kwargs)
        self.sample_weights = None
        
    def fit(self, X, y, sample_weight=None):
        self.sample_weights = sample_weight
        self.mlp.fit(X, y, sample_weight=sample_weight)
        return self
    
    def predict(self, X):
        return self.mlp.predict(X)
    
    def predict_proba(self, X):
        return self.mlp.predict_proba(X)
    
    def get_params(self, deep=True):
        return self.mlp.get_params(deep=deep)
    
    def set_params(self, **params):
        self.mlp.set_params(**params)
        return self
    
    @property
    def classes_(self):
        return self.mlp.classes_

# 極簡化配置：只使用單層，參數數量 < 樣本數量
# 目標：參數/樣本比例 < 1.0
n_features = X_train_scaled.shape[1]  # 9個特徵
n_classes = len(le.classes_)  # 3個類別

print("="*60)
print("模型複雜度分析")
print("="*60)
print(f"特徵數量: {n_features}")
print(f"類別數量: {n_classes}")
print(f"訓練樣本數: {len(X_train)}")
print("\n不同架構的參數數量估算:")

# 計算不同架構的參數數量
architectures = [(3,), (4,), (5,), (6,), (8,), (10,)]
for arch in architectures:
    if len(arch) == 1:
        n_params = n_features * arch[0] + arch[0] * n_classes + arch[0] + n_classes
        ratio = n_params / len(X_train)
        status = "✓" if ratio < 1.0 else "✗"
        print(f"  {arch}: {n_params} 參數, 比例 {ratio:.2f} {status}")

print("="*60)

params = {
    'hidden_layer_sizes': [
        # 只使用非常小的單層（確保參數 < 樣本數）
        (3,), (4,), (5,), (6,), (8,), (10,),
    ],
    'activation': ['relu', 'tanh'],
    'alpha': [3.0, 5.0, 10.0, 15.0, 20.0],  # 非常強的正則化
    'learning_rate': ['constant', 'adaptive'],
    'learning_rate_init': [0.001, 0.01],
    'beta_1': [0.9],
    'beta_2': [0.999],
}

clf = WeightedMLPClassifier(
    random_state=42, 
    max_iter=3000,
    early_stopping=True, 
    validation_fraction=0.25,
    tol=1e-3,
    n_iter_no_change=15,
    solver='adam',
    batch_size='auto'
)

skf = StratifiedKFold(n_splits=5, shuffle=True, random_state=42)

grid_search = GridSearchCV(
    clf, 
    params, 
    cv=skf,
    n_jobs=-1, 
    verbose=1,
    scoring='balanced_accuracy',
    refit=True
)

print("="*60)
print("開始 Fine-tune 網格搜索（極簡化單層配置）...")
print("="*60)
grid_search.fit(X_train_scaled, y_train_int, sample_weight=sample_weights)

best_clf = grid_search.best_estimator_
print(f"\n最佳參數: {grid_search.best_params_}")
print(f"最佳交叉驗證 Balanced Accuracy: {grid_search.best_score_:.4f}")

# 計算參數數量
best_hidden = grid_search.best_params_['hidden_layer_sizes']
n_features = X_train_scaled.shape[1]
n_classes = len(le.classes_)

if isinstance(best_hidden, tuple) and len(best_hidden) == 1:
    n_params = n_features * best_hidden[0] + best_hidden[0] * n_classes + best_hidden[0] + n_classes
    ratio = n_params / len(X_train)
    print(f"\n估計參數數量: {n_params}")
    print(f"訓練樣本數量: {len(X_train)}")
    print(f"參數/樣本比例: {ratio:.2f}", end="")
    if ratio < 1.0:
        print(" ✓ (符合建議)")
    elif ratio < 1.5:
        print(" ⚠️ (略高，但可接受)")
    else:
        print(" ✗ (仍然太高)")

# 顯示最終使用的網絡架構
print("\n" + "="*60)
print("最終使用的網絡架構")
print("="*60)
best_hidden_layers = grid_search.best_params_['hidden_layer_sizes']
print(f"hidden_layer_sizes: {best_hidden_layers}")

if isinstance(best_hidden_layers, tuple):
    if len(best_hidden_layers) == 1:
        print(f"架構類型: 單層MLP")
        print(f"  隱藏層: {best_hidden_layers[0]} 個神經元")
    else:
        print(f"架構類型: {len(best_hidden_layers)}層MLP")
        for i, neurons in enumerate(best_hidden_layers, 1):
            print(f"  第{i}層隱藏層: {neurons} 個神經元")

print(f"\n其他重要參數:")
print(f"  激活函數: {grid_search.best_params_['activation']}")
print(f"  正則化參數 (alpha): {grid_search.best_params_['alpha']}")
print(f"  學習率: {grid_search.best_params_['learning_rate']}")
print(f"  初始學習率: {grid_search.best_params_['learning_rate_init']}")
print("="*60)

# Validation set metrics
y_val_pred = best_clf.predict(X_val_scaled)
y_val_pred_proba = best_clf.predict_proba(X_val_scaled)
y_val_pred_labels = le.inverse_transform(y_val_pred)

print("\n" + "="*60)
print("Validation Metrics for Fine-tuned MLPClassifier Model")
print("="*60)
print("\n混淆矩陣:")
print(confusion_matrix(y_val_int, y_val_pred))
print("\n分類報告:")
print(classification_report(y_val, y_val_pred_labels, target_names=list(le.classes_)))

from sklearn.metrics import roc_curve, auc
from sklearn.preprocessing import label_binarize

y_val_binarized = label_binarize(y_val_int, classes=[0, 1, 2])
print("\n每個類別的 ROC AUC:")
for i, class_name in enumerate(le.classes_):
    fpr, tpr, _ = roc_curve(y_val_binarized[:, i], y_val_pred_proba[:, i])
    roc_auc = auc(fpr, tpr)
    print(f"  {class_name} (Class {i}): {roc_auc:.4f}")

macro_roc_auc = roc_auc_score(y_val_int, y_val_pred_proba, multi_class='ovr', average='macro')
weighted_roc_auc = roc_auc_score(y_val_int, y_val_pred_proba, multi_class='ovr', average='weighted')
val_balanced_acc = balanced_accuracy_score(y_val_int, y_val_pred)

print(f"\n宏平均 ROC AUC: {macro_roc_auc:.4f}")
print(f"加權平均 ROC AUC: {weighted_roc_auc:.4f}")
print(f"Balanced Accuracy: {val_balanced_acc:.4f}")

# Test set metrics
y_test_pred = best_clf.predict(X_test_scaled)
y_test_pred_proba = best_clf.predict_proba(X_test_scaled)
y_test_pred_labels = le.inverse_transform(y_test_pred)

print("\n" + "="*60)
print("Test Metrics for Fine-tuned MLPClassifier Model")
print("="*60)
print("\n混淆矩陣:")
print(confusion_matrix(y_test_int, y_test_pred))
print("\n分類報告:")
print(classification_report(y_test, y_test_pred_labels, target_names=list(le.classes_)))

print("\n每個類別的 ROC AUC:")
y_test_binarized = label_binarize(y_test_int, classes=[0, 1, 2])
for i, class_name in enumerate(le.classes_):
    fpr, tpr, _ = roc_curve(y_test_binarized[:, i], y_test_pred_proba[:, i])
    roc_auc = auc(fpr, tpr)
    print(f"  {class_name} (Class {i}): {roc_auc:.4f}")

macro_roc_auc_test = roc_auc_score(y_test_int, y_test_pred_proba, multi_class='ovr', average='macro')
weighted_roc_auc_test = roc_auc_score(y_test_int, y_test_pred_proba, multi_class='ovr', average='weighted')
test_balanced_acc = balanced_accuracy_score(y_test_int, y_test_pred)

print(f"\n宏平均 ROC AUC: {macro_roc_auc_test:.4f}")
print(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}")
print(f"Balanced Accuracy: {test_balanced_acc:.4f}")

# 過擬合分析
print("\n" + "="*60)
print("過擬合分析")
print("="*60)
overfitting_gap = val_balanced_acc - test_balanced_acc
print(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}")
print(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}")
print(f"過擬合差距: {overfitting_gap:.4f}")

if overfitting_gap > 0.15:
    print("⚠️ 警告：存在明顯過擬合（差距 > 0.15）")
    print("   建議：考慮使用更簡單的模型（如RandomForest、Logistic Regression）")
elif overfitting_gap > 0.10:
    print("⚠️ 注意：存在一定過擬合（差距 > 0.10）")
else:
    print("✓ 過擬合程度在可接受範圍內")
print("="*60)

數據分割信息
總樣本數: 88
訓練集樣本數: 30 (34.1%)
驗證集樣本數: 31 (35.2%)
測試集樣本數: 27 (30.7%)
數據集信息
訓練集類別分布: {'A': np.int64(12), 'C': np.int64(12), 'F': np.int64(6)}
驗證集類別分布: {'A': np.int64(11), 'C': np.int64(13), 'F': np.int64(7)}
測試集類別分布: {'A': np.int64(13), 'C': np.int64(4), 'F': np.int64(10)}
類別權重: {'A': np.float64(0.8333333333333334), 'C': np.float64(0.8333333333333334), 'F': np.float64(1.6666666666666667)}
樣本權重範圍: 0.833 - 1.667

模型複雜度分析
特徵數量: 9
類別數量: 3
訓練樣本數: 30

不同架構的參數數量估算:
  (3,): 42 參數, 比例 1.40 ✗
  (4,): 55 參數, 比例 1.83 ✗
  (5,): 68 參數, 比例 2.27 ✗
  (6,): 81 參數, 比例 2.70 ✗
  (8,): 107 參數, 比例 3.57 ✗
  (10,): 133 參數, 比例 4.43 ✗
開始 Fine-tune 網格搜索（極簡化單層配置）...
Fitting 5 folds for each of 240 candidates, totalling 1200 fits

最佳參數: {'activation': 'tanh', 'alpha': 15.0, 'beta_1': 0.9, 'beta_2': 0.999, 'hidden_layer_sizes': (4,), 'learning_rate': 'constant', 'learning_rate_init': 0.01}
最佳交叉驗證 Balanced Accuracy: 0.4333

估計參數數量: 55
訓練樣本數量: 30
參數/樣本比例: 1.83 ✗ (仍然太高)

最終使用的網絡架構
hidden_layer_sizes: (4,)
架構類型: 單層ML

### 多尺度特徵提取 CNN 模型（類似 Receptive Field）

In [None]:
# ============================================
# 完整訓練程式：改進版多尺度 CNN
# ============================================

# ============================================
# 步驟 1: 資料載入和預處理
# ============================================
import pandas as pd
import numpy as np
import os
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import RobustScaler, LabelEncoder
from sklearn.metrics import classification_report, confusion_matrix, balanced_accuracy_score, roc_auc_score
from sklearn.preprocessing import label_binarize
from sklearn.metrics import roc_curve, auc
import ast
import warnings
warnings.filterwarnings("ignore")

# 設置數據路徑
base_dir = r'/ibmnas/427/bachelors/b12901077/eeg'
dataset_dir = os.path.join(base_dir, 'ds004504')

# 讀取特徵數據
features_df = pd.read_csv(os.path.join(base_dir, "features_tv.csv"))
# 讀取參與者信息
participants = pd.read_csv(os.path.join(dataset_dir, "participants.tsv"), delimiter='\t')

# 合併數據
data = features_df.merge(participants, left_index=True, right_index=True)

# 定義特徵字典
features = {
    'stationary_ratio': 'Stationary Ratio',
    'Tik-norm': 'Tik-norm',
    'Total_Variation': 'Total Variation',
    'graph_energy': 'Graph Energy',
    'spectral_entropy': 'Spectral Entropy',
    'signal_energy': 'Signal Energy',
    'signal_power': 'Signal Power',
    'avg_degree': 'Average Degree',
    'diffusion_distance': 'Diffusion Distance',
}

# 準備特徵和標籤
X = data[list(features.keys())].copy()
y = data['Group'].copy()

# 數據清理
for col in X.columns:
    first_val = X[col].iloc[0] if len(X) > 0 else None
    if isinstance(first_val, str):
        def safe_convert(x):
            if pd.isna(x):
                return np.nan
            if isinstance(x, str):
                try:
                    parsed = ast.literal_eval(x)
                    if isinstance(parsed, (list, tuple, np.ndarray)):
                        return float(parsed[0]) if len(parsed) > 0 else np.nan
                    return float(parsed)
                except:
                    try:
                        return float(x)
                    except:
                        try:
                            parsed = eval('[' + ','.join(x.strip('[]').split()) + ']')
                            return float(parsed[0]) if len(parsed) > 0 else np.nan
                        except:
                            return np.nan
            try:
                return float(x) if pd.notna(x) else np.nan
            except:
                return np.nan
        X[col] = X[col].apply(safe_convert)
    else:
        X[col] = pd.to_numeric(X[col], errors='coerce')

rows_with_all_nan = X.isna().all(axis=1)
valid_mask = ~(rows_with_all_nan | y.isna())
X = X[valid_mask].copy()
y = y[valid_mask].copy()

if X.isna().sum().sum() > 0:
    for col in X.columns:
        if X[col].isna().sum() > 0:
            median_val = X[col].median()
            X[col] = X[col].fillna(median_val)

X = X.astype(float)

if len(X) == 0:
    raise ValueError("錯誤：清理後數據為空！請檢查原始數據。")

# 數據分割（目標比例：訓練集 30, 驗證集 31, 測試集 27）
X_train_val, X_test, y_train_val, y_test = train_test_split(
    X, y, test_size=27/88, random_state=42, stratify=y
)

X_train, X_val, y_train, y_val = train_test_split(
    X_train_val, y_train_val, test_size=31/61, random_state=42, stratify=y_train_val
)

X_train = X_train.astype(float)
X_val = X_val.astype(float)
X_test = X_test.astype(float)

# 標準化
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_val_scaled = scaler.transform(X_val)
X_test_scaled = scaler.transform(X_test)

# 編碼標籤
le = LabelEncoder()
y_train_int = le.fit_transform(y_train)
y_val_int = le.transform(y_val)
y_test_int = le.transform(y_test)

print("="*60)
print("數據分割信息")
print("="*60)
print(f"總樣本數: {len(X)}")
print(f"訓練集樣本數: {len(X_train)} ({len(X_train)/len(X)*100:.1f}%)")
print(f"驗證集樣本數: {len(X_val)} ({len(X_val)/len(X)*100:.1f}%)")
print(f"測試集樣本數: {len(X_test)} ({len(X_test)/len(X)*100:.1f}%)")
print("="*60)

# ============================================
# 步驟 2: 模型定義
# ============================================
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader, TensorDataset

class FocalLoss(nn.Module):
    """Focal Loss 用於處理類別不平衡"""
    def __init__(self, alpha=None, gamma=2.0):
        super().__init__()
        self.alpha = alpha
        self.gamma = gamma
    
    def forward(self, inputs, targets):
        ce_loss = F.cross_entropy(inputs, targets, reduction='none')
        pt = torch.exp(-ce_loss)
        focal_loss = ((1 - pt) ** self.gamma) * ce_loss
        
        if self.alpha is not None:
            alpha_t = self.alpha[targets]
            focal_loss = alpha_t * focal_loss
        
        return focal_loss.mean()


class ImprovedMultiScaleCNN(nn.Module):
    """
    改進版多尺度特徵提取 CNN
    針對小樣本優化，參數數量適中
    """
    def __init__(self, n_features=9, n_classes=3, dropout_rate=0.6):
        super().__init__()
        
        # 使用 3 個通道
        self.conv_small = nn.Sequential(
            nn.Conv1d(1, 3, kernel_size=3, padding=1),
            nn.BatchNorm1d(3),
            nn.ReLU()
        )
        
        self.conv_medium = nn.Sequential(
            nn.Conv1d(1, 3, kernel_size=5, padding=2),
            nn.BatchNorm1d(3),
            nn.ReLU()
        )
        
        self.conv_large = nn.Sequential(
            nn.Conv1d(1, 3, kernel_size=9, padding=4),
            nn.BatchNorm1d(3),
            nn.ReLU()
        )
        
        # 融合層
        self.fusion = nn.Sequential(
            nn.Linear(3 * 3, 12),  # 9 -> 12
            nn.BatchNorm1d(12),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(12, n_classes)
        )
        
    def forward(self, x):
        x = x.unsqueeze(1)  # (batch, 1, 9)
        
        # 多尺度卷積
        small = torch.mean(self.conv_small(x), dim=2)  # (batch, 3)
        medium = torch.mean(self.conv_medium(x), dim=2)  # (batch, 3)
        large = torch.mean(self.conv_large(x), dim=2)  # (batch, 3)
        
        # 融合
        combined = torch.cat([small, medium, large], dim=1)  # (batch, 9)
        output = self.fusion(combined)  # (batch, 3)
        
        return output


# ============================================
# 步驟 3: 訓練函數
# ============================================
def train_model_improved(model, train_loader, val_loader, 
                         class_weights, n_epochs=300, lr=0.0005, device='cpu', 
                         use_focal_loss=True):
    model = model.to(device)
    
    # 計算類別權重
    class_weights_tensor = torch.FloatTensor(class_weights).to(device)
    
    # 選擇損失函數
    if use_focal_loss:
        criterion = FocalLoss(alpha=class_weights_tensor, gamma=2.0)
        print("使用 Focal Loss")
    else:
        criterion = nn.CrossEntropyLoss(weight=class_weights_tensor)
        print("使用加權 CrossEntropyLoss")
    
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-3)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=15
    )
    
    best_val_balanced_acc = 0
    best_val_acc = 0
    patience_counter = 0
    max_patience = 40
    best_model_state = None
    
    for epoch in range(n_epochs):
        # 訓練
        model.train()
        train_loss = 0
        for X_batch, y_batch in train_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            
            optimizer.zero_grad()
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            
            # 梯度裁剪
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            
            optimizer.step()
            train_loss += loss.item()
        
        # 驗證
        model.eval()
        val_loss = 0
        val_correct = 0
        val_total = 0
        val_preds = []
        val_labels = []
        
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch, y_batch = X_batch.to(device), y_batch.to(device)
                outputs = model(X_batch)
                loss = criterion(outputs, y_batch)
                val_loss += loss.item()
                
                _, predicted = torch.max(outputs.data, 1)
                val_total += y_batch.size(0)
                val_correct += (predicted == y_batch).sum().item()
                val_preds.extend(predicted.cpu().numpy())
                val_labels.extend(y_batch.cpu().numpy())
        
        val_acc = val_correct / val_total
        val_balanced_acc = balanced_accuracy_score(val_labels, val_preds)
        
        scheduler.step(val_loss)
        
        # 使用 balanced accuracy 作為早停標準
        if val_balanced_acc > best_val_balanced_acc:
            best_val_balanced_acc = val_balanced_acc
            best_val_acc = val_acc
            patience_counter = 0
            best_model_state = model.state_dict().copy()
        else:
            patience_counter += 1
        
        if patience_counter >= max_patience:
            print(f"Early stopping at epoch {epoch+1}")
            break
        
        if (epoch + 1) % 20 == 0:
            print(f"Epoch {epoch+1}/{n_epochs}, Train Loss: {train_loss/len(train_loader):.4f}, "
                  f"Val Loss: {val_loss/len(val_loader):.4f}, Val Acc: {val_acc:.4f}, "
                  f"Val Balanced Acc: {val_balanced_acc:.4f}")
    
    if best_model_state is not None:
        model.load_state_dict(best_model_state)
    
    return model, best_val_acc, best_val_balanced_acc


# ============================================
# 步驟 4: 評估函數
# ============================================
def evaluate_model(model, data_loader, device='cpu'):
    model.eval()
    all_preds = []
    all_probs = []
    all_labels = []
    correct = 0
    total = 0
    
    with torch.no_grad():
        for X_batch, y_batch in data_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            probs = F.softmax(outputs, dim=1)
            _, predicted = torch.max(outputs.data, 1)
            
            all_preds.extend(predicted.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            all_labels.extend(y_batch.cpu().numpy())
            
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()
    
    accuracy = correct / total
    return all_preds, all_probs, all_labels, accuracy


# ============================================
# 步驟 5: 主程式執行
# ============================================
print("\n" + "="*60)
print("改進版多尺度特徵提取 CNN（類似 Receptive Field）")
print("="*60)

# 計算類別權重（使用 inverse frequency）
class_counts = np.bincount(y_train_int)
total_samples = len(y_train_int)
n_classes = len(class_counts)

# 方法 1：標準 balanced 權重
class_weights_balanced = total_samples / (n_classes * class_counts)

# 方法 2：inverse frequency（給少數類別更高權重）
class_weights_inv = 1.0 / class_counts
class_weights_inv = class_weights_inv / class_weights_inv.sum() * n_classes

print(f"\n類別分布: {dict(zip(le.classes_, class_counts))}")
print(f"Balanced 權重: {dict(zip(le.classes_, class_weights_balanced))}")
print(f"Inverse frequency 權重: {dict(zip(le.classes_, class_weights_inv))}")

# 使用 inverse frequency 權重（給少數類別更高權重）
class_weights = class_weights_inv

# 轉換為 PyTorch 格式
X_train_tensor = torch.FloatTensor(X_train_scaled)
X_val_tensor = torch.FloatTensor(X_val_scaled)
X_test_tensor = torch.FloatTensor(X_test_scaled)

y_train_tensor = torch.LongTensor(y_train_int)
y_val_tensor = torch.LongTensor(y_val_int)
y_test_tensor = torch.LongTensor(y_test_int)

# 創建 DataLoader
train_dataset = TensorDataset(X_train_tensor, y_train_tensor)
val_dataset = TensorDataset(X_val_tensor, y_val_tensor)
test_dataset = TensorDataset(X_test_tensor, y_test_tensor)

train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=8, shuffle=False)

# 創建模型
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"\n使用設備: {device}")

model = ImprovedMultiScaleCNN(
    n_features=X_train_scaled.shape[1], 
    n_classes=3, 
    dropout_rate=0.6
)

# 計算參數數量
total_params = sum(p.numel() for p in model.parameters())
print(f"模型參數數量: {total_params}")
print(f"訓練樣本數: {len(X_train_scaled)}")
print(f"參數/樣本比例: {total_params / len(X_train_scaled):.2f}")

# 訓練模型
print("\n" + "="*60)
print("開始訓練...")
print("="*60)
trained_model, best_val_acc, best_val_balanced_acc = train_model_improved(
    model, train_loader, val_loader,
    class_weights=class_weights,
    n_epochs=300,
    lr=0.0005,
    device=device,
    use_focal_loss=True  # 使用 Focal Loss
)

print(f"\n最佳驗證集準確率: {best_val_acc:.4f}")
print(f"最佳驗證集 Balanced Accuracy: {best_val_balanced_acc:.4f}")

# ============================================
# 步驟 6: 驗證集評估
# ============================================
print("\n" + "="*60)
print("Validation Metrics for Improved Multi-Scale CNN")
print("="*60)

val_preds, val_probs, val_labels, val_acc = evaluate_model(trained_model, val_loader, device)
val_pred_labels = le.inverse_transform(val_preds)
val_true_labels = le.inverse_transform(val_labels)

print("\n混淆矩陣:")
print(confusion_matrix(val_labels, val_preds))

print("\n分類報告:")
print(classification_report(val_true_labels, val_pred_labels, target_names=list(le.classes_)))

# ROC AUC
val_binarized = label_binarize(val_labels, classes=[0, 1, 2])
val_probs_array = np.array(val_probs)

print("\n每個類別的 ROC AUC:")
for i, class_name in enumerate(le.classes_):
    fpr, tpr, _ = roc_curve(val_binarized[:, i], val_probs_array[:, i])
    roc_auc = auc(fpr, tpr)
    print(f"  {class_name} (Class {i}): {roc_auc:.4f}")

macro_roc_auc_val = roc_auc_score(val_labels, val_probs_array, multi_class='ovr', average='macro')
weighted_roc_auc_val = roc_auc_score(val_labels, val_probs_array, multi_class='ovr', average='weighted')
val_balanced_acc = balanced_accuracy_score(val_labels, val_preds)

print(f"\n宏平均 ROC AUC: {macro_roc_auc_val:.4f}")
print(f"加權平均 ROC AUC: {weighted_roc_auc_val:.4f}")
print(f"Balanced Accuracy: {val_balanced_acc:.4f}")

# ============================================
# 步驟 7: 測試集評估
# ============================================
print("\n" + "="*60)
print("Test Metrics for Improved Multi-Scale CNN")
print("="*60)

test_preds, test_probs, test_labels, test_acc = evaluate_model(trained_model, test_loader, device)
test_pred_labels = le.inverse_transform(test_preds)
test_true_labels = le.inverse_transform(test_labels)

print("\n混淆矩陣:")
print(confusion_matrix(test_labels, test_preds))

print("\n分類報告:")
print(classification_report(test_true_labels, test_pred_labels, target_names=list(le.classes_)))

# ROC AUC
test_binarized = label_binarize(test_labels, classes=[0, 1, 2])
test_probs_array = np.array(test_probs)

print("\n每個類別的 ROC AUC:")
for i, class_name in enumerate(le.classes_):
    fpr, tpr, _ = roc_curve(test_binarized[:, i], test_probs_array[:, i])
    roc_auc = auc(fpr, tpr)
    print(f"  {class_name} (Class {i}): {roc_auc:.4f}")

macro_roc_auc_test = roc_auc_score(test_labels, test_probs_array, multi_class='ovr', average='macro')
weighted_roc_auc_test = roc_auc_score(test_labels, test_probs_array, multi_class='ovr', average='weighted')
test_balanced_acc = balanced_accuracy_score(test_labels, test_preds)

print(f"\n宏平均 ROC AUC: {macro_roc_auc_test:.4f}")
print(f"加權平均 ROC AUC: {weighted_roc_auc_test:.4f}")
print(f"Balanced Accuracy: {test_balanced_acc:.4f}")

# ============================================
# 步驟 8: 過擬合分析
# ============================================
print("\n" + "="*60)
print("過擬合分析")
print("="*60)
overfitting_gap = val_balanced_acc - test_balanced_acc
print(f"驗證集 Balanced Accuracy: {val_balanced_acc:.4f}")
print(f"測試集 Balanced Accuracy: {test_balanced_acc:.4f}")
print(f"過擬合差距: {overfitting_gap:.4f}")

if overfitting_gap > 0.15:
    print("⚠️ 警告：存在明顯過擬合（差距 > 0.15）")
    print("   建議：增加 dropout 率或減少模型複雜度")
elif overfitting_gap > 0.10:
    print("⚠️ 注意：存在一定過擬合（差距 > 0.10）")
else:
    print("✓ 過擬合程度在可接受範圍內")
print("="*60)

# ============================================
# 步驟 9: 類別預測分布分析
# ============================================
print("\n" + "="*60)
print("類別預測分布分析")
print("="*60)

print("\n驗證集預測分布:")
val_pred_counts = {le.classes_[i]: np.sum(np.array(val_preds) == i) for i in range(len(le.classes_))}
val_true_counts = {le.classes_[i]: np.sum(np.array(val_labels) == i) for i in range(len(le.classes_))}
print(f"真實分布: {val_true_counts}")
print(f"預測分布: {val_pred_counts}")

print("\n測試集預測分布:")
test_pred_counts = {le.classes_[i]: np.sum(np.array(test_preds) == i) for i in range(len(le.classes_))}
test_true_counts = {le.classes_[i]: np.sum(np.array(test_labels) == i) for i in range(len(le.classes_))}
print(f"真實分布: {test_true_counts}")
print(f"預測分布: {test_pred_counts}")
print("="*60)

數據分割信息
總樣本數: 88
訓練集樣本數: 30 (34.1%)
驗證集樣本數: 31 (35.2%)
測試集樣本數: 27 (30.7%)

改進版多尺度特徵提取 CNN（類似 Receptive Field）

類別分布: {'A': np.int64(12), 'C': np.int64(10), 'F': np.int64(8)}
Balanced 權重: {'A': np.float64(0.8333333333333334), 'C': np.float64(1.0), 'F': np.float64(1.25)}
Inverse frequency 權重: {'A': np.float64(0.8108108108108107), 'C': np.float64(0.972972972972973), 'F': np.float64(1.2162162162162162)}

使用設備: cuda
模型參數數量: 261
訓練樣本數: 30
參數/樣本比例: 8.70

開始訓練...
使用 Focal Loss
Epoch 20/300, Train Loss: 0.6046, Val Loss: 0.4866, Val Acc: 0.4194, Val Balanced Acc: 0.4917
Epoch 40/300, Train Loss: 0.5279, Val Loss: 0.4617, Val Acc: 0.4194, Val Balanced Acc: 0.4917
Early stopping at epoch 42

最佳驗證集準確率: 0.4516
最佳驗證集 Balanced Accuracy: 0.5000

Validation Metrics for Improved Multi-Scale CNN

混淆矩陣:
[[0 7 6]
 [0 6 4]
 [0 1 7]]

分類報告:
              precision    recall  f1-score   support

           A       0.00      0.00      0.00        13
           C       0.43      0.60      0.50        10
        