# Setup:

In [None]:
#---Activate the enviroment:
using Pkg;

Pkg.activate("../");
Pkg.instantiate();
Pkg.status()

#---Load CCIM rat lung data computed by NICHES:
projectpath = joinpath(@__DIR__, "../"); 
datapath = projectpath * "data/rat/"
if !isdir(datapath)
    # Create the folder if it does not exist
    mkdir(datapath)
end
figurespath = projectpath * "figures/rat/"
if !isdir(figurespath)
    # Create the folder if it does not exist
    mkdir(figurespath)
end

#---Load the BoostingAutoEncoder module:
include(projectpath * "/src/BAE.jl");
using .BoostingAutoEncoder

#---Load required packages for this notebook:
using RCall;
using DelimitedFiles;
using Plots;
using Random;
using StatsBase;
using VegaLite;  
using DataFrames;
using StatsPlots;

# Load data and create CCIM:

In [None]:
#---Download a subset of the example rat lung data from "https://zenodo.org/record/6846618/files/raredon_2019_rat.Robj" and save it to the data directory:
X_alra, tSNE_embeddings, celltypes, genenames = load_rat_scRNAseq_data(; data_path=datapath, transfer_data=true, assay="alra");

#---Run NICHES on scRNA-seq data:
filepath_expData = datapath * "Rat_Seurat_sub.rds";
run_NICHES_wrapper(filepath_expData; data_path=datapath, assay="alra", species="rat");

#---Load the NICHES CCIM and MetaData:
filepath_CCIM = datapath * "NICHES_CellToCell.rds";
CCIM, CCIM_st, MD = load_CCIM_CtC(filepath_CCIM); #CellToCell

# Analyze CCIM patterns in rat lung data with the BAE:

In [None]:
#---Define hyperparameters for training a BAE:
HP = Hyperparameters(zdim=30, n_runs=1, max_iter=2000, tol=1e-5, batchsize=2^12, η=0.01, λ=0.1, ϵ=0.001, M=1); 

#---Define the decoder architecture:
n_cellpairs, p = size(CCIM_st);
decoder = generate_BAEdecoder(p, HP; soft_clustering=true); 

#---Initialize the BAE model:
BAE = BoostingAutoencoder(; coeffs=zeros(eltype(CCIM_st), p, HP.zdim), decoder=decoder, HP=HP);
summary(BAE)

In [None]:
#---Train the BAE model:
@time begin
     output_dict = train_BAE!(CCIM_st, BAE; MD=MD); 
end;

# Result visualization:

