# <center> Characterizing the evolution of adipose-derived mesenchymal stem cells (AMSCs) after FFA addition with VAEs <center> 

**Summary**
    
In this project we will use MOVE (Multi-Omics Variational Autoencoder) to integrate Lipocyte Profiler data, RNA-seq data and Polygenic risk scores (PRSs). We will aim to:

- Find the most relevant variables to characterize the system of study.
- Find associations between gene expression levels and morphological features.
- Analyze the effects of lowering/increasing PRS scores.
    
**Papers of interest:**
    
- Allesøe, R.L., Lundgaard, A.T., Hernández Medina, R. *et al*. Discovery of
drug–omics associations in type 2 diabetes with generative deep-learning models.*Nat Biotechnol* (2023). https://www.nature.com/articles/s41587-022-01520-x
- Samantha Laber, Sophie Strobel, Josep M. Mercader, Hesam Dashti, Felipe R.C. dos Santos, Phil Kubitz, Maya Jackson, Alina Ainbinder, Julius Honecker, Saaket Agrawal, Garrett Garborcauskas, David R. Stirling, Aaron Leong, Katherine Figueroa, Nasa Sinnott-Armstrong, Maria Kost-Alimova, Giacomo Deodato, Alycen Harney, Gregory P. Way, Alham Saadat, Sierra Harken, Saskia Reibe-Pal, Hannah Ebert, Yixin Zhang, Virtu Calabuig-Navarro, Elizabeth McGonagle, Adam Stefek, Josée Dupuis, Beth A. Cimini, Hans Hauner, Miriam S. Udler, Anne E. Carpenter, Jose C. Florez, Cecilia Lindgren, Suzanne B.R. Jacobs, Melina Claussnitzer. Discovering cellular programs of intrinsic and extrinsic drivers of metabolic traits using LipocyteProfiler, Cell Genomics,Volume 3, Issue 7 (2023) https://doi.org/10.1016/j.xgen.2023.100346
    
    
**Theoretical insights**

**About Lipocyte Profiler:**

Lipocyte Profiler extends Cell-Painting and Cell Profiler. Cell images are obtained in a multi-channel fashion, staining for the nucleus, mitochondria, golgi apparatus, lipid vesicles, etc. Then, morphological features characterizing the cell's shape and texture (like granularity, radial distribution of certain organelles, intensity) are quantified and summarized in feature vectors.

**About RNA-seq:**

RNA sequencing will provide gene expression levels at different timepoints for:
- Control cells
- Cells that have been subject to the addition of free fatty acids (FFAs).

**About MOVE:**

MOVE is a Variational AutoEncoder, i.e., it is a neural network trained to compress the information and reconstruct it as accurately as possible. The variational part of it makes it more robust to noise in the input, allows us to treat it from a bayesian perspective, and enables us to use it as a generative model, i.e. we can generate new samples.

MOVE applies *in silico* perturbations to the input in order to find associations between variables.

*__The architecture of the model__*

MOVE consists in two parts: an encoder and a decoder. 

The encoder will compress the input information and extract shared information between features. The encoder is composed by the input layer, a vector of feature values for a given sample; hidden layers connecting the different features, and a latent layer.

The latent layer will contain a compressed representation of our data (lower dimensionality feature vector). Each sample will lie somewhere in what we call "the latent space". The decoder will then reconstruct the input from its latent representation.

      

# Data preprocessing

### Import required packages
The first step is to load all third party packages required to perform the different tasks in this notebook.

In [None]:
import os
import io
import pandas as pd
from firecloud import fiss
from pathlib import Path
import seaborn as sns
import sys
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
from itertools import chain
import gc

mpl.rcParams['font.family'] = 'Latin Modern Roman' #Font for the plots

We can now define a number of global hyperparameters of the notebook

In [None]:
############ Hyperparams #############
INSTALL = True # If it is the first run, we want to install MOVE in this environment
CHECK_CORR = False # To visualize the correlation matrices of RNA features
RESIDUALIZE = True #Do we want to correct the filtered RNA features for sex, batch and age dependencies?
HYPER_TUNNING = True # To perform hyper-parameter tuning

DATA_FOLDERS = ['data_residualized/','interim_data_rsd/'] if RESIDUALIZE else ['data/','interim_data/'] 
FIGURE_FOLDER = 'paper_figures/'
! mkdir -p {FIGURE_FOLDER} # Make figure folder

The first time we run it we will install a number of packages:

In [None]:
if INSTALL:
    # We will clone the version under development of MOVE, which can handle perturbations of continuous variables
    #! git clone -b developer-continuous-v3 https://github.com/RasmussenLab/MOVE.git /home/jupyter/.local/bin/MOVE
    ! git clone https://github.com/RasmussenLab/MOVE.git /home/jupyter/.local/bin/MOVE
    sys.path.append("/home/jupyter/.local/bin/MOVE/src")
    # ! pip install -e /home/jupyter/.local/bin/MOVE/ 
    
    #from terra_notebook_utils import table, gs, drs
    # We can also enable different sections to fold
    ! jupyter nbextension enable codefolding/main
    ! jupyter nbextension enable collapsible_headings/main

    ! pip install omegaconf upsetplot umap-learn
    ! pip install -r /home/jupyter/.local/bin/MOVE/requirements.txt


### Set environment variables


**Navigating the Terra interface and file structure:**

For this project we will mainly use two buckets: 
- 1) The "local" project bucket where we will store all predefined configuration files. These files are found in the DATA tab of this workspace: ```other data``` $\to$ ```files  ```
- 2) The "external" bucket where we will get the data from, called ```shared_2023```.

We can store the names of these buckets in variables for convenience.

In [None]:
BILLING_PROJECT_ID = os.environ['WORKSPACE_NAMESPACE']
WORKSPACE = os.environ['WORKSPACE_NAME']
bucket = os.environ['WORKSPACE_BUCKET']

#EXTERNAL_BUCKET = "gs://amsc_datasets/"
EXTERNAL_BUCKET = "gs://collaborations_marc/shared_2023/"
CONTENT = !gsutil ls {EXTERNAL_BUCKET}
    
print("Billing project: " + BILLING_PROJECT_ID)
print("Workspace: " + WORKSPACE)
print("Bucket: " + bucket)
print("Files in the external bucket:", *CONTENT, sep='\n')

### Split matching RNA-LP file and create IDs

We are going to work with the file _matchedID_normalizedRNA_LP.csv_ . To store the data in a MOVE-friendly manner we have to:

1) __Create an ID file:__ We need unique IDs on a sample per sample basis, so we'll keep an ID column with numerical ids.

2) __Split the data in smaller datasets that will be perturbed one by one.__ These datasets will be:

    - Demographics
        - Sex 
        - Age (will not be provided to the model)
        - Batch (will not be provided to the model)
    - RNA IDs
    - Nuclei
    - Cells
    - Cytoplasm
    
3) __Filter out genes that have low expression or low variance.__
4) __Filter out LP features that have extreme outliers.__

In [None]:
# Read the original csv files as a pandas dataframes
RNA_LP = EXTERNAL_BUCKET + "matchedID_normalizedRNA_LP.csv" 
df_RNA_LP = pd.read_csv(RNA_LP)

covariates = df_RNA_LP[['age','sex','batch']]

variants = EXTERNAL_BUCKET + "variants.csv"
df_variants = pd.read_csv(variants)

PRS = EXTERNAL_BUCKET + "PRS.csv"
df_PRS = pd.read_csv(PRS)

In [None]:
# Finding patients that were not genotyped
patient_set = set(df_RNA_LP['patientID'].values)
variant_set = set(df_variants['SubjID'].values)
missing =  patient_set - variant_set
print( "These patients in the RNA-LP dataset were not genotyped:\n", *missing, sep="\n")

In [None]:
# Number of samples falling under different subgroups:
for cell_type in ['sc', 'vc']:
    for con_3 in [0,1]:
        for con_2 in [0,3,8,14]:
            CON_1 = df_RNA_LP['cellType'] == cell_type
            CON_2 = df_RNA_LP['Day'] == con_2
            CON_3 = df_RNA_LP['FFA'] == con_3

            # Filter rows where column 'A' has value 't'
            filtered_df = df_RNA_LP[CON_1 & CON_2 & CON_3]

            # Count the rows in the filtered dataframe
            count = filtered_df.shape[0]

            print(f"For cell type {cell_type}, Day {con_2}, FFA {con_3} we have {count} samples")

In [None]:
# Visualizing sex-dependent gene expression variation
fig = plt.figure(figsize=(15,5))
df_violin_1 = df_RNA_LP[df_RNA_LP['sex'] == 1].filter(regex="^ENSG").iloc[:,:50]
df_violin_2 = df_RNA_LP[df_RNA_LP['sex'] == 2].filter(regex="^ENSG").iloc[:,:50]
plt.violinplot(df_violin_1)
plt.violinplot(df_violin_2)
plt.xticks(np.arange(1,51),df_violin_1.columns, rotation=90)
plt.xlabel('Gene')
plt.ylabel('Expression (DESeq norm counts)')
fig.show()

### Compensate for linear effects of sex, age and batch on gene expression features.

Here we will train a linear model to remove all linear contributions of the covariates to the expression levels.
We should decide on which genes to keep before residualizing, since the transformation changes significantly the data.

In [None]:
def residualize(targets, covariates):
    """
    This function trains a linear model to take into account the contributions
    of age, sex and batch in the gene expression values (RNA counts).
    
    Args:
        targets: Pandas dataframe of shape (N_samples x N_target_features) 
                 containing the target features to correct.
        covariates: Pandas dataframe of shape (N_samples x N_covariates)
                    containing the independent variables that we want to 
                    correct for.
    Returns:
        corrected_targets: Array with the linear contributions of the covariates 
                           removed.
    
    """
    import statsmodels.api as sm
    
    # Add a constant term to the independent variables
    covariates = sm.add_constant(covariates)
    
    # Fit a multivariate linear regression model
    model = sm.OLS(targets, covariates).fit()
    predictions = pd.DataFrame(data=model.predict(covariates))
    predictions.columns = targets.columns
    
    # Get the residuals (corrected values) for each target variable
    residuals = targets - predictions
    
    # Return the residuals (corrected targets)
    return residuals

def plot_expression_distributions(targets, covariates, title, covariate_of_interest = 'sex', n_genes = 50, savepath = Path(FIGURE_FOLDER)):
    # Visualizing sex-dependent gene expression variation
    fig = plt.figure(figsize=(9,3))
    df_violin_1 = targets[covariates[covariate_of_interest] == 1].iloc[:,:50]
    df_violin_2 = targets[covariates[covariate_of_interest] == 2].iloc[:,:50]
    plt.violinplot(df_violin_1)
    plt.violinplot(df_violin_2)
    plt.xticks(np.arange(1,n_genes + 1),df_violin_1.columns, rotation=90)
    plt.xlabel('Gene')
    plt.ylabel('Expression (DESeq norm counts)')
    plt.tight_layout()
    fig.savefig(savepath / f"{title}.png", dpi=200)

### Create splitted datasets

This is one of the most important steps in the pipeline.
For the RNA processing, we first:
- Filter out genes that have very low expression across samples.
    - We keep genes for which at least 20% of the samples have an expression above 10 reads.
- Keep the genes that present a high variance (after log2 transform).

We then residualize, i.e. remove linear contributions of age, batch and sex to the values. We finally Z-score normalize them.

In [None]:
from sklearn.preprocessing import scale

# Make a directory where we can store the newly splitted files in this session
! mkdir -p {DATA_FOLDERS[0]}
data_path = Path(DATA_FOLDERS[0])


# Filter the original dataset according to column names:
new_dataset_criteria_dict = {"batch":"^batch",
                             "patientID":"patientID",
                             "sex":"sex",
                             "age":"^age",
                             "BMI":"BMI",
                             "T2D":"^T2D",
                             "FFA":"^FFA",
                             "cellType":"^cellType",
                             "Day":"^Day",
                             "RNA":"^ENSG",
                             "Cyto":"^Cytoplasm",
                             "Cells":"^Cells",
                             "Nuc":"^Nuclei"}


################ RNA PREPROCESSING HYPERPARAMS ##############
min_expression = 10
min_samples = 54 # ca. 20% of the dataset, 269 samples 
var_threshold = 0.75 # Quantile
#############################################################

# Create unique ID file:
ids = pd.DataFrame(df_RNA_LP.index.astype(str).tolist(), columns=["ID"])
ids.to_csv(data_path /  "AMSC_ids.txt", index=False, header=False, sep="\t")

# Create separate csv files with shared IDS
for key in new_dataset_criteria_dict.keys():    
    splitted_df = df_RNA_LP.filter(regex=new_dataset_criteria_dict[key])
    
    if key == "T2D":
        splitted_df = splitted_df.astype("str") 
    # Filter out non-informative genes: zero mean or below a given percentile   
    
    if key == "RNA":
        # 1. Initial low-expression filtering
        genes_to_keep = (splitted_df > min_expression).sum(axis=0) >= min_samples
        filtered_df = splitted_df.loc[:, genes_to_keep]
        print(f"Shape after initial filtering: {filtered_df.shape}")

        # 2. Log2 transformation
        log2_df = np.log2(filtered_df + 1)  # Add 1 to avoid log(0)

        # 3. Variability-based filtering
        gene_vars = log2_df.var()
        var_threshold = gene_vars.quantile(var_threshold)
        highly_variable_genes = gene_vars[gene_vars > var_threshold].index
        splitted_df = log2_df[highly_variable_genes]
        print(f"Shape after variability filtering: {splitted_df.shape}")

        if RESIDUALIZE:
            # 4. Residualization
            residualized_df = residualize(splitted_df, covariates)
            print("After residualizing")
            plot_expression_distributions(residualized_df, covariates, "RNA_after_residualizing", covariate_of_interest = 'sex', n_genes = 50)     
            
            # 5. Z-score normalization
            splitted_df = pd.DataFrame(scale(residualized_df), 
                                     index=residualized_df.index, 
                                     columns=residualized_df.columns)
            
            print("After Z-scoring")
            plot_expression_distributions(splitted_df, covariates, "RNA_after_rz_scoring", covariate_of_interest = 'sex', n_genes = 50)
            print(np.max(splitted_df.values), np.min(splitted_df.values))
            
    if (key == "Cells") or (key == "Cyto"):
        #Drop features that contain Nans:
        splitted_df = splitted_df.dropna(axis="columns")
        value_thr = 5
        columns_to_drop = splitted_df.columns[(splitted_df > value_thr).any() | (splitted_df < -value_thr).any()]
        print(f"Columns to drop for the {key} dataset", len(columns_to_drop))
        splitted_df = splitted_df.drop(columns=columns_to_drop)
    if key == "Nuc":
        print("NUC:\n")
        
    splitted_df = pd.concat([ids,splitted_df], axis = 1)
    splitted_df.to_csv(data_path / f"{key}.tsv", sep="\t", header=True, index=False)
    print("Final shape of the dataset:", splitted_df.shape)



### Visualize correlations between features (RNA-seq post filtering)

In [None]:
from scipy.cluster.hierarchy import linkage, leaves_list

MODALITY = "RNA"
data_path = Path(DATA_FOLDERS[0])
df = pd.read_csv(data_path / f"{MODALITY}.tsv", sep="\t").set_index("ID")
STEP = 200

