In [None]:
import numpy as np
import os
import pickle
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.express as px
import plotly as py
import pandas as pd
from chart_studio.plotly import plot, iplot

# from plotly.offline import init_notebook_mode, iplot
from tqdm import tqdm_notebook

from scvi.dataset import GeneExpressionDataset
from scvi.models import VAE, IAVAE
from scvi.inference import UnsupervisedTrainer
from scvi.utils import demultiply, make_dir_if_necessary, predict_de_genes, save_fig, load_pickle, save_pickle
from scvi_utils import estimate_de_proba, estimate_lfc_density, estimate_lfc_mean, train_model, multi_train_estimates
from R_interop import all_predictions, all_de_predictions


N_EPOCHS = 100
DELTA = 0.5
SIZES = [5, 10, 20, 30, 50, 100]
SIZE = 100
N_SIZES = len(SIZES)
DO_CLOUD = True
Q0 = 5e-2
N_TRAININGS = 5
N_PICKS = 10

np.random.seed(42)

PATH_TO_SCRIPTS = "/home/ubuntu/conquer_comparison/scripts"
DIR_PATH = 'lfc_estimates/symsim'
make_dir_if_necessary(DIR_PATH)

# Generate Dataset

In [None]:
import chart_studio.plotly as py
py.sign_in("pierreboyeau", "2wvdnWZ2Qut1zD07ADVy")

In [None]:
symsim_data_path = "/home/ubuntu/symsim_result/DE"

In [None]:
x_obs_all = pd.read_csv(
    os.path.join(symsim_data_path, "DE_med.obsv.3.csv"), index_col=0
).T

select_gene = np.where(x_obs_all.mean(0) <= 1000)[0]
x_obs = x_obs_all.iloc[:, select_gene]

batch_info = pd.read_csv(
    os.path.join(symsim_data_path, "DE_med.batchid.csv"), index_col=0
) - 1
metadata = pd.read_csv(
    os.path.join(symsim_data_path, "DE_med.cell_meta.csv"), index_col=0
)
true_ = pd.read_csv(
    os.path.join(symsim_data_path, "DE_med.true.csv"), index_col=0
).T.iloc[:, select_gene]
lfc_info = pd.read_csv(
    os.path.join(symsim_data_path, "med_theoreticalFC.csv"), index_col=0
).iloc[select_gene, :]

display("batch_info", batch_info.head())
display("metadata", metadata.head())
display("x_obs", x_obs.head())
display("true_", true_.head())
display("lfc_info", lfc_info.head())

In [None]:
x_obs.isna().sum().sum()

In [None]:
x_obs.describe()

In [None]:
np.sort(x_obs.max(0))[::-1][:10]

In [None]:
px.histogram(metadata, x="pop", width=400, height=300).show()
px.histogram(lfc_info, x="12", width=400, height=300).show()

In [None]:
dataset = GeneExpressionDataset()
dataset.populate_from_data(
    X=x_obs.values,
    batch_indices=batch_info["x"].values,
    labels=metadata["pop"],
    cell_types=metadata["pop"],
)

In [None]:
plt.hist(dataset.local_vars)

In [None]:
n_examples = len(dataset)
TEST_INDICES = np.random.np.random.permutation(n_examples)[:2000]

x_test, y_test = dataset.X[TEST_INDICES, :], dataset.labels[TEST_INDICES, :].squeeze()
data_path = os.path.join(DIR_PATH, 'data.npy')
labels_path = os.path.join(DIR_PATH, 'labels.npy')

np.save(
    data_path,
    x_test.squeeze().astype(int)
)
np.savetxt(
    labels_path,
    y_test.squeeze()
)

In [None]:
batch_idx = dataset.batch_indices.squeeze() == 1
x_batch = dataset.X[batch_idx]
print(x_batch)
print(dataset.X[batch_idx].shape)

## Train parameters

