In [1]:
import numpy as np
import os
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch
import math

#os.chdir('/Users/dbm829/Documents/work/Projects/GitHub/omicsDGD/')
save_dir = '../../results/trained_models/'
data_name = 'mouse_gastrulation'
random_seed = 0

  from .autonotebook import tqdm as notebook_tqdm


## functions

In [14]:
import torch

def compute_error_per_sample(target, output, reduction_type='ms'):
    '''compute sample-wise error
    It can be of type `ms` (mean squared) or `ma` (mean absolute)
    '''
    error = target - output
    if reduction_type == 'ms':
        return torch.mean(error**2, dim=-1)
    elif reduction_type == 'ma':
        return torch.mean(torch.abs(error), dim=-1)
    else:
        raise ValueError('invalid reduction type given. Can only be `ms` or `ma`.')

In [3]:
def binary_output_scores(target, output, scaling_factor, switch, threshold, batch_size=5000, feature_indices=None):
    '''returns FPR, FNR, balanced accuracy, LR+ and LR-'''
    tp, fp, tn, fn = classify_binary_output(target, output, scaling_factor, switch, threshold, batch_size, feature_indices)
    tp = tp.sum()
    fp = fp.sum()
    tn = tn.sum()
    fn = fn.sum()
    tpr = tp / (tp + fn) # sensitivity
    tnr = tn / (tn + fp) # specificity
    fpr = 1 - tnr
    fnr = 1 - tpr
    balanced_accuracy = (tpr + tnr) / 2
    positive_likelihood_ratio = tpr/fpr
    negative_likelihood_ratio = fnr/tnr

    return tpr.item(), tnr.item(), balanced_accuracy.item(), positive_likelihood_ratio.item(), negative_likelihood_ratio.item()

def balanced_accuracy_with_sem(target, output, scaling_factor, switch, threshold, batch_size=5000, feature_indices=None):
    '''returns FPR, FNR, balanced accuracy, LR+ and LR-'''
    tp, fp, tn, fn = classify_binary_output(target, output, scaling_factor, switch, threshold, batch_size, feature_indices)
    tpr = tp / (tp + fn) # sensitivity
    tnr = tn / (tn + fp) # specificity
    fpr = 1 - tnr
    fnr = 1 - tpr
    balanced_accuracy = (tpr + tnr) / 2
    #ba_mean = balanced_accuracy.clone().mean().item()
    _, _, ba_mean, _, _ = binary_output_scores(target, output, scaling_factor, switch, threshold, batch_size, feature_indices)
    ba_error = balanced_accuracy.std() / math.sqrt(balanced_accuracy.shape[0])

    return ba_mean, ba_error.item()

def classify_binary_output(target, output, scaling_factor, switch, threshold, batch_size=5000, feature_indices=None):
    '''calculating true positives, false positives, true negatives and false negatives'''
    #print('classifying binary output')
    
    n_samples = target.shape[0]
    true_positives = torch.zeros((n_samples))
    false_positives = torch.zeros((n_samples))
    true_negatives = torch.zeros((n_samples))
    false_negatives = torch.zeros((n_samples))
    
    for i in range(int(n_samples/batch_size)+1):
        #print(round(i/(int(n_samples/batch_size))*100),'%')
        start = i*batch_size
        end = min((i+1)*batch_size,n_samples)
        indices = np.arange(start,end,1)
        x_accessibility = binarize(torch.Tensor(target[indices,:])).int()
        y_accessibility = output[indices,:]
        if type(y_accessibility) is not torch.Tensor:
            if type(y_accessibility) == pd.core.frame.DataFrame:
                y_accessibility = torch.from_numpy(y_accessibility.values)
                y_accessibility = y_accessibility.detach().cpu()
        else:
            y_accessibility = y_accessibility.detach().cpu()*scaling_factor[indices]
        y_accessibility = binarize(y_accessibility, threshold).int()
        if feature_indices is not None:
            x_accessibility = x_accessibility[:,feature_indices]
            y_accessibility = y_accessibility[:,feature_indices]
        p = (x_accessibility == 1)
        pp = (y_accessibility == 1)
        true_positives[indices] = torch.logical_and(p,pp).sum(-1).float()
        true_negatives[indices] = torch.logical_and(~p,~pp).sum(-1).float()
        false_positives[indices] = (y_accessibility > x_accessibility).sum(-1).float()
        false_negatives[indices] = (y_accessibility < x_accessibility).sum(-1).float()
    
    return true_positives, false_positives, true_negatives, false_negatives

