In [None]:
import numpy as np
import os
import matplotlib.pyplot as plt
import plotly.graph_objs as go
import plotly.express as px
import plotly as py
import pandas as pd
from chart_studio.plotly import plot, iplot

# from plotly.offline import init_notebook_mode, iplot
from tqdm import tqdm_notebook

from scvi.dataset import PowSimSynthetic, LatentLogPoissonDataset
from scvi.models import VAE, IAVAE
from scvi.inference import UnsupervisedTrainer
from scvi.utils import demultiply, make_dir_if_necessary, predict_de_genes
from scvi_utils import estimate_de_proba, estimate_lfc_density, estimate_lfc_mean
from R_interop import all_predictions


N_EPOCHS = 1
DELTA = 0.5
SIZES = [5, 10, 20, 30, 50, 100]
SIZE = 100
N_SIZES = len(SIZES)

Q0 = 5e-2
N_TRAININGS = 1
N_PICKS = 2

np.random.seed(42)

DIR_PATH = 'lfc_estimates/powsimr'
make_dir_if_necessary(DIR_PATH)

# Generate Dataset

In [None]:
dataset = PowSimSynthetic(
    cluster_to_samples=[7500, 7500],
    de_p=0.5,
    n_genes=1500,
    mode="NB",
#     n_genes_zi=250,
#     p_dropout=0.3,
)

is_significant_de = np.abs(dataset.lfc[:, 1] - dataset.lfc[:, 0]) >= DELTA
n_genes = dataset.nb_genes
trace1 = go.Histogram(x=dataset.lfc[:, 1] - dataset.lfc[:, 0])
fig = go.Figure(data=[trace1])
fig.show()

In [None]:
n_examples = len(dataset)
TEST_INDICES = np.random.np.random.permutation(n_examples)[:2000]

x_test, y_test = dataset.X[TEST_INDICES, :], dataset.labels[TEST_INDICES, :].squeeze()
data_path = os.path.join(DIR_PATH, 'data.npy')
labels_path = os.path.join(DIR_PATH, 'labels.npy')

np.save(
    data_path,
    x_test.squeeze().astype(int)
)
np.savetxt(
    labels_path,
    y_test.squeeze()
)

# Compute competitors scores

In [None]:
other_predictions = all_predictions(
    n_genes=n_genes, 
    n_picks=N_PICKS, 
    sizes=SIZES, 
    data_path=data_path, 
    labels_path=labels_path
)

# Experiments

## FDR / Power Control and PR Curves

In [None]:
def train_model(
    mdl_class, dataset, mdl_params: dict, train_params: dict, train_fn_params: dict
):
    """

    :param mdl_class: Class of algorithm
    :param dataset: Dataset
    :param mdl_params:
    :param train_params:
    :param train_fn_params:
    :return:
    """
    my_vae = mdl_class(dataset.nb_genes, n_batch=dataset.n_batches, **mdl_params)
    my_trainer = UnsupervisedTrainer(my_vae, dataset, **train_params)
    print(my_trainer.test_set.data_loader.sampler.indices)
    my_trainer.train(**train_fn_params)
    print(my_trainer.train_losses)
    return my_vae, my_trainer



### FDR and TPR Control

