In [1]:
# ============================================================================
# Cell 1: CONFIGURATION
# This cell contains all the settings a user needs to change.
# ============================================================================

# --- 1. File Paths ---
# Update these paths to match your file locations.
INPUT_FILE_PATH = "/PATH/TO/INPUT/FILE"
OUTPUT_DIR = "PATH/TO/OUTPUT/DIR"

# --- 2. Data Exclusions ---
# List any metabolites to remove from the analysis. Use an empty list [] if none.
METABOLITES_TO_EXCLUDE = ['9-methylanthracene']

# --- 3. Sample Naming Convention ---
# Define your experimental conditions and how to identify them using regular expressions.
# The key is the name of the condition (e.g., '+GFP').
# The value is a regex pattern that identifies samples belonging to that condition.
# The regex MUST include a named group `(?P<timepoint>\d+)` to capture the timepoint identifier.
CONDITIONS = {
    '+GFP': r'TM2A(?P<timepoint>\d+)_',
    '-GFP': r'TM2An(?P<timepoint>\d+)_'
}

# --- 4. Group & Timepoint Mapping ---
# Map the condition names above to numerical labels for the PLS-DA model.
# Typically, this is 0 and 1.
GROUP_LABELS = {
    '+GFP': 0,
    '-GFP': 1
}

# Map the captured timepoint identifier (from the regex) to a display name for plots.
TIMEPOINT_MAP = {
    '1': '0h',
    '2': '0.5h',
    '3': '2h',
    '4': '5h',
    '5': '10h'
}

# List the display names in the order you want them to appear in plot legends.
TIMEPOINT_PLOT_ORDER = ['0h', '0.5h', '2h', '5h', '10h']


# --- 5. PLS-DA & Plotting Parameters ---
# Maximum number of latent variables (components) to test during cross-validation.
MAX_PLS_COMPONENTS = 10
# VIP score threshold to draw on the VIP plot.
VIP_THRESHOLD = 1.0
# Number of top metabolites to show on the VIP bar plot.
TOP_N_VIP = 20
# Number of top metabolites (by VIP score) to label on the loadings plot.
TOP_N_LOADINGS = 15

print("--- Configuration Loaded ---")

--- Configuration Loaded ---


In [2]:
# ============================================================================
# Cell 2: SCRIPT LOGIC
# A user typically does not need to edit this cell.
# ============================================================================
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.cross_decomposition import PLSRegression
from sklearn.model_selection import cross_val_predict, LeaveOneOut
from sklearn.metrics import accuracy_score
import os
import re
from matplotlib.patches import Ellipse
import matplotlib.transforms as transforms
from matplotlib.lines import Line2D

# --- Helper Functions ---
def normalize_metabolite_name(name):
    if not isinstance(name, str): return name
    return name.lower().replace("'", "").replace('"', "").strip()

def pareto_scale(data):
    mean = np.mean(data, axis=0)
    std = np.std(data, axis=0, ddof=1)
    std[std == 0] = 1e-8 # Avoid division by zero
    return (data - mean) / np.sqrt(std)

def calculate_explained_variance(X_scaled, scores):
    total_variance_X = np.sum(np.var(X_scaled, axis=0))
    explained_variances = [np.var(scores[:, i]) / total_variance_X for i in range(scores.shape[1])]
    return np.array(explained_variances)

def add_confidence_ellipse(ax, x, y, **kwargs):
    if len(x) < 3: return
    cov = np.cov(x, y)
    if np.any(~np.isfinite(cov)): return
    pearson = cov[0, 1] / np.sqrt(cov[0, 0] * cov[1, 1])
    if not np.isfinite(pearson): return
    ell_radius_x, ell_radius_y = np.sqrt(1 + pearson), np.sqrt(1 - pearson)
    ellipse = Ellipse((0, 0), width=ell_radius_x * 2, height=ell_radius_y * 2, **kwargs)
    scale_x, scale_y = np.sqrt(cov[0, 0]) * 2.0, np.sqrt(cov[1, 1]) * 2.0
    mean_x, mean_y = np.mean(x), np.mean(y)
    transf = transforms.Affine2D().rotate_deg(45).scale(scale_x, scale_y).translate(mean_x, mean_y)
    ellipse.set_transform(transf + ax.transData); ax.add_patch(ellipse)

def calculate_vip(model):
    t, w, q = model.x_scores_, model.x_weights_, model.y_loadings_
    p, h = w.shape
    vips = np.zeros((p,))
    s = np.diag(t.T @ t @ q.T @ q).reshape(h, -1)
    total_s = np.sum(s)
    for i in range(p):
        weight = np.array([(w[i,j] / np.linalg.norm(w[:,j]))**2 for j in range(h)])
        vips[i] = np.sqrt(p * (s.T @ weight).item() / total_s)
    return vips

