In [None]:
# fresh_ml_bin_variants.py
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import hashlib

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve, classification_report, confusion_matrix
from sklearn.linear_model import LogisticRegression

from maldi_nn.spectrum import (
    SpectrumObject, SequentialPreprocessor, VarStabilizer, Smoother,
    BaselineCorrecter, Trimmer, PersistenceTransformer, Normalizer, Binner
)

# --- Config ---
neg_dir = r"Y:\\test_set\\allspectra\\neg_spectra\\neg_tsv"
pos_dir = r"Y:\\test_set\\allspectra\\pos_spectra\\pos_tsv"
ribo_masslist = r"Y:\\test_set\\ribo_Saureus.tsv"

# --- Data Leakage Check ---
def check_sample_uniqueness():
    all_files = []
    for label_dir in [neg_dir, pos_dir]:
        if os.path.exists(label_dir):
            all_files += [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".tsv")]

    basenames = [os.path.basename(f) for f in all_files]
    duplicates = pd.Series(basenames).duplicated()
    if duplicates.any():
        print("Data leakage detected! Duplicate sample names found across classes:")
        print(pd.Series(basenames)[duplicates].value_counts())
    else:
        print(f"No data leakage detected. {len(basenames)} unique sample files.")

check_sample_uniqueness()

# --- Load ribosomal m/z values ---
try:
    ribo_df = pd.read_csv(ribo_masslist, sep="\t", header=0, comment="#")
    ribo_mz = pd.to_numeric(ribo_df['Mass'], errors='coerce').dropna().values if 'Mass' in ribo_df.columns else np.array([])
    print(f"Loaded {len(ribo_mz)} ribosomal m/z values.")
except:
    ribo_mz = np.array([])
    print("Failed to load ribosomal mass list. Proceeding without alignment.")

# --- Alignment ---
def align_spectrum(spectrum, ref_mz_list, top_n=10, ppm=100):
    if not spectrum or len(ref_mz_list) == 0: return spectrum
    current_mz = np.asarray(spectrum.mz)
    current_intensity = np.asarray(spectrum.intensity)
    idx = np.argsort(-current_intensity)[:top_n]
    best_shift = 0.0
    min_error = float('inf')
    for peak in current_mz[idx]:
        delta = ref_mz_list - peak
        mask = np.abs(delta) <= (peak * ppm / 1e6)
        if np.any(mask):
            shift = delta[mask][np.argmin(np.abs(delta[mask]))]
            if abs(shift) < min_error:
                best_shift, min_error = shift, abs(shift)
    spectrum.mz = current_mz + best_shift
    return spectrum

# --- Preprocessing Steps ---
preproc_variants = [True, False]
bin_sizes = [1, 2, 3, 4, 5]

# Random Forest variants
rf_variants = [
    RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42),
    RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42),
    RandomForestClassifier(n_estimators=200, max_depth=None, random_state=42),
    RandomForestClassifier(n_estimators=300, max_depth=15, random_state=42),
    RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)
]

# SVM variants (with scaling)
svm_variants = [
    Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=0.1, kernel="linear", probability=True, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=1, kernel="linear", probability=True, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=10, kernel="linear", probability=True, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=1, kernel="rbf", probability=True, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=10, kernel="rbf", probability=True, random_state=42))])
]

# LightGBM variants
lgbm_variants = [
    LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42, verbosity=-1),
    LGBMClassifier(n_estimators=200, learning_rate=0.1, random_state=42, verbosity=-1),
    LGBMClassifier(n_estimators=200, learning_rate=0.2, random_state=42, verbosity=-1),
    LGBMClassifier(n_estimators=300, learning_rate=0.05, random_state=42, verbosity=-1),
    LGBMClassifier(n_estimators=100, learning_rate=0.01, random_state=42, verbosity=-1)
]

# XGBoost variants
xgb_variants = [
    XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=100, learning_rate=0.1, random_state=42),
    XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=200, learning_rate=0.1, random_state=42),
    XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=200, learning_rate=0.2, random_state=42),
    XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=300, learning_rate=0.05, random_state=42),
    XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=100, learning_rate=0.01, random_state=42)
]

# Logistic Regression variants (with scaling)
logreg_variants = [
    Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=0.01, max_iter=1000, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=0.1, max_iter=1000, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=1, max_iter=1000, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=10, max_iter=1000, random_state=42))]),
    Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=100, max_iter=1000, random_state=42))])
]

summary = []
metrics_detailed = []
plt.figure(figsize=(12, 8))

for bin_step in bin_sizes:
    for use_persistence in preproc_variants:
        print(f"\n Processing bin size {bin_step}, persistence={use_persistence}...")
        preproc = SequentialPreprocessor(
            VarStabilizer("sqrt"),
            Smoother(10),
            BaselineCorrecter("SNIP", 20),
            Trimmer(),
            PersistenceTransformer(use_persistence),
            Normalizer(1)
        )
        binner = Binner(2000, 15000, bin_step)
        X, y = [], []

        for folder, label in [(neg_dir, 0), (pos_dir, 1)]:
            if not os.path.exists(folder): continue
            for fname in os.listdir(folder):
                if not fname.endswith(".tsv") or not fname.startswith("2024"): continue
                path = os.path.join(folder, fname)
                try:
                    spec = SpectrumObject.from_tsv(path, sep="\t")
                    spec = preproc(spec)
                    spec = align_spectrum(spec, ribo_mz, ppm=100)
                    spec = binner(spec)
                    if len(spec.intensity):
                        X.append(np.asarray(spec.intensity).flatten())
                        y.append(label)
                except Exception as e:
                    print(f"Skipped {fname}: {e}")

        if len(X) < 10:
            print("Not enough data, skipping.")
            continue

        X, y = np.stack(X), np.array(y)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        def run_model(model, name):
            try:
                model.fit(X_train, y_train)
                probas = model.predict_proba(X_test)[:, 1]
                preds = model.predict(X_test)
                auc = roc_auc_score(y_test, probas)
                fpr, tpr, _ = roc_curve(y_test, probas)
                plt.plot(fpr, tpr, label=f"{name} (AUC={auc:.3f})")
                report = classification_report(y_test, preds, output_dict=True)
                cmatrix = confusion_matrix(y_test, preds).tolist()
                summary.append({"variant": name, "auc": auc})
                metrics_detailed.append({"variant": name, "auc": auc, "report": report, "confusion_matrix": cmatrix})
                print(f"{name} AUC={auc:.3f}")
                print(pd.DataFrame(report))
                print(f"Confusion Matrix:\n{cmatrix}")
            except Exception as e:
                print(f"{name} failed: {e}")
                metrics_detailed.append({"variant": name, "auc": 0.0, "error": str(e)})

        for i, model in enumerate(rf_variants):
            run_model(model, f"RF_bin{bin_step}_pers{use_persistence}_v{i}")
        for i, model in enumerate(svm_variants):
            run_model(model, f"SVM_bin{bin_step}_pers{use_persistence}_v{i}")
        for i, model in enumerate(lgbm_variants):
            run_model(model, f"LGBM_bin{bin_step}_pers{use_persistence}_v{i}")
        for i, model in enumerate(xgb_variants):
            run_model(model, f"XGB_bin{bin_step}_pers{use_persistence}_v{i}")
        for i, model in enumerate(logreg_variants):
            run_model(model, f"LOGREG_bin{bin_step}_pers{use_persistence}_v{i}")

        try:
            print(f"Running KMeans_bin{bin_step}_pers{use_persistence}")
            km = KMeans(n_clusters=2, random_state=42, n_init='auto').fit(X_train)
            preds = km.predict(X_test)
            if np.mean(y_test[preds == 0]) > np.mean(y_test[preds == 1]):
                preds = 1 - preds
            auc = roc_auc_score(y_test, preds)
            fpr, tpr, _ = roc_curve(y_test, preds)
            plt.plot(fpr, tpr, label=f"KMeans_bin{bin_step}_pers{use_persistence} (AUC={auc:.3f})")
            report = classification_report(y_test, preds, output_dict=True)
            cmatrix = confusion_matrix(y_test, preds).tolist()
            summary.append({"variant": f"KMeans_bin{bin_step}_pers{use_persistence}", "auc": auc})
            metrics_detailed.append({"variant": f"KMeans_bin{bin_step}_pers{use_persistence}", "auc": auc, "report": report, "confusion_matrix": cmatrix})
            print(f"KMeans_bin{bin_step}_pers{use_persistence} AUC={auc:.3f}")
            print(pd.DataFrame(report))
            print(f"Confusion Matrix:\n{cmatrix}")
        except Exception as e:
            print(f"KMeans failed: {e}")
            metrics_detailed.append({"variant": f"KMeans_bin{bin_step}_pers{use_persistence}", "auc": 0.0, "error": str(e)})

plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
plt.legend(fontsize='x-small', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout(rect=[0, 0, 0.75, 1])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curves for Bins 1-5 (2024 data only)")
plt.savefig("roc_bins_2024.png", dpi=300, bbox_inches='tight')

pd.DataFrame(summary).to_csv("auc_summary_bins_2024.csv", index=False)
pd.DataFrame(metrics_detailed).to_json("ml_metrics_bins_2024_detailed.json", orient="records", indent=2)
print("\n Done: Summary written to auc_summary_bins_2024.csv, full metrics in ml_metrics_bins_2024_detailed.json, and ROC to roc_bins_2024.png")


In [None]:
# fresh_ml_bin_variants.py
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import hashlib
import shap
import warnings
warnings.filterwarnings("ignore", category=UserWarning)

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve, classification_report, confusion_matrix
from sklearn.linear_model import LogisticRegression

from maldi_nn.spectrum import (
    SpectrumObject, SequentialPreprocessor, VarStabilizer, Smoother,
    BaselineCorrecter, Trimmer, PersistenceTransformer, Normalizer, Binner
)

# --- Config ---
neg_dir = r"Y:\\test_set\\allspectra\\neg_spectra\\neg_tsv"
pos_dir = r"Y:\\test_set\\allspectra\\pos_spectra\\pos_tsv"
ribo_masslist = r"Y:\\test_set\\ribo_Saureus.tsv"
BIN_START_MZ = 2000
BIN_END_MZ = 15000

# --- Data Leakage Check ---
def check_sample_uniqueness():
    all_files = []
    for label_dir in [neg_dir, pos_dir]:
        if os.path.exists(label_dir):
            all_files += [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".tsv")]

    basenames = [os.path.basename(f) for f in all_files]
    duplicates = pd.Series(basenames).duplicated()
    if duplicates.any():
        print("Data leakage detected! Duplicate sample names found across classes:")
        print(pd.Series(basenames)[duplicates].value_counts())
    else:
        print(f"No data leakage detected. {len(basenames)} unique sample files.")

check_sample_uniqueness()

# --- Load ribosomal m/z values ---
try:
    ribo_df = pd.read_csv(ribo_masslist, sep="\t", header=0, comment="#")
    ribo_mz = pd.to_numeric(ribo_df['Mass'], errors='coerce').dropna().values if 'Mass' in ribo_df.columns else np.array([])
    print(f"Loaded {len(ribo_mz)} ribosomal m/z values.")
except:
    ribo_mz = np.array([])
    print("Failed to load ribosomal mass list. Proceeding without alignment.")

# --- Alignment ---
def align_spectrum(spectrum, ref_mz_list, top_n=10, ppm=100):
    if not hasattr(spectrum, 'mz') or not hasattr(spectrum, 'intensity') or len(ref_mz_list) == 0:
        return spectrum
    current_mz = np.asarray(spectrum.mz)
    current_intensity = np.asarray(spectrum.intensity)
    if len(current_mz) == 0:
        return spectrum
    idx = np.argsort(-current_intensity)[:top_n]
    best_shift = 0.0
    min_error = float('inf')
    for peak in current_mz[idx]:
        delta = ref_mz_list - peak
        mask = np.abs(delta) <= (peak * ppm / 1e6)
        if np.any(mask):
            shift = delta[mask][np.argmin(np.abs(delta[mask]))]
            if abs(shift) < min_error:
                best_shift, min_error = shift, abs(shift)
    spectrum.mz = current_mz + best_shift
    return spectrum

# --- Preprocessing Steps ---
preproc_variants = [True, False]
bin_sizes = [1, 2, 3, 4, 5]

# --- Model Definitions ---
# Random Forest variants
rf_variants = [RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42), RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42), RandomForestClassifier(n_estimators=200, max_depth=None, random_state=42), RandomForestClassifier(n_estimators=300, max_depth=15, random_state=42), RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42)]
# SVM variants (with scaling)
svm_variants = [Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=0.1, kernel="linear", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=1, kernel="linear", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=10, kernel="linear", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=1, kernel="rbf", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=10, kernel="rbf", probability=True, random_state=42))])]
# LightGBM variants
lgbm_variants = [LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42, verbosity=-1), LGBMClassifier(n_estimators=200, learning_rate=0.1, random_state=42, verbosity=-1), LGBMClassifier(n_estimators=200, learning_rate=0.2, random_state=42, verbosity=-1), LGBMClassifier(n_estimators=300, learning_rate=0.05, random_state=42, verbosity=-1), LGBMClassifier(n_estimators=100, learning_rate=0.01, random_state=42, verbosity=-1)]
# XGBoost variants
xgb_variants = [XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=100, learning_rate=0.1, random_state=42), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=200, learning_rate=0.1, random_state=42), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=200, learning_rate=0.2, random_state=42), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=300, learning_rate=0.05, random_state=42), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=100, learning_rate=0.01, random_state=42)]
# Logistic Regression variants (with scaling)
logreg_variants = [Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=0.01, max_iter=1000, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=0.1, max_iter=1000, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=1, max_iter=1000, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=10, max_iter=1000, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=100, max_iter=1000, random_state=42))])]

summary = []
metrics_detailed = []
plt.figure(figsize=(12, 8))