In [None]:
de_probas_mf = estimate_de_proba(
    VAE,
    dataset=dataset,
    mdl_params=dict(n_hidden=128, n_layers=1, n_latent=5),
    train_params=dict(ratio_loss=True, test_indices=TEST_INDICES),
    train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-4),
    sizes=SIZES,
    n_trainings=N_TRAININGS,
    n_picks=N_PICKS,
)
de_probas_ia = estimate_de_proba(
    IAVAE,
    dataset=dataset,
    mdl_params=dict(n_hidden=128, n_layers=1, do_h=True, n_latent=5, t=4),
    train_params=dict(ratio_loss=True, test_indices=TEST_INDICES),
    train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-4),
    sizes=SIZES,
    n_trainings=N_TRAININGS,
    n_picks=N_PICKS,
)
# results_mf5 = fdr_control(
#     VAE,
#     mdl_params=dict(n_hidden=128, n_layers=1, n_latent=5),
#     train_params=dict(
#         train_size=0.7, ratio_loss=True, k_importance_weighted=5, single_backward=False
#     ),
#     train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-4),
#     n_trainings=N_TRAININGS,
#     n_picks=N_PICKS,
# )
# results_mf50tr = fdr_control(
#     VAE,
#     mdl_params=dict(n_hidden=64, n_layers=3, n_latent=5),
#     train_params=dict(train_size=0.7, ratio_loss=True, k_importance_weighted=25),
#     train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-4),
#     n_trainings=N_TRAININGS,
#     n_picks=N_PICKS
# )

In [None]:
def compute_stats(probas_arr, is_significant_de):
    fdrs = np.zeros((N_TRAININGS, len(SIZES), N_PICKS))
    fnrs = np.zeros((N_TRAININGS, len(SIZES), N_PICKS))
    for i in range(N_TRAININGS):
        for j in range(len(SIZES)):
            for k in range(N_PICKS):
                probs_pred_de = probas_arr[i, j, k, :]
                is_pred_de = predict_de_genes(probs_pred_de, desired_fdr=Q0)
                
                true_fdr = ((~is_significant_de) * is_pred_de).sum() / len(probs_pred_de)
                n_positives = is_significant_de.sum()
                true_fnr = (is_significant_de * (~is_pred_de)).sum() / n_positives
                fdrs[i, j, k] = true_fdr
                fnrs[i, j, k] = true_fnr
    return dict(fdr=fdrs, fnr=fnrs)

In [None]:
results_mf = compute_stats(de_probas_mf, is_significant_de=is_significant_de)
results_ia = compute_stats(de_probas_ia, is_significant_de=is_significant_de)

In [None]:
all_results = [
    results_mf,
    results_ia,
]
names = [
    "MF",
    "IAF",
]

In [None]:
associated_sizes = np.array(SIZES).reshape((1, -1))
associated_sizes = np.tile(associated_sizes, [N_TRAININGS, 1])

all_dfs = []
for result, name in zip(all_results, names):
    fdrs = result["fdr"]
    fnrs = result["fnr"]
    fdrs[np.isnan(fdrs)] = 0.0
    new_df = pd.DataFrame(
        dict(
            fdr=fdrs.mean(-1).reshape(-1), 
            size=associated_sizes.reshape(-1), 
            fnr=fnrs.mean(-1).reshape(-1)
        )
    ).assign(posterior=lambda x: name)
    all_dfs.append(new_df)

df = pd.concat(all_dfs, ignore_index=True).dropna()

display(df.head())

In [None]:
fig = px.box(
    df,
    x="size",
    y="fdr",
    color="posterior",
    title="Control on False Discovery Rate",
)
fig.show()
# iplot(fig, filename="fdr_control")

In [None]:
fig = px.box(
    df,
    x="size",
    y="fdr",
    color="posterior",
    title="Control on False Discovery Rate",
)
fig.show()
# iplot(fig, filename="fdr_control")

In [None]:
fig = px.box(
    df,
    x="size",
    y="fnr",
    color="posterior",
    title="Control on False Negative Rate",
)
fig.show()
# iplot(fig, filename="fdr_control")

### Confusion Matrices

In [None]:
trains_res = all_fdrs.mean(axis=1)
print(trains_res.mean(), trains_res.std())

In [None]:
from sklearn.metrics import confusion_matrix

y_preds_1d = y_preds.reshape((-1, dataset.nb_genes))
n_exps = len(y_preds_1d)
confs = np.zeros((n_exps, 2, 2))
for i in range(n_exps):
    confs[i, :, :] = confusion_matrix(is_significant_de, y_preds_1d[i, :])

