## GSEA Analysis Pipeline for Dimensionality Reduction Models

This script performs Gene Set Enrichment Analysis (GSEA) on weight matrices  extracted from various dimensionality reduction models (PCA, ICA, NMF, VanillaVAE,  BetaVAE, and BetaTCVAE). It iterates over different latent dimensions and model types, extracts the weight matrices, and computes GSEA scores. The results are combined into a single output file for downstream analysis.

In [1]:
import joblib
import pandas as pd
import blitzgsea as blitz
import random
import pathlib
import sys

script_directory = pathlib.Path("../2.train-VAE/utils/").resolve()
sys.path.insert(0, str(script_directory))
from betavae import BetaVAE, weights
from betatcvae import BetaTCVAE, tc_weights
from vanillavae import VanillaVAE, vanilla_weights

script_directory = pathlib.Path("../utils/").resolve()
sys.path.insert(0, str(script_directory))
from data_loader import load_train_test_data, load_model_data

In [2]:
# Load data
data_directory = pathlib.Path("../0.data-download/data").resolve()

train_df, test_df, val_df, load_gene_stats = load_train_test_data(
    data_directory, train_or_test="all", load_gene_stats=True, zero_one_normalize=True
)
train_data = pd.DataFrame(train_df)

dependency_file = pathlib.Path(f"{data_directory}/CRISPRGeneEffect.parquet").resolve()
gene_dict_file = pathlib.Path(f"{data_directory}/CRISPR_gene_dictionary.parquet").resolve()
dependency_df, gene_dict_df= load_model_data(dependency_file, gene_dict_file)
gene_dict_df = pd.DataFrame(gene_dict_df)

(1150, 18444)


In [3]:
#Load weight data for VAEs
data_directory = pathlib.Path("../0.data-download/data").resolve()
weight_df = load_train_test_data(
    data_directory, train_or_test="train"
)

gene_list_passed_qc = gene_dict_df.loc[
    gene_dict_df["qc_pass"], "dependency_column"
].tolist()

weight_data = weight_df.filter(gene_list_passed_qc, axis=1)
weight_data.head()

Unnamed: 0,TXNL4A (10907),WASHC5 (9897),ATP6V0B (533),EXOC1 (55763),CDC45 (8318),ELP1 (8518),CSTF3 (1479),TCF7L2 (6934),GTF2F2 (2963),CAND1 (55832),...,PDCD5 (9141),ATP13A1 (57130),FRS2 (10818),ZNF217 (7764),MAFK (7975),PAPOLA (10914),PUF60 (22827),XPO7 (23039),DGKD (8527),TBC1D20 (128637)
0,-2.227724,-0.078573,-0.731835,-0.521454,-1.844994,-1.144493,-1.181554,0.202279,-0.891127,-0.417594,...,-0.838127,-0.153616,-0.279691,0.099706,-0.212187,-0.252796,-3.026932,-0.004705,-0.044785,0.248787
1,-2.439884,-0.259433,-1.345063,-0.734227,-2.124592,-0.781735,-1.328123,-0.038707,-0.914743,-0.149132,...,-0.138306,-0.226299,-0.203144,0.378366,0.052225,-0.09559,-2.578793,-0.023021,-0.31785,-0.189224
2,-2.5482,-0.437635,-1.386458,-0.831147,-2.495377,-0.835291,-1.086402,-0.370876,-0.719334,-0.368493,...,-0.701827,-0.40028,-0.140989,-0.426595,-0.178327,-0.092904,-2.831752,0.440047,-0.259041,-0.039496
3,-1.961683,-0.129761,-1.18547,-0.509749,-2.191305,-1.015713,-1.229495,0.099056,-0.729127,-0.396442,...,-0.23396,-0.188678,-0.218935,0.186318,0.001882,-0.362151,-2.010449,0.072257,-0.077454,-0.095381
4,-2.55907,-0.38281,-1.526952,-0.533276,-2.486223,-0.456399,-0.870226,-0.055091,-0.653224,-0.333798,...,-0.791001,-0.110805,-0.204235,-0.335853,-0.151168,-0.044894,-2.348915,-0.168605,0.098183,-0.079004