In [None]:
#----Compute 2D UMAP embedding of the learned BAE latent representation and add to the metadata:
BAE.UMAP = generate_umap(BAE.Z'); 
MD.obs_df[!, :UMAP1] = BAE.UMAP[:, 1];
MD.obs_df[!, :UMAP2] = BAE.UMAP[:, 2];

#---Randomly shuffle the observation indices for plotting:
rand_inds = shuffle(1:n_cellpairs);
MD.obs_df = MD.obs_df[rand_inds, :];

#---Generate distinct colors:
n_cols = 2*BAE.HP.zdim; 
custom_colorscheme = [hsl_to_hex(i / n_cols, 0.7, 0.5 + 0.1 * sin(i * 4π / BAE.HP.zdim)) for i in 1:n_cols]; 
custom_colorscheme_shuffled = shuffle(custom_colorscheme);

#---Set color ranges for scatter plots (one for dark and one for light backgrounds):
#For dark backgrounds:
color_range_dark = [
    "#fff5f5", "#ffe0e0", "#ffcccc", "#ffb8b8", "#ffa3a3", "#ff8f8f", "#ff7a7a", "#ff6666",
    "#ff5252", "#ff3d3d", "#ff2929", "#ff1414", "#ff0000", "#e50000", "#cc0000", "#b20000",
    "#990000", "#7f0000", "#660000", "#4c0000", "#330000"
];
#For light backgrounds:
color_range_light = [
    "#000000", "#220022", "#440044", "#660066", "#880088", "#aa00aa", "#cc00cc", "#ee00ee",
    "#ff00ff", "#ff19ff", "#ff33ff", "#ff4cff", "#ff66ff", "#ff7fff", "#ff99ff", "#ffb2ff",
    "#ffccff", "#ffe5ff", "#ffccf5", "#ff99eb", "#ff66e0"
];

In [None]:
#---Create scatter plots of the top selected genes per cluster:
if !isdir(figurespath * "/TopFeaturesCluster")
    # Create the folder if it does not exist
    mkdir(figurespath * "/TopFeaturesCluster")
end
for key in keys(MD.Top_features)
    if length(MD.Top_features[key].Scores) > 0
        FeatureScatter_plot = TopFeaturesPerCluster_scatterplot(MD.Top_features[key], key; top_n=10)
        savefig(FeatureScatter_plot, figurespath * "/TopFeaturesCluster/" * "BAE_Cluster$(key)_Interactions.png")
    end
end

In [None]:
#---Plot the UMAP embedding of the learned BAE latent representation colored by the cluster labels:
vegascatterplot(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), MD.obs_df.Cluster; 
    path=figurespath * "Cluster_(BAE)umap.png",
    legend_title="Cluster",
    color_field="labels:o",
    scheme=nothing,
    domain_mid=nothing,
    range=custom_colorscheme_shuffled,
    save_plot=true,
    marker_size="5"
)

In [None]:
#---Plot the UMAP embedding of the learned BAE latent representation colored by the sender cell types:
vegascatterplot(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), MD.obs_df.SenderType; 
    path=figurespath * "SenderType_(BAE)umap.png",
    legend_title="Sender",
    color_field="labels:o",
    scheme=nothing,
    domain_mid=nothing,
    range=custom_colorscheme[[1, 3, 14, 26, 31, 36, 42, 45, 53]],
    save_plot=true,
    marker_size="5"
)

In [None]:
#---Plot the UMAP embedding of the learned BAE latent representation colored by the sender cell types:
vegascatterplot(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), MD.obs_df.ReceiverType; 
    path=figurespath * "ReceiverType_(BAE)umap.png",
    legend_title="Receiver",
    color_field="labels:o",
    scheme=nothing,
    domain_mid=nothing,
    range=custom_colorscheme[[1, 3, 14, 26, 31, 36, 42, 45, 53]],
    save_plot=true,
    marker_size="5"
)

In [None]:
#---Create scatter plots of the UMAP embedding of the learned BAE latent representation colored by activations for different clusters:
if !isdir(figurespath * "/UMAPplotsCluster")
    # Create the folder if it does not exist
    mkdir(figurespath * "/UMAPplotsCluster")
end
create_colored_vegascatterplots(Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]), BAE.Z_cluster[:, rand_inds];
    path=figurespath * "/UMAPplotsCluster/",
    filename="Rat_BAE_dim",
    filetype="scatter.png",
    legend_title="Activation",
    color_field="labels:q",
    scheme=nothing, 
    domain_mid=nothing,
    range=color_range_light,
    save_plot=true,
    marker_size="10"
)

In [None]:
#---Create scatter plots of the UMAP embedding of the learned BAE latent representation colored by expression levels of top selected genes for different clusters:
if !isdir(figurespath * "/FeaturePlots")
    # Create the folder if it does not exist
    mkdir(figurespath * "/FeaturePlots")
end
FeaturePlots(MD.Top_features, MD.featurename, CCIM[rand_inds, :], Matrix(MD.obs_df[:, [:UMAP1, :UMAP2]]); 
    top_n=5,
    marker_size="10", 
    fig_type=".png",
    path=figurespath * "/FeaturePlots/",
    legend_title="log1p",
    color_field="labels:q",
    scheme=nothing, 
    domain_mid=nothing,
    range=color_range_light
)