In [None]:
confusion_matrix(is_significant_de, y_preds_1d[0, :])

In [None]:
confs_mean = confs.mean(0)
confs_mean

In [None]:
fig = ff.create_annotated_heatmap(
    z=confs_mean, x=["Pred Negative", "Pred Positive"], y=["GT Negative", "GT Positive"]
)
fig.update({"layout": dict(title="Confusion Matrix")})

py.iplot(fig)

### PR Curves

In [None]:
de_probas_mf.shape
de_probas_ia.shape
print(other_predictions['mast']['pval'].shape)

In [None]:
from sklearn.metrics import precision_recall_curve

preds_md = 1.0 - de_probas_mf[0, -1, 0, :]
preds_iaf = 1.0 - de_probas_ia[0, -1, 0, :]
preds_deseq2 = 1.0 - other_predictions['deseq2']['pval'][-1, 0, :]
preds_edger = 1.0 - other_predictions['edger']['pval'][-1, 0, :]
preds_mast = 1.0 - other_predictions['mast']['pval'][-1, 0, :]

In [None]:
from sklearn.metrics import precision_recall_curve, average_precision_score

def plot_pr(fig, preds, y_true, name):
    average_precision = average_precision_score(y_true, preds)
    preds[np.isnan(preds)] = np.min(preds[~np.isnan(preds)])
    precs, recs, _ = precision_recall_curve(y_true=y_true, probas_pred=preds)
    fig.add_trace(
        go.Scatter(
            x=recs,
            y=precs,
            name=name+'@AP: {0:0.2f}'.format(average_precision)
        )
    )
    return
layout = go.Layout(
    title='Precision Recall Curves',
    xaxis=dict(title='Recall'),
    yaxis=dict(title='Precision'),
    width=800,
    height=600,
)
fig = go.Figure(layout=layout)
plot_pr(fig=fig, preds=preds_md, y_true=is_significant_de, name='MF')
plot_pr(fig=fig, preds=preds_iaf, y_true=is_significant_de, name='IAF')
plot_pr(fig=fig, preds=preds_deseq2, y_true=is_significant_de, name='DESeq2')
plot_pr(fig=fig, preds=preds_edger, y_true=is_significant_de, name='EdgeR')
plot_pr(fig=fig, preds=preds_mast, y_true=is_significant_de, name='MAST')

fig.show()

## Diagonal Curve

In [None]:
lfc_gt = -(dataset.lfc[:, 1] - dataset.lfc[:, 0])

In [None]:
lfcs_mf = estimate_lfc_density(
    VAE,
    dataset=dataset,
    mdl_params=dict(n_hidden=128, n_layers=1, n_latent=5),
    train_params=dict(ratio_loss=True, test_indices=TEST_INDICES),
    train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-3),
    sizes=[SIZE],
    n_picks=1,
)[SIZE]

lfcs_ia = estimate_lfc_density(
    IAVAE,
    dataset=dataset,
    mdl_params=dict(n_hidden=128, n_layers=1, do_h=True, n_latent=5, t=4),
    train_params=dict(ratio_loss=True, test_indices=TEST_INDICES),
    train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-3),
    sizes=[SIZE],
    n_picks=1,
)[SIZE]

Plot diagonal curve showing quality of LFC prediction

In [None]:
from plotly.subplots import make_subplots

lfcs_mf_est = lfcs_ia.reshape((-1, n_genes))
lfcs_ia_est = lfcs_mf.reshape((-1, n_genes))


fig = make_subplots(
    rows=1,
    cols=2,
    subplot_titles=("Mean Field", "Inverse Autoregressive Flows"),
    shared_xaxes=True,
    shared_yaxes=True,
)


def add_plot(fig, lfcs_est_m, lfcs_est_err, row, col):
    fig.add_trace(
        go.Scatter(
            x=lfc_gt,
            y=lfcs_est_m,
            error_y=dict(type="data", array=lfcs_est_err, visible=True),
            mode="markers",
        ),
        row=row,
        col=col,
    )
    return