for bin_step in bin_sizes:
    for use_persistence in preproc_variants:
        print(f"\n Processing bin size {bin_step}, persistence={use_persistence}...")
        preproc = SequentialPreprocessor(
            VarStabilizer("sqrt"),
            Smoother(10),
            BaselineCorrecter("SNIP", 20),
            Trimmer(),
            PersistenceTransformer(use_persistence),
            Normalizer(1)
        )
        binner = Binner(BIN_START_MZ, BIN_END_MZ, bin_step)
        X, y = [], []

        for folder, label in [(neg_dir, 0), (pos_dir, 1)]:
            if not os.path.exists(folder): continue
            for fname in os.listdir(folder):
                if not fname.endswith(".tsv") or not fname.startswith("2024"): continue
                path = os.path.join(folder, fname)
                try:
                    spec = SpectrumObject.from_tsv(path, sep="\t")
                    spec = preproc(spec)
                    spec = align_spectrum(spec, ribo_mz, ppm=100)
                    spec = binner(spec)
                    if len(spec.intensity):
                        X.append(np.asarray(spec.intensity).flatten())
                        y.append(label)
                except Exception as e:
                    print(f"Skipped {fname}: {e}")

        if len(X) < 10:
            print("Not enough data, skipping.")
            continue

        X, y = np.stack(X), np.array(y)
        X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

        # --- NEW: Generate m/z range feature names ---
        mz_feature_names = [f"{int(BIN_START_MZ + i * bin_step)}-{int(BIN_START_MZ + (i + 1) * bin_step)}" for i in range(X_train.shape[1])]

        # --- UPDATED: run_model function now accepts feature_names ---
        def run_model(model, name, feature_names):
            try:
                model.fit(X_train, y_train)
                probas = model.predict_proba(X_test)[:, 1]
                preds = model.predict(X_test)
                auc = roc_auc_score(y_test, probas)
                fpr, tpr, _ = roc_curve(y_test, probas)
                plt.plot(fpr, tpr, label=f"{name} (AUC={auc:.3f})")
                report = classification_report(y_test, preds, output_dict=True)
                cmatrix = confusion_matrix(y_test, preds).tolist()
                summary.append({"variant": name, "auc": auc})
                metrics_detailed.append({"variant": name, "auc": auc, "report": report, "confusion_matrix": cmatrix})
                print(f"{name} AUC={auc:.3f}")
                print(pd.DataFrame(report))
                print(f"Confusion Matrix:\n{cmatrix}")

                try:
                    if hasattr(model, "predict_proba") or hasattr(model, "decision_function"):
                        # UPDATED: Use the passed feature_names list
                        if not isinstance(model, (RandomForestClassifier, LGBMClassifier, XGBClassifier)):
                            print("Using summarized background data for SHAP (non-tree model)...")
                            background_data = shap.kmeans(X_train, 100)
                            explainer = shap.Explainer(model, background_data, feature_names=feature_names)
                        else:
                            explainer = shap.Explainer(model, X_train, feature_names=feature_names)

                        shap_values = explainer(X_test)
                        shap_values_class_1 = shap_values[:,:,1]

                        bees_path = f"shap_beeswarm_{name}.png"
                        water_path = f"shap_waterfall_{name}.png"

                        plt.figure()
                        # UPDATED: max_display is now 25
                        shap.plots.beeswarm(shap_values_class_1, max_display=25, show=False)
                        plt.title(f"SHAP Beeswarm - {name}")
                        plt.tight_layout()
                        plt.savefig(bees_path)
                        plt.close()

                        plt.figure()
                        shap.plots.waterfall(shap_values_class_1[0], show=False)
                        plt.title(f"SHAP Waterfall (1st sample) - {name}")
                        plt.tight_layout()
                        plt.savefig(water_path)
                        plt.close()

                        print(f"SHAP plots saved: {bees_path}, {water_path}")
                except Exception as e:
                    print(f"⚠️ SHAP failed for {name}: {e}")

            except Exception as e:
                print(f"{name} failed: {e}")
                metrics_detailed.append({"variant": name, "auc": 0.0, "error": str(e)})
        
        # --- UPDATED: Pass mz_feature_names to each model run ---
        for i, model in enumerate(rf_variants):
            run_model(model, f"RF_bin{bin_step}_pers{use_persistence}_v{i}", mz_feature_names)
        for i, model in enumerate(svm_variants):
            run_model(model, f"SVM_bin{bin_step}_pers{use_persistence}_v{i}", mz_feature_names)
        for i, model in enumerate(lgbm_variants):
            run_model(model, f"LGBM_bin{bin_step}_pers{use_persistence}_v{i}", mz_feature_names)
        for i, model in enumerate(xgb_variants):
            run_model(model, f"XGB_bin{bin_step}_pers{use_persistence}_v{i}", mz_feature_names)
        for i, model in enumerate(logreg_variants):
            run_model(model, f"LOGREG_bin{bin_step}_pers{use_persistence}_v{i}", mz_feature_names)

        try:
            print(f"Running KMeans_bin{bin_step}_pers{use_persistence}")
            km = KMeans(n_clusters=2, random_state=42, n_init='auto').fit(X_train)
            preds = km.predict(X_test)
            if np.mean(y_test[preds == 0]) > np.mean(y_test[preds == 1]): preds = 1 - preds
            auc = roc_auc_score(y_test, preds)
            fpr, tpr, _ = roc_curve(y_test, preds)
            plt.plot(fpr, tpr, label=f"KMeans_bin{bin_step}_pers{use_persistence} (AUC={auc:.3f})")
            report = classification_report(y_test, preds, output_dict=True)
            cmatrix = confusion_matrix(y_test, preds).tolist()
            summary.append({"variant": f"KMeans_bin{bin_step}_pers{use_persistence}", "auc": auc})
            metrics_detailed.append({"variant": f"KMeans_bin{bin_step}_pers{use_persistence}", "auc": auc, "report": report, "confusion_matrix": cmatrix})
            print(f"KMeans_bin{bin_step}_pers{use_persistence} AUC={auc:.3f}")
            print(pd.DataFrame(report))
            print(f"Confusion Matrix:\n{cmatrix}")
        except Exception as e:
            print(f"KMeans failed: {e}")
            metrics_detailed.append({"variant": f"KMeans_bin{bin_step}_pers{use_persistence}", "auc": 0.0, "error": str(e)})

plt.plot([0, 1], [0, 1], linestyle="--", color="gray")
plt.legend(fontsize='x-small', bbox_to_anchor=(1.05, 1), loc='upper left')
plt.tight_layout(rect=[0, 0, 0.75, 1])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curves for Bins 1-5 (2024 data only)")
plt.savefig("roc_bins_2024.png", dpi=300, bbox_inches='tight')

pd.DataFrame(summary).to_csv("auc_summary_bins_2024.csv", index=False)
pd.DataFrame(metrics_detailed).to_json("ml_metrics_bins_2024_detailed.json", orient="records", indent=2)
print("\n Done: Summary written to auc_summary_bins_2024.csv, full metrics in ml_metrics_bins_2024_detailed.json, and ROC to roc_bins_2024.png")

In [None]:
# fresh_ml_bin_variants.py
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import hashlib
import shap
from shap import KernelExplainer
import warnings
from sklearn.base import clone
from joblib import Parallel, delayed
from scipy.signal import find_peaks
from sklearn.linear_model import RANSACRegressor

warnings.filterwarnings("ignore", category=UserWarning)

from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from lightgbm import LGBMClassifier
from xgboost import XGBClassifier
from sklearn.cluster import KMeans
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, roc_curve, classification_report, confusion_matrix
from sklearn.linear_model import LogisticRegression

from maldi_nn.spectrum import (
    SpectrumObject, SequentialPreprocessor, VarStabilizer, Smoother,
    BaselineCorrecter, Trimmer, PersistenceTransformer, Normalizer, Binner
)

# --- Config ---
neg_dir = r"Y:\\test_set\\allspectra\\neg_spectra\\neg_tsv"
pos_dir = r"Y:\\test_set\\allspectra\\pos_spectra\\pos_tsv"
ribo_masslist = r"Y:\\test_set\\ribo_Saureus.tsv"
BIN_START_MZ = 2000
BIN_END_MZ = 15000
ALL_YEARS = ['2022', '2023', '2024']

# --- RANSAC Aligner Class ---
class Aligner:
    """A preprocessor step to align a spectrum's m/z axis to a reference peak list."""
    def __init__(self, reference_peaks, tolerance=500):
        self.reference_peaks = np.array(reference_peaks)
        self.tolerance = tolerance

    def __call__(self, spectrum: SpectrumObject) -> SpectrumObject:
        peaks, _ = find_peaks(spectrum.intensity, height=0.05 * np.max(spectrum.intensity), prominence=0.01)
        if len(peaks) < 5: return spectrum
        sample_peak_mzs = spectrum.mz[peaks]
        shifts = []
        for ref_peak in self.reference_peaks:
            tol_daltons = self.tolerance * ref_peak / 1e6
            matches = np.where(np.abs(sample_peak_mzs - ref_peak) < tol_daltons)[0]
            if len(matches) > 0:
                best_match_idx = matches[np.argmin(np.abs(sample_peak_mzs[matches] - ref_peak))]
                sample_peak_mz = sample_peak_mzs[best_match_idx]
                shifts.append((sample_peak_mz, ref_peak - sample_peak_mz))
        if len(shifts) < 3: return spectrum
        shifts = np.array(shifts)
        sample_mzs_for_fit, mz_shifts_for_fit = shifts[:, 0].reshape(-1, 1), shifts[:, 1]
        try:
            ransac = RANSACRegressor(random_state=42).fit(sample_mzs_for_fit, mz_shifts_for_fit)
            mz_correction = ransac.predict(spectrum.mz.reshape(-1, 1))
            return SpectrumObject(spectrum.mz + mz_correction, spectrum.intensity)
        except ValueError:
            return spectrum

# --- Data Leakage Check ---
def check_sample_uniqueness():
    all_files = []
    for label_dir in [neg_dir, pos_dir]:
        if os.path.exists(label_dir):
            all_files += [os.path.join(label_dir, f) for f in os.listdir(label_dir) if f.endswith(".tsv")]
    basenames = [os.path.basename(f) for f in all_files]
    if pd.Series(basenames).duplicated().any():
        print("Data leakage detected!")
    else:
        print(f"No data leakage detected. {len(basenames)} unique sample files.")
