In [2]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import os

# ---------------------------------------------------------
# CONFIG
# ---------------------------------------------------------
results_dir = "./results/"        # where your training saved files
analysis_dir = "./analysis/"      # where analysis will be saved
os.makedirs(analysis_dir, exist_ok=True)

n_folds = 3                        # you used 3 outer folds
# ---------------------------------------------------------

metrics_summary = []

for fold in range(1, n_folds+1):
    print(f"\n=== ANALYZING FOLD {fold} ===")

    # ------------------------------
    # Load predictions and Optuna file
    # ------------------------------
    pred_file = f"{results_dir}/fold_{fold}_predictions.csv"
    trials_file = f"{results_dir}/optuna_outer_fold_{fold}.csv"
    fi_file = f"{results_dir}/fold_{fold}_feature_importance.csv"

    df_pred = pd.read_csv(pred_file)
    df_trials = pd.read_csv(trials_file)
    df_fi = pd.read_csv(fi_file)

    # ------------------------------
    # Compute metrics
    # ------------------------------
    mae = np.mean(np.abs(df_pred["pred"] - df_pred["true"]))
    rmse = np.sqrt(np.mean((df_pred["pred"] - df_pred["true"])**2))
    r2 = 1 - np.sum((df_pred["pred"] - df_pred["true"])**2) / np.sum((df_pred["true"] - df_pred["true"].mean())**2)

    metrics_summary.append([fold, mae, rmse, r2])

    print(f"Fold {fold} MAE:  {mae:.4f}")
    print(f"Fold {fold} RMSE: {rmse:.4f}")
    print(f"Fold {fold} R2:   {r2:.4f}")

    # ------------------------------
    # Save metrics table
    # ------------------------------
    df_pred["error"] = df_pred["pred"] - df_pred["true"]

    # Pred vs True Plot
    plt.figure(figsize=(6,6))
    sns.scatterplot(x=df_pred["true"], y=df_pred["pred"], s=20)
    minv, maxv = df_pred["true"].min(), df_pred["true"].max()
    plt.plot([minv, maxv], [minv, maxv], "k--")
    plt.xlabel("True Band Gap")
    plt.ylabel("Predicted Band Gap")
    plt.title(f"Fold {fold}: Predicted vs True")
    plt.tight_layout()
    plt.savefig(f"{analysis_dir}/fold_{fold}_pred_vs_true.png")
    plt.close()

    # Error Histogram
    plt.figure(figsize=(6,4))
    sns.histplot(df_pred["error"], bins=40, kde=True)
    plt.title(f"Fold {fold} Error Distribution")
    plt.xlabel("Prediction Error")
    plt.tight_layout()
    plt.savefig(f"{analysis_dir}/fold_{fold}_error_distribution.png")
    plt.close()

    # Feature Importance
    top_fi = df_fi.sort_values("importance", ascending=False).head(25)
    plt.figure(figsize=(8,6))
    sns.barplot(y="feature", x="importance", data=top_fi)
    plt.title(f"Fold {fold} Top 25 Features")
    plt.tight_layout()
    plt.savefig(f"{analysis_dir}/fold_{fold}_feature_importance.png")
    plt.close()

    # Optuna Trials
    if "value" in df_trials.columns:
        plt.figure(figsize=(6,4))
        sns.scatterplot(x=df_trials.index, y=df_trials["value"])
        plt.xlabel("Trial")
        plt.ylabel("Validation MAE")
        plt.title(f"Fold {fold} Optuna Trial Performance")
        plt.tight_layout()
        plt.savefig(f"{analysis_dir}/fold_{fold}_optuna_trials.png")
        plt.close()

# ---------------------------------------------------------
# SAVE SUMMARY
# ---------------------------------------------------------
df_summary = pd.DataFrame(metrics_summary, columns=["Fold", "MAE", "RMSE", "R2"])
df_summary.to_csv(f"{analysis_dir}/cv_summary.csv", index=False)

print("\n=== ANALYSIS COMPLETE ===")
print(df_summary)
print(f"All plots & tables saved to: {analysis_dir}")



=== ANALYZING FOLD 1 ===
Fold 1 MAE:  0.1466
Fold 1 RMSE: 0.2172
Fold 1 R2:   0.9712

=== ANALYZING FOLD 2 ===
Fold 2 MAE:  0.1468
Fold 2 RMSE: 0.2243
Fold 2 R2:   0.9694

=== ANALYZING FOLD 3 ===
Fold 3 MAE:  0.1493
Fold 3 RMSE: 0.2257
Fold 3 R2:   0.9688

=== ANALYSIS COMPLETE ===
   Fold       MAE      RMSE        R2
0     1  0.146603  0.217233  0.971233
1     2  0.146834  0.224340  0.969372
2     3  0.149291  0.225682  0.968836
All plots & tables saved to: ./analysis/