# --- Main Analysis Pipeline ---
def run_plsda_analysis(config):
    """Loads data, runs PLS-DA, and generates plots based on the config dictionary."""
    print("--- Starting PLS-DA Pipeline ---")
    os.makedirs(config['output_dir'], exist_ok=True)
    
    # 1. Load and Prepare Data
    try:
        df = pd.read_excel(config['input_file'], sheet_name=0, index_col=0)
        df.index = df.index.map(normalize_metabolite_name)
        print(f"✓ Loaded data from '{os.path.basename(config['input_file'])}'. Shape: {df.shape}")
    except Exception as e:
        print(f"✗ ERROR: Could not load input file. Details: {e}"); return

    # Exclude specified metabolites
    for met in config['metabolites_to_exclude']:
        met_norm = normalize_metabolite_name(met)
        if met_norm in df.index:
            df = df.drop(met_norm)
            print(f"  - Removed metabolite: '{met}'")

    # 2. Dynamically Parse Sample Information using Config
    all_sample_info = []
    for condition_name, pattern in config['conditions'].items():
        regex = re.compile(pattern, re.IGNORECASE)
        for col in df.columns:
            match = regex.search(str(col))
            if match:
                try:
                    tp_id = match.group('timepoint')
                    tp_name = config['timepoint_map'].get(tp_id, f"ID:{tp_id}")
                    all_sample_info.append({
                        'Sample': col, 'Condition': condition_name,
                        'TimepointID': tp_id, 'TimepointName': tp_name
                    })
                except IndexError:
                    print(f"✗ WARNING: Regex '{pattern}' is missing the named group '(?P<timepoint>...)'.")
    
    sample_info_df = pd.DataFrame(all_sample_info)
    if sample_info_df.empty:
        print("✗ ERROR: No samples matched the patterns in the CONDITIONS configuration. Halting analysis."); return

    print(f"✓ Parsed {len(sample_info_df)} samples into {len(sample_info_df['Condition'].unique())} conditions.")

    # 3. Create X and y matrices for PLS-DA
    X_df = df[sample_info_df['Sample']].T
    X = X_df.values
    y = sample_info_df['Condition'].map(config['group_labels']).values
    metabolite_names = X_df.columns.tolist()

    print("\nApplying Pareto scaling...")
    X_scaled = pareto_scale(X)

    # 4. Cross-validation to find optimal number of components
    print("\nDetermining optimal number of components via Leave-One-Out CV...")
    max_components = min(config['max_pls_components'], X_scaled.shape[0] - 1, X_scaled.shape[1])
    cv_scores = []
    for n_comp in range(1, max_components + 1):
        pls = PLSRegression(n_components=n_comp, scale=False)
        y_pred_cv = cross_val_predict(pls, X_scaled, y, cv=LeaveOneOut())
        accuracy = accuracy_score(y, (y_pred_cv > 0.5).astype(int))
        cv_scores.append(accuracy)
        print(f"  - Components: {n_comp}, CV Accuracy: {accuracy:.3f}")

    optimal_components = np.argmax(cv_scores) + 1
    print(f"✓ Optimal number of components found: {optimal_components}")

    # 5. Fit Final Model and Extract Results
    print("\nFitting final PLS-DA model...")
    pls_da = PLSRegression(n_components=optimal_components, scale=False)
    pls_da.fit(X_scaled, y)
    scores, loadings = pls_da.x_scores_, pls_da.x_loadings_
    vip_scores = calculate_vip(pls_da)
    vip_df = pd.DataFrame({'Metabolite': metabolite_names, 'VIP': vip_scores}).sort_values('VIP', ascending=False)
    vip_df.to_excel(os.path.join(config['output_dir'], 'vip_scores.xlsx'), index=False)
    print("✓ VIP scores calculated and saved.")

    # 6. Generate Visualizations
    print("\nCreating visualizations...")
    explained_variance_ratio = calculate_explained_variance(X_scaled, scores)
    
    # --- Scores Plot ---
    fig, ax = plt.subplots(figsize=(12, 10))
    colors = plt.cm.plasma_r(np.linspace(0.1, 0.9, len(config['timepoint_plot_order'])))
    color_dict = {name: color for name, color in zip(config['timepoint_plot_order'], colors)}

    for tp_name in config['timepoint_plot_order']:
        for condition, marker in {'+GFP': 'o', '-GFP': 's'}.items():
            mask = (sample_info_df['TimepointName'] == tp_name) & (sample_info_df['Condition'] == condition)
            if mask.any():
                ax.scatter(scores[mask, 0], scores[mask, 1], color=color_dict[tp_name], marker=marker,
                            s=150, alpha=0.8, edgecolors='black', linewidth=1.5)
    
    ellipse_handles = []
    for cond_name, color in [('+GFP', 'blue'), ('-GFP', 'red')]:
        mask = sample_info_df['Condition'] == cond_name
        if mask.sum() > 2:
            add_confidence_ellipse(ax, scores[mask, 0], scores[mask, 1], facecolor='none', edgecolor=color, linestyle='--', linewidth=2)
            ellipse_handles.append(Line2D([0], [0], color=color, lw=2, linestyle='--', label=f'{cond_name} 95% CI'))

    legend_timepoints = [Line2D([0], [0], marker='o', color='w', label=name, markerfacecolor=color_dict[name], markersize=12) for name in config['timepoint_plot_order']]
    legend_conditions = [Line2D([0], [0], marker=m, color='w', label=c, markerfacecolor='grey', markeredgecolor='black', markersize=12) for c, m in {'+GFP': 'o', '-GFP': 's'}.items()]
    ax.legend(handles=legend_timepoints + legend_conditions + ellipse_handles, title="Legend", bbox_to_anchor=(1.05, 1), loc='upper left')

    ax.set_xlabel(f'LV1 ({explained_variance_ratio[0]:.1%} variance)'); ax.set_ylabel(f'LV2 ({explained_variance_ratio[1]:.1%} variance)')
    ax.set_title('PLS-DA Scores Plot'); ax.grid(True, alpha=0.3); ax.axhline(0, c='grey', ls='--'); ax.axvline(0, c='grey', ls='--')
    plt.savefig(os.path.join(config['output_dir'], 'plsda_scores_plot.pdf'), dpi=300, bbox_inches='tight'); plt.close(fig)
    print("✓ Scores plot saved.")

    # --- Loadings Plot ---
    fig, ax = plt.subplots(figsize=(12, 10))
    scatter = ax.scatter(loadings[:, 0], loadings[:, 1], c=vip_scores, cmap='hot', s=80, alpha=0.7, edgecolors='black')
    top_vips = vip_df.head(config['top_n_loadings'])
    for _, row in top_vips.iterrows():
        met_idx = metabolite_names.index(row['Metabolite'])
        ax.text(loadings[met_idx, 0], loadings[met_idx, 1], row['Metabolite'], fontsize=8)
    cbar = plt.colorbar(scatter, ax=ax); cbar.set_label('VIP Score')
    ax.set_xlabel('LV1 Loadings'); ax.set_ylabel('LV2 Loadings'); ax.set_title('PLS-DA Loadings Plot'); ax.grid(True, alpha=0.3)
    plt.savefig(os.path.join(config['output_dir'], 'plsda_loadings_plot.pdf'), dpi=300, bbox_inches='tight'); plt.close(fig)
    print("✓ Loadings plot saved.")

    # --- VIP Score Barplot ---
    fig, ax = plt.subplots(figsize=(10, 8))
    vip_top = vip_df.head(config['top_n_vip'])
    ax.barh(vip_top['Metabolite'], vip_top['VIP'], color='steelblue'); ax.invert_yaxis()
    ax.set_xlabel('VIP Score'); ax.set_title(f"Top {config['top_n_vip']} Metabolites by VIP Score")
    ax.axvline(x=config['vip_threshold'], color='red', linestyle='--', label=f"VIP > {config['vip_threshold']} Threshold"); ax.legend()
    plt.savefig(os.path.join(config['output_dir'], 'vip_scores_barplot.pdf'), dpi=300, bbox_inches='tight'); plt.close(fig)
    print("✓ VIP scores bar plot saved.")
    
    print("\n--- Pipeline Finished ---")

