In [None]:
from google.colab import drive
drive.mount('/content/drive')  # 挂载 Google Drive

In [None]:
pip install tsfel

### Median Filter

In [None]:
import os
import glob
import pandas as pd

PRIMARY_SIGNALS = [
    "Acc_X", "Acc_Y", "Acc_Z",
    "Gyr_X", "Gyr_Y", "Gyr_Z",
    "FreeAcc_E", "FreeAcc_N", "FreeAcc_U",]

def median_filter(
    cleaned_folder="/content/drive/My Drive/final_project/augmentation/cleaned",
    output_folder="/content/drive/My Drive/final_project/augmentation/medianfilter",
    kernel_size=3,
    return_dict=True
):

    assert kernel_size % 2 == 1, 
    os.makedirs(output_folder, exist_ok=True)

    filtered_data = {} if return_dict else None
    files = sorted(glob.glob(os.path.join(cleaned_folder, "*.csv")))

    for fp in files:
        df = pd.read_csv(fp)

        cols_to_filter = [c for c in PRIMARY_SIGNALS if c in df.columns]

        for col in cols_to_filter:
            s = pd.to_numeric(df[col], errors="coerce")
            df[col] = (
                s.rolling(window=kernel_size, center=True, min_periods=1)
                 .median()
                 .astype(float)
            )

        filename = os.path.basename(fp)
        out_path = os.path.join(output_folder, filename)
        df.to_csv(out_path, index=False)
        print(f"[saved] {out_path}")

        if return_dict:
            filtered_data[filename.replace(".csv", "")] = df

    return filtered_data

# 运行
filtered_data_dict = median_filter(
    output_folder="/content/drive/My Drive/final_project/augmentation/medianfilter"
)


### windows sliding
5s， 50% overlapping

In [None]:
import os
import re
import glob
import numpy as np
import pandas as pd
from typing import List, Tuple

# 作为时序通道的列
CANDIDATE_SIGNAL_COLS = [
    "Acc_X","Acc_Y","Acc_Z",
    "Gyr_X","Gyr_Y","Gyr_Z",
    "FreeAcc_E","FreeAcc_N","FreeAcc_U",
]

def _parse_meta_from_filename(fp: str):
    base = os.path.basename(fp).replace(".csv","")
    m = re.match(r"^T(\d+)_([A-Za-z0-9]+(?:_WT\d+)?)_(ankle|wrist)$", base, flags=re.IGNORECASE)
    if not m:
        return None
    subj = int(m.group(1))
    act = m.group(2)
    sens = m.group(3).lower()
    return subj, act, sens

def make_windows_from_folder(
    folder: str,
    signal_cols: List[str] = CANDIDATE_SIGNAL_COLS,
    subject_col: str = "Subject",
    activity_col: str = "Activity",
    sensor_col: str = "Sensor",
    time_col: str = "PacketTime_ms",  
    fs_hz: int = 40,
    window_sec: int = 5,
    overlap: float = 0.5,
    drop_na_windows: bool = True,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, pd.DataFrame]:

    assert 0 < overlap < 1
    T    = int(round(window_sec * fs_hz))
    step = max(1, int(round(T * (1.0 - overlap))))

    X_list, y_list, g_list, meta_rows = [], [], [], []
    csvs = sorted(glob.glob(os.path.join(folder, "*.csv")))
    D_ref = None  

    for fp in csvs:
        df = pd.read_csv(fp)

        # 从文件名获得 Subject/Activity/Sensor #
        parsed = _parse_meta_from_filename(fp)
        subj_val, act_val, sens_val = parsed
        df[subject_col] = subj_val
        df[activity_col] = act_val
        df[sensor_col]  = sens_val

        # 时间排序
        if time_col not in df.columns:
            raise ValueError(f"{os.path.basename(fp)} 缺少时间列 {time_col}")
        df = df.sort_values(time_col).reset_index(drop=True)

        cols = [c for c in signal_cols if c in df.columns]
        if not cols:
            print(f"[warn] {os.path.basename(fp)} 无可用信号列，跳过。")
            continue

        keep = cols + [subject_col, activity_col, sensor_col]
        df = df[keep].copy()

        for c in cols:
            df[c] = pd.to_numeric(df[c], errors="coerce")

        # 按 (Subject, Sensor, Activity) 分组切窗
        for (_subj, _sens, _act), g in df.groupby([subject_col, sensor_col, activity_col], sort=False):
            arr = g[cols].to_numpy(dtype=float)
            L, D_curr = arr.shape
            if L < T:
                continue

            if D_ref is None:
                D_ref = D_curr
            elif D_curr != D_ref:
                print(f"{os.path.basename(fp)} 分组({_subj},{_sens},{_act}) 的通道数 D={D_curr} "
                      f"与首个窗口 D={D_ref} 不一致，跳过该分组。")
                continue

            # 切窗
            start = 0
            while start + T <= L:
                win = arr[start:start+T] 
                if (not drop_na_windows) or np.isfinite(win).all():
                    X_list.append(win)
                    y_list.append(_act)
                    g_list.append(_subj)  # LOSO 用 subject 分组
                    meta_rows.append({
                        "file": os.path.basename(fp),
                        "Subject": _subj,
                        "Sensor": _sens,
                        "Activity": _act,
                        "start_pos": int(start),
                        "end_pos": int(start+T-1),
                        "T": T,
                        "D": D_ref
                    })
                start += step

    X = np.stack(X_list, axis=0) if X_list else np.empty((0, T, len(CANDIDATE_SIGNAL_COLS)))
    y = np.array(y_list)
    groups = np.array(g_list)
    meta = pd.DataFrame(meta_rows)
    return X, y, groups, meta


