In [1]:
import joblib
import pandas as pd
import blitzgsea as blitz
import random
import pathlib
import sys

script_directory = pathlib.Path("../2.train-VAE/utils/").resolve()
sys.path.insert(0, str(script_directory))
from betavae import BetaVAE, weights
from betatcvae import BetaTCVAE, tc_weights
from vanillavae import VanillaVAE, vanilla_weights

script_directory = pathlib.Path("../utils/").resolve()
sys.path.insert(0, str(script_directory))
from data_loader import load_train_test_data, load_model_data

In [2]:
# Load data
data_directory = pathlib.Path("../0.data-download/data").resolve()

train_df, test_df, val_df, load_gene_stats = load_train_test_data(
    data_directory, train_or_test="all", load_gene_stats=True, zero_one_normalize=True
)
train_data = pd.DataFrame(train_df)

dependency_file = pathlib.Path(f"{data_directory}/CRISPRGeneEffect.parquet").resolve()
gene_dict_file = pathlib.Path(f"{data_directory}/CRISPR_gene_dictionary.parquet").resolve()
dependency_df, gene_dict_df= load_model_data(dependency_file, gene_dict_file)
gene_dict_df = pd.DataFrame(gene_dict_df)

(1150, 18444)


In [3]:
#Load weight data for VAEs
data_directory = pathlib.Path("../0.data-download/data").resolve()
weight_df = load_train_test_data(
    data_directory, train_or_test="train"
)

gene_list_passed_qc = gene_dict_df.loc[
    gene_dict_df["qc_pass"], "dependency_column"
].tolist()

weight_data = weight_df.filter(gene_list_passed_qc, axis=1)
weight_data.head()

Unnamed: 0,SRSF3 (6428),GTSE1 (51512),NDUFAF1 (51103),SIRT1 (23411),BLM (641),CPSF7 (79869),SAV1 (60485),POLR2C (5432),NUBP2 (10101),RPN1 (6184),...,LSM5 (23658),PIK3CA (5290),NEK1 (4750),PRKAR1A (5573),RARS1 (5917),RAB35 (11021),SUPT7L (9913),MYO9B (4650),PSMG1 (8624),UMPS (7372)
0,-1.693757,-0.136488,-0.864693,-0.013704,-0.283012,-0.319562,0.178313,-1.447555,-0.631783,-1.414832,...,-1.58883,-0.746764,-0.374383,-0.29494,-0.686259,-0.326404,0.079078,-0.351177,-0.68676,-0.218936
1,-1.48516,-0.314759,-0.624142,0.087907,-0.304571,-0.304542,0.805793,-2.141096,-0.714623,-0.322596,...,-1.460564,-0.620022,0.184013,-0.100026,-0.757333,0.05936,0.10243,-0.276469,-0.28014,0.143615
2,-2.964004,0.05769,-0.387791,-0.070234,-0.139432,0.031475,0.018988,-1.899382,-1.059893,-0.516142,...,-1.345557,-0.493776,-0.228195,-0.459461,-0.96834,-0.330881,0.419336,-0.158184,-0.704374,-0.534583
3,-2.392055,-0.636052,-0.422132,-0.026249,-0.319615,-0.097562,0.037515,-3.127571,-0.792688,-0.868215,...,-0.712235,-0.84632,0.134637,-0.252746,-1.039371,-0.085814,0.0949,-0.087613,-0.318083,-0.429983
4,-3.336677,-0.464687,-0.565287,0.039973,-0.327088,-0.089449,0.033562,-2.240198,-0.902725,-0.738746,...,-1.869687,-0.879411,0.074374,-0.104297,-0.665542,-0.360077,-0.251629,-0.107372,-0.908419,-0.46704


In [4]:
# Function to extract weights for sklearn models
def extract_weights(model, model_name):
    if model_name in ["pca", "ica", "nmf"]:
        weights_df = pd.DataFrame(model.components_, columns=dependency_df.drop(columns=["ModelID"]).columns.tolist()).transpose()
        weights_df.columns = [f"{x}" for x in range(0, weights_df.shape[1])]
    
    weights_df = weights_df.reset_index().rename(columns={"index": "genes"})
    return weights_df

# GSEA function (same as before)
def perform_gsea(weights_df, model_name, num_components, lib="CORUM"):
    library = blitz.enrichr.get_library(lib)
    random.seed(0)
    seed = random.random()
    gsea_results = []
    for col in weights_df.columns[1:]:  # Skip 'genes' column
        gene_signature = weights_df[['genes', col]]
        if gene_signature.shape[0] > 0:
            try:
                gsea_result = blitz.gsea(gene_signature, library, seed=seed)
                gsea_result = gsea_result.reset_index()
                for _, pathway_result in gsea_result.iterrows():
                    result_row = {
                        "z": int(col),
                        "full_model_z": num_components,
                        "model": str(model_name),
                        "reactome_pathway": str(pathway_result['Term']),
                        "gsea_es_score": pathway_result['es'],
                        "nes_score": pathway_result['nes'],
                        "p_value": pathway_result['pval'],
                        "shuffled": False
                    }
                    gsea_results.append(result_row)
                    
            except ZeroDivisionError:
                print(f"Skipping GSEA for {col} due to zero division error.")
    
    gsea_results_df = pd.DataFrame(gsea_results)
    return gsea_results_df

