## GSEA Analysis Pipeline for Dimensionality Reduction Models

This script performs Gene Set Enrichment Analysis (GSEA) on weight matrices  extracted from various dimensionality reduction models (PCA, ICA, NMF, VanillaVAE,  BetaVAE, and BetaTCVAE). It iterates over different latent dimensions and model types, extracts the weight matrices, and computes GSEA scores. The results are combined into a single output file for downstream analysis.

In [1]:
import joblib
import pandas as pd
import blitzgsea as blitz
import random
import pathlib
import sys

script_directory = pathlib.Path("../utils/").resolve()
sys.path.insert(0, str(script_directory))
from data_loader import load_train_test_data, load_model_data
from model_utils import extract_weights

In [2]:
# Load data
data_directory = pathlib.Path("../0.data-download/data").resolve()

train_df, test_df, val_df, load_gene_stats = load_train_test_data(
    data_directory, train_or_test="all", load_gene_stats=True, zero_one_normalize=True
)
train_data = pd.DataFrame(train_df)

dependency_file = pathlib.Path(f"{data_directory}/CRISPRGeneEffect.parquet").resolve()
gene_dict_file = pathlib.Path(f"{data_directory}/CRISPR_gene_dictionary.parquet").resolve()
dependency_df, gene_dict_df= load_model_data(dependency_file, gene_dict_file)
gene_dict_df = pd.DataFrame(gene_dict_df)

(1150, 18444)


In [3]:
#Load weight data for VAEs
data_directory = pathlib.Path("../0.data-download/data").resolve()
weight_df = load_train_test_data(
    data_directory, train_or_test="train"
)

gene_list_passed_qc = gene_dict_df.loc[
    gene_dict_df["qc_pass"], "dependency_column"
].tolist()

weight_data = weight_df.filter(gene_list_passed_qc, axis=1)
weight_data.head()

Unnamed: 0,SLC25A26 (115286),WRN (7486),HAUS8 (93323),TNS3 (64759),RELB (5971),ARHGEF5 (7984),B4GALT5 (9334),MMACHC (25974),POLR3H (171568),GATAD1 (57798),...,PIK3CD (5293),CCDC6 (8030),NEUROD1 (4760),RAB3GAP2 (25782),RALGAPB (57148),ORC1 (4998),RBIS (401466),ZNF273 (10793),MTHFD1 (4522),CSNK2B (1460)
0,-1.032492,-0.188095,-0.276764,-0.128368,-0.068453,-0.348402,-0.195648,0.172237,-1.612276,0.04989,...,-0.172514,-0.246734,-0.075915,-0.220835,-0.084536,-1.054615,-0.138051,-0.101715,-0.103041,-1.129567
1,-0.250614,-0.249719,-0.304459,-0.1281,-0.028,-0.401503,-0.04236,-0.148679,-1.684179,-0.502164,...,-0.163735,0.128337,-0.125988,0.124108,0.718825,-0.625943,-0.075275,-0.180233,-0.526525,-0.867831
2,-0.099892,0.017013,-0.6359,-0.312322,-0.077608,0.060578,-0.031489,-0.099623,-2.388216,-0.157901,...,-0.119579,0.055878,0.197878,-0.006411,0.0349,-0.785888,-0.163386,-0.228768,-0.960661,-0.844785
3,-0.298281,-0.1924,-0.44352,-0.206342,0.128013,-0.193442,-0.110086,0.060171,-2.73419,-0.083247,...,-0.119559,-0.238584,0.216281,-0.151375,-0.212339,-0.707361,-0.650696,-0.661645,-0.667901,-1.487172
4,-0.475818,-0.364279,-0.779388,-0.268785,0.101934,-0.476407,-0.127605,0.078356,-1.706935,-0.246327,...,-0.071391,0.008482,0.126977,-0.139191,-0.240368,-0.696641,-0.217298,-0.318229,-0.498258,-0.934179