check_sample_uniqueness()

# --- Reference Peak Loading ---
try:
    ref_df = pd.read_csv(ribo_masslist, sep='\t', engine='python')
    proton_mass = 1.007276
    reference_peaks = (ref_df['Mass'] + proton_mass).values
    print(f"Loaded and processed {len(reference_peaks)} reference peaks for alignment.")
except Exception as e:
    reference_peaks = np.array([])
    print(f"Failed to load reference mass list, proceeding without alignment: {e}")

# --- RESTORED: Experiment parameters ---
preproc_variants = [True, False]
bin_sizes = [1, 2, 3, 4, 5]

# --- Model Definitions (SVM Last) ---
model_definitions = {
    "RF": [RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1), RandomForestClassifier(n_estimators=200, max_depth=10, random_state=42, n_jobs=-1), RandomForestClassifier(n_estimators=200, max_depth=None, random_state=42, n_jobs=-1), RandomForestClassifier(n_estimators=300, max_depth=15, random_state=42, n_jobs=-1), RandomForestClassifier(n_estimators=100, max_depth=5, random_state=42, n_jobs=-1)],
    "LGBM": [LGBMClassifier(n_estimators=100, learning_rate=0.05, random_state=42, verbosity=-1, n_jobs=-1, device='gpu'), LGBMClassifier(n_estimators=200, learning_rate=0.1, random_state=42, verbosity=-1, n_jobs=-1, device='gpu'), LGBMClassifier(n_estimators=200, learning_rate=0.2, random_state=42, verbosity=-1, n_jobs=-1, device='gpu'), LGBMClassifier(n_estimators=300, learning_rate=0.05, random_state=42, verbosity=-1, n_jobs=-1, device='gpu'), LGBMClassifier(n_estimators=100, learning_rate=0.01, random_state=42, verbosity=-1, n_jobs=-1, device='gpu')],
    "XGB": [XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=100, learning_rate=0.1, random_state=42, n_jobs=-1, device='cuda', tree_method='hist'), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=200, learning_rate=0.1, random_state=42, n_jobs=-1, device='cuda', tree_method='hist'), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=200, learning_rate=0.2, random_state=42, n_jobs=-1, device='cuda', tree_method='hist'), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=300, learning_rate=0.05, random_state=42, n_jobs=-1, device='cuda', tree_method='hist'), XGBClassifier(use_label_encoder=False, eval_metric='logloss', n_estimators=100, learning_rate=0.01, random_state=42, n_jobs=-1, device='cuda', tree_method='hist')],
    "LOGREG": [Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=0.01, max_iter=1000, random_state=42, n_jobs=-1))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=0.1, max_iter=1000, random_state=42, n_jobs=-1))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=1, max_iter=1000, random_state=42, n_jobs=-1))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=10, max_iter=1000, random_state=42, n_jobs=-1))]), Pipeline([("scaler", StandardScaler()), ("clf", LogisticRegression(C=100, max_iter=1000, random_state=42, n_jobs=-1))])],
    "SVM": [Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=0.1, kernel="linear", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=1, kernel="linear", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=10, kernel="linear", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=1, kernel="rbf", probability=True, random_state=42))]), Pipeline([("scaler", StandardScaler()), ("clf", SVC(C=10, kernel="rbf", probability=True, random_state=42))])]
}
summary, metrics_detailed = [], []

def evaluate_performance(model, name, feature_names, X_train, X_test, y_test):
    try:
        probas = model.predict_proba(X_test)[:, 1]
        preds = model.predict(X_test)
        auc = roc_auc_score(y_test, probas)
        fpr, tpr, _ = roc_curve(y_test, probas)
        plt.plot(fpr, tpr, label=f"{name} (AUC={auc:.3f})")
        report = classification_report(y_test, preds, output_dict=True, zero_division=0)
        cmatrix = confusion_matrix(y_test, preds).tolist()
        summary.append({"variant": name, "auc": auc})
        metrics_detailed.append({"variant": name, "auc": auc, "report": report, "confusion_matrix": cmatrix})
        print(f"{name} AUC={auc:.3f}")
        
        try:
            is_tree_model = isinstance(model, (RandomForestClassifier, LGBMClassifier, XGBClassifier))
            
            if is_tree_model:
                print("Using fast TreeExplainer...")
                X_test_subset = shap.sample(X_test, 100, random_state=42)
                explainer = shap.Explainer(model, X_train)
                shap_explanation = explainer(X_test_subset)

                if shap_explanation.values.ndim == 3:
                    shap_explanation = shap_explanation[:,:,1]
                
                shap_explanation.feature_names = feature_names
                
                plt.figure() 
                shap.plots.beeswarm(shap_explanation, max_display=25, show=False)
                plt.title(f"SHAP Beeswarm - {name}")
                plt.tight_layout()
                plt.savefig(f"shap_beeswarm_{name}.png")
                plt.close()
            else:
                model_type = "Non-Tree Model"
                if isinstance(model, Pipeline):
                    model_type = model.named_steps['clf'].__class__.__name__
                print(f"Skipping SHAP for slow model type: {model_type}")

        except Exception as e: 
            print(f"SHAP failed for {name}: {e}")

    except Exception as e:
        print(f"Evaluation failed for {name}: {e}")
        metrics_detailed.append({"variant": name, "auc": 0.0, "error": str(e)})

def process_file(filepath, label, preproc, aligner, binner):
    """Loads and preprocesses a single spectrum file using the Aligner class."""
    try:
        spec = SpectrumObject.from_tsv(filepath, sep="\t")
        spec = preproc(spec)
        spec = aligner(spec)
        spec = binner(spec)
        if len(spec.intensity) > 0: return (np.asarray(spec.intensity).flatten(), label)
    except Exception as e: print(f"Skipped {os.path.basename(filepath)}: {e}")
    return None

