In [1]:
import os
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from anndata import AnnData



In [2]:
# Path to the full anndata object containing all patients and timepoints
ADATA_FULL_PATH = '/home/minhang/mds_project/data/cohort_adata/multiVI_model/adata_multivi_corrected_rna.h5ad'
# Base directory where FA models and LR-Lasso results were saved by your pipeline
BASE_MODEL_OUTPUT_DIR = "results_fa_lrlasso_mrd_pipeline_donor_downsampled_projection_ready_may31"
# Paths to saved global objects from preprocessing
SAVED_HVG_LIST_PATH = '/home/minhang/mds_project/sc_classification/pipeline/mrd_hvg_genes_may31.pkl'
SAVED_SCALER_PATH = '/home/minhang/mds_project/sc_classification/pipeline/mrd_std_scaler_may31.pkl'

In [3]:
# FA settings to iterate through
N_COMPONENTS_LIST_FOR_FA = [100] # Same as in your training pipeline
PATIENT_COL = 'patient' # Column name for patient ID in adata_full.obs
TIMEPOINT_COL = 'timepoint_type' # Column name for timepoint in adata_full.obs
CELL_LABEL_COL = 'CN.label' # Column for 'cancer'/'normal' labels
POSITIVE_LABEL = 'cancer' # The label for malignant cells

VALIDATION_TIMEPOINTS = ['preSCT', 'Relapse']
OUTPUT_FIGURE_DIR = "QE_validation_projection_figures_jun1"
os.makedirs(OUTPUT_FIGURE_DIR, exist_ok=True)

In [4]:
def load_mrd_trained_objects(n_total_factors, base_model_dir, hvg_path, scaler_path):
    """Loads the HVG list, standard scaler, and FA model trained on MRD data."""
    print(f"Loading trained objects for {n_total_factors} total factors...")
    with open(hvg_path, 'rb') as f:
        hvg_list = pickle.load(f)
    
    with open(scaler_path, 'rb') as f:
        scaler = pickle.load(f)
        
    fa_model_path = os.path.join(base_model_dir, f"n_factors_{n_total_factors}", 
                                 "projection_model_objects", f"fa_model_{n_total_factors}factors.pkl")
    if not os.path.exists(fa_model_path):
        raise FileNotFoundError(f"FA model not found at {fa_model_path}")
    with open(fa_model_path, 'rb') as f:
        fa_model = pickle.load(f)
        
    print(f"Loaded HVG list ({len(hvg_list)} genes), scaler, and FA model.")
    return hvg_list, scaler, fa_model

In [5]:
def preprocess_and_project(adata_vp, hvg_list, scaler, fa_model):
    """
    Preprocesses validation data: subsets to HVGs, scales, and projects using FA.
    Ensures HVGs are in the correct order.
    """
    print(f"  Preprocessing: Original shape {adata_vp.shape}")
    
    # Ensure all HVGs are present in the validation adata
    missing_hvgs = [hvg for hvg in hvg_list if hvg not in adata_vp.var_names]
    if missing_hvgs:
        print(f"    Warning: {len(missing_hvgs)} HVGs missing from validation data. These will be effectively zeroed if not handled. Example: {missing_hvgs[:5]}")
        # Option 1: Error out
        # raise ValueError(f"{len(missing_hvgs)} HVGs not found in validation data. Cannot proceed.")
        # Option 2: Create a zero matrix for missing genes and align (complex, can introduce issues)
        # For now, we'll proceed, but Scanpy's subsetting below will handle it.
        # If genes are missing, they won't be in the reindexed adata_vp_hvg.X

    # Subset to HVGs *in the specific order* of hvg_list
    # Create a new AnnData with only the HVGs in the correct order
    # Some HVGs might be missing from a particular validation sample
    available_hvgs_in_order = [hvg for hvg in hvg_list if hvg in adata_vp.var_names]
    if len(available_hvgs_in_order) != len(hvg_list):
        print(f"    Warning: Only {len(available_hvgs_in_order)} out of {len(hvg_list)} HVGs are present in this sample.")

    if not available_hvgs_in_order:
        print("    Error: No HVGs available in the sample. Cannot proceed with projection.")
        return None
        
    adata_vp_hvg = adata_vp[:, available_hvgs_in_order].copy()

    # If genes were missing and you need to maintain the exact shape for scaler/FA model trained on full HVG list:
    # You might need to create a DataFrame, reindex with the full hvg_list, fill NaNs with 0, then convert to array.
    # This ensures the input to scaler.transform has the same number of features as during scaler.fit
    if adata_vp_hvg.n_vars != len(hvg_list):
        print(f"    Re-aligning data to match {len(hvg_list)} HVGs for scaler/FA model compatibility.")
        current_X_df = pd.DataFrame(adata_vp_hvg.X.toarray() if hasattr(adata_vp_hvg.X, 'toarray') else adata_vp_hvg.X, 
                                    index=adata_vp_hvg.obs_names, 
                                    columns=adata_vp_hvg.var_names)
        aligned_X_df = current_X_df.reindex(columns=hvg_list, fill_value=0)
        X_for_scaling = aligned_X_df.values
    else:
        X_for_scaling = adata_vp_hvg.X.toarray() if hasattr(adata_vp_hvg.X, 'toarray') else adata_vp_hvg.X

    if X_for_scaling.shape[1] == 0:
        print("    Error: No features available after HVG selection for scaling.")
        return None

    print(f"    Scaling data with {X_for_scaling.shape[1]} features...")
    X_scaled = scaler.transform(X_for_scaling)
    
    print(f"    Projecting scaled data using FA model into {fa_model.n_components} factors...")
    X_fa_projected = fa_model.transform(X_scaled)
    
    # Create a new AnnData for the projected data
    adata_projected = AnnData(X=adata_vp.X, obs=adata_vp.obs.copy(), var=adata_vp.var.copy()) # Keep original X for reference if needed
    adata_projected.obsm['X_fa_projected'] = X_fa_projected
    
    print(f"  Projection complete. Shape of X_fa_projected: {X_fa_projected.shape}")
    return adata_projected

