<a href="https://colab.research.google.com/github/MeiChenc/Aurevia/blob/main/Aurevia_model_training.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

shell: # Install deps
pip install numpy scipy mne pandas scikit-learn matplotlib
pip install tensorflow torch torchaudio torchvision
pip install tensorflow-model_optimization  # for quantization
pip install wfdb  # CHB-MIT dataset access
pip install h5py mlflow
pip install mne

In [1]:
from google.colab import drive
drive.mount('/content/drive')


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [2]:
#checking chb0*.npy
import numpy as np, collections, os, pathlib
p = "/content/drive/MyDrive/processed_data_npy"
X = np.load(os.path.join(p, "X_chb05.npy"))
y = np.load(os.path.join(p, "y_chb05.npy"))
print(X.shape, y.shape)
print(collections.Counter(y))



(273506, 512, 4) (273506,)
Counter({np.int64(0): 258575, np.int64(1): 13800, np.int64(2): 1131})


#Model training algo

In [3]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Bidirectional, LSTM, Dense, Dropout, BatchNormalization
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, f1_score
from collections import Counter
import os

In [9]:
# Data path
BASE_PATH = "/content/"
PATIENTS = ["chb05", "chb06", "chb07", "chb08"]


X_data_list = []
y_data_list = []

print(f"from {BASE_PATH} download files...")

for patient_id in PATIENTS:
    x_file_path = os.path.join(BASE_PATH, f"X_{patient_id}.npz")
    y_file_path = os.path.join(BASE_PATH, f"y_{patient_id}.npz")

    try:
        if not os.path.exists(x_file_path) or not os.path.exists(y_file_path):
            print(f"lost files {patient_id}")
            print(f"   lost: {x_file_path} or {y_file_path}")
            continue


        X_loaded = np.load(x_file_path, allow_pickle=True)['arr_0']
        X_data_list.append(X_loaded)

        y_loaded = np.load(y_file_path, allow_pickle=True)['arr_0']
        y_data_list.append(y_loaded)

        print(f"loaded patient {patient_id} file successfully。X shape: {X_loaded.shape}, y shape: {y_loaded.shape}")

    except Exception as e:
        print(f"error：load patient {patient_id} file, {e}")
        continue

# Merge the files to one

if X_data_list:
    X_combined = np.concatenate(X_data_list, axis=0)
    y_combined = np.concatenate(y_data_list, axis=0)

    print("\n--- 載入完成 ---")
    print(f"All patients' file  x shape : {X_combined.shape}")
    print(f"All patients' file y shape: {y_combined.shape}")
else:
    print("\n--- Fall to load ---")

#X_combined & y_combined created


from /content/ download files...
lost files chb05
   lost: /content/X_chb05.npz or /content/y_chb05.npz
lost files chb06
   lost: /content/X_chb06.npz or /content/y_chb06.npz
lost files chb07
   lost: /content/X_chb07.npz or /content/y_chb07.npz
lost files chb08
   lost: /content/X_chb08.npz or /content/y_chb08.npz

--- Fall to load ---


In [14]:
%%writefile heavy_model_2.py
import os, json, numpy as np
from collections import Counter
import tensorflow as tf
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv1D, MaxPooling1D, Bidirectional, LSTM, Dense, Dropout, BatchNormalization
from tensorflow.keras.initializers import HeNormal
from sklearn.metrics import classification_report, f1_score

# ========= Configuration (aligned with your preprocessing) =========
DATA_DIR   = "/content/drive/MyDrive/processed_data_npy"   # .npy only
PATIENTS   = ["chb05","chb06","chb07","chb08"]             # LOPOCV set
OUTPUT_DIR = "/content/drive/MyDrive/models"
os.makedirs(OUTPUT_DIR, exist_ok=True)

SAMPLING_RATE_HZ    = 256.0
WINDOW_DURATION_SEC = 2.0
OVERLAP_PERCENT     = 0.75
STRIDE_SEC          = WINDOW_DURATION_SEC * (1.0 - OVERLAP_PERCENT)  # 0.5s