def binarize(x, threshold=0.5):
    x[x >= threshold] = 1
    x[x < threshold] = 0
    return x

# recalculate losses for sample-wise with SEM

In [9]:
# read the test and reconstructions from the saved files
test_gex = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_test_counts_gex.npy')
recon_gex = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_l20_h2-2_rs0_test_recon_gex.npy')

In [7]:
n_samples = test_gex.shape[0]
errors = compute_error_per_sample(torch.tensor(test_gex), torch.tensor(recon_gex)*torch.tensor(test_gex).sum(axis=1).unsqueeze(1), reduction_type='ms')
out_errors = torch.sqrt(errors)
out_error_mean = out_errors.clone().mean()
out_error_se = out_errors.clone().std() / math.sqrt(test_gex.shape[0])
print('RMSE: ', out_error_mean.item(), ' +/- ', out_error_se.item())

RMSE:  1.714585781097412  +/-  0.012759659439325333


In [4]:
# read the test and reconstructions from the saved files
test_atac = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_test_counts_atac.npy')
recon_atac = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_l20_h2-2_rs0_test_recon_atac.npy')

In [16]:
# compute loss for ATAC data

threshold = 0.2
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(test_atac, torch.tensor(recon_atac), torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)

balanced accuracy:  0.7324193716049194  +/-  0.0006929467199370265


### now make a file for the feature selection indices

In [8]:
# get the data and the indices of the overlapping features
# first the full set
"""
import anndata as ad
import numpy as np
import mudata as md
gex = ad.read_h5ad('../../data/mouse_gastrulation/raw/anndata.h5ad')
atac = ad.read_h5ad('../../data/mouse_gastrulation/raw/PeakMatrix_anndata.h5ad')
ids_shared = list(set(gex.obs['sample'].index.values).intersection(set(atac.obs['sample'].index.values)))
ids_gex = np.where(gex.obs['sample'].index.isin(ids_shared))[0]
ids_atac = np.where(atac.obs['sample'].index.isin(ids_shared))[0]
gex = gex[ids_gex]
atac = atac[ids_atac]
threshold = 0.00
mudata = md.MuData({'rna': gex, 'atac': atac})
mudata.obs['stage'] = mudata['atac'].obs['stage']
mudata.obs['celltype'] = mudata['rna'].obs['celltype']
train_indices = np.where(mudata.obs["train_val_test"] == "train")[0]
trainset = mudata[train_indices,:].copy()
test_indices = np.where(mudata.obs["train_val_test"] == "test")[0]
testset = mudata[test_indices,:].copy()
mudata, gex, atac = None, None, None
modality_switch_full = testset['rna'].X.shape[1]

# now the subset
data_subset = md.read('../../data/mouse_gastrulation.h5mu', backed=False)
rna_indices = [i for i, x in enumerate(testset.var.index) if x in data_subset['rna'].var.index]
atac_indices = [i-modality_switch_full for i, x in enumerate(testset.var.index) if x in data_subset['atac'].var.index]

# save indices as csv file
indices_df = pd.concat(
    (pd.DataFrame({'idx': rna_indices,
                           'modality': 'rna'}),
    pd.DataFrame({'idx': atac_indices,
                           'modality': 'atac'})), axis=0
)
#indices_df.to_csv('data/mouse_gastrulation/five_percent_indices.csv')
"""

  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col


In [2]:
indices_df = pd.read_csv('../../../data/mouse_gastrulation_five_percent_indices.csv')
rna_indices = indices_df[indices_df['modality'] == 'rna']['idx'].values
atac_indices = indices_df[indices_df['modality'] == 'atac']['idx'].values