add_plot(fig, lfcs_mf_est.mean(0), lfcs_mf_est.std(0), row=1, col=1)
fig.add_trace(
    go.Scatter(
        x=[-3, 3],
        y=[-3, 3],
        mode="lines",
        line=dict(color="black", width=4, dash="dash"),
    ),
    row=1,
    col=1,
)
add_plot(fig, lfcs_ia_est.mean(0), lfcs_ia_est.std(0), row=1, col=2)
fig.add_trace(
    go.Scatter(
        x=[-3, 3],
        y=[-3, 3],
        mode="lines",
        line=dict(color="black", width=4, dash="dash"),
    ),
    row=1,
    col=2,
)

fig.update_xaxes(title_text="Ground Truth LFC", row=1, col=1)
fig.update_xaxes(title_text="Ground Truth LFC", row=2, col=1)
fig.update_xaxes(title_text="Predicted LFC", row=1, col=1)
fig.update_xaxes(title_text="Predicted LFC", row=2, col=1)

fig.update_layout(
    height=600, width=1000, title_text="LFC estimation for {} sample cells".format(sz)
)
fig.show()
# iplot(fig, filename='lfc_with_uncertainty')

## Study of LFC errors

In [None]:
lfc_gt = -(dataset.lfc[:, 1] - dataset.lfc[:, 0])

def l2_err(vals: np.ndarray, other: np.ndarray = None):
    if other is None:
        diff = vals
    else:
        diff = vals - other
    res = 0.5 * (diff ** 2) ** (0.5)
    res = np.nanmean(res, axis=-1)
    return res

In [None]:
lfcs_pred_mf = estimate_lfc_mean(
    VAE,
    dataset=dataset,
    mdl_params=dict(n_hidden=128, n_layers=1, n_latent=5),
    train_params=dict(train_size=0.7, ratio_loss=True, test_indices=TEST_INDICES),
    train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-3),
    sizes=SIZES,
    n_picks=N_PICKS,
)

lfcs_pred_ia = estimate_lfc_mean(
    IAVAE,
    dataset=dataset,
    mdl_params=dict(n_hidden=128, n_layers=1, do_h=True, n_latent=5, t=4),
    train_params=dict(train_size=0.7, ratio_loss=True, test_indices=TEST_INDICES),
    train_fn_params=dict(n_epochs=N_EPOCHS, lr=1e-3),
    sizes=SIZES,
    n_picks=N_PICKS,
)

In [None]:
lfcs_errs_mf = np.array([l2_err(arr, other=lfc_gt) for (size, arr) in lfcs_pred_mf.items()])
lfcs_errs_ia = np.array([l2_err(arr, other=lfc_gt) for (size, arr) in lfcs_pred_ia.items()])

In [None]:
print(lfcs_errs_mf.shape, other_predictions["deseq2"]['lfc'].shape, lfc_gt.shape, lfcs_pred_mf[100].shape)

In [None]:
lfcs_pred_mf[100] - lfc_gt

In [None]:
lfcs_errs_deseq2 = l2_err(other_predictions["deseq2"]["lfc"], other=lfc_gt)
lfcs_errs_edger = l2_err(other_predictions["edger"]["lfc"], other=lfc_gt)
lfcs_errs_mast = l2_err(other_predictions["mast"]["lfc"], other=lfc_gt)

In [None]:
other_predictions["deseq2"]["lfc"][-1, 0] - lfc_gt

In [None]:
l2_err(other_predictions["deseq2"]["lfc"][-1, 0], lfc_gt)

In [None]:
assert lfcs_errs_mf.shape == (N_SIZES, N_PICKS)
assert lfcs_errs_deseq2.shape == (N_SIZES, N_PICKS)

