In [None]:
### import comet_ml at the top of your file
# from comet_ml import Experiment  ### special library to record performance on a web server
# from config import * ### file with personal API_KEY

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
from torch.utils.data import Dataset, TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from scipy import stats
from skbio.stats.composition import clr
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:1' if use_cuda else 'mps')
print(device)

### Record experiment on comet.com

In [None]:
# Create and initialize comet experiment with your api key
'''
experiment = Experiment(
    api_key=API_KEY,
    project_name="vae-eco-ml",
    workspace="zireae1",
)
'''

### Load clean data

In [None]:
### data: rows are samples, columns are species, relative abundance sum ~ [80-100]

data_df = pd.read_csv("data/wgs_train_health.noab_data_to_ml.filt.txt", sep="\t")
data_df = data_df.to_numpy()
print(data_df.shape)

zero_thr = 1e-6
features = np.copy(data_df)
features[(features>zero_thr)]=1 # binarize only features
features[(features<=zero_thr)]=0
labels = np.copy(data_df)

#features = np.asarray(features)#/100
print(features.shape[0])
#labels = np.asarray(labels)#/100
print(labels.shape[0])
### check that features are binary:
features

### Create dataset object

In [None]:
class CustomDataset(TensorDataset):
    def __init__(self, features, labels):
        self.labels = labels
        self.features = features
        #self.device=device
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        label = self.labels[idx]#.to(self.device)
        data = self.features[idx]#.to(self.device)
        sample = {"Features": data, "Labels": label}
        return sample

### prepare train test split
X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

### clr transformation for outputs
gmean_train = (np.exp(np.nansum(np.log(y_train[y_train > 0]+zero_thr)) / np.size(y_train)))
y_train_clr = np.log((y_train+zero_thr)/gmean_train)
y_test_clr = np.log((y_test+zero_thr)/gmean_train)

### rescale to [-1, 1] interval
scaler = preprocessing.MaxAbsScaler().fit(y_train_clr)
#scaler = preprocessing.MinMaxScaler().fit(y_train_scaled)
y_train_scaled = scaler.transform(y_train_clr)
y_test_scaled = scaler.transform(y_test_clr)

### transform to tensors
X_train_scaled=torch.from_numpy(X_train).float()
y_train_scaled=torch.from_numpy(y_train_scaled).float()
X_test_scaled=torch.from_numpy(X_test).float()
y_test_scaled=torch.from_numpy(y_test_scaled).float()

#X_train=X_train.to(device)
#y_train=y_train.to(device)
#X_test=X_test.to(device)
#y_test=y_test.to(device)

#### dataset build 
Train = CustomDataset(X_train_scaled, y_train_scaled)
Test = CustomDataset(X_test_scaled, y_test_scaled)

### create batch spits of data
kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
train_DS = DataLoader(Train, batch_size=100, shuffle=True, drop_last=True, **kwargs)
test_DS = DataLoader(Test, batch_size=100, shuffle=True, drop_last=True, **kwargs)

In [None]:
### check distribution of the data
plt.hist(y_test_scaled.flatten(), bins=40)
plt.gcf().set_size_inches(5, 5)
#plt.savefig('wgs_filt.pdf', dpi=1000) 

plt.show()

In [None]:
### Calculate mean for training to compute residual loss:
#train_means=np.mean(y_train_scaled.numpy(), axis=0)
#print(np.mean(train_means))

### more sophisticated way (mean across samples where species was present):
train_means=(torch.sum(X_train_scaled*y_train_scaled, axis=0)/torch.sum(X_train_scaled, axis=0)).numpy()
train_means

### Implement VAE 

