In [69]:
import joblib
import pandas as pd
from scipy.stats import pearsonr
from sklearn.preprocessing import MinMaxScaler
import pathlib
import sys
import torch
from torch.utils.data import DataLoader, TensorDataset
import numpy as np 
import matplotlib.pyplot as plt

script_directory = pathlib.Path("../2.train-VAE/utils/").resolve()
sys.path.insert(0, str(script_directory))
from betavae import  extract_latent_dimensions
from betatcvae import tc_extract_latent_dimensions
from vanillavae import vvae_extract_latent_dimensions

from utils import load_utils

sys.path.insert(0, "../utils/")
from data_loader import load_model_data

In [70]:
# Function to extract latent dimensions for PCA, ICA, and NMF models
def sklearn_extract_latent_dimensions(model, model_name, dependency_df):
    # Extract components from the model (latent vectors)
    original_feature_names = model.feature_names_in_
    reordered_df = dependency_df[original_feature_names]
    # Transform models into pca space
    latent_df = pd.DataFrame(
        model.transform(reordered_df)
    )

    # Recode column space and add back model IDs
    latent_df.columns = [f"z_{x}" for x in range(0, latent_df.shape[1])]
    latent_df = pd.concat([dependency_df.loc[:, "ModelID"], latent_df], axis="columns")
    
    return latent_df


def perform_correlation(latent_df, drug_df, model_name, num_components, shuffle=False):
    """
    Perform Pearson correlation between latent dimensions and drug dependency scores.
    
    Parameters:
    latent_df (pd.DataFrame): Dataframe containing latent dimensions (e.g., PCA/ICA/NMF).
    drug_df (pd.DataFrame): Dataframe containing drug dependency scores with `ModelID`.
    model_name (str): Name of the model used to extract the latent dimensions (PCA, ICA, NMF).
    num_components (int): Number of latent dimensions/components in the model.
    
    Returns:
    pd.DataFrame: Correlation results between latent dimensions and drug scores.
    """
    correlation_results = []
    if 'ModelID' in latent_df.columns:
        latent_df = latent_df.set_index('ModelID')

    # Align both dataframes based on the ModelID
    common_model_ids = latent_df.index.intersection(drug_df.index)

    # Filter both dataframes to keep only common ModelIDs
    latent_df_filtered = latent_df.loc[common_model_ids]
    prism_df_filtered = drug_df.loc[common_model_ids]

    # Check the variance of each latent dimension and drug response column
    latent_variance = latent_df_filtered.var()
    prism_variance = prism_df_filtered.var()

    # Filter out constant columns (variance == 0)
    latent_df_filtered = latent_df_filtered.loc[:, latent_variance != 0]
    prism_df_filtered = prism_df_filtered.loc[:, prism_variance != 0]
    # Loop over each latent dimension and calculate correlation with each drug
    for latent_col in latent_df_filtered.columns:
        for drug_col in prism_df_filtered.columns:
            latent_values = latent_df_filtered[latent_col]
            drug_values = prism_df_filtered[drug_col]
            # Check if either column is constant
            if latent_values.nunique() <= 1 or drug_values.nunique() <= 1:
                corr = np.nan
            else:
                # Drop missing values for both columns
                valid_data = pd.concat([latent_values, drug_values], axis=1).dropna()
                latent_values_valid = valid_data[latent_col]
                drug_values_valid = valid_data[drug_col]
                if len(latent_values_valid) > 1 and len(drug_values_valid) > 1:
                    # Calculate Pearson correlation
                    corr, p_value = pearsonr(latent_values_valid, drug_values_valid)
                else:
                    corr = np.nan
                    print("nan")
        
                # Store the results
                result_row = {
                    "z": int(latent_col.replace("z_", "")),
                    "full_model_z": num_components,
                    "model": str(model_name),
                    "drug": str(drug_col),
                    "pearson_correlation": corr,
                    "p_value": p_value,
                    "shuffled": shuffle
                }
                correlation_results.append(result_row)
    # Convert results into a dataframe
    correlation_results_df = pd.DataFrame(correlation_results)
    return correlation_results_df

In [71]:
data_directory = pathlib.Path("../0.data-download/data").resolve()
dependency_file = pathlib.Path(f"{data_directory}/CRISPRGeneEffect.parquet").resolve()
gene_dict_file = pathlib.Path(f"{data_directory}/CRISPR_gene_dictionary.parquet").resolve()

In [72]:
# Load PRISM data
top_dir = "../5.drug-dependency"
data_dir = "data"

prism_df, prism_cell_df, prism_trt_df = load_utils.load_prism(
    top_dir=top_dir,
    data_dir=data_dir,
    secondary_screen=False,
    load_cell_info=True,
    load_treatment_info=True,
)