In [11]:
recon_gex = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_l20_h2-2_rs0_scale5_featselect0_test_recon_gex.npy')
recon_gex = recon_gex[:, rna_indices]

In [12]:
import torch
import math
n_samples = test_gex.shape[0]
errors = compute_error_per_sample(torch.tensor(test_gex), torch.tensor(recon_gex)*torch.tensor(test_gex).sum(axis=1).unsqueeze(1), reduction_type='ms')
out_errors = torch.sqrt(errors)
out_error_mean = out_errors.clone().mean()
out_error_se = out_errors.clone().std() / math.sqrt(test_gex.shape[0])
print('RMSE: ', out_error_mean.item(), ' +/- ', out_error_se.item())

RMSE:  1.6939202547073364  +/-  0.01272590272128582


In [7]:
recon_atac = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_l20_h2-2_rs0_scale5_featselect0_test_recon_atac.npy')
recon_atac = recon_atac[:, atac_indices]

In [10]:
# compute loss for ATAC data # old version

threshold = 0.2
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(test_atac, torch.tensor(recon_atac), torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)

balanced accuracy:  0.739676833152771  +/-  0.0007066355901770294


In [23]:
# compute loss for ATAC data
"""
threshold = 0.2
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(test_atac, torch.tensor(recon_atac), torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)
"""

classifying binary output
0 %
100 %
balanced accuracy:  0.6792690753936768  +/-  0.0007066355901770294


### MVI

In [5]:
import anndata as ad
import numpy as np
import mudata as md
gex = ad.read_h5ad('../../../data/raw/mouse_gastrulation_anndata.h5ad')
atac = ad.read_h5ad('../../../data/raw/mouse_gastrulation_PeakMatrix_anndata.h5ad')
ids_shared = list(set(gex.obs['sample'].index.values).intersection(set(atac.obs['sample'].index.values)))
ids_gex = np.where(gex.obs['sample'].index.isin(ids_shared))[0]
ids_atac = np.where(atac.obs['sample'].index.isin(ids_shared))[0]
gex = gex[ids_gex]
atac = atac[ids_atac]
threshold = 0.00
mudata = md.MuData({'rna': gex, 'atac': atac})
mudata_original = md.read('../../../data/mouse_gastrulation.h5mu', backed=False)
mudata.obs = mudata_original.obs.copy()
mudata_original = None
train_indices = np.where(mudata.obs["train_val_test"] == "train")[0]
trainset = mudata[train_indices,:].copy()
test_indices = np.where(mudata.obs["train_val_test"] == "test")[0]
testset = mudata[test_indices,:].copy()
modality_switch_full = testset['rna'].X.shape[1]

  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col


In [6]:
library = torch.cat(
    (torch.tensor(np.asarray(testset['rna'].X.sum(-1))),
    torch.tensor(np.asarray(testset['atac'].X.sum(-1)))),
    dim=1
)

In [7]:
# first the model on all data

