In [13]:
import joblib
import pandas as pd
import blitzgsea as blitz
import random
import pathlib
import sys

script_directory = pathlib.Path("../2.train-VAE/utils/").resolve()
sys.path.insert(0, str(script_directory))
from betavae import BetaVAE, weights
from betatcvae import BetaTCVAE, tc_weights
from vanillavae import VanillaVAE, vanilla_weights

script_directory = pathlib.Path("../utils/").resolve()
sys.path.insert(0, str(script_directory))
from data_loader import load_train_test_data, load_model_data

In [14]:
# Load data
data_directory = pathlib.Path("../0.data-download/data").resolve()

train_df, test_df, val_df, load_gene_stats = load_train_test_data(
    data_directory, train_or_test="all", load_gene_stats=True, zero_one_normalize=True
)
train_data = pd.DataFrame(train_df)

dependency_file = pathlib.Path(f"{data_directory}/CRISPRGeneEffect.parquet").resolve()
gene_dict_file = pathlib.Path(f"{data_directory}/CRISPR_gene_dictionary.parquet").resolve()
dependency_df, gene_dict_df= load_model_data(dependency_file, gene_dict_file)
gene_dict_df = pd.DataFrame(gene_dict_df)

(1150, 18444)


In [15]:
#Load weight data for VAEs
data_directory = pathlib.Path("../0.data-download/data").resolve()
weight_df = load_train_test_data(
    data_directory, train_or_test="train"
)

gene_list_passed_qc = gene_dict_df.loc[
    gene_dict_df["qc_pass"], "dependency_column"
].tolist()

weight_data = weight_df.filter(gene_list_passed_qc, axis=1)
weight_data.head()

Unnamed: 0,ZNF273 (10793),SARS1 (6301),ACTR10 (55860),MDC1 (9656),SMG8 (55181),LEO1 (123169),TECR (9524),IRF9 (10379),EFR3A (23167),CTR9 (9646),...,SKP2 (6502),SMG6 (23293),CCNC (892),REXO2 (25996),EXT2 (2132),PWP2 (5822),PYROXD1 (79912),SIK3 (23387),CALM2 (805),MPDU1 (9526)
0,-0.025015,-3.152211,-1.679321,-0.165686,-0.578564,0.168773,-0.143454,-0.059662,-0.286817,-0.969735,...,-0.366123,-0.938793,-0.087345,-0.15587,-0.168154,-1.486831,-0.593819,0.076267,0.06879,-0.059073
1,-0.083222,-2.721956,-1.590585,0.064598,-0.49535,-0.250202,-0.19546,-0.23279,-0.238748,-1.176141,...,-0.916519,-1.226641,0.263611,-0.095143,-0.875274,-1.325296,-0.475371,0.040527,0.109524,-0.240882
2,-0.235488,-2.153757,-1.881204,-0.103206,-0.557517,0.123995,-0.243343,-0.095325,-0.103035,-0.670284,...,-0.525837,-1.332287,-0.33706,-0.040212,-0.217806,-1.019237,-0.475445,0.053786,0.022866,-0.082868
3,-0.013641,-2.549781,-1.529319,-0.606543,-0.225701,-0.346301,-0.06933,-0.300127,-0.387945,-1.146583,...,-0.698852,-0.426716,0.134276,0.15991,-0.097972,-0.798972,-0.759828,-0.100234,-0.128199,-0.142403
4,-0.216511,-2.322025,-1.299256,0.030481,-0.489542,-0.140603,-0.30061,0.297708,-0.500277,-0.924757,...,-0.334904,-1.02067,-0.637679,0.251131,-0.281219,-1.60095,-0.548521,-0.176438,0.25617,-0.25555


In [20]:
# Function to extract weights for sklearn models
def extract_weights(model, model_name):
    if model_name in ["pca", "ica", "nmf"]:
        weights_df = pd.DataFrame(model.components_, columns=dependency_df.drop(columns=["ModelID"]).columns.tolist()).transpose()
        weights_df.columns = [f"{x}" for x in range(0, weights_df.shape[1])]
    
    weights_df = weights_df.reset_index().rename(columns={"index": "genes"})
    return weights_df

# GSEA function (same as before)
def perform_gsea(weights_df, model_name, num_components, lib="Reactome_2022"):
    library = blitz.enrichr.get_library(lib)
    random.seed(0)
    seed = random.random()
    gsea_results = []
    for col in weights_df.columns[1:]:  # Skip 'genes' column
        gene_signature = weights_df[['genes', col]]
        if gene_signature.shape[0] > 0:
            try:
                gsea_result = blitz.gsea(gene_signature, library, seed=seed)
                gsea_result = gsea_result.reset_index()
                for _, pathway_result in gsea_result.iterrows():
                    result_row = {
                        "z": int(col),
                        "full_model_z": num_components,
                        "model": str(model_name),
                        "reactome_pathway": str(pathway_result['Term']),
                        "gsea_es_score": pathway_result['es'],
                        "nes_score": pathway_result['nes'],
                        "p_value": pathway_result['pval'],
                        "shuffled": False
                    }
                    gsea_results.append(result_row)
                    
            except ZeroDivisionError:
                print(f"Skipping GSEA for {col} due to zero division error.")
    
    gsea_results_df = pd.DataFrame(gsea_results)
    return gsea_results_df

# Define the location of the saved models and output directory for GSEA results
model_save_dir = pathlib.Path("saved_models")
output_dir = pathlib.Path("gsea_results")
output_dir.mkdir(parents=True, exist_ok=True)

# Latent dimensions and model names to iterate over
latent_dims = [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100, 150, 200]
model_names = ["pca", "ica", "nmf", "vanillavae", "betavae", "betatcvae"]

