In [None]:
from functools import reduce
from pathlib import Path
from typing import Literal, Sized, Union, cast

import hydra
import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf

from move.conf.schema import (
    IdentifyAssociationsBayesConfig,
    IdentifyAssociationsConfig,
    IdentifyAssociationsTTestConfig,
    MOVEConfig,
)
from move.core.logging import get_logger
from move.core.typing import FloatArray, IntArray
from move.data import io
from move.data.dataloaders import MOVEDataset, make_dataloader
from move.data.perturbations import perturb_categorical_data
from move.data.preprocessing import one_hot_encode_single
from move.models.vae import VAE

In [None]:
pwd

'/Users/wkq953/Desktop/CPR/4. codes/move/tutorial/notebooks'

In [None]:
cd .. 

/Users/wkq953/Desktop/CPR/4. codes/move/tutorial


## 1. Initiate MOVE with test dataset

### Encoding data

In [None]:
from move.data import io
from move.tasks import encode_data
import numpy as np
#config = io.read_config("random_small", "encode_data")
#config = io.read_config("random_continuous", "encode_data")
#config = io.read_config("random_catagorical", "encode_data")
config = io.read_config("random_test", "encode_data")

encode_data(config.data)

In [None]:
from pathlib import Path
path = Path(config.data.interim_data_path)
cat_datasets, cat_names, con_datasets, con_names = io.load_preprocessed_data(path, config.data.categorical_names, config.data.continuous_names)
cat_names
dataset_names = config.data.categorical_names + config.data.continuous_names
for dataset, dataset_name in zip(cat_datasets + con_datasets, dataset_names):
    print(f"{dataset_name}: {dataset.shape}")

random.test.drugs: (500, 5, 2)
random.test.metagenomics: (500, 20, 2)
random.test.proteomics: (500, 20)


### Tune model: no need to run currently

In [None]:
## another way to initial the model 
config = io.read_config("random_test", "tune_model_reconstruction")
config = io.read_config("random_test", "tune_model_stability")
config

