In [3]:
import os
import pickle
import numpy as np
import pandas as pd
import scanpy as sc
import matplotlib.pyplot as plt
from pathlib import Path
from anndata import AnnData
import sys

# Ensure the parent directory is in the system path
sys.path.append('/home/minhang/mds_project/sc_classification')
from utils.experiment_manager import ExperimentManager
from utils.experiment_analysis import ExperimentAnalyzer

In [4]:
experiment_id = "20250714_205422_fa_100_random_6dbbde08"
experiments_dir = '/home/minhang/mds_project/sc_classification/experiments/'

experiment_manager = ExperimentManager(experiments_dir)
analyzer = ExperimentAnalyzer(experiment_manager)

In [5]:
# Path to the full anndata object containing all patients and timepoints
ADATA_FULL_PATH = '/home/minhang/mds_project/data/cohort_adata/multiVI_model/adata_multivi_corrected_rna.h5ad'

# Column names in your adata.obs
PATIENT_COL = 'patient'
TIMEPOINT_COL = 'timepoint_type'
CELL_LABEL_COL = 'CN.label'

# Timepoints to project onto
VALIDATION_TIMEPOINTS = ['preSCT', 'Relapse']

In [7]:
patient_indices_to_analyze = {
    "P01": [12, 13, 14],
    "P02": [14, 15, 16],
    "P03": [17], 
    "P04": [11, 12, 13, 14, 15],
    "P05": [9, 10, 11, 12, 13],
    "P06": [9, 10, 11, 12, 13, 14], 
    "P07": [10, 11, 12, 13],
    "P09": [11, 12, 13, 14, 15, 16, 17],
    "P13": [10, 11, 12, 13]
}

exp = experiment_manager.load_experiment(experiment_id)

Loading experiment from: /home/minhang/mds_project/sc_classification/experiments/20250714_205422_fa_100_random_6dbbde08


In [8]:
# --- Define Output Path using Experiment Object ---
# Get the base 'projections' directory from the experiment object
base_projections_dir = exp.get_path('projections')
# Create a specific subdirectory for this analysis run
OUTPUT_FIGURE_DIR = base_projections_dir / "projection_analysis_by_index"
OUTPUT_FIGURE_DIR.mkdir(exist_ok=True, parents=True)
print(f"Figures will be saved to: {OUTPUT_FIGURE_DIR}")

Figures will be saved to: /home/minhang/mds_project/sc_classification/experiments/20250714_205422_fa_100_random_6dbbde08/analysis/projections/projection_analysis_by_index


In [10]:
# 1. Load the projection environment once
print(f"\n--- Preparing projection environment for experiment: {experiment_id} ---")
proj_env = analyzer.prepare_projection_environment(experiment_id)
if not proj_env:
    raise RuntimeError("Failed to prepare projection environment. Aborting.")


--- Preparing projection environment for experiment: 20250714_205422_fa_100_random_6dbbde08 ---
Loading experiment from: /home/minhang/mds_project/sc_classification/experiments/20250714_205422_fa_100_random_6dbbde08
Loading experiment from: /home/minhang/mds_project/sc_classification/experiments/20250714_205422_fa_100_random_6dbbde08


In [13]:
print(proj_env.keys())
model = proj_env['model']
scaler = proj_env['scaler']
hvg_list = proj_env['hvg_list']
coefficients_dict = proj_env['coefficients']
config = proj_env['config']
n_total_factors = config.get('dimension_reduction.n_components')

dict_keys(['model', 'scaler', 'hvg_list', 'coefficients', 'config'])


In [15]:
# 2. Load the full AnnData object
print(f"--- Loading validation data from {ADATA_FULL_PATH} ---")
adata_full = sc.read_h5ad(ADATA_FULL_PATH)

--- Loading validation data from /home/minhang/mds_project/data/cohort_adata/multiVI_model/adata_multivi_corrected_rna.h5ad ---