final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
try:
    combined_results_df = pd.read_parquet(final_output_file)
    print(f"Loaded existing results from {final_output_file}")
except FileNotFoundError:
    # If the file doesn't exist, initialize an empty DataFrame
    combined_results_df = pd.DataFrame()
    print(f"No existing file found. Initialized empty DataFrame.")

for num_components in latent_dims:
    for model_name in model_names:
        # Load the saved model
        # Check if this model and latent dimension have already been processed
        if not combined_results_df.empty:
            if ((combined_results_df['model'] == model_name) & 
                (combined_results_df['full_model_z'] == num_components)).any():
                print(f"Skipping {model_name} with {num_components} dimensions as it is already processed.")
                continue  # Skip to the next iteration if this combination is already present
        model_filename = model_save_dir / f"{model_name}_{num_components}_components_model.joblib"
        if model_filename.exists():
            print(f"Loading model from {model_filename}")
            model = joblib.load(model_filename)
            
            if model_name in ["pca", "ica", "nmf"]:
                # Extract the weight matrix
                weight_matrix_df = extract_weights(model, model_name)
            elif model_name == "betavae":
                weight_matrix_df = weights(model, weight_data)
                weight_matrix_df.rename(columns={0: 'genes'}, inplace=True)
            elif model_name == "betatcvae":
                weight_matrix_df = tc_weights(model, weight_data)
                weight_matrix_df.rename(columns={0: 'genes'}, inplace=True)
            elif model_name == "vanillavae":
                weight_matrix_df = vanilla_weights(model, weight_data)
                weight_matrix_df.rename(columns={0: 'genes'}, inplace=True)
            # Perform GSEA
            gsea_results_df = perform_gsea(weight_matrix_df, model_name, num_components)
            combined_results_df = pd.concat([combined_results_df, gsea_results_df], ignore_index=True)
        else:
            print(f"Model file {model_filename} not found. Skipping.")
            

Loaded existing results from gsea_results/combined_z_matrix_gsea_results.parquet
Skipping pca with 2 dimensions as it is already processed.
Skipping ica with 2 dimensions as it is already processed.
Skipping nmf with 2 dimensions as it is already processed.
Skipping vanillavae with 2 dimensions as it is already processed.
Skipping betavae with 2 dimensions as it is already processed.
Skipping betatcvae with 2 dimensions as it is already processed.
Skipping pca with 3 dimensions as it is already processed.
Skipping ica with 3 dimensions as it is already processed.
Skipping nmf with 3 dimensions as it is already processed.
Skipping vanillavae with 3 dimensions as it is already processed.
Skipping betavae with 3 dimensions as it is already processed.
Skipping betatcvae with 3 dimensions as it is already processed.
Skipping pca with 4 dimensions as it is already processed.
Skipping ica with 4 dimensions as it is already processed.
Skipping nmf with 4 dimensions as it is already processed.


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_150_components_model.joblib
Loading model from saved_models/ica_150_components_model.joblib
Loading model from saved_models/nmf_150_components_model.joblib
Loading model from saved_models/vanillavae_150_components_model.joblib


  return torch.load(io.BytesIO(b))


Model file saved_models/betavae_150_components_model.joblib not found. Skipping.
Model file saved_models/betatcvae_150_components_model.joblib not found. Skipping.
Model file saved_models/pca_200_components_model.joblib not found. Skipping.
Model file saved_models/ica_200_components_model.joblib not found. Skipping.
Model file saved_models/nmf_200_components_model.joblib not found. Skipping.
Model file saved_models/vanillavae_200_components_model.joblib not found. Skipping.
Model file saved_models/betavae_200_components_model.joblib not found. Skipping.
Model file saved_models/betatcvae_200_components_model.joblib not found. Skipping.


In [21]:
# Save the combined dataframe to a file
final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
combined_results_df.to_parquet(final_output_file, index=False)

print(f"Saved final combined z_matrix and GSEA results to {final_output_file}")

#Save as CSV for R 
csv_output_file = output_dir / "combined_z_matrix_gsea_results.csv"
combined_results_df.to_csv(csv_output_file, index=False)

Saved final combined z_matrix and GSEA results to gsea_results/combined_z_matrix_gsea_results.parquet


In [22]:
combined_results_df.sort_values(by='gsea_es_score', key=abs, ascending = False).head(50)

Unnamed: 0,z,full_model_z,model,reactome_pathway,gsea_es_score,nes_score,p_value,shuffled
1679486,34,45,pca,Regulation Of IGF Transport And Uptake By IGFB...,-0.97616,-2.854464,0.004311,False
1679487,34,45,pca,Post-translational Protein Phosphorylation R-H...,-0.97616,-2.854464,0.004311,False
636296,1,18,nmf,G Beta:Gamma Signaling Thru PI3Kgamma R-HSA-39...,0.972813,2.712398,0.00668,False
636295,1,18,nmf,G-protein Beta:Gamma Signaling R-HSA-397795,0.970808,2.713693,0.006654,False
38411,3,4,ica,"Synthesis, Secretion, And Inactivation Of Gluc...",-0.962916,-2.633018,0.008463,False
673662,2,18,betavae,Cardiac Conduction R-HSA-5576891,-0.961982,-2.900359,0.003727,False
780576,9,20,vanillavae,G Beta:Gamma Signaling Thru PI3Kgamma R-HSA-39...,-0.961797,-2.598086,0.009374,False
917593,6,25,vanillavae,Metabolism Of Fat-Soluble Vitamins R-HSA-6806667,-0.959089,-2.633433,0.008453,False
38409,3,4,ica,"Incretin Synthesis, Secretion, And Inactivatio...",-0.958637,-2.779595,0.005443,False
4619109,97,100,betavae,SHC-related Events Triggered By IGF1R R-HSA-24...,0.957001,2.836918,0.004555,False