In [None]:
median_folder = "/content/drive/My Drive/final_project/augmentation/medianfilter"
X, y, groups, meta = make_windows_from_folder(
    folder=median_folder,
    signal_cols=CANDIDATE_SIGNAL_COLS,
    subject_col="Subject",
    activity_col="Activity",
    sensor_col="Sensor",
    fs_hz=40,
    window_sec=5,
    overlap=0.5,
)
print("X shape:", X.shape) 
print("y shape:", y.shape)
print("groups shape:", groups.shape)
print(meta.head())

##### 保存数据，后续能加载

In [None]:
import numpy as np

np.savez_compressed(
    "../augmentation/windows/windows_5s_50pct_D9.npz",
    X=X, y=y, groups=groups
)
meta.to_csv("../augmentation/windows/windows_meta.csv", index=False)

# 
data = np.load("../augmentation/windows/windows_5s_50pct_D9.npz")
X, y, groups = data["X"], data["y"], data["groups"]


--------

#### HistGradientBoosting + PCA/SelectKBest

In [None]:
# Imports
import numpy as np
import pandas as pd
from scipy.signal import welch
from scipy import interpolate

from sklearn.base import BaseEstimator, TransformerMixin, clone
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
from sklearn.feature_selection import SelectKBest, f_classif, mutual_info_classif
from sklearn.ensemble import HistGradientBoostingClassifier, RandomForestClassifier
from lightgbm import LGBMClassifier
from sklearn.metrics import (accuracy_score, f1_score, balanced_accuracy_score,
                             classification_report, confusion_matrix, ConfusionMatrixDisplay)
from sklearn.model_selection import GridSearchCV, LeaveOneGroupOut
from imblearn.pipeline import Pipeline as ImbPipeline
from imblearn.base import SamplerMixin
import matplotlib.pyplot as plt


# 
import tsfel
_cfg = tsfel.get_features_by_domain()
TSFEL_CFG = {"temporal": _cfg["temporal"], "spectral": _cfg["spectral"]}

SEED = 42
rng_global = np.random.default_rng(SEED)

# tsfel+4
def _spectral_centroid_and_edge(x, fs_hz, edge_percent=0.9):
    f, Pxx = welch(x, fs=fs_hz, nperseg=min(len(x), 256))
    Pxx = np.clip(Pxx, 1e-12, None)
    sc = float(np.sum(f * Pxx) / np.sum(Pxx))
    cumsum = np.cumsum(Pxx) / np.sum(Pxx)
    idx = int(np.searchsorted(cumsum, edge_percent))
    idx = min(idx, len(f)-1)
    sef = float(f[idx])
    return sc, sef

def _jerk_stats(x, fs_hz):
    j = np.diff(x) * fs_hz
    if j.size == 0:
        return 0.0, 0.0
    jm = float(np.mean(j))
    js = float(np.std(j, ddof=1)) if j.size > 1 else 0.0
    return jm, js

def _coef_var(x):
    m = float(np.mean(x))
    if abs(m) < 1e-8:
        return 0.0
    return float(np.std(x, ddof=1) / abs(m))

