In [44]:
import os
import pandas as pd
import numpy as np
import joblib
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report, roc_auc_score, roc_curve, confusion_matrix, ConfusionMatrixDisplay
from xgboost import XGBClassifier
from catboost import CatBoostClassifier
from lightgbm import LGBMClassifier

class EmbeddingClassifier:
    def __init__(
        self,
        llm_name: str,
        window: str,
        output_dir: str = "./results",
        id_column: str = "VariationID",
        label_column: str = "is_pathogenic"
    ):
        self.llm_name = llm_name
        self.window = window
        self.output_dir = output_dir
        self.id_column = id_column
        self.label_column = label_column
        self.model_output_dir = os.path.join(output_dir, f"{llm_name}_{window}")
        os.makedirs(self.model_output_dir, exist_ok=True)
        # Mapping for classifier legend
        self.classifier_display = {
            "LR": "Logistic Regression",
            "XGB": "XGBoost",
            "CAT": "CatBoost",
            "LGBM": "LightGBM"
        }
        self.classifiers = {
            "LR": LogisticRegression(max_iter=500, class_weight="balanced", random_state=42),
            "XGB": XGBClassifier(use_label_encoder=False, eval_metric="logloss", random_state=42),
            "CAT": CatBoostClassifier(verbose=0, random_state=42),
            "LGBM": LGBMClassifier(random_state=42)
        }

    def _path(self, stem, ext, clf=None):
        # Helper to generate consistent filenames in the subfolder
        if clf is not None:
            return os.path.join(self.model_output_dir, f"{self.llm_name}_{self.window}_{clf}_{stem}.{ext}")
        else:
            return os.path.join(self.model_output_dir, f"{self.llm_name}_{self.window}_{stem}.{ext}")
    
    def plot_pca_elbow(self, pca, var_thresh=0.95):
        cumvar = np.cumsum(pca.explained_variance_ratio_)
        n_components = int(np.argmax(cumvar >= var_thresh) + 1)
        plt.figure(figsize=(8, 5))
        x = np.arange(1, len(cumvar) + 1)
        plt.plot(x, cumvar, marker="o", label="Cumulative explained variance")
        plt.axvline(n_components, color="red", linestyle="--", label=f"{n_components} components")
        plt.scatter([n_components], [cumvar[n_components - 1]], color="red", zorder=5)
        plt.axhline(var_thresh, color="gray", linestyle=":", label=f"{int(var_thresh*100)}% threshold")
        plt.xlabel("Number of PCA components")
        plt.ylabel("Cumulative explained variance")
        plt.title(f"PCA Elbow: {self.llm_name} {self.window}")
        plt.legend(loc="lower right")
        plt.tight_layout()
        fname = self._path("PCA_elbow", "png")
        plt.savefig(fname)
        plt.close()
        print(f"PCA elbow plot saved to {fname}")

    def fit_pca_and_classifiers(
        self,
        train_meta: str,
        train_emb: str,
        pca_var: float = 0.95,
        classifiers: list = None
    ):
        df = pd.read_csv(train_meta, sep="\t")
        emb = pd.read_csv(train_emb)
        train = df.merge(emb, on=self.id_column)
        X = train.filter(like="emb_").values
        y = train[self.label_column].values

        # 1. Fit PCA
        pca = PCA(n_components=pca_var, random_state=42)
        X_pca = pca.fit_transform(X)
        pca_path = self._path("pca", "pkl")
        joblib.dump(pca, pca_path)
        print(f"Saved PCA model to {pca_path}, reduced to {X_pca.shape[1]} dims")

        # 2. Train and save classifiers
        to_train = classifiers or list(self.classifiers.keys())
        for clf_name in to_train:
            clf = self.classifiers[clf_name]
            clf.fit(X_pca, y)
            clf_path = self._path("clf", "pkl", clf=clf_name)
            joblib.dump(clf, clf_path)
            print(f"Saved {clf_name} classifier to {clf_path}")

        # Optionally, save the reduced train set for auditing
        pca_cols = [f'pca_{i+1}' for i in range(X_pca.shape[1])]
        pca_df = pd.DataFrame(X_pca, columns=pca_cols, index=train.index)
        train_reduced = pd.concat([train.reset_index(drop=True), pca_df.reset_index(drop=True)], axis=1)
        train_reduced.to_csv(self._path("train_reduced", "csv"), index=False)
        print(f"Saved reduced train set to {self._path('train_reduced', 'csv')}")
        # Plot PCA elbow
        self.plot_pca_elbow(pca, var_thresh=pca_var)

    def plot_all_metrics_grid(self, metrics_dict):
        """
        Plot all metrics (precision, recall, f1-score, support) as grouped bar charts
        for all classifiers, saved into a single 2x2 grid image.
        The legend is shown once on the right of all subplots.
        """
        import matplotlib.pyplot as plt
    
        metrics = ['precision', 'recall', 'f1-score', 'support']
        example_metrics = next(iter(metrics_dict.values()))
        exclude_keys = {"accuracy", "macro avg", "weighted avg"}
        # Top-level keys are class labels (as strings) plus 'accuracy' etc.
        class_labels = [k for k in example_metrics.keys() if k not in exclude_keys]
        n_classes = len(class_labels)
        x = np.arange(n_classes)
        width = 0.18
        clf_names = list(metrics_dict.keys())
    
        fig, axs = plt.subplots(2, 2, figsize=(16, 12))
        axs = axs.flatten()
        for i, metric in enumerate(metrics):
            ax = axs[i]
            for j, clf in enumerate(clf_names):
                # For each class, get the metric value
                vals = [metrics_dict[clf][cl][metric] for cl in class_labels]
                ax.bar(x + j*width, vals, width, label=self.classifier_display.get(clf, clf))
            ax.set_xticks(x + width * (len(clf_names) - 1) / 2)
            ax.set_xticklabels(class_labels)
            ax.set_ylabel(metric.capitalize())
            ax.set_title(f'{metric.capitalize()} by Classifier')
            # Do NOT call ax.legend() inside the loop
    
        # Place one legend for all plots to the right
        handles, labels = axs[0].get_legend_handles_labels()
        fig.legend(handles, labels, title="Classifier", loc="center right", bbox_to_anchor=(1.12, 0.5))
        plt.tight_layout(rect=[0, 0, 0.93, 1])  # Make space for the legend on the right
    
        fname = os.path.join(self.model_output_dir, f"{self.llm_name}_{self.window}_all_metrics_bar.png")
        plt.savefig(fname, bbox_inches="tight")
        plt.close()
        print(f"Saved all metrics bar grid to {fname}")

    def plot_all_confusion_matrices(self, y_true, y_preds_dict, class_names=None):
        """
        Plot confusion matrices for each classifier in a 2x2 grid.
        """
        import matplotlib.pyplot as plt
        from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
    
        n_clf = len(y_preds_dict)
        nrows, ncols = 2, 2  # For up to 4 classifiers
        fig, axs = plt.subplots(nrows, ncols, figsize=(12, 12))
        axs = axs.flatten()
        for i, (clf, y_pred) in enumerate(y_preds_dict.items()):
            cm = confusion_matrix(y_true, y_pred)
            disp = ConfusionMatrixDisplay(cm, display_labels=class_names)
            disp.plot(ax=axs[i], colorbar=False)
            axs[i].set_title(f"Confusion Matrix: {clf}")
        # Hide unused axes if less than 4 classifiers
        for j in range(i + 1, nrows * ncols):
            fig.delaxes(axs[j])
        plt.tight_layout()
        fname = os.path.join(self.model_output_dir, f"{self.llm_name}_{self.window}_all_confmats.png")
        plt.savefig(fname)
        plt.close()
        print(f"Saved all confusion matrices to {fname}")

    def plot_roc_auc(self, y_true, y_probs_dict):
        """
        Plots ROC curves for all classifiers.
        """
        from sklearn.metrics import roc_curve, auc
        plt.figure(figsize=(8, 6))
        for clf, y_prob in y_probs_dict.items():
            fpr, tpr, _ = roc_curve(y_true, y_prob)
            roc_auc = auc(fpr, tpr)
            plt.plot(fpr, tpr, lw=2, label=f"{self.classifier_display.get(clf, clf)} (AUC={roc_auc:.2f})")
        plt.plot([0, 1], [0, 1], "k--", lw=1)
        plt.xlabel("False Positive Rate")
        plt.ylabel("True Positive Rate")
        plt.title(f"ROC Curve: {self.llm_name} {self.window}")
        plt.legend(loc="lower right")
        plt.tight_layout()
        plot_path = self._path("ROC", "png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved ROC curve plot to {plot_path}")

    def plot_confusion_matrix(self, y_true, y_pred, clf_name):
        """
        Plots and saves the confusion matrix for a classifier.
        """
        from sklearn.metrics import ConfusionMatrixDisplay
        cm = confusion_matrix(y_true, y_pred)
        disp = ConfusionMatrixDisplay(cm)
        disp.plot(cmap='Blues')
        plt.title(f"Confusion Matrix: {self.classifier_display.get(clf_name, clf_name)}")
        plt.tight_layout()
        plot_path = self._path(f"confmat_{clf_name}", "png")
        plt.savefig(plot_path)
        plt.close()
        print(f"Saved confusion matrix to {plot_path}")

    def save_classifier_legend(self):
        legend_lines = [
            "Classifier Legend:",
            "  CAT  - CatBoost",
            "  LR   - Logistic Regression",
            "  XGB  - XGBoost",
            "  LGBM - LightGBM"
        ]
        legend_txt = "\n".join(legend_lines)
        print(legend_txt)
        with open(self._path("classifier_legend", "txt"), "w") as f:
            f.write(legend_txt)

    def predict_and_eval(
        self,
        test_meta: str,
        test_emb: str,
        classifiers: list = None
    ):
        df = pd.read_csv(test_meta, sep="\t")
        emb = pd.read_csv(test_emb)
        test = df.merge(emb, on=self.id_column)
        X = test.filter(like="emb_").values
        y = test[self.label_column].values

        # 2. Load PCA and reduce test
        pca = joblib.load(self._path("pca", "pkl"))
        X_pca = pca.transform(X)
        print(f"Test data reduced to {X_pca.shape[1]} dims")

        to_eval = classifiers or list(self.classifiers.keys())
        metrics_dict = {}
        y_probs_dict = {}
        y_preds_dict = {}   # <-- NEW

        for clf_name in to_eval:
            clf = joblib.load(self._path("clf", "pkl", clf=clf_name))
            y_pred = clf.predict(X_pca)
            y_preds_dict[clf_name] = y_pred  # <-- NEW
            if hasattr(clf, "predict_proba"):
                y_prob = clf.predict_proba(X_pca)[:, 1]
            else:
                y_prob = np.zeros_like(y_pred)
            y_probs_dict[clf_name] = y_prob
            cr = classification_report(y, y_pred, output_dict=True)
            metrics_dict[clf_name] = cr
            print(f"\n== {clf_name} on {self.llm_name} {self.window} ==")
            print(classification_report(y, y_pred, digits=3))
            print("ROC AUC: %.3f" % roc_auc_score(y, y_prob))
            # Save predictions
            test_cp = test.copy()
            test_cp["prediction"] = y_pred
            test_cp["prob_pathogenic"] = y_prob
            pred_file = self._path(f"predictions_test", "csv", clf=clf_name)
            test_cp.to_csv(pred_file, index=False)
            print(f"Saved predictions: {pred_file}")
            # Save confusion matrix
            #self.plot_confusion_matrix(y, y_pred, clf_name)

        # Visualize all metrics and confusion matrices in one file
        self.plot_all_metrics_grid(metrics_dict)
        self.plot_all_confusion_matrices(y, y_preds_dict, class_names=["Non-Pathogenic", "Pathogenic"])
        self.plot_roc_auc(y, y_probs_dict)
        self.save_classifier_legend()