# --- Global parameters for thresholding and smoothing ---
THRESH_SCAN_STEP = 0.05   # Step size for threshold search
FPH_CAP          = 0.10   # Max False Positives Per Hour constraint
MIN_RUN_LEN      = 2      # Minimum consecutive positive windows for smoothing/event

# ========= I/O (npy only) =========
def load_patient_npy(data_dir, pid):
    x_path = os.path.join(data_dir, f"X_{pid}.npy")
    y_path = os.path.join(data_dir, f"y_{pid}.npy")
    if not (os.path.exists(x_path) and os.path.exists(y_path)):
        raise FileNotFoundError(f"Missing {x_path} or {y_path}")
    X = np.load(x_path, mmap_mode="r")
    y = np.load(y_path, mmap_mode="r")
    return np.asarray(X), np.asarray(y)  # ensure ndarray for TF/Keras


# ========= Metrics: FP/h and Warning Time =========
def fp_per_hour_events(y_true, y_pred, stride_sec= STRIDE_SEC, refractory_min=10, min_run_len=2):

    y_true= np.asarray(y_true)
    # FIX: Typo 'np.asarry' -> 'np.asarray' and ensure conversion
    y_pred = np.asarray(y_pred)

    # FIX: Typo 'y_ture' -> 'y_true'
    inter_mask = (y_true == 0)
    inter_idx = np.where(inter_mask)[0]
    if inter_idx.size == 0:
        return np.inf

    y_pred_inter = (y_pred[inter_mask] > 0).astype(int)
    # FIX: IndentationError corrected here and in the block below
    n = len(y_pred_inter)

    fp_events = 0
    i = 0
    refr_windows = int((refractory_min * 60.0) / max(stride_sec, 1e-9))
    while i < n:
        if y_pred_inter[i] == 1:
            # find positive run
            j = i
            while j < n and y_pred_inter[j] == 1:
                j += 1
            run_len = j - i
            if run_len >= min_run_len:
                fp_events += 1
                i = min(j + refr_windows, n)  # skip refractory
            else:
                i = j
        else:
            i += 1

    hours = (inter_idx.size * stride_sec) / 3600.0
    return fp_events / max(hours, 1e-9)

# ========= Smoothing: minimal consecutive positives =========
# ** change k window majority smoothing to minimal run smoothing
def min_consecutive_positive(y_pred_win: np.ndarray, min_len: int = 2) -> np.ndarray:
    """
    Keep positive runs only if they reach a minimum consecutive length (min_len).
    y_pred_win: 0=inter-ictal, 1=pre-ictal, 2=ictal
    """
    y = np.asarray(y_pred_win).copy()
    n = len(y)
    i = 0
    while i < n:
        if y[i] > 0:
            j = i
            while j < n and y[j] > 0:
                j += 1
            if (j - i) < min_len:
                y[i:j] = 0
            i = j
        else:
            i += 1
    return y

def compute_warning_time_minutes(y_true, y_hat_bin, stride_sec=STRIDE_SEC):
    """
    For each ictal episode onset (first window of label==2 after not-2), look backward to the
    first window flagged positive by the model and compute time difference in minutes.
    """
    y = np.asarray(y_true)
    onsets = []
    for i in range(len(y)):
        if y[i] == 2 and (i == 0 or y[i-1] != 2):
            onsets.append(i)
    if not onsets:
        return (0.0, 0.0, 0, 0)
    detected = []
    for t0 in onsets:
        j = t0
        while j >= 0 and y_hat_bin[j] == 0:
            j -= 1
        if j >= 0:
            minutes = (t0 - j) * stride_sec / 60.0
            detected.append(minutes)
    if not detected:
        return (0.0, 0.0, len(onsets), 0)
    arr = np.asarray(detected)
    return (float(arr.mean()), float(np.median(arr)), len(onsets), len(arr))