# Define the location of the saved models and output directory for GSEA results
model_save_dir = pathlib.Path("saved_models")
output_dir = pathlib.Path("gsea_results")
output_dir.mkdir(parents=True, exist_ok=True)

# Latent dimensions and model names to iterate over
latent_dims = [2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 14, 16, 18, 20, 25, 30, 35, 40, 45, 50, 60, 70, 80, 90, 100, 150, 200]
model_names = ["pca", "ica", "nmf", "vanillavae", "betavae", "betatcvae"]

final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
try:
    combined_results_df = pd.read_parquet(final_output_file)
    print(f"Loaded existing results from {final_output_file}")
except FileNotFoundError:
    # If the file doesn't exist, initialize an empty DataFrame
    combined_results_df = pd.DataFrame()
    print(f"No existing file found. Initialized empty DataFrame.")

for num_components in latent_dims:
    for model_name in model_names:
        # Load the saved model
        # Check if this model and latent dimension have already been processed
        if not combined_results_df.empty:
            if ((combined_results_df['model'] == model_name) & 
                (combined_results_df['full_model_z'] == num_components)).any():
                print(f"Skipping {model_name} with {num_components} dimensions as it is already processed.")
                continue  # Skip to the next iteration if this combination is already present
        model_filename = model_save_dir / f"{model_name}_{num_components}_components_model.joblib"
        if model_filename.exists():
            print(f"Loading model from {model_filename}")
            model = joblib.load(model_filename)
            
            if model_name in ["pca", "ica", "nmf"]:
                # Extract the weight matrix
                weight_matrix_df = extract_weights(model, model_name)
            elif model_name == "betavae":
                weight_matrix_df = weights(model, weight_data)
                weight_matrix_df.rename(columns={0: 'genes'}, inplace=True)
            elif model_name == "betatcvae":
                weight_matrix_df = tc_weights(model, weight_data)
                weight_matrix_df.rename(columns={0: 'genes'}, inplace=True)
            elif model_name == "vanillavae":
                weight_matrix_df = vanilla_weights(model, weight_data)
                weight_matrix_df.rename(columns={0: 'genes'}, inplace=True)
            # Perform GSEA
            gsea_results_df = perform_gsea(weight_matrix_df, model_name, num_components)
            combined_results_df = pd.concat([combined_results_df, gsea_results_df], ignore_index=True)
        else:
            print(f"Model file {model_filename} not found. Skipping.")
            

No existing file found. Initialized empty DataFrame.
Loading model from saved_models/pca_2_components_model.joblib
Loading model from saved_models/ica_2_components_model.joblib
Loading model from saved_models/nmf_2_components_model.joblib
Loading model from saved_models/vanillavae_2_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_2_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_2_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_3_components_model.joblib
Loading model from saved_models/ica_3_components_model.joblib
Loading model from saved_models/nmf_3_components_model.joblib
Loading model from saved_models/vanillavae_3_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_3_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_3_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_4_components_model.joblib
Loading model from saved_models/ica_4_components_model.joblib
Loading model from saved_models/nmf_4_components_model.joblib
Loading model from saved_models/vanillavae_4_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_4_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_4_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_5_components_model.joblib
Loading model from saved_models/ica_5_components_model.joblib
Loading model from saved_models/nmf_5_components_model.joblib
Loading model from saved_models/vanillavae_5_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_5_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_5_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_6_components_model.joblib
Loading model from saved_models/ica_6_components_model.joblib
Loading model from saved_models/nmf_6_components_model.joblib
Loading model from saved_models/vanillavae_6_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_6_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_6_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_7_components_model.joblib
Loading model from saved_models/ica_7_components_model.joblib
Loading model from saved_models/nmf_7_components_model.joblib
Loading model from saved_models/vanillavae_7_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_7_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_7_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_8_components_model.joblib
Loading model from saved_models/ica_8_components_model.joblib
Loading model from saved_models/nmf_8_components_model.joblib
Loading model from saved_models/vanillavae_8_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_8_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_8_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_9_components_model.joblib
Loading model from saved_models/ica_9_components_model.joblib
Loading model from saved_models/nmf_9_components_model.joblib
Loading model from saved_models/vanillavae_9_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_9_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_9_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_10_components_model.joblib
Loading model from saved_models/ica_10_components_model.joblib
Loading model from saved_models/nmf_10_components_model.joblib
Loading model from saved_models/vanillavae_10_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_10_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_10_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_12_components_model.joblib
Loading model from saved_models/ica_12_components_model.joblib
Loading model from saved_models/nmf_12_components_model.joblib
Loading model from saved_models/vanillavae_12_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_12_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_12_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_14_components_model.joblib
Loading model from saved_models/ica_14_components_model.joblib
Loading model from saved_models/nmf_14_components_model.joblib
Loading model from saved_models/vanillavae_14_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_14_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_14_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_16_components_model.joblib
Loading model from saved_models/ica_16_components_model.joblib
Loading model from saved_models/nmf_16_components_model.joblib
Loading model from saved_models/vanillavae_16_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_16_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_16_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_18_components_model.joblib
Loading model from saved_models/ica_18_components_model.joblib
Loading model from saved_models/nmf_18_components_model.joblib
Loading model from saved_models/vanillavae_18_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_18_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_18_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_20_components_model.joblib
Loading model from saved_models/ica_20_components_model.joblib
Loading model from saved_models/nmf_20_components_model.joblib
Loading model from saved_models/vanillavae_20_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_20_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betatcvae_20_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/pca_25_components_model.joblib
Loading model from saved_models/ica_25_components_model.joblib
Loading model from saved_models/nmf_25_components_model.joblib
Loading model from saved_models/vanillavae_25_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from saved_models/betavae_25_components_model.joblib


  return torch.load(io.BytesIO(b))