In [None]:
mdl_params = dict(
    iaf=dict(n_hidden=128, n_layers=1, do_h=True, n_latent=10, t=4),
    mf=dict(n_hidden=128, n_layers=1, n_latent=10),
    iaf_k5=dict(n_hidden=128, n_layers=1, do_h=True, n_latent=10, t=4),
    mf_k5=dict(n_hidden=128, n_layers=1, n_latent=10),
)
train_params = dict(
    iaf=dict(ratio_loss=True, test_indices=TEST_INDICES, frequency=1.0),
    mf=dict(ratio_loss=True, test_indices=TEST_INDICES, frequency=1.0),
    iaf_k5=dict(ratio_loss=True, test_indices=TEST_INDICES, k_importance_weighted=5),
    mf_k5=dict(ratio_loss=True, test_indices=TEST_INDICES, k_importance_weighted=5)
)
train_fn_params = dict(
    iaf=dict(n_epochs=N_EPOCHS, lr=1e-2),
    mf=dict(n_epochs=N_EPOCHS, lr=1e-2),
    iaf_k5=dict(n_epochs=N_EPOCHS, lr=1e-3),
    mf_k5=dict(n_epochs=N_EPOCHS, lr=1e-3),
)

## Verify that everything ok

In [None]:
# mf, mf_trainer = train_model(
#     mdl_class=VAE,
#     dataset=dataset,
#     mdl_params=mdl_params["mf"],
#     train_params=train_params["mf"],
#     train_fn_params=train_fn_params["mf"],
# )
# iaf, iaf_trainer = train_model(
#     mdl_class=IAVAE,
#     dataset=dataset,
#     mdl_params=mdl_params["iaf"],
#     train_params=train_params["iaf"],
#     train_fn_params=train_fn_params["iaf"]
# )


In [None]:
# plt.plot(mf_trainer.train_losses[5:], label="mf")
# plt.plot(iaf_trainer.train_losses[5:], label="iaf")
# plt.legend()
# plt.yscale("log")

In [None]:
# z_mf, labels_mf = mf_trainer.train_set.get_latents()
# z_iaf, labels_iaf = iaf_trainer.train_set.get_latents()

# z_mf = z_mf.cpu()
# z_iaf = z_iaf.cpu()

In [None]:
# from sklearn.manifold import TSNE

# idx = np.random.permutation(len(z_mf))[:1000]
# z_tsne = TSNE().fit_transform(z_mf[idx, :])
# labels_tsne = labels_mf[idx, :].cpu().squeeze()

# trace = go.Scatter(
#     x=z_tsne[:, 0],
#     y=z_tsne[:, 1],
#     marker_color=labels_tsne,
#     mode="markers",
#     marker_showscale=True,
# )
# fig = go.Figure([trace])
# fig.show()

In [None]:
# np.unique(dataset.labels.squeeze())

In [None]:
# z_tsne = TSNE().fit_transform(z_iaf[idx, :])
# labels_tsne = labels_iaf[idx, :].cpu().squeeze()

# trace = go.Scatter(
#     x=z_tsne[:, 0],
#     y=z_tsne[:, 1],
#     marker_color=labels_tsne,
#     mode="markers",
#     marker_showscale=True,
# )
# fig = go.Figure([trace])
# fig.show()

In [None]:
# z_iaf, labels_iaf, scales_iaf = iaf_trainer.test_set.get_latents(n_samples=100, other='scales', device="cpu")


# from scvi.utils import plot_identity

# where_a = np.where(labels_iaf == 0)[0][:200]
# where_b = np.where(labels_iaf == 1)[0][:200]


# scales_a = scales_iaf[:, where_a, :]
# scales_b = scales_iaf[:, where_b, :]

# lfc = np.log2(scales_a) - np.log2(scales_b)
# lfc = lfc.mean((0, 1))

# plt.scatter(x=lfc, y=lfc_gt)
# plot_identity()
# plt.show()

## Params

In [None]:
# not same indices (in R, 0 corresponds to 1)
label_a = 0
label_b = 1
n_genes = dataset.nb_genes