In [4]:
def extract_weights(
    model: object, 
    model_name: str, 
    weight_data: pd.DataFrame = None
) -> pd.DataFrame:
    """
    Extracts weight matrix from a given model based on its type.

    Args:
        model (object): A fitted model (e.g., PCA, ICA, NMF, or a VAE).
        model_name (str): Name of the model (e.g., 'pca', 'ica', 'nmf', 'betavae', 'betatcvae', 'vanillavae').
        weight_data (pd.DataFrame, optional): Data required for weight extraction in VAE models.

    Returns:
        pd.DataFrame: DataFrame containing weights with genes as rows and components as columns.
    """
    if model_name in ["pca", "ica", "nmf"]:
        weights_df = pd.DataFrame(
            model.components_,
            columns=dependency_df.drop(columns=["ModelID"]).columns.tolist()
        ).transpose()
        weights_df.columns = [f"{x}" for x in range(0, weights_df.shape[1])]
        weights_df = weights_df.reset_index().rename(columns={"index": "genes"})
    elif model_name in ["betavae", "betatcvae", "vanillavae"]:
        if model_name == "betavae":
            weights_df = weights(model, weight_data)
        elif model_name == "betatcvae":
            weights_df = tc_weights(model, weight_data)
        elif model_name == "vanillavae":
            weights_df = vanilla_weights(model, weight_data)

        # Ensure no duplicate or unintended columns
        weights_df = weights_df.loc[:, ~weights_df.columns.duplicated()]
        
        # Rename first column to 'genes', if appropriate
        if weights_df.columns[0] != "genes":
            weights_df.rename(columns={weights_df.columns[0]: "genes"}, inplace=True)

        # Reset index without adding duplicates
        weights_df = weights_df.reset_index(drop=True)
    else:
        raise ValueError(f"Unsupported model type: {model_name}")

    return weights_df

def perform_gsea(weights_df: pd.DataFrame, model_name: str, num_components: int, init: int, modelseed:int, lib: str = "CORUM") -> pd.DataFrame:
    """
    Performs Gene Set Enrichment Analysis (GSEA) for a given weight matrix.

    Args:
        weights_df (pd.DataFrame): DataFrame containing genes and their associated weights.
        model_name (str): Name of the model being analyzed.
        num_components (int): Number of components used in the model.
        lib (str): Name of the GSEA library (default: 'CORUM').

    Returns:
        pd.DataFrame: Results of GSEA with columns for pathway, enrichment scores, and other metrics.
    """
    
    library = blitz.enrichr.get_library(lib)
    random.seed(0)
    seed = random.random()
    gsea_results = []
    for col in weights_df.columns[1:]:  # Skip 'genes' column
        gene_signature = weights_df[['genes', col]]
        if gene_signature.shape[0] > 0:
            try:
                gsea_result = blitz.gsea(gene_signature, library, seed=seed)
                gsea_result = gsea_result.reset_index()
                for _, pathway_result in gsea_result.iterrows():
                    result_row = {
                        "z": int(col),
                        "full_model_z": num_components,
                        "init" : int(init),
                        "modelseed" : int(modelseed),
                        "model": str(model_name),
                        "reactome_pathway": str(pathway_result['Term']),
                        "gsea_es_score": pathway_result['es'],
                        "nes_score": pathway_result['nes'],
                        "p_value": pathway_result['pval'],
                        "shuffled": False
                    }
                    gsea_results.append(result_row)
                    
            except ZeroDivisionError:
                print(f"Skipping GSEA for {col} due to zero division error.")
    
    gsea_results_df = pd.DataFrame(gsea_results)
    return gsea_results_df

# Define the location of the saved models and output directory for GSEA results
model_save_dir = pathlib.Path("saved_models")
output_dir = pathlib.Path("gsea_results")
output_dir.mkdir(parents=True, exist_ok=True)

final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
try:
    combined_results_df = pd.read_parquet(final_output_file)
    print(f"Loaded existing results from {final_output_file}")
except FileNotFoundError:
    # If the file doesn't exist, initialize an empty DataFrame
    combined_results_df = pd.DataFrame()
    print(f"No existing file found. Initialized empty DataFrame.")