Model file saved_models/betatcvae_25_components_model.joblib not found. Skipping.
Loading model from saved_models/pca_30_components_model.joblib
Loading model from saved_models/ica_30_components_model.joblib
Loading model from saved_models/nmf_30_components_model.joblib
Loading model from saved_models/vanillavae_30_components_model.joblib


  return torch.load(io.BytesIO(b))


Model file saved_models/betavae_30_components_model.joblib not found. Skipping.
Model file saved_models/betatcvae_30_components_model.joblib not found. Skipping.
Model file saved_models/pca_35_components_model.joblib not found. Skipping.
Model file saved_models/ica_35_components_model.joblib not found. Skipping.
Model file saved_models/nmf_35_components_model.joblib not found. Skipping.
Model file saved_models/vanillavae_35_components_model.joblib not found. Skipping.
Model file saved_models/betavae_35_components_model.joblib not found. Skipping.
Model file saved_models/betatcvae_35_components_model.joblib not found. Skipping.
Model file saved_models/pca_40_components_model.joblib not found. Skipping.
Model file saved_models/ica_40_components_model.joblib not found. Skipping.
Model file saved_models/nmf_40_components_model.joblib not found. Skipping.
Model file saved_models/vanillavae_40_components_model.joblib not found. Skipping.
Model file saved_models/betavae_40_components_model.jo

In [7]:
# Save the combined dataframe to a file
final_output_file = output_dir / "combined_z_matrix_gsea_results.parquet"
combined_results_df.to_parquet(final_output_file, index=False)

print(f"Saved final combined z_matrix and GSEA results to {final_output_file}")

#Save as CSV for R 
csv_output_file = output_dir / "combined_z_matrix_gsea_results.csv"
combined_results_df.to_csv(csv_output_file, index=False)

Saved final combined z_matrix and GSEA results to gsea_results/combined_z_matrix_gsea_results.parquet


In [8]:
combined_results_df.sort_values(by='gsea_es_score', key=abs, ascending = False).head(50)

Unnamed: 0,z,full_model_z,model,reactome_pathway,gsea_es_score,nes_score,p_value,shuffled
96220,7,16,betatcvae,Toposome (human),-0.943584,-3.08768,0.002017,False
59671,3,12,nmf,Nucleic and chromatin Fanconi complex (human),-0.931089,-2.959808,0.003078,False
166601,21,30,ica,SIN3-ING1b complex I (human),0.930681,2.663094,0.007743,False
18190,6,6,vanillavae,CHUK-NFKB2-REL-IKBKG-SPAG9-NFKB1-NFKBIE-COPB2-...,0.92955,2.832204,0.004623,False
42160,3,9,betavae,NCOR2 complex (human),-0.925815,-2.559891,0.01047,False
156910,20,25,betavae,CHUK-NFKB2-REL-IKBKG-SPAG9-NFKB1-NFKBIE-COPB2-...,0.921811,2.88604,0.003901,False
22101,3,7,ica,Nucleic and chromatin Fanconi complex (human),0.920653,2.528776,0.011446,False
108630,10,18,vanillavae,Condensin II (human),0.919581,2.651714,0.008008,False
157590,24,25,betavae,CHUK-NFKB2-REL-IKBKG-SPAG9-NFKB1-NFKBIE-COPB2-...,0.917719,2.883604,0.003932,False
31620,1,8,vanillavae,Condensin II (human),-0.915064,-2.611345,0.009019,False