In [6]:
def get_selected_factor_indices_for_patient(patient_id, n_total_factors, base_model_dir, 
                                            chose_reg_strength_dict, alphas_lasso_array, 
                                            positive_label='cancer'): # positive_label currently not used here but kept for signature consistency
    """
    Loads LR-Lasso coefficients and determines selected factor INDICES based on
    a handpicked regularization strength for the patient.
    
    Parameters:
    - patient_id (str): The ID of the patient.
    - n_total_factors (int): The total number of factors in the FA model (e.g., 20, 50, 100).
    - base_model_dir (str): Base directory where LR-Lasso results are stored.
    - chose_reg_strength_dict (dict): Maps patient_id to a 1-based index for alphas_lasso_array.
    - alphas_lasso_array (np.array): The array of alpha values used during training.
    
    Returns:
    - list: A sorted list of unique 0-based indices for the selected factors.
    """
    coef_path = os.path.join(base_model_dir, f"n_factors_{n_total_factors}",
                             f"patient_{patient_id}", 
                             f"results_coefs_{patient_id}_factors_{n_total_factors}.csv")
    if not os.path.exists(coef_path):
        print(f"Warning: Coefficients file not found for patient {patient_id}, {n_total_factors} factors: {coef_path}")
        return [] 

    try:
        coef_df = pd.read_csv(coef_path, index_col=0) # Factor names are index (e.g., 'X_fa_1')
    except Exception as e:
        print(f"Error reading coefficients file {coef_path}: {e}")
        return []

    if coef_df.empty:
        print(f"Warning: Coefficients DataFrame is empty for patient {patient_id}, {n_total_factors} factors.")
        return []

    # Get the chosen 1-based index for the patient's regularization strength
    chosen_1_based_idx = chose_reg_strength_dict.get(patient_id)

    if chosen_1_based_idx is None:
        print(f"Warning: No regularization strength index specified for patient {patient_id} in chose_reg_strength_dict. No factors selected.")
        return []
    if not (1 <= chosen_1_based_idx <= len(alphas_lasso_array)):
        print(f"Warning: Invalid 1-based index ({chosen_1_based_idx}) for patient {patient_id}. Must be between 1 and {len(alphas_lasso_array)}. No factors selected.")
        return []

    # Convert to 0-based index for array access
    chosen_0_based_idx = chosen_1_based_idx - 1
    
    # The columns in coef_df are the alpha values themselves (from ALPHAS_LASSO during training script).
    # When pd.read_csv loads them, they might be strings.
    # It's safest to access the column by its position (0-based index) if we are sure
    # the order of columns in the CSV matches the order in alphas_lasso_array.
    # The run_fa_lr_lasso_pipeline_downsampling_donor.py script saves:
    # coefs_df = pd.DataFrame(..., columns=ALPHAS_LASSO)
    # So, the columns in the CSV should correspond positionally to ALPHAS_LASSO.
    
    if chosen_0_based_idx >= len(coef_df.columns):
        print(f"Warning: Chosen alpha index {chosen_0_based_idx} is out of bounds for the columns in coef_df for patient {patient_id} (max index: {len(coef_df.columns)-1}). No factors selected.")
        return []
        
    # Select the column of coefficients corresponding to the chosen alpha
    # The column names in coef_df should be the string representations of the alpha values.
    # We can use the 0-based index directly if the column order is preserved.
    selected_alpha_column_name = coef_df.columns[chosen_0_based_idx]
    selected_factor_series = coef_df[selected_alpha_column_name]
    actual_alpha_value_used = alphas_lasso_array[chosen_0_based_idx] # For printing

    selected_indices = []
    for factor_name in selected_factor_series.index: # factor_name is like 'X_fa_1', 'X_fa_2', etc.
        if selected_factor_series[factor_name] != 0:
            try:
                # Assumes factor names are 'X_fa_1', 'X_fa_2', ...
                # This matches feature_names = [f"X_fa_{i+1}" for i in range(X.shape[1])]
                # from your LRLasso class when preparing data for fit_along_regularization_path.
                if factor_name.startswith("X_fa_"):
                    idx = int(factor_name.split('_')[-1]) - 1 # Convert 'X_fa_1' to 0, 'X_fa_2' to 1
                    if 0 <= idx < n_total_factors: # Basic sanity check
                        selected_indices.append(idx)
                    else:
                        print(f"Warning: Parsed index {idx} for factor {factor_name} is out of expected range [0, {n_total_factors-1}] for patient {patient_id}.")
                else:
                    print(f"Warning: Factor name {factor_name} does not start with 'X_fa_'. Cannot parse index for patient {patient_id}.")
            except ValueError:
                print(f"Warning: Could not parse integer index from factor name {factor_name} for patient {patient_id}.")
            except Exception as e:
                print(f"Error parsing factor name {factor_name} for patient {patient_id}: {e}")
    
    num_selected = len(selected_indices)
    unique_sorted_indices = sorted(list(set(selected_indices)))
    if num_selected != len(unique_sorted_indices):
        print(f"Warning: Duplicate indices found for patient {patient_id}, {n_total_factors} factors. Using unique sorted list.")

    print(f"Patient {patient_id}, {n_total_factors} total factors: Using {chosen_1_based_idx}(th) alpha value ({actual_alpha_value_used:.2e}). Selected {len(unique_sorted_indices)} factor indices.")
    return unique_sorted_indices