In [None]:
relu = torch.nn.ReLU()
#relu = torch.nn.ELU()
input_dim=y_train.shape[1]
class VAE(torch.nn.Module):
    def __init__(self):
        super(VAE, self).__init__()
        self.fc1 = torch.nn.Linear(input_dim, 1024)
        #self.fc2 = torch.nn.Linear(90, 80)
        self.fc3a = torch.nn.Linear(1024, 64)
        self.fc3b = torch.nn.Linear(1024, 64)
        self.fc4 = torch.nn.Linear(64, 1024)
        #self.fc5 = torch.nn.Linear(26, 52)
        self.fc6 = torch.nn.Linear(1024, input_dim)
        
        # Define proportion or neurons to dropout
        self.dropout = torch.nn.Dropout(0.1)  

    def encode(self, x):  # 784-400-[20,20]
        z = self.fc1(x)
        z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        #z = torch.tanh(z)
        z = relu(z)
        z = self.dropout(z)
        
        #z = self.fc2(z)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        #z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
        
        z1 = self.fc3a(z)  # u
        #z1 = self.fc3a(x)  # u
        #z1 = self.dropout(z1)
        
        z2 = self.fc3b(z)  # logvar
        #z2 = self.fc3b(x)  # logvar
        #z2 = self.dropout(z2)
        return (z1, z2)

    def decode(self, z):  # 20-400-784
        z = self.fc4(z)
        z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        #z = torch.tanh(z)
        z = relu(z)
        z = self.dropout(z)
        
        #z = self.fc5(z)
        #z = self.dropout(z)
        
        #z = torch.sigmoid(z)
        #z = torch.tanh(z)
        #z = relu(z)
        #z = self.dropout(z)
        
        z = self.fc6(z)
        z = self.dropout(z)
        
        #z = torch.sigmoid(z) ### turn off for abundance prediction
        z = torch.tanh(z)
        #z = relu(z)
        return z

    def forward(self, x):  # 
        x = x.view(-1, input_dim)
        (u, logvar) = self.encode(x)
        stdev = torch.exp(0.5 * logvar)
        noise = torch.randn_like(stdev)
        z = u + 1 * (noise * stdev)  
        z = self.decode(z)    
        return (z, u, logvar)

### try to penalize many non-zero labels as in L1 lasso? [did not help]
def final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x):
    """
    This function will add the reconstruction loss and the 
    KL-Divergence.
    KL-Divergence = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2)
    :param mse_loss: mse recontruction loss
    :param bce_loss: bce recontruction loss
    :param spr_loss: custom term in recontruction loss
    :param mu: the mean from the latent vector
    :param logvar: log variance from the latent vector
    """
    BCE = bce_loss 
    MSE = mse_loss
    SPR = spr_loss
    KLD = -0.5 * torch.sum(1 + logvar - u.pow(2) - logvar.exp())
    return 0.01 * KLD + MSE + 10 * SPR + 1 * BCE

### Model Initialization and Training

In [None]:
model = VAE()#.to(device)

### Loss functions:
criterion_mse = torch.nn.MSELoss(reduction='sum')
#criterion_mse = RMSELoss()
criterion_bce = torch.nn.BCELoss(reduction='sum')
#criterion = torch.nn.CrossEntropyLoss(reduction='sum')
criterion_l1 = torch.nn.L1Loss()
criterion_kl_loss = nn.KLDivLoss(reduction="batchmean", log_target=True)

### Define your L2 regularization strength for model weights
l2_lambda = 0.01

### Using an Adam Optimizer 
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3 ,weight_decay = 1e-7)

'''
### calculate weights by prevalence (if needed):
prevalence_weights=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence_weights[prevalence_weights<0.5]=1-prevalence_weights[prevalence_weights<0.5]
prevalence_weights=torch.from_numpy(prevalence_weights).float()
'''

### iterate through batches
epochs = 1000
zero_thr = -0.6 ### picked based on the distribution above
outputs = []
train_losses = []
test_losses = []
train_acc = []
test_acc = []
train_corr = []
test_corr = []
train_grad = []
epochs_without_improvement=0
patience=100

