In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.lines as mlines
import seaborn as sns
import matplotlib.gridspec as gridspec
from matplotlib.patches import Patch
import scanpy as sc



In [2]:
def plot_metrics_and_coefficients(metrics_df, coefs_df, sample_name="all_samples", top_n=10, alpha_idx=None):
    """
    Plot classification metrics and logistic regression coefficient paths.
    
    Parameters:
    - metrics_df: DataFrame containing metrics along the regularization path
    - coefs_df: DataFrame containing coefficients along the regularization path
    - sample_name: Name of the sample to display in the title
    - top_n: Number of top features to plot based on survival frequency
    - alpha_idx: Index of the alpha value to highlight with a vertical line
    """
    # Filter metrics for the given sample
    metrics_results = metrics_df[metrics_df['group'] == sample_name]
    
    if metrics_results.empty:
        print(f"No metrics found for sample '{sample_name}'")
        return None
    
    # Step 1: Extract alpha values and coefficient data
    alphas = coefs_df.columns.astype(float).values  # Convert column names to float
    coef_results_arr = np.array(coefs_df)  # Convert DataFrame to NumPy array
    feature_names = coefs_df.index  # Feature names from the index

    # Step 2: Extract metrics
    overall_acc = metrics_results['overall_accuracy'].values
    mal_accuracy = metrics_results['mal_accuracy'].values
    norm_accuracy = metrics_results['norm_accuracy'].values
    
    # If 'roc_auc' is present, we'll plot that on the same left axis
    has_roc = 'roc_auc' in metrics_results.columns
    if has_roc:
        roc_auc = metrics_results['roc_auc'].values
    else:
        roc_auc = None

    # Surviving features: count how many coefficients are non‐zero at each alpha
    surviving_features = (coefs_df != 0).sum(axis=0)  # Count non-zero per alpha

    # Extract the majority and minority class numbers if available
    if 'majority_num' in metrics_results.columns and 'minority_num' in metrics_results.columns:
        majority_num = metrics_results['majority_num'].values[0]
        minority_num = metrics_results['minority_num'].values[0]
    else:
        majority_num = "Unknown"
        minority_num = "Unknown"

    # Step 3: Identify top features based on how long they survive regularization
    non_zero_counts = (coefs_df != 0).sum(axis=1)  # For each feature, how many non-zero across alphas
    top_features_idx = np.argsort(non_zero_counts)[-top_n:]  # Indices of top surviving features

    # Step 4: Create a two‐panel figure
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(12, 10), gridspec_kw={'height_ratios': [1, 2]})

    # -- TOP AXIS: Plot metrics (ACC, etc.) on left axis, Surviving Features on right axis --
    log_alphas = np.log10(alphas)

    # Accuracy lines
    ax1.plot(log_alphas, overall_acc, 'o-', label="Overall Accuracy", color='skyblue', linewidth=1.5, alpha=0.8)
    ax1.plot(log_alphas, mal_accuracy, '^-', label="Cancer Cell Accuracy", color='darkblue', linewidth=1.5, alpha=0.8)
    ax1.plot(log_alphas, norm_accuracy, 's--', label="Normal Cell Accuracy", color='green', linewidth=1.5, alpha=0.8)

    # Optionally plot trivial accuracy
    if "trivial_accuracy" in metrics_results:
        trivial_acc = metrics_results['trivial_accuracy'].values
        ax1.plot(
            log_alphas,
            trivial_acc,
            '--',
            label=f"Trivial (Majority) Accuracy = {trivial_acc[0]:.3f}",
            color='red',
            linewidth=1.5,
            alpha=0.7
        )

    # Optionally plot ROC AUC if present
    if roc_auc is not None:
        ax1.plot(
            log_alphas,
            roc_auc,
            'd-',
            label="ROC AUC",
            color='purple',
            linewidth=1.5,
            alpha=0.8
        )

    ax1.set_xlabel(r"$\log_{10}(\lambda)$")
    ax1.set_ylabel("Accuracy / AUC")
    ax1.grid(True)
    
    # Set x-axis limits to include all data points
    min_log_alpha = -4
    max_log_alpha = 5
    ax1.set_xlim(min_log_alpha, max_log_alpha)
    ax1.set_xticks(range(min_log_alpha, max_log_alpha + 1))

    # Surviving features on secondary y‐axis
    ax1_2 = ax1.twinx()
    ax1_2.plot(log_alphas,
               surviving_features / len(feature_names) * 100,
               's-',
               color='orange',
               label="Surviving Features (%)",
               alpha=0.8)
    ax1_2.set_ylabel("Surviving Features (%)")
    ax1_2.set_ylim([0, 100])

    # Add vertical line at the selected alpha index if provided
    if alpha_idx is not None and 0 <= alpha_idx < len(alphas):
        selected_log_alpha = log_alphas[alpha_idx]
        selected_alpha = alphas[alpha_idx]
        ax1.axvline(x=selected_log_alpha, color='black', linestyle='-.', linewidth=2, 
                    label=f"Selected λ={selected_alpha:.2e}", alpha=0.6)

    # Prepare legends combining lines from both y-axes
    lines_1, labels_1 = ax1.get_legend_handles_labels()
    lines_2, labels_2 = ax1_2.get_legend_handles_labels()

    # Extra lines for class distribution
    extra_lines = [plt.Line2D([0], [0], color="none")] * 2
    extra_labels = [
        f"Majority (normal) Size: {majority_num}",
        f"Minority (cancer) Size: {minority_num}"
    ]

    ax1.legend(
        extra_lines + lines_1 + lines_2,
        extra_labels + labels_1 + labels_2,
        loc='center left',
        bbox_to_anchor=(1.15, 0.5),
        fontsize='small',
        frameon=False
    )

    ax1.set_title(f"Class-Specific Accuracies, ROC-AUC & surviving features with changing regularization strength \nPatient: {sample_name}")

    # -- BOTTOM AXIS: Plot Coefficient Paths for top N features --
    for idx in top_features_idx:
        ax2.plot(log_alphas, coef_results_arr[idx],
                 label=feature_names[idx],
                 alpha=0.8)

    # Add vertical line at the selected alpha index in bottom panel as well
    if alpha_idx is not None and 0 <= alpha_idx < len(alphas):
        ax2.axvline(x=selected_log_alpha, color='black', linestyle='-.', linewidth=2, alpha=0.6)

    # Use the same x-axis limits and ticks as the top panel
    ax2.set_xlim(min_log_alpha, max_log_alpha)
    ax2.set_xticks(range(min_log_alpha, max_log_alpha + 1))
    
    # Set x-axis label for bottom panel
    ax2.set_xlabel(r"$\log_{10}(\lambda)$; $\lambda$ (Lasso regularization strength)")
    ax2.set_ylabel("Coefficient Value")
    ax2.set_title(f"Logistic Regression Coefficient Paths for Top-{top_n} Features")
    ax2.axhline(0, color='black', linestyle='--', lw=1)
    ax2.legend(loc='upper right', bbox_to_anchor=(1.3, 1), fontsize='small')
    ax2.grid(True)

    plt.tight_layout()
    return fig