# Threshold selection on validation ** max sensitivity under constrain- FPH <= cap, making sure the sens won't be 0
def pick_threshold_on_val(
    p_val: np.ndarray,
    y_val: np.ndarray,
    step: float = 0.01,
    fph_cap: float = 0.10,
    stride_sec: float = 0.5,
    min_sens: float = 0.60,
    objective: str = "sens",
    refractory_min: int = 10,
    min_run_len: int = 2,
):
    """
    Select a threshold under an FP/hour cap (event-level).
    objective: "sens" (maximize sensitivity) or "f1_pos" (maximize F1 of positive vs non-positive).
    If no threshold reaches min_sens within the cap, fall back to the best F1 within the cap; if still none, best overall F1.
    """
    p_pos = p_val[:, 1] + p_val[:, 2]
    y_pos = (y_val > 0).astype(int)

    records = []
    for th in np.arange(0.05, 0.96, step):
        y_hat_bin = (p_pos >= th).astype(int)
        sens = (np.sum((y_pos == 1) & (y_hat_bin == 1)) / max(np.sum(y_pos == 1), 1))
        # === ** use event-level FP/h ===
        # Proper call (use original labels and predictions at window-level):
        fph  = fp_per_hour_events(y_true=y_val, y_pred=y_hat_bin,  # event-level on inter-ictal
                                  stride_sec=stride_sec,
                                  refractory_min=refractory_min,
                                  min_run_len=min_run_len)
        f1   = f1_score(y_pos, y_hat_bin, zero_division=0)
        records.append((float(th), float(sens), float(fph), float(f1)))

    feasible = [r for r in records if r[2] <= fph_cap]

    def pick_best(pool, prefer="sens"):
        if prefer == "sens":
            return max(pool, key=lambda r: (r[1], r[3]))  # sens, then F1
        else:
            return max(pool, key=lambda r: (r[3], r[1]))  # F1, then sens

    if feasible:
        good = [r for r in feasible if r[1] >= min_sens]
        pool = good if good else feasible
        best = pick_best(pool, "sens" if objective == "sens" else "f1")
    else:
        best = pick_best(records, "f1")

    th, sens, fph, f1 = best
    meta = {
        "rule": "sens_or_f1_under_event_fph_cap",
        "th": th, "sens": sens, "fph": fph, "f1_pos": f1,
        "params": {"refractory_min": refractory_min, "min_run_len": min_run_len}
    }
    return th, meta


# ========= Model =========
def build_model(input_shape, export_safe: bool = False):
    lstm_common = dict(
        activation='tanh',
        recurrent_activation='sigmoid',
        use_bias=True,
        unit_forget_bias=True,
    )
    if export_safe:
        lstm_common.update(dict(recurrent_dropout=0.1, implementation=1))

    return Sequential([
        Conv1D(64, 5, activation='relu', kernel_initializer=HeNormal(), input_shape=input_shape),
        BatchNormalization(),
        MaxPooling1D(pool_size=2),
        Dropout(0.3),

        Bidirectional(LSTM(128, return_sequences=True, **lstm_common)),
        Dropout(0.3),
        Bidirectional(LSTM(64, **lstm_common)),
        Dropout(0.3),

        Dense(128, activation='relu', kernel_initializer=HeNormal()),
        Dropout(0.5),
        Dense(3, activation='softmax')
    ])