for epoch in range(epochs):
    #print('Epoch number: ', epoch)
    model = model.train()
    train_loss = 0.0        # accumulated custom loss
    test_loss = 0.0        # accumulated custom loss
    corr=0.0
    grad_mag=0.0
    accs=0.0
    
    ### VAE training
    for (idx, batch) in enumerate(train_DS):

        ### Transfer to GPU
        #batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)

        optimizer.zero_grad()
        recon_x, u, logvar = model(batch['Features'])

        ### calculate MSE loss on a data
        #mse_loss = torch.sqrt(criterion_mse(recon_x, batch['Labels']))
        mse_loss = criterion_mse(recon_x, batch['Labels'])
        
        spr_loss = criterion_mse(recon_x-torch.from_numpy(train_means)[None,:], batch['Labels']-torch.from_numpy(train_means)[None,:])
        #spr_loss=0

        out = ((recon_x.detach())>zero_thr).float() ### for MaxAbs scaled
        
        ### try penalize only non-zero zeroes
        #bce_loss = torch.sum(torch.abs((1-batch['Features'])*(recon_x-batch['Labels'])))
        
        ### try to penalize more according to the entropy of species??? same as prevalence:
        #entropy2 = Categorical(probs = torch.transpose(batch['Features'],0, 1)).entropy()
        
        #bce_loss = criterion_bce(torch.abs(out), torch.abs(batch['Labels']))
        bce_loss=0 ### addition of BCE loss did not help
        
        ### for prev masked
        #out=((recon_x.detach()+torch.from_numpy(prevalence)[None,:])>zero_thr).float()
        #bce_loss = criterion_bce(out, (batch['Features']+torch.from_numpy(prevalence)[None,:]).float())

        ### Calculating THE loss function
        loss_val = final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x)
        
        ### Compute L2 regularization loss
        l2_reg_loss = 0.0
        for param in model.parameters():
            if param.requires_grad:
                l2_reg_loss += torch.norm(param, 2)
        
        ### Add L2 regularization loss to the reconstruction loss
        total_loss = loss_val + l2_lambda * l2_reg_loss

        total_loss.backward()
        train_loss += loss_val.item()
        
        acc = (out == batch['Features']).float().mean()
        acc = float(acc)
        accs +=acc
        
        '''
        # Get the magnitude of gradients for each parameter
        grad_magnitudes = []
        for name, param in model.named_parameters():
            if param.grad is not None:
                grad_magnitude = torch.norm(param.grad.data)
                grad_magnitudes.append(grad_magnitude.item())

        # Log the gradient magnitudes for analysis
        grad_mag += np.mean(grad_magnitudes)
        '''
        optimizer.step()
        
    # Storing the losses in a list for plotting
    train_losses.append(train_loss/len(train_DS))
    train_acc.append(accs/len(train_DS))
    train_grad.append(grad_mag)

    ### VAE evaluation
    accs=0.0
    corr=0.0
    with torch.no_grad():
        model = model.eval()
        for (idx, batch) in enumerate(test_DS):
            ### Transfer to GPU
            #batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)

            recon_x, u, logvar = model(batch['Features'])
            ### Calculating the loss function
            mse_loss = criterion_mse(recon_x, batch['Labels'])
            #mse_loss = 0

            out = ((recon_x.detach())>zero_thr).float() #for MaxAbs scaled
            #bce_loss = criterion_bce(out, batch['Labels'])
            bce_loss=0
            
            ### try penalize only non-zero zeroes
            #bce_loss = torch.sum(torch.abs((1-batch['Features'])*(recon_x-batch['Labels'])))
            
            spr_loss = criterion_mse(recon_x-torch.from_numpy(train_means)[None,:], batch['Labels']-torch.from_numpy(train_means)[None,:])
            #spr_loss=0
            
            loss_val = final_loss(mse_loss, bce_loss, spr_loss, u, logvar, recon_x)
            test_loss += loss_val.item()

            # Calculate accuracy
            acc = (out == batch['Features']).float().mean() # for binary-abund
            
            #acc = (out == (batch['Features']+torch.from_numpy(prevalence)[None,:]).float()).float().mean()
            acc = float(acc)
            accs +=acc

        # Storing the losses in a list for plotting
        test_losses.append(test_loss/len(test_DS))
        test_acc.append(accs/len(test_DS))
        
    ### Store metrics for comet:
    '''
    metrics = {'train loss': train_loss/len(train_DS), 
               'test loss': test_loss/len(test_DS), 
               'accuracy': accs/len(test_DS),
               #'gradient': grad_mag,
               #'test_corr': corr/X_test.shape[0]
               }
    experiment.log_metrics(metrics, step=epoch)
    '''
    if train_loss>test_loss:
        epochs_without_improvement=0
    else:
        epochs_without_improvement+=1
 
    if epochs_without_improvement==patience:
        print("early stop")
        break
    #outputs.append((epochs, batch['Features'], reconstructed))
    
# experiment.end() ### end comet experiment

In [None]:
### check the distribution of reconstructed data:
plt.hist(recon_x.detach().numpy().flatten(), bins=40)
plt.gcf().set_size_inches(5, 5)
#plt.savefig('wgs_filt_clr_mse_loss_wide_v2.pdf', dpi=1000) 
plt.show()
#plt.hist(recon_x.detach().numpy().flatten())

### Look at test set reconstruction quality