def featurize_windows_tsfel_plus(X, colnames, fs_hz):
    rows = []
    for n in range(X.shape[0]):
        w = X[n] 
        df_w = pd.DataFrame(w, columns=colnames)

        # TSFEL
        tsfel_df = tsfel.time_series_features_extractor(
            TSFEL_CFG, df_w, fs=fs_hz, window_size=None, verbose=0
        )
        tsfel_df.reset_index(drop=True, inplace=True)

        # +4 
        extras = {}
        for d, name in enumerate(colnames):
            x = w[:, d]
            sc, sef = _spectral_centroid_and_edge(x, fs_hz, edge_percent=0.90)
            if name.startswith("Acc_") or name.startswith("FreeAcc_"):
                jm, js = _jerk_stats(x, fs_hz)
            else:
                jm, js = 0.0, 0.0
            cv = _coef_var(x)
            extras[f"{name}_spec_centroid"] = sc
            extras[f"{name}_spec_edge90"] = sef
            extras[f"{name}_jerk_mean"] = jm
            extras[f"{name}_jerk_std"] = js
            extras[f"{name}_coef_var"] = cv

        rows.append(pd.concat([tsfel_df.iloc[0], pd.Series(extras)], axis=0))

    return pd.DataFrame(rows)

# augmentation
def _find_triplets(colnames):
    idx = {c:i for i,c in enumerate(colnames)}
    out = []
    for a,b,c in [("Acc_X","Acc_Y","Acc_Z"),
                  ("Gyr_X","Gyr_Y","Gyr_Z"),
                  ("FreeAcc_E","FreeAcc_N","FreeAcc_U")]:
        if a in idx and b in idx and c in idx:
            out.append((idx[a], idx[b], idx[c]))
    return out

def _random_rotation_matrix(rng, max_deg=20):
    theta = np.deg2rad(rng.uniform(-max_deg, max_deg))
    axis = rng.normal(size=3); axis /= (np.linalg.norm(axis)+1e-8)
    x,y,z = axis; c,s = np.cos(theta), np.sin(theta)
    return np.array([[c+x*x*(1-c),   x*y*(1-c)-z*s, x*z*(1-c)+y*s],
                     [y*x*(1-c)+z*s, c+y*y*(1-c),   y*z*(1-c)-x*s],
                     [z*x*(1-c)-y*s, z*y*(1-c)+x*s, c+z*z*(1-c)]], dtype=float)

def aug_rotate(win, triplets, rng, max_deg=20):
    if not triplets: return win
    R = _random_rotation_matrix(rng, max_deg=max_deg)
    out = win.copy()
    for i,j,k in triplets:
        out[:, [i,j,k]] = out[:, [i,j,k]] @ R.T
    return out

def aug_permute(win, rng, min_seg=3, max_seg=5):
    T = win.shape[0]
    K = rng.integers(min_seg, max_seg+1)
    idx = np.linspace(0, T, K+1).astype(int)
    segs = [win[idx[i]:idx[i+1]] for i in range(K)]
    return np.concatenate([segs[o] for o in rng.permutation(K)], axis=0)

def aug_timewarp(win, rng, sigma=0.2, knots=4):
    T, D = win.shape
    t_src = np.linspace(0, 1, T)
    knot_x = np.linspace(0, 1, knots+2)
    perturb = rng.normal(0, sigma, size=knots)
    knot_y = np.r_[0, np.cumsum(perturb)/max(1,knots), 1.0]
    knot_y = np.sort(np.clip(knot_y, 0, 1))
    spl = interpolate.UnivariateSpline(knot_x, knot_y, k=3, s=0)
    t_new = spl(t_src)
    t_new = (t_new - t_new.min()) / (t_new.max() - t_new.min() + 1e-8)
    out = np.zeros_like(win)
    for d in range(D):
        spld = interpolate.UnivariateSpline(t_src, win[:,d], k=3, s=0)
        out[:,d] = spld(t_new)
    return out

def aug_jitter(win, rng, percent=0.02):
    noise = rng.normal(0, 1, size=win.shape) * (np.std(win, axis=0, ddof=1) * percent + 1e-8)
    return win + noise

def aug_scaling(win, rng, low=0.9, high=1.1):
    scale = rng.uniform(low, high, size=(1, win.shape[1]))
    return win * scale