# ========= One LOPOCV fold =========
def run_fold(test_pid, val_pid, train_pids):
    print(f"\n===== FOLD | test={test_pid} val={val_pid} train={train_pids} =====")
    # Load data
    X_tr_list, y_tr_list = [], []
    for p in train_pids:
        Xp, yp = load_patient_npy(DATA_DIR, p)
        X_tr_list.append(Xp); y_tr_list.append(yp)
    X_tr = np.concatenate(X_tr_list, axis=0); y_tr = np.concatenate(y_tr_list, axis=0)

    X_va, y_va = load_patient_npy(DATA_DIR, val_pid)
    X_te, y_te = load_patient_npy(DATA_DIR, test_pid)

    # Class weights from the RAW train distribution
    class_counts  = Counter(y_tr)
    n_class       = len(class_counts)
    total_samples = sum(class_counts.values())
    class_weights = {cls: total_samples / (n_class * cnt) for cls, cnt in class_counts.items()}
    print(f"[IMB] train raw counts={class_counts}  class_weights={class_weights}")

    # *** class weights for training
    # Build & train
    input_shape = X_tr.shape[1:]
    model = build_model(input_shape)
    opt = tf.keras.optimizers.Adam(learning_rate=1e-3, clipnorm=1.0)
    model.compile(optimizer=opt, loss='sparse_categorical_crossentropy', metrics=['accuracy'])
    model.summary()

    cb_early = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
    cb_rlr   = tf.keras.callbacks.ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=3,
                                                    min_lr=1e-6, verbose=1)

    model.fit(
        X_tr, y_tr,
        validation_data=(X_va, y_va),
        epochs=100,
        batch_size=64,
        class_weight=class_weights,   # keep class_weight
        callbacks=[cb_early, cb_rlr],
        verbose=1
    )


    # Threshold scan on validation
    p_val = model.predict(X_va, verbose=0)
    best_th, th_meta = pick_threshold_on_val(
        p_val, y_va,
        # FIX: Use defined global constants
        step=THRESH_SCAN_STEP,
        fph_cap=FPH_CAP,
        stride_sec=STRIDE_SEC,
        min_sens=0.60,
        objective="sens",
        refractory_min=10,   #  ** align with KPI
        min_run_len=MIN_RUN_LEN # Use defined global constant
    )
    print(f"[THRESH] best_th={best_th:.2f} meta={th_meta}")


    # Test: apply threshold -> 3-class -> K-window smoothing *** Too smoothing that pre-ictal is draged to 0
    p_te = model.predict(X_te, verbose=0)
    p_pos_te = p_te[:,1] + p_te[:,2]
    y_hat_bin = (p_pos_te >= best_th).astype(int)
    y_hat = np.zeros_like(y_te)
    pos_idx = np.where(y_hat_bin == 1)[0]
    if len(pos_idx):
        y_hat[pos_idx] = np.where(p_te[pos_idx,2] >= p_te[pos_idx,1], 2, 1)
    y_hat_sm = min_consecutive_positive(y_hat, min_len=MIN_RUN_LEN) # ** minimum continuous positive = MIN_RUN_LEN


    # Evaluation
    print("\n--- Classification Report (Per Window, post K-smooth) ---")
    print(classification_report(y_te, y_hat_sm, target_names=["Inter-ictal","Pre-ictal","Ictal"], zero_division=0))

    sens = (np.sum((y_te>0) & (y_hat_sm>0)) / max(np.sum(y_te>0), 1)) # ** align validation and test of FPH
    fph  = fp_per_hour_events(y_te, y_hat_sm, stride_sec=STRIDE_SEC,
                              refractory_min=10, min_run_len=MIN_RUN_LEN)

    wt_avg, wt_med, n_epi, n_det = compute_warning_time_minutes(y_te, (y_hat_sm>0).astype(int), stride_sec=STRIDE_SEC)
    print(f"[KPI] Sens={sens*100:.2f}%  FP/h={fph:.4f}  Warning(min) avg/med={wt_avg:.2f}/{wt_med:.2f}  "
          f"Episodes={n_epi} Detected={n_det}")

    # Export per-fold artifacts
    # --- export: build export-safe twin (non-cuDNN) and copy weights ---
    export_model = build_model(input_shape, export_safe=True)
    export_model.set_weights(model.get_weights())

    keras_path  = os.path.join(OUTPUT_DIR, f"heavy_model_{test_pid}.keras")
    tflite_path = os.path.join(OUTPUT_DIR, f"heavy_model_{test_pid}_fp16.tflite")

    # Save Keras (export-safe)
    export_model.save(keras_path)

    # Convert to FP16 TFLite
    converter = tf.lite.TFLiteConverter.from_keras_model(export_model)
    converter.optimizations = [tf.lite.Optimize.DEFAULT]
    converter.target_spec.supported_types = [tf.float16]
    tflite_fp16 = converter.convert()
    with open(tflite_path, "wb") as f:
        f.write(tflite_fp16)


    meta = {
        "fold": {"train": train_pids, "val": [val_pid], "test": [test_pid]},
        "window_sec": WINDOW_DURATION_SEC,
        "stride_sec": STRIDE_SEC,
        "threshold": float(best_th),
        "threshold_meta": th_meta,
        # FIX: Use MIN_RUN_LEN instead of the undefined K_CONSENSUS
        "min_run_len": int(MIN_RUN_LEN),
        "metrics_test": {
            "sensitivity": float(sens),
            "fp_per_hour": float(fph),
            "warning_time_avg_min": float(wt_avg),
            "warning_time_med_min": float(wt_med),
            "episodes": int(n_epi),
            "episodes_detected": int(n_det),
        }
    }
    meta_path = os.path.join(OUTPUT_DIR, f"inference_meta_{test_pid}.json")
    with open(meta_path, "w") as f:
        json.dump(meta, f, indent=2)

    print(f"[EXPORT] {keras_path}")
    print(f"[EXPORT] {tflite_path}")
    print(f"[EXPORT] {meta_path}")

    return {"test": test_pid, "val": val_pid, "train": train_pids,
            "sens": sens, "fph": fph, "wt_avg": wt_avg, "wt_med": wt_med}