In [None]:
(lfc_info["12"].abs() >= DELTA).mean()

In [None]:
is_significant_de = (lfc_info["12"].abs() >= DELTA).values
lfc_gt = lfc_info["12"].values

# Compute competitors scores

In [None]:
os.listdir(DIR_PATH)

In [None]:
other_predictions = all_predictions(
    filename=os.path.join(DIR_PATH, "other_predictions.pickle"),
    n_genes=n_genes, 
    n_picks=N_PICKS, 
    sizes=SIZES, 
    data_path=data_path, 
    labels_path=labels_path,
    path_to_scripts=PATH_TO_SCRIPTS,
    label_a=label_a,
    label_b=label_b
)

other_predictions = all_de_predictions(
    other_predictions, significance_level=Q0, delta=DELTA
)

In [None]:
filename=os.path.join(DIR_PATH, "other_predictions.pickle")
n_genes=n_genes 
n_picks=N_PICKS 
sizes=SIZES 
data_path=data_path 
labels_path=labels_path
path_to_scripts=PATH_TO_SCRIPTS
label_a=label_a
label_b=label_b
all_nature = False
lfc_threshold = 0.5

from R_interop import NDESeq2, NEdgeRLTRT, NMASTcpm, MAST
from tqdm import tqdm

# n_sizes = len(sizes)

# # DESeq2
# lfcs_deseq2 = np.zeros((n_sizes, n_picks, n_genes))
# pvals_deseq2 = np.zeros((n_sizes, n_picks, n_genes))
# for (size_ix, size) in enumerate(tqdm(sizes)):
#     for exp in range(n_picks):
#         deseq_inference = NDESeq2(
#             A=size,
#             B=size,
#             data=data_path,
#             labels=labels_path,
#             cluster=(label_a, label_b),
#             path_to_scripts=path_to_scripts,
#             lfc_threshold=lfc_threshold,
#         )
#         res_df = deseq_inference.fit()
#         lfcs_deseq2[size_ix, exp, :] = res_df["lfc"].values
#         pvals_deseq2[size_ix, exp, :] = res_df["padj"].values
# deseq_res = dict(lfc=lfcs_deseq2.squeeze(), pval=pvals_deseq2.squeeze())

# # EdgeR
# lfcs_edge_r = np.zeros((n_sizes, n_picks, n_genes))
# pvals_edge_r = np.zeros((n_sizes, n_picks, n_genes))
# for (size_ix, size) in enumerate(tqdm(sizes)):
#     for exp in range(n_picks):
#         deseq_inference = NEdgeRLTRT(
#             A=size,
#             B=size,
#             data=data_path,
#             labels=labels_path,
#             cluster=(label_a, label_b),
#             path_to_scripts=path_to_scripts,
#         )
#         res_df = deseq_inference.fit()
#         lfcs_edge_r[size_ix, exp, :] = res_df["lfc"].values
#         pvals_edge_r[size_ix, exp, :] = res_df["padj"].values
# edger_res = dict(lfc=lfcs_edge_r.squeeze(), pval=pvals_edge_r.squeeze())

# # MAST
# lfcs_mast = np.zeros((n_sizes, n_picks, n_genes))
# var_lfcs_mast = np.zeros((n_sizes, n_picks, n_genes))
# pvals_mast = np.zeros((n_sizes, n_picks, n_genes))
# for (size_ix, size) in enumerate(tqdm(sizes)):
#     for exp in range(n_picks):
#         if all_nature:
#             mast_inference = NMASTcpm(
#                 A=size,
#                 B=size,
#                 data=data_path,
#                 labels=labels_path,
#                 cluster=(label_a, label_b),
#                 path_to_scripts=path_to_scripts,
#             )
#             res_df = mast_inference.fit()
#             print(res_df.info())
#             var_lfcs_mast[size_ix, exp, :] = res_df["varLogFC"].values
#             lfcs_mast[size_ix, exp, :] = res_df["logFC"].values