# --- Main processing loop ---
for bin_step in bin_sizes:
    for use_persistence in preproc_variants:
        print(f"\n==========================================================")
        print(f"Processing Bin Size: {bin_step}, Persistence: {use_persistence}")
        print(f"==========================================================")
        
        preproc = SequentialPreprocessor(VarStabilizer("sqrt"), Smoother(10), BaselineCorrecter("SNIP", 20), Trimmer(), PersistenceTransformer(use_persistence), Normalizer(1))
        binner = Binner(BIN_START_MZ, BIN_END_MZ, bin_step)
        aligner = Aligner(reference_peaks=reference_peaks, tolerance=500)
        
        all_files_to_process = [{'filepath': os.path.join(folder, fname), 'label': label} for folder, label in [(neg_dir, 0), (pos_dir, 1)] if os.path.exists(folder) for fname in os.listdir(folder) if fname.endswith(".tsv")]
        print(f"Found {len(all_files_to_process)} total files. Starting parallel preprocessing...")
        
        results = Parallel(n_jobs=-1, verbose=10)(delayed(process_file)(f['filepath'], f['label'], preproc, aligner, binner) for f in all_files_to_process)
        
        data_by_year = {year: [] for year in ALL_YEARS}
        labels_by_year = {year: [] for year in ALL_YEARS}
        for i, res in enumerate(results):
            if res:
                fname = os.path.basename(all_files_to_process[i]['filepath'])
                file_year = next((year for year in ALL_YEARS if fname.startswith(year)), None)
                if file_year:
                    data_by_year[file_year].append(res[0])
                    labels_by_year[file_year].append(res[1])

        datasets = {}
        for year in ALL_YEARS:
            if len(data_by_year[year]) > 10:
                datasets[year] = (np.stack(data_by_year[year]), np.array(labels_by_year[year]))
                print(f"Loaded {len(datasets[year][1])} samples for {year}.")
            else:
                print(f"Not enough data for {year}, skipping this year in tests.")

        if not datasets:
            print("No years with sufficient data. Skipping this configuration.")
            continue
            
        scenarios = [
            {'train': ['2022'], 'test': ['2022', '2023', '2024']},
            {'train': ['2023'], 'test': ['2022', '2023', '2024']},
            {'train': ['2024'], 'test': ['2022', '2023', '2024']},
            {'train': ['2022', '2023'], 'test': ['2024']},
            {'train': ['All'], 'test': ['All']}
        ]

        for scenario in scenarios:
            train_years, test_years = scenario['train'], scenario['test']
            if 'All' in train_years: train_years_str = "All"
            else: train_years_str = "+".join(train_years)
            print(f"\n\n--- Starting Scenario: TRAIN on {train_years_str} ---")

            if train_years_str == "All":
                if not datasets: continue
                all_X, all_y = np.concatenate([d[0] for d in datasets.values()]), np.concatenate([d[1] for d in datasets.values()])
                X_train, X_test_split, y_train, y_test_split = train_test_split(all_X, all_y, test_size=0.2, random_state=42, stratify=all_y)
                test_sets = {f"{train_years_str}_held_out": (X_test_split, y_test_split)}
            else:
                X_train_list, y_train_list, test_sets = [], [], {}
                for year in train_years:
                    if year in datasets and year in test_years:
                        print(f"Splitting {year} for training and testing.")
                        X_year, y_year = datasets[year]
                        X_train_part, X_test_part, y_train_part, y_test_part = train_test_split(X_year, y_year, test_size=0.2, random_state=42, stratify=y_year)
                        X_train_list.append(X_train_part); y_train_list.append(y_train_part); test_sets[year] = (X_test_part, y_test_part)
                    elif year in datasets:
                        X_train_list.append(datasets[year][0]); y_train_list.append(datasets[year][1])
                for year in test_years:
                    if year not in train_years and year in datasets: test_sets[year] = datasets[year]
                if not X_train_list:
                    print(f"Skipping scenario for train years {train_years_str} due to missing data.")
                    continue
                X_train, y_train = np.concatenate(X_train_list), np.concatenate(y_train_list)

            plt.figure(figsize=(14, 10))
            mz_feature_names = [f"{int(BIN_START_MZ + i * bin_step)}-{int(BIN_START_MZ + (i + 1) * bin_step)}" for i in range(X_train.shape[1])]
            
            for model_type, model_variants in model_definitions.items():
                for i, model_template in enumerate(model_variants):
                    model = clone(model_template)
                    print(f"\n-- Fitting {model_type}_v{i} on {train_years_str} data --")
                    try: model.fit(X_train, y_train)
                    except Exception as e:
                        print(f"Training failed for {model_type}_v{i}: {e}")
                        continue
                    for test_year_str, (X_test, y_test) in test_sets.items():
                        if len(X_test) == 0: continue
                        variant_name = f"{model_type}_train_{train_years_str}_test_{test_year_str}_bin{bin_step}_pers{use_persistence}_v{i}"
                        evaluate_performance(model, variant_name, mz_feature_names, X_train, X_test, y_test)

            plt.plot([0, 1], [0, 1], linestyle="--", color="gray", label="Chance")
            plt.legend(fontsize='x-small', bbox_to_anchor=(1.05, 1), loc='upper left')
            plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
            plt.title(f"ROC Curves (Train: {train_years_str} / Bin: {bin_step} / Persist: {use_persistence})")
            plt.tight_layout(rect=[0, 0, 0.75, 1])
            roc_filename = f"roc_train_{train_years_str}_bin{bin_step}_pers{use_persistence}.png"
            plt.savefig(roc_filename, dpi=300, bbox_inches='tight')
            plt.close()
            print(f"ROC Plot for scenario saved to {roc_filename}")

print("\n\n All processing complete. Saving final summaries.")
pd.DataFrame(summary).to_csv("master_auc_summary.csv", index=False)
pd.DataFrame(metrics_detailed).to_json("master_metrics_detailed.json", orient="records", indent=2)
print("\nDone: Master summary written to master_auc_summary.csv, full metrics in master_metrics_detailed.json")

In [None]:
import os
import re
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset, TensorDataset
from scipy.signal import find_peaks
from sklearn.linear_model import RANSACRegressor
import matplotlib.pyplot as plt
import shap

# --- maldi_nn imports ---
# Ensure maldi_nn is installed or its modules are accessible
from maldi_nn.spectrum import (
    SpectrumObject,
    SequentialPreprocessor,
    VarStabilizer,
    Smoother,
    BaselineCorrecter,
    Trimmer
)

#==============================================================================
# PART 1: DATA PREPARATION
#==============================================================================

class Aligner:
    """
    A preprocessor step to align a spectrum's m/z axis to a reference peak list
    using a robust RANSAC regressor.
    """
    def __init__(self, reference_peaks, tolerance=500):
        self.reference_peaks = np.array(reference_peaks)
        self.tolerance = tolerance # tolerance in ppm

    def __call__(self, spectrum: SpectrumObject) -> SpectrumObject:
        peaks, _ = find_peaks(spectrum.intensity, height=0.05 * np.max(spectrum.intensity), prominence=0.01)
        if len(peaks) < 5: return spectrum

        sample_peak_mzs = spectrum.mz[peaks]
        shifts = []
        for ref_peak in self.reference_peaks:
            tol_daltons = self.tolerance * ref_peak / 1e6
            matches = np.where(np.abs(sample_peak_mzs - ref_peak) < tol_daltons)[0]
            if len(matches) > 0:
                best_match_idx = matches[np.argmin(np.abs(sample_peak_mzs[matches] - ref_peak))]
                sample_peak_mz = sample_peak_mzs[best_match_idx]
                shifts.append((sample_peak_mz, ref_peak - sample_peak_mz))

        if len(shifts) < 3: return spectrum

        shifts = np.array(shifts)
        sample_mzs_for_fit, mz_shifts_for_fit = shifts[:, 0].reshape(-1, 1), shifts[:, 1]
        
        try:
            ransac = RANSACRegressor(random_state=42).fit(sample_mzs_for_fit, mz_shifts_for_fit)
            mz_correction = ransac.predict(spectrum.mz.reshape(-1, 1))
            aligned_mz = spectrum.mz + mz_correction
            return SpectrumObject(aligned_mz, spectrum.intensity)
        except ValueError:
            return spectrum

def get_preprocessor(reference_file, mz_range=(2000, 12000)):
    """Initializes the preprocessing pipeline, including the RANSAC Aligner."""
    try:
        ref_df = pd.read_csv(reference_file, sep='\t', engine='python')
        proton_mass = 1.007276 # Use precise proton mass for [M+H]+
        reference_peaks = (ref_df['Mass'] + proton_mass).values
        print(f"Loaded {len(reference_peaks)} reference peaks for alignment.")
    except Exception as e:
        print(f"FATAL: Could not load or parse reference peak file '{reference_file}': {e}")
        return None

    return SequentialPreprocessor(
        VarStabilizer(method="sqrt"),
        Smoother(halfwindow=10),
        BaselineCorrecter(method="SNIP", snip_n_iter=20),
        Aligner(reference_peaks=reference_peaks, tolerance=500),
        Trimmer(*mz_range)
    )

def prepare_cnn_dataset(neg_dir, pos_dir, preprocessor, fixed_len=10000):
    """Loads, preprocesses, and resamples spectra by scanning directories."""
    print("--- Starting Data Preparation with RANSAC Alignment ---")
    samples, labels, years = [], [], []
    data_sources = [(neg_dir, 0), (pos_dir, 1)]
    
    for data_dir, label in data_sources:
        if not os.path.exists(data_dir):
            print(f"Directory not found, skipping: {data_dir}")
            continue
        print(f"Processing files from '{os.path.basename(data_dir)}' with label {label}...")
        for fname in os.listdir(data_dir):
            if not fname.endswith(".tsv"): continue
            file_path = os.path.join(data_dir, fname)
            try:
                year_match = re.match(r'^(\d{4})', fname)
                if not year_match: raise ValueError("Filename does not start with a 4-digit year.")
                year = int(year_match.group(1))
                data = np.loadtxt(file_path, delimiter="\t")
                s = SpectrumObject(data[:, 0], data[:, 1])
                s_processed = preprocessor(s)
                if s_processed is None or len(s_processed.mz) < 2: raise ValueError("Spectrum empty after preprocessing.")
                mz_min, mz_max = s_processed.mz.min(), s_processed.mz.max()
                new_mz_grid = np.linspace(mz_min, mz_max, fixed_len)
                vec = np.interp(new_mz_grid, s_processed.mz, s_processed.intensity).astype(np.float32)
                if np.any(np.isnan(vec)) or np.any(np.isinf(vec)): raise ValueError("NaN or Inf in vector.")
                samples.append(vec)
                labels.append(label)
                years.append(year)
            except Exception as e:
                print(f"Skipped '{file_path}': {e}")
            
    if not samples:
        print("FATAL: No spectra were successfully processed.")
        return None, None

    print(f"\n Successfully processed {len(samples)} total spectra.")
    x_tensor = torch.tensor(np.array(samples), dtype=torch.float32).unsqueeze(1)
    y_tensor = torch.tensor(labels, dtype=torch.long)
    return TensorDataset(x_tensor, y_tensor), np.array(years)

