## Extracting drug associations (Bayesian)

This notebook runs part of the Multi-Omics Variational autoEncoder (MOVE) framework for using the structure the VAE has identified for extracting categorical data assositions across all continuous datasets.

Similar to notebook #5, here we use MOVE to extract associations between the categorical data (drug status in this case) and the continuous omics datasets. However, here we follow an approach based on Bayesian decision theory. 

In [None]:
import numpy as np
import pandas as pd
import torch

import move
from move.data.dataloaders import make_dataloader
from move.models import VAE

@torch.no_grad()
def get_recon(model, dataloader):
    batch = iter(dataloader).next()  # there is only 1 batch
    batch = torch.cat(batch, 1).float().to(model.device)
    _, recon_con, _, _ = model(batch)
    return recon_con

First, we fetch the data and make perturbations by changing the drug status of all patients to 1 (i.e., took the drug). We are analyzing 20 drugs, so we repeat this process for each drug.

In [None]:
path = None  # has to be defined

cat_list, cat_names, con_list, con_names, *_ = move.utils.data_utils.get_data()
atc_data = pd.read_csv(path, index_col=None)

drug_data = cat_list[-1]  # drug dataset is last (not always)
ndrugs = drug_data.shape[1]
drug_mask = ~(drug_data[:, :, 1] == 1)  # select who took a drug (complement did not take drug)

nsamples = drug_data.shape[0]

_, dataloader = make_dataloader(cat_list=cat_list, con_list=con_list)

ncontinuous = dataloader.dataset.con_all.shape[1]
con_shapes = dataloader.dataset.con_shapes

ncategorical = dataloader.dataset.cat_all.shape[1]
cat_shapes = dataloader.dataset.cat_shapes

dataloaders = []

# Perturb dataset
for i in range(ndrugs):
    perturbed_drug = np.copy(cat_list[-1])  # drug dataset is last on cat_list
    perturbed_drug[:,i,:] = [1, 0]  # make the perturbation
    perturbed_cat_list = cat_list[:-1] + [perturbed_drug]  # replace with perturbed dataset
    # generate and save new dataloader
    _, perturbed_dataloader = make_dataloader(cat_list=perturbed_cat_list, con_list=con_list)
    dataloaders.append(perturbed_dataloader)
    # update drug mask, we have to consider ATC subgroups
    atc_subgroups = atc_data[lambda df: df.idx == i]["atc_subgroup"].tolist()
    linked_drug_ids = atc_data[lambda df: df.atc_subgroup.isin(atc_subgroups)]["idx"].tolist()
    if len(linked_drug_ids) == 1:
        continue
    # select patients who took (or have no info on) a drug in the same ATC subgroup
    #   complement of all who did not take any drug in the same ATC
    atc_drug_mask = ~np.all(drug_data[:, linked_drug_ids, 1] == 1, axis=1)
    # => took drug X or a drug in the same ATC subgroup
    drug_mask[:, i] = atc_drug_mask

dataloaders.append(dataloader)

orig_con = dataloaders[-1].dataset.con_all
nan_mask = (orig_con == 0).numpy()  # NaN values were transformed to 0s

Next, we will calculate the reconstruction of the baseline data and the reconstruction of the perturbed data. The method can be summarized as:

1. Load pretrained model (see previous notebooks for model training and hyperparameter tuning)
2. Obtain reconstruction for the original and perturbed data
3. Calculate difference between reconstructions and average across # refits (we ensemble different number of refits for benchmarking)
4. Calculate Bayes factor $K = \log{p(x_\text{perturbed} > x_0)} - \log{p(x_\text{perturbed} \leq x_0)}$

In [None]:
device = torch.device("cpu")

# define # steps at which we do calculations
plot_steps = [1, 2, 3, 4, 5, 6, 8, 10, 12, 15, 17, 20, 22, 25, 27, 30, 35, 40, 45, 50]
n_refits = plot_steps[-1]

eps = 1e-8
mean_diff = torch.zeros((ndrugs, nsamples, ncontinuous)).to(device)
bayes_k = np.empty((len(plot_steps), ndrugs, ncontinuous))

j = 0
for i in range(n_refits):
    ## Fetch trained model
    model = VAE(categorical_shapes=cat_shapes, continuous_shapes=con_shapes, nhiddens=[2000], num_latent=nlatent,
                beta=0.0001, cat_weights=[1,1,1], con_weights=[2,1,1,1,1,1,1], dropout=0.1, cuda=False).to(device)

    state_dict_path = path # path to saved checpoint
    model.load_state_dict(torch.load(state_dict_path, map_location=device))
    model.eval()

    ## Calculate reconstructions

    # Get reconstruction from baseline model
    base_recon = get_recon(model, dataloaders[-1])
    # Add differences [D x N x F]
    for k in range(ndrugs):
        diff = (get_recon(model, dataloaders[k]) - base_recon)
        mean_diff[k, :, :] += diff

    try: # only do calculation every plot step
        j = plot_steps.index(i + 1)
    except ValueError:
        continue

    avg_div = plot_steps[j]

    for drug_idx in range(ndrugs):
        # select individuals who took the drug (originally), shape: N x 1
        samples_mask = drug_mask[:, [drug_idx]]
        # average by number of refits
        diff = mean_diff[drug_idx, :, :] / avg_div  # shape: N x F
        diff = np.ma.array(diff.numpy(), mask=samples_mask | nan_mask)  # mask individuals who took drug and NaN values
        # perturbed > baseline => perturbed - baseline > 0
        prob = np.mean(diff > 1e-8, axis=0).data  # shape: F
        k = np.log(prob + eps) - np.log(1 - prob + eps)
        bayes_k[j, drug_idx, :] = k

Finally, we use the Bayes factors to rank features (the higher the factor the more significant the feature), and then we
use the cumulative evidence to establish a false discovery rate (FDR). Finally, by setting a threshold on the FDR
(e.g., 0.05), we select the significant features.

In [None]:
con_shapes = [0] + [x.shape[1] for x in con_list]
con_labels = ["clinical", "diet_wearables", "proteomics", "target_metabolomics", "untarget_metabolomics", "transcriptomics", "metagenomics"]

con_df = pd.DataFrame(dict(x=reduce(list.__add__, con_names))).reset_index()
con_df.columns = ["feature_id", "feature_name"]

drug_df = pd.DataFrame(dict(x=cat_names[-1])).reset_index()
drug_df.columns = ["drug_id", "drug_name"]

fdr_threshold = 0.05

ks = np.abs(bayes_k)
probas = np.exp(ks) / (1 + np.exp(ks))

for n in plot_steps:
    j = plot_steps.index(n)
    sort_ids = np.argsort(ks[j, :, :], axis=None)[::-1]
    p = np.take(probas[j, :, :], sort_ids)

    fdr = np.cumsum(1 - p) / np.arange(1, p.size + 1)
    idx = np.argmin(np.abs(fdr - fdr_threshold), axis=0)

    colname = str(fdr_threshold).replace(".", "_")
    sig_ids = sort_ids[:idx]
    sig_ids = np.vstack((sig_ids // ncontinuous, sig_ids % ncontinuous)).T
    df = pd.DataFrame(sig_ids, columns=["drug_id", "feature_id"]).sort_values("drug_id")
    df = df.merge(drug_df, on="drug_id").merge(con_df, on="feature_id")
    df["dataset"] = pd.cut(df.feature_id, bins=np.cumsum(con_shapes), right=False, labels=con_labels)
    df.to_csv(path)