# Tutorial Notebook: Functionality of the Boosting Autoencoder (BAE) with the disentanglement constraint and the soft clustering component

**Functionality tutorial for analyzing feature patterns of single-cell data with the [Boosting Autoencoder (BAE)](https://github.com/NiklasBrunn/BoostingAutoencoder).** 


### The notebook is devided in five main steps:

- [Setup](#Setup)
- [Load the gene expression data](#Load-the-gene-expression-data)
- [Pattern analysis with the BAE](#Pattern-analysis-with-the-BAE)
- [Result visualization and plots saving](#Result-visualization-and-plots-saving)

## Setup:
First, you can activate the Julia environment and load all the packages needed to run the BAE functionality tutorial notebook. The first time you run the following cell, all required packages will be downloaded and precompiled, which may take a moment.

In [None]:
#---Activate the enviroment:
using Pkg;

Pkg.activate("../");
Pkg.instantiate();
Pkg.status()

#---Load CCIM rat lung data computed by NICHES:
projectpath = joinpath(@__DIR__, "../"); 
datapath = projectpath * "data/simData/";
if !isdir(datapath)
    # Create the folder if it does not exist
    mkdir(datapath)
end
figurespath = projectpath * "figures/simData/"
if !isdir(figurespath)
    # Create the folder if it does not exist
    mkdir(figurespath)
end

#---Load the BoostingAutoEncoder module:
include(projectpath * "/src/BAE.jl");
using .BoostingAutoEncoder

#---Load required packages for this notebook:
using Plots;
using Random;
using StatsBase;
using VegaLite;  
using DataFrames;
using StatsPlots;
using CSV;

## Load the gene expression data:

...

In [None]:
#---Generate and transform data:
dataseed = 1;
n_cells = 1000;
n_genes = 200;
n_overlap = 0;
stageno = 10;
blockprob = 0.6;
noiseprob = 0.1;

X = sim_scRNAseqData(dataseed; 
    n=n_cells, 
    stageoverlap=n_overlap, 
    blockprob=blockprob, 
    noiseprob=noiseprob, 
    num_genes=n_genes,
    stageno=stageno,  
    stagep=Int(50 / stageno), 
    stagen=Int(n_cells / stageno)
);

#---Scale the data:
X_st = scale(X);
#Optional: Rescale noise genes of X_st, because values of noise genes are more extreme after scaling ...
#X_st[:, 50:end] .*= 0.8f0

n, p = size(X);

#---Create meta data:
MD = MetaData();
MD.obs_df[!, :CellGroup] = repeat(1:10, inner=100)
MD.featurename = ["$(j)" for j in 1:p];

#---Plot the binary data:
vegaheatmap(X[:, 1:100]; 
    path=figurespath * "binary_data.png", 
    Title="Binary Gene expression ($(n_genes-100) noise genes omitted)",
    xlabel="Gene", 
    ylabel="Cell",
    legend_title="Value",
    color_field="value:o",
    domain_mid=nothing,
    scheme="paired",
    save_plot=true,
    Width=500, 
    Height=500
)

## Pattern analysis with the BAE:

In [None]:
#---Define hyperparameters for training a BAE
HP = Hyperparameters(zdim=6, n_runs=2, max_iter=100, tol=1e-6, batchsize=2^9, η=0.01, λ=0.1, ϵ=0.03, M=1);  

#---Define the decoder architecture:
decoder = generate_BAEdecoder(p, HP; soft_clustering=true);

#---Initialize the BAE model:
BAE = BoostingAutoencoder(; coeffs=zeros(eltype(X_st), p, HP.zdim), decoder=decoder, HP=HP);
summary(BAE)

In [None]:
#---Train the BAE model:
@time begin
    output_dict = train_BAE!(X_st, BAE; MD=MD, track_coeffs=true, save_data=true, data_path=datapath);
end;

In [None]:
#---Plot the mean trainloss per epoch:
mean_trainlossPerEpoch = output_dict["trainloss"];
loss_plot = plot(1:length(mean_trainlossPerEpoch), mean_trainlossPerEpoch,
     title = "Mean train loss per epoch",
     xlabel = "Epoch",
     ylabel = "Loss",
     legend = true,
     label = "Train loss",
     linecolor = :red,
     linewidth = 2
);
loss_plot

## Result visualization and plots saving:

In [None]:
#---Plot the encoder weights:
vegaheatmap(BAE.coeffs'; 
    path=figurespath * "encoderWeights_BAE.png", 
    Title="Encoder weights",
    xlabel="Gene", 
    ylabel="Latent dimension",
    legend_title="Value",
    save_plot=true
)

In [None]:
#---Plot the latent representations of cells:
vegaheatmap(BAE.Z'; 
    path=figurespath * "latentRepresentation_BAE.png", 
    Title="Latent representation",
    xlabel="Latent dimension", 
    ylabel="Cell",
    legend_title="Activation",
    save_plot=true
)

In [None]:
#---Plot the absolute values of Pearson correlation coefficients between latent dimensions:
vegaheatmap(abs.(cor(BAE.Z, dims=2)); 
    path=figurespath * "cor_latentDimensions_BAE.png", 
    Title="Absolute correlations of latent dimensions",
    xlabel="Latent dimension", 
    ylabel="Latent dimension",
    legend_title="Value",
    scheme="inferno",
    domain_mid=nothing,
    save_plot=true,
    Width=500, 
    Height=500
)

In [None]:
#---Plot the probabilities of cells of belonging to the different clusters (each latent dimension corresponds to two subsequent clusters [reflecting pos. and neg. activations]):
vegaheatmap(BAE.Z_cluster'; 
    path=figurespath * "clusterProbabilities_BAE.png", 
    Title="Cluster probabilities of cells",
    xlabel="Cluster", 
    ylabel="Cell",
    legend_title="Probability",
    scheme="purpleblue",
    domain_mid=nothing,
    save_plot=true
)

### ToDO:

In [None]:
#---Create scatter plots of the top selected genes per latent dimension:
#using IJulia
if !isdir(figurespath * "/TopFeaturesLatentDim")
    # Create the folder if it does not exist
    mkdir(figurespath * "/TopFeaturesLatentDim")
end
for dim in 1:BAE.HP.zdim
    Featurescatter_plot = normalizedFeatures_scatterplot(BAE.coeffs[:, dim], MD.featurename, dim; top_n=20)
    savefig(Featurescatter_plot, figurespath * "/TopFeaturesLatentDim/" * "BAE_dim$(dim)_topGenes.png")
end

#---Create scatter plots of the top selected genes per cluster:
if !isdir(figurespath * "/TopFeaturesCluster")
    # Create the folder if it does not exist
    mkdir(figurespath * "/TopFeaturesCluster")
end
for key in keys(MD.Top_features)
    if length(MD.Top_features[key].Scores) > 0
        FeatureScatter_plot = TopFeaturesPerCluster_scatterplot(MD.Top_features[key], key; top_n=10)
        savefig(FeatureScatter_plot, figurespath * "/TopFeaturesCluster/" * "BAE_Cluster$(key)_topGenes.svg")
    end
end

#---Create a coefficient plots for visually inspecting coefficient update trajectories for the last run of the training:
if haskey(output_dict, "coefficients")
    if !isdir(figurespath * "/CoefficientsPlots")
        # Create the folder if it does not exist
        mkdir(figurespath * "/CoefficientsPlots")
    end
    for dim in 1:BAE.HP.zdim
        pl = track_coefficients(output_dict["coefficients"], dim; iters=nothing, xscale=:log10)
        savefig(pl, figurespath * "/CoefficientsPlots/CoefficientsPlot_BAE_dim$(dim).png")
    end
else 
    @warn "No coefficient trajectories were saved during training."
end

#---Get the top selected genes for each latent dimension:
sel_genes = get_topFeatures(BAE, MD)
for l in 1:BAE.HP.zdim
    println("Top selected genes for latent variable $(l): ", sel_genes[l])
end

### ToDO:

<img src="../figures/functionality_notebook/BAE_dim2_topGenes.svg" width="500">

### ToDO:

<img src="../figures/functionality_notebook/BAE_Cluster3_topGenes.svg" width="500">

### ToDO:

<img src="../figures/functionality_notebook/BAE_Cluster4_topGenes.svg" width="500">

### ToDo:

<img src="../figures/functionality_notebook/CoefficientsPlot_BAE_dim1.svg" width="500">

In [None]:
# Load an example dataframe of top selected genes for a specific cluster:
cluster = 4;
top_features_cluster3 = CSV.read(datapath * "TopFeaturesCluster_CSV/topFeatures_Cluster_$(cluster).csv", DataFrame)