# Iterate over all files in the saved_models directory
for model_file in model_save_dir.glob("*.joblib"):
    # Extract model name and number of components from the filename
    model_file_name = model_file.stem
    try:
        # Assuming the filename format includes model_name, num_components, and potentially a seed
        # Example: "BetaVAE_100_components_seed42_model.joblib"
        parts = model_file_name.split("_")
        model_name = parts[0]  # First part is the model name
        num_components = int(parts[3])  # Second part should indicate the number of components
        init = int(parts[7])
        seed = int(parts[9])
    except (IndexError, ValueError) as e:
        print(f"Skipping file {model_file} due to unexpected filename format.")
        continue
    # Check if this model, latent dimension, and initialization have already been processed
    if not combined_results_df.empty:
        if ((combined_results_df['model'] == model_name) & 
            (combined_results_df['init'] == init) &
            (combined_results_df['full_model_z'] == num_components)).any():
            print(f"Skipping {model_name} init {init} with {num_components} dimensions as it is already processed.")
            continue

    # Load the model
    print(f"Loading model from {model_file}")
    try:
        model = joblib.load(model_file)
    except Exception as e:
        print(f"Failed to load model from {model_file}: {e}")
        continue

    # Extract the weight matrix
    
    weight_matrix_df = extract_weights(model, model_name, weight_data)
    
    # Perform GSEA
    gsea_results_df = perform_gsea(weight_matrix_df, model_name, num_components, init, seed)
    combined_results_df = pd.concat([combined_results_df, gsea_results_df], ignore_index=True)
            

Loaded existing results from gsea_results/combined_z_matrix_gsea_results.parquet
Skipping betatcvae init 2 with 5 dimensions as it is already processed.
Skipping betavae init 4 with 14 dimensions as it is already processed.
Skipping betatcvae init 1 with 20 dimensions as it is already processed.
Skipping betavae init 4 with 3 dimensions as it is already processed.
Skipping vanillavae init 1 with 150 dimensions as it is already processed.
Skipping pca init 0 with 35 dimensions as it is already processed.
Skipping ica init 0 with 50 dimensions as it is already processed.
Skipping betavae init 3 with 16 dimensions as it is already processed.
Skipping betatcvae init 1 with 150 dimensions as it is already processed.
Skipping betavae init 4 with 100 dimensions as it is already processed.
Skipping vanillavae init 4 with 12 dimensions as it is already processed.
Skipping betatcvae init 0 with 35 dimensions as it is already processed.
Skipping betatcvae init 3 with 2 dimensions as it is already

In [5]:
# Save the dataframe to a file
final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
combined_results_df.to_parquet(final_output_file, index=False)

print(f"Saved final z_matrix combining all models and latent dimensions and GSEA results to {final_output_file}")


Saved final filtered z_matrix and GSEA results to gsea_results/combined_z_matrix_gsea_results.parquet


In [6]:
combined_results_df.sort_values(by='gsea_es_score', key=abs, ascending = False).head(50)

Unnamed: 0,z,full_model_z,init,modelseed,model,reactome_pathway,gsea_es_score,nes_score,p_value,shuffled
585651,10,60,2,1709873177,betavae,HDAC2-asscociated core complex (human),0.978077,2.798097,0.00514,False
451520,4,100,4,4073875038,betavae,LSm2-8 complex (human),-0.970976,-2.818678,0.004822,False
1106190,60,60,3,3783068897,betavae,p400-associated complex (human),-0.964271,-2.884007,0.003927,False
826370,10,90,0,0,nmf,"Cytochrome c oxidase (EC 1.9.3.1), mitochondri...",0.959698,2.483275,0.013018,False
680172,25,40,1,844030991,vanillavae,HDAC2-asscociated core complex (human),0.953582,2.743509,0.006079,False
74120,14,35,2,1507182516,betavae,Prefoldin complex (human),-0.952486,-2.790051,0.00527,False
80750,20,35,1,844030991,vanillavae,BRD4-RFC complex (human),-0.950275,-2.779034,0.005452,False
750380,9,50,3,3783068897,betavae,ESCRT-III complex (human),-0.949979,-2.966301,0.003014,False
393380,3,35,0,0,nmf,RSmad complex (human),0.948103,2.441158,0.01464,False
698872,35,200,2,1709873177,vanillavae,BLM complex II (human),-0.94705,-2.56813,0.010225,False
