In [1]:
import numpy as np
from IPython.display import clear_output
import gc
import os

# After heavy computations
clear_output(wait=True)
gc.collect()

0

In [2]:
import os
import torch
import pandas as pd
from datetime import datetime
from itertools import product

from steps import analyze_seizure_propagation
from models import CustomTransformerSeizurePredictor, Wavenet, ResNet, load_model_with_config

# ================= Basic configuration =================
DATA_FOLDER   = "data"
MODEL_FOLDER  = "checkpoints/BestModels"
RESULTS_DIR   = "result"
MARKING_FILE  = os.path.join(DATA_FOLDER, "Seizure_Onset_Type_ML_USC.xlsx")
DEVICE        = "cuda:0" if torch.cuda.is_available() else "cpu"
os.makedirs(RESULTS_DIR, exist_ok=True)

# Evaluate these models (outer loop: load each model once)
MODEL_NAMEs = ["ResNet"]
MODEL_CLASSES = {
    "Transformer": CustomTransformerSeizurePredictor,
    "Wavenet": Wavenet,
    "ResNet": ResNet,
}

# Patients and seizure ranges
patients = {
    65: [1],  # seizures 1–8
    66: range(6, 8),  # seizures 1–7
}

# ========== Grid for visualization parameters ==========
PARAM_GRID = {
    "threshold":        [0.6, 0.7, 0.8, 0.9],
    "smooth_window":    [25, 50, 75],
    "n_seconds":        [100],
    "seizure_start":    [60],          # can be extended if needed
    "seizure_plot_time":[10],
    "overlap":          [0.6, 0.7, 0.8],
}
BASE_STATIC_ARGS = {"device": DEVICE}

# Collect results
rows = []
run_tag = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = os.path.join(RESULTS_DIR, f"propagation_eval_opt_{run_tag}.csv")

