In [1]:
    # demo.ipynb

    import joblib
    import matplotlib.pyplot as plt

    from src.data_preprocessing import load_and_preprocess_data
    from src.model_from_scratch import LogisticRegressionScratch
    from src.model_sklearn import train_sklearn_model, predict_sklearn_model
    from src.evaluation import evaluate_model, plot_confusion_matrix, plot_roc_curve

    # 1. 加载数据
    X_train, X_test, y_train, y_test, vectorizer = load_and_preprocess_data("data/spam.csv")

    # 2. 训练 Scratch 模型
    scratch_model = LogisticRegressionScratch(lr=0.1, n_iter=300, class_weight="balanced", verbose=False)
    scratch_model.fit(X_train, y_train)
    y_pred_scratch = scratch_model.predict(X_test)

    # 3. 训练 Sklearn 模型
    sklearn_model = train_sklearn_model(X_train, y_train)
    y_pred_sklearn = predict_sklearn_model(sklearn_model, X_test)

    # 4. 评估结果
    print("=== Logistic Regression (Scratch) ===")
    metrics_scratch = evaluate_model(y_test, y_pred_scratch, "Scratch LR")

    print("\n=== Logistic Regression (Sklearn) ===")
    metrics_sklearn = evaluate_model(y_test, y_pred_sklearn, "Sklearn LR")

    # 5. 可视化
    plot_confusion_matrix(y_test, y_pred_scratch, save_path="results/confusion_scratch.png")
    plot_confusion_matrix(y_test, y_pred_sklearn, save_path="results/confusion_sklearn.png")
    plot_roc_curve(scratch_model, X_test, y_test, save_path="results/roc_scratch.png")
    plot_roc_curve(sklearn_model, X_test, y_test, save_path="results/roc_sklearn.png")

    print("\n图表已保存至 results/ 文件夹")

=== Logistic Regression (Scratch) ===
[Scratch LR] accuracy=0.9758 precision=0.9067 recall=0.9128 f1=0.9097
              precision    recall  f1-score   support

           0     0.9865    0.9855    0.9860       966
           1     0.9067    0.9128    0.9097       149

    accuracy                         0.9758      1115
   macro avg     0.9466    0.9491    0.9479      1115
weighted avg     0.9759    0.9758    0.9758      1115


=== Logistic Regression (Sklearn) ===
[Sklearn LR] accuracy=0.9821 precision=0.9448 recall=0.9195 f1=0.9320
              precision    recall  f1-score   support

           0     0.9876    0.9917    0.9897       966
           1     0.9448    0.9195    0.9320       149

    accuracy                         0.9821      1115
   macro avg     0.9662    0.9556    0.9608      1115
weighted avg     0.9819    0.9821    0.9820      1115


图表已保存至 results/ 文件夹