#         else:
#             mast_inference = MAST(
#                 A=size,
#                 B=size,
#                 data=data_path,
#                 labels=labels_path,
#                 cluster=(label_a, label_b),
#             )
#             res_df = mast_inference.fit(return_fc=True)
#             lfcs_mast[size_ix, exp, :] = res_df["lfc"].values
#         pvals_mast[size_ix, exp, :] = res_df["pval"].values
# mast_res = dict(
#     lfc=lfcs_mast.squeeze(), pval=pvals_mast.squeeze(), var_lfc=var_lfcs_mast
# )

# results = dict(deseq2=deseq_res, edger=edger_res, mast=mast_res)
# save_pickle(data=results, filename=filename)


In [None]:
# for (size_ix, size) in enumerate(tqdm(sizes)):
#     for exp in range(n_picks):
#         deseq_inference = NDESeq2(
#             A=size,
#             B=size,
#             data=data_path,
#             labels=labels_path,
#             cluster=(label_a, label_b),
#             path_to_scripts=path_to_scripts,
#             lfc_threshold=lfc_threshold,
#         )
#         res_df = deseq_inference.fit()
#         lfcs_deseq2[size_ix, exp, :] = res_df["lfc"].values
#         pvals_deseq2[size_ix, exp, :] = res_df["padj"].values
# deseq_res = dict(lfc=lfcs_deseq2.squeeze(), pval=pvals_deseq2.squeeze())

In [None]:
# other_predictions = results

# Experiments

In [None]:
res_mf = multi_train_estimates(
    filename=os.path.join(DIR_PATH, "res_mf.pickle"),
    mdl_class=VAE,
    dataset=dataset,
    mdl_params=mdl_params["mf"],
    train_params=train_params["mf"],
    train_fn_params=train_fn_params["mf"],
    sizes=SIZES,
    n_trainings=N_TRAININGS,
    n_picks=N_PICKS,
    label_a=label_a,
    label_b=label_b
).assign(algorithm="MF")

res_iaf = multi_train_estimates(
    filename=os.path.join(DIR_PATH, "res_ia.pickle"),
    mdl_class=IAVAE,
    dataset=dataset,
    mdl_params=mdl_params["iaf"],
    train_params=train_params["iaf"],
    train_fn_params=train_fn_params["iaf"],
    sizes=SIZES,
    n_trainings=N_TRAININGS,
    n_picks=N_PICKS,
    label_a=label_a,
).assign(algorithm="IAF")

## FDR / Power Control and PR Curves

In [None]:
def train_model(
    mdl_class, dataset, mdl_params: dict, train_params: dict, train_fn_params: dict
):
    """

    :param mdl_class: Class of algorithm
    :param dataset: Dataset
    :param mdl_params:
    :param train_params:
    :param train_fn_params:
    :return:
    """
    my_vae = mdl_class(dataset.nb_genes, n_batch=dataset.n_batches, **mdl_params)
    my_trainer = UnsupervisedTrainer(my_vae, dataset, **train_params)
    print(my_trainer.test_set.data_loader.sampler.indices)
    my_trainer.train(**train_fn_params)
    print(my_trainer.train_losses)
    return my_vae, my_trainer

### FDR and TPR Control

In [None]:
def fdr_fnr(my_df):
    my_df = my_df.sort_values("gene")
    assert len(my_df) == n_genes
    is_pred_de = predict_de_genes(my_df.de_proba.values, desired_fdr=Q0)
    true_fdr = ((1.0 - is_significant_de) * is_pred_de).sum() / is_pred_de.sum()
    n_positives = is_significant_de.sum()
    true_fnr = (is_significant_de * (1.0 - is_pred_de)).sum() / n_positives
    return pd.Series(dict(fdr=true_fdr, fnr=true_fnr))


fdr_fnr_mf = (
    res_mf.groupby(["experiment", "training", "sample_size"])
    .apply(fdr_fnr)
    .reset_index()
    .assign(algorithm="MF")
)
fdr_fnr_iaf = (
    res_iaf.groupby(["experiment", "training", "sample_size"])
    .apply(fdr_fnr)
    .reset_index()
    .assign(algorithm="IAF")
)

