In [None]:
### import comet_ml at the top of your file
# from comet_ml import Experiment  ### special library to record performance on a web server
# from config import * ### file with personal API_KEY

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from scipy import stats
from skbio.stats.composition import clr
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:1' if use_cuda else 'mps')
print(device)

### Record experiment on comet.com

In [None]:
# Create and initialize comet experiment with your api key
'''
experiment = Experiment(
    api_key=API_KEY,
    project_name="vae-eco-ml",
    workspace="zireae1",
)
'''

### Load clean data

In [None]:
### data: rows are samples, columns are species, relative abundance sum ~ [80-100]

data_df = pd.read_csv("data/wgs_train_health.noab_data_to_ml.filt.txt", sep="\t")
data_df = data_df.to_numpy()
print(data_df.shape)

features = np.copy(data_df)
labels = np.copy(data_df)

#features = np.asarray(features)#/100
print(features.shape[0])
#labels = np.asarray(labels)#/100
print(labels.shape[0])

In [None]:
### Create dataset object
class CustomDataset(TensorDataset):
    def __init__(self, features, labels):
        self.labels = labels
        self.features = features
        #self.device=device
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]#.to(self.device)
        data = self.features[idx]#.to(self.device)
        sample = {"Features": data, "Labels": label}
        return sample

### prepare train test split
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

### clr transformation for outputs
zero_thr = 1e-6
gmean_train = (np.exp(np.nansum(np.log(y_train[y_train > 0]+zero_thr)) / np.size(y_train)))
y_train_clr = np.log((y_train+zero_thr) / gmean_train)
y_test_clr = np.log((y_test+zero_thr)/ gmean_train)

y_train_scaled = y_train_clr
y_test_scaled = y_test_clr

scaler = preprocessing.MaxAbsScaler().fit(y_train_scaled)
#scaler = preprocessing.MinMaxScaler().fit(y_train_scaled)
y_train_scaled = scaler.transform(y_train_scaled)
y_test_scaled = scaler.transform(y_test_scaled)

X_train_scaled = y_train_scaled
X_test_scaled = y_test_scaled

X_train_scaled=torch.from_numpy(X_train_scaled).float()
y_train_scaled=torch.from_numpy(y_train_scaled).float()
X_test_scaled=torch.from_numpy(X_test_scaled).float()
y_test_scaled=torch.from_numpy(y_test_scaled).float()

#### dataset build 
Train = CustomDataset(X_train_scaled, y_train_scaled)
Test = CustomDataset(X_test_scaled, y_test_scaled)

### create batch spits of data
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
train_DS = DataLoader(Train, batch_size=100, shuffle=True, drop_last=True, **kwargs)
test_DS = DataLoader(Test, batch_size=100, shuffle=True, drop_last=True, **kwargs)


In [None]:
### check distribution of the data
plt.hist(X_test_scaled.flatten(), bins=40)
plt.gcf().set_size_inches(5, 5)
#plt.savefig('wgs_filt.pdf', dpi=1000) 

plt.show()

In [None]:
### Calculate mean for training to compute residual loss:
train_means=np.mean(y_train_scaled.numpy(), axis=0)
#print(np.mean(train_means))
### more sophisticated way means (only for non-zero species)[need to modify]
#train_means=(torch.sum(X_train_scaled*y_train_scaled, axis=0)/torch.sum(X_train_scaled, axis=0)).numpy()
train_means

### Implement VAE 

In [None]:
relu = torch.nn.ReLU()
#relu = torch.nn.ELU()
input_dim=y_train.shape[1]
class VAE(torch.nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 1024)
        #self.fc2 = torch.nn.Linear(90, 80)
        self.fc3a = torch.nn.Linear(1024, 64)
        self.fc3b = torch.nn.Linear(1024, 64)
        self.fc4 = torch.nn.Linear(64, 1024)
        #self.fc5 = torch.nn.Linear(26, 52)
        self.fc6 = torch.nn.Linear(1024, input_dim)
        
        # Define proportion or neurons to dropout
        self.dropout = torch.nn.Dropout(0.1)

    def encode(self, x):  
        z = self.fc1(x)
        z = self.dropout(z)
        
        #z = torch.tanh(z)
        z = relu(z)
        z = self.dropout(z)
        
        #z = self.fc2(z)
        #z = self.dropout(z)
        
        #z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
        
        z1 = self.fc3a(z)  # u 
        z2 = self.fc3b(z)  # logvar

        return (z1, z2)

    def decode(self, z):  
        z = self.fc4(z)
        z = self.dropout(z)
        
        #z = torch.tanh(z)
        z = relu(z)
        z = self.dropout(z)
        
        #z = self.fc5(z)
        #z = self.dropout(z)
        
        #z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
        
        z = self.fc6(z)
        z = self.dropout(z)
        
        #z = torch.sigmoid(z) ### turn off for abundance prediction
        z = torch.tanh(z)
        #z = relu(z)
        return z

    def forward(self, x):  # 784-400-[20,20]-20-400-784
        x = x.view(-1, input_dim)
        (u, logvar) = self.encode(x)
        stdev = torch.exp(0.5 * logvar)
        noise = torch.randn_like(stdev)
        z = u + 1 * (noise * stdev)  # no noise variation!!!
        z = self.decode(z)     # 20-400-784
        return (z, u, logvar)