In [None]:
def plot_errors(fig, errors, color="red", name=""):
    """
        errors should be (n_sizes, n_picks)
    """
    errs_mean = errors.mean(1)
    errs_std = errors.std(1)
    fig.add_trace(
        go.Scatter(
            x=SIZES,
            y=errs_mean - errs_std,
            mode="lines",
            line_color=color,
            showlegend=False,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=SIZES,
            y=errs_mean + errs_std,
            line_color=color,
            mode="lines",
            fill="tonexty",
            showlegend=False,
        )
    )
    fig.add_trace(
        go.Scatter(
            x=SIZES, y=errs_mean, line_color=color, mode="lines+markers", name=name
        )
    )
    return

In [None]:
fig = go.Figure()
fig.update_layout(
    title=go.layout.Title(text="L2 LFC Prediction Error"),
    xaxis=go.layout.XAxis(title=go.layout.xaxis.Title(text="Number of sampled cells")),
    yaxis=go.layout.YAxis(title=go.layout.yaxis.Title(text="Error")),
)

plot_errors(fig, lfcs_errs_mf, name="Mean Field")
plot_errors(fig, lfcs_errs_ia, color="blue", name="IAF")

plot_errors(fig, lfcs_errs_deseq2, color="green" , name="DESeq2")
plot_errors(fig, lfcs_errs_edger, color="purple" , name="EdgeR")
plot_errors(fig, lfcs_errs_mast, color="orange" , name="MAST")
fig.show()
# iplot(fig, filename='lfc_size_influence')

# Debug

In [None]:
# trace1 = go.Histogram(x=dataset.X.mean(0))
# fig = go.Figure(data=[trace1])
# fig.show()

# dataset.X

# # from torch.autograd import set_detect_anomaly
# # set_detect_anomaly(True)

# vae = VAE(
#     dataset.nb_genes, n_batch=dataset.n_batches, n_hidden=128, n_layers=1, n_latent=5
# )
# trainer = UnsupervisedTrainer(
#     vae, dataset, ratio_loss=True, k_importance_weighted=5, single_backward=False
# )
# trainer.train(n_epochs=50, lr=1e-3)

# trainer.train_losses

# iw_vae = VAE(
#     dataset.nb_genes, n_batch=dataset.n_batches, n_hidden=128, n_layers=1, n_latent=5
# )
# iw_trainer = UnsupervisedTrainer(
#     iw_vae, dataset, ratio_loss=True, k_importance_weighted=5, single_backward=False
# )
# iw_trainer.train(n_epochs=50, lr=1e-3)

# iw_trainer.train_losses

# iavae = IAVAE(
#     dataset.nb_genes,
#     n_batch=dataset.n_batches,
#     n_hidden=128,
#     n_layers=1,
#     do_h=True,
#     n_latent=5,
#     t=4,
# )
# ia_trainer = UnsupervisedTrainer(iavae, dataset, ratio_loss=True)
# ia_trainer.train(n_epochs=100, lr=1e-3)

# ia_trainer.train_losses

# plt.plot(trainer.train_losses[5:], label="MF")
# plt.plot(iw_trainer.train_losses[5:], label="IW")
# plt.plot(ia_trainer.train_losses[5:], label="IAF")
# # plt.yscale("log")
# plt.legend()

# ia_trainer.test_set.marginal_ll(n_mc_samples=100, ratio_loss=True)



# evidence_mf = trainer.test_set.marginal_ll(n_mc_samples=100, ratio_loss=True)
# # evidence_iw = iw_trainer.test_set.marginal_ll(n_mc_samples=100, ratio_loss=True)
# evidence_ia = ia_trainer.test_set.marginal_ll(n_mc_samples=100, ratio_loss=True)
# print(evidence_mf, evidence_iw, evidence_ia)

4395.42928125 4369.418739583333 4391.585421875

The lower the better

Using very quick experiments (I am waiting for the autotune module), it looks like:

