In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.distributions import Categorical
import matplotlib.pyplot as plt
import matplotlib
import pandas as pd
import numpy as np
import pickle
from torch.utils.data import Dataset, TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn import preprocessing
from sklearn.preprocessing import StandardScaler, RobustScaler, MaxAbsScaler
from scipy import stats
from tqdm import tqdm
# from skbio.stats.composition import clr
use_cuda = torch.cuda.is_available()
device = torch.device('cuda:0' if use_cuda else 'mps')
print(device)

from utils import *

import matplotlib as mpl
mpl.rcParams['font.family'] = 'Arial'
mpl.rcParams['font.size'] = 14

cuda:0


#### Preprocess synthetic data

In [2]:
def save_preprocessed(p_data, name_data):
    data_df=pickle.load(open(p_data, "rb"))
    idx, data_df = data_df.iloc[:, :1], data_df.iloc[:, 1:]

    # filter the data and remove low abundance species
    min_thres = 1e-5
    data_df[data_df<=min_thres]=0
    n_samples = data_df.shape[0]
    data_df = data_df.loc[:, (data_df>0).sum(axis=0) >= 0.05*n_samples]
    
    idx = idx.loc[data_df.sum(axis=1) > 0]
    data_df = data_df.loc[data_df.sum(axis=1) > 0]
#     print(list(data_df.sum(axis=1)).count(0))

    ## normalize to relative abundance
    data_df = data_df.div(data_df.sum(axis=1), axis=0)
    
    ## separate train and test
    features = (data_df>0).values.astype(float)
    labels = data_df.values
    print(f"{features.shape[0]} samples, {features.shape[1]} families") 
    
    ## save preprocessed data
    torch.save({"features":features, "labels":labels, "names":list(data_df.columns)}, f"../data/{name_data}_filtered.pt")

hubs = [1, 3, 6, 12, 24]
for hub in hubs:
    save_preprocessed(f"../synthetic_data/data_v4/trophic_{hub}_all_diets.pkl", f"synthetic_v4_trophic_{hub}")

10000 samples, 64 families
10000 samples, 79 families
10000 samples, 88 families
10000 samples, 92 families
10000 samples, 95 families


### Load, split and normalize

In [3]:
def load_filtered_data(p_data):
    features = torch.load(p_data)["features"]
    labels = torch.load(p_data)["labels"]
    spc_names = torch.load(p_data)["names"]
    
    original_indices = np.arange(len(features))
    train_indices, test_indices = train_test_split(
        original_indices,
        test_size=0.2,
        random_state=42,
    )
    X_train = features[train_indices]
    X_test = features[test_indices]
    y_train = labels[train_indices]
    y_test = labels[test_indices]
#     X_train, X_test, y_train, y_test = train_test_split(features, labels, test_size=0.2, random_state=42)

    ## clr transformation for outputs
    zero_thr = 1e-8
    gmean_train = (np.exp(np.nansum(np.log(y_train[y_train > 0]+zero_thr)) / np.size(y_train)))
    y_train_clr = np.log((y_train+zero_thr)/gmean_train)
    y_test_clr = np.log((y_test+zero_thr)/gmean_train)

    ## rescale the data
    scaler = preprocessing.MaxAbsScaler().fit(y_train_clr)
    # scaler = preprocessing.RobustScaler().fit(y_train_clr)
    y_train_scaled = scaler.transform(y_train_clr)
    y_test_scaled = scaler.transform(y_test_clr)

    ## transform to tensors
    X_train_scaled=torch.from_numpy(X_train).float()
    y_train_scaled=torch.from_numpy(y_train_scaled).float()
    X_test_scaled=torch.from_numpy(X_test).float()
    y_test_scaled=torch.from_numpy(y_test_scaled).float()

