In [None]:
## import comet_ml at the top of your file
from comet_ml import Experiment

from config import *

import torch
#import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from scipy import stats
from skbio.stats.composition import clr
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:1' if use_cuda else 'mps')
print(device)

In [None]:
# Create an experiment with your api key
experiment = Experiment(
    api_key=API_KEY,
    project_name="vae-eco-ml",
    workspace="zireae1",
)

In [None]:
### load data
data_df = pd.read_csv("data/wgs_train_health.noab_data_to_ml.filt.txt", sep="\t")
data_df = data_df.to_numpy()
print(data_df.shape)

### binarize
zero_thr = 0
data_df[(data_df>zero_thr)]=1
data_df[(data_df<=zero_thr)]=0
data_df = np.asarray(data_df)

#ids=np.random.choice(data_df.shape[0], 221*8)

features = np.copy(data_df)
labels = np.copy(data_df)

print(features.shape[0])
print(labels.shape[0])

In [None]:
### Create dataset object
class CustomDataset(TensorDataset):
    def __init__(self, features, labels):
        self.labels = labels
        self.features = features
        #self.device=device
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]#.to(self.device)
        data = self.features[idx]#.to(self.device)
        sample = {"Features": data, "Labels": label}
        return sample

### prepare train test split
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

### use prevalence to rescale:
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence[prevalence>=0.5]=1
prevalence[prevalence<0.5]=0
print(1-prevalence.sum()/prevalence.shape[0])
X_train_scaled = X_train-prevalence
X_test_scaled = X_test-prevalence
y_train_scaled = y_train-prevalence
y_test_scaled = y_test-prevalence

X_train_scaled=torch.from_numpy(X_train_scaled).float()
y_train_scaled=torch.from_numpy(y_train_scaled).float()
X_test_scaled=torch.from_numpy(X_test_scaled).float()
y_test_scaled=torch.from_numpy(y_test_scaled).float()

#### Build datasets
Train = CustomDataset(X_train_scaled, y_train_scaled)
Test = CustomDataset(X_test_scaled, y_test_scaled)

### create batch spits of data
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
train_DS = DataLoader(Train, batch_size=100, shuffle=True, drop_last=True, **kwargs)
test_DS = DataLoader(Test, batch_size=100, shuffle=True, drop_last=True, **kwargs)


In [None]:
#### Implement VAE 
dim=y_train.shape[1]
relu = torch.nn.ReLU()
input_dim=y_train.shape[1]
class VAE(torch.nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 64)
        #self.fc2 = torch.nn.Linear(64, 32)
        #self.fc2a = torch.nn.Linear(32, 16)
        self.fc3a = torch.nn.Linear(64, 8)
        self.fc3b = torch.nn.Linear(64, 8)
        self.fc4 = torch.nn.Linear(8, 64)
        #self.fc5 = torch.nn.Linear(16, 32)
        #self.fc5a = torch.nn.Linear(32, 64)
        self.fc6 = torch.nn.Linear(64, input_dim)
        
        # Define proportion or neurons to dropout
        #self.dropout = torch.nn.Dropout(0.2)
        '''
        torch.nn.init.xavier_uniform_(self.fc1.weight)
        torch.nn.init.zeros_(self.fc1.bias)
        torch.nn.init.xavier_uniform_(self.fc2.weight)
        torch.nn.init.zeros_(self.fc2.bias)
        torch.nn.init.xavier_uniform_(self.fc3a.weight)
        torch.nn.init.zeros_(self.fc3a.bias)
        torch.nn.init.xavier_uniform_(self.fc3b.weight)
        torch.nn.init.zeros_(self.fc3b.bias)
        torch.nn.init.xavier_uniform_(self.fc4.weight)
        torch.nn.init.zeros_(self.fc4.bias)
        torch.nn.init.xavier_uniform_(self.fc5.weight)
        torch.nn.init.zeros_(self.fc5.bias)
        torch.nn.init.xavier_uniform_(self.fc6.weight)
        torch.nn.init.zeros_(self.fc6.bias)
       ''' 

    def encode(self, x):  # 784-400-[20,20]
        z = self.fc1(x)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
        
        #z = self.fc2(z)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        #z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
         
        z1 = self.fc3a(z)  # u
        #z1 = self.fc3a(x)  # u
        #z1 = self.dropout(z1)
        
        z2 = self.fc3b(z)  # logvar
        #z2 = self.fc3b(x)  # logvar
        #z2 = self.dropout(z2)
        return (z1, z2)

    def decode(self, z):  # 20-400-784
        z = self.fc4(z)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
        
        #z = self.fc5(z)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        #z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
          
        z = self.fc6(z)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z) ### turn off for abundance prediction
        z = torch.tanh(z)
        #z = relu(z)
        return z

    def forward(self, x):  # 784-400-[20,20]-20-400-784
        x = x.view(-1, input_dim)
        (u, logvar) = self.encode(x)
        stdev = torch.exp(0.5 * logvar)
        noise = torch.randn_like(stdev)
        z = u + 1 * (noise * stdev)  # no noise variation!!!
        z = self.decode(z)     # 20-400-784
        return (z, u, logvar)

