In [None]:
using AutoEncoderToolkit
using AutoEncoderToolkit.VAEs
using Clustering                
using Combinatorics
using CUDA
using CSV
using DataFrames
using DelimitedFiles
using Distributions
using Distances
using FileIO
using Flux
using GaussianMixtures         
using Glob
using LinearAlgebra
using MAT                      
using Measures
using MultivariateStats         
using NIfTI
using Plots
using Printf
using Random
using RegressionDynamicCausalModeling
using Statistics
using StatsBase                
using TSne
using MultivariateStats: PCA, fit, transform


include("utils.jl")
include("utils_clustering.jl")

const FUNC_DIR = "data/FunImgARCFW_n88"
const RESULTS_DIR = "data/connectivity_n88"
const ATLAS_DIR = "data"
const FUNC_FILENAME = "wFiltered_4DVolume.nii"
const ATLAS_FILENAME = "BN_Atlas_246_3mm.nii"
const BNA_MATRIX_FILENAME = "data/BNA_matrix_binary_246x246.csv"
const SUBDIAGNOSIS_CSV_FILENAME = "data/sub_diagnosis_n88.csv"



"data/sub_diagnosis_n88.csv"

1) Load Connectivity Data
Load both rDCM and functional connectivity results from previously saved `.mat` files.

This loads for each subject:
- connectivity_data["<subject_id>"]["rdcm"]
- connectivity_data["<subject_id>"]["fc_mat"]
- connectivity_data["<subject_id>"]["fc_z"]
- connectivity_data["<subject_id>"]["cov_mat"]

In [None]:
connectivity_data = load_connectivity_data(RESULTS_DIR; load_rdcm=true, load_functional=true)


2) Mask and Flatten Matrices and Build Matrices for Analysis
 Apply the structural mask (BNA matrix) and flatten the connectivity matrices into vectors:

These are stored as nested dictionaries where:
- Keys are subject IDs,
- Values are vectors of connectivity values.


Transform connectivity vectors into subject × feature matrices. 
You can also extract specific types like functional connectivity. 


In [3]:
masked_connectivity_vectors = mask_and_flatten_connectivity_data(connectivity_data, BNA_MATRIX_FILENAME; mask=true)
connectivity_vectors = mask_and_flatten_connectivity_data(connectivity_data, BNA_MATRIX_FILENAME; mask=false)
subjectIDs =  extract_allSubjectIDs(RESULTS_DIR)


masked_connectivity_matrix = get_connectivity_matrix(masked_connectivity_vectors, subjectIDs)
connectivity_matrix = get_connectivity_matrix(connectivity_vectors, subjectIDs )
matrix_fc = get_connectivity_matrix(connectivity_vectors, subjectIDs, "fc_mat")
matrix_masked_fc = get_connectivity_matrix(masked_connectivity_vectors, subjectIDs, "fc_mat")


88×11060 Matrix{Float64}:
 0.45254    0.582365   0.263348   …  -0.149657     0.169509   0.44305
 0.667324   0.61458    0.441389       0.0850222    0.0530302  0.543824
 0.454203   0.523889   0.317798       0.248442     0.58188    0.59517
 0.88403    0.105796   0.0389225      0.406011     0.174374   0.791872
 0.641462   0.38131    0.0934279      0.627856     0.363746   0.752038
 0.613588   0.110665  -0.0169369  …   0.231938     0.342462   0.896338
 0.954613   0.826107   0.829335       0.584308     0.344539   0.845087
 0.73033   -0.104463  -0.0800406      0.00623737  -0.0225886  0.66911
 0.573449   0.316101   0.263258       0.506535     0.575842   0.775527
 0.540277   0.36996    0.816273       0.722109     0.703819   0.445008
 ⋮                                ⋱                           
 0.424387   0.495474   0.484028       0.179833     0.0476406  0.5422
 0.545267   0.718087   0.625973   …   0.423885     0.0974599  0.537257
 0.744607   0.477448   0.618623       0.16184      0.224431   0.

3) Load Group Labels
 
Load subject diagnostic labels from a CSV file:

In [4]:
# Load CSV and skip the first row
data_table = CSV.read(SUBDIAGNOSIS_CSV_FILENAME, DataFrame; delim=',', skipto=2)
# Extract the 'group_1till5' column and convert to Int
group_col = data_table[!, Symbol("group [1-4, 4=NC, 3=PTSD+TBI, 2=PTSD, 1=TBI]")]
group_array = convert(Vector{Int}, group_col)  # convert to a plain array of Ints

88-element Vector{Int64}:
 3
 4
 4
 4
 4
 4
 4
 4
 4
 2
 ⋮
 2
 4
 1
 4
 1
 1
 1
 4
 4


The following parameters are defined  to control dimensionality reduction and clustering analysis:

- `dims_list`: Latent dimensions to evaluate → [2, 10, 20, 50, 100]

- `data_sources`: Connectivity data used for training:
    - "RDCM" → rDCM-based connectivity
    - "FC"   → functional connectivity (masked)

- `methods`: Dimensionality reduction methods to compare:
    - "PCA"
    - "tSNE"
    - "VAE"