def aug_magwarp(win, rng, sigma=0.2, knots=4):
    T, D = win.shape
    t = np.linspace(0, 1, T)
    knot_x = np.linspace(0, 1, knots+2)
    warp = np.zeros((T, D))
    for d in range(D):
        pert = rng.normal(0, sigma, size=knots)
        knot_y = np.r_[0, np.cumsum(pert)/max(1,knots), 0]
        spl = interpolate.UnivariateSpline(knot_x, knot_y, k=3, s=0)
        warp[:, d] = 1.0 + spl(t)  
    return win * warp

from sklearn.base import BaseEstimator

class WindowAugmenter(BaseEstimator):
    _parameter_constraints: dict = {}

    def __init__(self, colnames, fs_hz=40,
                 do_rot=True, rot_max_deg=20,
                 do_timewarp=False, tw_sigma=0.2, tw_knots=4,
                 do_permute=False, perm_min=3, perm_max=5,
                 do_jitter=False, jitter_pct=0.02,
                 do_scaling=False, scale_low=0.9, scale_high=1.1,
                 do_magwarp=False, mw_sigma=0.2, mw_knots=4,
                 per_class_strategy="q80",
                 random_state=42):
        self.colnames = colnames
        self.fs_hz = fs_hz
        self.do_rot = do_rot; self.rot_max_deg = rot_max_deg
        self.do_timewarp = do_timewarp; self.tw_sigma = tw_sigma; self.tw_knots = tw_knots
        self.do_permute = do_permute; self.perm_min = perm_min; self.perm_max = perm_max
        self.do_jitter = do_jitter; self.jitter_pct = jitter_pct
        self.do_scaling = do_scaling; self.scale_low = scale_low; self.scale_high = scale_high
        self.do_magwarp = do_magwarp; self.mw_sigma = mw_sigma; self.mw_knots = mw_knots
        self.per_class_strategy = per_class_strategy
        self.random_state = random_state

    def fit(self, X, y):
        self.triplets_ = _find_triplets(self.colnames)
        self.rng_ = np.random.default_rng(self.random_state)
        if (y is None) or (self.per_class_strategy == "none"):
            self.ratio_ = None
        else:
            cnt = pd.Series(y).value_counts()
            if self.per_class_strategy == "q80":
                target = int(min(cnt.max(), cnt.quantile(0.8)))
            elif self.per_class_strategy == "max":
                target = int(cnt.max())
            else:
                target = None
            self.ratio_ = None if target is None else {c: max(1, int(np.ceil(target/c))) for c in cnt.index}
        return self

    def _augment_one(self, w):
        rng = self.rng_
        out = w.copy()
        if self.do_rot:
            out = aug_rotate(out, self.triplets_, rng, max_deg=self.rot_max_deg)
        if self.do_timewarp:
            out = aug_timewarp(out, rng, sigma=self.tw_sigma, knots=self.tw_knots)
        if self.do_permute:
            out = aug_permute(out, rng, min_seg=self.perm_min, max_seg=self.perm_max)
        if self.do_jitter:
            out = aug_jitter(out, rng, percent=self.jitter_pct)
        if self.do_scaling:
            out = aug_scaling(out, rng, low=self.scale_low, high=self.scale_high)
        if self.do_magwarp:
            out = aug_magwarp(out, rng, sigma=self.mw_sigma, knots=self.mw_knots)
        return out

    def fit_resample(self, X, y):
        if getattr(self, "ratio_", None) is None:
            return X, y
        X_aug, y_aug = [X], [y]
        y_arr = np.asarray(y)
        for cls, mult in self.ratio_.items():
            if mult <= 1:
                continue
            idx = np.where(y_arr == cls)[0]
            need = (mult - 1) * len(idx)
            if need <= 0 or len(idx) == 0:
                continue
            pick = self.rng_.choice(idx, size=need, replace=True)
            aug_list = [self._augment_one(X[i]) for i in pick]
            X_aug.append(np.stack(aug_list))
            y_aug.append(np.full(len(aug_list), cls, dtype=y_arr.dtype))
        return np.concatenate(X_aug, axis=0), np.concatenate(y_aug, axis=0)

class WindowFeaturizer(BaseEstimator, TransformerMixin):
    def __init__(self, colnames, fs_hz=40):
        self.colnames = colnames; self.fs_hz = fs_hz
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        Feat = featurize_windows_tsfel_plus(X, self.colnames, self.fs_hz)
        A = Feat.values    # 交给后续 standardscaler,selector
        A = np.nan_to_num(A, nan=0.0, posinf=0.0, neginf=0.0) 
        return A