# Reset the index and name it ModelID
prism_df.reset_index(inplace=True)
prism_df.rename(columns={'index': 'ModelID'}, inplace=True)
prism_df.set_index('ModelID', inplace=True)
prism_df.head()

Unnamed: 0_level_0,BRD-A00077618-236-07-6::2.5::HTS,BRD-A00100033-001-08-9::2.5::HTS,BRD-A00147595-001-01-5::2.5::HTS,BRD-A00218260-001-03-4::2.5::HTS,BRD-A00376169-001-01-6::2.5::HTS,BRD-A00520476-001-07-4::2.5::HTS,BRD-A00546892-001-02-6::2.5::HTS,BRD-A00578795-001-04-3::2.5::HTS,BRD-A00758722-001-04-9::2.5::HTS,BRD-A00827783-001-24-6::2.5::HTS,...,BRD-K98557884-001-01-6::2.5::MTS004,BRD-K99077012-001-01-9::2.332734192::MTS004,BRD-K99199077-001-16-1::2.603211317::MTS004,BRD-K99431849-001-01-7::2.500018158::MTS004,BRD-K99447003-335-04-1::2.37737659::MTS004,BRD-K99506538-001-03-8::2.5::MTS004,BRD-K99616396-001-05-1::2.499991421::MTS004,BRD-K99879819-001-02-1::2.5187366::MTS004,BRD-K99919177-001-01-3::2.5::MTS004,BRD-M63173034-001-03-6::2.64076472::MTS004
ModelID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ACH-000001,-0.015577,-0.449332,0.489379,0.206675,0.27273,0.021036,-0.02546,0.467158,-0.736306,0.644137,...,0.429238,0.204841,0.150055,-0.575404,-0.101247,0.399233,-0.127658,-0.141651,-1.153652,0.510464
ACH-000007,-0.09573,0.257943,0.772349,-0.438502,-0.732832,0.779201,0.426523,-1.288508,-0.476133,-0.277105,...,-0.471486,0.212998,-0.12323,0.625527,0.383198,0.212031,0.349225,-0.387439,-0.831461,0.323558
ACH-000008,0.37948,-0.596132,0.548056,0.422269,-0.216986,0.081866,0.145335,-0.570841,-0.512119,0.452698,...,-0.111951,0.534787,0.206642,-0.410153,-0.560722,-0.036088,0.158071,0.171043,-3.94709,0.09931
ACH-000010_FAILED_STR,0.11889,-0.231615,0.621937,-0.202707,-1.005139,-0.213739,0.020246,-0.795278,,0.679571,...,0.200605,-0.075356,0.61031,-0.019413,-0.202971,0.218158,-0.411009,-0.18154,-3.010225,0.090652
ACH-000011,0.145346,-0.499274,0.26747,0.157804,-0.272286,0.207768,0.004464,-0.19168,-0.310375,0.112537,...,-0.076863,0.026002,0.139921,-0.261704,0.085339,0.447482,0.16462,-0.565251,-4.110627,0.222394


In [73]:
# Create a copy of the prism dataframe to shuffle the values without removing the ModelID column
prism_df_shuffled = prism_df.copy()

# Iterate over each drug column (except 'ModelID') and shuffle its values
for drug_col in prism_df_shuffled.columns:
    if drug_col != 'ModelID':
        # Shuffle the values of the column without resetting the index
        prism_df_shuffled[drug_col] = prism_df_shuffled[drug_col].sample(frac=1, random_state=None).values

prism_df_shuffled.head()

Unnamed: 0_level_0,BRD-A00077618-236-07-6::2.5::HTS,BRD-A00100033-001-08-9::2.5::HTS,BRD-A00147595-001-01-5::2.5::HTS,BRD-A00218260-001-03-4::2.5::HTS,BRD-A00376169-001-01-6::2.5::HTS,BRD-A00520476-001-07-4::2.5::HTS,BRD-A00546892-001-02-6::2.5::HTS,BRD-A00578795-001-04-3::2.5::HTS,BRD-A00758722-001-04-9::2.5::HTS,BRD-A00827783-001-24-6::2.5::HTS,...,BRD-K98557884-001-01-6::2.5::MTS004,BRD-K99077012-001-01-9::2.332734192::MTS004,BRD-K99199077-001-16-1::2.603211317::MTS004,BRD-K99431849-001-01-7::2.500018158::MTS004,BRD-K99447003-335-04-1::2.37737659::MTS004,BRD-K99506538-001-03-8::2.5::MTS004,BRD-K99616396-001-05-1::2.499991421::MTS004,BRD-K99879819-001-02-1::2.5187366::MTS004,BRD-K99919177-001-01-3::2.5::MTS004,BRD-M63173034-001-03-6::2.64076472::MTS004
ModelID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
ACH-000001,-0.163078,-0.141022,0.280622,-0.089914,0.058201,0.85399,0.228885,0.077809,-1.851323,0.647325,...,0.047355,0.068484,-0.007079,0.005752,0.135617,0.171826,-0.52266,-0.751369,-2.450714,0.83605
ACH-000007,-0.227344,-0.045904,0.077096,0.361099,-0.189716,0.203614,-0.423503,-0.003777,0.132081,0.22776,...,0.554327,0.381861,0.434845,-0.412535,-0.191883,0.139679,-0.363183,-0.126036,-1.33736,0.353163
ACH-000008,-1.784574,-0.249375,-0.55033,-0.076472,0.237283,0.808024,0.222299,-0.643501,-1.608127,-0.420496,...,-0.114721,0.353309,-0.376246,0.15722,0.794863,0.67248,0.539113,-0.068331,-2.963241,-0.494247
ACH-000010_FAILED_STR,-0.180194,0.273475,0.650977,0.469955,-0.260721,0.521024,0.268623,-0.037395,-1.395684,0.287478,...,0.437695,0.002237,-0.233532,-0.294671,-0.031543,-0.27061,-0.300999,0.280022,-2.494909,0.01795
ACH-000011,0.397743,-0.078387,0.10314,0.026767,0.215636,0.105494,-0.296193,-0.52115,-0.676141,0.006124,...,0.336527,0.417571,-0.110321,,0.285538,0.307538,,0.437737,-1.885755,0.122579