In [45]:
classifier = EmbeddingClassifier(llm_name="NT", window="225bp", output_dir="./results")
classifier.fit_pca_and_classifiers(
    train_meta="./data/windows_225/clinvar_binary_train_225.tsv",
    train_emb="./data/embeddings/clinvar_binary_train_embeddings_NT_225.csv",
    pca_var=0.95,  # retain 95% variance
    classifiers=["LR", "XGB", "CAT", "LGBM"]  # Run all, or just pick your subset
)

# Evaluate on holdout (test) set
classifier.predict_and_eval(
    test_meta="./data/windows_225/clinvar_binary_test_225.tsv",
    test_emb="./data/embeddings/clinvar_binary_test_embeddings_NT_225.csv",
    classifiers=["LR", "XGB", "CAT", "LGBM"]  # Must match what you trained!
)

Saved PCA model to ./results/NT_225bp/NT_225bp_pca.pkl, reduced to 398 dims
Saved LR classifier to ./results/NT_225bp/NT_225bp_LR_clf.pkl


Parameters: { "use_label_encoder" } are not used.



Saved XGB classifier to ./results/NT_225bp/NT_225bp_XGB_clf.pkl
Saved CAT classifier to ./results/NT_225bp/NT_225bp_CAT_clf.pkl
[LightGBM] [Info] Number of positive: 15000, number of negative: 15000
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.023839 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 101490
[LightGBM] [Info] Number of data points in the train set: 30000, number of used features: 398
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Saved LGBM classifier to ./results/NT_225bp/NT_225bp_LGBM_clf.pkl
Saved reduced train set to ./results/NT_225bp/NT_225bp_train_reduced.csv
PCA elbow plot saved to ./results/NT_225bp/NT_225bp_PCA_elbow.png
Test data reduced to 398 dims

