In [30]:
from pathlib import Path

import hydra
import numpy as np
import torch
from omegaconf import OmegaConf

from move.data import io
from move.data import dataloaders as dl
from move.data import preprocessing as pp
from move.data.perturbations import perturb_data
from move.models import VAE

In [63]:
with hydra.initialize("src/move/conf", version_base="1.1"):
    config = hydra.compose("main", overrides=[
        "data.categorical_inputs=[{name:random.small.drugs,weight:1}]",
        "data.continuous_inputs=[{name:random.small.proteomics,weight:1},{name:random.small.metagenomics,weight:1}]",
        "data.raw_data_path=tutorial/data",
        "data.interim_data_path=tutorial/interim_data",
        "data.processed_data_path=tutorial/results",        
        "task=identify_associations_ttest",
        "task.batch_size=50",
        "task.num_refits=30",
        "task.target_dataset=random.small.drugs",
        "task.target_value=1",
        "task.model.num_hidden=[1000]",
        "task.model.num_latent=150",
        "task.training_loop.num_epochs=40",
    ])

In [67]:
task_config = config.task

interim_path = Path(config.data.interim_data_path)
output_path = Path(config.data.processed_data_path) / "identify_associations"
output_path.mkdir(exist_ok=True, parents=True)

cat_list, cat_names, con_list, con_names = io.read_data(config)
mappings = io.load_mappings(interim_path / "mappings.json")
target_mapping = mappings[task_config.target_dataset]
target_value = pp.one_hot_encode_single(target_mapping, task_config.target_value)

In [33]:
train_mask, train_dataloader = dl.make_dataloader(
    cat_list,
    con_list,
    shuffle=True,
    batch_size=task_config.batch_size,
    drop_last=True,
)
con_shapes = [con.shape[1] for con in con_list]
dataloaders = perturb_data(
    cat_list,
    con_list,
    config.data.categorical_names,
    task_config.target_dataset,
    target_value,
)
baseline_dataloader = dataloaders[-1]
baseline_dataset = baseline_dataloader.dataset
num_perturbed = len(dataloaders) - 1  # F
num_samples = len(baseline_dataloader.sampler)  # N
num_continuous = sum(con_shapes)  # C

orig_con = baseline_dataset.con_all
nan_mask = (orig_con == 0).numpy()  # NaN values encoded as 0s

target_dataset_idx = config.data.categorical_names.index(task_config.target_dataset)
target_dataset = cat_list[target_dataset_idx]
feature_mask = np.all(target_dataset == target_value, axis=2)
feature_mask |= np.sum(target_dataset, axis=2) == 0

In [46]:
feature_mask = np.all(target_dataset == target_value, axis=2) | (np.sum(target_dataset, axis=2) == 0)

In [69]:
normalizer = 1 / task_config.num_refits
normalizer

0.03333333333333333

In [73]:
model: VAE = hydra.utils.instantiate(
    task_config.model,
    continuous_shapes=baseline_dataset.con_shapes,
    categorical_shapes=baseline_dataset.cat_shapes,
)
model.load_state_dict(torch.load(output_path / "model_b0.pt"))
model.eval()

VAE (1240 ⇄ 1000 ⇄ 150)

In [127]:
k = np.empty((num_perturbed, num_continuous))
diff = np.zeros((num_perturbed, num_samples, num_continuous))
for j in range(task_config.num_refits):
    _, baseline_recon = model.reconstruct(baseline_dataloader)
    for i in range(num_perturbed):
        _, perturb_recon = model.reconstruct(dataloaders[i])
        diff[i, :, :] += (perturb_recon - baseline_recon) * normalizer

In [130]:

for i in range(num_perturbed):
    mask = feature_mask[:, [i]] | nan_mask
    delta = np.ma.masked_array(diff[i, :, :], mask=mask)
    prob = np.mean(delta > 1e-8, axis=0).data
    k[i, :] = np.abs(np.log(prob + 1e-8) - np.log(1 - prob + 1e-8))



In [144]:
b_prob = np.exp(k) / (1 + np.exp(k))
sort_ids = np.argsort(k, axis=None)[::-1]
b_prob = np.take(b_prob, sort_ids)

fdr = np.cumsum(1 - b_prob) / np.arange(1, b_prob.size + 1)
is_sig = fdr < 0.05

sort_ids[is_sig]

array([], dtype=int64)

In [148]:
Path(".").resolve()

WindowsPath('C:/Users/zqw270/Documents/GitHub/MOVE_fork')

In [93]:
mask = feature_mask[:, [0]] | nan_mask
delta = np.ma.masked_array(np.mean(diff, axis=0), mask)

In [119]:
prob = np.ma.compressed(np.mean(delta > 1e-8, axis=0))
k = np.log(prob + 1e-8) - np.log(1 - prob + 1e-8)
#k = np.abs(k)

In [120]:
prob_abs = np.exp(k) / (1 + np.exp(k))
sort_ids = np.argsort(k, axis=None)[::-1]
prob_sort = np.take(prob_abs, sort_ids)

In [123]:
fdr = np.cumsum(1 - prob_sort) / np.arange(1, prob.size + 1)
prob_sort[fdr <= 0.05]

array([], dtype=float64)

In [126]:
k.shape

(1200,)