# `ABaCo` tutorial: Anaerobic Digestion benchmark dataset

In this notebook we will implement ABaCo for batch correction on an Anaerobic Digestion (AD) dataset under different phenol concentration conditions. 

-----
**Data Description:**
- The AD dataset generated by ([Chapleur et al., 2016](https://doi.org/10.1007/s10532-015-9751-4)) and shared in R package [PLSDAbatch](https://github.com/EvaYiwenWang/PLSDAbatch) by [Wang and Cao (2023)](https://doi.org/10.1093/bib/bbac622) is a dataset composed of 75 samples with 567 identified taxonomic groups. 
- The samples were treated with two different phenol concentrations, accounting for the biological source of variation. 
- Samples were processed on 5 different dates over the lapse of 2 years, accounting for the technical source of variation (i.e., batch effect).

**Goal:**

With ABaCo the aim is to remove the technical variation (batch effect) while retaining the biological variation (phenol group effect). 

-----

## Libraries

In [None]:
# Essentials
import torch

import numpy as np
import pandas as pd
# import re
import os
import sys
import random
# import seaborn as sns
# import networkx as nx

from abaco.BatchEffectCorrection import correctCombat, correctLimma_rBE, correctBMC, correctPLSDAbatch_R, correctCombatSeq, correctConQuR
from abaco.BatchEffectPlots import plotPCA, plotPCoA, plot_LISI_perplexity
from abaco.BatchEffectMetrics import all_metrics, pairwise_distance, pairwise_distance_std, PERMANOVA, pairwise_distance_multi_run, kBET, ARI, ASW, iLISI_norm
from abaco.ABaCo import abaco_run, abaco_recon, contour_plot

>> clustergrammer2 backend version 0.18.0


## Dataset Requirements

The dataset should contain the following to be compatible with the ABaCo framework:

[insert image or phony table]

- The data must have 3 categorical columns: 
    1. unique ids to identify the observations/samples e.g. sample ids `sample`
    2. ids for the batch/factor groupings to be corrected by abaco. e.g. dates of sample analysis `batch`
    3. biological/experimental factor variation for abaco to retain when correcting batch effect e.g., phenol concentration condition `trt`

- And the features (numeric type) to be trained on. 

`abaco` provides a `BatchEffectDataLoader.dataPreprocess()` function to help convert a plain text file (e.g, csv, tsv) into a compatible pd.DataFrame format, and a `onehotencoding()` function to encode categorical columns. 

In [103]:
from abaco.BatchEffectDataLoader import DataPreprocess, one_hot_encoding

# Load AD count
path_to_dataset = "data/dataset_ad.csv"
ad_count_batch_label = "batch"
ad_count_sample_label = "sample"
ad_count_bio_label = "trt"

# Convert data path into compatible pd.DataFrame with ABaCo framework
ad_count_data = DataPreprocess(
    path_to_dataset,
    factors=[
        ad_count_sample_label, 
        ad_count_batch_label, 
        ad_count_bio_label
    ]
)

# # One-hot encode the categorical columns
# one_hot_encoding(ad_count_data[ad_count_batch_label])[0], # one hot encoded batch information
# one_hot_encoding(ad_count_data[ad_count_bio_label])[0], #one hot encoded biological information


# see that there are 3 categorical and n numeric columns
ad_count_data.info()


# Summary
plotPCoA(
    data = ad_count_data, 
    sample_label=ad_count_sample_label, 
    batch_label=ad_count_batch_label, 
    experiment_label=ad_count_bio_label
)
print(ad_count_data.groupby([ad_count_batch_label, ad_count_bio_label]).size())

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 75 entries, 0 to 74
Columns: 570 entries, sample to Cluster_17710
dtypes: category(3), int64(567)
memory usage: 335.5 KB


batch       trt  
01/07/2016  0-0.5     8
            1-2      13
09/04/2015  0-0.5     4
            1-2       5
14/04/2016  0-0.5     4
            1-2      12
14/11/2016  0-0.5     8
            1-2       9
21/09/2017  0-0.5     2
            1-2      10
dtype: int64


abaco also provides plotting functions to easily visualize the batch and biological effects on the data. Here we use `abaco.BatchEffectPlots.plotPCoA()`.  

In [None]:
# Summary
plotPCoA(
    data = ad_count_data, 
    sample_label=ad_count_sample_label, 
    batch_label=ad_count_batch_label, 
    experiment_label=ad_count_bio_label
)
print(ad_count_data.groupby([ad_count_batch_label, ad_count_bio_label]).size())

In [41]:
ad_count_data[ad_count_batch_label].unique()

['09/04/2015', '14/04/2016', '01/07/2016', '14/11/2016', '21/09/2017']
Categories (5, object): ['01/07/2016', '09/04/2015', '14/04/2016', '14/11/2016', '21/09/2017']

In [None]:
from sklearn.preprocessing import OneHotEncoder

In [99]:
mlb = OneHotEncoder(
    sparse_output=False, 
    dtype=np.float32,
    drop=None # is this a correct behaviour 
)
ohe_batch = torch.tensor(mlb.fit_transform(ad_count_data[[ad_count_batch_label]]))
ohe_batch_classes = mlb.categories_

bio_binarizer = OneHotEncoder(
    sparse_output=False, 
    dtype=np.float32,
    drop=None # is this a correct behaviour, should use LabelBinarizer instead and drop 1? 
)
ohe_bio = torch.tensor(bio_binarizer.fit_transform(ad_count_data[[ad_count_bio_label]]))
ohe_bio_classes = bio_binarizer.categories_


In [122]:
bio_binarizer = OneHotEncoder(
    sparse_output=False, 
    dtype=np.float32,
    drop=None # is this a correct behaviour, should use LabelBinarizer instead and drop 1? 
)

bio_binarizer.fit_transform(ad_count_data[[ad_count_bio_label, ad_count_batch_label]])

bio_binarizer.categories_

[array(['0-0.5', '1-2'], dtype=object),
 array(['01/07/2016', '09/04/2015', '14/04/2016', '14/11/2016',
        '21/09/2017'], dtype=object)]

## Data Preparation for PyTorch

ABaCo is implemented using the [Pytorch ecosystem](https://docs.pytorch.org/docs/stable/index.html). 

Following the typical [workflow](https://docs.pytorch.org/tutorials/beginner/basics/quickstart_tutorial.html) we will use their [`torch.utils.data.DataLoader` class](https://docs.pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader) to easily iterate over the dataset in batches. 

In [101]:
from torch.utils.data import DataLoader, TensorDataset

# Construct DataLoader: [samples, ohe_batch]
ad_count_input_size = 567
ad_count_train_dataloader = DataLoader(
    TensorDataset(
        features_tensor,#torch.tensor(ad_count_data.select_dtypes(include="number").values, dtype=torch.float32), # samples
        ohe_batch, #one_hot_encoding(ad_count_data[ad_count_batch_label])[0], # one hot encoded batch information
        ohe_bio, #one_hot_encoding(ad_count_data[ad_count_bio_label])[0], #one hot encoded biological information
    ),
    batch_size=1000,
)
ad_count_batch_size = 5
ad_count_bio_size = 2

In [91]:
ohe_batch.shape, ohe_bio.shape, features_tensor.shape

(torch.Size([75, 5]), torch.Size([75, 2]), torch.Size([75, 567]))

In [55]:
ad_count_train_dataloader.dataset.tensors[0].shape

torch.Size([75, 567])

In [None]:
ohe_batch.shape, ohe_bio.shape, features_tensor.shape

(torch.Size([75, 5]), torch.Size([75, 1]), torch.Size([75, 567]))

In [50]:
ad_count_train_dataloader.dataset.tensors[0].shape

torch.Size([75, 567])

In [44]:

# Summary
plotPCoA(
    data = ad_count_data, 
    sample_label=ad_count_sample_label, 
    batch_label=ad_count_batch_label, 
    experiment_label=ad_count_bio_label
)
print(ad_count_data.groupby([ad_count_batch_label, ad_count_bio_label]).size())

batch       trt  
01/07/2016  0-0.5     8
            1-2      13
09/04/2015  0-0.5     4
            1-2       5
14/04/2016  0-0.5     4
            1-2      12
14/11/2016  0-0.5     8
            1-2       9
21/09/2017  0-0.5     2
            1-2      10
dtype: int64


## ABaCo run

In order to run ABaCo we need to setup some parameters. The following were defined but can be tuned for exploration of them. As follows there is a bried explanation of what each one does:

In [102]:
# Setup device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Setting seed for reproducibility
seed = 42

ad_count_abaco_model = abaco_run(
    dataloader=ad_count_train_dataloader,
    n_batches=ad_count_batch_size,
    n_bios=ad_count_bio_size,
    input_size=ad_count_input_size,
    device=device,
    prior = "VMM", # Prior distribution used (it can be also "MoG" or "Std"). Baseline is "VMM" which stands for VampPrior Mixture Model
    w_contra=100.0,  # Contrastive learning power: higher value yields a higher separation of biological groups at the latent space. Sometimes higher is better.
    beta=20.0,  # KL-divergence coefficient: higher value yields bigger penalization from the prior distribution
    kl_cycle=True,  # Regularization term used during the batch correction (ABaCo second phase)
    seed=seed,
    smooth_annealing=True,  # Slow batch masking during ABaCo second phase to avoid exploding gradients
    pre_epochs=2500,  # ABaCo first phase epochs: data reconstruction
    post_epochs=5000,  # ABaCo second phase epohcs: batch correction
    vae_pre_lr=1e-3,
    vae_post_lr=2e-4,
    adv_lr=1e-7,  # Adversarial learning rate: to the encoder, if batch effect is on latent space
    disc_lr=1e-7,  # Discriminator learning rate: to the batch discriminator
    new_pre_train=False,  # Another config
    w_cycle=1e-3,  # Higher value lets to more unstability during ABaCo second phase
)

Pre-training: VAE for reconstructing data and batch mixing adversarial training:   0%|          | 0/2500 [00:00<?, ?it/s]

Pre-training: VAE for reconstructing data and batch mixing adversarial training: 100%|██████████| 2500/2500 [01:28<00:00, 28.17it/s, adv=-1.6168, contra=368.8553, disc=1.6205, elbo=524.2938, epoch=2499/2501]
Training: VAE decoder with masked batch labels: 100%|██████████| 5000/5000 [23:12<00:00,  3.59it/s, epoch=5000/5000, vae_loss=617.2891]


## ABaCo data reconstruction

Once ABaCo model is trained on the dataset, the reconstruction can be done with or without Monte Carlo reconstruction setup (not recommended)

In [104]:
ad_count_corrected_data = abaco_recon(
    model=ad_count_abaco_model,
    device=device,
    data=ad_count_data,
    dataloader=ad_count_train_dataloader,
    sample_label=ad_count_sample_label,
    batch_label=ad_count_batch_label,
    bio_label=ad_count_bio_label,
    seed=42,
    monte_carlo=1, # Setting Monte Carlo sampling at 1 is the same as just sampling from the final ZINB distribution obtained from the trained model
)

plotPCoA(data = ad_count_corrected_data, sample_label=ad_count_sample_label, batch_label=ad_count_batch_label, experiment_label=ad_count_bio_label)