== LR on NT 225bp ==
              precision    recall  f1-score   support

           0      0.713     0.703     0.708      1500
           1      0.707     0.717     0.712      1500

    accur

In [46]:
# Train (fit PCA and all 4 classifiers on train set)
classifier = EmbeddingClassifier(llm_name="DNABERT6", window="225bp", output_dir="./results")
classifier.fit_pca_and_classifiers(
    train_meta="./data/windows_225/clinvar_binary_train_225.tsv",
    train_emb="./data/embeddings/clinvar_binary_train_embeddings_DNABERT6_225.csv",
    pca_var=0.95,  # retain 95% variance
    classifiers=["LR", "XGB", "CAT", "LGBM"]  # Run all, or just pick your subset
)

# Evaluate on holdout (test) set
classifier.predict_and_eval(
    test_meta="./data/windows_225/clinvar_binary_test_225.tsv",
    test_emb="./data/embeddings/clinvar_binary_test_embeddings_DNABERT6_225.csv",
    classifiers=["LR", "XGB", "CAT", "LGBM"]  # Must match what you trained!
)

Saved PCA model to ./results/DNABERT6_225bp/DNABERT6_225bp_pca.pkl, reduced to 38 dims
Saved LR classifier to ./results/DNABERT6_225bp/DNABERT6_225bp_LR_clf.pkl