#     plt.hist(y_test_scaled.numpy().flatten(), bins=40)
#     plt.yscale("log")
#     plt.xlabel("Abundance score")
#     plt.ylabel("Counts")
#     plt.show()
    
    # keep the spc names and train/test split for evaluation use
    return X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, {"spcs":spc_names,
                                                                          "train_idx": train_indices, 
                                                                          "test_idx": test_indices}

### Apply autoencoder

In [4]:
class VAETrainer:
    def __init__(self, model, train_loader, test_loader, optimizer, weights, device="cuda:0"):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.test_loader = test_loader
        self.optimizer = optimizer
        self.weights = weights
        self.device = device

    def train_one_epoch(self, zero_thr):
        self.model.train()
        total_loss, total_recon_loss, total_bce_loss, total_kl_loss = 0, 0, 0, 0
        total_acc = 0
        total_len = 0
        for batch in self.train_loader:
            batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)
            
            self.optimizer.zero_grad()
#             recon_x, loss, recon_loss, kl_loss = compute_loss(self.model, batch['Features'], batch['Labels'], self.weights)
            b, nb, loss, recon_loss, bce_loss, kl_loss = compute_loss_2(self.model, batch['Features'], batch['Labels'], self.weights)
            loss.backward()
            self.optimizer.step()
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_bce_loss += bce_loss.item()
            total_kl_loss += kl_loss.item()
            
#             out = ((recon_x.detach())>zero_thr).float()
            out = (b.detach()>0.5).float()
            total_acc += (out == batch['Features']).float().sum().item() # for binary-abundance
            total_len += batch["Features"].shape[0]

        return total_loss / total_len, \
               total_recon_loss / total_len, \
               total_bce_loss / total_len, \
               total_kl_loss / total_len, \
               total_acc / total_len / batch["Features"].shape[1]

    @torch.no_grad()
    def test_one_epoch(self, zero_thr):
        self.model.eval()
        total_loss, total_recon_loss, total_bce_loss, total_kl_loss = 0, 0, 0, 0
        total_acc = 0
        total_len = 0
        for batch in self.test_loader:
            batch['Features'], batch['Labels'] = batch['Features'].to(device), batch['Labels'].to(device)
            
#             recon_x, loss, recon_loss, kl_loss = compute_loss(self.model, batch['Features'], batch['Labels'], self.weights)
            b, nb, loss, recon_loss, bce_loss, kl_loss = compute_loss_2(self.model, batch['Features'], batch['Labels'], self.weights)
            
            total_loss += loss.item()
            total_recon_loss += recon_loss.item()
            total_bce_loss += bce_loss.item()
            total_kl_loss += kl_loss.item()
            
#             out = ((recon_x.detach())>zero_thr).float()
            out = (b.detach()>0.5).float()
            total_acc += (out == batch['Features']).float().sum().item() # for binary-abundance
            total_len += batch["Features"].shape[0]

        return total_loss / total_len, \
               total_recon_loss / total_len, \
               total_bce_loss / total_len, \
               total_kl_loss / total_len, \
               total_acc / total_len / batch["Features"].shape[1]

In [5]:
def train_VAE_on_data(p_data, p_output, weights=[1.0, 1.0, 0.0]):
    '''
    weights: recon abundance loss; presence bce loss; KL divergence. 
    '''
    X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, annotations = load_filtered_data(p_data)
    spc_names = annotations["spcs"]
    class CustomDataset(Dataset):
        def __init__(self, features, labels, device=None):
            self.labels = labels
            self.features = features
        def __len__(self):
            return len(self.labels)
        def __getitem__(self, idx):
            label = self.labels[idx]
            data = self.features[idx]
            return {"Features": data, "Labels": label}
    Train = CustomDataset(X_train_scaled, y_train_scaled)
    Test = CustomDataset(X_test_scaled, y_test_scaled)

    input_dim = len(spc_names)
    hidden_dim = 2048
    latent_dim = 512
    epochs = 100
    lr = 7e-4
    weight_decay = 3e-5
    zero_thr = -0.8
    