# Calculate the correlation matrix
corr_matrix = df.corr()
columns = corr_matrix.columns

# Compute the linkage matrix
linkage_matrix = linkage(corr_matrix, method='average')

# Get the order of the variables
ordered_indices = leaves_list(linkage_matrix)
corr_matrix = corr_matrix.iloc[ordered_indices, ordered_indices]
columns = columns[ordered_indices]


# Plot the correlation matrix as a heatmap
fig = plt.figure(figsize=(6, 6))
plt.imshow(corr_matrix, cmap='seismic', vmin=-1, vmax=1)
plt.xticks(np.arange(0,len(columns),STEP), columns[::STEP], rotation = 90)
plt.yticks(np.arange(0,len(columns),STEP), columns[::STEP])
plt.title(f'{MODALITY} Correlation Matrix')
plt.colorbar()
plt.tight_layout()
fig.savefig(Path(FIGURE_FOLDER) / f"{MODALITY}_Correlation_matrix.png", dpi=200)

# Extract the upper triangle of the correlation matrix without the diagonal
upper_triangle_indices = np.triu_indices_from(corr_matrix, k=1)
upper_triangle_values = corr_matrix.values[upper_triangle_indices]

# Plot the frequency distribution of the correlation coefficients
fig = plt.figure(figsize=(5, 3))
sns.kdeplot(abs(upper_triangle_values), bw_adjust=0.5)
plt.xlabel('Correlation Coefficient')
plt.ylabel('Density')
plt.title(f'Density Plot of Correlation Coefficients in {MODALITY}')
plt.tight_layout()
fig.savefig(Path(FIGURE_FOLDER) / f"{MODALITY}_Correlation_distribution.png", dpi=200)

### Visualize Mutual information between features. (RNA post processing, optional)

> 🕑 : This step takes a long time and it was skipped.

In [None]:
MI = False

if MI:

    from sklearn.feature_selection import mutual_info_regression
    from joblib import Parallel, delayed


    def mutual_info_continuous(x, y):
        return mutual_info_regression(x.reshape(-1, 1), y)[0]

    def compute_mutual_info_matrix_continuous(df):
        n = df.shape[1]
        mi_matrix = np.zeros((n, n))

        def compute_mi(i, j):
            if i == j:
                return (i, j, 0)
            else:
                mi = mutual_info_continuous(df.iloc[:, i].values, df.iloc[:, j].values)
                return (i, j, mi)

        results = Parallel(n_jobs=-1)(delayed(compute_mi)(i, j) for i in range(n) for j in range(i, n))

        for i, j, mi in results:
            mi_matrix[i, j] = mi
            mi_matrix[j, i] = mi

        return pd.DataFrame(mi_matrix, index=df.columns, columns=df.columns)

    # Assuming df is your DataFrame
    mi_matrix = compute_mutual_info_matrix_continuous(df)

    # Perform hierarchical clustering
    linkage_matrix = linkage(mi_matrix, method='average')
    ordered_indices = leaves_list(linkage_matrix)

    # Reorder the mutual information matrix
    ordered_mi_matrix = mi_matrix.iloc[ordered_indices, ordered_indices]


    # Plot the MI matrix as a heatmap
    plt.figure(figsize=(7, 7))
    plt.imshow(mi_matrix, cmap='inferno') #, vmin=-1, vmax=1)
    plt.title('Correlation Matrix Heatmap')
    plt.colorbar()
    plt.show()


Now we can create the datasets for the variants and PRSs mapping subject IDs to the new numerical IDs.

**Create tsv file for the variants**

We can first visualize the default variants file:

In [None]:
pd.set_option('display.max_rows', 10)
pd.set_option('display.max_columns', 10)
# Output some of the first and last lines of the original variants dataframe
df_variants

We would like to match these patients to the patients we have LP and gene expression data for. In addition, we will unphase the heterozygous data.

In [None]:
# Create variant file with shared IDS: map original ID in variant file to new unique id
df_variants_new = df_variants.rename(columns={'SubjID':'patientID'}).set_index('patientID')
#print(df_variants_new.head)
df_patients = pd.read_csv(data_path / "patientID.tsv", sep="\t")
# Merge variant file (right df) with patient id mapping file (left df) and remove original patientID column
df_patients = df_patients.merge(df_variants_new, left_on='patientID', right_index=True).drop(columns="patientID")


# Since not all patients are genotyped we need to fill the missing indices with NAs:
new_index = range(len(ids))
df_patients = df_patients.reindex(new_index)
# We can now unphase the genotypes
df_patients = df_patients.replace('0|1', '0/1')
df_patients = df_patients.replace('1|0', '0/1')
df_patients["ID"] = df_patients.index
df_patients.to_csv(data_path / "variants.tsv", sep="\t", header=True, index=False)


**Create tsv file for the PRSs**

We will now create a separate tsv file for the PRS scores. 

Process-specific PRS scores will be Z-score normalized (They were not originally). This can be seen in the plot below, where the first 5 violins contained data extremely close to zero originally (blue) and have been widened after the transformation (orange).

In [None]:
# Same protocol as before for PRS file
df_PRS_new = df_PRS.rename(columns={'FID':'patientID'}).set_index('patientID')
df_patients = pd.read_csv(data_path / "patientID.tsv", sep="\t")
df_patients = df_patients.merge(df_PRS_new, left_on='patientID', right_index=True).drop(columns="patientID")


df_patients["ID"] = df_patients.index
df_patients = df_patients.drop(columns="geneticBatch") # Genetic batch is not a PRS as such

plt.violinplot(df_patients.iloc[:,1:])
# Z-score process specific PRSs
df_patients[[colname for colname in df_patients.columns if "prs." in colname]] = scale(df_patients[[colname for colname in df_patients.columns if "prs." in colname]], axis=0)
#Plot violinplots for PRS value distributions
plt.violinplot(df_patients.iloc[:,1:])
# Save the results
df_patients.to_csv(data_path / "PRS.tsv", sep="\t", header=True, index=False)



In [None]:
# Visualize PRS distributions and related covariates
for i,col in enumerate(df_PRS.columns):
    fig = plt.figure()
    plt.hist(df_PRS[col], bins = 100)
    plt.title(f"{col}")
    plt.show()

# Data analysis


### MOVE

MOVE can be found online at RasmussenLab's Github repository (https://github.com/RasmussenLab/MOVE). It can be downloaded as a pip package from the command line as follows:
```bash
pip install move-dl
```

However, we will proceed differently. At the beginning of the notebook we cloned MOVE to this environment, i.e. we "copied" MOVE here. Now we will install it in "editable" mode (-e flag), so that we can change the code locally if required. This could be done by editing the local copy of MOVE, found in ```/home/jupyter/.local/bin/MOVE```.

⚠️ **Code edits**
A few edits to MOVE's code were made for this project.
           
>1) **Storing multiple runs:**                    
> We stored the SHAP results for different hyperparameter choices and obtained the results as ensemble averages. To do that, we modified ```MOVE/src/move/visualization/feature_importance.py``` to store the results in txt files.
>
>2) **Removing feature_mask:**
>   When identifying associations (```identify_associations.py```) , feature_masks were removed and we only kept nan_masks.
>
> 3) **Corrected Cumulative distributions**
> Cumulative distributions obtained when using the KS method were not properly normalyzed. We now multiply by bin_width to obtain a distribution between 0-1. plot_cumulative function found in ```dataset_distributions.py```


In [None]:
if INSTALL:
    ! pip install -e /home/jupyter/.local/bin/MOVE/

### Import config files

Config files for this project will be stored in the data section of the workspace. New config files can also be created and uploaded there. When we run the following cell we copy everything here.