{'seed': None, 'data': {'raw_data_path': 'data/test', 'interim_data_path': 'test_data/interim_data/', 'results_path': 'test_data/results/', 'sample_names': 'random.test.ids', 'categorical_inputs': [{'name': 'random.test.drugs'}, {'name': 'random.test.metagenomics'}], 'continuous_inputs': [{'name': 'random.test.proteomics'}], 'categorical_names': '${names:${data.categorical_inputs}}', 'continuous_names': '${names:${data.continuous_inputs}}', 'categorical_weights': '${weights:${data.categorical_inputs}}', 'continuous_weights': '${weights:${data.continuous_inputs}}'}, 'task': {'batch_size': 10, 'model': {'_target_': 'move.models.vae.VAE', 'cuda': False, 'categorical_weights': '${weights:${data.categorical_inputs}}', 'continuous_weights': '${weights:${data.continuous_inputs}}', 'num_hidden': [1000], 'num_latent': 150, 'beta': 0.0001, 'dropout': 0.1}, 'training_loop': {'_target_': 'move.training.training_loop.training_loop', 'num_epochs': 40, 'lr': 0.0001, 'kld_warmup_steps': [15, 20, 25], 

In [None]:
! move-dl experiment=random_test__tune_reconstruction
! move-dl experiment=random_test__tune_stability

In [None]:
from itertools import chain
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
results = pd.read_csv("test_data/results/tune_model/reconstruction_stats.tsv", sep="\t")

### Latent space analysis

In [None]:
from move.tasks import analyze_latent
#config = io.read_config("random_catagorical", "random_catagorical__latent")
config = io.read_config("random_test", "random_test__latent")
#config = io.read_config("random_continuous", "random_continuous__latent")
#config = io.read_config("random_small", "random_small__latent")
#print(OmegaConf.to_yaml(config, resolve=True))

analyze_latent(config)

[INFO  - analyze_latent]: Beginning task: analyze latent space
[INFO  - analyze_latent]: Projecting into latent space
[INFO  - analyze_latent]: Reconstructing
[INFO  - analyze_latent]: Computing reconstruction metrics
[INFO  - analyze_latent]: Computing feature importance


## 2. Association test

In [None]:
from move.tasks import identify_associations
# config = io.read_config("random_small", "random_small__id_assoc_bayes")
config = io.read_config("random_test", "random_test__id_assoc_bayes")
#identify_associations(config)
#identify_associations_cat(config)

In [None]:
config

{'seed': None, 'data': {'raw_data_path': 'data/', 'interim_data_path': 'interim_data_test/', 'results_path': 'results_test/', 'sample_names': 'random.test.ids', 'categorical_inputs': [{'name': 'random.test.drugs'}, {'name': 'random.test.metagenomics'}], 'continuous_inputs': [{'name': 'random.test.proteomics'}], 'categorical_names': '${names:${data.categorical_inputs}}', 'continuous_names': '${names:${data.continuous_inputs}}', 'categorical_weights': '${weights:${data.categorical_inputs}}', 'continuous_weights': '${weights:${data.continuous_inputs}}'}, 'task': {'batch_size': 10, 'model': {'_target_': 'move.models.vae.VAE', 'cuda': False, 'categorical_weights': '${weights:${data.categorical_inputs}}', 'continuous_weights': '${weights:${data.continuous_inputs}}', 'num_hidden': [100], 'num_latent': 10, 'beta': 0.0001, 'dropout': 0.1}, 'training_loop': {'_target_': 'move.training.training_loop.training_loop', 'num_epochs': 40, 'lr': 0.0001, 'kld_warmup_steps': [15, 20, 25], 'batch_dilation_

## 3. Deep into the function

In [None]:
from pathlib import Path
import torch
import hydra
from move.data import io
from move.data.dataloaders import make_dataloader
from move.data.perturbations import perturb_categorical_data

cfg = io.read_config("random_test", "random_test__id_assoc_bayes")

cat_list, cat_names, con_list, con_names = io.load_preprocessed_data(
    Path(cfg.data.interim_data_path),
    cfg.data.categorical_names,
    cfg.data.continuous_names,
)

baseline_dataloader = make_dataloader(
    cat_list,
    con_list,
    shuffle=False,
    batch_size=cfg.task.batch_size,
)
baseline_dataset = baseline_dataloader.dataset

dataloaders = perturb_categorical_data(
    baseline_dataloader,
    cfg.data.categorical_names,
    "random.test.drugs",
    np.array([0, 1]),
)

In [None]:
model = hydra.utils.instantiate(
    cfg.task.model,
    continuous_shapes=baseline_dataset.con_shapes,
    categorical_shapes=baseline_dataset.cat_shapes,
)
# require pre-trained model
model.load_state_dict(torch.load("results_test/latent_space/model.pt"))
model.eval()

VAE (70 ⇄ 12 ⇄ 8)

### Question part

1. How to get the NA mask from original catagorical dataset since drug and meta are combine together

In [None]:
orig_cat = baseline_dataset.cat_all # 2D N x Cat (500,50) 5*2+20*2
#orig_cat = orig_cat[-1] # remove drug data
orig_cat[0]


tensor([0., 1., 1., 0., 0., 1., 1., 0., 1., 0., 0., 1., 1., 0., 1., 0., 1., 0.,
        0., 1., 0., 1., 1., 0., 1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1.,
        1., 0., 1., 0., 0., 1., 0., 1., 0., 1., 0., 1., 1., 0.])

2. how to deal with refit number 

probability / num_refit

3. cat shape

## Test function

In [68]:
__all__ = ["identify_associations"]

from functools import reduce
from os.path import exists
from pathlib import Path
from typing import Literal, Sized, Union, cast

import hydra
import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf
from scipy.stats import ks_2samp, pearsonr  # type: ignore
from torch.utils.data import DataLoader

from move.analysis.metrics import get_2nd_order_polynomial

from move.conf.schema import (
    IdentifyAssociationsBayesConfig,
    IdentifyAssociationsBayesCatConfig,
    IdentifyAssociationsConfig,
    IdentifyAssociationsKSConfig,
    IdentifyAssociationsTTestConfig,
    MOVEConfig,
)
from move.core.logging import get_logger
from move.core.typing import BoolArray, FloatArray, IntArray
from move.data import io
from move.data.dataloaders import MOVEDataset, make_dataloader
from move.data.perturbations import (
    ContinuousPerturbationType,
    perturb_categorical_data,
    perturb_continuous_data_extended,
)
from move.data.preprocessing import one_hot_encode_single
from move.models.vae import VAE
from move.visualization.dataset_distributions import (
    plot_correlations,
    plot_cumulative_distributions,
    plot_feature_association_graph,
    plot_reconstruction_movement,
)

TaskType = Literal["bayes", "bayes_cat","ttest", "ks"]
CONTINUOUS_TARGET_VALUE = ["minimum", "maximum", "plus_std", "minus_std"]

def prepare_for_categorical_perturbation(
    config: MOVEConfig,
    interim_path: Path,
    baseline_dataloader: DataLoader,
    cat_list: list[FloatArray],
) -> tuple[list[DataLoader], BoolArray, BoolArray,]:
    """
    This function creates the required dataloaders and masks
    for further categorical association analysis.

    Args:
        config: main configuration file
        interim_path: path where the intermediate outputs are saved
        baseline_dataloader: reference dataloader that will be perturbed
        cat_list: list of arrays with categorical data

    Returns:
        dataloaders: all dataloaders, including baseline appended last.
        nan_mask: mask for Nans
        feature_mask: masks the column for the perturbed feature.
    """

    # Read original data and create perturbed datasets
    task_config = cast(IdentifyAssociationsConfig, config.task)
    logger = get_logger(__name__)

    # Loading mappings:
    mappings = io.load_mappings(interim_path / "mappings.json")
    target_mapping = mappings[task_config.target_dataset]
    target_value = one_hot_encode_single(target_mapping, task_config.target_value)
    logger.debug(
        f"Target value: {task_config.target_value} => {target_value.astype(int)[0]}"
    )

    dataloaders = perturb_categorical_data(
        baseline_dataloader,
        config.data.categorical_names,
        task_config.target_dataset,
        target_value,
    )
    dataloaders.append(baseline_dataloader)

    
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

    assert baseline_dataset.con_all is not None
    orig_con = baseline_dataset.con_all
    nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s
    logger.debug(f"# NaN values: {np.sum(nan_mask)}/{orig_con.numel()}")

    target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)
    target_dataset = cat_list[target_dataset_idx]
    feature_mask = np.all(target_dataset == target_value, axis=2)  # 2D: N x P
    feature_mask |= np.sum(target_dataset, axis=2) == 0
    cat_list_wo_target = cat_list.copy()
    cat_list_wo_target.pop(target_dataset_idx)
    nan_mask_cat = [np.all(i == [0,0], axis = 2) for i in cat_list_wo_target] # in case there are more than one catagorical dataset
    nan_mask_cat = np.hstack(nan_mask_cat)

    return (
        dataloaders,
        nan_mask,
        feature_mask,
        nan_mask_cat
    )

def _bayes_cat_approach(
    config: MOVEConfig,
    task_config: IdentifyAssociationsBayesCatConfig,
    train_dataloader: DataLoader,
    baseline_dataloader: DataLoader,
    dataloaders: list[DataLoader],
    models_path: Path,
    num_perturbed: int,
    num_samples: int,
    num_catagorical: int,
    nan_mask_cat: BoolArray,
    feature_mask: BoolArray,
) -> tuple[Union[IntArray, FloatArray], ...]:

    assert task_config.model is not None
    device = torch.device("cuda" if task_config.model.cuda == True else "cpu")

    # Train models
    logger = get_logger(__name__)
    logger.info("Training models")
    mean_prob = np.zeros((num_perturbed, num_catagorical))
    normalizer = 1 / task_config.num_refits
    target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)

    # Last appended dataloader is the baseline
    baseline_dataset = cast(MOVEDataset, baseline_dataloader.dataset)

    for j in range(task_config.num_refits):
        # Initialize model
        model: VAE = hydra.utils.instantiate(
            task_config.model,
            continuous_shapes=baseline_dataset.con_shapes,
            categorical_shapes=baseline_dataset.cat_shapes,
        )
        if j == 0:
            logger.debug(f"Model: {model}")

        # Train/reload model
        model_path = models_path / f"model_{task_config.model.num_latent}_{j}.pt"
        if model_path.exists():
            logger.debug(f"Re-loading refit {j + 1}/{task_config.num_refits}")
            model.load_state_dict(torch.load(model_path))
            model.to(device)
        else:
            logger.debug(f"Training refit {j + 1}/{task_config.num_refits}")
            model.to(device)
            hydra.utils.call(
                task_config.training_loop,
                model=model,
                train_dataloader=train_dataloader,
            )
            if task_config.save_refits:
                torch.save(model.state_dict(), model_path)
        model.eval()

        # Calculate baseline reconstruction
        baseline_recat, _ = model.reconstruct(baseline_dataloader) # including drug data
        baseline_recat.pop(target_dataset_idx)
        baseline_recat = np.hstack(baseline_recat) # combine all catagoical data without drug data into one array


        # Calculate perturb reconstruction => keep track of mean difference
        for i in range(num_perturbed):
            perturb_recat, _ = model.reconstruct(dataloaders[i])
            perturb_recat.pop(target_dataset_idx) 
            perturb_recat = np.hstack(perturb_recat)
            diff_recat = (perturb_recat != baseline_recat) # T: if the class change 
            mask_cat = feature_mask[:, [i]]| nan_mask_cat
            diff_recat = np.ma.masked_array(diff_recat, mask = mask_cat)
            prob = np.ma.compressed(np.mean(diff_recat, axis=0))  # 1D: C
            mean_prob[i, :] = prob * normalizer


    # Calculate Bayes factors
    logger.info("Identifying significant features")
    bayes_k = np.empty((num_perturbed, num_catagorical))
    bayes_mask = np.zeros(np.shape(bayes_k))
    for i in range(num_perturbed):
        bayes_k[i, :] = np.log(mean_prob[i,] + 1e-8) - np.log(1 - mean_prob[i,] + 1e-8)

    bayes_mask[bayes_mask != 0] = 1
    bayes_mask = np.array(bayes_mask, dtype=bool)

    # Calculate Bayes probabilities
    bayes_abs = np.abs(bayes_k)
    bayes_p = np.exp(bayes_abs) / (1 + np.exp(bayes_abs))  # 2D: N x C
    bayes_abs[bayes_mask] = np.min(
        bayes_abs
    )  # Bring feature_i feature_i associations to minimum
    sort_ids = np.argsort(bayes_abs, axis=None)[::-1]  # 1D: N x C
    prob = np.take(bayes_p, sort_ids)  # 1D: N x C
    logger.debug(f"Bayes proba range: [{prob[-1]:.3f} {prob[0]:.3f}]")

    # Sort Bayes
    bayes_k = np.take(bayes_k, sort_ids)  # 1D: N x C

    # Calculate FDR
    fdr = np.cumsum(1 - prob) / np.arange(1, prob.size + 1)  # 1D
    idx = np.argmin(np.abs(fdr - task_config.sig_threshold))
    logger.debug(f"FDR range: [{fdr[0]:.3f} {fdr[-1]:.3f}]")

    return sort_ids[:idx], prob[:idx], fdr[:idx], bayes_k[:idx]

In [65]:
config = io.read_config("random_test", "random_test__id_assoc_bayes")
logger = get_logger(__name__)
task_config = cast(IdentifyAssociationsConfig, config.task)
interim_path = Path(config.data.interim_data_path)

models_path = interim_path / "models"
if task_config.save_refits:
    models_path.mkdir(exist_ok=True)

output_path = Path(config.data.results_path) / "identify_associations"
output_path.mkdir(exist_ok=True, parents=True)

# Load datasets:
cat_list, cat_names, con_list, con_names = io.load_preprocessed_data(
    interim_path,
    config.data.categorical_names,
    config.data.continuous_names,
)

train_dataloader = make_dataloader(
    cat_list,
    con_list,
    shuffle=True,
    batch_size=task_config.batch_size,
    drop_last=True,
)
con_shapes = [con.shape[1] for con in con_list]
target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)

if len(cat_list) >1:
    cat_shapes = [cat.shape[1] for cat in cat_list] # [5,20] [20]
    cat_shapes.pop(target_dataset_idx)
    num_catagorical = sum(cat_shapes)  # C

num_samples = len(cast(Sized, train_dataloader.sampler))  # N
num_continuous = sum(con_shapes)  # C

logger.debug(f"# continuous features: {num_continuous}")

# Creating the baseline dataloader:
baseline_dataloader = make_dataloader(
    cat_list, con_list, shuffle=False, batch_size=task_config.batch_size
)

(dataloaders, nan_mask, feature_mask, nan_mask_cat) = prepare_for_categorical_perturbation(
    config, interim_path, baseline_dataloader, cat_list)

num_perturbed = len(dataloaders) - 1 

5

In [73]:
sig_ids, *extra_cols = _bayes_cat_approach(
    config,
    task_config,
    train_dataloader,
    baseline_dataloader,
    dataloaders,
    models_path,
    num_perturbed,
    num_samples,
    num_catagorical,
    nan_mask_cat,
    feature_mask,
)

[INFO  - __main__]: Training models
[INFO  - __main__]: Identifying significant features


### Test

In [None]:
# get mask
## 1. drug mask 
feature_mask = np.all(cat_list[0] == [0, 1], axis=2) # (500,5) mask the samples with drug value = 1
feature_mask |= np.sum(cat_list[0], axis=2) == 0 # (500,5) mask the samples with drug value = NA [0,0]
## 2. meta mask
nan_mask_cat = np.all(cat_list[-1] == [0,0], axis = 2 )  # remove NA values encoded as [0,0]
## 3. protein mask
orig_con = baseline_dataset.con_all # 2D: N x C
nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s

In [None]:
cat_shapes = [cat.shape[1] for cat in cat_list]
cat_shapes = cat_shapes[1:]
cat_shapes

[20]

In [None]:
# test in one drug
drug_idx = 0 # for i in range(num_perturbed):

baseline_recat, baseline_recon = model.reconstruct(baseline_dataloader)
perturb_recat, perturb_recon = model.reconstruct(dataloaders[drug_idx])
diff_recon = (perturb_recon - baseline_recon) 

baseline_recat = baseline_recat[-1] # remove drug data 
perturb_recat = perturb_recat[-1] # remove drug data
diff_recat = (perturb_recat == baseline_recat) # (500,20)
normalizer = 1 / task_config.num_refits # 1/
mask_cat = feature_mask[:, [drug_idx]]| nan_mask_cat
diff_recat = np.ma.masked_array(diff_recat, mask = mask_cat)  # 2D: N x C  only consider the samples with drug_i from 0 to 1 
prob = np.ma.compressed(np.mean(diff_recat, axis=0))  
mean_pro = prob * normalizer
bayes_k = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8)
bayes_k

array([ 0.28271928,  0.62371866,  0.47692406,  0.64250343,  0.94024395,
        0.2135741 ,  1.64561752,  0.64250343,  0.69953696,  1.2582424 ,
        0.31753503,  1.16203409,  1.07044139,  0.66139847,  0.23080642,
        0.56798403,  0.37012573,  0.58646303, -0.30010459,  1.04818141])

In [None]:
mean_diff = np.zeros((5, 500, 20))
mean_diff[0, :, :] += diff_recat

In [None]:
mean_diff.shape
mean_diff[0, :, :] += diff_recat
mean_diff[0, :, :] += diff_recat

#mean_diff[0, :, :] = np.append (mean_diff[0, :, :] ,diff_recat)
mean_diff.shape


(5, 500, 20)