MF
- 1 layer 16 941.0715670955882
- 1 layer 64 935.0418129595588
- 1 layer 128 925.3853147977941
- 3 layers 32 972.2561259191176

IAF
- 1 layer 16 950.423221507353
- 1 layer 64 928.0418129595588
- 1 layer 128 917.7115165441177
- 3 layers 32 928.8929044117647
- t=4 1 layer 128 915.8975482536765
- t=5 1 layer 128 920.9016911764705

### Visualizing IAF posteriors

In [None]:
post = ia_trainer.train_set
train_indices = post.data_loader.sampler.indices
train_samples = np.random.permutation(train_indices)[:2000]
post = ia_trainer.create_posterior(
    model=iavae, gene_dataset=dataset, indices=train_samples
)
z_ia, labels_ia, scales_ia = post.get_latents(n_samples=500, other=True, device="cpu")

In [None]:
# post = trainer.train_set
# train_indices = post.data_loader.sampler.indices
# train_samples = np.random.permutation(train_indices)[:2000]
post = trainer.create_posterior(model=vae, gene_dataset=dataset, indices=train_samples)
z, labels, scales = post.get_latents(n_samples=500, other=True, device="cpu")

In [None]:
(labels == labels_ia).float().mean()

In [None]:
for idx in [5, 10, 100, 1000, 2, 30, 542]:
    trace1 = go.Scatter(x=z[:, idx, 0], y=z[:, idx, 1], mode="markers")
    trace2 = go.Scatter(x=z_ia[:, idx, 0], y=z_ia[:, idx, 1], mode="markers")
    fig = go.Figure([trace1, trace2])
    fig.show()

In [None]:
from plotly.subplots import make_subplots
from sklearn.manifold import TSNE


fig = make_subplots(rows=1, cols=2)

for my_z in []:
    z_mean = my_z.mean(0)
    z_tsne = TSNE().fit_transform(z_mean)
    #     z_tnse = z_mean

    fig.add_trace(
        go.Scatter(
            x=z_tnse[:, 0],
            y=z_tnse[:, 1],
            marker=dict(color=my_lbl.squeeze(), colorscale="viridis"),
            mode="markers",
        ),
        row=1,
        col=1,
    )

fig

In [None]:
is_pred_de.astype(bool)

In [None]:
size = 25

In [None]:
labels = labels.squeeze()
where_a = np.where(labels == 0)[0]
where_b = np.where(labels == 1)[0]
where_a = where_a[np.random.choice(len(where_a), size=size)]
where_b = where_b[np.random.choice(len(where_b), size=size)]
scales_a = scales[:, where_a, :].reshape((-1, dataset.nb_genes)).numpy()
scales_b = scales[:, where_b, :].reshape((-1, dataset.nb_genes)).numpy()
scales_a, scales_b = demultiply(arr1=scales_a, arr2=scales_b, factor=3)
lfc = np.log2(scales_a) - np.log2(scales_b)

pgs = (np.abs(lfc) >= DELTA).mean(axis=0)
sorted_genes = np.argsort(-pgs)
sorted_pgs = pgs[sorted_genes]
cumulative_fdr = (1.0 - sorted_pgs).cumsum() / (1.0 + np.arange(len(sorted_pgs)))
d = (cumulative_fdr <= Q0).sum() - 1
pred_de_genes = sorted_genes[:d]
is_pred_de = np.zeros_like(cumulative_fdr)
is_pred_de[pred_de_genes] = True
true_fdr = ((~is_significant_de) * is_pred_de).sum() / len(pred_de_genes)

true_fdr

In [None]:
is_pred_de.mean()

In [None]:
scales_a

In [None]:
sorted_pgs

In [None]:
from sklearn.metrics import confusion_matrix

y_preds_1d = is_pred_de.reshape((-1, dataset.nb_genes))
n_exps = len(y_preds_1d)
mat = confusion_matrix(is_significant_de, is_pred_de)
mat