# --- EXECUTION ---
# This block creates a configuration dictionary and runs the pipeline.
if __name__ == "__main__":
    config_dict = {
        "input_file": INPUT_FILE_PATH,
        "output_dir": OUTPUT_DIR,
        "metabolites_to_exclude": METABOLITES_TO_EXCLUDE,
        "conditions": CONDITIONS,
        "group_labels": GROUP_LABELS,
        "timepoint_map": TIMEPOINT_MAP,
        "timepoint_plot_order": TIMEPOINT_PLOT_ORDER,
        "max_pls_components": MAX_PLS_COMPONENTS,
        "vip_threshold": VIP_THRESHOLD,
        "top_n_vip": TOP_N_VIP,
        "top_n_loadings": TOP_N_LOADINGS
    }
    run_plsda_analysis(config_dict)

--- Starting PLS-DA Pipeline ---
✓ Loaded data from 'MOD_RF_Imputed.xlsx'. Shape: (115, 50)
  - Removed metabolite: '9-methylanthracene'
✓ Parsed 50 samples into 2 conditions.

Applying Pareto scaling...

Determining optimal number of components via Leave-One-Out CV...
  - Components: 1, CV Accuracy: 0.540
  - Components: 2, CV Accuracy: 0.740
  - Components: 3, CV Accuracy: 0.820
  - Components: 4, CV Accuracy: 0.820
  - Components: 5, CV Accuracy: 0.860
  - Components: 6, CV Accuracy: 0.880
  - Components: 7, CV Accuracy: 0.860
  - Components: 8, CV Accuracy: 0.920
  - Components: 9, CV Accuracy: 0.880
  - Components: 10, CV Accuracy: 0.880
✓ Optimal number of components found: 8

Fitting final PLS-DA model...
✓ VIP scores calculated and saved.

Creating visualizations...
✓ Scores plot saved.
✓ Loadings plot saved.
✓ VIP scores bar plot saved.

--- Pipeline Finished ---