def main():
    # Optional: let TF grow GPU memory as needed
    try:
        gpus = tf.config.list_physical_devices('GPU')
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except Exception:
        pass

    # Ring LOPOCV: test=i, val=i+1, others=train
    results = []
    n = len(PATIENTS)
    for i, test_pid in enumerate(PATIENTS):
        val_pid = PATIENTS[(i+1) % n]
        train_pids = [p for p in PATIENTS if p not in (test_pid, val_pid)]
        res = run_fold(test_pid, val_pid, train_pids)
        results.append(res)

    # Summary across folds
    if results:
        sens = np.array([r["sens"] for r in results])
        fph  = np.array([r["fph"]  for r in results])
        wt_a = np.array([r["wt_avg"] for r in results])
        wt_m = np.array([r["wt_med"] for r in results])
        print("\n===== LOPOCV SUMMARY =====")
        print(f"Sensitivity mean/median: {sens.mean()*100:.2f}% / {np.median(sens)*100:.2f}%")
        print(f"FP/h       mean/median: {fph.mean():.4f} / {np.median(fph):.4f}")
        print(f"WarnTime(min) avg/med : {wt_a.mean():.2f} / {np.median(wt_m):.2f}")

if __name__ == "__main__":
    main()

Overwriting heavy_model_2.py


In [15]:
!python heavy_model_2.py

2025-10-12 15:27:23.538224: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760282843.814828    2834 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760282843.900952    2834 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1760282844.526955    2834 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1760282844.526996    2834 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1760282844.527002    2834 computation_placer.cc:177] computation placer alr

In [12]:
import matplotlib.pyplot as plt
def plot_history(history):
    plt.figure(figsize=(10,4))
    plt.subplot(1,2,1); plt.plot(history.history["accuracy"]); plt.plot(history.history["val_accuracy"]); plt.title("Accuracy"); plt.legend(["train","val"])
    plt.subplot(1,2,2); plt.plot(history.history["loss"]); plt.plot(history.history["val_loss"]); plt.title("Loss"); plt.legend(["train","val"])
    plt.tight_layout(); plt.savefig("training_curves.png"); plt.show()