In [None]:
#### check how good the reconstructions are (plot matrices as heatmaps):
for (idx, batch) in enumerate(test_DS):
    print(idx)
    true_x = batch['Labels'].detach().numpy()
    #print(true_x)
    recon_x, u, logvar = model(batch['Features'])
    
    fig, (ax1, ax2) = plt.subplots(1, 2)
    #fig.suptitle('True vs Reconstructed')
    
    ax1.imshow(true_x[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    ax1.set_title('True')
    
    a = recon_x.detach().numpy()

    ax2.imshow(a[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    ax2.set_title('Reconstructed')
    #print(1-np.sum(np.abs(true_x)-np.abs(a))/a.shape[0]/a.shape[1])
    print(stats.spearmanr(true_x.flatten(), a.flatten()))
    #plt.savefig('batch_reconstruction_masked_v1_{}_wide.pdf'.format(idx), dpi=1300) 
    plt.show()


In [None]:
#### check how good the reconstructions are (plot predicted vs observed scatter plot):

for (idx, batch) in enumerate(test_DS):
    print(idx)
    true_x = batch['Labels'].detach().numpy()
    #print(true_x)
    recon_x, u, logvar = model(batch['Features'])
    
    a = recon_x.detach().numpy()
    #fig.suptitle('True vs Reconstructed')
    #plt.scatter(x=true_x, y=a)
    plt.hist2d(x=true_x.flatten(), y=a.flatten(), norm=matplotlib.colors.PowerNorm(1/10), bins=100)

    #ax1.imshow(true_x[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    #ax1.set_title('True')
    
    

    #ax2.imshow(a[:,range(input_dim)], cmap='bwr', interpolation='nearest')
    #ax2.set_title('Reconstructed')
    
    print(stats.spearmanr(true_x.flatten(), a.flatten()))
    #plt.savefig('batch_reconstruction_masked_v1_{}_wide.pdf'.format(idx), dpi=1300) 
    plt.show()


### Plot model training curves

In [None]:
train_acc=np.array(train_acc)

plt.plot(range(train_acc.shape[0]), train_acc*100, c='blue', label = "train_acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy, %")
plt.ylim([0, 100])
plt.xlim([0,1000])  # adjust the right leaving left unchanged

test_acc=np.array(test_acc)
plt.plot(range(test_acc.shape[0]), test_acc*100, c='red', label = "test_acc")
# horizontal line showing sparsity (average number of zeroes)
#plt.axhline(y = (1-np.mean(np.sum(X_train+1, axis=1)/X_train.shape[1]))*100, color = 'grey', linestyle = '-.')
# horizontal line showing dumb predictions based on prevalence
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence[prevalence>=0.5]=1
prevalence[prevalence<0.5]=0
#plt.axhline(y = (1-prevalence.sum()/prevalence.shape[0])*100, color = 'green', linestyle = '-.')
#acc=(1-np.mean(np.sum(y_train_scaled.numpy()+1, axis=1)/y_train_scaled.shape[1]))*100
# horizontal line showing dumb predictions based on prevalence
#acc=100-np.mean(np.sum(np.abs(true_x+prevalence),axis=1))*100/X_train.shape[1]
acc=(1-np.sum(np.abs(X_test_scaled.numpy()-prevalence))/X_test_scaled.shape[0]/prevalence.shape[0])*100
#acc=1-np.mean(np.sum(X_test_scaled.numpy()+1-prevalence))
plt.axhline(y = acc, color = 'orange', linestyle = '-.')

train_losses=np.array(train_losses)
test_losses=np.array(test_losses)
plt.plot(range(train_acc.shape[0]), train_losses*100/np.max(train_losses), c='blue', label = "train_loss", ls='--')
plt.plot(range(train_acc.shape[0]), test_losses*100/np.max(train_losses), c='red', label = "test_loss", ls='--')

plt.legend(frameon=False)
#plt.savefig('VAE_masked_wgs_filt_nonb_lat8_wide.pdf', dpi=1000) 
plt.show()

### Count true/false positive/negative _per species_ statistics

In [None]:
tps=np.zeros((true_x.shape[1], len(test_DS)))
tns=np.zeros((true_x.shape[1], len(test_DS)))
fps=np.zeros((true_x.shape[1], len(test_DS)))
fns=np.zeros((true_x.shape[1], len(test_DS)))

zero_thr = -0.6
for (idx, batch) in enumerate(test_DS):
    print("num")
    true_x_bin = batch['Features'].detach().numpy()
  
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=recon_x.detach().numpy()#+1
    recon_x_bin[recon_x_bin>zero_thr]=1
    recon_x_bin[recon_x_bin<=zero_thr]=0
    print(recon_x_bin.shape)
    for i in range(true_x_bin.shape[1]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[0]):
            if(true_x_bin[j,i]==1):
                if(recon_x_bin[j,i]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[j,i]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn

### Summarize across batches
sum_tps=np.sum(tps, axis=1)
sum_tns=np.sum(tns, axis=1)
sum_fps=np.sum(fps, axis=1)
sum_fns=np.sum(fns, axis=1)

sensitivity=sum_tps/(sum_tps+sum_fns)#+0.5)
specificity=sum_tns/(sum_tns+sum_fps)#+0.5)   

In [None]:
len(test_DS)

In [None]:
plt.scatter(x=sensitivity, y=specificity)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Sensitivity")
plt.ylabel("Specificity")
plt.legend(["Spearman r= -0.01"], frameon=False, loc=3)
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")

plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")
plt.plot([0.4, 1], [1, 0.4], ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(sensitivity, specificity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('VAE_wgs_filt_nonb_per_species_metrics_residual_loss_wide.pdf', dpi=1000) 

plt.show()

In [None]:
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
plt.scatter(x=prevalence, y=specificity)
plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Prevalence")
plt.ylabel("Specificity")
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")
plt.legend(["Spearman r= -0.54"], frameon=False, loc=3)

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(prevalence, specificity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('VAE_wgs_filt_nonb_per_species_prev_residual_loss_wide.pdf', dpi=1000) 

plt.show()

In [None]:
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
plt.scatter(x=prevalence, y=sensitivity)

plt.ylim([0, 1])
plt.xlim([0, 1])
plt.xlabel("Prevalence")
plt.ylabel("Sensitivity")
plt.legend(["Spearman r= 0.73"], frameon=False, loc=4)

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(prevalence, sensitivity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('VAE_wgs_filt_nonb_per_species_prev_sens_residual_loss_wide.pdf', dpi=1000) 

plt.show()

### Count true/false positive/negative _per sample_ statistics

In [None]:
tps=np.zeros((true_x.shape[0], len(test_DS)))
tns=np.zeros((true_x.shape[0], len(test_DS)))
fps=np.zeros((true_x.shape[0], len(test_DS)))
fns=np.zeros((true_x.shape[0], len(test_DS)))
#i=np.zeros((true_x.shape[1], 2))
prevalence=np.sum(X_train, axis=0)/X_train.shape[0]
prevalence[prevalence>=0.5]=1
prevalence[prevalence<0.5]=0
zero_thr = -0.6
for (idx, batch) in enumerate(test_DS):
    true_x_bin = batch['Features'].detach().numpy()#+prevalence
  
    recon_x, u, logvar = model(batch['Features'])
    recon_x_bin=recon_x.detach().numpy()
    recon_x_bin[recon_x_bin>zero_thr]=1
    recon_x_bin[recon_x_bin<=zero_thr]=0
    print(recon_x_bin.shape)
    for i in range(true_x_bin.shape[0]):
        tp=0
        tn=0
        fn=0
        fp=0
        #print(true_x_bin.shape[0])
        for j in range(true_x_bin.shape[1]):
            if(true_x_bin[i,j]==1):
                if(recon_x_bin[i,j]==1):
                    tp+=1
                else:
                    fn+=1
            else:
                if(recon_x_bin[i,j]==1):
                    fp+=1
                else:
                    tn+=1
        #print(i)
        tps[i, idx]=tp
        tns[i, idx]=tn
        fps[i, idx]=fp
        fns[i, idx]=fn

### Summarize across batches
sum_tps=tps.flatten()
sum_tns=tns.flatten()
sum_fps=fps.flatten()
sum_fns=fns.flatten()

sensitivity=sum_tps/((sum_tps+sum_fns))#+0.5)
specificity=sum_tns/((sum_tns+sum_fps))#+0.5)

In [None]:
plt.scatter(x=sensitivity, y=specificity, alpha=0.3)
plt.ylim([0, 1.1])
plt.xlim([0, 1.1])
plt.xlabel("Sensitivity")
plt.ylabel("Specificity")
plt.legend(["Spearman r= -0.98"], frameon=False, loc=3)
plt.plot([0.0, 1], [1, 0.0], ls="--", linewidth=".3", c=".3")

plt.plot([0.2, 1], [1, 0.2], ls="--", linewidth=".3", c=".3")
plt.plot([0.4, 1], [1, 0.4], ls="--", linewidth=".3", c=".3")
plt.plot([0.6, 1], [1, 0.6], ls="--", linewidth=".3", c=".3")
plt.plot([0.8, 1], [1, 0.8], ls="--", linewidth=".3", c=".3")

plt.gcf().set_size_inches(5, 5)
plt.legend(["Spearman r=" + str(np.round(stats.spearmanr(sensitivity, specificity)[0], decimals=3))], frameon=False, loc=3)

#plt.savefig('graphs/VAE_wgs_filt_nonbin_per_samples_train_weighted_loss_wide.pdf', dpi=1000) 

plt.show()

### Let's look at the latent space:

In [None]:
import numba
numba.__version__
import umap

In [None]:
batch_size = 1000  # Choose an appropriate batch size
indices = torch.randperm(len(X_train_scaled))[:batch_size]
z=X_train_scaled[indices]

# Decode the latent variables to obtain the corresponding data points
with torch.no_grad():
    encoded, u = model.encode(z)  # Decode latent variables to data points

# Convert the decoded data points to numpy array
encoded_np = encoded.numpy()

# Apply UMAP to reduce the dimensionality of the latent space
reducer = umap.UMAP(n_components=2, random_state=0)
latent_umap = reducer.fit_transform(encoded_np)

# Plot the latent space using a scatter plot
plt.scatter(latent_umap[:, 0], latent_umap[:, 1], color ='red', alpha=0.1,)
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('UMAP of points projected to latent space')
plt.legend(['Training data'])
#plt.savefig('UMAP_of_latent_projections_wgs_filt.pdf', dpi=1200) 

plt.show()

In [None]:
batch_size = 1000  # Choose an appropriate batch size

indices = torch.randperm(len(y_train_scaled))[:batch_size]
data_umap = reducer.fit_transform(y_train_scaled[indices])
transformed_true=reducer.fit_transform(y_train_scaled[indices])

# Plot the latent space using a scatter plot
#plt.scatter(data_umap[:, 0], data_umap[:, 1], color ='red', alpha=0.1,)
plt.scatter(transformed_true[:, 0], transformed_true[:, 1],  color='red', alpha=0.1)
test_umap = reducer.transform(y_test_scaled)
plt.scatter(test_umap[:, 0], test_umap[:, 1], color ='blue', alpha=0.1,)

plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('UMAP of points sampled from training and test data')
plt.legend(['Training data', 'Test data'])
#plt.savefig('UMAP_of_test_data_wgs_filt.pdf', dpi=1200) 

plt.show()

In [None]:
# Generate a batch of data points
batch_size =3000  # Choose an appropriate batch size
z_dim = 64  # Specify the dimensionality of the latent space
z = torch.randn(batch_size, z_dim)  # Generate random latent variables

# Decode the latent variables to obtain the corresponding data points
with torch.no_grad():
    decoded = model.decode(z)  # Decode latent variables to data points

# Convert the decoded data points to numpy array
decoded_np = decoded.numpy()

# Apply UMAP to reduce the dimensionality of the latent space
latent_umap = reducer.transform(decoded_np)

# Plot the latent space reconstructions using a scatter plot
plt.scatter(transformed_true[:, 0], transformed_true[:, 1],  color='red', alpha=0.1)
plt.scatter(latent_umap[:, 0], latent_umap[:, 1], color='blue', alpha=0.1,)
plt.xlabel('Latent Dimension 1')
plt.ylabel('Latent Dimension 2')
plt.title('UMAP of points sampled from latent space')
plt.legend(['Training data','Generated'])
#plt.savefig('UMAP_reconstr_train_wgs_filt_wide_mse.pdf', dpi=1200) 

plt.show()

In [None]:
### check inter-species correlations
cormat=stats.spearmanr(labels, axis=0).statistic ### better to calculate sparse correlations but oh well
np.fill_diagonal(cormat, 0, wrap=False)
print(np.min(cormat))
print(np.max(cormat))
plt.imshow(cormat, cmap='seismic', interpolation='nearest')
#plt.savefig('feature_correlations_spearman_wgs.pdf', dpi=1200) 
plt.show()

In [None]:
plt.hist(cormat.flatten(), bins=60)

In [None]:
data_df

In [None]:
'''
####### RMSE loss function
class RMSELoss(nn.Module):
    def __init__(self, eps=1e-6):
        super().__init__()
        self.mse = nn.MSELoss(reduction='sum')
        self.eps = eps
        
    def forward(self,yhat,y):
        loss = torch.sqrt(self.mse(yhat,y) + self.eps)
        return loss
        
'''