df = pd.concat([fdr_fnr_mf, fdr_fnr_iaf], ignore_index=True)


fig = px.box(
    df,
    x="sample_size",
    y="fdr",
    color="algorithm",
    title="Control on False Discovery Rate",
)
fig.show()
# iplot(fig, filename="powsimr_fdr_control")

fig = px.box(
    df,
    x="sample_size",
    y="fnr",
    color="algorithm",
    title="Control on False Negative Rate",
)
fig.show()
# iplot(fig, filename="powsimr_power_control")

In [None]:
# ['deseq2', 'edger', 'mast']

def get_fdr_fnr(y_pred, y_true):
    """
        y_pred: (n_sz, n_picks, n_genes) bool predictions
        y_true: (n_genes) gt vals
    """
    n_sz, n_picks, _ = y_pred.shape
    fnrs = np.zeros((n_sz, n_picks))
    fdrs = np.zeros((n_sz, n_picks))
    for sz in range(n_sz):
        for pick in range(n_picks):
            y_pred_it = y_pred[sz, pick, :]
            fnr = ((~y_true) * y_pred_it).sum() / y_pred_it.sum()
            fdr = (y_true * (~y_pred_it)).sum() / y_true.sum()
            fnrs[sz, pick] = fnr
            fdrs[sz, pick] = fdr
    fnrs[np.isnan(fnrs)] = 0.0
    return dict(fnr=fnrs, fdr=fdrs)

print(other_predictions["mast"]['pval'].shape)
print(other_predictions["deseq2"]['pval'].shape)
print(other_predictions["edger"]['pval'].shape)

is_de_mast = other_predictions["mast"]["is_de"]
is_de_deseq2 = other_predictions["deseq2"]["is_de"]
is_de_edger = other_predictions["edger"]["is_de"]

res_mast = get_fdr_fnr(is_de_mast, y_true=is_significant_de)
res_deseq2 = get_fdr_fnr(is_de_deseq2, y_true=is_significant_de)
res_edger = get_fdr_fnr(is_de_edger, y_true=is_significant_de)

### Confusion Matrices

In [None]:
# trains_res = all_fdrs.mean(axis=1)
# print(trains_res.mean(), trains_res.std())

In [None]:
# from sklearn.metrics import confusion_matrix

# y_preds_1d = y_preds.reshape((-1, dataset.nb_genes))
# n_exps = len(y_preds_1d)
# confs = np.zeros((n_exps, 2, 2))
# for i in range(n_exps):
#     confs[i, :, :] = confusion_matrix(is_significant_de, y_preds_1d[i, :])

In [None]:
# confusion_matrix(is_significant_de, y_preds_1d[0, :])

# confs_mean = confs.mean(0)
# confs_mean

# fig = ff.create_annotated_heatmap(
#     z=confs_mean, x=["Pred Negative", "Pred Positive"], y=["GT Negative", "GT Positive"]
# )
# fig.update({"layout": dict(title="Confusion Matrix")})

# py.iplot(fig)

### PR Curves

In [None]:
selected_training = 1

preds_md = res_mf.loc[
    lambda x: (x.experiment == 0) & (x.training == selected_training) & (x.sample_size == 100)
].sort_values("gene")["de_proba"]

preds_iaf = res_iaf.loc[
    lambda x: (x.experiment == 0) & (x.training == selected_training) & (x.sample_size == 100)
].sort_values("gene")["de_proba"]

In [None]:
from sklearn.metrics import precision_recall_curve

preds_deseq2 = 1.0 - other_predictions['deseq2']['pval'][-1, 0, :]
preds_edger = 1.0 - other_predictions['edger']['pval'][-1, 0, :]
preds_mast = 1.0 - other_predictions['mast']['pval'][-1, 0, :]