# load data & merge M10
data = np.load("/content/drive/My Drive/final_project/augmentation/windows/windows_5s_50pct_D9.npz")
X, y, groups = data["X"], data["y"].astype(str), data["groups"]

def collapse_walk_labels(y_arr):
    y_out = y_arr.copy().astype(str)
    mask = np.char.startswith(y_out, "M10")
    y_out[mask] = "M10"
    return y_out

y = collapse_walk_labels(y)

signal_cols = ["Acc_X","Acc_Y","Acc_Z","Gyr_X","Gyr_Y","Gyr_Z","FreeAcc_E","FreeAcc_N","FreeAcc_U"]
fs_hz = 40

# split - subject level
train_subjects = np.array([29, 28, 24, 21, 3, 26, 14, 25, 12, 18, 5, 11])
val_subjects   = np.array([17, 13, 22, 23, 20])
test_subjects  = np.array([4, 2, 1, 19])

tr_mask = np.isin(groups, train_subjects)
va_mask = np.isin(groups, val_subjects)
te_mask = np.isin(groups, test_subjects)

X_tr, y_tr, g_tr = X[tr_mask], y[tr_mask], groups[tr_mask]
X_va, y_va = X[va_mask], y[va_mask]
X_te, y_te = X[te_mask], y[te_mask]

print("Train/Val/Test windows:", X_tr.shape, X_va.shape, X_te.shape)

# Pipeline
logo = LeaveOneGroupOut()

augmenter = WindowAugmenter(
    colnames=signal_cols, fs_hz=fs_hz,
    do_rot=True,  rot_max_deg=20,
    do_timewarp=True, tw_sigma=0.2, tw_knots=4,
    do_permute=False,
    do_jitter=False,
    do_scaling=False,
    do_magwarp=False,
    per_class_strategy="q80",
    random_state=42
)

pipe = ImbPipeline([
    ("augment", augmenter), 
    ("featurize", WindowFeaturizer(signal_cols, fs_hz)),
    ("scale", StandardScaler()),
    ("selector", PCA(random_state=SEED)),
    # ("selector", KBestFlex(score_func=f_classif, k=0.6)),  
    ("clf", HistGradientBoostingClassifier(random_state=SEED))
])

param_grid = {
    "augment__rot_max_deg": [20], 
    "augment__do_timewarp": [True], 
    "augment__tw_sigma": [0.2],
    "augment__tw_knots": [4],  
    "augment__per_class_strategy": ["q80"],
    "selector__n_components": [0.8],  # 0.9, 0.95, 0.7
    "clf__learning_rate": [0.01],       
    "clf__max_depth": [8, 6], 
    "clf__max_iter": [800], 
    "clf__min_samples_leaf": [25],   
    "clf__l2_regularization": [10.0],   
    "clf__early_stopping": [True],  
}

# --------------------------- RF ----------------------------
# pipe_rf_kbest = ImbPipeline([
#     ("augment", augmenter),
#     ("featurize", WindowFeaturizer(signal_cols, fs_hz)),
#     ("scale", StandardScaler()),
#     ("selector", KBestFlex(score_func=f_classif, k=0.6)), 
#     ("clf", RandomForestClassifier(
#         class_weight="balanced_subsample",  
#         n_jobs=-1, random_state=SEED
#     ))
# ])
# param_grid_rf_kbest = {
#     "augment__rot_max_deg": [20],
#     "augment__do_timewarp": [True],
#     "augment__tw_sigma": [0.2],
#     "augment__tw_knots": [4],
#     "augment__per_class_strategy": ["q80"],
#     "selector__k": [0.6],
#     "clf__n_estimators": [600],        # [300, 500]
#     "clf__max_depth": [10],     
#     "clf__min_samples_split": [4],   # [2, 4]
#     "clf__min_samples_leaf": [8],    # [2, 8]
#     "clf__max_features": ["sqrt"],     # ["sqrt", "log2"]
# }
# --------------------------- RF ----------------------------

# --------------------------- LightGBM ----------------------------
# from lightgbm import LGBMClassifier
# pipe_lgbm_pca = ImbPipeline([
#     ("augment", augmenter),
#     ("featurize", WindowFeaturizer(signal_cols, fs_hz)),
#     ("scale", StandardScaler()),
#     ("selector", PCA(random_state=SEED)),
#     ("clf", LGBMClassifier(
#         objective="multiclass", num_class=len(np.unique(y_tr)),
#         random_state=SEED, n_jobs=-1
#     ))
# ])