In [7]:
ALPHAS_LASSO_TRAINING = np.logspace(-4, 5, 20)

CHOSE_REG_STRENGTH_DICT = {
    "P01": 10, 
    "P02": 13, # 14 potentially
    "P03": 17,
    "P04": 12,
    "P05": 11, # 12 potentially
    "P06": 10, # 9 potentially
    "P07": 12, # 9 potentially
    "P09": 14,
    "P13": 10 # 12 potentially
}

In [8]:
print("Loading full AnnData...")
adata_full = sc.read_h5ad(ADATA_FULL_PATH)
print(f"Full AnnData loaded: {adata_full.shape}")

#unique_patients = sorted(list(adata_full.obs[PATIENT_COL].unique()))

#print(f"Unique patients found: {unique_patients}")

Loading full AnnData...
Full AnnData loaded: (192149, 36601)


In [9]:
unique_patients = ['P01', 'P02', 'P03', 'P04', 'P05', 'P06', 'P07', 'P09', 'P13']  # For testing, use a subset of patients

In [10]:
for n_factors in N_COMPONENTS_LIST_FOR_FA:
    print(f"\n--- Processing for FA models with {n_factors} total factors ---")
    hvg_list, scaler, fa_model = load_mrd_trained_objects(n_factors, BASE_MODEL_OUTPUT_DIR, SAVED_HVG_LIST_PATH, SAVED_SCALER_PATH)
    
    for patient_id in unique_patients:
        print(f"\n  -- Patient: {patient_id} --")
        
        # Get selected factor indices for this patient and n_factors setting
        selected_factor_indices = get_selected_factor_indices_for_patient(
            patient_id, 
            n_factors,  # This is n_total_factors
            BASE_MODEL_OUTPUT_DIR,
            CHOSE_REG_STRENGTH_DICT,    # Pass the dictionary here
            ALPHAS_LASSO_TRAINING       # Pass the alpha array here
        )
        
        if not selected_factor_indices:
            print(f"    No factors selected for patient {patient_id}, {n_factors} factors. Skipping validation timepoints.")
            continue
        print(f"    Using {len(selected_factor_indices)} selected factor indices: {selected_factor_indices[:10]}...")

        for val_tp in VALIDATION_TIMEPOINTS:
            print(f"\n    - Validation Timepoint: {val_tp} -")
            
            adata_patient_tp_mask = (adata_full.obs[PATIENT_COL] == patient_id) & \
                                    (adata_full.obs[TIMEPOINT_COL] == val_tp)
            adata_validation = adata_full[adata_patient_tp_mask].copy()
            
            if adata_validation.n_obs == 0:
                print(f"      No data for patient {patient_id} at timepoint {val_tp}. Skipping.")
                continue
            
            print(f"      Data found: {adata_validation.n_obs} cells, {adata_validation.n_vars} genes.")
            
            # Preprocess (subset HVGs, scale) and project using FA model
            adata_val_projected = preprocess_and_project(adata_validation, hvg_list, scaler, fa_model)
            
            if adata_val_projected is None or 'X_fa_projected' not in adata_val_projected.obsm:
                print(f"      Projection failed for patient {patient_id}, {val_tp}. Skipping visualization.")
                continue

            # Subset to only the selected factors
            try:
                adata_val_projected.obsm['X_fa_selected'] = adata_val_projected.obsm['X_fa_projected'][:, selected_factor_indices]
            except IndexError as e:
                print(f"      Error selecting factor indices: {e}. Max index available: {adata_val_projected.obsm['X_fa_projected'].shape[1]-1}. Indices wanted: {selected_factor_indices}")
                continue
                
            n_selected = adata_val_projected.obsm['X_fa_selected'].shape[1]
            title_prefix = f"Patient {patient_id} - {val_tp} ({n_factors} total FA, {n_selected} selected)"
            filename_prefix = f"p{patient_id}_{val_tp}_totalFA{n_factors}_selected{n_selected}"

            # Ensure there are cells and selected factors to plot
            if adata_val_projected.n_obs == 0 or n_selected == 0:
                print("      No cells or no selected factors to visualize. Skipping.")
                continue
            
            elif n_selected == 1:
                print(f"      Only 1 factor initially selected ({selected_factor_indices[0]}). Plotting its violin distribution.")
                primary_factor_idx = selected_factor_indices[0]
                temp_obs_col_name_violin = f"factor_to_plot_{primary_factor_idx}"

                if primary_factor_idx < adata_val_projected.obsm['X_fa_projected'].shape[1]:
                    adata_val_projected.obs[temp_obs_col_name_violin] = adata_val_projected.obsm['X_fa_projected'][:, primary_factor_idx]
                    
                    sc.pl.violin(adata_val_projected, keys=[temp_obs_col_name_violin], groupby=CELL_LABEL_COL, use_raw=False,
                                stripplot=True, jitter=0.4, show=False, 
                                save=f"_{filename_prefix}_factor{primary_factor_idx}_violin.png")
                    plt.title(f"{title_prefix} - Violin Factor {primary_factor_idx}")
                    plt.close() # Ensure plot is closed
                    
                    del adata_val_projected.obs[temp_obs_col_name_violin]
                else:
                    print(f"      Error: Primary factor index {primary_factor_idx} is out of bounds for X_fa_projected.")

                # --- New: Attempt to find a second factor for a 2D plot ---
                print(f"      Attempting to find a second factor for 2D plot (primary was {primary_factor_idx})...")
                second_factor_idx = None
                # Load coef_df for the current patient and n_total_factors
                coef_path = os.path.join(BASE_MODEL_OUTPUT_DIR, f"n_factors_{n_factors}",
                                        f"patient_{patient_id}", 
                                        f"results_coefs_{patient_id}_factors_{n_factors}.csv")
                if os.path.exists(coef_path):
                    coef_df = pd.read_csv(coef_path, index_col=0)
                    chosen_1_based_idx = CHOSE_REG_STRENGTH_DICT.get(patient_id)

                    if chosen_1_based_idx is not None and not coef_df.empty:
                        chosen_0_based_idx = chosen_1_based_idx - 1
                        
                        # Iterate towards less regularization (earlier alpha columns)
                        for j in range(chosen_0_based_idx - 1, -1, -1): # from chosen_alpha_idx-1 down to 0
                            if j >= len(coef_df.columns): continue # Should not happen if chosen_0_based_idx is valid

                            current_alpha_col_name = coef_df.columns[j]
                            factors_at_this_alpha_series = coef_df[current_alpha_col_name]
                            non_zero_factors_at_this_alpha = factors_at_this_alpha_series[factors_at_this_alpha_series != 0]
                            
                            for factor_name in non_zero_factors_at_this_alpha.index:
                                try:
                                    if factor_name.startswith("X_fa_"):
                                        idx = int(factor_name.split('_')[-1]) - 1
                                        if idx != primary_factor_idx and 0 <= idx < n_factors:
                                            second_factor_idx = idx
                                            print(f"        Found second factor: {second_factor_idx} at alpha column {j} (value: {ALPHAS_LASSO_TRAINING[j]:.2e})")
                                            break # Found the second factor
                                    else: continue # factor name not in expected format
                                except ValueError: continue # parsing factor_name failed
                            if second_factor_idx is not None:
                                break # Break from iterating through alpha columns
                    else:
                        print("        Could not determine chosen alpha or coef_df is empty for finding second factor.")
                else:
                    print(f"        Coefficients file not found for finding second factor: {coef_path}")

                if second_factor_idx is not None:
                    print(f"      Plotting 2D scatter with Factor {primary_factor_idx} and Factor {second_factor_idx}.")
                    # Prepare data for 2D scatter plot
                    factors_for_2d_plot_indices = sorted([primary_factor_idx, second_factor_idx])
                    temp_obsm_key_2d = 'X_fa_2factors_temp'
                    
                    # Ensure indices are valid before slicing
                    max_proj_factor_idx = adata_val_projected.obsm['X_fa_projected'].shape[1] -1
                    if all(idx <= max_proj_factor_idx for idx in factors_for_2d_plot_indices):
                        adata_val_projected.obsm[temp_obsm_key_2d] = adata_val_projected.obsm['X_fa_projected'][:, factors_for_2d_plot_indices]
                        
                        plot_title_2d = f"{title_prefix}\n2D: F{factors_for_2d_plot_indices[0]} & F{factors_for_2d_plot_indices[1]}"
                        file_suffix_2d = f"_{filename_prefix}_2Dscatter_F{factors_for_2d_plot_indices[0]}_F{factors_for_2d_plot_indices[1]}.png"

                        sc.pl.embedding(adata_val_projected, basis=temp_obsm_key_2d, color=CELL_LABEL_COL,
                                        title=plot_title_2d, show=False, save=file_suffix_2d)
                        plt.close() # Ensure plot is closed
                        
                        del adata_val_projected.obsm[temp_obsm_key_2d] # Clean up
                    else:
                        print(f"      Error: One or both factor indices for 2D plot ({factors_for_2d_plot_indices}) are out of bounds for X_fa_projected.")
                else:
                    print("      No suitable second factor found for 2D plot.")

            elif n_selected == 2:
                print("      2 factors selected. Plotting 2D scatter.")
                sc.pl.embedding(adata_val_projected, basis='X_fa_selected', color=CELL_LABEL_COL,
                                title=title_prefix, show=False, save=f"_{filename_prefix}_scatter2D.png")
                # The save functionality of sc.pl.embedding might need adjustment for custom paths.
                # If sc.pl.embedding's save doesn't work as expected:
                # fig = sc.pl.embedding(adata_val_projected, basis='X_fa_selected', color=CELL_LABEL_COL, title=title_prefix, show=False, return_fig=True)
                # fig.savefig(os.path.join(OUTPUT_FIGURE_DIR, f"{filename_prefix}_scatter2D.png"))
                # plt.close(fig)


            elif n_selected == 3:
                print("      3 factors selected. Plotting 3D scatter.")
                # Scanpy's direct 3D scatter with pl.embedding is limited. Use matplotlib.
                fig = plt.figure()
                ax = fig.add_subplot(111, projection='3d')
                X_plot = adata_val_projected.obsm['X_fa_selected']
                # Create a color map for 'cancer' and 'normal'
                colors = {'cancer': 'red', 'normal': 'blue'} # Add other labels if present
                cell_colors = [colors.get(label, 'grey') for label in adata_val_projected.obs[CELL_LABEL_COL]]

                ax.scatter(X_plot[:, 0], X_plot[:, 1], X_plot[:, 2], c=cell_colors, s=5)
                ax.set_xlabel("Selected Factor 1")
                ax.set_ylabel("Selected Factor 2")
                ax.set_zlabel("Selected Factor 3")
                plt.title(title_prefix)
                plt.savefig(os.path.join(OUTPUT_FIGURE_DIR, f"{filename_prefix}_scatter3D.png"))
                plt.close(fig)

            elif n_selected > 3:
                print(f"      {n_selected} factors selected. Computing and plotting UMAP based on these factors.")
                sc.pp.neighbors(adata_val_projected, use_rep='X_fa_selected', n_neighbors=min(15, adata_val_projected.n_obs -1) if adata_val_projected.n_obs >1 else 1) # Adjust n_neighbors
                if adata_val_projected.n_obs > 1 : # UMAP requires more than 1 sample
                    sc.tl.umap(adata_val_projected)
                    sc.pl.umap(adata_val_projected, color=CELL_LABEL_COL, 
                           title=title_prefix, show=False, save=f"_{filename_prefix}_umap.png")
                else:
                    print("      Not enough cells to compute UMAP.")
            else: # n_selected == 0
                 print("      No factors selected to plot.")