### try to penalize many non-zero labels as in L1 lasso?
def final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x):
    """
    This function will add the reconstruction loss (BCELoss) and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param bce_loss: recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    MSE = mse_loss
    SPR = spr_loss
    KLD = -0.5 * torch.sum(1 + logvar - u.pow(2) - logvar.exp())
    #L1 = 0.5 * torch.sum(torch.abs(recon_x))
    return KLD + MSE + 0.1 * BCE + 10 * SPR #+ L1 +0.5 * SPR 



In [None]:
# Model Initialization
model = VAE()#.to(device)

# Validation using MSE Loss function
criterion_mse = torch.nn.MSELoss(reduction='sum')
criterion_bce = torch.nn.BCELoss(reduction='sum')
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')

criterion_kl_loss = torch.nn.KLDivLoss(reduction="batchmean", log_target=True)
# Define your L2 regularization strength for weights
l2_lambda = 0.01

# Using an Adam Optimizer 
optimizer = torch.optim.Adam(model.parameters(),
                             lr = 1e-3 #,weight_decay = 1e-8
                            )
### calculate weights by prevalence:
prevalence_weights=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence_weights[prevalence_weights<0.5]=1-prevalence_weights[prevalence_weights<0.5]
prevalence_weights=torch.from_numpy(prevalence_weights).float()

### iterate through batches
epochs = 1000
zero_thr = 0.5
outputs = []
train_losses = []
test_losses = []
train_acc = []
test_acc = []
train_corr = []
test_corr = []
train_grad = []
epochs_without_improvement=0
patience=100

for epoch in range(epochs):
    #print('Epoch number: ', epoch)
    model = model.train()
    train_loss = 0.0        # accumulated custom loss
    test_loss = 0.0        # accumulated custom loss
    corr=0.0
    grad_mag=0.0
    accs=0.0
    for (idx, batch) in enumerate(train_DS):

        # Transfer to GPU
        #batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)

        # Output of VAE
        optimizer.zero_grad()
        recon_x, u, logvar = model(batch['Features'])

        ### calculate MSE loss on a data
        mse_loss = criterion_mse(recon_x, batch['Labels'])
        ### customized mse_loss
        spr_loss = torch.sum(prevalence_weights*torch.sum(torch.square(recon_x-batch['Labels']),dim=0))
            
        #out = (recon_x.detach()>zero_thr).float()
        #bce_loss = criterion_bce(out, batch['Labels'])

        #out1 = (recon_x.detach()>zero_thr).float()
        #out2 = (recon_x.detach()<-zero_thr).float()
        #out=out1-out2
        #bce_loss = criterion_bce(torch.abs(out), torch.abs(batch['Labels']))
        ### for prev masked
        out=((recon_x.detach()+torch.from_numpy(prevalence)[None,:])>zero_thr).float()
        #bce_loss = criterion_bce(out, (batch['Labels']+torch.from_numpy(prevalence)[None,:]).float())
        bce_loss=0
        # Calculating THE loss function
        loss_val = final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x)

        # Compute L2 regularization loss
        l2_reg_loss = 0.0
        for param in model.parameters():
            if param.requires_grad:
                l2_reg_loss += torch.norm(param, 2)

        # Add L2 regularization loss to the reconstruction loss
        total_loss = loss_val + l2_lambda * l2_reg_loss

        total_loss.backward()
        train_loss += loss_val.item()
        
        #acc = (recon_x.round() == batch['Labels']).float().mean()
        #acc = (out == batch['Labels']).float().mean()
        #acc = (out == batch['Features']).float().mean()
        ### for prev masked
        acc = (out == (batch['Labels']+torch.from_numpy(prevalence)[None,:]).float()).float().mean()
        acc = float(acc)
        accs +=acc
        '''
        # Get the magnitude of gradients for each parameter
        grad_magnitudes = []
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_magnitude = torch.norm(param.grad.data)
                grad_magnitudes.append(grad_magnitude.item())

        # Log the gradient magnitudes for analysis
        grad_mag += np.mean(grad_magnitudes)
        '''
        optimizer.step()
        
    # Storing the losses in a list for plotting
    train_losses.append(train_loss/len(train_DS))
    train_acc.append(accs/len(train_DS))
    train_grad.append(grad_mag)

    accs=0.0
    corr=0.0
    with torch.no_grad():
        model = model.eval()
        for (idx, batch) in enumerate(test_DS):
            # Transfer to GPU
            #batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)

            recon_x, u, logvar = model(batch['Features'])
            ### Calculating the loss function
            mse_loss = criterion_mse(recon_x, batch['Labels'])
            ### customized mse_loss
            spr_loss = torch.sum(prevalence_weights*torch.sum(torch.square(recon_x-batch['Labels']),dim=0))
            #out = (recon_x.detach()>zero_thr).float()
            #bce_loss = criterion_bce(out, batch['Labels'])

            ### for prev masked
            out=((recon_x.detach()+torch.from_numpy(prevalence)[None,:])>zero_thr).float()
            #bce_loss = criterion_bce(out, (batch['Labels']+torch.from_numpy(prevalence)[None,:]).float())
            bce_loss=0
            loss_val = final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x)
            test_loss += loss_val.item()

            # Calculate accuracy
            #acc = (out == batch['Features']).float().mean() # for binary-abund
            #acc +=np.sum((true_x.shape[1]-np.sum(np.abs(bin_recon_x-true_x), axis=1))/true_x.shape[1])
            #acc = (recon_x.round() == batch['Labels']).float().mean()
            #acc = (out == batch['Labels']).float().mean()
            ### for prev masked
            acc = (out == (batch['Labels']+torch.from_numpy(prevalence)[None,:]).float()).float().mean()
            acc = float(acc)
            accs +=acc

        # Storing the losses in a list for plotting
        test_losses.append(test_loss/len(test_DS))
        test_acc.append(accs/len(test_DS))
    
    metrics = {'train loss': train_loss/len(train_DS), 
               'test loss': test_loss/len(test_DS), 
               'accuracy': accs/len(test_DS),
               #'gradient': grad_mag,
               #'test_corr': corr/X_test.shape[0]
               }
    experiment.log_metrics(metrics, step=epoch)
    
    if train_loss>test_loss:
        epochs_without_improvement=0
    else:
        epochs_without_improvement+=1
 
    if epochs_without_improvement==patience:
        print("early stop")
        break
    #outputs.append((epochs, batch['Features'], reconstructed))
experiment.end()

In [None]:
import matplotlib.cm as cm
#### check how good the reconstruction is by r2:
for (idx, batch) in enumerate(test_DS):
    print(idx)
    true_x = batch['Labels'].detach().numpy()+prevalence
    print(true_x)
    recon_x, u, logvar = model(batch['Features'])
    
    fig, (ax1, ax2) = plt.subplots(1, 2)
    #fig.suptitle('True vs Reconstructed')
    
    im1=ax1.imshow(true_x[:,range(dim)], cmap='Greys', interpolation='nearest')
    ax1.set_title('True')
    
    a = recon_x.detach().numpy()+prevalence
    #a = (recon_x.detach()>zero_thr).float()
    a[(a>zero_thr)]=1
    a[(a<=zero_thr)]=0
    #a=a-prevalence
    #ax2.imshow(a[:,range(dim)], cmap='bwr', interpolation='nearest')
    im2=ax2.imshow(a[:,range(dim)], cmap='Greys', interpolation='nearest')
    ax2.set_title('Reconstructed')
    print(1-np.sum(np.abs(true_x-a))/a.shape[0]/a.shape[1])
    ax2.figure.colorbar(im2)
    #plt.savefig('graphs/batch_reconstr_masked_binary_{}.pdf'.format(idx), dpi=300) 
    plt.show()


In [None]:
train_acc=np.array(train_acc)

plt.plot(range(train_acc.shape[0]), train_acc*100, c='blue', label = "train_acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy, %")
plt.ylim([0, 100])
#plt.xlim([0,3000])  # adjust the right leaving left unchanged

test_acc=np.array(test_acc)
plt.plot(range(test_acc.shape[0]), test_acc*100, c='red', label = "test_acc")
# horizontal line showing sparsity (average number of zeroes)
plt.axhline(y = (1-np.mean(np.sum(X_train, axis=1)/X_train.shape[1]))*100, color = 'grey', linestyle = '-.')
# horizontal line showing dumb predictions based on prevalence
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence[prevalence>=0.5]=1
prevalence[prevalence<0.5]=0
#plt.axhline(y = (1-prevalence.sum()/prevalence.shape[0])*100, color = 'green', linestyle = '-.')
#acc=100-np.mean(np.sum(np.abs(true_x),axis=1))*100//X_train.shape[1]
# horizontal line showing dumb predictions based on prevalence
#acc=100-np.mean(np.sum(np.abs(true_x+prevalence),axis=1))*100/X_train.shape[1]
acc=(1-np.sum(np.abs(X_test-prevalence))/X_test.shape[0]/prevalence.shape[0])*100

plt.axhline(y = acc, color = 'orange', linestyle = '-.')

train_losses=np.array(train_losses)
test_losses=np.array(test_losses)
#plt.plot(range(train_acc.shape[0]), train_losses*100/np.max(train_losses), c='blue', label = "train_loss", ls='--')
#plt.plot(range(train_acc.shape[0]), test_losses*100/np.max(train_losses), c='red', label = "test_loss", ls='--')

plt.legend(frameon=False)
plt.savefig('VAE_masked_wgs_filt_bin_weighted_loss.pdf', dpi=1000) 
plt.show()

In [None]:
import numba
numba.__version__
### let's look at the latent space:
import umap

In [None]:
batch_size = 3000  # Choose an appropriate batch size
indices = torch.randperm(len(X_train_scaled))[:batch_size]
z=X_train_scaled[indices]
# Decode the latent variables to obtain the corresponding data points
with torch.no_grad():
    encoded, u = model.encode(z)  # Decode latent variables to data points

# Convert the decoded data points to numpy array
encoded_np = encoded.numpy()

# Apply UMAP to reduce the dimensionality of the latent space
reducer = umap.UMAP(n_components=2, random_state=0)
latent_umap = reducer.fit_transform(encoded_np)

# Plot the latent space using a scatter plot
plt.scatter(latent_umap[:, 0], latent_umap[:, 1], color ='red', alpha=0.1,)
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('UMAP of points projected to latent space')
plt.legend(['Training data'])
#plt.savefig('UMAP_of_latent_projections_wgs_filt.pdf', dpi=1200) 

plt.show()

In [None]:
batch_size = 3000  # Choose an appropriate batch size

# Apply UMAP to reduce the dimensionality of the latent space
reducer = umap.UMAP(n_components=2, random_state=0)

indices = torch.randperm(len(y_train_scaled))[:batch_size]
data_umap = reducer.fit_transform(y_train_scaled[indices])
transformed_true=reducer.fit_transform(y_train_scaled[indices])

# Plot the latent space using a scatter plot
#plt.scatter(data_umap[:, 0], data_umap[:, 1], color ='red', alpha=0.1,)
plt.scatter(transformed_true[:, 0], transformed_true[:, 1],  color='red', alpha=0.1)
test_umap = reducer.transform(y_test_scaled)
plt.scatter(test_umap[:, 0], test_umap[:, 1], color ='blue', alpha=0.1,)

plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('UMAP of points sampled from training data')
plt.legend(['Training data', 'Test data'])
#plt.savefig('UMAP_of_test_data_wgs_filt.pdf', dpi=1200) 

plt.show()

In [None]:
# Generate a batch of data points
batch_size = 3000  # Choose an appropriate batch size
z_dim = 8  # Choose the dimensionality of the latent space
z = torch.randn(batch_size, z_dim)  # Generate random latent variables

# Decode the latent variables to obtain the corresponding data points
with torch.no_grad():
    decoded = model.decode(z)  # Decode latent variables to data points

# Convert the decoded data points to numpy array
decoded_np = decoded.numpy()

# Apply UMAP to reduce the dimensionality of the latent space
#reducer = umap.UMAP(n_components=2, random_state=0)
#indices = torch.randperm(len(y_train_scaled))[:batch_size]
#transformed_true=reducer.fit_transform(y_train_scaled[indices])
latent_umap = reducer.transform(decoded_np)

# Plot the latent space reconstructions using a scatter plot
plt.scatter(transformed_true[:, 0], transformed_true[:, 1],  color='red', alpha=0.1)
plt.scatter(latent_umap[:, 0], latent_umap[:, 1], color='blue', alpha=0.1,)
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('UMAP of points sampled from latent space')
plt.legend(['Training data','Generated'])
#plt.savefig('UMAP_reconstr_train_wgs_filt.pdf', dpi=1200) 

plt.show()

In [None]:
# calculate true positive:
tps=np.zeros((true_x.shape[1], 2))
tns=np.zeros((true_x.shape[1], 2))
fps=np.zeros((true_x.shape[1], 2))
fns=np.zeros((true_x.shape[1], 2))
#i=np.zeros((true_x.shape[1], 2))
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence[prevalence>=0.5]=1
prevalence[prevalence<0.5]=0
for (idx, batch) in enumerate(test_DS):
    zero_thr = 0.0
    true_x_bin = batch['Features'].detach().numpy()+prevalence
    #true_x_bin= true_x
    #true_x_bin[(true_x_bin>zero_thr)]=1
    #true_x_bin[(true_x_bin<=zero_thr)]=0
    #np.sum(true_x_bin-batch['Features'].detach().numpy())
    
    ### we need another threshold for reconstructed ones
    zero_thr = 0.5
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=recon_x.detach().numpy()+prevalence
    recon_x_bin[recon_x_bin>zero_thr]=1
    recon_x_bin[recon_x_bin<=zero_thr]=0
    print(recon_x_bin.shape)
    for i in range(true_x_bin.shape[1]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[0]):
            if(true_x_bin[j,i]==1):
                if(recon_x_bin[j,i]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[j,i]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        #print(idx)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn


sum_tps=np.sum(tps, axis=1)
sum_tns=np.sum(tns, axis=1)
sum_fps=np.sum(fps, axis=1)
sum_fns=np.sum(fns, axis=1)
#print(sum_tps)
#sum_tns=tns[:,0]+tns[:,1]
#sum_fps=fps[:,0]+fps[:,1]
#sum_fns=tps[:,0]+fns[:,1]

sensitivity=sum_tps/(sum_tps+sum_fns+0.5)
specificity=sum_tns/(sum_tns+sum_fps+0.5)

In [None]:
true_x_bin

In [None]:
plt.scatter(x=sensitivity, y=specificity)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Sensitivity")
plt.ylabel("Specificity")
plt.legend(["Spearman r= -0.63"], frameon=False, loc=3)
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")

plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")
plt.plot([0.4, 1], [1, 0.4], ls="--", linewidth=".3", c=".3")
plt.plot([0.6, 1], [1, 0.6], ls="--", linewidth=".3", c=".3")
plt.plot([0.8, 1], [1, 0.8], ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.savefig('graphs/VAE_wgs_filt_bin_per_species_weighted_loss.pdf', dpi=1000) 

plt.show()

In [None]:
stats.spearmanr(sensitivity[sensitivity>0.0], specificity[sensitivity>0.0])

In [None]:
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
plt.scatter(x=prevalence, y=specificity)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Prevalence")
plt.ylabel("Specificity")
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")
plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")

plt.legend(["Spearman r= -0.83"], frameon=False, loc=3)

plt.gcf().set_size_inches(5, 5)
plt.savefig('VAE_wgs_filt_bin_per_species_prev_spec_weighted_loss.pdf', dpi=1000) 

plt.show()

In [None]:
stats.spearmanr(prevalence[prevalence<0.5], specificity[prevalence<0.5])

In [None]:
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
plt.scatter(x=prevalence, y=sensitivity)

plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Prevalence")
plt.ylabel("Sensitivity")
plt.legend(["Spearman r= 0.86"], frameon=False, loc=4)
plt.plot([0, 1], [0, 1], ls="--", linewidth=".3", c=".3")
#plt.plot([0, 0.85],[0.43, 1],  ls="--", linewidth=".3", c=".3")
#plt.plot([0, 0.6],[0.4, 1],  ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.savefig('VAE_wgs_filt_bin_per_species_prev_sens_weighted_loss.pdf', dpi=1000) 

plt.show()

In [None]:
stats.spearmanr(prevalence[sensitivity>0.0], sensitivity[sensitivity>0.0])

In [None]:
##### rebalance the data???
rare=np.where(prevalence<0.3)[0]
X_data=X_train_scaled.numpy()[:,rare]
rare_samples=0
for i in range(rare.shape[0]):
    rare_samples=np.append(rare_samples,np.where(X_data[:,i]>0))
    #print(X_train_scaled.numpy()[:,rare[i]])
    
rare_samples=rare_samples[1:]

In [None]:
np.unique(rare_samples).shape

In [None]:
rare_samples.shape

In [None]:
X_train_scaled.shape

In [None]:
### Check sample to sample variability

# calculate true positive:
tps=np.zeros((true_x.shape[0], 10))
tns=np.zeros((true_x.shape[0], 10))
fps=np.zeros((true_x.shape[0], 10))
fns=np.zeros((true_x.shape[0], 10))
#i=np.zeros((true_x.shape[1], 2))
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence[prevalence>=0.5]=1
prevalence[prevalence<0.5]=0
for (idx, batch) in enumerate(train_DS):
    zero_thr = 0.0
    true_x_bin = batch['Features'].detach().numpy()+prevalence
    #true_x_bin= true_x
    #true_x_bin[(true_x_bin>zero_thr)]=1
    #true_x_bin[(true_x_bin<=zero_thr)]=0
    #np.sum(true_x_bin-batch['Features'].detach().numpy())
    
    ### we need another threshold for reconstructed ones
    zero_thr = 0.5
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=recon_x.detach().numpy()+prevalence
    recon_x_bin[recon_x_bin>zero_thr]=1
    recon_x_bin[recon_x_bin<=zero_thr]=0
    print(recon_x_bin.shape)
    for i in range(true_x_bin.shape[0]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[1]):
            if(true_x_bin[i,j]==1):
                if(recon_x_bin[i,j]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[i,j]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        #print(idx)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn


sum_tps=tps.flatten()
sum_tns=tns.flatten()
sum_fps=fps.flatten()
sum_fns=fns.flatten()
#print(sum_tps)
#sum_tns=tns[:,0]+tns[:,1]
#sum_fps=fps[:,0]+fps[:,1]
#sum_fns=tps[:,0]+fns[:,1]

sensitivity=sum_tps/(sum_tps+sum_fns)#+0.5)
specificity=sum_tns/(sum_tns+sum_fps)#+0.5)

In [None]:
X_train.shape[0]

In [None]:
plt.scatter(x=sensitivity, y=specificity, alpha=0.3)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Sensitivity")
plt.ylabel("Specificity")
plt.legend(["Spearman r= -0.22"], frameon=False, loc=3)
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")

plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")
plt.plot([0.4, 1], [1, 0.4], ls="--", linewidth=".3", c=".3")
plt.plot([0.6, 1], [1, 0.6], ls="--", linewidth=".3", c=".3")
plt.plot([0.8, 1], [1, 0.8], ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.savefig('graphs/VAE_wgs_filt_bin_per_samples_train_weighted_loss.pdf', dpi=1000) 

plt.show()

In [None]:
stats.spearmanr(sensitivity, specificity)

In [None]:
### Renormalize data to sum=1 (enforce it?)

In [None]:
### TO DO: NN baseline for accuracy (reuse code from SNP-genes)


In [None]:
### TO DO: PCA->embed->reverse PCA (?)

In [None]:
### Categorize abundance for accuracy evaluation

In [None]:
### Reweight loss function with penalty proportional to the reverse distribution?

In [None]:
### try other data

In [None]:
### try to memorize perfectly

In [None]:
### extract what model learned?

In [None]:
### code refactoring