# param_grid_lgbm_pca = {
#     "augment__per_class_strategy": ["q80"],
#     "selector__n_components": [0.8],
#     "clf__learning_rate": [0.05],          # 0.01,
#     "clf__n_estimators": [800],         # 600,
#     "clf__max_depth": [8],          # -1 表示不限制, -1, 6,
#     "clf__num_leaves": [63],           # 31,
#     "clf__min_child_samples": [100],       # 50,
#     "clf__subsample": [1.0],          # bagging  # 0.7,
#     "clf__colsample_bytree": [1.0],       # 0.7,
#     "clf__reg_lambda": [10.0],             # 0.0,
# }
# --------------------------- LightGBM ----------------------------


gs = GridSearchCV(
    estimator=pipe,
    param_grid=param_grid,
    cv=logo.split(X_tr, y_tr, g_tr), 
    scoring="accuracy",
    n_jobs=2,
    verbose=2,
    refit=True
)
gs.fit(X_tr, y_tr)

print("Best params:", gs.best_params_)
print("Best LOSO acc on Dev(Train):", gs.best_score_)

best_model = gs.best_estimator_

from joblib import dump
dump(best_model, '/content/drive/My Drive/final_project/augmentation/modelSave/best_model_hist.joblib')



##### Validation Set

In [None]:

yhat_va = best_model.predict(X_va)

labels_va = best_model.named_steps["clf"].classes_  
print("VAL Acc: %.3f | Macro-F1: %.3f | BalAcc: %.3f" % (
    accuracy_score(y_va, yhat_va),
    f1_score(y_va, yhat_va, average="macro"),
    balanced_accuracy_score(y_va, yhat_va)
))

print("\nValidation report:\n", classification_report(y_va, yhat_va, labels=labels_va, target_names=labels_va))

# cm_va = confusion_matrix(y_va, yhat_va, labels=labels_va)
# ConfusionMatrixDisplay(cm_va, display_labels=labels_va).plot(xticks_rotation=45, cmap="Blues", values_format="d")
# plt.title("Confusion Matrix - Validation"); plt.tight_layout(); plt.show()

########################
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay

def plot_confusions(y_true, y_pred, labels, title_prefix="Validation"):
    cm_counts = confusion_matrix(y_true, y_pred, labels=labels)
    disp = ConfusionMatrixDisplay(cm_counts, display_labels=labels)
    disp.plot(xticks_rotation=45, cmap="Blues", values_format="d")
    plt.title(f"Confusion Matrix — {title_prefix} (Counts)")
    plt.tight_layout()
    plt.show()

    cm_norm = confusion_matrix(y_true, y_pred, labels=labels, normalize="true")
    disp = ConfusionMatrixDisplay(cm_norm, display_labels=labels)
    disp.plot(xticks_rotation=45, cmap="Blues", values_format=".2f")
    plt.title(f"Confusion Matrix — {title_prefix} (Normalized)")
    plt.tight_layout()
    plt.show()

labels_va = best_model.named_steps["clf"].classes_ 
plot_confusions(y_va, yhat_va, labels_va, title_prefix="Validation")




-------

##### Test Set

In [None]:

# Dev(Train+Val) 重训 -- Test 评估
dev_mask = np.isin(groups, np.r_[train_subjects, val_subjects])
X_dev, y_dev = X[dev_mask], y[dev_mask]

final_model = clone(best_model)  
final_model.fit(X_dev, y_dev)  

yhat_te = final_model.predict(X_te)

labels_te = final_model.named_steps["clf"].classes_
print("TEST Acc: %.3f | Macro-F1: %.3f | BalAcc: %.3f" % (
    accuracy_score(y_te, yhat_te),
    f1_score(y_te, yhat_te, average="macro"),
    balanced_accuracy_score(y_te, yhat_te)
))

print("\nTest report:\n", classification_report(y_te, yhat_te, labels=labels_te, target_names=labels_te))

# cm_te = confusion_matrix(y_te, yhat_te, labels=labels_te)
# ConfusionMatrixDisplay(cm_te, display_labels=labels_te).plot(xticks_rotation=45, cmap="Blues", values_format="d")
# plt.title("Confusion Matrix - Test"); plt.tight_layout(); plt.show()

## test 集的两版 confusion matrix
labels_te = final_model.named_steps["clf"].classes_
plot_confusions(y_te, yhat_te, labels_te, title_prefix="Test")