print("\n--- Validation Projection and Visualization Script Finished ---")


--- Processing for FA models with 100 total factors ---
Loading trained objects for 100 total factors...
Loaded HVG list (3000 genes), scaler, and FA model.

  -- Patient: P01 --
Patient P01, 100 total factors: Using 10(th) alpha value (1.83e+00). Selected 55 factor indices.
    Using 55 selected factor indices: [3, 4, 7, 8, 11, 14, 16, 19, 22, 24]...

    - Validation Timepoint: preSCT -
      Data found: 21707 cells, 36601 genes.
  Preprocessing: Original shape (21707, 36601)
    Scaling data with 3000 features...
    Projecting scaled data using FA model into 100 factors...
  Projection complete. Shape of X_fa_projected: (21707, 100)
      55 factors selected. Computing and plotting UMAP based on these factors.

    - Validation Timepoint: Relapse -
      Data found: 5115 cells, 36601 genes.
  Preprocessing: Original shape (5115, 36601)
    Scaling data with 3000 features...
    Projecting scaled data using FA model into 100 factors...
  Projection complete. Shape of X_fa_projected

### Comparing with MRD timepoint 

In [8]:
def generate_comparison_plots(
    adata_with_projection, 
    selected_indices_from_mrd_lasso,
    p_id, 
    tp_name, 
    n_total_fa_components,
    coef_dir_base, # Path to where patient-specific coef CSVs are, e.g., BASE_MODEL_OUTPUT_DIR
    reg_strength_map, # CHOSE_REG_STRENGTH_DICT
    alphas_array, # ALPHAS_LASSO_TRAINING
    fig_output_dir
):
    """
    Generates a suite of plots for a given timepoint using FA scores and selected factors.
    adata_with_projection: AnnData object with .obsm['X_fa_projected']
    selected_indices_from_mrd_lasso: List of 0-based factor indices selected from MRD LR-Lasso.
    """
    print(f"      Generating plots for Patient {p_id}, Timepoint: {tp_name}, Total FA: {n_total_fa_components}")
    
    if adata_with_projection.n_obs == 0:
        print("        No cells to visualize. Skipping.")
        return

    # Create a working copy for this function to avoid modifying original adata in main loop
    adata_plot = adata_with_projection.copy()

    # Prepare .obsm['X_fa_selected'] if selected_indices are provided and valid
    n_selected_initially = 0
    if selected_indices_from_mrd_lasso:
        try:
            # Ensure all selected indices are within the bounds of projected factors
            max_projected_idx = adata_plot.obsm['X_fa_projected'].shape[1] - 1
            valid_selected_indices = [idx for idx in selected_indices_from_mrd_lasso if idx <= max_projected_idx]
            
            if len(valid_selected_indices) != len(selected_indices_from_mrd_lasso):
                print(f"        Warning: Some selected factor indices were out of bounds. Original: {len(selected_indices_from_mrd_lasso)}, Valid: {len(valid_selected_indices)}")
            
            if not valid_selected_indices:
                 print(f"        No valid selected factor indices after boundary check. Cannot create 'X_fa_selected'.")
            else:
                adata_plot.obsm['X_fa_selected'] = adata_plot.obsm['X_fa_projected'][:, valid_selected_indices]
                n_selected_initially = adata_plot.obsm['X_fa_selected'].shape[1]
        
        except IndexError as e:
            print(f"        Error creating 'X_fa_selected': {e}. Indices: {selected_indices_from_mrd_lasso}, Max available: {adata_plot.obsm['X_fa_projected'].shape[1]-1}")
            return # Cannot proceed if this fails
        except Exception as e_gen:
            print(f"        Unexpected error creating 'X_fa_selected': {e_gen}")
            return
    else: # No factors were initially selected by LR-Lasso for this patient/n_total_fa
        print(f"        No factors were initially selected from MRD LR-Lasso for Patient {p_id}, Total FA {n_total_fa_components}.")


    title_prefix_base = f"Patient {p_id} - {tp_name} ({n_total_fa_components}FA)"
    filename_prefix_base = f"p{p_id}_{tp_name}_totalFA{n_total_fa_components}"

    # --- Plotting logic based on n_selected_initially ---
    if n_selected_initially == 0:
        print("        No initially selected factors to plot from MRD anlaysis.")
        # Optionally, could plot UMAP of ALL projected factors here if desired
        # sc.pp.neighbors(adata_plot, use_rep='X_fa_projected', n_neighbors=min(15, adata_plot.n_obs-1) if adata_plot.n_obs > 1 else 1)
        # if adata_plot.n_obs > 1:
        #     sc.tl.umap(adata_plot)
        #     sc.pl.umap(adata_plot, color=CELL_LABEL_COL, title=f"{title_prefix_base} - UMAP of ALL Projected Factors", 
        #                 show=False, save=f"_{filename_prefix_base}_UMAPallFA.png")
        #     plt.close()

    elif n_selected_initially == 1:
        primary_factor_idx = selected_indices_from_mrd_lasso[0] # This is the only selected factor
        title_prefix_violin = f"{title_prefix_base} (1sel: F{primary_factor_idx})"
        filename_prefix_violin = f"{filename_prefix_base}_selF{primary_factor_idx}"
        
        print(f"        Plotting violin for primary selected Factor {primary_factor_idx}.")
        temp_obs_col_violin = f"factor_plot_{primary_factor_idx}"
        if primary_factor_idx < adata_plot.obsm['X_fa_projected'].shape[1]:
            adata_plot.obs[temp_obs_col_violin] = adata_plot.obsm['X_fa_projected'][:, primary_factor_idx]
            sc.pl.violin(adata_plot, keys=[temp_obs_col_violin], groupby=CELL_LABEL_COL, use_raw=False,
                         stripplot=True, jitter=0.4, show=False, 
                         save=f"_{filename_prefix_violin}_violin.png")
            plt.title(title_prefix_violin)
            plt.close()
            del adata_plot.obs[temp_obs_col_violin]
        else:
            print(f"        Error: Primary factor index {primary_factor_idx} for violin plot is out of bounds.")

        # Attempt to find a second factor for a 2D plot
        print(f"        Attempting to find a second factor for 2D plot (primary was F{primary_factor_idx})...")
        second_factor_idx = None
        coef_path = os.path.join(coef_dir_base, f"n_factors_{n_total_fa_components}",
                                 f"patient_{p_id}", 
                                 f"results_coefs_{p_id}_factors_{n_total_fa_components}.csv")
        if os.path.exists(coef_path):
            coef_df = pd.read_csv(coef_path, index_col=0)
            chosen_1_based_idx = reg_strength_map.get(p_id)
            if chosen_1_based_idx is not None and not coef_df.empty:
                chosen_0_based_idx = chosen_1_based_idx - 1
                for j in range(chosen_0_based_idx - 1, -1, -1):
                    if j >= len(coef_df.columns): continue
                    current_alpha_col_name = coef_df.columns[j]
                    non_zero_factors = coef_df[coef_df[current_alpha_col_name] != 0].index
                    for factor_name in non_zero_factors:
                        try:
                            if factor_name.startswith("X_fa_"):
                                idx = int(factor_name.split('_')[-1]) - 1
                                if idx != primary_factor_idx and 0 <= idx < n_total_fa_components:
                                    second_factor_idx = idx
                                    print(f"          Found second factor: F{second_factor_idx} at alpha column {j} (value: {alphas_array[j]:.2e})")
                                    break
                        except ValueError: continue
                    if second_factor_idx is not None: break
        
        if second_factor_idx is not None:
            factors_for_2d = sorted([primary_factor_idx, second_factor_idx])
            print(f"        Plotting 2D scatter with Factor {factors_for_2d[0]} and Factor {factors_for_2d[1]}.")
            temp_obsm_key_2d = 'X_fa_temp_2factors'
            if all(idx < adata_plot.obsm['X_fa_projected'].shape[1] for idx in factors_for_2d):
                adata_plot.obsm[temp_obsm_key_2d] = adata_plot.obsm['X_fa_projected'][:, factors_for_2d]
                title_2d = f"{title_prefix_base}\n(F{primary_factor_idx} & 2nd F{second_factor_idx})"
                file_2d = f"{filename_prefix_base}_2Dscatter_F{factors_for_2d[0]}_F{factors_for_2d[1]}.png"
                sc.pl.embedding(adata_plot, basis=temp_obsm_key_2d, color=CELL_LABEL_COL,
                                title=title_2d, show=False, save=file_2d)
                plt.close()
                del adata_plot.obsm[temp_obsm_key_2d]
            else:
                print(f"        Error: Factor indices for 2D plot {factors_for_2d} are out of bounds.")
        else:
            print("        No suitable second factor found for 2D plot.")

    elif n_selected_initially == 2:
        title_scatter2d = f"{title_prefix_base} (2sel: F{selected_indices_from_mrd_lasso[0]},{selected_indices_from_mrd_lasso[1]})"
        filename_scatter2d = f"{filename_prefix_base}_sel2_scatter2D.png"
        print(f"        Plotting 2D scatter for selected factors: {selected_indices_from_mrd_lasso}.")
        sc.pl.embedding(adata_plot, basis='X_fa_selected', color=CELL_LABEL_COL,
                        title=title_scatter2d, show=False, save=filename_scatter2d)
        plt.close()

    elif n_selected_initially == 3:
        title_scatter3d = f"{title_prefix_base} (3sel: F{','.join(map(str,selected_indices_from_mrd_lasso))})"
        filename_scatter3d = f"{filename_prefix_base}_sel3_scatter3D.png"
        print(f"        Plotting 3D scatter for selected factors: {selected_indices_from_mrd_lasso}.")
        fig = plt.figure()
        ax = fig.add_subplot(111, projection='3d')
        X_plot_3d = adata_plot.obsm['X_fa_selected']
        colors_map = {'cancer': 'red', 'normal': 'blue'}
        cell_plot_colors = [colors_map.get(label, 'grey') for label in adata_plot.obs[CELL_LABEL_COL]]
        ax.scatter(X_plot_3d[:, 0], X_plot_3d[:, 1], X_plot_3d[:, 2], c=cell_plot_colors, s=10, alpha=0.7)
        ax.set_xlabel(f"Factor {selected_indices_from_mrd_lasso[0]}")
        ax.set_ylabel(f"Factor {selected_indices_from_mrd_lasso[1]}")
        ax.set_zlabel(f"Factor {selected_indices_from_mrd_lasso[2]}")
        plt.title(title_scatter3d, fontsize=10)
        fig.tight_layout()
        plt.savefig(os.path.join(fig_output_dir, filename_scatter3d), dpi=150)
        plt.close(fig)

    elif n_selected_initially > 3:
        title_umap = f"{title_prefix_base} ({n_selected_initially}sel UMAP)"
        filename_umap = f"{filename_prefix_base}_sel{n_selected_initially}_UMAP.png"
        print(f"        Plotting UMAP for {n_selected_initially} selected factors.")
        if adata_plot.n_obs > 1:
            n_neighbors_val = min(15, adata_plot.n_obs - 1) if adata_plot.n_obs > 15 else max(1, adata_plot.n_obs - 1)
            if n_neighbors_val > 0:
                sc.pp.neighbors(adata_plot, use_rep='X_fa_selected', n_neighbors=n_neighbors_val)
                sc.tl.umap(adata_plot)
                sc.pl.umap(adata_plot, color=CELL_LABEL_COL, title=title_umap,
                            show=False, save=filename_umap)
                plt.close()
            else: print("        Not enough cells/neighbors for UMAP.")
        else: print("        Not enough cells for UMAP.")
    # Ensure OUTPUT_FIGURE_DIR is used by scanpy save parameter or wrap with plt.savefig
    # Scanpy's save= parameter will save to a ./figures/ subdirectory by default if only a filename is given.
    # To save to your specific OUTPUT_FIGURE_DIR, you might need to prepend it or use plt.savefig.
    # Example for plt.savefig:
    # if save:
    #     if isinstance(save, str):
    #         # Ensure the figure directory exists
    #         figdir = os.path.dirname(os.path.join(OUTPUT_FIGURE_DIR, save_filename_with_prefix))
    #         if not os.path.exists(figdir): os.makedirs(figdir)
    #         plt.savefig(os.path.join(OUTPUT_FIGURE_DIR, save_filename_with_prefix), dpi=150)
    # For sc.pl.save, if save is f"_{filename_prefix_base}_UMAPallFA.png", it saves to "./figures/_pPATIENT_TP_... .png"
    # To control this: either set sc.settings.figdir = OUTPUT_FIGURE_DIR globally, or use plt.savefig after sc.pl calls with show=False, return_fig=True.
    # The current implementation relies on scanpy's default save behavior for simplicity, assuming you run from a location where ./figures/ is acceptable.