In [16]:
# --- Main Analysis Loop ---
for patient_id, alpha_indices in patient_indices_to_analyze.items():
    print(f"\n{'='*20} Processing Patient: {patient_id} {'='*20}")
    
    if patient_id not in coefficients_dict:
        print(f"  Warning: Patient {patient_id} not found in the experiment's classification results. Skipping.")
        continue
        
    patient_coefs_df = coefficients_dict[patient_id]

    for alpha_index in alpha_indices:
        print(f"\n  --- Analyzing Alpha Index: {alpha_index} ---")
        
        # 4. Select Active Factors for this index
        alpha_idx_0based = alpha_index - 1
        if not (0 <= alpha_idx_0based < len(patient_coefs_df.columns)):
            print(f"    Error: Alpha index {alpha_index} is out of bounds for patient {patient_id}. Skipping.")
            continue
            
        target_col_name = patient_coefs_df.columns[alpha_idx_0based]
        coefs_at_alpha = patient_coefs_df[target_col_name]
        selected_factor_indices = np.where(coefs_at_alpha != 0)[0]
        n_selected = len(selected_factor_indices)

        print(f"    Alpha column: '{target_col_name}', Selected {n_selected} factors.")

        if n_selected == 0:
            print("    No factors selected at this regularization strength. Skipping timepoints.")
            continue

        # 5. Project Data for Each Timepoint
        for timepoint in VALIDATION_TIMEPOINTS:
            print(f"    - Projecting onto timepoint: {timepoint} -")
            
            mask = (adata_full.obs[PATIENT_COL] == patient_id) & (adata_full.obs[TIMEPOINT_COL] == timepoint)
            if not mask.any():
                print(f"      No cells found for patient {patient_id} at timepoint {timepoint}.")
                continue
            
            adata_subset = adata_full[mask].copy()
            adata_projected = analyzer.project_new_data(adata_subset, model, hvg_list, scaler)

            if adata_projected is None:
                print("      Projection failed. Skipping visualization.")
                continue

            # 6. Visualize the projection
            title_prefix = f"Patient {patient_id} - {timepoint}\n(Total: {n_total_factors} Factors, Selected: {n_selected} at index {alpha_index})"
            filename_prefix = f"patient_{patient_id}_tp_{timepoint}_totalFA{n_total_factors}_alphaidx{alpha_index}"
            
            adata_projected.obsm['X_selected'] = adata_projected.obsm['X_projected'][:, selected_factor_indices]
            
            if n_selected >= 2:
                print(f"      {n_selected} factors selected. Computing and plotting UMAP.")
                if adata_projected.n_obs > 1:
                    n_neighbors = min(15, adata_projected.n_obs - 1)
                    sc.pp.neighbors(adata_projected, use_rep='X_selected', n_neighbors=n_neighbors)
                    sc.tl.umap(adata_projected)
                    
                    fig = sc.pl.umap(adata_projected, color=CELL_LABEL_COL, 
                                   title=title_prefix, show=False, return_fig=True)
                    fig.savefig(OUTPUT_FIGURE_DIR / f"{filename_prefix}_umap.png", dpi=300, bbox_inches='tight')
                    plt.close(fig)
                else:
                    print(f"      Not enough cells ({adata_projected.n_obs}) to compute UMAP.")

            elif n_selected == 1:
                print(f"      Only 1 factor selected. Plotting violin distribution.")
                factor_name = f"Factor_{selected_factor_indices[0]+1}"
                adata_projected.obs[factor_name] = adata_projected.obsm['X_selected'][:, 0]
                
                fig, ax = plt.subplots()
                sc.pl.violin(adata_projected, keys=[factor_name], groupby=CELL_LABEL_COL, ax=ax, show=False, stripplot=True, jitter=0.4)
                ax.set_title(title_prefix)
                fig.savefig(OUTPUT_FIGURE_DIR / f"{filename_prefix}_violin.png", dpi=300, bbox_inches='tight')
                plt.close(fig)

print(f"\n--- Projection analysis complete. Figures saved to: {OUTPUT_FIGURE_DIR} ---")



  --- Analyzing Alpha Index: 12 ---
    Alpha column: 'alpha_1.62e+01', Selected 31 factors.
    - Projecting onto timepoint: preSCT -
Projecting 21707 cells onto 3000 HVGs...
Projection complete.
      31 factors selected. Computing and plotting UMAP.
    - Projecting onto timepoint: Relapse -
Projecting 5115 cells onto 3000 HVGs...
Projection complete.
      31 factors selected. Computing and plotting UMAP.

  --- Analyzing Alpha Index: 13 ---
    Alpha column: 'alpha_4.83e+01', Selected 13 factors.
    - Projecting onto timepoint: preSCT -
Projecting 21707 cells onto 3000 HVGs...
Projection complete.
      13 factors selected. Computing and plotting UMAP.
    - Projecting onto timepoint: Relapse -
Projecting 5115 cells onto 3000 HVGs...
Projection complete.
      13 factors selected. Computing and plotting UMAP.

  --- Analyzing Alpha Index: 14 ---
    Alpha column: 'alpha_1.44e+02', Selected 3 factors.
    - Projecting onto timepoint: preSCT -
Projecting 21707 cells onto 3000 HVG