In [4]:
def perform_gsea(weights_df: pd.DataFrame, model_name: str, num_components: int, init: int, modelseed:int, lib: str = "CORUM") -> pd.DataFrame:
    """
    Performs Gene Set Enrichment Analysis (GSEA) for a given weight matrix.

    Args:
        weights_df (pd.DataFrame): DataFrame containing genes and their associated weights.
        model_name (str): Name of the model being analyzed.
        num_components (int): Number of components used in the model.
        lib (str): Name of the GSEA library (default: 'CORUM').

    Returns:
        pd.DataFrame: Results of GSEA with columns for pathway, enrichment scores, and other metrics.
    """
    
    library = blitz.enrichr.get_library(lib)
    random.seed(0)
    seed = random.random()
    gsea_results = []
    for col in weights_df.columns[1:]:  # Skip 'genes' column
        gene_signature = weights_df[['genes', col]]
        if gene_signature.shape[0] > 0:
            try:
                gsea_result = blitz.gsea(gene_signature, library, seed=seed)
                gsea_result = gsea_result.reset_index()
                for _, pathway_result in gsea_result.iterrows():
                    result_row = {
                        "z": int(col),
                        "full_model_z": num_components,
                        "init" : int(init),
                        "modelseed" : int(modelseed),
                        "model": str(model_name),
                        "reactome_pathway": str(pathway_result['Term']),
                        "gsea_es_score": pathway_result['es'],
                        "nes_score": pathway_result['nes'],
                        "p_value": pathway_result['pval'],
                        "shuffled": False
                    }
                    gsea_results.append(result_row)
                    
            except ZeroDivisionError:
                print(f"Skipping GSEA for {col} due to zero division error.")
    
    gsea_results_df = pd.DataFrame(gsea_results)
    return gsea_results_df

# Define the location of the saved models and output directory for GSEA results
model_save_dir = pathlib.Path("saved_models")
output_dir = pathlib.Path("gsea_results")
output_dir.mkdir(parents=True, exist_ok=True)

final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
try:
    combined_results_df = pd.read_parquet(final_output_file)
    print(f"Loaded existing results from {final_output_file}")
except FileNotFoundError:
    # If the file doesn't exist, initialize an empty DataFrame
    combined_results_df = pd.DataFrame()
    print(f"No existing file found. Initialized empty DataFrame.")


# Iterate over all files in the saved_models directory
for model_file in model_save_dir.glob("*.joblib"):
    # Extract model name and number of components from the filename
    model_file_name = model_file.stem
    try:
        # Assuming the filename format includes model_name, num_components, and potentially a seed
        # Example: "BetaVAE_100_components_seed42_model.joblib"
        parts = model_file_name.split("_")
        model_name = parts[0]  # First part is the model name
        num_components = int(parts[3])  # Second part should indicate the number of components
        init = int(parts[7])
        seed = int(parts[9])
    except (IndexError, ValueError) as e:
        print(f"Skipping file {model_file} due to unexpected filename format.")
        continue
    # Check if this model, latent dimension, and initialization have already been processed
    if not combined_results_df.empty:
        if ((combined_results_df['model'] == model_name) & 
            (combined_results_df['init'] == init) &
            (combined_results_df['full_model_z'] == num_components)).any():
            print(f"Skipping {model_name} init {init} with {num_components} dimensions as it is already processed.")
            continue

    # Load the model
    print(f"Loading model from {model_file}")
    try:
        model = joblib.load(model_file)
    except Exception as e:
        print(f"Failed to load model from {model_file}: {e}")
        continue

    # Extract the weight matrix
    weight_matrix_df = extract_weights(model, model_name, weight_data, dependency_df)
    
    # Perform GSEA
    gsea_results_df = perform_gsea(weight_matrix_df, model_name, num_components, init, seed)
    combined_results_df = pd.concat([combined_results_df, gsea_results_df], ignore_index=True)
            

No existing file found. Initialized empty DataFrame.


In [5]:
# Save the dataframe to a file
final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
combined_results_df.to_parquet(final_output_file, index=False)

print(f"Saved final z_matrix combining all models and latent dimensions and GSEA results to {final_output_file}")


Saved final z_matrix combining all models and latent dimensions and GSEA results to gsea_results/combined_z_matrix_gsea_results.parquet


In [7]:
combined_results_df.sort_values(by='gsea_es_score', key=abs, ascending = False).head(50)

KeyError: 'gsea_es_score'