In [None]:
print(np.isnan(preds_md).mean())
print(np.isnan(preds_iaf).mean())
print(np.isnan(preds_deseq2).mean())
print(np.isnan(preds_deseq2).mean())
print(np.isnan(preds_edger).mean())
print(np.isnan(preds_mast).mean())

In [None]:
preds_deseq2[np.isnan(preds_deseq2)] = 0.0

In [None]:
from sklearn.metrics import precision_recall_curve, average_precision_score

def plot_pr(fig, preds, y_true, name):
    average_precision = average_precision_score(y_true, preds)
    preds[np.isnan(preds)] = np.min(preds[~np.isnan(preds)])
    precs, recs, _ = precision_recall_curve(y_true=y_true, probas_pred=preds)
    fig.add_trace(
        go.Scatter(
            x=recs,
            y=precs,
            name=name+'@AP: {0:0.2f}'.format(average_precision)
        )
    )
    return
layout = go.Layout(
    title='Precision Recall Curves',
    xaxis=dict(title='Recall'),
    yaxis=dict(title='Precision'),
    width=800,
    height=600,
)
fig = go.Figure(layout=layout)
plot_pr(fig=fig, preds=preds_md, y_true=is_significant_de, name='MF')
plot_pr(fig=fig, preds=preds_iaf, y_true=is_significant_de, name='IAF')
plot_pr(fig=fig, preds=preds_deseq2, y_true=is_significant_de, name='DESeq2')
plot_pr(fig=fig, preds=preds_edger, y_true=is_significant_de, name='EdgeR')
plot_pr(fig=fig, preds=preds_mast, y_true=is_significant_de, name='MAST')

fig.show()

## Gene ranking?

In [None]:
print(preds_md.shape)
print(preds_iaf.shape)
print(preds_deseq2.shape)
print(preds_edger.shape)
print(preds_mast.shape)

In [None]:
gene_ranks_md = np.argsort(-preds_md)
gene_ranks_iaf = np.argsort(-preds_iaf)
gene_ranks_deseq2 = np.argsort(-preds_deseq2)
gene_ranks_edger = np.argsort(-preds_edger)
gene_ranks_mast = np.argsort(-preds_mast)

gt_ranks = np.argsort(-np.abs(lfc_gt))

In [None]:
from scipy.stats import spearmanr

# rhos_md = spearmanr(gene_ranks_md, gt_ranks)
# rhos_iaf = spearmanr(gene_ranks_iaf, gt_ranks)
# rhos_deseq2 = spearmanr(gene_ranks_deseq2, gt_ranks)
# rhos_edger = spearmanr(gene_ranks_edger, gt_ranks)
# rhos_mast = spearmanr(gene_ranks_mast, gt_ranks)

rhos_md = spearmanr(preds_md, lfc_gt)
rhos_iaf = spearmanr(preds_iaf, lfc_gt)
rhos_deseq2 = spearmanr(preds_deseq2, lfc_gt)
rhos_edger = spearmanr(preds_edger, lfc_gt)
rhos_mast = spearmanr(preds_mast, lfc_gt)

In [None]:
print(rhos_md)
print(rhos_iaf)
print(rhos_deseq2)
print(rhos_edger)
print(rhos_mast)

## Volcano Plot

In [None]:
fig = go.Figure(
    layout=go.Layout(
        yaxis=dict(title="Estimated probabily of DE"),
        xaxis=dict(title="Ground-Truth LFC"),
    )
)
fig.add_traces(
    [
        go.Scatter(x=lfc_gt, y=np.log10(preds_md), mode="markers"),
        go.Scatter(x=lfc_gt, y=np.log10(preds_iaf), mode="markers"),
#         go.Scatter(x=lfc_gt, y=np.log10(preds_mast), mode="markers"),
#         go.Scatter(x=lfc_gt, y=np.log10(preds_edger), mode="markers"),
        go.Scatter(
            x=[-0.5, -0.5], y=[-6, 0.0], mode="lines", line=dict(color="black", width=2)
        ),
        go.Scatter(
            x=[0.5, 0.5], y=[-6, 0.0], mode="lines", line=dict(color="black", width=2)
        ),
        #         go.Scatter(
        #             x=[alpha, alpha], y=[-6, 0.0], mode="lines", line=dict(color="black", width=2)
        #         ),
    ]
)