In [74]:
# Load metadata
metadata_df_dir = pathlib.Path("../0.data-download/data/metadata_df.parquet")
metadata = pd.read_parquet(metadata_df_dir)
print(metadata.shape)

#Load dependency data
dependency_df, gene_dict_df = load_model_data(dependency_file, gene_dict_file)
dependency_df.head()

# Initialize the MinMaxScaler
scaler = MinMaxScaler()

# Apply the scaler to the numeric columns
dependency_df[dependency_df.select_dtypes(include=['float64', 'int']).columns] = scaler.fit_transform(
    dependency_df.select_dtypes(include=['float64', 'int'])
)

(958, 3)
(1150, 18444)


In [75]:
train_and_test_subbed_dir = pathlib.Path("../0.data-download/data/train_and_test_subbed.parquet")
train_and_test_subbed = pd.read_parquet(train_and_test_subbed_dir)


# Convert DataFrame to NumPy and then Tensor
train_test_array = train_and_test_subbed.to_numpy()
train_test_tensor = torch.tensor(train_test_array, dtype=torch.float32)

#Create TensorDataset and DataLoader
tensor_dataset = TensorDataset(train_test_tensor)
train_and_test_subbed_loader = DataLoader(tensor_dataset, batch_size=32, shuffle=False)

In [76]:
# Define the location of the saved models and output directory for correlation results
model_save_dir = pathlib.Path("../4.gene_expression_signatures/saved_models")
output_dir = pathlib.Path("results")
output_dir.mkdir(parents=True, exist_ok=True)

# Latent dimensions and model names to iterate over
latent_dims = [2, 10, 20, 50, 100, 200]
model_names = ["pca", "ica", "nmf", "vanillavae", "betavae", "betatcvae"]

# File to store the combined correlation results
final_output_file = output_dir / "combined_latent_drug_correlations.parquet"
try:
    combined_results_df = pd.read_parquet(final_output_file)
    print(f"Loaded existing results from {final_output_file}")
except FileNotFoundError:
    # If the file doesn't exist, initialize an empty DataFrame
    combined_results_df = pd.DataFrame()
    print(f"No existing file found. Initialized empty DataFrame.")

for num_components in latent_dims:
    for model_name in model_names:
        # Check if this model and latent dimension have already been processed
        if not combined_results_df.empty:
            if ((combined_results_df['model'] == model_name) & 
                (combined_results_df['full_model_z'] == num_components)).any():
                print(f"Skipping {model_name} with {num_components} dimensions as it is already processed.")
                continue  # Skip to the next iteration if this combination is already present
        
        # Load the saved model
        model_filename = model_save_dir / f"{model_name}_{num_components}_components_model.joblib"
        if model_filename.exists():
            print(f"Loading model from {model_filename}")
            model = joblib.load(model_filename)
            
            if model_name in ["pca", "ica", "nmf"]:
                # Extract the latent dimensions for these models
                latent_df = sklearn_extract_latent_dimensions(model, model_name, dependency_df)
            elif model_name == "betavae":
                latent_df = extract_latent_dimensions(model, train_and_test_subbed_loader, metadata)
            elif model_name == "betatcvae":
                latent_df = tc_extract_latent_dimensions(model, train_and_test_subbed_loader, metadata)
            elif model_name == "vanillavae":
                latent_df = vvae_extract_latent_dimensions(model, train_and_test_subbed_loader, metadata)

            latent_df.columns = ['ModelID'] + [f'z_{col}' if isinstance(col, int) else col for col in latent_df.columns[1:]]
            # Perform Pearson correlation between latent dimensions and drug data
            correlation_results_df = perform_correlation(latent_df, prism_df, model_name, num_components)
            # Perform Pearson correlation for shuffled data (negative control)
            negative_control_results_df = perform_correlation(latent_df, prism_df_shuffled, model_name, num_components, shuffle=True)
            # Concatenate results to the combined dataframe
            combined_results_df = pd.concat([combined_results_df, correlation_results_df, negative_control_results_df], ignore_index=True)
        else:
            print(f"Model file {model_filename} not found. Skipping.")