# ============== Outer loop: load each model once ==============
for MODEL_NAME in MODEL_NAMEs:
    ckpt_path = os.path.join(MODEL_FOLDER, f"{MODEL_NAME}_best.pth")
    if not os.path.exists(ckpt_path):
        print(f"⚠️ Skip {MODEL_NAME}, checkpoint not found: {ckpt_path}")
        rows.append({
            "model": MODEL_NAME, "patient": None, "seizure": None,
            "best_accuracy_all_channels": None,
            "error": f"checkpoint not found: {ckpt_path}",
            "best_params": None, "results_folder": None
        })
        continue

    try:
        print(f"\n===== Loading {MODEL_NAME} =====")
        model, _ = load_model_with_config(ckpt_path, MODEL_CLASSES[MODEL_NAME])
        model.to(DEVICE).eval()
        print(f"✅ Loaded {MODEL_NAME} from {ckpt_path}")
    except Exception as e:
        print(f"❌ Error loading {MODEL_NAME}: {e}")
        rows.append({
            "model": MODEL_NAME, "patient": None, "seizure": None,
            "best_accuracy_all_channels": None,
            "error": f"load_error: {e}",
            "best_params": None, "results_folder": None
        })
        continue
    # ===== Inner loop: iterate over patients and seizures =====
    for patient_no, seizure_range in patients.items():
        for seizure_no in seizure_range:
            best_score = None
            best_params = None
            last_error  = None

            # ---------- 1) Grid search (do not save plots, only evaluate) ----------
            keys, grids = zip(*PARAM_GRID.items())
            recalculate_features = False  # only for the first run
            for values in product(*grids):
                try_params = dict(zip(keys, values))
                try_params.update(BASE_STATIC_ARGS)

                try:
                    res = analyze_seizure_propagation(
                        patient_no=patient_no,
                        seizure_no=seizure_no,
                        model=model,
                        data_folder=DATA_FOLDER,
                        marking_file=MARKING_FILE,
                        params=try_params,
                        save_results_ind=False,   # no plots in search
                        recalculate_features=recalculate_features
                    )
                    perf = res.get("performance", {}) or {}
                    score = perf.get("accuracy_all_channels", None)
                    recalculate_features = False  # only for the first run

                    if score is None:
                        continue

                    if (best_score is None) or (score > best_score):
                        best_score  = score
                        best_params = try_params.copy()

                except Exception as e:
                    last_error = str(e)
                    continue
                finally:
                    gc.collect()
                    if DEVICE.startswith('cuda'):
                        torch.cuda.empty_cache()
            
            # ---------- 2) Run once with the best parameters (save plots) ----------
            if best_params is None:
                print(f"⚠️ {MODEL_NAME} | P{patient_no} | S{seizure_no} no valid parameter combination.")
                rows.append({
                    "model": MODEL_NAME, "patient": patient_no, "seizure": seizure_no,
                    "best_accuracy_all_channels": None,
                    "error": last_error or "no_valid_param_combination",
                    "best_params": None, "results_folder": None
                })
                continue

            try:
                print(f"▶️ {MODEL_NAME} | P{patient_no} | S{seizure_no} | BEST params: {best_params}")
                final_res = analyze_seizure_propagation(
                    patient_no=patient_no,
                    seizure_no=seizure_no,
                    model=model,
                    data_folder=DATA_FOLDER,
                    marking_file=MARKING_FILE,
                    params=best_params,
                    save_results_ind=True,     # save plots and results
                    recalculate_features=False
                )
                final_perf = final_res.get("performance", {}) or {}
                results_folder = final_res.get("results_folder", "")

                rows.append({
                    "model": MODEL_NAME,
                    "patient": patient_no,
                    "seizure": seizure_no,
                    "best_accuracy_all_channels": final_perf.get("accuracy_all_channels", best_score),
                    "error": None,
                    "best_params": best_params,
                    "results_folder": results_folder,
                })
                print(f"✅ Done: {MODEL_NAME} | P{patient_no} | S{seizure_no} | acc_all={final_perf.get('accuracy_all_channels')}")

            except Exception as e:
                print(f"❌ Final run error: {MODEL_NAME} | P{patient_no} | S{seizure_no}: {e}")
                rows.append({
                    "model": MODEL_NAME,
                    "patient": patient_no,
                    "seizure": seizure_no,
                    "best_accuracy_all_channels": best_score,
                    "error": str(e),
                    "best_params": best_params,
                    "results_folder": None,
                })
            finally:
                gc.collect()
                if DEVICE.startswith('cuda'):
                    torch.cuda.empty_cache()

# ============== Save CSV ==============
df = pd.DataFrame(rows)
df.to_csv(csv_path, index=False, encoding="utf-8-sig")
print(f"\n📄 Results saved to: {csv_path}")
print(df.tail(min(10, len(df))))



===== Loading ResNet =====
✅ 成功加载ResNet模型
📋 模型配置: {'lr': 0.000144039109322085, 'weight_decay': 3.0789594164416115e-06, 'base_filters': 128, 'kernel_size': 16, 'dropout': 0.13138495398671235, 'n_blocks': 3, 'input_dim': 22, 'output_dim': 2}
🕐 保存时间: 2025-08-24T17:17:22.925879
✅ Loaded ResNet from checkpoints/BestModels\ResNet_best.pth
▶️ ResNet | P65 | S1 | BEST params: {'threshold': 0.6, 'smooth_window': 75, 'n_seconds': 100, 'seizure_start': 60, 'seizure_plot_time': 10, 'overlap': 0.6, 'device': 'cuda:0'}
Results saved to: result\P65\Seizure1\analysis_results
✅ Done: ResNet | P65 | S1 | acc_all=0.9375
▶️ ResNet | P66 | S6 | BEST params: {'threshold': 0.6, 'smooth_window': 25, 'n_seconds': 100, 'seizure_start': 60, 'seizure_plot_time': 10, 'overlap': 0.6, 'device': 'cuda:0'}
Results saved to: result\P66\Seizure6\analysis_results
✅ Done: ResNet | P66 | S6 | acc_all=0.9090909090909091
▶️ ResNet | P66 | S7 | BEST params: {'threshold': 0.6, 'smooth_window': 25, 'n_seconds': 100, 'seizure_s