### Hyperparameter Tuning for ggml-ot

In this tutorial, we explore how to tune the hyperparameters of the ggml-ot framework, which is sensitive to choices of $\alpha$ and $\lambda$. These hyperparameters control the margin and the regularization strength in the learned metric.

ggml-ot supports automatic hyperparameter tuning: if a list of possible values for $\alpha$, $\lambda$ or (the rank of the subspace projection) $k$ is provided, instead of a single float, the model will internally search for the best combination.

After training, the most suitable parameters are returned together with the best w_theta.



Load the dataset and import necessary packages. Here, we use the Myocardial Infarction dataset from Kuppe et al., 2022 (https://www.nature.com/articles/s41586-022-05060-x).

In [None]:
import ggml_ot
import anndata as ad

local_path = "data/czi_dataset.h5ad"
adata = ad.read_h5ad(local_path)

We now run the ggml function and test for the best combination of the values $\alpha$ = {0.1,10} and $\lambda$ = {0.1,10} by passing lists for the two parameters.

To keep the runtime manageable, we limit the number of iterations to five and the subset of cells per patient to 500.
Depending on the size of your dataset and number of hyperparameter combinations, this may take a while.

In [2]:
w_theta, best = ggml_ot.ggml(adata, alpha=[0.1,10], reg=[0.1,10], max_iterations=5, n_cells=500, verbose=False)

keeping 7777 high variable genes
Starting the hyperparameter tuning


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=0.1, l=0.1, k=5",0.925,0.871089,0.920772,0.929053,0.915986


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=0.1, l=10, k=5",0.875,0.499167,0.732336,0.710212,0.761119


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=10, l=0.1, k=5",0.975,0.947845,0.974383,0.967818,0.969174


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=10, l=10, k=5",0.975,0.858474,0.942726,0.914377,0.940096


Once tuning is complete, the function returns 
- w_theta: the best learned weight matrix
- best: a dictionary of the best hyperparameter combination that was used to train w_theta

In [3]:
print(best)

{'score': np.float64(0.9691744029802156), 'knn_acc': np.float64(0.975), 'alpha': 10, 'reg': 0.1, 'k': 5}


Alternatively, we can run `ggml` and set the parameters `alpha`, `reg` and `rank_k` to [], to start an hyperparameter tuning with default values.

### Evaluating generalizability

Once w_theta was learned, we can assess how well it performs on unseen data. This ensures that the learned metric is not just memorizing training data but captures the general structure in the dataset.

ggml-ot provides a benchmarking function, `evaluate_generalizability`, which uses a train/test split of distributions to evaluate:

- k-NN classification accuracy
- Adjusted Rand Index (ARI)
- Normalized Mutual Information (NMI)
- Variation of Information (VI)

To evaluate generalizability, we need to define the train/test split when creating the dataset using `scRNA_Dataset`. By specifying the train_size parameter, the data is automatically partitioned into training and test sets, which are then used for evaluation.

In the following example, we use an 80% training and 20% test split. The matrix w_theta is trained using the best hyperparameters identified during the tuning step above but with 1000 cells per patient.

In [None]:
dataset = ggml_ot.scRNA_Dataset(adata, n_cells=500, train_size=0.8)
w_theta = ggml_ot.ggml(dataset, alpha=10, reg=0.1, max_iterations=5, n_threads=16)

keeping 7777 high variable genes
Running GGML with alpha: 10, reg: 0.1, rank: 5


100%|██████████| 9/9 [02:15<00:00, 15.07s/it]


Iteration 1 with Loss  108.90546417236328


100%|██████████| 9/9 [02:16<00:00, 15.19s/it]


Iteration 2 with Loss  87.245849609375


100%|██████████| 9/9 [02:22<00:00, 15.87s/it]


Iteration 3 with Loss  66.6489486694336


100%|██████████| 9/9 [02:19<00:00, 15.46s/it]


Iteration 4 with Loss  48.0069465637207


100%|██████████| 9/9 [02:18<00:00, 15.36s/it]


Iteration 5 with Loss  28.143152236938477


Now, we use `evaluate_generalizability` and get the k-NN accuracy score, the MI score, the ARI score and the VI score.

In [None]:
from ggml_ot.benchmark import evaluate_generalizability
knn, mi, ari, vi = evaluate_generalizability(dataset, w_theta, print_latex=False, method="alpha=10, reg=0.1")

Unnamed: 0_level_0,knn_acc,mi,ari,vi
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
"alpha=10, reg=0.1",0.833333,0.895945,0.954769,0.937487