#==============================================================================
# PART 2: CNN MODEL DEFINITIONS
#==============================================================================

class SimpleCNN(nn.Module):
    def __init__(self, input_length=10000, num_classes=2):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 8, 7, padding=3), nn.BatchNorm1d(8), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.25),
            nn.Conv1d(8, 16, 5, padding=2), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.35),
            nn.Conv1d(16, 32, 3, padding=1), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5)
        )
        with torch.no_grad(): self.flattened_size = self.features(torch.zeros(1, 1, input_length)).numel()
        self.classifier = nn.Sequential(nn.Linear(self.flattened_size, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes))
    def forward(self, x):
        x = self.features(x); x = x.view(x.size(0), -1); return self.classifier(x)

class MaldiCNN(nn.Module):
    def __init__(self, input_length=10000, num_classes=2):
        super(MaldiCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 16, 7, padding=3), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.2),
            nn.Conv1d(16, 32, 5, padding=2), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.3),
            nn.Conv1d(32, 64, 3, padding=1), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.4),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5)
        )
        with torch.no_grad(): self.flattened_size = self.features(torch.zeros(1, 1, input_length)).numel()
        self.classifier = nn.Sequential(nn.Linear(self.flattened_size, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes))
    def forward(self, x):
        x = self.features(x); x = x.view(x.size(0), -1); return self.classifier(x)

class DeeperMaldiCNN(nn.Module):
    def __init__(self, input_length=10000, num_classes=2):
        super(DeeperMaldiCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 16, 9, padding=4), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.2),
            nn.Conv1d(16, 32, 7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.3),
            nn.Conv1d(32, 64, 5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.4),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5),
            nn.Conv1d(128, 256, 3, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5)
        )
        with torch.no_grad(): self.flattened_size = self.features(torch.zeros(1, 1, input_length)).numel()
        self.classifier = nn.Sequential(
            nn.Linear(self.flattened_size, 1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, num_classes)
        )
    def forward(self, x):
        x = self.features(x); x = x.view(x.size(0), -1); return self.classifier(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, 3, stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels: self.shortcut = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1, stride, bias=False), nn.BatchNorm1d(out_channels))
    def forward(self, x):
        identity = self.shortcut(x); out = self.relu(self.bn1(self.conv1(x))); out = self.bn2(self.conv2(out)); out += identity; return self.relu(out)

class MaldiResNet(nn.Module):
    def __init__(self, input_length=10000, num_classes=2):
        super(MaldiResNet, self).__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv1d(1, 16, 7, 2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(16); self.relu = nn.ReLU(inplace=True); self.maxpool = nn.MaxPool1d(3, 2, padding=1)
        self.layer1 = self._make_layer(16, 2, stride=1); self.layer2 = self._make_layer(32, 2, stride=2); self.layer3 = self._make_layer(64, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool1d(1); self.fc = nn.Linear(64, num_classes)
    def _make_layer(self, out_channels, num_blocks, stride):
        strides, layers = [stride] + [1]*(num_blocks-1), []
        for s in strides: layers.append(ResidualBlock(self.in_channels, out_channels, s)); self.in_channels = out_channels
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x)))); x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.avgpool(x); x = x.view(x.size(0), -1); return self.fc(x)

class DeeperMaldiResNet(nn.Module):
    def __init__(self, input_length=10000, num_classes=2):
        super(DeeperMaldiResNet, self).__init__()
        self.in_channels = 32
        self.conv1 = nn.Conv1d(1, 32, 7, 2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(32); self.relu = nn.ReLU(inplace=True); self.maxpool = nn.MaxPool1d(3, 2, padding=1)
        self.layer1 = self._make_layer(32, 3, stride=1); self.layer2 = self._make_layer(64, 4, stride=2)
        self.layer3 = self._make_layer(128, 6, stride=2); self.layer4 = self._make_layer(256, 3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool1d(1); self.fc = nn.Linear(256, num_classes)
    def _make_layer(self, out_channels, num_blocks, stride):
        strides, layers = [stride] + [1]*(num_blocks-1), []
        for s in strides: layers.append(ResidualBlock(self.in_channels, out_channels, s)); self.in_channels = out_channels
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x)))); x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.avgpool(x); x = x.view(x.size(0), -1); return self.fc(x)


#==============================================================================
# 3. MAIN EXECUTION BLOCK
#==============================================================================
if __name__ == '__main__':
    # --- Configuration ---
    neg_dir = r"Y:\\test_set\\allspectra\\neg_spectra\\neg_tsv"
    pos_dir = r"Y:\\test_set\\allspectra\\pos_spectra\\pos_tsv"
    ribo_masslist = r"Y:\\test_set\\ribo_Saureus.tsv"
    MZ_RANGE = (2000, 12000)
    VECTOR_LENGTH = 10000
    BEST_MODEL_PATH = 'best_model.pth' # Assumes a pre-trained model exists
    
    # --- STEP 1: DATA PREPARATION ---
    preprocessor = get_preprocessor(ribo_masslist, mz_range=MZ_RANGE)
    if preprocessor:
        full_dataset, years_array = prepare_cnn_dataset(
            neg_dir=neg_dir, pos_dir=pos_dir, preprocessor=preprocessor, fixed_len=VECTOR_LENGTH
        )
        if full_dataset:
            print("\n--- Data Shapes ---")
            print(f"Features (X) tensor shape: {full_dataset.tensors[0].shape}")
            print(f"Labels   (y) tensor shape: {full_dataset.tensors[1].shape}")
            print(f"Unique years found: {np.unique(years_array)}")
            print("\n Data preparation complete.")
            
            # --- STEP 2: SHAP ANALYSIS ---
            print("\n\n" + "="*25 + " SHAP EXPLAINABILITY ANALYSIS " + "="*25)
            if not os.path.exists(BEST_MODEL_PATH):
                print(f"Could not perform SHAP analysis: The model file '{BEST_MODEL_PATH}' was not found.")
                print("Please train a model and save it to this path first.")
            else:
                device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
                model_zoo = {
                    "SimpleCNN": SimpleCNN, "MaldiCNN": MaldiCNN, "DeeperMaldiCNN": DeeperMaldiCNN,
                    "MaldiResNet": MaldiResNet, "DeeperMaldiResNet": DeeperMaldiResNet
                }
                
                print(f"Loading best model from '{BEST_MODEL_PATH}'...")
                checkpoint = torch.load(BEST_MODEL_PATH, map_location=device)
                model_name = checkpoint.get('model_name')
                if not model_name or model_name not in model_zoo:
                    print(f"Error: Model name '{model_name}' not found in checkpoint or model zoo.")
                else:
                    model_class = model_zoo[model_name]
                    num_classes = len(torch.unique(full_dataset.tensors[1]))
                    model = model_class(VECTOR_LENGTH, num_classes).to(device)
                    model.load_state_dict(checkpoint['model_state_dict'])
                    model.eval()

                    val_indices = checkpoint.get('val_indices', [])
                    if not val_indices:
                        print("Error: 'val_indices' not found in checkpoint. Cannot select data for explanation.")
                    else:
                        test_subset = Subset(full_dataset, val_indices)
                        background_subset = Subset(full_dataset, checkpoint.get('train_indices', [])[:100])
                        
                        test_loader = DataLoader(test_subset, batch_size=64)
                        background_loader = DataLoader(background_subset, batch_size=100)
                        
                        test_tensors, _ = next(iter(test_loader))
                        background_tensors, _ = next(iter(background_loader))
                        
                        print("Initializing DeepExplainer...")
                        explainer = shap.DeepExplainer(model, background_tensors.to(device))
                        print(f"Calculating SHAP values for {test_tensors.shape[0]} test samples...")
                        
                        shap_values_raw = explainer.shap_values(test_tensors.to(device))
                        
                        # shap_values_raw for binary classification with 1D Conv is a list of 2 arrays of shape (N, 1, L)
                        # We want the values for the positive class (class 1)
                        shap_values_class1 = shap_values_raw[1].squeeze(1) # Squeeze the channel dimension

                        feature_names = [f"m/z {mz:.1f}" for mz in np.linspace(MZ_RANGE[0], MZ_RANGE[1], VECTOR_LENGTH)]
                        
                        print("\nGenerating SHAP summary plot (beeswarm)...")
                        shap.summary_plot(shap_values_class1, features=test_tensors.squeeze(1).cpu().numpy(), feature_names=feature_names, max_display=25)

