In [2]:
# Execute only once when stock is miss
# !pip -q install lightgbm scikit-learn pandas numpy matplotlib pyarrow
import os, time, json
from pathlib import Path

import numpy as np
import pandas as pd
import lightgbm as lgb

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import LabelEncoder
from sklearn.metrics import (
    f1_score, classification_report, confusion_matrix, precision_recall_fscore_support
)
import matplotlib.pyplot as plt

# 结果目录
Path("results/logs").mkdir(parents=True, exist_ok=True)
Path("results/figures").mkdir(parents=True, exist_ok=True)
Path("results/tables").mkdir(parents=True, exist_ok=True)

def per_class_recall(y_true, y_pred):
    # 返回一个 {class_id: recall} 的 dict
    p, r, f, s = precision_recall_fscore_support(y_true, y_pred, average=None, zero_division=0)
    return {int(i): float(r[i]) for i in range(len(r))}

def plot_recall_bar(rec_dict, class_names, title, save_path=None):
    idx = list(rec_dict.keys())
    vals = [rec_dict[i] for i in idx]
    names = [class_names[i] for i in idx]
    plt.figure(figsize=(10,4))
    plt.bar(range(len(vals)), vals)
    plt.xticks(range(len(vals)), names, rotation=45, ha='right')
    plt.ylim(0,1)
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=150)
    plt.show()
# === 修改这里：你的 parquet 路径 ===
DATA_PATH = "data/ciciot2023_subset.parquet"

df = pd.read_parquet(DATA_PATH)
assert "label" in df.columns, "需要一列 label"
print("Loaded:", df.shape, "classes:", df["label"].nunique())
print(df["label"].value_counts().head())

# 标签编码
le = LabelEncoder()
y = le.fit_transform(df["label"])
class_names = list(le.classes_)

# 特征
X = df.drop(columns=["label"])

# 若还有 object 列，做 one-hot（EFB 对稀疏/互斥特征更有发挥空间）
cat_cols = [c for c in X.columns if X[c].dtype == "object"]
if cat_cols:
    X = pd.get_dummies(X, columns=cat_cols, dummy_na=False)

M_RAW = X.shape[1]     # 原始维度
print("M_raw =", M_RAW)

# 切分（先用随机分层；你也可以换成按时间/设备的切分策略）
Xtr, Xtmp, ytr, ytmp = train_test_split(X, y, test_size=0.30, stratify=y, random_state=42)
Xva, Xte,  yva, yte  = train_test_split(Xtmp, ytmp, test_size=0.50, stratify=ytmp, random_state=42)

print("Train/Val/Test:", Xtr.shape, Xva.shape, Xte.shape)
def train_lgbm(
    Xtr, ytr, Xva, yva,
    use_goss=False, enable_efb=True, top_rate=0.10, other_rate=0.10,
    seed=42, learning_rate=0.05, num_leaves=64, min_data_in_leaf=50,
    objective="multiclass", num_class=None
):
    params = dict(
        objective=objective,
        metric=["multi_logloss"] if objective=="multiclass" else ["auc","binary_logloss"],
        num_class=num_class if objective=="multiclass" else None,
        boosting_type="goss" if use_goss else "gbdt",
        learning_rate=learning_rate,
        num_leaves=num_leaves,
        min_data_in_leaf=min_data_in_leaf,
        max_bin=255,
        feature_fraction=1.0,
        bagging_fraction=1.0,
        lambda_l2=1.0,
        enable_bundle=enable_efb,   # EFB 开关
        force_col_wise=True,
        deterministic=True,
        seed=seed,
        verbose=-1,
    )
    if use_goss:
        params.update(dict(top_rate=top_rate, other_rate=other_rate))  # GOSS 采样比

    dtr = lgb.Dataset(Xtr, ytr, free_raw_data=False)
    dva = lgb.Dataset(Xva, yva, reference=dtr, free_raw_data=False)

    t0 = time.time()
    booster = lgb.train(
        params, dtr, num_boost_round=5000,
        valid_sets=[dtr, dva], valid_names=["train","val"],
        early_stopping_rounds=100, keep_training_booster=True, verbose_eval=50
    )
    t1 = time.time()
    elapsed = t1 - t0

    # 评测（验证集）
    proba = booster.predict(Xva, num_iteration=booster.best_iteration)
    if objective == "multiclass":
        yhat = proba.argmax(1)
        macro_f1 = f1_score(yva, yhat, average="macro")
    else:
        yhat = (proba >= 0.5).astype(int)
        macro_f1 = f1_score(yva, yhat, average="binary")

    rec = per_class_recall(yva, yhat)
    cm  = confusion_matrix(yva, yhat).tolist()

    result = dict(
        time_sec=float(elapsed),
        best_iter=int(booster.best_iteration),
        macro_f1=float(macro_f1),
        n_features_input=int(Xtr.shape[1]),
        n_features_effective=len(booster.feature_name()), # 观察 EFB 后有效特征数
        params=booster.params
    )
    return booster, result, rec, cm


# 120_lgbm_GOSS_only.ipynb
USE_GOSS = True
ENABLE_EFB = False

TOP_RATE   = 0.10   # 只有在 GOSS=True 时生效
OTHER_RATE = 0.10

OBJECTIVE  = "multiclass"       # 如果只做二分类，就改成 "binary"
NUM_CLASS  = len(np.unique(y))  # 二分类时可以设为 None
SETTING_NAME = f"{'goss' if USE_GOSS else 'gbdt'}_{'EFB' if ENABLE_EFB else 'noEFB'}"
SETTING_NAME
booster, result, rec, cm = train_lgbm(
    Xtr, ytr, Xva, yva,
    use_goss=USE_GOSS,
    enable_efb=ENABLE_EFB,
    top_rate=TOP_RATE,
    other_rate=OTHER_RATE,
    objective=OBJECTIVE,
    num_class=NUM_CLASS
)

print(json.dumps({**result, "setting": SETTING_NAME}, indent=2))

# 保存日志/模型/表格
with open(f"results/logs/{SETTING_NAME}.json", "w") as f:
    json.dump({**result, "setting": SETTING_NAME, "per_class_recall": rec, "cm": cm}, f, indent=2)

booster.save_model(f"results/{SETTING_NAME}.txt")

# 汇总到一个 CSV（若不存在则创建）
row = {
    "setting": SETTING_NAME,
    "time_sec": result["time_sec"],
    "best_iter": result["best_iter"],
    "macro_f1": result["macro_f1"],
    "M_raw": result["n_features_input"],
    "M_effective": result["n_features_effective"],
    "use_goss": USE_GOSS,
    "enable_efb": ENABLE_EFB,
    "top_rate": TOP_RATE if USE_GOSS else None,
    "other_rate": OTHER_RATE if USE_GOSS else None
}
tbl_path = Path("results/tables/efb_goss_ablation.csv")
pd.DataFrame([row]).to_csv(tbl_path, mode="a", header=not tbl_path.exists(), index=False)
tbl_path
# per-class recall 柱状图
save_fig = f"results/figures/recall_{SETTING_NAME}.png"
plot_recall_bar(rec, class_names, f"Per-class Recall: {SETTING_NAME}", save_fig)

# 更详细的分类报告（验证集）
proba = booster.predict(Xva, num_iteration=booster.best_iteration)
yhat = proba.argmax(1) if OBJECTIVE=="multiclass" else (proba>=0.5).astype(int)
print(classification_report(yva, yhat, target_names=class_names, digits=4))

FileNotFoundError: [Errno 2] No such file or directory: 'data/ciciot2023_subset.parquet'