Parameters: { "use_label_encoder" } are not used.



Saved XGB classifier to ./results/DNABERT6_225bp/DNABERT6_225bp_XGB_clf.pkl
Saved CAT classifier to ./results/DNABERT6_225bp/DNABERT6_225bp_CAT_clf.pkl
[LightGBM] [Info] Number of positive: 15000, number of negative: 15000
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002336 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 9690
[LightGBM] [Info] Number of data points in the train set: 30000, number of used features: 38
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Saved LGBM classifier to ./results/DNABERT6_225bp/DNABERT6_225bp_LGBM_clf.pkl
Saved reduced train set to ./results/DNABERT6_225bp/DNABERT6_225bp_train_reduced.csv
PCA elbow plot saved to ./results/DNABERT6_225bp/DNABERT6_225bp_PCA_elbow.png
Test data reduced to 38 dims

== LR on DNABERT6 225bp ==
              precision    recall  f1-score   support

           0      0.540     0.485     0.511      1500
  

In [47]:
# Train (fit PCA and all 4 classifiers on train set)
classifier = EmbeddingClassifier(llm_name="GROVER", window="225bp", output_dir="./results")
classifier.fit_pca_and_classifiers(
    train_meta="./data/windows_225/clinvar_binary_train_225.tsv",
    train_emb="./data/embeddings/clinvar_binary_train_embeddings_GROVER_225.csv",
    pca_var=0.95,  # retain 95% variance
    classifiers=["LR", "XGB", "CAT", "LGBM"]  # Run all, or just pick your subset
)

# Evaluate on holdout (test) set
classifier.predict_and_eval(
    test_meta="./data/windows_225/clinvar_binary_test_225.tsv",
    test_emb="./data/embeddings/clinvar_binary_test_embeddings_GROVER_225.csv",
    classifiers=["LR", "XGB", "CAT", "LGBM"]  # Must match what you trained!
)

Saved PCA model to ./results/GROVER_225bp/GROVER_225bp_pca.pkl, reduced to 184 dims
Saved LR classifier to ./results/GROVER_225bp/GROVER_225bp_LR_clf.pkl


Parameters: { "use_label_encoder" } are not used.



Saved XGB classifier to ./results/GROVER_225bp/GROVER_225bp_XGB_clf.pkl
Saved CAT classifier to ./results/GROVER_225bp/GROVER_225bp_CAT_clf.pkl
[LightGBM] [Info] Number of positive: 15000, number of negative: 15000
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.011903 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 46920
[LightGBM] [Info] Number of data points in the train set: 30000, number of used features: 184
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
Saved LGBM classifier to ./results/GROVER_225bp/GROVER_225bp_LGBM_clf.pkl
Saved reduced train set to ./results/GROVER_225bp/GROVER_225bp_train_reduced.csv
PCA elbow plot saved to ./results/GROVER_225bp/GROVER_225bp_PCA_elbow.png
Test data reduced to 184 dims

== LR on GROVER 225bp ==
              precision    recall  f1-score   support

           0      0.573     0.547     0.559      1500
           1      0.5