In [None]:
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
import copy
import pandas as pd
import shap
import os
try:
    from scipy.signal import find_peaks
except ImportError:
    print("Please install scipy for peak detection: pip install scipy")
    # Define a dummy function if scipy is not available to avoid crashing the script
    def find_peaks(*args, **kwargs):
        print("WARNING: scipy.signal.find_peaks is not available. Peak detection will be skipped.")
        return np.array([]), {}


#==============================================================================
# 1. RE-DEFINE MODEL ZOO (This must be consistent with the training script)
#==============================================================================

class SimpleCNN(nn.Module):
    def __init__(self, input_length, num_classes):
        super(SimpleCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 8, 7, padding=3), nn.BatchNorm1d(8), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.25),
            nn.Conv1d(8, 16, 5, padding=2), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.35),
            nn.Conv1d(16, 32, 3, padding=1), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5)
        )
        with torch.no_grad():
            self.flattened_size = self.features(torch.zeros(1, 1, input_length)).numel()
        self.classifier = nn.Sequential(
            nn.Linear(self.flattened_size, 256), nn.ReLU(), nn.Dropout(0.5), nn.Linear(256, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

class MaldiCNN(nn.Module):
    def __init__(self, input_length, num_classes):
        super(MaldiCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 16, 7, padding=3), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.2),
            nn.Conv1d(16, 32, 5, padding=2), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.3),
            nn.Conv1d(32, 64, 3, padding=1), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.4),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5)
        )
        with torch.no_grad():
            self.flattened_size = self.features(torch.zeros(1, 1, input_length)).numel()
        self.classifier = nn.Sequential(
            nn.Linear(self.flattened_size, 512), nn.ReLU(), nn.Dropout(0.5), nn.Linear(512, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

class DeeperMaldiCNN(nn.Module):
    def __init__(self, input_length, num_classes):
        super(DeeperMaldiCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 16, 9, padding=4), nn.BatchNorm1d(16), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.2),
            nn.Conv1d(16, 32, 7, padding=3), nn.BatchNorm1d(32), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.3),
            nn.Conv1d(32, 64, 5, padding=2), nn.BatchNorm1d(64), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.4),
            nn.Conv1d(64, 128, 3, padding=1), nn.BatchNorm1d(128), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5),
            nn.Conv1d(128, 256, 3, padding=1), nn.BatchNorm1d(256), nn.ReLU(), nn.MaxPool1d(2), nn.Dropout(0.5)
        )
        with torch.no_grad():
            self.flattened_size = self.features(torch.zeros(1, 1, input_length)).numel()
        self.classifier = nn.Sequential(
            nn.Linear(self.flattened_size, 1024), nn.ReLU(), nn.Dropout(0.5), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, num_classes)
        )
    def forward(self, x):
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv1d(in_channels, out_channels, 3, stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm1d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv1d(out_channels, out_channels, 3, padding=1, bias=False)
        self.bn2 = nn.BatchNorm1d(out_channels)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(nn.Conv1d(in_channels, out_channels, 1, stride, bias=False), nn.BatchNorm1d(out_channels))
    def forward(self, x):
        identity = self.shortcut(x)
        out = self.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += identity
        return self.relu(out)

class MaldiResNet(nn.Module):
    def __init__(self, input_length, num_classes):
        super(MaldiResNet, self).__init__()
        self.in_channels = 16
        self.conv1 = nn.Conv1d(1, 16, 7, 2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(16)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(3, 2, padding=1)
        self.layer1 = self._make_layer(16, 2, stride=1)
        self.layer2 = self._make_layer(32, 2, stride=2)
        self.layer3 = self._make_layer(64, 2, stride=2)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(64, num_classes)
    def _make_layer(self, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(ResidualBlock(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

class DeeperMaldiResNet(nn.Module):
    def __init__(self, input_length, num_classes):
        super(DeeperMaldiResNet, self).__init__()
        self.in_channels = 32
        self.conv1 = nn.Conv1d(1, 32, 7, 2, padding=3, bias=False)
        self.bn1 = nn.BatchNorm1d(32)
        self.relu = nn.ReLU(inplace=True)
        self.maxpool = nn.MaxPool1d(3, 2, padding=1)
        self.layer1 = self._make_layer(32, 3, stride=1)
        self.layer2 = self._make_layer(64, 4, stride=2)
        self.layer3 = self._make_layer(128, 6, stride=2)
        self.layer4 = self._make_layer(256, 3, stride=2)
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.fc = nn.Linear(256, num_classes)
    def _make_layer(self, out_channels, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for s in strides:
            layers.append(ResidualBlock(self.in_channels, out_channels, s))
            self.in_channels = out_channels
        return nn.Sequential(*layers)
    def forward(self, x):
        x = self.maxpool(self.relu(self.bn1(self.conv1(x))))
        x = self.layer1(x); x = self.layer2(x); x = self.layer3(x); x = self.layer4(x)
        x = self.avgpool(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

#==============================================================================
# 2. SHAP ANALYSIS EXECUTION (Refactored for Robustness)
#==============================================================================
if __name__ == '__main__':
    BEST_MODEL_PATH = 'best_model.pth'
    VECTOR_LENGTH = 10000
    SHAP_VALUES_PATH = 'shap_values.csv'
    FEATURE_VALUES_PATH = 'feature_values_for_shap.csv'
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    model_zoo = {
        "SimpleCNN": SimpleCNN,
        "MaldiCNN_4_Layer": MaldiCNN,
        "DeeperMaldiCNN_5_Layer": DeeperMaldiCNN,
        "MaldiResNet_Shazam": MaldiResNet,
        "DeeperMaldiResNet": DeeperMaldiResNet
    }

    print("\n\n" + "="*25 + " SHAP EXPLAINABILITY ANALYSIS " + "="*25)

    if 'full_dataset' not in locals() or 'years_array' not in locals():
        print("Prerequisite 'full_dataset' or 'years_array' not found. Please run Script 1 first.")
    elif os.path.exists(BEST_MODEL_PATH):
        print(f"Loading best model from '{BEST_MODEL_PATH}'...")
        
        checkpoint = torch.load(BEST_MODEL_PATH, map_location=device, weights_only=False)
        model_name = checkpoint['model_name']
        model_class = model_zoo.get(model_name)
        
        if model_class is None:
            print(f"Error: Model name '{model_name}' from saved file is not defined in the model_zoo.")
        else:
            if not os.path.exists(SHAP_VALUES_PATH):
                # ... (Part 1 code to calculate SHAP would go here if needed) ...
                pass
            
            # --- Load data for plotting ---
            print("\n--- Loading data for analysis ---")
            shap_df = pd.read_csv(SHAP_VALUES_PATH, index_col=0)
            feature_values_df = pd.read_csv(FEATURE_VALUES_PATH, index_col=0)
            
            # --- Part 5: Advanced Peak Analysis and Mirror Plot (MODIFIED) ---
            try:
                print("\n" + "="*25 + " PEAK ANALYSIS & PLOTTING " + "="*25)
                
                mean_shaps = shap_df.mean(axis=1)
                mz_values = np.array([float(s.split(' ')[1]) for s in shap_df.index])

                print("Finding positive and negative peaks...")
                pos_peaks_idx, pos_properties = find_peaks(mean_shaps.values, prominence=0.005, distance=50)
                neg_peaks_idx, neg_properties = find_peaks(-mean_shaps.values, prominence=0.005, distance=50)
                print(f"Found {len(pos_peaks_idx)} positive peaks and {len(neg_peaks_idx)} negative peaks.")

                # --- Mirror Plot ---
                # (This can be kept or removed, as the separate plots below are more detailed)
                
                # --- Positive Plot and Table ---
                print("\n--- Generating Positive Impact Plot and Table ---")
                plt.figure(figsize=(20, 6))
                plt.plot(mz_values, mean_shaps.where(mean_shaps > 0), color='red', alpha=0.7)
                plt.scatter(mz_values[pos_peaks_idx], mean_shaps.values[pos_peaks_idx], color='maroon', s=50, zorder=5, marker='X')
                top_pos_indices = np.argsort(pos_properties['prominences'])[-10:]
                for i in pos_peaks_idx[top_pos_indices]:
                    plt.text(mz_values[i], mean_shaps.values[i], f' {mz_values[i]:.1f}', verticalalignment='bottom', fontsize=9)
                plt.title(f'Positive SHAP Impact (-> Class 1) for {model_name}', fontsize=16)
                plt.xlabel("m/z", fontsize=12)
                plt.ylabel("Mean SHAP Value", fontsize=12)
                plt.grid(True, linestyle=':', alpha=0.6)
                plt.tight_layout()
                save_path = f"{model_name}_shap_positive_peaks_plot.png"
                plt.savefig(save_path, dpi=150)
                print(f"Saved Positive Peaks Plot to: {save_path}")
                plt.show()

                # Create, print, and save the corresponding table for ALL positive peaks
                pos_peaks_df = pd.DataFrame({
                    'm/z': mz_values[pos_peaks_idx],
                    'Mean SHAP Value': mean_shaps.values[pos_peaks_idx],
                    'Prominence': pos_properties['prominences']
                }).sort_values(by='Prominence', ascending=False).reset_index(drop=True)
                
                pos_table_path = 'positive_peaks_table.csv'
                pos_peaks_df.to_csv(pos_table_path, index=False)
                print(f"Table of all {len(pos_peaks_df)} positive peaks saved to '{pos_table_path}'")
                print("--- Table for Positive Peaks (Sorted by Prominence) ---")
                print(pos_peaks_df.to_string())


                # --- Negative Plot and Table ---
                print("\n--- Generating Negative Impact Plot and Table ---")
                plt.figure(figsize=(20, 6))
                plt.plot(mz_values, mean_shaps.where(mean_shaps < 0), color='royalblue', alpha=0.7)
                plt.scatter(mz_values[neg_peaks_idx], mean_shaps.values[neg_peaks_idx], color='navy', s=50, zorder=5, marker='X')
                top_neg_indices = np.argsort(neg_properties['prominences'])[-10:]
                for i in neg_peaks_idx[top_neg_indices]:
                    plt.text(mz_values[i], mean_shaps.values[i], f' {mz_values[i]:.1f}', verticalalignment='top', fontsize=9)
                plt.title(f'Negative SHAP Impact (-> Class 0) for {model_name}', fontsize=16)
                plt.xlabel("m/z", fontsize=12)
                plt.ylabel("Mean SHAP Value", fontsize=12)
                plt.grid(True, linestyle=':', alpha=0.6)
                plt.tight_layout()
                save_path = f"{model_name}_shap_negative_peaks_plot.png"
                plt.savefig(save_path, dpi=150)
                print(f"Saved Negative Peaks Plot to: {save_path}")
                plt.show()

                # Create, print, and save the corresponding table for ALL negative peaks
                neg_peaks_df = pd.DataFrame({
                    'm/z': mz_values[neg_peaks_idx],
                    'Mean SHAP Value': mean_shaps.values[neg_peaks_idx],
                    'Prominence': neg_properties['prominences']
                }).sort_values(by='Prominence', ascending=False).reset_index(drop=True)

                neg_table_path = 'negative_peaks_table.csv'
                neg_peaks_df.to_csv(neg_table_path, index=False)
                print(f"Table of all {len(neg_peaks_df)} negative peaks saved to '{neg_table_path}'")
                print("--- Table for Negative Peaks (Sorted by Prominence) ---")
                print(neg_peaks_df.to_string())

            except NameError:
                 print("Could not generate peak analysis plots because `find_peaks` is not available. Is `scipy` installed?")
            except Exception as e:
                print(f"Could not generate peak analysis plots: {e}")
            
            # --- Part 6: FINAL SYNTHESIS: COMBINED PLOT AND PEAK TABLE ---
            try:
                print("\n" + "="*25 + " FINAL SYNTHESIS " + "="*25)

                # --- Combined Summary Table ---
                print("\nCreating combined peak summary table...")
                all_peaks_df = pd.concat([
                    pos_peaks_df.assign(**{'Class Impact': '-> Class 1'}),
                    neg_peaks_df.assign(**{'Class Impact': '-> Class 0'})
                ]).sort_values(by='Prominence', ascending=False).reset_index(drop=True)
                
                peak_table_path = 'identified_peaks_summary.csv'
                all_peaks_df.to_csv(peak_table_path, index=False)
                print(f"Combined peak summary table saved to '{peak_table_path}'")
                print("Top 15 Most Prominent Peaks (Combined):")
                print(all_peaks_df.head(15).to_string())

                # --- Combined Overlay Plot ---
                print("\nGenerating Combined Spectrum vs. SHAP Plot...")
                all_labels = full_dataset.tensors[1].cpu().numpy()
                val_indices = checkpoint['val_indices']
                val_labels = all_labels[val_indices]
                
                mean_class1_spectrum = feature_values_df.iloc[:, val_labels == 1].mean(axis=1)
                mean_class0_spectrum = feature_values_df.iloc[:, val_labels == 0].mean(axis=1)
                
                fig, ax1 = plt.subplots(figsize=(20, 8))
                p1, = ax1.plot(mz_values, mean_shaps.where(mean_shaps > 0), color='red', alpha=0.7, label='Positive SHAP Impact')
                p2, = ax1.plot(mz_values, mean_shaps.where(mean_shaps < 0), color='royalblue', alpha=0.7, label='Negative SHAP Impact')
                ax1.set_xlabel('m/z', fontsize=12)
                ax1.set_ylabel('Mean SHAP Value', color='black', fontsize=12)
                ax1.axhline(0, color='black', linestyle='--', linewidth=1)

                ax2 = ax1.twinx()
                p3 = ax2.fill_between(mz_values, 0, mean_class1_spectrum, color='salmon', alpha=0.3, label='Mean Spectrum (Class 1)')
                p4 = ax2.fill_between(mz_values, 0, mean_class0_spectrum, color='skyblue', alpha=0.3, label='Mean Spectrum (Class 0)')
                ax2.set_ylabel('Mean Intensity', color='gray', fontsize=12)
                
                plt.title(f'Combined Plot: Mean Spectra vs. Mean SHAP Impact for {model_name}', fontsize=16)
                ax1.set_xlim(2000, 12000)
                ax1.legend([p1, p2, p3, p4], [p.get_label() for p in [p1, p2, p3, p4]], loc='upper right')
                
                plt.tight_layout()
                save_path = f"{model_name}_shap_combined_overlay_plot.png"
                plt.savefig(save_path, dpi=150)
                print(f"Saved Combined Overlay Plot to: {save_path}")
                plt.show()

            except Exception as e:
                print(f"Could not generate final synthesis plots: {e}")

    else:
        print(f"Could not perform SHAP analysis: The model file '{BEST_MODEL_PATH}' was not found.")