fig.show()
iplot(fig, filename="symsim_volcano", sharing="private")

## Diagonal Curve

In [None]:
subsample_genes = np.sort(np.random.permutation(n_genes)[:150])

lfcs_mf = (
    res_mf
    .loc[
        lambda x: (x.experiment == 0)
        & (x.training == selected_training)
        & (x.sample_size == 100)
        & (x.gene.isin(subsample_genes))
    ]
    .sort_values("gene")
    [["lfc_mean", "hdi64_low", "hdi64_high", "algorithm"]]
    .assign(
        err_minus=lambda x: x.lfc_mean - x.hdi64_low,
        err_pos=lambda x: x.hdi64_high - x.lfc_mean,
        lfc_gt=lfc_gt[subsample_genes]
    )
)

lfcs_ia = (
    res_iaf
    .loc[
        lambda x: (x.experiment == 0)
        & (x.training == selected_training)
        & (x.sample_size == 100)
        & (x.gene.isin(subsample_genes))
    ]
    .sort_values("gene")
    [["lfc_mean", "hdi64_low", "hdi64_high", "algorithm"]]
    .assign(
        err_minus=lambda x: x.lfc_mean - x.hdi64_low,
        err_pos=lambda x: x.hdi64_high - x.lfc_mean,
        lfc_gt=lfc_gt[subsample_genes]
    )
)


all_lfcs = pd.concat([lfcs_mf, lfcs_ia], ignore_index=True)

In [None]:
fig = px.scatter(
    lfcs_mf,
    x="lfc_gt",
    y="lfc_mean",
    error_y="err_pos",
    error_y_minus="err_minus",
)

fig.add_trace(
    go.Scatter(
        x=[-3, 3],
        y=[-3, 3],
        mode="lines",
        line=dict(color="black", width=4, dash="dash"),
    )
)

fig.show()

In [None]:
fig = px.scatter(
    lfcs_ia,
    x="lfc_gt",
    y="lfc_mean",
    error_y="err_pos",
    error_y_minus="err_minus",
)

fig.add_trace(
    go.Scatter(
        x=[-3, 3],
        y=[-3, 3],
        mode="lines",
        line=dict(color="black", width=4, dash="dash"),
    )
)

fig.show()

In [None]:
fig = px.scatter(
    all_lfcs,
    x="lfc_gt",
    y="lfc_mean",
    color="algorithm",
    error_y="err_pos",
    error_y_minus="err_minus",
)

fig.add_trace(
    go.Scatter(
        x=[-3, 3],
        y=[-3, 3],
        mode="lines",
        line=dict(color="black", width=4, dash="dash"),
    )
)

fig.show()

## Study of LFC errors

In [None]:
def compute_l2_err(diff):
    res = 0.5 * (diff ** 2) ** (0.5)
    res = np.nanmean(res, axis=-1)
    return res

def l2_err_competitor(vals: np.ndarray, other: np.ndarray = None):
    vals[np.isnan(vals)] = 0.0
    if other is None:
        diff = vals
    else:
        diff = vals - other
    res = compute_l2_err(diff)
    assert res.shape == (N_SIZES, N_PICKS)
    data = []
    for (size_ix, size) in enumerate(SIZES):
        for pick in range(N_PICKS):
            data.append(dict(experiment=pick, training=0, sample_size=size, error=res[size_ix, pick]))
    return pd.DataFrame(data)

lfcs_errs_deseq2 = l2_err_competitor(other_predictions["deseq2"]["lfc"], other=lfc_gt).assign(algorithm="DESeq2")
lfcs_errs_edger = l2_err_competitor(other_predictions["edger"]["lfc"], other=lfc_gt).assign(algorithm="EdgeR")
lfcs_errs_mast = l2_err_competitor(other_predictions["mast"]["lfc"], other=lfc_gt).assign(algorithm="MAST")