#     # early stopping
#     patience = 10 
#     min_delta = 1e-3
#     best_loss = float('inf')
#     epochs_no_improve = 0
#     best_model_state = None

    model = VAE_2(input_dim, latent_dim, hidden_dim).to(device)
    optimizer = torch.optim.AdamW(
        model.parameters(), 
        lr=lr,
        weight_decay=weight_decay
    )

    ## create batch spits of data
    kwargs = {'num_workers': 4, 'pin_memory': True} if use_cuda else {}
    train_DS = DataLoader(Train, batch_size=100, shuffle=True, drop_last=False, **kwargs)
    test_DS = DataLoader(Test, batch_size=100, shuffle=True, drop_last=False, **kwargs)

    trainer = VAETrainer(model, train_DS, test_DS, optimizer, weights, device)

    train_losses = []
    test_losses = []
    train_acc = []
    test_acc = []
    MSEs, BCEs, KLDs = [], [], []
    for epoch in tqdm(range(epochs)):
        train_loss, MSE, BCE, KLD, acc = trainer.train_one_epoch(zero_thr)
        train_losses.append(train_loss)
        train_acc.append(acc)

        test_loss, MSE, BCE, KLD, acc = trainer.test_one_epoch(zero_thr)
        test_losses.append(test_loss)
        test_acc.append(acc)
        MSEs.append(MSE)
        BCEs.append(BCE)
        KLDs.append(KLD)
        
    torch.save({"model":model.state_dict(), "annotations":annotations}, p_output)        
        
    ## plot the training process
    train_acc=np.array(train_acc)
    test_acc = np.array(test_acc)
    plt.plot(range(train_acc.shape[0]), train_acc*100, c='blue', label = "train_acc")
    plt.xlabel("Epoch")
    plt.ylabel("Accuracy, %")
    plt.ylim([0, 100])
    #     plt.xlim([0,1000])  # adjust the right leaving left unchanged
    plt.plot(range(test_acc.shape[0]), test_acc*100, c='red', label = "test_acc")
    prevalence=np.sum(X_train_scaled.numpy(), axis=0)/X_train_scaled.shape[0]
    prevalence[prevalence>=0.5]=1
    prevalence[prevalence<0.5]=0
    # horizontal line showing dumb predictions based on prevalence
    acc=(1-np.sum(np.abs(X_test_scaled.numpy()-prevalence))/X_test_scaled.shape[0]/prevalence.shape[0])*100
    plt.axhline(y = acc, color = 'orange', linestyle = '-.')
    plt.legend(frameon=False)
    handle = p_data.split("/")[-1][:-3]
    plt.savefig(f"../figures/trainingproc_acc_{handle}.pdf", bbox_inches="tight")
    plt.show()

    # show the losses
    lists = [MSEs, BCEs, KLDs]
    names = ["MSE", "Binary", "KLD"]
    for (yloss, loss_name) in zip(lists, names):
        plt.plot(range(train_acc.shape[0]), yloss, label=loss_name)
    plt.legend(frameon=False)
    plt.xlabel("Epoch")
    plt.ylabel("Loss")
    plt.yscale("log")
    handle = p_data.split("/")[-1][:-3]
    plt.savefig(f"../figures/trainingproc_loss_{handle}.pdf", bbox_inches="tight")
    plt.show()

In [6]:
hub = 1
p_data = f"../data/synthetic_v4_trophic_{hub}_filtered.pt"
X_train_scaled, X_test_scaled, y_train_scaled, y_test_scaled, annotations = load_filtered_data(p_data)
torch.isnan(X_test_scaled).any()

tensor(False)

In [7]:
for hub in hubs[:1]:
    train_VAE_on_data(f"../data/synthetic_v4_trophic_{hub}_filtered.pt", f"../models/synthetic_v4_trophic_{hub}_trained_AE.pt", weights=[1.0, 1.0, 0.0])

  8%|▊         | 8/100 [00:09<01:53,  1.24s/it]


KeyboardInterrupt: 