In [None]:
#------------------------------
# Setup: 
#------------------------------
#---Activate the enviroment:
using Pkg;

Pkg.activate("../");
Pkg.instantiate();
Pkg.status()

#---Load CCIM rat lung data computed by NICHES:
projectpath = joinpath(@__DIR__, "../"); 
figurespath = projectpath * "figures/simData/"
if !isdir(figurespath)
    # Create the folder if it does not exist
    mkdir(figurespath)
end

#---Load the BoostingAutoEncoder module:
include(projectpath * "/src/BAE.jl");
using .BoostingAutoEncoder

#---Load required packages for this notebook:
using Plots;
using Random;
using StatsBase;
using VegaLite;  
using DataFrames;

# Train BAE on simulated scRNA-seq data

In [None]:
#---Generate and transform data:
dataseed = 1;
n_genes=200;
X = sim_scRNAseqData(dataseed; n=1000, stageoverlap=0, 
    blockprob = 0.6, noiseprob=0.1, num_genes=n_genes,
    stageno = 10,  stagep = Int(50 / 10), stagen = Int(1000 / 10)
);
X_st = scale(X);
#X_st[:, 50:end] .*= 0.8f0
n, p = size(X);

#---Create meta data:
MD = MetaData();
MD.obs_df[!, :CellGroup] = repeat(1:10, inner=100)
MD.featurename = ["$(j)" for j in 1:p];

#---Plot the binary data:
vegaheatmap(X[:, 1:100]; 
    path=figurespath * "binary_data.svg", 
    Title="Binary Gene expression ($(n_genes-100) noise genes omitted)",
    xlabel="Gene", 
    ylabel="Cell",
    legend_title="Value",
    color_field="value:o",
    domain_mid=nothing,
    scheme="paired",
    save_plot=true,
    Width=500, 
    Height=500
)

In [None]:
#---Define hyperparameter for training a BAE
HP = Hyperparameter(zdim=6, epochs=20000, batchsize=2^9, η=0.01f0, λ=0.1f0, ϵ=0.001f0, M=1); 

#---Define the decoder architecture:
modelseed = 42; 
Random.seed!(modelseed)
decoder = generate_BAEdecoder(p, HP; soft_clustering=true);

#---Initialize the BAE model:
BAE = BoostingAutoencoder(; coeffs=zeros(Float32, p, HP.zdim), decoder=decoder, HP=HP);
summary(BAE)

In [None]:
#---Train the BAE model:
batchseed = 42;
Random.seed!(batchseed)
mean_trainlossPerEpoch = train_BAE!(X_st, BAE; soft_clustering=true, MD=MD, save_data=false, path=nothing);

#---Save the top selected genes per cluster:
TopGenes_Cluster = topFeatures_per_Cluster(BAE, MD; path=figurespath);

#---Plot the mean trainloss per epoch:
loss_plot = plot(1:length(mean_trainlossPerEpoch), mean_trainlossPerEpoch,
     title = "Mean train loss per epoch",
     xlabel = "Epoch",
     ylabel = "Loss",
     legend = true,
     label = "Train loss",
     linecolor = :red,
     linewidth = 2
);
savefig(loss_plot, figurespath * "/SimData_loss_BAE.svg");
loss_plot

## Visualize the Results

In [None]:
#---Plot the encoder weights:
vegaheatmap(BAE.coeffs'; 
    path=figurespath * "encoderWeights_BAE.svg", 
    Title="Encoder weights",
    xlabel="Gene", 
    ylabel="Latent dimension",
    legend_title="Value",
    save_plot=true
)

In [None]:
#---Plot the latent representations of cells:
vegaheatmap(BAE.Z'; 
    path=figurespath * "latentRepresentation_BAE.svg", 
    Title="Latent representation",
    xlabel="Latent dimension", 
    ylabel="Cell",
    legend_title="Activation",
    save_plot=true
)

In [None]:
#---Plot the probabilities of cells of belonging to the different clusters (each latent dimension corresponds to two subsequent clusters [reflecting pos. and neg. activations]):
vegaheatmap(BAE.Z_cluster'; 
    path=figurespath * "clusterProbabilities_BAE.svg", 
    Title="Cluster probabilities of cells",
    xlabel="Cluster", 
    ylabel="Cell",
    legend_title="Probability",
    scheme="purpleblue",
    domain_mid=nothing,
    save_plot=true
)