# Save the combined results to a parquet file
combined_results_df.to_parquet(final_output_file)
print(f"Saved combined results to {final_output_file}")

Loaded existing results from results/combined_latent_drug_correlations.parquet
Skipping pca with 2 dimensions as it is already processed.
Skipping ica with 2 dimensions as it is already processed.
Loading model from ../4.gene_expression_signatures/saved_models/nmf_2_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/vanillavae_2_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/betavae_2_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/betatcvae_2_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/pca_10_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/ica_10_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/nmf_10_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/vanillavae_10_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/betavae_10_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/betatcvae_10_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/pca_20_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/ica_20_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/nmf_20_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/vanillavae_20_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betavae_20_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betatcvae_20_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/pca_50_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/ica_50_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/nmf_50_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/vanillavae_50_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betavae_50_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betatcvae_50_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/pca_100_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/ica_100_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/nmf_100_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/vanillavae_100_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betavae_100_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betatcvae_100_components_model.joblib


  return torch.load(io.BytesIO(b))


Loading model from ../4.gene_expression_signatures/saved_models/pca_200_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/ica_200_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/nmf_200_components_model.joblib
Loading model from ../4.gene_expression_signatures/saved_models/vanillavae_200_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betavae_200_components_model.joblib


  return torch.load(io.BytesIO(b))
  corr, p_value = pearsonr(latent_values_valid, drug_values_valid)


Loading model from ../4.gene_expression_signatures/saved_models/betatcvae_200_components_model.joblib


  return torch.load(io.BytesIO(b))


Saved combined results to results/combined_latent_drug_correlations.parquet


In [84]:
# Assuming 'drug_column_name' is the column in prism_trt_df that matches the 'drug' column in correlation_df
prism_trt_df_filtered = prism_trt_df[['column_name', 'name', 'moa', 'target', 'indication', 'phase']]

# Merge correlation_df with prism_trt_df based on the 'drug' column in correlation_df and the matching column in prism_trt_df
correlation_df_merged = pd.merge(combined_results_df, prism_trt_df_filtered, how='left', left_on='drug', right_on='column_name')

# Drop the redundant drug_column_name column after the merge if needed
correlation_df_merged = correlation_df_merged.drop(columns=['column_name'])

significant_corr_df = correlation_df_merged[
    (correlation_df_merged['pearson_correlation'].abs() > 0.1)
]

# Save the combined results to a parquet file
correlation_df_merged.to_parquet(final_output_file)
print(f"Saved combined results to {final_output_file}")

#Save as CSV for R 
csv_output_file = output_dir / "combined_latent_drug_correlations.csv"
combined_results_df.to_csv(csv_output_file, index=False)

Saved combined results to results/combined_latent_drug_correlations.parquet


In [83]:
combined_results_df.sort_values(by='pearson_correlation', key=abs, ascending = False).head(50)

Unnamed: 0,z,full_model_z,model,drug,pearson_correlation,p_value,shuffled
832323,3,20,ica,BRD-K62627508-001-01-5::2.5::HTS,-0.654912,3.553648e-58,False
832896,3,20,ica,BRD-K79584249-001-01-3::2.5::HTS,-0.629819,1.981946e-52,False
832400,3,20,ica,BRD-K64925568-001-01-8::2.5::HTS,-0.619451,1.639182e-50,False
330921,8,10,nmf,BRD-K62627508-001-01-5::2.5::HTS,-0.593012,2.091656e-45,False
1043193,8,20,nmf,BRD-K62627508-001-01-5::2.5::HTS,-0.575836,2.469647e-42,False
330998,8,10,nmf,BRD-K64925568-001-01-8::2.5::HTS,-0.57012,2.3725529999999997e-41,False
331494,8,10,nmf,BRD-K79584249-001-01-3::2.5::HTS,-0.562082,7.774741e-40,False
1043766,8,20,nmf,BRD-K79584249-001-01-3::2.5::HTS,-0.549912,7.193560999999999e-38,False
1043270,8,20,nmf,BRD-K64925568-001-01-8::2.5::HTS,-0.548167,9.464346e-38,False
6147808,11,100,nmf,BRD-K50010139-001-02-3::2.5::MTS004,-0.515469,5.600041e-32,False
