

# ON THE USE OF VAES FOR BIOMEDICAL DATA INTEGRATION: THE TUTORIAL


## Intro

Hi!
This notebook is a short tutorial to guide you through the main findings of the paper.

> 💻 **Optional GPU acceleration:**
>
> We can use GPUs when running colab notebooks! Go to the top right corner and click on the arrow pointing downwards 🔽. There you can choose to ``` Change runtime type``` and select the ```Hardware accelerator``` to be ```T4```.
>

**Background**

We will explore the behavior of Multimodal Variational Autoencoders (VAEs) when integrating diverse data modalities.

>📕 **What are VAEs?**
>
>If you are not familiar with VAEs, you can read:
>- _The original paper_: [Autoencoding Variational Bayes](https://arxiv.org/abs/1312.6114), by Diederik P. Kingma and Max Welling.
>  - Strong theoretical foundation, motivation of the Evidence Lower Bound Objective (ELBO).
>- _The MultiOmics Variational Autoencoder paper_ [MOVE](https://www.nature.com/articles/s41587-022-01520-x).
>
>  - Presents the MultiOmics Variational autoEncoder (MOVE) and applies it to a cohort of Type II diabetes (T2D) patients.
>
>- _Deep Generative Modelling_, by Jakub M. Tomczak. Chapter 4.3.
>  
>  - Solid theoretical motivation with code examples.



**The notebook**

In this notebook we will:
- Install the Multiomics Variational Autoencoder [MOVE](https://www.nature.com/articles/s41587-022-01520-x)
- Generate a synthetic dataset containing categorical and continuous features.
- Run the main tasks in the MOVE pipeline:
  - Encode the data
  - Analyze the latent space
  - Identify associations
  - Visualize the perturbations




Let's start!

## Install and import the required packages

The code to create and run the Multiomics Variational AutoEncoder (MOVE) can be installed as a pip package.

>⚠️ **Warning**
>
>You'll be asked to restart the runtime after running the command below to update the changes.

In [None]:
! pip install move-dl



In [None]:
import numpy as np
import random as rnd
import pandas as pd
from pathlib import Path
import shutil
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from sklearn.datasets import make_sparse_spd_matrix
from PIL import Image
from matplotlib.pyplot import cm
import seaborn as sns
from matplotlib.colors import ListedColormap
from IPython.display import display
from IPython.display import Image as Image_display
import torch

#from scipy.stats import pearsonr
#from move.data.preprocessing import scale
#import os
#import io
#import seaborn as sns
#import sys
#import matplotlib as mpl
#from itertools import chain

In [None]:
CUDA = "true" if torch.cuda.is_available() else "false"

## Generating a synthetic dataset

We will now create a synthetic dataset by drawing a number of samples ($N_{samples}$, in rows) from a known distribution that describes a number of properties or features of the samples ($N_{features}$, in columns). We will create a multimodal dataset as a combination of 4 modalities: Categorical_A and Categorical_B, with one feature each, and Continuous_A and Continuous_B with 10 features each.

>**📕 Theoretical insights:** In our synthetic datasets, each sample was originally obtained as a measurement or draw from a multivariate gaussian distribution, 𝒙 ∼ 𝒩(𝜇, 𝜎), where each feature was described by one of the components of the gaussian. This relatively simple protocol enabled us to simulate different scales of possible feature values via the mean of each feature. It also allowed us to add an arbitrary number of linear associations between features by controlling the covariance matrix describing the distribution. A vector encoding the means of each feature or column and a sparse covariance matrix (sparsity defined by $α_{cov}$ = 0.99) were used to define the parameters of the multivariate random normal distribution. The dataset went subsequently through the standard preprocessing for continuous variables, which consisted in $\log_2(1+x)$ transformation and posterior normalization to zero mean and unit variance.
>
>Stronger correlations could  subsequently be added by defining the second half of features to be linear combinations of randomly chosen features in the first half:
>
>$$Feature_k=\frac{Feature_i + Feature_j}{2} + \epsilon$$
>
>Where   was drawn from a univariate random normal distribution. $i$ and $j$ are features from the first half with $i$  $j$ , overriding a feature from the second half k.  Finally, continuous features could be turned into binary features by setting to zero/one all values below/above the mean. Ground truth associated variables were defined to be two variables with a correlation above a predefined threshold.
>
> In this tutorial, we will add the extra correlations.


In [None]:
################################ Functions ####################################

def get_feature_names(settings):
    """
    This function returns a list with all feature names

    Args:
        settings (dict): dictionary with all settings

    Returns:
        all_feature_names: list with all feature names
    """
    all_feature_names = [
        f"{key}_{i+1}"
        for key in settings.keys()
        for i in range(settings[key]["features"])
    ]
    return all_feature_names


def create_mean_profiles(settings):
    """
    This function returns a list with all feature means.

    Args:
        settings (dict): dictionary with all settings.

    Returns:
        feature_means: list with all feature means.
    """
    feature_means = []
    for key in settings.keys():
        mean = settings[key]["offset"]
        for freq, coef in zip(
            settings[key]["frequencies"], settings[key]["coefficients"]
        ):
            mean += coef * (
                np.sin(
                    freq * np.arange(settings[key]["features"]) + settings[key]["phase"]
                )
                + 1
            )
        feature_means.extend(list(mean))
    return feature_means


def create_ground_truth_correlations_file(correlations, COR_THRES):
    """
    This function saves the ground truth associations in a Dataframe, which will be
    then stored in a tsv file. Ground truth associations are defined to be the
    pairs of fatures with a pearson correlation above COR_THRES.

    Args:
        correlations (np.array): array with all correlations.
        COR_THRES (float): threshold for the pearson correlation.

    Returns:
        associations (pd.DataFrame): dataframe with all associations.
    """

    sort_ids = np.argsort(abs(correlations), axis=None)[::-1]  # 1D: N x C
    corr = np.take(correlations, sort_ids)  # 1D: N x C
    sig_ids = sort_ids[abs(corr) > COR_THRES]
    sig_ids = np.vstack(
        (sig_ids // len(all_feature_names), sig_ids % len(all_feature_names))
    ).T
    associations = pd.DataFrame(sig_ids, columns=["feature_a_id", "feature_b_id"])
    a_df = pd.DataFrame(dict(feature_a_name=all_feature_names))
    a_df.index.name = "feature_a_id"
    a_df.reset_index(inplace=True)
    b_df = pd.DataFrame(dict(feature_b_name=all_feature_names))
    b_df.index.name = "feature_b_id"
    b_df.reset_index(inplace=True)
    associations = associations.merge(a_df, on="feature_a_id", how="left").merge(
        b_df, on="feature_b_id", how="left"
    )
    associations["Correlation"] = corr[abs(corr) > COR_THRES]
    associations = associations[
        associations.feature_a_id > associations.feature_b_id
    ]  # Only one half of the matrix
    return associations


def plot_score_matrix(
    array, feature_names, cmap="bwr", vmin=None, vmax=None, label_step=5
):
    """
    This function plots a score matrix.

    Args:
        array (np.array): array with all correlations.
        feature_names (list): list with all feature names.

    Returns:
        fig: fig object to save or show the plot.
    """
    if vmin is None:
        vmin = np.min(array)
    elif vmax is None:
        vmax = np.max(array)
    # if ax is None:
    fig = plt.figure(figsize=(5, 5))
    plt.imshow(array, cmap=cmap, vmin=vmin, vmax=vmax)
    plt.xticks(
        np.arange(0, len(feature_names), label_step),
        feature_names[::label_step],
        fontsize=8,
        rotation=90,
    )
    plt.yticks(
        np.arange(0, len(feature_names), label_step),
        feature_names[::label_step],
        fontsize=8,
    )
    plt.tight_layout()
    # ax
    return fig

def save_splitted_datasets(
    settings: dict, PROJECT_NAME, dataset, all_feature_names, n_samples, outpath
):
    """
    This function saves the splitted datasets in tsv files.

    """
    # Save index file
    index = pd.DataFrame({"ID": list(np.arange(1, n_samples + 1))})
    index.to_csv(outpath / f"random.{PROJECT_NAME}.ids.txt", index=False, header=False)
    # Save continuous files
    df = pd.DataFrame(
        dataset, columns=all_feature_names, index=list(np.arange(1, n_samples + 1))
    )
    cum_feat = 0
    for key in settings.keys():
        df_feat = settings[key]["features"]
        df_cont = df.iloc[:, cum_feat : cum_feat + df_feat]
        df_cont.insert(0, "ID", np.arange(1, n_samples + 1))
        df_cont.to_csv(
            outpath / f"random.{PROJECT_NAME}.{key}.tsv", sep="\t", index=False
        )
        cum_feat += df_feat

**Hyperparameters**

We will set the value of a number of hyperparameters to create and store the datasets. The values inside SETTINGS just define the sample-specific profiles.

In [None]:
########################### Hyperparameters ####################################
PROJECT_NAME = "random_all_sim"
HIGH_CORR = True # Add extra correlations as linear combinations of existing features
SEED_1 = 1234 # Keeping the seed for reproducibility
np.random.seed(SEED_1) # Setting seed values
rnd.seed(SEED_1) # Setting seed values
COV_ALPHA = 0.97 # Alpha: hyperparam governing the sparsity of the covariance matrix
N_SAMPLES = 1000 # LOW N: 50, High N: 1000

# Settings: Dictionary with names and values to create feature profiles.
SETTINGS = {
    "Continuous_A": {
        "features": 10,
        "frequencies": [0.002, 0.01, 0.02],
        "coefficients": [500, 100, 50],
        "phase": 0,
        "offset": 500,
    },
    "Continuous_B": {
        "features": 10,
        "frequencies": [0.001, 0.05, 0.08],
        "coefficients": [80, 20, 10],
        "phase": np.pi / 2,
        "offset": 400,
    },
    "Categorical_A": {
        "features": 1,
        "frequencies": [0.1, 0.5, 0.8],
        "coefficients": [.2, .1, .05],
        "phase": np.pi / 2,
        "offset": 10,
    },
        "Categorical_B": {
        "features": 1,
        "frequencies": [0.01, 0.5, 0.08],
        "coefficients": [10, .1, .05],
        "phase": np.pi,
        "offset": 1,
    }
}

COR_THRES = 0.02 # Correlation threshold above which a pair of features is considered to be associated
PAIRS_OF_INTEREST = [(1,2),(3,4)]


# Path to store output files
outpath = Path("./synthetic_data_II")
outpath.mkdir(exist_ok=True, parents=True)

**Script**

We will now create the different datasets and save them in a number of tables, as tsv files. In addition, a file with samples ids will be created.

In [None]:
###################### Script to create a synthetic dataset ####################

# Add all datasets in a single matrix:
all_feature_names = get_feature_names(SETTINGS)
feat_means = create_mean_profiles(SETTINGS)

# Covariance matrix
covariance_matrix = make_sparse_spd_matrix(dim=len(all_feature_names), alpha=COV_ALPHA, norm_diag=True,random_state=SEED_1)
ABS_MAX = np.max(abs(covariance_matrix))

# No scaling in the dataset creation! It will be handled in preprocessing.
scaled_dataset = np.random.multivariate_normal(feat_means, covariance_matrix, N_SAMPLES)

# Add extra correlations as linear combinations of existing features
if HIGH_CORR: # The last half of the features are combinations of the first half:
    for i in range(scaled_dataset.shape[1]//2):
        col_1 = np.random.choice(range(scaled_dataset.shape[1]//2))
        col_2 = np.random.choice(range(scaled_dataset.shape[1]//2))
        scaled_dataset[:,i+scaled_dataset.shape[1]//2] = (scaled_dataset[:,col_1]+scaled_dataset[:,col_2])/2 + np.random.normal()

# Binarize the categorical dataset
NUM_CAT = SETTINGS["Categorical_A"]["features"] + SETTINGS["Categorical_B"]["features"]
columns_to_binarize = scaled_dataset[:,-NUM_CAT:]

# Compute the mean of each of the categorical columns
means = columns_to_binarize.mean(axis=0)

# Apply the binarization
scaled_dataset[:,-NUM_CAT:] = (columns_to_binarize > means).astype(int)

# Plot correlations matrix
correlations = np.corrcoef(scaled_dataset, rowvar=False)
fig = plot_score_matrix(correlations, all_feature_names, vmin=-1, vmax=1, label_step=5)
plt.title("Correlations between variables")
fig.savefig(outpath / f"Correlations_{PROJECT_NAME}.png", dpi=200)

# Sort correlations by absolute value
associations = create_ground_truth_correlations_file(correlations, COR_THRES)
associations.to_csv(outpath / f"changes.{PROJECT_NAME}.txt", sep="\t", index=False)

# Write tsv files with feature values for all samples in both datasets:
save_splitted_datasets(
    SETTINGS, PROJECT_NAME, scaled_dataset, all_feature_names, N_SAMPLES, outpath
)

You can have a look at the created datasets at:

```
./content/synthetic_data_II/
```


## Running MOVE

When running MOVE, we read a number of hyperparameters from the configuration (.yaml) files. We will create these configuration files now.

> **For the advanced reader:**
>
>Feel free to modify the hyperparameters! In particular, you can play with the strength on the KLD with the prior ($\beta$) or the weight assigned to the loss associated to each dataset (weight parameter on the data yaml config).
> Particular hyperparameters will shape the behavior of the modls and therefore results, as we show in the manuscript!
>


### Create config files:



In [None]:
config_paths = [Path("./config/data/"),
                Path("./config/task/")]

for config_path in config_paths:
  config_path.mkdir(parents=True, exist_ok=True)

data_yaml_contents = {
"random_continuous_paper_II.yaml": """
# DO NOT EDIT DEFAULTS
defaults:
  - base_data

# FEEL FREE TO EDIT BELOW

raw_data_path: synthetic_data_II/              # where raw data is stored
interim_data_path: interim_data_cont_paper_II/  # where intermediate files will be stored
results_path: results_cont_paper_II/     # where result files will be placed

sample_names: random.random_all_sim.ids  # names/IDs of each sample, must appear in the
                                # other datasets

categorical_inputs:
  - name: random.random_all_sim.Categorical_A
  - name: random.random_all_sim.Categorical_B

continuous_inputs:   # a list of continuous datasets
  - name: random.random_all_sim.Continuous_A
    weight: 1
    log2: true
    scale: true
  - name: random.random_all_sim.Continuous_B
    weight: 1
    log2: true
    scale: true
"""}

task_yaml_contents = {"random_continuous_paper_II__latent.yaml": """

defaults:
  - analyze_latent

batch_size: 10

feature_names:
  - Continuous_A_1
  - Continuous_B_1
  - Continuous_B_2
  - Continuous_B_3
  - Continuous_B_4
  - Continuous_B_5
  - Continuous_B_6
  - Continuous_B_7
  - Continuous_B_8
  - Categorical_A_1

model:
  num_hidden:
    - 15
  num_latent: 3
  beta: .0001

training_loop:
  lr: 1e-4
  num_epochs: 300
  batch_dilation_steps:
    - 100
    - 200
  kld_warmup_steps:
    - 50
    - 100
    - 125
    - 150
    - 175
    - 200
    - 225
    - 250
    - 275
  early_stopping: false
  patience: 0 ""","""random_continuous_paper_II__id_assoc_ks.yaml""": """
defaults:
  - identify_associations_ks_schema

batch_size: 10

num_refits: 1
sig_threshold: 0.05

target_dataset: random.random_all_sim.Continuous_B
target_value: plus_std
save_refits: True

model:
  categorical_weights: ${weights:${data.categorical_inputs}}
  continuous_weights: ${weights:${data.continuous_inputs}}
  num_hidden:
    - 15
  num_latent: 3
  beta: .0001
  dropout: 0.1

training_loop:
  lr: 1e-4
  num_epochs: 300
  batch_dilation_steps:
    - 100
    - 200
  kld_warmup_steps:
    - 50
    - 100
    - 125
    - 150
    - 175
    - 200
    - 225
    - 250
    - 275
  early_stopping: false
  patience: 0

perturbed_feature_names:
  - Continuous_B_1
target_feature_names:
  - Continuous_B_1
  - Continuous_B_2
  - Continuous_B_3
  - Continuous_B_4
  - Continuous_B_5
  - Continuous_B_6
  - Continuous_B_7
  - Continuous_B_8
  - Continuous_B_9
  - Continuous_B_10 """}


for file,content in data_yaml_contents.items():
  with open(config_paths[0] / file, 'w') as f:
      f.write(content)


for file,content in task_yaml_contents.items():
  with open(config_paths[1] / file, 'w') as f:
      f.write(content)

You can have a look at the global config file:

```
./content/config/data/
```

And the task specific configuration files:
```
./content/config/data/
```


### Performing the different tasks:

Now we will:
- Encode the data in a MOVE-friendly manner.
- Train MOVE models according to the settings predefined in the configuration files.
- Visualize a 2D representation of the latent space (UMAP).
- Perform feature importance analyses with SHAP
- Perturb the inputs to identify associated variables in the output.

We will do it in command-line style, all at once, and we will discuss the results afterwards.


> ⏰ This step will take some minutes. Coffee break!

In [None]:
### Running MOVE on simple synthetic data

# Encode data
! move-dl task=encode_data data=random_continuous_paper_II

# Latent space analysis
! move-dl task=random_continuous_paper_II__latent data=random_continuous_paper_II task.model.cuda={CUDA}

# Identify assoc ks
! move-dl task=random_continuous_paper_II__id_assoc_ks data=random_continuous_paper_II task.model.cuda={CUDA}

[INFO  - encode_data]: Beginning task: encode data
[INFO  - encode_data]: Encoding 'random.random_all_sim.Categorical_A'
[INFO  - encode_data]: Encoding 'random.random_all_sim.Categorical_B'
[INFO  - encode_data]: Encoding 'random.random_all_sim.Continuous_A'
[INFO  - encode_data]: Encoding 'random.random_all_sim.Continuous_B'
[INFO  - analyze_latent]: Beginning task: analyze latent space
[INFO  - analyze_latent]: Generating visualizations
[INFO  - analyze_latent]: Projecting into latent space
[INFO  - analyze_latent]: Reconstructing
[INFO  - analyze_latent]: Computing reconstruction metrics
[INFO  - analyze_latent]: Computing feature importance
[INFO  - identify_associations]: Perturbing dataset: 'random.random_all_sim.Continuous_B'
[INFO  - identify_associations]: Beginning task: identify associations continuous (ks)
[INFO  - identify_associations]: Perturbation type: plus_std
[INFO  - identify_associations]: Training models
[INFO  - identify_associations]: Suggested absolute KS thre

## Analyzing the results

### Latent space analysis

In [None]:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / 'reconstruction_metrics.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()

In [None]:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / 'loss_curve.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()

**Loss curves:**

The overall loss is a combination of Cross-Entropy for the categorical variables, SSE for continuous variables and a KLD term to enforce the sample distribution in latent space to be like our prior, a Normal Gaussian.

Note:
1. We are under a low regularization regime, KLD has no influence on the latent space distribution of samples.
2. The model benefits a lot from reconstructing properly the categoical variables, seen as a Cross-Entropy loss going to zero.

In [None]:
categorical_variable = "Categorical_A_1"
continuous_variables = ["Continuous_B_6", "Continuous_B_8"]

# Categorical variable:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / f'latent_space_{categorical_variable}.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()

# Continuous variable I:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / f'latent_space_{continuous_variables[0]}.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()

# Continuous variable II:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / f'latent_space_{continuous_variables[1]}.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()


**Latent space visualization:**

This is a Umap representation of our 3D latent space.

*Samples are ordered in clusters according to their categorical labels*

We can already see that the network splits the latent space into clusters to perfectly classify the samples according to their categorical features. In this case, we have 2 categorical variables with two classes each, yielding 4 clasters with unique labels.

*Samples are ordered following value gradients of learned continuous variables*

Here the Umap representations provide a hint but do not help to illustrate this idea. We'll clearly see this behavior in the next section.


In [None]:
categorical_dataset = "random.random_all_sim.Categorical_A"
continuous_dataset = "random.random_all_sim.Continuous_B"

# Categorical variable:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / f'feat_importance_{categorical_dataset}.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()

# Continuous variable:
results_latent_path = Path("./results_cont_paper_II/latent_space/")
img = mpimg.imread(results_latent_path / f'feat_importance_{continuous_dataset}.png')
imgplot = plt.imshow(img)
plt.axis('off')
plt.show()

**Feature importance analysis: SHAP**

The results obtained from SHAP analysis are a direct consequence of MOVE's way of organizing samples in latent space. Each dot in the plot corresponds to a sample, color coded by its label (top) for the categorical variable being shown, or its feature value of the continuous variable being shown. On the x axis we see the impact on latent space, which corresponds to the displacement of each sample when setting the value of the feature of interest to 0, sample by sample.

Note that the most important features are learned, i.e. samples with similar values move similarly. In addition, samples move more when removing these features. At the bottom of the list we find variables that either varied faster, i.e. in shorter distances, or variables that the model ignored.


> **📕 Theoretical insights:**
>
>An adaptation of the SHAP algorithm was used to define the importance that the model attributed to each input feature. The input dataset $X$, a matrix with $N_{samples}$ rows and $N_{features}$ columns, served as a reference and was encoded into its latent representation $z$, with $N_{samples}$ rows and $N_{latent}$ columns.  A perturbed dataset $X’$,  where the feature of interest had been substituted in all samples for the missing value (i.e. 0), was encoded into its latent representation $z’$. The induced movement of the samples in latent space was then computed as the Euclidean distance $d$ between both encodings, i.e. $d = z’ - z$. Finally, the overall movement of each sample was obtained by adding the movements in all latent components, as a sum of all elements in each row. This protocol was repeated for each feature, one by one, and the 10 features with the largest absolute sum difference across samples were taken to be the most important ones.

### Visualizing the perturbations

UMAP or t-SNE will inevitably distort and compress dimensions to ease visualization. However, if we compress the inputs to 3 latent dimensions, we can visualize the exact latent space, the exact movements induced in each sample after the perturbations, and hopefully get a clearer picture of what is actually going on!

> 💵 **The price to pay**:
>
> It is not always possible to compress the inputs to only 3 dimensions. In the end it depends on many different factors as we commented in the first section of the manuscript: the sparsity of the features, their entanglement and the feature to sample ratio will play a role in defining how many latent dimensions are necessary to compress and reconstruct the inputs. It is also up to us to decide how noisy can the reconstructions be for downstream tasks.

In [None]:
def plot_3D_latent_and_displacement(
    mu_baseline,
    mu_perturbed,
    feature_values,
    feature_name,
    show_baseline=True,
    show_perturbed=True,
    show_arrows=True,
    step: int=1,
    altitude: int=30,
    azimuth: int=45,
):
    """
    Plot the movement of the samples in the 3D latent space after perturbing one
    input variable.

    Args:
        mu_baseline:
            ND array with dimensions n_samples x n_latent_nodes containing
            the latent representation of each sample
        mu_perturbed:
            ND array with dimensions n_samples x n_latent_nodes containing
            the latent representation of each sample after perturbing the input
        feature_values:
            1D array with feature values to map to a colormap ("bwr"). Each sample is
            colored according to its value for the feature of interest.
        feature_name:
            name of the feature mapped to a colormap
        show_baseline:
            plot orginal location of the samples in the latent space
        show_perturbed:
            plot final location (after perturbation) of the samples in latent space
        show_arrows:
            plot arrows from original to final location of each sample
        angle:
            elevation from dim1-dim2 plane for the visualization of latent space.

    Raises:
        ValueError: If latent space is not 3-dimensional (3 hidden nodes).
    Returns:
        Figure
    """

    my_cmap = sns.color_palette("RdYlBu", as_cmap=True)

    eps = 1e-16
    if [np.shape(mu_baseline)[1], np.shape(mu_perturbed)[1]] != [3, 3]:
        raise ValueError(
            " The latent space must be 3-dimensional. Redefine num_latent to 3."
        )

    fig = plt.figure(layout="constrained", figsize=(6, 6))
    ax = fig.add_subplot(projection="3d")
    ax.view_init(altitude, azimuth)

    if show_baseline:
        vmin, vmax = np.min(feature_values[::step]), np.max(feature_values[::step])
        abs_max = np.max([abs(vmin), abs(vmax)])
        ax.scatter(
            mu_baseline[::step, 0],
            mu_baseline[::step, 1],
            mu_baseline[::step, 2],
            marker="o",
            c=feature_values[::step],
            s=15,
            lw=0,
            cmap=my_cmap,
        )
        ax.set_title(feature_name)

    if show_perturbed:
        ax.scatter(
            mu_perturbed[::step, 0],
            mu_perturbed[::step, 1],
            mu_perturbed[::step, 2],
            marker="o",
            c=feature_values[::step],
            s=15,
            label="perturbed",
            lw=0,
        )
    if show_arrows:
        u = mu_perturbed[::step, 0] - mu_baseline[::step, 0]
        v = mu_perturbed[::step, 1] - mu_baseline[::step, 1]
        w = mu_perturbed[::step, 2] - mu_baseline[::step, 2]

        module = np.sqrt(u * u + v * v + w * w)

        mask = module > eps

        max_u, max_v, max_w = np.max(abs(u)), np.max(abs(v)), np.max(abs(w))

        # Arrow colors will be weighted contributions of red -> dim1, green -> dim2, and blue-> dim3. I.e. purple arrow means movement in dims 1 and 3
        colors = [
            (abs(du) / max_u, abs(dv) / max_v, abs(dw) / max_w, 0.7)
            for du, dv, dw in zip(u, v, w)
        ]
        ax.quiver(
            mu_baseline[::step, 0][mask],
            mu_baseline[::step, 1][mask],
            mu_baseline[::step, 2][mask],
            u[mask],
            v[mask],
            w[mask],
            color=colors,
            lw=.8,
            )
    ax.set_xlabel("Dim 1")
    ax.set_ylabel("Dim 2")
    ax.set_zlabel("Dim 3")


    return fig

In [None]:
! mkdir -p figures
feature_list = [("Categorical_A", 1),("Continuous_B",8)]
figure_path = Path("./figures/")
results_path = Path("./results/identify_associations/")

# Load latent space locations: Shape = (N_samples, N_latent, N_perturb +1)
latent_space_location = np.load("./results_cont_paper_II/identify_associations/latent_location.npy")
latent_space_baseline = latent_space_location[:,:,-1]

for (dataset, feature) in feature_list:
    feature_values = pd.read_csv(f"./synthetic_data_II/random.random_all_sim.{dataset}.tsv", sep="\t")
    feature_values = feature_values[dataset + '_' + str(feature)].values

    # # Plot latent space:
    pic_num = 0
    n_pictures = 50

    for azimuth, altitude in zip(
        np.linspace(0, 60, n_pictures), np.linspace(15, 60, n_pictures)
    ):

        title = dataset + '_' + str(feature)

        fig = plot_3D_latent_and_displacement(
            latent_space_baseline,
            latent_space_baseline,
            feature_values=feature_values,
            feature_name=f"Sample movement",
            show_baseline=True,
            show_perturbed=False,
            show_arrows=False,
            step=1,
            altitude=altitude,
            azimuth=azimuth,
        )

        fig.savefig(figure_path / f"3D_latent_movement_{pic_num}_perturbed_feature.png", dpi=50)
        plt.close(fig)

        if "Continuous" in dataset:
            latent_space_perturbed = latent_space_location[:,:,feature-1]
            fig = plot_3D_latent_and_displacement(
                latent_space_baseline,
                latent_space_perturbed,
                feature_values=feature_values,
                feature_name=f"{title}",
                show_baseline=False,
                show_perturbed=False,
                show_arrows=True,
                altitude=altitude,
                azimuth=azimuth,
            )
            fig.savefig(figure_path / f"3D_latent_movement_{pic_num}_arrows.png", dpi=50)
            plt.close(fig)

        pic_num += 1


    # Creating gifs
    plot_types = ["arrows", "perturbed_feature"] if "Continuous" in dataset else ["perturbed_feature"]
    for plot_type in plot_types:
        frames = [
            Image.open(figure_path / f"3D_latent_movement_{pic_num}_{plot_type}.png")
            for pic_num in range(n_pictures)
        ]  # sorted(glob.glob("*3D_latent*"))]
        frames[0].save(
            figure_path / f"{plot_type}_{title}.gif",
            format="GIF",
            append_images=frames[1:],
            save_all=True,
            duration=75,
            loop=0,
        )

**Visualize gifs:**



We started with multimodal samples, for which we simulated the measurement of 20 continuous features and 2 categorical features. Then we used MOVE to compress the representations of these samples to 3 dimensions in the network's latent layer. We will now visualize the real, complete 3D latent space.

Here, each input sample is represented by a small sphere. Each sample's latent representation $z$ is a vector with three components, where each component corresponds to the value of a latent node.

For example, sample 1 might have a latent vector $z = (\text{Value of node 1},\text{Value of node 2}, \text{Value of node 3}) = (0.5,-5,-1)$. This sample would then be plotted at 0.5 for dimension 1, -5 for dimension 2, etc.


We will plot:

1) The latent space where each sample is color coded by the categorical label of Categorical_A_1, which has two possible classes.

2) The latent space where each sample is color coded by the value of the continuous feature Continuous_B_8.

3) The movement of each sample when adding a small perturbation (1 std) to the original value of Continuous_B_8

In [None]:
# Path to your local GIF file
gif_list = ['perturbed_feature_Categorical_A_1.gif', 'perturbed_feature_Continuous_B_8.gif','arrows_Continuous_B_8.gif' ]
gif_path = 'figures/perturbed_feature_Categorical_A_1.gif'

# Display the GIF

for gif_path in gif_list:
  display(Image_display(filename="figures/" + gif_path))

Here we can clearly see that:

1. MOVE orders samples in clusters according to their categorical labels.
2. MOVE orders samples in latent space following value gradients of learned continuous features.
3. When performing perturbations on learned continuous variables, the latent representations of the samples move following the local value gradient of the perturbed feature. Of note, they do not necessarily all move in the same direction.

# Final remarks: THE END!

Thanks for going through the tutorial! We hope you found it useful.