In [None]:
#---Plot the absolute values of Pearson correlation coefficients between latent dimensions:
vegaheatmap(abs.(cor(BAE.Z, dims=2)); 
    path=figurespath * "cor_latentDimensions_BAE.svg", 
    Title="Absolute correlations of latent dimensions",
    xlabel="Latent dimension", 
    ylabel="Latent dimension",
    legend_title="Value",
    scheme="orangered",
    domain_mid=nothing,
    save_plot=true,
    Width=500, 
    Height=500
)

In [None]:
#---Plot the absolute values of Pearson correlation coefficients between latent dimensions:
vegaheatmap(abs.(cor(X, dims=2)); 
    path=figurespath * "cor_cells.svg", 
    Title="Absolute correlations of cells",
    xlabel="Cell", 
    ylabel="Cell",
    legend_title="Value",
    scheme="orangered",
    domain_mid=nothing,
    save_plot=true,
    Width=500, 
    Height=500
)

In [None]:
#---Plot the absolute values of Pearson correlation coefficients between latent dimensions (latent representation is used):
vegaheatmap(abs.(cor(BAE.Z, dims=1)); 
    path=figurespath * "cor_cellslatentRepresentations_BAE.svg", 
    Title="Absolute correlations of cell Representations",
    xlabel="Cell", 
    ylabel="Cell",
    legend_title="Value",
    scheme="orangered",
    domain_mid=nothing,
    save_plot=true,
    Width=500, 
    Height=500
)

In [None]:
#---Plot the absolute values of Pearson correlation coefficients between latent dimensions (cluster representation is used):
vegaheatmap(abs.(cor(BAE.Z_cluster, dims=1)); 
    path=figurespath * "cor_cellsRepresentations(Cluster)_BAE.svg", 
    Title="Absolute correlations of cell Representations",
    xlabel="Cell", 
    ylabel="Cell",
    legend_title="Value",
    scheme="orangered",
    domain_mid=nothing,
    save_plot=true,
    Width=500, 
    Height=500
)

In [None]:
#---Create scatter plots of the top selected genes per latent dimension:
if !isdir(figurespath * "/TopFeaturesLatentDim")
    # Create the folder if it does not exist
    mkdir(figurespath * "/TopFeaturesLatentDim")
end
for dim in 1:BAE.HP.zdim
    Featurescatter_plot = normalizedFeatures_scatterplot(BAE.coeffs[:, dim], MD.featurename, dim; top_n=10)
    savefig(Featurescatter_plot, figurespath * "/TopFeaturesLatentDim/" * "BAE_dim$(dim)_topGenes.svg")
end

In [None]:
#---Create scatter plots of the top selected genes per cluster:
if !isdir(figurespath * "/TopFeaturesCluster")
    # Create the folder if it does not exist
    mkdir(figurespath * "/TopFeaturesCluster")
end
for key in keys(TopGenes_Cluster)
    FeatureScatter_plot = TopFeaturesPerCluster_scatterplot(TopGenes_Cluster[key], key; top_n=10)
    savefig(FeatureScatter_plot, figurespath * "/TopFeaturesCluster/" * "BAE_Cluster$(key)_topGenes.svg")
end