In [None]:
#Copy multiple files in workspace data to the cloud environment
!gsutil -m cp -r $bucket/* . 

## Encode the data

The first step in order to run MOVE is to encode the data in a format that the model can understand.

To encode the data we store all datasets in a TSV format. Each table needs to have a shape `N` &times; `M`, i.e. `N` rows and `M` columns where `N` is the number of samples/individuals and `M` is the number of features.

> 📡 **How is data encoded?**
>
> **_Categorical data is one-hot encoded._** _For a feature like cellType , which has discrete values/categories (e.g., sc or vc), we encode these categories as_ binary bit flags. _This means each category is assigned a value starting with one, and then represented in binary format (with zeros and ones)._
>
>_A useful property of flags is that they do not have hierarchy; they are incompatible with "<" or ">" operators. So (in our example), sc would not be considered more or less important than vc._
>
> **_Continuous data can be z-score normalized_**, _meaning that each feature can be rescaled to have zero mean and unit variance:_
>
>  $$ Z = \frac{x-\mu}{\sigma} $$
> Where $x$ is the vector of feature values for all samples, $\mu$ and $\sigma$ its mean and standard deviation, respectively.
> In this specific project we will work with either the already normalized LP and DESeq normalized gene expressions (after a log2(x+1) transformation) or the residualized matrices over sex, age and batch. These steps were performed in the prepreocessing of the data before feeding the datasets to MOVE, so we set the flags for log2 transform and scale in the config files to false.


The first step is to read the configuration called `AMSC` and specify the pre-defined task called `encode_data`.

⚠️ Remember that the notebook takes user-defined configs in a `config/data` directory located in the current working directory.


In [None]:
from omegaconf import OmegaConf
from move.data import io


config =io.read_config("AMSC", None, 
                        f"data.raw_data_path={DATA_FOLDERS[0]}", 
                        f"data.interim_data_path={DATA_FOLDERS[1]}")

# Print data config file
print(OmegaConf.to_yaml(config, resolve=True))

In [None]:
! move-dl data=AMSC task=encode_data data.raw_data_path={DATA_FOLDERS[0]} data.interim_data_path="./interim_data_rsd_test"

We can now check how does the folder structure look like:

Data will be encoded accordingly and saved to the directory defined as `interim_data_path` in the `data` configuration.



We can confirm how the data looks by loading it:

In [None]:
path = Path(DATA_FOLDERS[1])

cat_datasets, cat_names, con_datasets, con_names = io.load_preprocessed_data(path, config.data.categorical_names, config.data.continuous_names)

The shape of the encoded datasets can now be checked.

> 🔺🟡🟩🔷 **What is the shape of a dataset?**
>
> If we visualize each dataset as a matrix of values, the shape of the matrix refers to the number of values or entries that can be found in every dimension of the matrix. As an example, the dataset `PRS` is described by a matrix with 2 dimensions. It contains 269 samples (in the rows, the first dimension) for which 17 PRS scores were quantified (17 features in columns, the second dimension). The shape of the matrix is indicated in this case as (269, 17). For categorical datasets, we have an additional dimension to hold the number of categories/classes. So a dataset like `cellType` has 269 samples, 1 feature (cellType) and 2 categories (sc/vc).

In [None]:
dataset_names = config.data.categorical_names + config.data.continuous_names

for dataset, dataset_name in zip(cat_datasets + con_datasets, dataset_names):
    print(f"{dataset_name}: {dataset.shape}")

> _**Note:** Cells and Cyto do not show the same shapes as when we created the datasets. This is because there were some columns full of zeros that were not filtered out originally. These are the true shapes we will work with._

__Intermission: ML basics__

In the following sections, we will begin mentioning machine learning (ML) terms. If this is your first approximation to ML, read the following snippet and peruse this [ML glossary](https://developers.google.com/machine-learning/glossary) if need be.

> 🦾 **ML 101: Training the model**
>
>_The main goal when training a machine learning algorithm is to minimize the difference between the model's output and the ground truth output we would like to achieve. For autoencoders, the output is the same set of values as we had in the input._
>
> _To do that, we feed the model a set of input samples, add the errors in the outputs of the model in its current state, and update the relative contributions of the nodes in the network (weights) so that we get closer to the desired output. This is how the network is trained and "learns" how to perform the task._

## Hyperparameter optimization

This section goes through the hyperparameter tuning of MOVE.


> 🔧 **What are hyperparameters?**
>
> Hyperparameters are the variables that either determine the network's structure/architecture (e.g. number of nodes in a hidden layer) or define how the network is trained (e.g. the learning rate).

The process of selecting an optimal set of hyperparameters is called tuning. In MOVE, we define optimal as the settings that produce models that generate the most accurate reconstructions and/or the most stable latent representations.

We will first focus on the **reconstruction accuracy**. The reconstruction accuracy measures how well the model is able to reconstruct the data after it has been compressed to a latent space. It ranges from zero to one. One represents a lossless decompression (reconstruction).

To illustrate how tuning works, we will assess how the number of latent and hidden nodes influence the model's reconstruction accuracy, among other hyperparameters. Running the following cell will start the training of a set of models models, which correspond to the different possible combinations of hyperparameters specified in the config file ```AMSC_tune_reconstruction.yaml```



In [None]:
if HYPER_TUNNING:
    ! move-dl experiment=AMSC__tune_reconstruction data.raw_data_path={DATA_FOLDERS[0]} data.interim_data_path={DATA_FOLDERS[1]}

The output of the previous command is a TSV table called ```reconstruction_stats.tsv```, recording the metrics of each
 run.

In [None]:
#Load the results in a pandas DataFrame
if HYPER_TUNNING:
    results = pd.read_csv(Path("./results/tune_model/reconstruction_stats.tsv"), sep="\t")
    results = results.drop_duplicates(subset=['task.batch_size',
                                              'task.model.beta',
                                              'task.model.num_hidden',
                                              'task.model.num_latent',
                                              'task.training_loop.num_epochs',
                                              'metric',
                                              'dataset',
                                              'split'],
                                      keep='last')

⤵️ _The following code will plot a set of containing the reconstruction accuracy of MOVE when trained using different hyperparameter sets. To visualize the reconstruction accuracies when testing other hyperparameters (e.g. number of hidden layers or number of nodes per layer) further changes in the code must be performed._

In [None]:
if HYPER_TUNNING:    
    
    datasets = ['Cyto','Cells','PRS','RNA', 'Day']
    # Hyperparameters that we want to compare:
    num_hidden = [500, 1000, 2000] # 2000 # 3000
    num_latent = [50, 100, 200]
    h = 0

    #Hyperparameters that we want to keep:
    batch_size =  10
    n_epochs =  300 #
    beta = .0001

    fig, axs = plt.subplots(len(num_hidden),len(datasets), layout="constrained", figsize=(12, 9))
    fig.suptitle(f"beta = {beta}, n_epochs = {n_epochs}, batch_size={batch_size}")

    n_fields = len(num_latent) # Number of inner fields inside each subplot

    for h,hidden in enumerate(num_hidden): # y axis shows what happens when varying the number of hidden nodes
        # Conditions
        subset_conditions = f"(`task.training_loop.num_epochs` == {n_epochs}) \
        & (`task.model.beta` == {beta}) \
        & (`task.model.num_hidden` == '[{num_hidden[h]}]') \
        & (`task.batch_size` == {batch_size})"
        # & (`task.model.num_latent` == {num_latent[]}) \
        
        subset_results = results.query(subset_conditions) #f"`task.training_loop.num_epochs` == {n_epochs[i]}")
        
        #print(subset_results)
        for d,dataset in enumerate(datasets): # Each subplot in the x axis shows a different dataset
            
            results_dataset = subset_results.query(f"dataset == '{dataset}'")
            #print(results_dataset.head)
            test = results_dataset.query("split == 'test'").to_dict(orient="records")
            train = results_dataset.query("split == 'train'").to_dict(orient="records")
            
            # matplotlib complains if fliers are unset
            for bxp_stats in chain(train, test):
                bxp_stats["fliers"] = []
                

            coll1 = axs[h,d].bxp(train, positions=[*range(0, n_fields * 2, 2)], boxprops=dict(facecolor="#7570b3"), patch_artist=True)
            coll2 = axs[h,d].bxp(test, positions=[*range(1, n_fields * 2, 2)], boxprops=dict(facecolor="#1b9e77"), patch_artist=True)


            axs[h,d].set(xticks=np.arange(0.5, n_fields * 2, 2),
                         xticklabels= num_latent,
                         ylim=(0, 1),
                         xlabel="# latent",
                         ylabel=f"{hidden} hidden nodes",
                         title=f"{dataset} features")
            
        if h == 1:
            axs[0,0].legend([coll1["boxes"][0], coll2["boxes"][0]], ["train", "test"], title="split")
        
        elif d != 0:
            axs[h,d].set_ylabel("")
    plt.tight_layout()
    fig.savefig(Path(FIGURE_FOLDER) / f"beta_{beta}_n_epochs_{n_epochs}_unphased.png", dpi=200)

We obtained a total of fifteen plots. 

In different columns of subplots we can see the reconstructions of different datasets.
In different rows we can see the impact of changing the number of hidden nodes on the reconstructions.

During tuning, we split the dataset into a **training** and **test set**. The former refers to the data that the model uses to learn, whereas the latter is a subset of "new" data that the model is unaware of. We can see that each plot has six boxes, with purple boxes showing training performances and green boxes showing test performances. The x-axis of each subplot shows the impact of increasing the number of latent nodes.

__Hyperparameter tuning insights:__


From the plots above, we can conclude that (in terms of reconstruction accuracy) it does not really matter the set of hyperparameters we use at such a low regularization regime.




## Latent space analysis 

This section trains MOVE to integrate the data into a latent space. We will then plot the results and find the important variables for the integration using SHAP analysis.

> ℹ️ About SHAP analysis.
>
> There are many ways to identify the most important features in the data, or the set of features that the model will use the most when encoding the data into a compressed/latent representation. One of them is SHAP (SHapley Additive exPlanations) analysis.
>
> This method measures how much do samples move in latent space when removing one variable at a time from the input. If the model gives a lot of importance to an input variable, e.g. the concentration of a metabolite, removing it from the input will lead to a significant movement of the samples in latent space (i.e., wide band in the SHAP plots, impact on latent space). On the other hand, if an input variable is not really needed, the model will "ignore" it and hence the latent space representation of the samples will not change much when that feature is not there anymore (impact on latent space close to 0).



As in previous examples, first we need to read our configuration files and then we can run the `AMSC__latent` task. You can have a look at the file in ```config/task/AMSC__latent.yaml```.


>_📖 For the advanced reader:_ _This config file, in addition, explicitly sets the learning rate to be 1e-4 and introduces the factors beta (controlling how variational do we make the autoencoder) and the Kullback-Leibler divergence warmup steps, which will gradually introduce the loss term that pushes the samples towards the center of the latent space._


**Example latent space analysis:**

a) *In command-line style:*
>```python
! move-dl data=AMSC task=AMSC__latent_Simon data.raw_data_path={DATA_FOLDERS[0]} data.interim_data_path={DATA_FOLDERS[1]}
>```

b) *Importing the functions themselves:*
>```python
import warnings
from move.core import set_global_seed
from move.data import io
from move.tasks import analyze_latent
>
>warnings.filterwarnings("ignore") # Ignore plotting warnings
config = io.read_config("AMSC", "AMSC__latent_Simon", 
                        f"data.raw_data_path={DATA_FOLDERS[0]}", 
                        f"data.interim_data_path={DATA_FOLDERS[1]}",
                        f"task.multiprocess={MULTIPROCESS}")
>
>config#, "data.interim_data_path='interim_data_rsd/'") #AMSC__latent_B
>set_global_seed(151)  # set seed to get same results, change to get slightly different plots
>analyze_latent(config)
>```

### Loop over hyperparameters to increase reliability of the results obtaine from SHAP analysis.

SHAP analysis is deeply influenced by the nature of the data, but more importantly it tells us what is important for a model with a given set of weights. There are many possible weight sets that can lead the model to similar losses, and also different sets of hyperparameters might lead to different outcomes. Hence, we trained 24 different models, with different sets of hyperparameters, to compensate for this uncertainty. The idea was to identify what variables were deemed important or were ignored regardless of specific training runs or hyperparameter choices. 

This is the list of hyperparameter combinations that we will use:

In [None]:
hyper_list = [(lat,hid,bet) for bet in [.01,.0001] for hid in [[200], [500], [1000], [2000]] for lat in [10,50,100]]
for i,(lat,hid,bet) in enumerate(hyper_list):
    print(i,(lat,hid,bet))


> Note: To run the feature importance on this dataset in Terra we'll need a considerable RAM (ca.30GB). I ran it with 8 CPUs and 52 GB of RAM. 200-2000-.01 was run at the end (and skipped originally,

In [None]:
#import warnings
from move.core import set_global_seed
from move.data import io
from move.tasks import analyze_latent
import warnings

warnings.filterwarnings("ignore") # Ignore plotting warnings

# Hyperparameter set to explore:
hyper_list = [(lat,hid,bet) for bet in [.01,.0001] for hid in [[200], [500], [1000], [2000]] for lat in [10,50,100]]
! mkdir -p results_temp

# Loop over hyperparameters:
for i,(lat,hid,bet) in enumerate(hyper_list):
    foldername = f"latent_{lat}_{hid}_{bet}"
    if foldername not in os.listdir("./results_temp/"):
        !mkdir -p results_temp/{foldername}
        !pwd
        # Train again the model
        ! rm ./results/latent_space/model.pt
        ! move-dl data=AMSC task=AMSC__latent_Simon data.raw_data_path={DATA_FOLDERS[0]} data.interim_data_path={DATA_FOLDERS[1]} task.model.beta={bet} task.model.num_latent={lat} task.model.num_hidden={hid}
        # Copy the results to the corresponding folder
        !cp -r results/latent_space/ results_temp/{foldername} 
        print(f"Finished {foldername}")
        
    else:
        print(f"{foldername} was already created")

    

### Visualization

The following code is aimed to read a file with feature names and create an alluvial (Sankey) plot. We will use it to show the fraction of features that appear in the SHAP analyses from different channels and compartments.

In [None]:
# Install the package
! pip install pyalluvial

**Create an alluvial plot:**

This code is inspired on the scripts Felipe Dos Santos used to create sankey plots for the Lipocyte Profiler paper.

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import re
import pyalluvial.alluvial as alluvial
from pathlib import Path

# Define your file path and load data
file = "SHAP_LP_new.txt"
df = pd.read_csv(Path("./") / file, sep='\t', header=None) # Get only column 2
feature_names = list(df.values[:, 1].flatten())

def plot_alluvial(features):
    # Create a DataFrame from the list of features
    all_data = pd.DataFrame({'features': features})

    # Define functions to categorize features
    def categorize_compartment(feature):
        if re.search(r'^Cells_', feature):
            return 'Cells'
        elif re.search(r'^Cytoplasm_', feature):
            return 'Cytoplasm'
        elif re.search(r'Nuclei', feature):
            return 'Nuclei'
        else:
            return 'Other'

    def categorize_feature_channel(feature):
        if re.search(r'BODIPY', feature) and not re.search(r'AGP|Mito|DNA', feature):
            return 'Lipid'
        elif re.search(r'BODIPY', feature) and re.search(r'AGP|Mito', feature):
            return 'Combination'
        elif re.search(r'Mito', feature) and not re.search(r'AGP|BODIPY|DNA', feature):
            return 'Mito'
        elif re.search(r'Mito', feature) and re.search(r'AGP|DNA', feature):
            return 'Combination'
        elif re.search(r'AGP', feature) and not re.search(r'BODIPY|Mito|DNA', feature):
            return 'AGP'
        elif re.search(r'AGP', feature) and re.search(r'DNA', feature):
            return 'Combination'
        elif re.search(r'DNA', feature) and not re.search(r'AGP|BODIPY|Mito', feature):
            return 'DNA'
        else:
            return 'Compartimental'

    def categorize_feature_measure(feature):
        if re.search(r'Texture', feature):
            return 'Texture'
        elif re.search(r'Intensity', feature):
            return 'Intensity'
        elif re.search(r'Granularity', feature):
            return 'Granularity'
        elif re.search(r'Correlation|Location|Angle|Distance|Touching|RadialDistribution', feature):
            return 'Position'
        else:
            return 'Shape/Size/Count'

    # Apply categorization
    all_data['compartment'] = all_data['features'].apply(categorize_compartment)
    all_data['feature_channel'] = all_data['features'].apply(categorize_feature_channel)
    all_data['feature_measure'] = all_data['features'].apply(categorize_feature_measure)

    df = all_data[['compartment', 'feature_channel', 'feature_measure']]

    # Group by specific columns and count unique rows
    grouped = df.groupby(['compartment', 'feature_channel', 'feature_measure']).size().reset_index(name='Count')

    # Create a DataFrame of unique rows
    unique_rows = df.drop_duplicates(subset=['compartment', 'feature_channel', 'feature_measure'])

    # Merge the unique rows DataFrame with the count DataFrame
    df = unique_rows.merge(grouped, on=['compartment', 'feature_channel', 'feature_measure'], how='left')

    # Plot the alluvial diagram
    fig = alluvial.plot(df=df, xaxis_names=['compartment', 'feature_channel', 'feature_measure'], 
                        y_name='Count', alluvium='compartment', cmap_name='Pastel1', figsize=(12, 10))

    # Calculate the count sums
    compartment_sum = df.groupby('compartment')['Count'].sum()
    feature_channel_sum = df.groupby('feature_channel')['Count'].sum()
    feature_measure_sum = df.groupby('feature_measure')['Count'].sum()

    # Display the results in the console (optional)
    print("Compartment Sums:\n", compartment_sum)
    print("\nFeature Channel Sums:\n", feature_channel_sum)
    print("\nFeature Measure Sums:\n", feature_measure_sum)

    # Add annotations with count sums
    def add_annotations(ax, sum_series, position_dict, color_dict):
        for label, value in sum_series.items():
            pos = position_dict[label]
            color = color_dict[label]
            ax.text(pos[0], pos[1], f'{value}', ha='center', va='top', fontsize=12, color=color)

    # Define the color mapping
    color_mapping = {
        'AGP': '#f2cf41',
        'Lipid': '#4dac26',
        'DNA': '#65a6db',
        'Mito': '#f56464',
        'Combination': '#8d55c6',
        'Compartimental': '#959ca3',
        'Granularity': '#808080',
        'Intensity': '#B8BAC5',
        'Texture': '#383F59',
        'Correlation': '#656B7D',
        'Position': '#9AAFB5',
        'Shape/Size/Count': '#708090',
        'Cells': 'blue',
        'Cytoplasm':'green'
    }

    # Mapping of positions where the text should be placed
    compartment_positions = {
        'Cells': [1, 0.2],
        'Cytoplasm': [1, 0.1],
        'Nuclei': [1, 0],
        'Other': [1, -0.1]
    }
    feature_channel_positions = {
        'Lipid': [2, 0.2],
        'Mito': [2, 0.1],
        'DNA': [2, 0],
        'AGP': [2, -0.1],
        'Combination': [2, -0.2],
        'Compartimental': [2, -0.3]
    }
    feature_measure_positions = {
        'Texture': [3, 0.2],
        'Intensity': [3, 0.1],
        'Granularity': [3, 0],
        'Position': [3, -0.1],
        'Shape/Size/Count': [3, -0.2]
    }

    ax = fig.gca()
    plt.axis('off')
    # Add the count sums as annotations
    #add_annotations(ax, compartment_sum, compartment_positions, color_mapping)
    #add_annotations(ax, feature_channel_sum, feature_channel_positions, color_mapping)
    #add_annotations(ax, feature_measure_sum, feature_measure_positions, color_mapping)

    # Save the figure
    fig.savefig(Path(FIGURE_FOLDER) / "Alluvial.png", dpi=200)

    # Prepare data for the return
    node_dict = {name: i for i, name in enumerate(set(df.values[:, :-1].flatten()))}
    source, target, value, edge_colors = [], [], [], []

    for line in df.values:
        for i in range(2):
            source.append(node_dict[line[i]])
            target.append(node_dict[line[i + 1]])
            value.append(line[3])
            edge_colors.append(color_mapping[line[0]])

    return list(node_dict.keys()), source, target, value, list(color_mapping.values()), edge_colors

# Calling the function with your feature names
label, source, target, value, color, edge_colors = plot_alluvial(feature_names)


In [None]:
fig = go.Figure(data=[go.Sankey(
    node = dict(
      pad = 15,
      thickness = 20,
      line = dict(color = "black", width = 0.5),
      label = label,
      color = color
    ),
    link = dict(
      source = source, # indices correspond to labels, eg A1, A2, A1, B1, ...
      target = target,
      value = value,
      color = edge_colors
  ))])

fig.update_layout(title_text="Basic Sankey Diagram", font_size=10)
fig.show()

In [None]:
# Summary on PRS ranks
PRS_summary = pd.read_csv("./SHAP_PRS.txt", sep="\t", header=None)
PRS_summary.columns = ["Rank", "PRS"]

PRS_summary.groupby(by=["PRS"]).sum().sort_values(by=["Rank"])

__Latent space analysis insights:__

- The model organizes samples in different separable regions

⚠️  If you get a similar error:

```RuntimeError: Error(s) in loading state_dict for VAE:
	size mismatch for encoderlayers.0.weight: copying a param with shape torch.Size([720, 1342]) from checkpoint, the shape in current model is torch.Size([900, 1342]). ```

Erase the model.pt file in ```results/latent_space```. This can be done by setting REMOVE_LATENT_MODEL = True and running the second cell above.

### Copy results to workspace bucket

We can now copy the obtained results back to the workspace bucket. The new files will appear in the data tab inside the files folder.

In [None]:
# Save the results in the workspace bucket
#! gsutil cp -r ./results/ {bucket}


## Identifying associations between features

Lastly, we will use what the model has learnt to identify entangled or associated variables.

When perturbing the input value of a certain feature for all samples, the latent representation of said samples will change, and so will their reconstructions. We can track the induced shifts in output values to identify the features that were affected the most when perturbing the original feature, at a cohort level.

> ℹ️ We will use use the [Bayesian decision theory-based approach](https://www.nature.com/articles/s41587-022-01520-x#Sec15), presented in the Methods section of the original paper. We will also use an approach based on Kolmogorov-Smirnov distances between the feature reconstruction distributions, comparing them before and after perturbing an input feature.

### Loop over hyperparameter combinations to identify associations between features.

In [None]:
# Identify associations loop for continuous datasets:

####################### Hyper parameters #######################
target_ds_list = [(str('Cyto'),"plus_std"), (str("FFA"), 1), (str("cellType"), "sc"), (str('PRS'),"plus_std")]
NUM_REFITS = 24
TASK = "AMSC__id_assoc_bayes"
SIG_THR = 0.05 #(0.05 for bayes)
NUM_EPOCHS = 300
hyper_list = [(lat,hid,bet) for bet in [0.0001] for hid in [[500]] for lat in [50]] # Best performing model in validation

###########################################################
for i,(lat,hid,bet) in enumerate(hyper_list):
    ! pwd
    for TARGET_DS,TARGET_VALUE in target_ds_list:
        ! mkdir -p results_temp_id_assoc_{TARGET_DS}
        foldername = f"id_assoc_{lat}_{hid}_{bet}_{TARGET_DS}_{TASK}"
        
        if foldername not in os.listdir(f"./results_temp_id_assoc_{TARGET_DS}/"):
            !mkdir -p results_temp_id_assoc_{TARGET_DS}/{foldername}

            print(f"Starting {foldername}")
            ! move-dl data=AMSC task={TASK} data.raw_data_path={DATA_FOLDERS[0]} data.interim_data_path={DATA_FOLDERS[1]} task.model.beta={bet} task.model.num_latent={lat} task.model.num_hidden={hid} task.target_dataset={TARGET_DS} task.num_refits={NUM_REFITS} task.target_value={TARGET_VALUE} task.sig_threshold={SIG_THR} task.training_loop.num_epochs={NUM_EPOCHS}

            !cp -r results/identify_associations/ results_temp_id_assoc_{TARGET_DS}/{foldername} 
            print(f"Finished {foldername}")
             
        else:
            print(f"{foldername} was already created")
            
    # Save models:
    #! mkdir -p all_models/models_{lat}_{hid}_{bet}
    #! cp -r {DATA_FOLDERS[1]}models all_models/models_{lat}_{hid}_{bet}
    #! rm -r {DATA_FOLDERS[1]}models/*
            

In [None]:
! ls ./all_models/models_50_[500]_0.0001/models/

In [None]:
# KS here
# Identify associations loop for continuous datasets:
target_ds_list = [(str('PRS'),"plus_std"), (str('Cyto'),"plus_std")]
NUM_REFITS = 1
TASK = "AMSC__id_assoc_ks"
SIG_THR = 0.999 #(0.05 for bayes)
NUM_EPOCHS = 300

#hyper_list = [(lat,hid,bet) for bet in [.01, .001] for hid in [[500], [1000]] for lat in [100]]
#hyper_list = [(lat,hid,bet) for bet in [0.001] for hid in [[1000]] for lat in [100]]
hyper_list = [(lat,hid,bet) for bet in [0.0001 for _ in range(24)] for hid in [[500]] for lat in [50]]

for i,(lat,hid,bet) in enumerate(hyper_list):
    ! pwd
    ! rm -r {DATA_FOLDERS[1]}models
    ! mkdir -p {DATA_FOLDERS[1]}models
    ! cp ./all_models/models_50_[500]_0.0001/models/model_{lat}_{i}.pt {DATA_FOLDERS[1]}models/model_{lat}_0.pt
    
    for TARGET_DS,TARGET_VALUE in target_ds_list:
        ! mkdir -p results_temp_id_assoc_{TARGET_DS}
        foldername = f"id_assoc_{lat}_{hid}_{bet}_{TARGET_DS}_{TASK}_{i}"
        
        if foldername not in os.listdir(f"./results_temp_id_assoc_{TARGET_DS}/"):
            !mkdir -p results_temp_id_assoc_{TARGET_DS}/{foldername}

            print(f"Starting {foldername}")
            ! move-dl data=AMSC task={TASK} data.raw_data_path={DATA_FOLDERS[0]} data.interim_data_path={DATA_FOLDERS[1]} task.model.beta={bet} task.model.num_latent={lat} task.model.num_hidden={hid} task.target_dataset={TARGET_DS} task.num_refits={NUM_REFITS} task.target_value={TARGET_VALUE} task.sig_threshold={SIG_THR} task.training_loop.num_epochs={NUM_EPOCHS}

            !cp -r results/identify_associations/ results_temp_id_assoc_{TARGET_DS}/{foldername} 
            print(f"Finished {foldername}")
             
        else:
            print(f"{foldername} was already created")

In [None]:
import numpy as np
import matplotlib.pyplot as plt

# Function to compute KS threshold
def ks_threshold(N, alpha):
    return np.sqrt(-(1/N) * np.log(alpha/2))

# Define the range for N and alpha
N_values = np.linspace(100, 500, 10000)  # Avoid zero to prevent division by zero
alpha_values = np.linspace(0.001, 1, 1000)  # Avoid zero for alpha to prevent log(0)

# Create a meshgrid for N and alpha
N_grid, alpha_grid = np.meshgrid(N_values, alpha_values)
ks_thrs = ks_threshold(N_grid, alpha_grid)

# Create a 2D heatmap
fig, ax = plt.subplots(figsize=(7,3.5))
c = ax.pcolormesh(N_grid, alpha_grid, ks_thrs, shading='auto', cmap='RdYlBu')

# Add a color bar
fig.colorbar(c, ax=ax, label='KS Threshold')

# Set labels
ax.set_xlabel('Number of Samples N')
ax.set_ylabel('Significance Threshold')
ax.set_title('Heatmap of KS Threshold')
# Set x-axis to logarithmic scale
#ax.set_xscale('log')
fig.tight_layout()
fig.savefig(FIGURE_FOLDER + 'KS_thr_heatmap.png', dpi=200)


### Creating the final association lists

In [None]:
import numpy as np
from matplotlib_venn import venn2

DATASET_LIST = ["PRS", "Cyto"]

for DATASET in DATASET_LIST:
    id_assoc_multi_run_path = Path(f"./results_temp_id_assoc_{DATASET}/")
    combined_df = []
    #Loop over different runs
    for run_folder in os.scandir(id_assoc_multi_run_path):
        run_idx = run_folder.name.split('_')[-1]
        print(run_idx)
        if os.path.isfile(Path(run_folder) / "identify_associations/results_sig_assoc_ks.tsv"):
            df = pd.read_csv(Path(run_folder) / "identify_associations/results_sig_assoc_ks.tsv", header=0, index_col=0, sep="\t")
            df['run_idx'] = run_idx
            combined_df.append(df)

    # Concatenate the results from different runs and rank them according to abs(KS distance).
    combined_df = pd.concat(combined_df, ignore_index=True).sort_values('ks_distance', key=lambda x: abs(x), ascending=False)
    combined_df.to_csv(id_assoc_multi_run_path / f"{DATASET}_combined_ks_associations.csv", sep="\t", index=False)



In [None]:
! ls ./results_temp_id_assoc_Cyto/id_assoc_50_[500]_0.0001_Cyto_AMSC__id_assoc_bayes/identify_associations/

In [None]:
MARC_BUCKET = "gs://collaborations_marc/shared_2023/results_marc/results_residualized/associations_final_tables/KS/"


! gsutil cp ./PRS_combined_ks_associations.csv {MARC_BUCKET}
#! gsutil -m cp -r ./results_cont_paper_II/ {MARC_BUCKET}
#! gsutil cp -r ./results/* {MARC_BUCKET}

In [None]:
# Create a new column with absolute KS values
combined_df['abs_ks'] = combined_df['ks_distance'].abs()

# Group by 'feature_a_name' and 'feature_b_name', then get index of max abs_ks
idx = combined_df.groupby(['feature_a_name', 'feature_b_name'])['abs_ks'].idxmax()

# Use these indices to select the rows
combined_df = combined_df.loc[idx]

# Reset the index
combined_df = combined_df.reset_index(drop=True)

# If you want to drop the 'abs_ks' column we created
combined_df= combined_df.drop('abs_ks', axis=1)

thr = ks_threshold(N=269, alpha=.2)
print("KS threshold", thr)

combined_df_ks = combined_df[combined_df['ks_distance'].abs() >= thr]


df_bayes = pd.read_csv(f"./results_temp_id_assoc_{DATASET}/id_assoc_50_[500]_0.0001_{DATASET}_AMSC__id_assoc_bayes/identify_associations/results_sig_assoc_bayes.tsv", header=0, index_col=0, sep="\t")



df_bayes

import pandas as pd
import matplotlib.pyplot as plt
from matplotlib_venn import venn2

# Assuming you have two DataFrames: df1 and df2
# Both containing the same two columns, let's call them 'col1' and 'col2'

# Convert DataFrames to sets of tuples for comparison
set1 = set(combined_df_ks[['feature_a_name', 'feature_b_name']].itertuples(index=False, name=None))
set2 = set(df_bayes[['feature_a_name', 'feature_b_name']].itertuples(index=False, name=None))

# Create the Venn diagram
plt.figure(figsize=(10, 6))
venn2([set1, set2], set_labels=('DataFrame 1', 'DataFrame 2'))

plt.title("Overlap between DataFrame 1 and DataFrame 2")
plt.show()

# To get the actual numbers:
only_in_df1 = len(set1 - set2)
only_in_df2 = len(set2 - set1)
in_both = len(set1.intersection(set2))

print(f"Rows only in DataFrame 1: {only_in_df1}")
print(f"Rows only in DataFrame 2: {only_in_df2}")
print(f"Rows in both DataFrames: {in_both}")

In [None]:
# Calculate the absolute KS values
abs_ks = df['ks_distance'].abs()

# Sort the values
sorted_data = np.sort(abs_ks)

# Calculate the proportional values of y
y = np.arange(1, len(sorted_data)+1) / len(sorted_data)

# Create the plot
plt.figure(figsize=(10, 6))
plt.plot(sorted_data, y, lw=1)
plt.xlabel('Absolute KS Distance')
plt.ylabel('Cumulative Probability')
plt.title('Cumulative Distribution of Absolute KS Values')
plt.grid(True)

# Add a line at y=0.5 for median
plt.axhline(y=0.5, color='r', linestyle='--')

plt.show()

### 3D plots: latent space of 3 dimensions

This cell is a replica of MOVE's code for the identification of associations. I stores the latent representations of the samples, to plot them subsequently. We set the latent nodes to three to be able to directly visualize the latent location of samples and their movements after the perturbations, i.e. without dimensionality reduction transformations like t-SNE or UMAP

In [None]:
__all__ = ["identify_associations"]

from functools import reduce
from os.path import exists
from pathlib import Path
from typing import Literal, Sized, Union, cast

import hydra
import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from scipy.stats import ks_2samp, pearsonr  # type: ignore
from torch.utils.data import DataLoader

from move.analysis.metrics import get_2nd_order_polynomial

from move.conf.schema import (
    IdentifyAssociationsBayesConfig,
    IdentifyAssociationsConfig,
    IdentifyAssociationsKSConfig,
    IdentifyAssociationsTTestConfig,
    MOVEConfig,
)
from move.core.logging import get_logger
from move.core.typing import BoolArray, FloatArray, IntArray
from move.data import io
from move.data.dataloaders import MOVEDataset, make_dataloader
from move.data.perturbations import (
    ContinuousPerturbationType,
    perturb_categorical_data,
    perturb_continuous_data_extended,
)
from move.data.preprocessing import one_hot_encode_single
from move.models.vae import VAE
from move.visualization.dataset_distributions import (
    plot_correlations,
    plot_cumulative_distributions,
    plot_feature_association_graph,
    plot_reconstruction_movement,
)

TaskType = Literal["bayes", "ttest", "ks"]
CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"]

################################### Hyperparameters ##############################
lat, hid, bet = (3,[1000],.1)
TARGET_DS = "cellType"
TARGET_VALUE= "sc"
config = io.read_config("AMSC",
                        "AMSC__id_assoc_bayes",
                        f"data.raw_data_path={DATA_FOLDERS[0]}",
                        f"data.interim_data_path={DATA_FOLDERS[1]}",
                        f"task.model.beta={bet}",
                        f"task.model.num_latent={lat}",
                        f"task.model.num_hidden={hid}",
                        f"task.num_refits=1",
                        f"task.target_dataset={TARGET_DS}", 
                        f"task.target_value={TARGET_VALUE}"
                       )
####################################################################################

def _get_task_type(
    task_config: IdentifyAssociationsConfig,
) -> TaskType:
    task_type = OmegaConf.get_type(task_config)
    if task_type is IdentifyAssociationsBayesConfig:
        return "bayes"
    if task_type is IdentifyAssociationsTTestConfig:
        return "ttest"
    if task_type is IdentifyAssociationsKSConfig:
        return "ks"
    raise ValueError("Unsupported type of task!")


def _validate_task_config(
    task_config: IdentifyAssociationsConfig, task_type: TaskType
) -> None:
    if not (0.0 <= task_config.sig_threshold <= 1.0):
        raise ValueError("Significance threshold must be within [0, 1].")
    if task_type == "ttest":
        task_config = cast(IdentifyAssociationsTTestConfig, task_config)
        if len(task_config.num_latent) != 4:
            raise ValueError("4 latent space dimensions required.")


def prepare_for_categorical_perturbation(
    config: MOVEConfig,
    interim_path: Path,
    baseline_dataloader: DataLoader,
    cat_list: list[FloatArray],
) -> tuple[list[DataLoader], BoolArray, BoolArray,]:
    """
    This function creates the required dataloaders and masks
    for further categorical association analysis.

    Args:
        config: main configuration file
        interim_path: path where the intermediate outputs are saved
        baseline_dataloader: reference dataloader that will be perturbed
        cat_list: list of arrays with categorical data

    Returns:
        dataloaders: all dataloaders, including baseline appended last.
        nan_mask: mask for Nans
        feature_mask: masks the column for the perturbed feature.
    """

    # Read original data and create perturbed datasets
    task_config = cast(IdentifyAssociationsConfig, config.task)
    logger = get_logger(__name__)

    # Loading mappings:
    mappings = io.load_mappings(interim_path / "mappings.json")
    
    target_mapping = mappings[task_config.target_dataset]
    target_value = one_hot_encode_single(target_mapping, task_config.target_value)
    logger.debug(
        f"Target value: {task_config.target_value} => {target_value.astype(int)[0]}"
    )
    
    dataloaders = perturb_categorical_data(
        baseline_dataloader,
        config.data.categorical_names,
        task_config.target_dataset,
        target_value,
    )
    
    dataloaders.append(baseline_dataloader)

    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

    assert baseline_dataset.con_all is not None
    orig_con = baseline_dataset.con_all
    nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s
    logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}")

    target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)
    target_dataset = cat_list[target_dataset_idx]
    feature_mask = np.all(target_dataset == target_value, axis=2)  # 2D: N x P
    feature_mask |= np.sum(target_dataset, axis=2) == 0

    return (
        dataloaders,
        nan_mask,
        feature_mask,
    )


def prepare_for_continuous_perturbation(
    config: MOVEConfig,
    output_subpath: Path,
    baseline_dataloader: DataLoader,
) -> tuple[list[DataLoader], BoolArray, BoolArray,]:
    """
    This function creates the required dataloaders and masks
    for further continuous association analysis.

    Args:
        config:
            main configuration file.
        output_subpath:
            path where the output plots for continuous analysis are saved.
        baseline_dataloader:
            reference dataloader that will be perturbed.

    Returns:
        dataloaders:
            list with all dataloaders, including baseline appended last.
        nan_mask:
            mask for NaNs
        feature_mask:
            same as `nan_mask`, in this case.
    """

    # Read original data and create perturbed datasets
    logger = get_logger(__name__)
    task_config = cast(IdentifyAssociationsConfig, config.task)

    dataloaders = perturb_continuous_data_extended(
        baseline_dataloader,
        config.data.continuous_names,
        task_config.target_dataset,
        cast(ContinuousPerturbationType, task_config.target_value),
        output_subpath,
    )
    dataloaders.append(baseline_dataloader)

    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

    assert baseline_dataset.con_all is not None
    orig_con = baseline_dataset.con_all
    nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s
    logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}")
    feature_mask = nan_mask

    return (dataloaders, nan_mask, feature_mask)


def _bayes_approach(
    config: MOVEConfig,
    task_config: IdentifyAssociationsBayesConfig,
    train_dataloader: DataLoader,
    baseline_dataloader: DataLoader,
    dataloaders: list[DataLoader],
    models_path: Path,
    num_perturbed: int,
    num_samples: int,
    num_continuous: int,
    nan_mask: BoolArray,
    feature_mask: BoolArray,
) -> tuple[Union[IntArray, FloatArray], ...]:

    assert task_config.model is not None
    device = torch.device("cuda" if task_config.model.cuda == True else "cpu")

    # Train models
    logger = get_logger(__name__)
    logger.info("Training models")
    mean_diff = np.zeros((num_perturbed, num_samples, num_continuous))
    normalizer = 1 / task_config.num_refits

    # Last appended dataloader is the baseline
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)
    
    for j in range(task_config.num_refits):
        # Initialize model
        model: VAE = hydra.utils.instantiate(
            task_config.model,
            continuous_shapes=baseline_dataset.con_shapes,
            categorical_shapes=baseline_dataset.cat_shapes,
        )
        if j == 0:
            logger.debug(f"Model: {model}")

        # Train/reload model
        model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt"
        if model_path.exists():
            logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}")
            model.load_state_dict(torch.load(model_path))
            model.to(device)
        else:
            logger.debug(f"Training refit {j + 1}/{task_config.num_refits}")
            model.to(device)
            hydra.utils.call(
                task_config.training_loop,
                model=model,
                train_dataloader=train_dataloader,
            )
            if task_config.save_refits:
                torch.save(model.state_dict(), model_path)
        model.eval()

        # Calculate baseline reconstruction
        _, baseline_recon = model.reconstruct(baseline_dataloader)
        latent_space_baseline = model.project(baseline_dataloader)
        
        min_feat, max_feat = np.zeros((num_perturbed, num_continuous)), np.zeros(
            (num_perturbed, num_continuous)
        )
        min_baseline, max_baseline = np.min(baseline_recon, axis=0), np.max(
            baseline_recon, axis=0
        )

        # Calculate perturb reconstruction => keep track of mean difference
        for i in range(num_perturbed):
            _, perturb_recon = model.reconstruct(dataloaders[i])
            latent_space_perturbed = model.project(dataloaders[i])
            
            diff = perturb_recon - baseline_recon  # 2D: N x C
            mean_diff[i, :, :] += diff * normalizer

            min_perturb, max_perturb = np.min(perturb_recon, axis=0), np.max(
                perturb_recon, axis=0
            )
            min_feat[i, :], max_feat[i, :] = np.min(
                [min_baseline, min_perturb], axis=0
            ), np.max([max_baseline, max_perturb], axis=0)

    # Calculate Bayes factors
    logger.info("Identifying significant features")
    bayes_k = np.empty((num_perturbed, num_continuous))
    bayes_mask = np.zeros(np.shape(bayes_k))
    for i in range(num_perturbed):
        mask = feature_mask[:, [i]] | nan_mask  # 2D: N x C
        diff = np.ma.masked_array(mean_diff[i, :, :], mask=mask)  # 2D: N x C
        prob = np.ma.compressed(np.mean(diff > 1e-8, axis=0))  # 1D: C
        bayes_k[i, :] = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8)
        if task_config.target_value in CONTINUOUS_TARGET_VALUE:
            bayes_mask[i, :] = (
                baseline_dataloader.dataset.con_all[0, :]
                - dataloaders[i].dataset.con_all[0, :]
            )

    bayes_mask[bayes_mask != 0] = 1
    bayes_mask = np.array(bayes_mask, dtype=bool)

    # Calculate Bayes probabilities
    bayes_abs = np.abs(bayes_k)
    bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs))  # 2D: N x C
    bayes_abs[bayes_mask] = np.min(
        bayes_abs
    )  # Bring feature_i feature_i associations to minimum
    sort_ids = np.argsort(bayes_abs, axis=None)[::-1]  # 1D: N x C
    prob = np.take(bayes_p, sort_ids)  # 1D: N x C
    logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]")

    # Sort Bayes
    bayes_k = np.take(bayes_k, sort_ids)  # 1D: N x C

    # Calculate FDR
    fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1)  # 1D
    idx = np.argmin(np.abs(fdr - task_config.sig_threshold))
    logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")

    return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx], latent_space_baseline, latent_space_perturbed



logger = get_logger(__name__)
task_config = cast(IdentifyAssociationsConfig, config.task)
task_type = _get_task_type(task_config)
_validate_task_config(task_config, task_type)

interim_path = Path(config.data.interim_data_path)

models_path = interim_path / "models"
if task_config.save_refits:
    models_path.mkdir(exist_ok=True)

output_path = Path(config.data.results_path) / "identify_associations"
output_path.mkdir(exist_ok=True, parents=True)

# Load datasets:
cat_list, cat_names, con_list, con_names = io.load_preprocessed_data(
    interim_path,
    config.data.categorical_names,
    config.data.continuous_names,
)

train_dataloader = make_dataloader(
    cat_list,
    con_list,
    shuffle=True,
    batch_size=task_config.batch_size,
    drop_last=True,
)

con_shapes = [con.shape[1] for con in con_list]

num_samples = len(cast(Sized, train_dataloader.sampler))  # N
num_continuous = sum(con_shapes)  # C
logger.debug(f"# continuous features: {num_continuous}")

# Creating the baseline dataloader:
baseline_dataloader = make_dataloader(
    cat_list, con_list, shuffle=False, batch_size=task_config.batch_size
)

# Indentify associations between continuous features:
logger.info(f"Perturbing dataset: '{task_config.target_dataset}'")
if task_config.target_value in CONTINUOUS_TARGET_VALUE:
    logger.info(f"Beginning task: identify associations continuous ({task_type})")
    logger.info(f"Perturbation type: {task_config.target_value}")
    output_subpath = Path(output_path) / "perturbation_visualization"
    output_subpath.mkdir(exist_ok=True, parents=True)
    (dataloaders, nan_mask, feature_mask,) = prepare_for_continuous_perturbation(
        config, output_subpath, baseline_dataloader
    )

# Identify associations between categorical and continuous features:
else:
    logger.info("Beginning task: identify associations categorical")
    (dataloaders, nan_mask, feature_mask,) = prepare_for_categorical_perturbation(
        config, interim_path, baseline_dataloader, cat_list
    )

    
num_perturbed = len(dataloaders) - 1  # P
logger.debug(f"# perturbed features: {num_perturbed}")

################# APPROACH EVALUATION ##########################

if task_type == "bayes":
    task_config = cast(IdentifyAssociationsBayesConfig, task_config)
    *_ , latent_space_baseline, latent_space_perturbed = _bayes_approach(
        config,
        task_config,
        train_dataloader,
        baseline_dataloader,
        dataloaders,
        models_path,
        num_perturbed,
        num_samples,
        num_continuous,
        nan_mask,
        feature_mask,
    )

    extra_colnames = ["proba", "fdr", "bayes_k"]

else:
    raise ValueError()


In [None]:
from PIL import Image
from matplotlib.pyplot import cm
import seaborn as sns
from matplotlib.colors import ListedColormap

genes_of_interest = {"ENSG00000099194":"SCD",
                     "ENSG00000147872":"PLIN2",
                     "ENSG00000079435":"LIPE",
                      "ENSG00000181856":"GLUT4",}
                      #"ENSG00000177370":"TIMM22",
                      #"ENSG00000171105":"INSR"}


! mkdir -p figures
feature_list = ["Day"] + list(genes_of_interest.keys())
#feature_list = ["ENSG00000079435"]
PRS_OF_INTEREST = "gps_vatadjbmi3"
figure_path = Path("./figures/")
results_path = Path("./results/identify_associations/")

                                     
def plot_3D_latent_and_displacement(
    mu_baseline,
    mu_perturbed,
    feature_values,
    feature_name,
    show_baseline=True,
    show_perturbed=True,
    show_arrows=True,
    step: int=1,
    altitude: int=30,
    azimuth: int=45,
):
    """
    Plot the movement of the samples in the 3D latent space after perturbing one
    input variable.

    Args:
        mu_baseline:
            ND array with dimensions n_samples x n_latent_nodes containing
            the latent representation of each sample
        mu_perturbed:
            ND array with dimensions n_samples x n_latent_nodes containing
            the latent representation of each sample after perturbing the input
        feature_values:
            1D array with feature values to map to a colormap ("bwr"). Each sample is
            colored according to its value for the feature of interest.
        feature_name:
            name of the feature mapped to a colormap
        show_baseline:
            plot orginal location of the samples in the latent space
        show_perturbed:
            plot final location (after perturbation) of the samples in latent space
        show_arrows:
            plot arrows from original to final location of each sample
        angle:
            elevation from dim1-dim2 plane for the visualization of latent space.

    Raises:
        ValueError: If latent space is not 3-dimensional (3 hidden nodes).
    Returns:
        Figure
    """
    # construct cmap
    #hex_colors= ['#d95f01','#1c9e78','#323896','#fec071']
    #my_cmap = ListedColormap(hex_colors)
    #my_cmap = "inferno"
    #my_cmap = sns.color_palette("RdYlBu", as_cmap=True)
    my_cmap = sns.color_palette("seismic", as_cmap=True)

    eps = 1e-16
    if [np.shape(mu_baseline)[1], np.shape(mu_perturbed)[1]] != [3, 3]:
        raise ValueError(
            " The latent space must be 3-dimensional. Redefine num_latent to 3."
        )

    fig = plt.figure(layout="constrained", figsize=(7, 7))
    ax = fig.add_subplot(projection="3d")
    ax.view_init(altitude, azimuth)

    if show_baseline:
        vmin, vmax = np.min(feature_values[::step]), np.max(feature_values[::step])
        abs_max = np.max([abs(vmin), abs(vmax)])
        ax.scatter(
            mu_baseline[::step, 0],
            mu_baseline[::step, 1],
            mu_baseline[::step, 2],
            marker="o",
            c=feature_values[::step],
            s=8,
            lw=0,
            cmap=my_cmap,
            #vmin=0,
            #vmax=1
        )
        ax.set_title(feature_name)
        #plt.colorbar()  # Normalize(min(feature_values[::step]),max(feature_values[::step]))), ax=ax)
    if show_perturbed:
        ax.scatter(
            mu_perturbed[::step, 0],
            mu_perturbed[::step, 1],
            mu_perturbed[::step, 2],
            marker="o",
            c=feature_values[::step],
            s=8,
            label="perturbed",
            lw=0,
        )
    if show_arrows:
        u = mu_perturbed[::step, 0] - mu_baseline[::step, 0]
        v = mu_perturbed[::step, 1] - mu_baseline[::step, 1]
        w = mu_perturbed[::step, 2] - mu_baseline[::step, 2]

        module = np.sqrt(u * u + v * v + w * w)

        mask = module > eps

        max_u, max_v, max_w = np.max(abs(u)), np.max(abs(v)), np.max(abs(w))
        # Arrow colors will be weighted contributions of red -> dim1, green -> dim2, and blue-> dim3. I.e. purple arrow means movement in dims 1 and 3
        colors = [
            (abs(du) / max_u, abs(dv) / max_v, abs(dw) / max_w, 0.7)
            for du, dv, dw in zip(u, v, w)
        ]
        ax.quiver(
            mu_baseline[::step, 0][mask],
            mu_baseline[::step, 1][mask],
            mu_baseline[::step, 2][mask],
            u[mask],
            v[mask],
            w[mask],
            color=colors,
            lw=.8,
            )  # alpha=(1-module/np.max(module))**6, arrow_length_ratio=0)
        # help(ax.quiver)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")
    ax.set_zlabel("Dim 3")
    # ax.set_axis_off()

    return fig

for feature in feature_list:
    dataset = feature if "ENSG" not in feature else "RNA" 
    feature_values = pd.read_csv(f"./data_residualized/{dataset}.tsv", sep="\t")
    if feature == "cellType":
        feature_values = [0 if celltype == 'sc' else 1 for celltype in feature_values['cellType']]
    elif feature == "Day":
        feature_values = [day for day in feature_values['Day']]
        
    elif feature == "PRS":
        feature_values = [feature_values[PRS_OF_INTEREST]]
    else: # FFA and GENES do already contain the target values in the column
        feature_values = feature_values[feature]
        
    # # Plot latent space:
    pic_num = 0
    n_pictures = 10
    

    for azimuth, altitude in zip(
        np.linspace(0, 10, n_pictures), np.linspace(15, 45, n_pictures)
    ):
                     
        title = feature if "ENSG" not in feature else genes_of_interest[feature]
                     
        fig = plot_3D_latent_and_displacement(
            latent_space_baseline,
            latent_space_perturbed,
            feature_values=feature_values,
            feature_name=f"Sample movement",
            show_baseline=True,
            show_perturbed=False,
            show_arrows=True,
            step=1,
            altitude=altitude,
            azimuth=azimuth,
        )

        fig.savefig(figure_path / f"3D_latent_movement_{pic_num}_arrows.png", dpi=100)
        plt.close(fig)

    
        fig = plot_3D_latent_and_displacement(
            latent_space_baseline,
            latent_space_perturbed,
            feature_values=feature_values,
            feature_name=f"{title}",
            show_baseline=True,
            show_perturbed=False,
            show_arrows=False,
            altitude=altitude,
            azimuth=azimuth,
        )
        fig.savefig(
            figure_path / f"3D_latent_movement_{pic_num}_perturbed_feature.png", dpi=100
        )
        plt.close(fig)
        pic_num += 1

    for plot_type in ["arrows", "perturbed_feature"]:
        frames = [
            Image.open(figure_path / f"3D_latent_movement_{pic_num}_{plot_type}.png")
            for pic_num in range(n_pictures)
        ]  # sorted(glob.glob("*3D_latent*"))]
        frames[0].save(
            figure_path / f"{plot_type}_{title}.gif",
            format="GIF",
            append_images=frames[1:],
            save_all=True,
            duration=75,
            loop=0,
        )


**Example image of the latent space with three coordinates or nodes.**
When perturbing the cell type, i.e. adding the flag "Subcutaneous" to the samples that were from the visceral depot, we can see that the samples never abandon the original cluster, but they still move in the direction towards sc samples.

In [None]:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt

for i in range(n_pictures):
    image_path = f"./figures/3D_latent_movement_{i}_arrows.png"
    image = mpimg.imread(image_path)
    fig = plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.show()

In [None]:
MARC_BUCKET = "gs://collaborations_marc/shared_2023/results_marc/results_residualized/figures_cell_type/"


! gsutil -m cp -r ./figures/* {MARC_BUCKET}
#! gsutil -m cp -r ./results_cont_paper_II/ {MARC_BUCKET}
#! gsutil cp -r ./results/* {MARC_BUCKET}

### PRS ranking visualization

We will gather the feature importance tsv files, compute SHAP values, and create violin plots to check the median ranking of a PRS across different models architectures. PRSs that are ranked high or low systematically are the ones that matter the most/less to the model to characterize the studied system.

In [None]:
from pathlib import Path
import os
import pandas as pd
import numpy as np
path = Path("./results_temp/")

ls = "latent_space"

feature_set = ["Cyto","Cells","RNA","PRS"]

for folder in os.listdir(path):
    print(folder)
    _, lat, hid, beta = folder.split("_")
    for feature in feature_set:
        df = pd.read_csv(f"./results_temp/latent_{lat}_{hid}_{beta}/latent_space/feat_importance_{feature}.tsv", sep="\t").set_index("sample")
        #print(df)
        #df = pd.read_csv(f"./results_temp_II/{folder}/latent_space/feat_importance_{feature}.tsv")#path / folder / / f"feat_importance_{feature}.tsv")
        #print(df)
        top10_ids = np.argsort(np.sum(np.abs(df.values), axis=0))[::-1]
        if "PRS" not in feature:
            top10_ids = top10_ids[:10]
            
        order = np.take(df.columns, top10_ids)
        savepath = "/home/jupyter/Characterizing AMSCs using MOVE/edit/"
        gene_shap = "SHAP_genes_new.txt"
        LP_shap = "SHAP_LP_new.txt"
        PRS_shap = "SHAP_PRS_new.txt"

        with open(savepath + gene_shap, 'a') as g, open(savepath + LP_shap, 'a') as l, open(savepath + PRS_shap, 'a') as p:
            for o, feature in enumerate(order):
                if "ENS" in feature:
                    g.write(f"{o}\t{feature}\n")
                elif ("Cytoplasm" in feature) or ("Cells" in feature):
                    l.write(f"{o}\t{feature}\n")
                elif ("gps_" in feature) or ("prs." in feature):
                    p.write(f"{o}\t{feature}\n")
            #print(order)

In [None]:
import matplotlib.pyplot as plt

path = Path("./")
filename = "SHAP_PRS_new.txt"

colnames= ["Rank","PRS"]

df = pd.read_csv(path / filename, sep="\t", names=colnames)

df_violin = pd.DataFrame()

df_violin["Lists"] = df.groupby("PRS")["Rank"].apply(list)

df_violin['median'] = [np.median(l) for l in df_violin['Lists']]
#df_violin["mean"] = df_violin["Lists"].apply(lambda x: sum(x) / len(x))
df_violin = df_violin.sort_values(by="median")
#df_violin = df_violin.drop(columns="median")

fig = plt.figure(figsize=(8,4))
parts = plt.violinplot(df_violin["Lists"], showmedians=True, showextrema=False)

for pc in parts['bodies']:
    pc.set_facecolor('#c9dbeb')
    #pc.set_edgecolor('black')  # Optional: change edge color if desired
    pc.set_alpha(1)  # Optional: adjust transparency
    
# Customize median points
median_line = parts['cmedians']
median_line.set_linewidth(0)  # Hide the line
scatter_x = range(1, len(df_violin["Lists"]) + 1)
scatter_y = [np.median(subset) for subset in df_violin["Lists"]]
plt.scatter(scatter_x, scatter_y, color='black', marker='D', s=5, zorder=3)

plt.xticks(np.arange(1,18), labels=df_violin.index, rotation=90)
plt.ylabel("Rank")
plt.tight_layout()
fig.savefig(Path(FIGURE_FOLDER) / "PRS_SHAP.png", dpi=200)

# Saving results: 

- Locally
- In the google cloud external bucket:

In [None]:
MARC_BUCKET = "gs://collaborations_marc/shared_2023/results_marc/results_residualized/"


! gsutil -m cp -r ./paper_figures/ {MARC_BUCKET}
#! gsutil -m cp -r ./results_cont_paper_II/ {MARC_BUCKET}
#! gsutil cp -r ./results/* {MARC_BUCKET}
#! gsutil cp /home/jupyter/Characterizing\ AMSCs\ using\ MOVE/edit/results/identify_associations/results_sig_assoc_bayes.tsv {MARC_BUCKET}

In [None]:
! gsutil ls {MARC_BUCKET}

# Sythetic datasets benchmark

### Synthetic dataset creation

A synthetic dataset is created as a multivariate gaussian, where different features are different components of the Gaussian and each sample is a draw from the distribution. Explicit stronger correlations can be added by defining some features to be linear combinations of others. Categorical variables (binary) can be obtained by setting negative values to zero and positive values to one for a given feature.


In [None]:
import numpy as np
import random as rnd
import pandas as pd
from pathlib import Path
from scipy.stats import pearsonr
import matplotlib.pyplot as plt
from sklearn.datasets import make_sparse_spd_matrix
from move.data.preprocessing import scale
import os
import io
import seaborn as sns
import sys
import matplotlib as mpl
from itertools import chain

mpl.rcParams['font.family'] = 'Latin Modern Roman' #Font for the plots

################################ Functions ####################################

def get_feature_names(settings):
    all_feature_names = [
        f"{key}_{i+1}"
        for key in settings.keys()
        for i in range(settings[key]["features"])
    ]
    return all_feature_names


def create_mean_profiles(settings):
    feature_means = []
    for key in settings.keys():
        mean = settings[key]["offset"]
        for freq, coef in zip(
            settings[key]["frequencies"], settings[key]["coefficients"]
        ):
            mean += coef * (
                np.sin(
                    freq * np.arange(settings[key]["features"]) + settings[key]["phase"]
                )
                + 1
            )
        feature_means.extend(list(mean))
    return feature_means


def create_ground_truth_correlations_file(correlations):
    sort_ids = np.argsort(abs(correlations), axis=None)[::-1]  # 1D: N x C
    corr = np.take(correlations, sort_ids)  # 1D: N x C
    sig_ids = sort_ids[abs(corr) > COR_THRES]
    sig_ids = np.vstack(
        (sig_ids // len(all_feature_names), sig_ids % len(all_feature_names))
    ).T
    associations = pd.DataFrame(sig_ids, columns=["feature_a_id", "feature_b_id"])
    a_df = pd.DataFrame(dict(feature_a_name=all_feature_names))
    a_df.index.name = "feature_a_id"
    a_df.reset_index(inplace=True)
    b_df = pd.DataFrame(dict(feature_b_name=all_feature_names))
    b_df.index.name = "feature_b_id"
    b_df.reset_index(inplace=True)
    associations = associations.merge(a_df, on="feature_a_id", how="left").merge(
        b_df, on="feature_b_id", how="left"
    )
    associations["Correlation"] = corr[abs(corr) > COR_THRES]
    associations = associations[
        associations.feature_a_id > associations.feature_b_id
    ]  # Only one half of the matrix
    return associations


def plot_score_matrix(
    array, feature_names, cmap="bwr", vmin=None, vmax=None, label_step=5
):
    if vmin is None:
        vmin = np.min(array)
    elif vmax is None:
        vmax = np.max(array)
    # if ax is None:
    fig = plt.figure(figsize=(5, 5))
    plt.imshow(array, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.xticks(
        np.arange(0, len(feature_names), label_step),
        feature_names[::label_step],
        fontsize=8,
        rotation=90,
    )
    plt.yticks(
        np.arange(0, len(feature_names), label_step),
        feature_names[::label_step],
        fontsize=8,
    )
    plt.tight_layout()
    # ax
    return fig


def plot_feature_profiles(dataset, feature_means):
    ## Plot profiles
    fig = plt.figure(figsize=(15, 5))
    plt.plot(
        np.arange(len(feature_means)), feature_means, lw=1, marker=".", markersize=0
    )
    for sample in dataset:
        plt.plot(
            np.arange(len(feature_means)), sample, lw=0.1, marker=".", markersize=0
        )
    plt.xlabel("Feature number")
    plt.ylabel("Count number")
    plt.title("Patient specific profiles")
    plt.tight_layout()

    return fig


def plot_feature_correlations(dataset, pairs_2_plot):
    fig = plt.figure()
    for f1, f2 in pairs_2_plot:
        plt.plot(
            dataset[:, f1],
            dataset[:, f2],
            lw=0,
            marker=".",
            markersize=1,
            label=f"{correlations[f1,f2]:.2f}",
        )
    plt.xlabel("Feature A")
    plt.ylabel("Feature B")
    plt.legend(
        loc="upper center",
        bbox_to_anchor=(0.5, -0.1),
        fancybox=True,
        shadow=True,
        ncol=5,
    )
    plt.title("Feature correlations")
    plt.tight_layout()

    return fig


def save_splitted_datasets(
    settings: dict, PROJECT_NAME, dataset, all_feature_names, n_samples, outpath
):
    # Save index file
    index = pd.DataFrame({"ID": list(np.arange(1, n_samples + 1))})
    index.to_csv(outpath / f"random.{PROJECT_NAME}.ids.txt", index=False, header=False)
    # Save continuous files
    df = pd.DataFrame(
        dataset, columns=all_feature_names, index=list(np.arange(1, n_samples + 1))
    )
    cum_feat = 0
    for key in settings.keys():
        df_feat = settings[key]["features"]
        df_cont = df.iloc[:, cum_feat : cum_feat + df_feat]
        df_cont.insert(0, "ID", np.arange(1, n_samples + 1))
        df_cont.to_csv(
            outpath / f"random.{PROJECT_NAME}.{key}.tsv", sep="\t", index=False
        )
        cum_feat += df_feat





In [None]:
########################### Hyperparameters ####################################

PROJECT_NAME = "random_all_sim"
MODE = "linear"  # "non-linear"
HIGH_CORR = True
SEED_1 = 1234
np.random.seed(SEED_1)
rnd.seed(SEED_1)

COV_ALPHA = .99 #0.99 #.01 
N_SAMPLES = 5000

SETTINGS = {
    "Continuous_A": {
        "features": 50,
        "frequencies": [0.002, 0.01, 0.02],
        "coefficients": [500, 100, 50],
        "phase": 0,
        "offset": 700,
    },
    "Continuous_B": {
        "features": 100,
        "frequencies": [0.001, 0.05, 0.08],
        "coefficients": [80, 20, 10],
        "phase": np.pi / 2,
        "offset": 300,
    },
    "Categorical_A": {
        "features": 1,
        "frequencies": [0.1, 0.5, 0.8],
        "coefficients": [.2, .1, .05],
        "phase": np.pi / 2,
        "offset": 10,
    },
        "Categorical_B": {
        "features": 1,
        "frequencies": [0.01, 0.5, 0.08],
        "coefficients": [10, .1, .05],
        "phase": np.pi,
        "offset": 10,
    }
}

COR_THRES = 0.02
PAIRS_OF_INTEREST = [(1,2),(3,4)]  # ,(77,75),(99,70),(38,2),(67,62)]

# Path to store output files
outpath = Path("synthetic_data")
outpath.mkdir(exist_ok=True, parents=True)



################################## Main script ##################################
# %%
# Add all datasets in a single matrix:
all_feature_names = get_feature_names(SETTINGS)
feat_means = create_mean_profiles(SETTINGS)

# %%
###### Covariance matrix definition ######
if MODE == "linear":
    covariance_matrix = make_sparse_spd_matrix(
        dim=len(all_feature_names),
        alpha=COV_ALPHA,
        smallest_coef=0,
        largest_coef=1,
        norm_diag=False,
        random_state=SEED_1,
    )
elif MODE == "non-linear":
    covariance_matrix = np.identity(len(all_feature_names))

ABS_MAX = np.max(abs(covariance_matrix))
fig = plot_score_matrix(
    covariance_matrix, all_feature_names, vmin=-ABS_MAX, vmax=ABS_MAX
)
fig.savefig(outpath / f"Covariance_matrix_{PROJECT_NAME}.png")

#    dataset = np.array(
#        [
#            list(np.random.multivariate_normal(feat_means, covariance_matrix))
#            for _ in range(N_SAMPLES)
#        ]
#    )

dataset = np.random.multivariate_normal(feat_means, covariance_matrix, N_SAMPLES)



# Add non-linearities
if MODE == "non-linear":
    for i, j in PAIRS_OF_INTEREST:
        freq = np.random.choice([4, 5, 6])
        dataset[:, i] += np.sin(freq * dataset[:, j])

#scaled_dataset, _ = scale(dataset)
# No scaling in the dataset creation! It will be handled in preprocessing.
scaled_dataset = dataset

if HIGH_CORR: # The last half of the features are combinations of the first half:
    for i in range(scaled_dataset.shape[1]//2):
        col_1 = np.random.choice(range(scaled_dataset.shape[1]//2))
        col_2 = np.random.choice(range(scaled_dataset.shape[1]//2))
        scaled_dataset[:,i+scaled_dataset.shape[1]//2] = (scaled_dataset[:,col_1]+scaled_dataset[:,col_2])/2 + np.random.normal()

# Binarize the categorical dataset
NUM_CAT = SETTINGS["Categorical_A"]["features"] + SETTINGS["Categorical_B"]["features"]
columns_to_binarize = scaled_dataset[:,-NUM_CAT:]

# Compute the mean of each of the categorical columns    
means = columns_to_binarize.mean(axis=0)

# Apply the binarization
scaled_dataset[:,-NUM_CAT:] = (columns_to_binarize > means).astype(int)

print(np.min(scaled_dataset),np.max(scaled_dataset))
# Actual correlations:
# correlations = np.empty(np.shape(covariance_matrix))
# for ifeat in range(len(covariance_matrix)):
#     for jfeat in range(len(covariance_matrix)):
#         correlations[ifeat, jfeat] = pearsonr(dataset[:, ifeat], dataset[:, jfeat])[
#             0
#         ]

correlations = np.corrcoef(scaled_dataset, rowvar=False)
fig = plot_score_matrix(correlations, all_feature_names, vmin=-1, vmax=1, label_step=20)
fig.savefig(outpath / f"Correlations_{PROJECT_NAME}.png", dpi=200)

# Sort correlations by absolute value
associations = create_ground_truth_correlations_file(correlations)
associations.to_csv(outpath / f"changes.{PROJECT_NAME}.txt", sep="\t", index=False)

# Plot feature profiles per sample
fig = plot_feature_profiles(scaled_dataset, feat_means)
fig.savefig(outpath / "Multi-omic_profiles.png")

## Plot correlations
fig = plot_feature_correlations(dataset, PAIRS_OF_INTEREST)
fig.savefig(outpath / "Feature_correlations.png")

fig = plot_feature_correlations(scaled_dataset, PAIRS_OF_INTEREST)
fig.savefig(outpath / "Feature_correlations_scaled.png")

# Write tsv files with feature values for all samples in both datasets:
save_splitted_datasets(
    SETTINGS, PROJECT_NAME, scaled_dataset, all_feature_names, N_SAMPLES, outpath
)

In [None]:
### Running MOVE on synthetic data

# Note that we will change the significance threshld for the KS to get more associations.
# Encode data
! move-dl task=encode_data data=random_continuous_paper data.raw_data_path='synthetic_data'

# Identify assoc bayes
! move-dl task=random_continuous_paper__id_assoc_bayes data=random_continuous_paper data.raw_data_path='synthetic_data'

# Identify assoc KS
! move-dl task=random_continuous_paper__id_assoc_ks data=random_continuous_paper data.raw_data_path='synthetic_data'

# Identify assoc t-test
! move-dl task=random_continuous_paper__id_assoc_ttest data=random_continuous_paper data.raw_data_path='synthetic_data'

### Comparing methods to identify associations

We can run MOVE using the three methods to identify associations (t-test, Bayes and KS). The resulting tsv files can be compared with the ground truths file to compare their performance.

In [None]:
import argparse
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
from sklearn.metrics import precision_recall_curve, average_precision_score
from matplotlib_venn import venn2, venn3
from upsetplot import UpSet
from matplotlib import cm

# Create a colormap from seaborn
cmap = sns.color_palette("Dark2", 3,  as_cmap=False)  # False for a list of colors

# Set the color cycle for matplotlib using the colormap
plt.rcParams['axes.prop_cycle'] = plt.cycler(color=cmap)
##################################### Functions #############################################

def plot_confusion_matrix(cm,
                          target_names,
                          cmap=None,
                          normalize=False):
    
    """ Function that plots the confusion matrix given cm. Mattias Ohlsson's code extended."""

    import itertools

    accuracy = np.trace(cm) / float(np.sum(cm))
    misclass = 1 - accuracy

    if cmap is None:
        cmap = plt.get_cmap('Blues')

    fig = plt.figure(figsize=(4,3))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.colorbar()

    if target_names is not None:
        tick_marks = np.arange(len(target_names))
        plt.xticks(tick_marks, target_names, rotation=0, fontsize=12)
        plt.yticks(tick_marks, target_names,fontsize=12)

    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]


    thresh = cm.max() / 1.5 if normalize else cm.max() / 2
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        if normalize:
            plt.text(j, i, "{:0.4f}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black", fontsize=14)
        else:
            plt.text(j, i, "{:,}".format(cm[i, j]),
                     horizontalalignment="center",
                     color="white" if cm[i, j] > thresh else "black", fontsize=14)


    plt.tight_layout()
    plt.ylabel('True label', fontsize=14)
    plt.xlabel('Predicted label\naccuracy={:0.4f}; misclass={:0.4f}'.format(accuracy, misclass), fontsize=14)

    return fig

def classify_associations(target_file, assoc_tuples):
    self_assoc = 0 # Self associations
    found_assoc_dict = {}
    false_assoc_dict = {}
    tp_fp = np.array([[0,0]])
    with open (target_file, "r") as f:
        for line in f:
            if line[0] != "f":
                splitline = line.strip().split("\t") 
                feat_a = splitline[2]
                feat_b = splitline[3]
                score = abs(float(splitline[5]))
                if feat_a == feat_b: # Self associations will not be counted
                    self_assoc += 1
                else:
                    if (feat_a,feat_b) in assoc_tuples:
                        found_assoc_dict[(feat_a,feat_b)] = score
                        if (feat_b,feat_a) not in found_assoc_dict.keys(): #If we had not found it yet
                            tp_fp = np.vstack((tp_fp,tp_fp[-1]+[0,1]))
                    elif (feat_a,feat_b) not in assoc_tuples:
                        false_assoc_dict[(feat_a,feat_b)] = score
                        if (feat_b,feat_a) not in false_assoc_dict.keys():
                            tp_fp = np.vstack((tp_fp,tp_fp[-1]+[1,0]))

    # Remove duplicated associations:
    for (i,j) in list(found_assoc_dict.keys()):
        if (j,i) in found_assoc_dict.keys():
            del found_assoc_dict[(j,i)] # remove the weakest direction for the association


    for (i,j) in list(false_assoc_dict.keys()):
        if (j,i) in false_assoc_dict.keys():
            del false_assoc_dict[(i,j)]

    return self_assoc, found_assoc_dict, false_assoc_dict, tp_fp


def create_confusion_matrix(n_feat,associations,real_assoc,false_assoc):
    cm = np.empty((2,2))
    # TN: only counting the upper half matrix (non doubled associations)
    cm[0,0] = (n_feat*n_feat-n_feat)/2 - (associations+false_assoc) # Diagonal is discarded
    cm[0,1] = false_assoc
    cm[1,0] = associations- real_assoc
    cm[1,1] = real_assoc

    return cm

def get_precision_recall(found_assoc_dict,false_assoc_dict,associations):
    y_true = []
    y_pred = []

    # True Positives 
    for score in found_assoc_dict.values():
        y_true.append(1)
        y_pred.append(score)
    # False Positives
    for score in false_assoc_dict.values():
        y_true.append(0)
        y_pred.append(score)
    # False negatives
    for _ in range(associations-len(found_assoc_dict)):
        y_true.append(1)
        y_pred.append(0)

    precision, recall, thr = precision_recall_curve(y_true,y_pred) #thr will tell us score values
    avg_prec = average_precision_score(y_true,y_pred)

    return precision, recall, thr, avg_prec

def plot_precision_recall(precision,recall,avg_prec,label, ax):
    ax.scatter(recall,precision, lw=0, marker=".", s=5, edgecolors='none', label = f"{label} - APS:{avg_prec:.2f}")
    ax.legend()
    return ax


def plot_thr_recall(thr, recall,label,  ax):
    ax.scatter(recall[:-1],thr, lw=0, marker=".", s=5, edgecolors='none', label=label)
    ax.legend()
    return ax

def plot_TP_vs_FP(tp_fp, label, ax):
    ax.scatter(tp_fp[:,0],tp_fp[:,1],s=2, label=label, edgecolors='none')
    ax.legend()
    return ax

def plot_filling_order(order_list, last_rank=None):
    
    if last_rank is None:
        last_rank = len(order_list)
    fig = plt.figure()
    order_img = np.zeros((np.max(order_list),len(order_list)))
    for i, element in enumerate(order_list):
        order_img[element-1,i:] = 1

    plt.imshow(order_img[:last_rank,:], cmap="binary")
    plt.xlabel("Correct prediction number")
    plt.ylabel("Association ranking")
    plt.plot(np.arange(last_rank),np.arange(last_rank))
    return fig

def plot_effect_size_matching(assoc_tuples_dict,found_assoc_dict,label, ALGORITHM, ax):


    ground_truth_effects = [assoc_tuples_dict[key] for key in list(found_assoc_dict.keys())]
    predicted_effects = np.array(list(found_assoc_dict.values()))

    if ALGORITHM == 'ttest':
        #Eq 15 on https://doi.org/10.1146/annurev-statistics-031017-100307
        predicted_effects = [-np.log10(p) if p !=0 else -1 for p in predicted_effects]
        predicted_effects[predicted_effects == -1] = np.max(predicted_effects) # Change zeros for max likelihood, -1 as dummy value
        predicted_effects = np.array(predicted_effects)

    max, min  = np.max(predicted_effects), np.min(predicted_effects)
    standarized_pred_effects = (predicted_effects-min)/(max-min)
    ax.scatter(ground_truth_effects,standarized_pred_effects,s=12, edgecolors='none', label=label)
    ax.legend()
    return ax

def plot_venn_diagram(venn, ax, mode = 'all', scale='log'):
    sets = [set(venn[key][mode]) for key in list(venn.keys())]
    labels = (key for key in  list(venn.keys()))

    if len(venn) == 2:
        venn2(sets, labels, ax=ax)
    elif len(venn) == 3:
        venn3(sets, labels, ax=ax)
    else:
        raise ValueError("Unsupported number of input files.")


def plot_upsetplot(venn,assoc_tuples):
    
    all_assoc = set([association for ALGORITHM in venn.keys() for association in venn[ALGORITHM]['all']])
    columns = ['ground truth']
    columns.extend([ALGORITHM for ALGORITHM in list(venn.keys())])

    df = {}
    for association in all_assoc:
        df[association] = []

        if association in assoc_tuples:
            df[association].append('TP')
        else:
            df[association].append('FP')
        
        for ALGORITHM in list(venn.keys()):
            if association in venn[ALGORITHM]['all']:
                df[association].append(1)
            else:
                df[association].append(0)

    df = pd.DataFrame.from_dict(df, orient='index', columns = columns)
    df = df.set_index([pd.Index(df[ALGORITHM] == 1) for ALGORITHM in list(venn.keys())])
    upset = UpSet(df, intersection_plot_elements=0, show_counts=True)

    upset.add_stacked_bars(by="ground truth", colors=cm.Pastel1,
                       title="Count by ground truth value", elements=10)

    return upset




###################################### Main code ################################################

def main(args_list):
    parser = argparse.ArgumentParser(description='Read two files with ground truth associations and predicted associations.')
    parser.add_argument('-p', '--perturbed', metavar='pert', type=str, required=True, help='perturbed feature names')
    parser.add_argument('-n', '--features', metavar='feat', type=int, required=True, help='total number of features')
    parser.add_argument('-r', '--reference', metavar='ref', type=str, required=True, help='path to the ground truth associations file')
    parser.add_argument('-o', '--outpath', metavar='outpath', type=str, required=True, help='path where figures will be saved')
    parser.add_argument('-t', '--targets', metavar='tar', type=str, required=True, nargs='+', help='path to the predicted associations files')
    
    args = parser.parse_args(args_list)


    # Defining main performance evaluation figures:
    fig_0, ax_0 = plt.subplots(figsize=(5,5))
    fig_1, ax_1 = plt.subplots(figsize=(5,5))
    fig_2, ax_2 = plt.subplots(figsize=(5,5))
    fig_3, ax_3 = plt.subplots(figsize=(5,5))

    assoc_tuples_dict = {}

    # Reading the file with the ground truth changes:
    with open (args.reference, "r") as f:
        for line in f:
            if line[0] != "f" and line[0] != "n":
                splitline = line.strip().split("\t") 
                feat_a = splitline[2]
                feat_b = splitline[3]
                assoc_strength = abs(float(splitline[4]))
                # Only can detect associations with perturbed features
                if args.perturbed in feat_a or args.perturbed in feat_b: 
                    assoc_tuples_dict[(feat_a,feat_b)] = assoc_strength
                    assoc_tuples_dict[(feat_b,feat_a)] = assoc_strength

    associations = int(len(assoc_tuples_dict)/2) 
    venn = {}
    
    # Count and save found associations
    for target_file in args.targets:      
        ALGORITHM = target_file.split('/')[-1].split('_')[3][:-4] 
        self_assoc, found_assoc_dict, false_assoc_dict, tp_fp = classify_associations(target_file,list(assoc_tuples_dict.keys()))
        real_assoc = len(found_assoc_dict) # True predicted associations
        false_assoc = len(false_assoc_dict) # False predicted associations
        total_assoc = real_assoc + false_assoc

        venn[ALGORITHM] = {}
        venn[ALGORITHM]['correct'] = list(found_assoc_dict.keys()) 
        venn[ALGORITHM]['all'] = list(found_assoc_dict.keys()) +  list(false_assoc_dict.keys())

        # Assess ranking of associations (they are doubled in assoc_tuples):
        order_list = [list(assoc_tuples_dict.keys()).index((feat_a,feat_b))//2 for (feat_a,feat_b) in list(found_assoc_dict.keys())]
        fig = plot_filling_order(order_list)
        fig.savefig(f"Order_image_{ALGORITHM}.png", dpi=200)

        ax_0 = plot_effect_size_matching(assoc_tuples_dict, found_assoc_dict, ALGORITHM, ALGORITHM, ax_0)

        # Plot confusion matrix:
        cm = create_confusion_matrix(args.features,associations,real_assoc,false_assoc)
        fig = plot_confusion_matrix(cm,
                              ["No assoc","Association"],
                              cmap=None,
                              normalize=False)

        fig.savefig(f'Confusion_matrix_{ALGORITHM}.png', dpi=100, bbox_inches='tight')

        # Plot precision-recall and TP-FP curves
        precision, recall, thr, avg_prec = get_precision_recall(found_assoc_dict,false_assoc_dict,associations)

        ax_1 = plot_precision_recall(precision,recall, avg_prec,ALGORITHM, ax_1)
        ax_2 = plot_TP_vs_FP(tp_fp, ALGORITHM, ax_2)
        ax_3 = plot_thr_recall(thr, recall, ALGORITHM, ax_3)


        # Write results:
        with open('Performance_evaluation_summary_results.txt','a') as f:
            f.write(f" File:  {target_file}\n")
            f.write(f"Ground truth detectable associations (i.e. involving perturbed feature,{args.perturbed}):{associations}\n")
            f.write(f"{total_assoc} unique associations found\n{self_assoc} self-associations were found before filtering\n{real_assoc} were real associations\n{false_assoc} were either false or below the significance threshold\n")
            #print("Correct associations:\n", found_assoc_tuples, "\n")
            f.write(f"Sensitivity:{real_assoc}/{associations} = {real_assoc/associations}\n")
            f.write(f"Precision:{real_assoc}/{total_assoc} = {(real_assoc)/total_assoc}\n")
            f.write(f"Order list:{order_list}\n\n")
            f.write("______________________________________________________\n")


    # Edit figures: layout
    ax_0.set_xlabel("Real effect")
    ax_0.set_ylabel("Predicted effect")
    ax_0.set_ylim((-0.02,1.02))
    ax_0.set_xlim((0,1.02))
    ax_0.legend(loc='upper center', bbox_to_anchor=(0.5, 1.1),
              ncol=3, fancybox=True, shadow=True)

    ax_1.set_xlabel("Recall")
    ax_1.set_ylabel("Precision")
    ax_1.legend()
    ax_1.set_ylim((0,1.05))
    ax_1.set_xlim((0,1.05))

    ax_2.set_xlabel("False Positives")
    ax_2.set_ylabel("True Positives")
    ax_2.set_aspect("auto")

    ax_3.set_ylabel("Threshold")
    ax_3.set_xlabel("Recall")


    # Save main figures:
    fig_0.savefig(args.outpath + "Effect_size_matchin.png", dpi=200)
    fig_1.savefig(args.outpath + "Precision_recall.png", dpi=200)
    fig_2.savefig(args.outpath + "TP_vs_FP.png", dpi=200)
    fig_3.savefig(args.outpath + "thr_vs_recall.png", dpi=200)

    # Plotting venn diagram:
    if len(venn) == 2 or len(venn) == 3:
        fig_v, ax_v = plt.subplots(figsize=(4,4))
        ax_v = plot_venn_diagram(venn, ax_v, mode = 'correct')
        fig_v.savefig(args.outpath + 'Venn_diagram.png', dpi=200)

    # Plotting UpSet plot
    upset = plot_upsetplot(venn, list(assoc_tuples_dict.keys()))
    upset.plot()
    plt.savefig(args.outpath + 'UpSet.png', dpi=200)

In [None]:
args_list = [
    '-p', 'Continuous_A_2',
    '-n', '36',
    '-r', './synthetic_data/changes.random_all_sim.txt',
    '-o', './synthetic_data/',
    '-t', 
    './results_cont_paper/identify_associations/results_sig_assoc_ttest.tsv',
    './results_cont_paper/identify_associations/results_sig_assoc_bayes.tsv',
     './results_cont_paper/identify_associations/results_sig_assoc_ks.tsv'
]

main(args_list)

**Synthetic data 2: Simplified version:**

Low N: 50
High N: 1000

In [None]:
! rm -r ./interim_data_cont_paper_II/models/
! rm results_cont_paper_II/latent_space/model.pt
! rm -r results_cont_paper_II/identify_associations/
! rm -r ./synthetic_data_II/


In [None]:
########################### Hyperparameters ####################################
import random as rnd

PROJECT_NAME = "random_all_sim"
MODE = "linear"  # "non-linear"
HIGH_CORR = True
SEED_1 = 1234
np.random.seed(SEED_1)
rnd.seed(SEED_1)

COV_ALPHA = 0.97 #.01 
N_SAMPLES = 1000 # LOW N: 50, High N: 1000


SETTINGS = {
    "Continuous_A": {
        "features": 10,
        "frequencies": [0.002, 0.01, 0.02],
        "coefficients": [500, 100, 50],
        "phase": 0,
        "offset": 500,
    },
    "Continuous_B": {
        "features": 10,
        "frequencies": [0.001, 0.05, 0.08],
        "coefficients": [80, 20, 10],
        "phase": np.pi / 2,
        "offset": 400,
    },
    "Categorical_A": {
        "features": 1,
        "frequencies": [0.1, 0.5, 0.8],
        "coefficients": [.2, .1, .05],
        "phase": np.pi / 2,
        "offset": 10,
    },
        "Categorical_B": {
        "features": 1,
        "frequencies": [0.01, 0.5, 0.08],
        "coefficients": [10, .1, .05],
        "phase": np.pi,
        "offset": 1,
    }
}

COR_THRES = 0.02
PAIRS_OF_INTEREST = [(1,2),(3,4)]  # ,(77,75),(99,70),(38,2),(67,62)]

# Path to store output files
outpath = Path("synthetic_data_II")
outpath.mkdir(exist_ok=True, parents=True)

# %%
# Add all datasets in a single matrix:
all_feature_names = get_feature_names(SETTINGS)
feat_means = create_mean_profiles(SETTINGS)

# %%
###### Covariance matrix definition ######
if MODE == "linear":
    covariance_matrix = make_sparse_spd_matrix(
        dim=len(all_feature_names),
        alpha=COV_ALPHA,
        smallest_coef=0,
        largest_coef=1,
        norm_diag=True,
        random_state=SEED_1,
    )
elif MODE == "non-linear":
    covariance_matrix = np.identity(len(all_feature_names))

ABS_MAX = np.max(abs(covariance_matrix))
fig = plot_score_matrix(
    covariance_matrix, all_feature_names, vmin=-ABS_MAX, vmax=ABS_MAX
)
fig.savefig(outpath / f"Covariance_matrix_{PROJECT_NAME}.png")

#    dataset = np.array(
#        [
#            list(np.random.multivariate_normal(feat_means, covariance_matrix))
#            for _ in range(N_SAMPLES)
#        ]
#    )

dataset = np.random.multivariate_normal(feat_means, covariance_matrix, N_SAMPLES)



# Add non-linearities
if MODE == "non-linear":
    for i, j in PAIRS_OF_INTEREST:
        freq = np.random.choice([4, 5, 6])
        dataset[:, i] += np.sin(freq * dataset[:, j])

#scaled_dataset, _ = scale(dataset)
# No scaling in the dataset creation! It will be handled in preprocessing.
scaled_dataset = dataset

if HIGH_CORR: # The last half of the features are combinations of the first half:
    for i in range(scaled_dataset.shape[1]//2):
        col_1 = np.random.choice(range(scaled_dataset.shape[1]//2))
        col_2 = np.random.choice(range(scaled_dataset.shape[1]//2))
        scaled_dataset[:,i+scaled_dataset.shape[1]//2] = (scaled_dataset[:,col_1]+scaled_dataset[:,col_2])/2 + np.random.normal()

# Binarize the categorical dataset
NUM_CAT = SETTINGS["Categorical_A"]["features"] + SETTINGS["Categorical_B"]["features"]
columns_to_binarize = scaled_dataset[:,-NUM_CAT:]

# Compute the mean of each of the categorical columns    
means = columns_to_binarize.mean(axis=0)

# Apply the binarization
scaled_dataset[:,-NUM_CAT:] = (columns_to_binarize > means).astype(int)

print(np.min(scaled_dataset),np.max(scaled_dataset))
# Actual correlations:
# correlations = np.empty(np.shape(covariance_matrix))
# for ifeat in range(len(covariance_matrix)):
#     for jfeat in range(len(covariance_matrix)):
#         correlations[ifeat, jfeat] = pearsonr(dataset[:, ifeat], dataset[:, jfeat])[
#             0
#         ]

correlations = np.corrcoef(scaled_dataset, rowvar=False)
fig = plot_score_matrix(correlations, all_feature_names, vmin=-1, vmax=1, label_step=5)
fig.savefig(outpath / f"Correlations_{PROJECT_NAME}.png", dpi=200)

# Sort correlations by absolute value
associations = create_ground_truth_correlations_file(correlations)
associations.to_csv(outpath / f"changes.{PROJECT_NAME}.txt", sep="\t", index=False)

# Plot feature profiles per sample
fig = plot_feature_profiles(scaled_dataset, feat_means)
fig.savefig(outpath / "Multi-omic_profiles.png")

## Plot correlations
fig = plot_feature_correlations(dataset, PAIRS_OF_INTEREST)
fig.savefig(outpath / "Feature_correlations.png")

fig = plot_feature_correlations(scaled_dataset, PAIRS_OF_INTEREST)
fig.savefig(outpath / "Feature_correlations_scaled.png")

# Write tsv files with feature values for all samples in both datasets:
save_splitted_datasets(
    SETTINGS, PROJECT_NAME, scaled_dataset, all_feature_names, N_SAMPLES, outpath
)


In [None]:
### Running MOVE on simple synthetic data

# Encode data
#! move-dl task=encode_data data=random_continuous_paper_II 

# Identify assoc ks
#! move-dl task=random_continuous_paper_II__latent data=random_continuous_paper_II 

# Identify assoc ks
! move-dl task=random_continuous_paper_II__id_assoc_ks data=random_continuous_paper_II 

In [None]:
MARC_BUCKET = "gs://collaborations_marc/shared_2023/results_marc/results_residualized/" #associations_final_tables/Bayes/" #results_synthetic_small/high_N_low_corr/max_pert/"


! gsutil -m cp -r ./paper_figures/ {MARC_BUCKET}
#! gsutil -m cp -r ./results_cont_paper_II/ {MARC_BUCKET}
#! gsutil -m cp -r ./synthetic_data_II/ {MARC_BUCKET}

In [None]:
!ls ./results_temp_id_assoc_cellType/id_assoc_200_[500]_0.0001_cellType_AMSC__id_assoc_bayes/identify_associations/

In [None]:
from PIL import Image
from matplotlib.pyplot import cm
import seaborn as sns
from matplotlib.colors import ListedColormap

DATASET = "Continuous_B"
#feature_list = ["Day"]# + list(genes_of_interest.keys())
feature_list = pd.read_csv(f"./interim_data_cont_paper_II/random.random_all_sim.{DATASET}.txt", header=None)
feature_list = feature_list.values.flatten()

figure_path = Path("./synthetic_data_II/figures_synthetic_small/")
! mkdir -p {figure_path}

latent_space = np.load("./results_cont_paper_II/identify_associations/latent_location.npy")

def plot_3D_latent_and_displacement(
    mu_baseline,
    mu_perturbed,
    feature_values,
    feature_name,
    show_baseline=True,
    show_perturbed=True,
    show_arrows=True,
    step: int=1,
    altitude: int=30,
    azimuth: int=45,
):
    """
    Plot the movement of the samples in the 3D latent space after perturbing one
    input variable.

    Args:
        mu_baseline:
            ND array with dimensions n_samples x n_latent_nodes containing
            the latent representation of each sample
        mu_perturbed:
            ND array with dimensions n_samples x n_latent_nodes containing
            the latent representation of each sample after perturbing the input
        feature_values:
            1D array with feature values to map to a colormap ("bwr"). Each sample is
            colored according to its value for the feature of interest.
        feature_name:
            name of the feature mapped to a colormap
        show_baseline:
            plot orginal location of the samples in the latent space
        show_perturbed:
            plot final location (after perturbation) of the samples in latent space
        show_arrows:
            plot arrows from original to final location of each sample
        angle:
            elevation from dim1-dim2 plane for the visualization of latent space.

    Raises:
        ValueError: If latent space is not 3-dimensional (3 hidden nodes).
    Returns:
        Figure
    """
    # construct cmap
    #hex_colors= ['#36429bff','#78b0d3ff','#fedc8cff','#d22b27ff']
    #my_cmap = ListedColormap(hex_colors)
    #my_cmap = sns.color_palette("Dark2", as_cmap=True)
    my_cmap = sns.color_palette("RdYlBu", as_cmap=True)

    eps = 1e-16
    if [np.shape(mu_baseline)[1], np.shape(mu_perturbed)[1]] != [3, 3]:
        raise ValueError(
            " The latent space must be 3-dimensional. Redefine num_latent to 3."
        )

    fig = plt.figure(layout="constrained", figsize=(7, 7))
    ax = fig.add_subplot(projection="3d")
    ax.view_init(altitude, azimuth)

    if show_baseline:
        vmin, vmax = np.min(feature_values[::step]), np.max(feature_values[::step])
        abs_max = np.max([abs(vmin), abs(vmax)])
        plot = ax.scatter(
            mu_baseline[::step, 0],
            mu_baseline[::step, 1],
            mu_baseline[::step, 2],
            marker="o",
            c=feature_values[::step],
            s=20,
            lw=0,
            cmap=my_cmap,
            #vmin=0,
            #vmax=1
        )
        ax.set_title(feature_name)
        fig.colorbar(plot)
        
        #plt.colorbar()  # Normalize(min(feature_values[::step]),max(feature_values[::step]))), ax=ax)
    if show_perturbed:
        ax.scatter(
            mu_perturbed[::step, 0],
            mu_perturbed[::step, 1],
            mu_perturbed[::step, 2],
            marker="o",
            c=feature_values[::step],
            s=10,
            label="perturbed",
            lw=0,
        )
    if show_arrows:
        u = mu_perturbed[::step, 0] - mu_baseline[::step, 0]
        v = mu_perturbed[::step, 1] - mu_baseline[::step, 1]
        w = mu_perturbed[::step, 2] - mu_baseline[::step, 2]

        module = np.sqrt(u * u + v * v + w * w)

        mask = module > eps

        max_u, max_v, max_w = np.max(abs(u)), np.max(abs(v)), np.max(abs(w))
        # Arrow colors will be weighted contributions of red -> dim1, green -> dim2, and blue-> dim3. I.e. purple arrow means movement in dims 1 and 3
        colors = [
            (abs(du) / max_u, abs(dv) / max_v, abs(dw) / max_w, 0.7)
            for du, dv, dw in zip(u, v, w)
        ]
        ax.quiver(
            mu_baseline[::step, 0][mask],
            mu_baseline[::step, 1][mask],
            mu_baseline[::step, 2][mask],
            u[mask],
            v[mask],
            w[mask],
            color=colors,
            lw=.8,
            )  # alpha=(1-module/np.max(module))**6, arrow_length_ratio=0)
        # help(ax.quiver)
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")
    ax.set_zlabel("Dim 3")

    
    # ax.set_axis_off()

    return fig

for i,feature in enumerate(feature_list):
    feature_values = np.load(f"./interim_data_cont_paper_II/random.random_all_sim.{DATASET}.npy")
    print(feature_values.shape)
    
    feature_values = feature_values[:,i]
    latent_space_baseline = latent_space[:,:,-1]
    latent_space_perturbed = latent_space[:,:,i]
    
    # # Plot latent space:
    pic_num = 0
    n_pictures = 100
    
    
    for azimuth, altitude in zip(
        np.linspace(0, 90, n_pictures), np.linspace(15, 45, n_pictures)
    ):
        
        if pic_num == 80:

            title = feature 
            fig = plot_3D_latent_and_displacement(
                latent_space_baseline,
                latent_space_perturbed,
                feature_values=feature_values,
                feature_name=f"{title}",
                show_baseline=True,
                show_perturbed=False,
                show_arrows=False,
                altitude=altitude,
                azimuth=azimuth,
            )
            fig.savefig(
                figure_path / f"3D_latent_{pic_num}_perturbed_{feature}.png", dpi=200
            )
            plt.close(fig)

            fig = plot_3D_latent_and_displacement(
                latent_space_baseline,
                latent_space_perturbed,
                feature_values=feature_values,
                feature_name=f"{title}",
                show_baseline=False,
                show_perturbed=False,
                show_arrows=True,
                altitude=altitude,
                azimuth=azimuth,
            )
            fig.savefig(
                figure_path / f"3D_latent_{pic_num}_perturbed_{feature}_arrows.png", dpi=200
            )
            plt.close(fig)
        
        pic_num += 1







In [None]:
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
for i in range(1,11): # range(0,90,10):
    image_path = figure_path / f"3D_latent_80_perturbed_Continuous_B_{i}.png"
    image = mpimg.imread(image_path)
    fig = plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.show()
    
    image_path = figure_path / f"3D_latent_80_perturbed_Continuous_B_{i}_arrows.png"
    image = mpimg.imread(image_path)
    fig = plt.figure(figsize=(10,10))
    plt.imshow(image)
    plt.show()

# Multiprocessing installation:

```python
if INSTALL:
    # We will clone the version under development of MOVE, which can handle perturbations of continuous variables
    #! git clone -b developer-continuous-v3 https://github.com/RasmussenLab/MOVE.git /home/jupyter/.local/bin/MOVE
    ! git clone -b developer https://github.com/RasmussenLab/MOVE.git /home/jupyter/.local/bin/MOVE
    %cd /home/jupyter/.local/bin/MOVE
    # Step 3: Fetch the pull request
    !git fetch origin pull/92/head:pr-92
    
    # Step 4: Checkout the branch
    !git checkout pr-92
    
    sys.path.append("/home/jupyter/.local/bin/MOVE/src")
    # ! pip install -e /home/jupyter/.local/bin/MOVE/ 
    
    # Execute notebook from starting folder
    %cd /home/jupyter/Characterizing AMSCs using MOVE/edit
    
    #from terra_notebook_utils import table, gs, drs
    # We can also enable different sections to fold
    ! jupyter nbextension enable codefolding/main
    ! jupyter nbextension enable collapsible_headings/main

    ! pip install omegaconf upsetplot umap-learn
    ! pip install -r /home/jupyter/.local/bin/MOVE/requirements.txt
```

# Checking for dropped columns: full nans

In [None]:
a = pd.read_csv('./data_residualized/Cells.tsv', sep='\t')
print(a.shape)

b = np.load('./interim_data_rsd/Cells.npy')
print(b.shape)

filepath = Path('./interim_data_rsd/Cells.npy')
data = np.load(filepath).astype(np.float32)
data[data == 0] += 1000
fig = plt.figure()
plt.imshow(data, cmap='viridis')
fig.show()

mask_col = np.abs(data).sum(axis=0) != 0

data[:,~mask_col] += 1000

fig = plt.figure()
plt.imshow(data, cmap='binary')
fig.show()


data = data[:, mask_col]
print(data.shape)