import scipy
import scvi
train_stages = trainset.obs['stage'].values
test_stages = testset.obs['stage'].values
# now VAE
trainset = ad.AnnData(scipy.sparse.hstack((trainset['rna'].X,trainset['atac'].X))) # making test set anndata
trainset.var_names_make_unique()
trainset.obs['stage'] = train_stages
trainset.obs['modality'] = 'paired'

  Referenced from: <08E12B12-6183-307E-BDA0-374FA8EBA2C9> /Users/dbm829/Library/Python/3.9/lib/python/site-packages/torchvision/image.so
  warn(
  self.seed = seed
  self.dl_pin_memory_gpu_training = (


In [14]:
testset

In [8]:
scvi.model.MULTIVI.setup_anndata(trainset, batch_key='stage')
testset = ad.AnnData(scipy.sparse.hstack((testset['rna'].X,testset['atac'].X)))
testset.var_names_make_unique()
testset.obs['stage'] = test_stages
testset.obs['modality'] = 'paired'
#testset.obs['_indices'] = np.arange(testset.n_obs)
scvi.model.MULTIVI.setup_anndata(testset, batch_key='stage')

  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [9]:
data_name = "mouse_gastrulation"
model_name = "l20_e2_d2_rs0_featselect0"
model = scvi.model.MULTIVI.load(
    save_dir+'multiVI/'+data_name+'/'+model_name,
    adata=trainset
)

[34mINFO    [0m File ..[35m/../results/trained_models/multiVI/mouse_gastrulation/l20_e2_d2_rs0_featselect0/[0m[95mmodel.pt[0m already   
         downloaded                                                                                                


  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [10]:
trainset = None
mudata = None

In [15]:
test_gex = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_test_counts_gex.npy')
test_atac = np.load('../../results/analysis/performance_evaluation/reconstruction/mouse_gast_test_counts_atac.npy')

In [11]:
def compute_expression_error(target, mod, scaling_factor, switch, batch_size=5000, error_type='rmse', feature_indices=None):
    '''computes expression error for target (given as anndata object)'''
    n_samples = target.shape[0]

    errors = torch.zeros((n_samples))

    for i in range(int(n_samples/batch_size)+1):
        print('   ',round(i/(int(n_samples/batch_size))*100),'%')
        start = i*batch_size
        end = min((i+1)*batch_size,n_samples)
        indices = np.arange(start,end,1)
        #target.n_vars = switch # because of multivi
        y_expression = mod.get_normalized_expression(target, indices=indices)
        if type(y_expression) is not torch.Tensor:
            if type(y_expression) == pd.core.frame.DataFrame:
                y_expression = torch.from_numpy(y_expression.values)
        y_expression *= scaling_factor[indices]
        if feature_indices is not None:
            y_expression = y_expression[:,feature_indices]
            x_expression = torch.Tensor(target.X[indices,:switch].todense())[:,feature_indices]
        else:
            x_expression = torch.Tensor(target.X[indices,:switch].todense())
        #print(y_expression[:10,:10])
        #print(torch.Tensor(target.X[indices,:switch].todense())[:10,:10])
        errors[indices] = compute_error_per_sample(x_expression, y_expression, reduction_type='ms')
    
    return errors

In [18]:
mvi_rna_errors = compute_expression_error(testset, model, library[:,0].unsqueeze(1), modality_switch_full, batch_size=5000, feature_indices=rna_indices)

    0 %
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


    100 %


In [19]:
mvi_rna_errors.shape

torch.Size([5686])

In [20]:
out_errors = torch.sqrt(mvi_rna_errors)
out_error_mean = out_errors.clone().mean()
out_error_se = out_errors.clone().std() / math.sqrt(test_gex.shape[0])
print('RMSE: ', out_error_mean.item(), ' +/- ', out_error_se.item())

RMSE:  2.4637396335601807  +/-  0.021737542003393173


In [12]:
from omicsdgd.functions._analysis import classify_binary_output, binary_output_scores

def balanced_accuracy_with_sem(target, mod, scaling_factor, switch, threshold, batch_size=1000, feature_indices=None):
    '''returns FPR, FNR, balanced accuracy, LR+ and LR-'''
    tp, fp, tn, fn = classify_binary_output(target, mod, scaling_factor, switch, threshold, batch_size, feature_indices)
    tpr = tp / (tp + fn) # sensitivity
    tnr = tn / (tn + fp) # specificity
    fpr = 1 - tnr
    fnr = 1 - tpr
    balanced_accuracy = (tpr + tnr) / 2
    #ba_mean = balanced_accuracy.clone().mean().item()
    _, _, ba_mean, _, _ = binary_output_scores(target, mod, scaling_factor, switch, threshold, batch_size, feature_indices)
    ba_error = balanced_accuracy.std() / math.sqrt(balanced_accuracy.shape[0])

    return ba_mean, ba_error.item()

In [13]:
# compute loss for ATAC data # original mean calc

threshold = 0.5
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(testset, model, torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold, feature_indices=atac_indices)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)

classifying binary output
0 %
[34mINFO    [0m Input AnnData not setup with scvi-tools. attempting to transfer AnnData setup                             


  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


20 %
40 %
60 %
80 %
100 %
classifying binary output
0 %
20 %
40 %
60 %
80 %
100 %
balanced accuracy:  0.5196595191955566  +/-  4.284674650989473e-05


In [28]:
# compute loss for ATAC data
"""
threshold = 0.5
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(testset, model, torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold, feature_indices=atac_indices)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)
"""

classifying binary output
0 %
100 %
balanced accuracy:  0.499816358089447  +/-  4.2846764699788764e-05


### now the original model (on feature subset)

In [5]:
import anndata as ad
import numpy as np
import mudata as md
import scipy
data_name = 'mouse_gastrulation'
mudata = md.read("../../../data/mouse_gastrulation.h5mu", backed=False)
modality_switch = mudata["rna"].X.shape[1]
adata = ad.AnnData(scipy.sparse.hstack((mudata["rna"].X, mudata["atac"].X)))
adata.obs = mudata.obs
mudata = None
adata.var["feature_type"] = "ATAC"
adata.var["feature_type"][:modality_switch] = "GEX"
train_indices = list(np.where(adata.obs["train_val_test"] == "train")[0])
test_indices = list(np.where(adata.obs["train_val_test"] == "test")[0])
adata.var_names_make_unique()
adata.obs["modality"] = "paired"
trainset = adata[train_indices]
testset = adata[test_indices]

  data_mod.loc[:, colname] = col
  data_mod.loc[:, colname] = col


In [7]:
import scvi
scvi.model.MULTIVI.setup_anndata(trainset, batch_key='stage')
scvi.model.MULTIVI.setup_anndata(testset, batch_key='stage')

  Referenced from: <08E12B12-6183-307E-BDA0-374FA8EBA2C9> /Users/dbm829/Library/Python/3.9/lib/python/site-packages/torchvision/image.so
  warn(
  self.seed = seed
  self.dl_pin_memory_gpu_training = (
  adata.obs["_indices"] = np.arange(adata.n_obs)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)
  adata.obs["_indices"] = np.arange(adata.n_obs)
  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [8]:
model_name = 'l20_e2_d2'
model = scvi.model.MULTIVI.load(
        save_dir+'multiVI/'+data_name+'/'+model_name,
        adata=trainset
    )
trainset = None
adata = None

[34mINFO    [0m File ..[35m/../results/trained_models/multiVI/mouse_gastrulation/l20_e2_d2/[0m[95mmodel.pt[0m already downloaded        


  _verify_and_correct_data_format(adata, self.attr_name, self.attr_key)


In [16]:
mvi_rna_errors = compute_expression_error(testset, model, torch.tensor(test_gex).sum(1).unsqueeze(1), modality_switch, batch_size=5000)
out_errors = torch.sqrt(mvi_rna_errors)
out_error_mean = out_errors.clone().mean()
out_error_se = out_errors.clone().std() / math.sqrt(test_gex.shape[0])
print('RMSE: ', out_error_mean.item(), ' +/- ', out_error_se.item())

    0 %
    100 %
RMSE:  2.342944383621216  +/-  0.02120981365442276


In [17]:
# compute loss for ATAC data # with original mean calc

threshold = 0.5
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(testset, model, torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)

classifying binary output
0 %
20 %
40 %
60 %
80 %
100 %
classifying binary output
0 %
20 %
40 %
60 %
80 %
100 %
balanced accuracy:  0.7121416330337524  +/-  0.00034910012618638575


In [36]:
# compute loss for ATAC data
"""
threshold = 0.5
balanced_accuracy_mean, balanced_accuracy_sem = balanced_accuracy_with_sem(testset, model, torch.tensor(test_atac).sum(1).unsqueeze(1), test_gex.shape[1], threshold)
print('balanced accuracy: ', balanced_accuracy_mean, ' +/- ', balanced_accuracy_sem)
"""

classifying binary output
0 %
100 %
balanced accuracy:  0.6968663930892944  +/-  0.00034910012618638575