In [9]:
PREPROCESSED_ADATA_PATH = '/home/minhang/mds_project/data/cohort_adata/multiVI_model/adata_mrd_hvg_std_may31.h5ad' 

In [10]:
print("Starting validation and MRD comparison plotting script...")

# Load the full AnnData for preSCT/Relapse, and the preprocessed MRD AnnData
adata_full_timepoints = sc.read_h5ad(ADATA_FULL_PATH)
print(f"Full timepoints AnnData loaded: {adata_full_timepoints.shape}")

adata_mrd_preprocessed_global = sc.read_h5ad(PREPROCESSED_ADATA_PATH) # This is HVG selected & standardized
print(f"Global preprocessed MRD AnnData loaded: {adata_mrd_preprocessed_global.shape}")

# Define unique patients, e.g., from one of the anndatas or a predefined list
# Assuming PATIENT_COL is defined
unique_patients = ['P01', 'P02', 'P03', 'P04', 'P05', 'P06', 'P07', 'P09', 'P13']

TIMEPOINTS_FOR_PLOTS = ['MRD'] + VALIDATION_TIMEPOINTS # e.g., ['MRD', 'preSCT', 'Relapse']

for n_fa_components in N_COMPONENTS_LIST_FOR_FA:
    print(f"\n=== Processing for FA models with {n_fa_components} total components ===")
    
    # Load HVG list, SCALER (needed for preSCT/Relapse), and FA MODEL (trained on all MRD)
    hvg_list_mrd, scaler_mrd, fa_model_mrd_all = load_mrd_trained_objects(
        n_fa_components, BASE_MODEL_OUTPUT_DIR, SAVED_HVG_LIST_PATH, SAVED_SCALER_PATH
    )
    
    # Project ALL MRD cells using the globally trained FA model ONCE per n_fa_components
    # fa_model_mrd_all was trained on standardized data, and adata_mrd_preprocessed_global.X is that data.
    print(f"  Projecting all MRD cells ({adata_mrd_preprocessed_global.n_obs}) using FA model for {n_fa_components} components...")
    if adata_mrd_preprocessed_global.n_vars != fa_model_mrd_all.n_features_in_:
         print(f"    FATAL ERROR: Number of features in global MRD data ({adata_mrd_preprocessed_global.n_vars}) "
               f"does not match FA model's expected input features ({fa_model_mrd_all.n_features_in_}). "
               f"Ensure HVG list used for FA model training matches current global MRD AnnData's var_names exactly.")
         continue # Skip this n_fa_components if features mismatch

    X_fa_mrd_all_cells = fa_model_mrd_all.transform(adata_mrd_preprocessed_global.X)
    adata_mrd_with_all_projections = adata_mrd_preprocessed_global.copy()
    adata_mrd_with_all_projections.obsm['X_fa_projected'] = X_fa_mrd_all_cells
    print(f"  Projection of all MRD cells complete. Shape of X_fa_projected: {X_fa_mrd_all_cells.shape}")

    for patient_id_str in unique_patients:
        print(f"\n  -- Patient: {patient_id_str} --")
        
        # Get factors selected by LR-Lasso on this patient's MRD data
        mrd_selected_indices = get_selected_factor_indices_for_patient(
            patient_id_str, 
            n_fa_components,
            BASE_MODEL_OUTPUT_DIR,
            CHOSE_REG_STRENGTH_DICT,
            ALPHAS_LASSO_TRAINING
        )
        
        # Note: generate_comparison_plots handles the case of empty mrd_selected_indices internally for its title/logic
        # but we might want to skip all timepoints for this patient/n_fa_components if no factors were selected at all.
        if not mrd_selected_indices:
            print(f"    No factors selected from MRD LR-Lasso for Patient {patient_id_str} with {n_fa_components} total factors. Skipping all timepoint plots for this combination.")
            continue
        print(f"    Using {len(mrd_selected_indices)} factors selected from MRD analysis: {mrd_selected_indices[:10]}...")

        for tp_name_str in TIMEPOINTS_FOR_PLOTS:
            adata_for_current_plot = None # This will hold the AnnData for the patient and current timepoint, with X_fa_projected
            
            if tp_name_str == 'MRD':
                # Filter the globally projected MRD data for the current patient
                patient_mrd_mask = adata_mrd_with_all_projections.obs[PATIENT_COL] == patient_id_str
                if not patient_mrd_mask.any():
                    print(f"    No MRD data found for Patient {patient_id_str} in the global projected MRD AnnData. Skipping MRD plot.")
                    continue
                adata_for_current_plot = adata_mrd_with_all_projections[patient_mrd_mask].copy()
                # .obsm['X_fa_projected'] is already present
            
            else: # For 'preSCT' or 'Relapse'
                validation_mask = (adata_full_timepoints.obs[PATIENT_COL] == patient_id_str) & \
                                  (adata_full_timepoints.obs[TIMEPOINT_COL] == tp_name_str)
                if not validation_mask.any():
                    print(f"    No {tp_name_str} data found for Patient {patient_id_str}. Skipping {tp_name_str} plot.")
                    continue
                
                adata_validation_raw_subset = adata_full_timepoints[validation_mask].copy()
                
                # Preprocess (HVG selection based on MRD list, scale with MRD scaler) and project with MRD FA model
                adata_for_current_plot = preprocess_and_project(
                    adata_validation_raw_subset, hvg_list_mrd, scaler_mrd, fa_model_mrd_all
                )
                
                if adata_for_current_plot is None or 'X_fa_projected' not in adata_for_current_plot.obsm:
                    print(f"    Projection failed for Patient {patient_id_str}, Timepoint {tp_name_str}. Skipping plot.")
                    continue
            
            # Now, adata_for_current_plot should be valid and have .obsm['X_fa_projected']
            if adata_for_current_plot is not None and adata_for_current_plot.n_obs > 0:
                generate_comparison_plots(
                    adata_for_current_plot,
                    mrd_selected_indices, # These are the fixed indices from patient's MRD Lasso
                    patient_id_str,
                    tp_name_str,
                    n_fa_components,
                    BASE_MODEL_OUTPUT_DIR, 
                    CHOSE_REG_STRENGTH_DICT, 
                    ALPHAS_LASSO_TRAINING,
                    OUTPUT_FIGURE_DIR
                )
            else:
                 print(f"    No cells available for Patient {patient_id_str} at Timepoint {tp_name_str} after processing. Skipping plot.")

print("\n--- Script finished ---")

Starting validation and MRD comparison plotting script...
Full timepoints AnnData loaded: (192149, 36601)
Global preprocessed MRD AnnData loaded: (69801, 3000)

=== Processing for FA models with 100 total components ===
Loading trained objects for 100 total factors...
Loaded HVG list (3000 genes), scaler, and FA model.
  Projecting all MRD cells (69801) using FA model for 100 components...
  Projection of all MRD cells complete. Shape of X_fa_projected: (69801, 100)

  -- Patient: P01 --
Patient P01, 100 total factors: Using 10(th) alpha value (1.83e+00). Selected 55 factor indices.
    Using 55 factors selected from MRD analysis: [3, 4, 7, 8, 11, 14, 16, 19, 22, 24]...
      Generating plots for Patient P01, Timepoint: MRD, Total FA: 100
        Plotting UMAP for 55 selected factors.
  Preprocessing: Original shape (21707, 36601)
    Scaling data with 3000 features...
    Projecting scaled data using FA model into 100 factors...
  Projection complete. Shape of X_fa_projected: (21707, 