In [3]:
# Set the path to your results directory
results_dir = "../pipeline/results_multivi_classification"
per_patient_dir = results_dir # os.path.join(results_dir, "per_patient")
plots_output_dir = os.path.join(results_dir, "metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots")

# Create the output directory for plots if it doesn't exist
os.makedirs(plots_output_dir, exist_ok=True)

In [5]:
# Get the list of patient directories
patient_dirs = [d for d in os.listdir(per_patient_dir) if d.startswith("patient_") and 
                os.path.isdir(os.path.join(per_patient_dir, d))]

# Iterate through each patient directory
for patient_dir in patient_dirs:
    patient_id = patient_dir.replace("patient_", "")
    print(f"Processing patient: {patient_id}")
    
    # Check if the required files exist
    coefs_file = os.path.join(per_patient_dir, patient_dir, f"results_coefs_X_multivi_{patient_id}.csv")
    metrics_file = os.path.join(per_patient_dir, patient_dir, f"results_metrics_X_multivi_{patient_id}.csv")
    
    if not os.path.exists(coefs_file) or not os.path.exists(metrics_file):
        print(f"  Missing required files for {patient_id}. Skipping.")
        continue
    
    # Read the coefficient and metrics files
    coefs_df = pd.read_csv(coefs_file, index_col=0)
    metrics_df = pd.read_csv(metrics_file)
    
    # Get the alpha index if one has been specified for this patient (1 based)
    # alpha_idx = nmf_hand_picked_alpha_indices.get(patient_id, None) - 1
    
    # Create the plot with the specified alpha index
    fig = plot_metrics_and_coefficients(metrics_df, coefs_df, sample_name=patient_id, top_n=10, alpha_idx=None)
    
    if fig is not None:
        # Save the plot
        plot_file = os.path.join(plots_output_dir, f"metrics_and_coefs_{patient_id}.png")
        fig.savefig(plot_file, dpi=300, bbox_inches='tight')
        plt.close(fig)
        print(f"  Saved plot to {plot_file}")
        
        # If this patient has a hand-picked alpha, also save a version with '_selected' in the filename
        #if alpha_idx is not None:
        #    selected_plot_file = os.path.join(plots_output_dir, f"metrics_and_coefs_{patient_id}_selected.png")
        #    fig.savefig(selected_plot_file, dpi=300, bbox_inches='tight')
        #    print(f"  Saved selected alpha plot to {selected_plot_file}")
    else:
        print(f"  Could not create plot for {patient_id}")

Processing patient: P04
  Saved plot to ../pipeline/results_multivi_classification/metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots/metrics_and_coefs_P04.png
Processing patient: P02
  Saved plot to ../pipeline/results_multivi_classification/metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots/metrics_and_coefs_P02.png
Processing patient: P13
  Saved plot to ../pipeline/results_multivi_classification/metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots/metrics_and_coefs_P13.png
Processing patient: P05
  Saved plot to ../pipeline/results_multivi_classification/metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots/metrics_and_coefs_P05.png
Processing patient: P09
  Saved plot to ../pipeline/results_multivi_classification/metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots/metrics_and_coefs_P09.png
Processing patient: P06
  Saved plot to ../pipeline/results_multivi_classification/metrics_knee_lasso_per_patient_w_allPatient_multiVI_plots/metrics_and_coefs_P06.pn