- `clustering_methods`: Clustering algorithms applied in the reduced space:
    - "kmeans"
    - "fuzzy"
    - "hierarchical"

- `repeats`: Number of repetitions per method and configuration → 1

In [5]:
dims_list= [2, 10, 20, 50, 100]
data_sources = Dict("RDCM" => connectivity_matrix, "FC" => matrix_masked_fc)
methods = ["PCA", "tSNE", "VAE"]
repeats = 50
clustering_methods = ["kmeans", "fuzzy","hierarchical"]
select_counts = [20, 20, 0, 20] # how many subjects are selected from each of four groups 

4-element Vector{Int64}:
 20
 20
  0
 20

Training VAE for specified latent dimensions in dims_list_dimExp

In [6]:
# Dictionary to store all trained models and their metrics
vae_results = Dict()
n_epochs = 30

# Loop over datasets and latent dimensions
for (data_name, matrix) in data_sources
    println("Training on dataset: $data_name")

    dataset_results = Dict()
    matrix_f32 = Float32.((matrix))

    for dim in dims_list
        println("  Latent Dim: $dim")

        # Train the model
        vae_model, metrics = train_vae_model(matrix_f32, dim, n_epochs)

        # Save model and metrics
        dataset_results[string(dim)] = (
            model = vae_model,
            metrics = metrics,
        )
    end

    vae_results[data_name] = dataset_results
end


Training on dataset: RDCM
  Latent Dim: 2
Using GPU for training.
Epoch: 1
Epoch 1 | Batch 1 / 8
Epoch 1 | Batch 2 / 8
Epoch 1 | Batch 3 / 8
Epoch 1 | Batch 4 / 8
Epoch 1 | Batch 5 / 8
Epoch 1 | Batch 6 / 8
Epoch 1 | Batch 7 / 8
Epoch 1 | Batch 8 / 8
Epoch 1 / 30:
- Train MSE: 0.008182317
- Val MSE: 0.008999289
- Train Loss: 1.5149443e7
- Val Loss: 518.49896
- Train Entropy: 0.6938634
- Val Entropy: 0.6931714
Epoch: 2
Epoch 2 | Batch 1 / 8
Epoch 2 | Batch 2 / 8
Epoch 2 | Batch 3 / 8
Epoch 2 | Batch 4 / 8
Epoch 2 | Batch 5 / 8
Epoch 2 | Batch 6 / 8
Epoch 2 | Batch 7 / 8
Epoch 2 | Batch 8 / 8
Epoch 2 / 30:
- Train MSE: 0.008427602
- Val MSE: 0.009382518
- Train Loss: 598.5416
- Val Loss: 50.786255
- Train Entropy: 0.6937942
- Val Entropy: 0.6939193
Epoch: 3
Epoch 3 | Batch 1 / 8
Epoch 3 | Batch 2 / 8
Epoch 3 | Batch 3 / 8
Epoch 3 | Batch 4 / 8
Epoch 3 | Batch 5 / 8
Epoch 3 | Batch 6 / 8
Epoch 3 | Batch 7 / 8
Epoch 3 | Batch 8 / 8
Epoch 3 / 30:
- Train MSE: 0.008309871
- Val MSE: 0.009452

Run dimensionality reducation and clustering based on experiemntal parameters specified above 

In [13]:
results, grouped, all_results = evaluate_dimensionality_reduction(
    dims_list,
    data_sources,
    methods,
    repeats,
    clustering_methods,
    vae_results,
    group_array,
    select_counts
)

# Access results:
display(results)  # Summary DataFrame
# === Summary Results DataFrame ===
# This DataFrame contains one row per configuration tested across the pipeline.
# It summarizes the clustering performance after dimensionality reduction.

# Columns:
# - Data        :: String   → Name of the dataset ("RDCM", "FC", etc.)
# - Method      :: String   → Dimensionality reduction method used ("PCA", "tSNE", "VAE")
# - Dims        :: Int64    → Latent dimension size
# - Clustering  :: String   → Clustering algorithm used ("kmeans", "fuzzy", "hierarchical")
# - MeanAcc     :: Float64  → Mean accuracy across repeats
# - StdAcc      :: Float64  → Standard deviation of accuracy

# Example:
# Row	   Data	   Method	Dims	Clustering	  MeanAcc   StdAcc
# 1	  RDCM	PCA	    2	    kmeans	      0.413667	0.0285
# 2	  RDCM	PCA	    2	    fuzzy	        0.408333	0.0176
# 3	  RDCM	PCA	    2	    hierarchical	0.413333	0.0276
# 4	  RDCM	tSNE	  2	    kmeans	      0.413000	0.0309


UndefVarError: UndefVarError: `transform` not defined in `Main`
Hint: It looks like two or more modules export different bindings with this name, resulting in ambiguity. Try explicitly importing it from a particular module, or qualifying the name with the module it should come from.
Hint: a global variable of this name also exists in DataFrames.
Hint: a global variable of this name may be made accessible by importing ScikitLearnBase in the current active module Main
Hint: a global variable of this name also exists in MultivariateStats.