### try to penalize many non-zero labels as in L1 lasso?
def final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x):
    """
    This function will add the reconstruction loss and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param mse_loss: mse recontruction loss
    :param bce_loss: bce recontruction loss
    :param spr_loss: custom term in recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    MSE = mse_loss
    SPR = spr_loss
    KLD = -0.5 * torch.sum(1 + logvar - u.pow(2) - logvar.exp())
    
    return 0.01 * KLD + MSE + 10 * SPR + 1 * BCE


### Model Initialization and Training

In [None]:
model = VAE()#.to(device)

# Validation using MSE Loss function
criterion_mse = torch.nn.MSELoss(reduction='sum')
criterion_bce = torch.nn.BCELoss(reduction='sum')
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')
#criterion_l1 = torch.nn.L1Loss()
#criterion_kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)

# Define L2 regularization strength for weights
l2_lambda = 0.01

# Using an Adam Optimizer 
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3, weight_decay = 1e-7)
### iterate through batches
epochs = 1000
zero_thr = -0.6
outputs = []
train_losses = []
test_losses = []
train_acc = []
test_acc = []
train_corr = []
test_corr = []
train_grad = []
epochs_without_improvement=0
patience=100

for epoch in range(epochs):
    #print('Epoch number: ', epoch)
    model = model.train()
    train_loss = 0.0        # accumulated custom loss
    test_loss = 0.0        # accumulated custom loss
    corr=0.0
    grad_mag=0.0
    accs=0.0
    
    ### VAE training
    for (idx, batch) in enumerate(train_DS):

        # Transfer to GPU
        #batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)

        # Output of VAE
        optimizer.zero_grad()
        recon_x, u, logvar = model(batch['Features'])

        ### calculate MSE loss on a data
        mse_loss = criterion_mse(recon_x, batch['Labels'])
        
        spr_loss = criterion_mse(recon_x-torch.from_numpy(train_means)[None,:], batch['Labels']-torch.from_numpy(train_means)[None,:])
        out = ((recon_x.detach())>zero_thr).float() ### for MaxAbs scaled
        
        #bce_loss = criterion_bce(out, batch['Features']) #for binary-abund
        bce_loss=0
        
        # Calculating THE loss function
        loss_val = final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x)
        
        # Compute L2 regularization loss
        l2_reg_loss = 0.0
        for param in model.parameters():
            if param.requires_grad:
                l2_reg_loss += torch.norm(param, 2)
        
        # Add L2 regularization loss to the reconstruction loss
        total_loss = loss_val + l2_lambda * l2_reg_loss

        total_loss.backward()
        train_loss += loss_val.item()
        
        true_bin = ((batch['Features'])>zero_thr).float() ### for MaxAbs scaled
        acc = (out == true_bin).float().mean()
  
        acc = float(acc)
        accs +=acc
        
        optimizer.step()
        
    # Storing the losses in a list for plotting
    train_losses.append(train_loss/len(train_DS))
    train_acc.append(accs/len(train_DS))
    train_grad.append(grad_mag)

    accs=0.0
    corr=0.0
    
    ### VAE evaluation
    with torch.no_grad():
        model = model.eval()
        for (idx, batch) in enumerate(test_DS):
            # Transfer to GPU
            #batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)

            recon_x, u, logvar = model(batch['Features'])
            # Calculating the loss function
            mse_loss = criterion_mse(recon_x, batch['Labels'])
    
            out = ((recon_x.detach())>zero_thr).float() #for MaxAbs scaled

            bce_loss=0

            spr_loss = criterion_mse(recon_x-torch.from_numpy(train_means)[None,:], batch['Labels']-torch.from_numpy(train_means)[None,:])
            
            loss_val = final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x)
            test_loss += loss_val.item()

            # Calculate accuracy
            true_bin = ((batch['Features'])>zero_thr).float() ### for MaxAbs scaled
            acc = (out == true_bin).float().mean() # for binary-abund
            acc = float(acc)
            accs +=acc

        # Storing the losses in a list for plotting
        test_losses.append(test_loss/len(test_DS))
        test_acc.append(accs/len(test_DS))
    
    ### Store metrics for comet:
    '''
    metrics = {'train loss': train_loss/len(train_DS), 
               'test loss': test_loss/len(test_DS), 
               'accuracy': accs/len(test_DS),
               #'gradient': grad_mag,
               #'test_corr': corr/X_test.shape[0]
               }
    experiment.log_metrics(metrics, step=epoch)
    '''
    
    if train_loss>test_loss:
        epochs_without_improvement=0
    else:
        epochs_without_improvement+=1
 
    if epochs_without_improvement==patience:
        print("early stop")
        break
    #outputs.append((epochs, batch['Features'], reconstructed))
# experiment.end() ### end comet experiment

In [None]:
### check the distribution of reconstructed data:

plt.hist(recon_x.detach().numpy().flatten(), bins=40)
plt.gcf().set_size_inches(5, 5)
#plt.savefig('wgs_filt_clr_mse_loss_wide_v2.pdf', dpi=1000) 
plt.show()


### Look at test set reconstruction quality

In [None]:
#### check how good the reconstructions are (plot matrices as heatmaps):
for (idx, batch) in enumerate(test_DS):
    print(idx)
    true_x = batch['Labels'].detach().numpy()
    #print(true_x)
    recon_x, u, logvar = model(batch['Features'])
    
    fig, (ax1, ax2) = plt.subplots(1, 2)
    #fig.suptitle('True vs Reconstructed')
    
    ax1.imshow(true_x[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    ax1.set_title('True')
    
    a = recon_x.detach().numpy()

    ax2.imshow(a[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    ax2.set_title('Reconstructed')
    
    print(stats.spearmanr(true_x.flatten(), a.flatten()))
    #plt.savefig('batch_reconstruction_masked_v1_{}_wide.pdf'.format(idx), dpi=1300) 
    plt.show()


In [None]:
#### check how good the reconstructions are (plot predicted vs observed scatter plot):

for (idx, batch) in enumerate(test_DS):
    print(idx)
    true_x = batch['Labels'].detach().numpy()
    #print(true_x)
    recon_x, u, logvar = model(batch['Features'])
    
    a = recon_x.detach().numpy()
    #fig.suptitle('True vs Reconstructed')
    #plt.scatter(x=true_x, y=a)
    plt.hist2d(x=true_x.flatten(), y=a.flatten(), norm=matplotlib.colors.PowerNorm(1/10), bins=100)

    #ax1.imshow(true_x[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    #ax1.set_title('True')
    
    

    #ax2.imshow(a[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    #ax2.set_title('Reconstructed')
    
    print(stats.spearmanr(true_x.flatten(), a.flatten()))
    #plt.savefig('batch_reconstruction_masked_v1_{}_wide.pdf'.format(idx), dpi=1300) 
    plt.show()


### Count true/false positive/negative _per species_ binary statistics

In [None]:
# calculate true positive:
tps=np.zeros((true_x.shape[1], len(test_DS)))
tns=np.zeros((true_x.shape[1], len(test_DS)))
fps=np.zeros((true_x.shape[1], len(test_DS)))
fns=np.zeros((true_x.shape[1], len(test_DS)))
prev=np.zeros((true_x.shape[1], len(test_DS)))
zero_thr = -0.6
accs=0

for (idx, batch) in enumerate(test_DS):
    print("num")
    true_x_bin = (batch['Features']>zero_thr).float()
    
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=(recon_x.detach()>zero_thr).float()
    
    acc = (recon_x_bin == true_x_bin).float().mean()
    acc = float(acc)
    accs +=acc
    
    true_x_bin=true_x_bin.numpy()
    recon_x_bin=recon_x_bin.numpy()
    
    print(recon_x_bin.shape)
    prev[:,idx]=np.mean(true_x_bin, axis=0)
    for i in range(true_x_bin.shape[1]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[0]):
            if(true_x_bin[j,i]==1):
                if(recon_x_bin[j,i]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[j,i]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn
    
sum_tps=np.sum(tps, axis=1)
sum_tns=np.sum(tns, axis=1)
sum_fps=np.sum(fps, axis=1)
sum_fns=np.sum(fns, axis=1)
sum_prev=np.mean(prev, axis=1)

sensitivity=sum_tps/(sum_tps+sum_fns)#+0.5)
specificity=sum_tns/(sum_tns+sum_fps)#+0.5)   
    
print(accs/len(test_DS))

In [None]:
plt.scatter(x=sensitivity, y=specificity)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Sensitivity")
plt.ylabel("Specificity")
plt.legend(["Spearman r= -0.01"], frameon=False, loc=3)
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")

plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")
plt.plot([0.4, 1], [1, 0.4], ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(sensitivity, specificity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('VAE_wgs_filt_nonb_per_species_nbnb.pdf', dpi=1000) 

plt.show()

In [None]:
plt.scatter(x=sum_prev, y=specificity)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Prevalence")
plt.ylabel("Specificity")
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")
plt.legend(["Spearman r= -0.54"], frameon=False, loc=3)

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(sum_prev, specificity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('VAE_wgs_filt_nonb_per_species_prev_nbnb.pdf', dpi=1000) 

plt.show()

In [None]:
#prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
plt.scatter(x=sum_prev, y=sensitivity)

plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Prevalence")
plt.ylabel("Sensitivity")
plt.legend(["Spearman r= 0.73"], frameon=False, loc=4)

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(sum_prev, sensitivity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('VAE_wgs_filt_nonb_per_species_prev_sens_nbnb.pdf', dpi=1000) 

plt.show()

### Grab taxonomic info for features (i.e. columns)

In [None]:
taxonomy_df = pd.read_csv("data/wgs_train_health.noab_data_to_ml.taxonomy.txt", sep="\t")
taxonomy_df['prev']=sum_prev
taxonomy_df

In [None]:
taxonomy_df['tp']=sum_tps
taxonomy_df['tn']=sum_tns
taxonomy_df['fp']=sum_fps
taxonomy_df['fn']=sum_fns
taxonomy_df

In [None]:
taxonomy_df=taxonomy_df.sort_values(by=['prev'], ascending=False)
taxonomy_df

In [None]:
# plot bars in stack manner
plt.bar(taxonomy_df['scientific_name'], taxonomy_df['tp'], color='r')
plt.bar(taxonomy_df['scientific_name'], taxonomy_df['tn'], bottom=taxonomy_df['tp'], color='b')
plt.bar(taxonomy_df['scientific_name'], taxonomy_df['fp'], bottom=taxonomy_df['tp']+taxonomy_df['tn'], color='y')
plt.bar(taxonomy_df['scientific_name'], taxonomy_df['fn'], bottom=taxonomy_df['tp']+taxonomy_df['tn']+taxonomy_df['fp'], color='g')
plt.ylim([0,len(test_DS)*100])
plt.xlabel("Species")
plt.xticks(fontsize=4, rotation = 90)
plt.ylabel("total # of examples")
plt.legend(["tp", "tn", "fp", "fn"])
plt.title("Accuracy metrics sorted by prevalence")
plt.savefig('graphs/VAE_wgs_species_acc_test.pdf', dpi=1000, bbox_inches='tight') 

plt.show()

In [None]:
plt.bar(taxonomy_df['scientific_name'], taxonomy_df['prev']*len(test_DS)*100, color='r')
plt.xlabel("Species")
plt.ylim([0,len(test_DS)*100])
plt.xticks(fontsize=4, rotation = 90)
plt.ylabel("total # of examples")
plt.title("Species sorted by prevalence") ### i.e. all positives
plt.savefig('graphs/VAE_wgs_species_prev_test.pdf', dpi=1000, bbox_inches='tight') 

plt.show()

### This is how one can evaluate model on external validation data (i.e. rerun the above plots):

In [None]:
### data in the same format as data_df (rows - samples, columns species in the same order)
validation_df = pd.read_csv("data/wgs_test_health.ab_data_to_ml.txt", sep="\t")
#validation_df = pd.read_csv("data/wgs_test_disease_data_to_ml.txt", sep="\t")

print(validation_df)

validation=validation_df.to_numpy()

zero_thr = 1e-6
validation_clr = np.log((validation+zero_thr) / gmean_train)
validation_scaled = scaler.transform(validation_clr)

X_validation_scaled = validation_scaled
y_validation_scaled = validation_scaled

X_validation_scaled=torch.from_numpy(X_validation_scaled).float()
y_validation_scaled=torch.from_numpy(y_validation_scaled).float()

#### dataset build 
Validation = CustomDataset(X_validation_scaled, y_validation_scaled)

### create batch spits of data
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
validation_DS = DataLoader(Validation, batch_size=100, shuffle=True, drop_last=True, **kwargs)

In [None]:
# calculate true positive:
tps=np.zeros((true_x.shape[1], len(validation_DS)))
tns=np.zeros((true_x.shape[1], len(validation_DS)))
fps=np.zeros((true_x.shape[1], len(validation_DS)))
fns=np.zeros((true_x.shape[1], len(validation_DS)))
prev=np.zeros((true_x.shape[1], len(validation_DS)))
#i=np.zeros((true_x.shape[1], 2))
accs=0
zero_thr = -0.6
for (idx, batch) in enumerate(validation_DS):
    print("num")
    
    true_x_bin = (batch['Features']>zero_thr).float()
    
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=(recon_x.detach()>zero_thr).float()
 
    acc = (recon_x_bin == true_x_bin).float().mean()
    acc = float(acc)
    accs +=acc
    
    true_x_bin=true_x_bin.numpy()
    recon_x_bin=recon_x_bin.numpy()
    
    print(recon_x_bin.shape)
    prev[:,idx]=np.sum(true_x_bin, axis=0)
    for i in range(true_x_bin.shape[1]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[0]):
            if(true_x_bin[j,i]==1):
                if(recon_x_bin[j,i]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[j,i]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        #print(idx)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn
        #print(tp+fn)
        #print(np.sum(true_x_bin[:,i]))
    
sum_tps=np.sum(tps, axis=1)
sum_tns=np.sum(tns, axis=1)
sum_fps=np.sum(fps, axis=1)
sum_fns=np.sum(fns, axis=1)
sum_prev=np.sum(prev, axis=1)/len(validation_DS)/100
#print(sum_tps)
#sum_tns=tns[:,0]+tns[:,1]
#sum_fps=fps[:,0]+fps[:,1]
#sum_fns=tps[:,0]+fns[:,1]

sensitivity=sum_tps/(sum_tps+sum_fns)#+0.5)
specificity=sum_tns/(sum_tns+sum_fps)#+0.5)   

print(accs/len(validation_DS))

### Count true/false positive/negative _per sample_ statistics

In [None]:
### Check sample to sample variability

# calculate true positive:
tps=np.zeros((true_x.shape[0], len(test_DS)))
tns=np.zeros((true_x.shape[0], len(test_DS)))
fps=np.zeros((true_x.shape[0], len(test_DS)))
fns=np.zeros((true_x.shape[0], len(test_DS)))
#i=np.zeros((true_x.shape[1], 2))
for (idx, batch) in enumerate(test_DS):
    zero_thr = -0.6
    true_x_bin = (batch['Features']>zero_thr).float().numpy()
    
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=(recon_x.detach()>zero_thr).float().numpy()

    print(recon_x_bin.shape)
    for i in range(true_x_bin.shape[0]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[1]):
            if(true_x_bin[i,j]==1):
                if(recon_x_bin[i,j]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[i,j]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        #print(idx)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn


sum_tps=tps.flatten()
sum_tns=tns.flatten()
sum_fps=fps.flatten()
sum_fns=fns.flatten()

sensitivity=sum_tps/((sum_tps+sum_fns))#+0.5)
specificity=sum_tns/((sum_tns+sum_fps))#+0.5)

In [None]:
plt.scatter(x=sensitivity, y=specificity, alpha=0.3)
plt.ylim([0, 1.1])
plt.xlim([0, 1.1])
plt.xlabel("Sensitivity")
plt.ylabel("Specificity")
plt.legend(["Spearman r= -0.98"], frameon=False, loc=3)
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")

plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")
plt.plot([0.4, 1], [1, 0.4], ls="--", linewidth=".3", c=".3")
plt.plot([0.6, 1], [1, 0.6], ls="--", linewidth=".3", c=".3")
plt.plot([0.8, 1], [1, 0.8], ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(sensitivity, specificity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('graphs/VAE_wgs_filt_nonbin_per_samples_train_nbnb.pdf', dpi=1000) 

plt.show()