In [None]:
def pd_l2_err(my_df):
    diff = my_df.sort_values("gene")["lfc_mean"] - lfc_gt
    error = 0.5 * (diff ** 2) ** (0.5)
    error = np.nanmean(error)
    return pd.Series(dict(error=error))

lfcs_errs_mf = (
    res_mf
    .groupby(["experiment", "sample_size", "training", "algorithm"])
    .apply(pd_l2_err)
    .reset_index()
)

lfcs_errs_iaf = (
    res_iaf
    .groupby(["experiment", "sample_size", "training", "algorithm"])
    .apply(pd_l2_err)
    .reset_index()
)

In [None]:
all_errs = pd.concat([
    lfcs_errs_mf,
    lfcs_errs_iaf,
    lfcs_errs_deseq2,
    lfcs_errs_edger,
    lfcs_errs_mast,
], ignore_index=True)

px.box(all_errs, x="sample_size", y="error", color="algorithm")

# Debug

In [None]:
# iw_vae = IAVAE(
#     dataset.nb_genes, n_batch=dataset.n_batches, n_hidden=32, n_layers=1, do_h=True, n_latent=10, t=4
# )
# iw_trainer = UnsupervisedTrainer(
#     iw_vae, dataset, ratio_loss=True, k_importance_weighted=5, single_backward=False
# )
# iw_trainer.train(n_epochs=50, lr=1e-3)
# iw_trainer.train_losses

iw_vae = VAE(
    dataset.nb_genes, n_batch=dataset.n_batches, n_hidden=128, n_layers=1, n_latent=10
)
iw_trainer = UnsupervisedTrainer(
    iw_vae, dataset, ratio_loss=True, k_importance_weighted=5, single_backward=False
)
iw_trainer.train(n_epochs=50, lr=1e-3)
iw_trainer.train_losses



vae = VAE(
    dataset.nb_genes, n_batch=dataset.n_batches, n_hidden=128, n_layers=1, n_latent=10
)
trainer = UnsupervisedTrainer(
    vae, dataset, ratio_loss=True, #k_importance_weighted=5, single_backward=False
)
trainer.train(n_epochs=50, lr=1e-3)
trainer.train_losses


In [None]:
test_iw = iw_trainer.test_set
test_mf = trainer.test_set

In [None]:
def get_lfc(post):
    z_iaf, labels_iaf, scales_iaf = post.get_latents(n_samples=100, other='scales', device="cpu")
    where_a = np.where(labels_iaf == label_a)[0][:100]
    where_b = np.where(labels_iaf == label_b)[0][:100]

    scales_a = scales_iaf[:, where_a, :]
    scales_b = scales_iaf[:, where_b, :]

    lfc = np.log2(scales_a) - np.log2(scales_b)
    lfc = lfc.reshape((-1, n_genes))
    lfc = np.array(lfc)
    return (np.abs(lfc) >= DELTA).mean(0)

de_probas_iw = get_lfc(test_iw)
de_probas_mf = get_lfc(test_mf)

In [None]:
from sklearn.metrics import precision_recall_curve, average_precision_score

def plot_pr(fig, preds, y_true, name):
    average_precision = average_precision_score(y_true, preds)
    preds[np.isnan(preds)] = np.min(preds[~np.isnan(preds)])
    precs, recs, _ = precision_recall_curve(y_true=y_true, probas_pred=preds)
    fig.add_trace(
        go.Scatter(
            x=recs,
            y=precs,
            name=name+'@AP: {0:0.2f}'.format(average_precision)
        )
    )
    return

In [None]:
fig = go.Figure()
plot_pr(fig=fig, preds=de_probas_iw, y_true=is_significant_de, name='IW IAF')
plot_pr(fig=fig, preds=de_probas_mf, y_true=is_significant_de, name='MF')
fig.show()