In [None]:
#----Compute 2D UMAP embedding of the learned BAE latent representation and add to the metadata:
plotseed = 7;
BAE.UMAP = generate_umap(BAE.Z', plotseed);
MD.obs_df[!, :UMAP1] = BAE.UMAP[:, 1];
MD.obs_df[!, :UMAP2] = BAE.UMAP[:, 2];

#---Randomly shuffle the observation indices for plotting:
rand_inds = shuffle(1:size(X_st, 1));
MD.obs_df = MD.obs_df[rand_inds, :];

In [None]:
#---Plot a heatmap of the cluster probabilities of cells:
Cluster_df = DataFrame(BAE.Z_cluster[:, rand_inds]', :auto);
Cluster_df[!, :Cluster] = copy(MD.obs_df.Cluster);
sort!(Cluster_df, :Cluster);
vegaheatmap(Matrix(Cluster_df[:, 1:end-1]); 
    path=figurespath * "clusterProbabilities_BAE.svg", 
    Title="Cluster probabilities of cells",
    xlabel="Cluster", 
    ylabel="Cell",
    legend_title="Probability",
    scheme="purpleblue",
    domain_mid=nothing,
    save_plot=true
)

In [None]:
#---Plot the UMAP embedding of the learned BAE latent representation colored by the cell group labels:
vegascatterplot(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), MD.obs_df.CellGroup; 
    path=figurespath * "Celltype_(BAE)umap.svg",
    legend_title="Cell type",
    color_field="labels:o",
    scheme="category20",
    domain_mid=nothing,
    range=nothing,
    save_plot=true,
    marker_size="25"
)

In [None]:
#---Plot the UMAP embedding of the learned BAE latent representation colored by the cluster labels:
vegascatterplot(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), MD.obs_df.Cluster; 
    path=figurespath * "Cluster_(BAE)umap.svg",
    legend_title="Cluster",
    color_field="labels:o",
    scheme="category20",
    domain_mid=nothing,
    range=nothing,
    save_plot=true,
    marker_size="25"
)

In [None]:
#---Create scatter plots of the UMAP embedding of the learned BAE latent representation colored by activations in different latent dimensions:
if !isdir(figurespath * "/UMAPplotsLatDims")
    # Create the folder if it does not exist
    mkdir(figurespath * "/UMAPplotsLatDims")
end
create_colored_vegascatterplots(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), BAE.Z[:, rand_inds];
    path=figurespath * "/UMAPplotsLatDims/",
    filename="Mouse_BAE_dim",
    filetype="scatter.svg",
    legend_title="Activation",
    color_field="labels:q",
    scheme="blueorange", 
    domain_mid=0,
    range=nothing,
    save_plot=true,
    marker_size="25"
)

In [None]:
#---Create scatter plots of the UMAP embedding of the learned BAE latent representation colored by activations for different clusters:
if !isdir(figurespath * "/UMAPplotsCluster")
    # Create the folder if it does not exist
    mkdir(figurespath * "/UMAPplotsCluster")
end
color_range = [
    "#fff5f5", "#ffe0e0", "#ffcccc", "#ffb8b8", "#ffa3a3", "#ff8f8f", "#ff7a7a", "#ff6666",
    "#ff5252", "#ff3d3d", "#ff2929", "#ff1414", "#ff0000", "#e50000", "#cc0000", "#b20000",
    "#990000", "#7f0000", "#660000", "#4c0000", "#330000"
];
create_colored_vegascatterplots(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), BAE.Z_cluster[:, rand_inds];
    path=figurespath * "/UMAPplotsCluster/",
    filename="Mouse_BAE_dim",
    filetype="scatter.svg",
    legend_title="Activation",
    color_field="labels:q",
    scheme=nothing,
    domain_mid=nothing,
    range=color_range,
    save_plot=true,
    marker_size="25"
)

In [None]:
#---Create scatter plots of the UMAP embedding of the learned BAE latent representation colored by expression levels of top selected genes for different clusters:
if !isdir(figurespath * "/FeaturePlots")
    # Create the folder if it does not exist
    mkdir(figurespath * "/FeaturePlots")
end
color_range = [
    "#fff5f5", "#ffe0e0", "#ffcccc", "#ffb8b8", "#ffa3a3", "#ff8f8f", "#ff7a7a", "#ff6666",
    "#ff5252", "#ff3d3d", "#ff2929", "#ff1414", "#ff0000", "#e50000", "#cc0000", "#b20000",
    "#990000", "#7f0000", "#660000", "#4c0000", "#330000"
];
FeaturePlots(TopGenes_Cluster, MD.featurename, X[rand_inds, :], Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]); 
    top_n=5,
    marker_size="25", 
    fig_type=".svg",
    path=figurespath * "/FeaturePlots/",
    legend_title="log1p",
    color_field="labels:o",
    scheme=nothing, 
    domain_mid=nothing,
    range=color_range
)