## Hyperparameter Tuning & Testing Generalizability

In this tutorial, we explore how to tune the hyperparameters of the ggml-ot framework, which is sensitive to choices of $\alpha$ and $\lambda$. These hyperparameters control the margin and the regularization strength in the learned metric.

ggml-ot supports automatic hyperparameter tuning: if a list of possible values for $\alpha$, $\lambda$ or (the rank of the subspace projection) $k$ is provided, instead of a single float, the model will internally search for the best combination.

After training, the most suitable parameters are returned together with the best w_theta.



Load the dataset and import necessary packages. Here, we use the Myocardial Infarction dataset from Kuppe et al., 2022 (https://www.nature.com/articles/s41586-022-05060-x).

In [None]:
import ggml_ot
import anndata as ad

local_path = "data/czi_dataset.h5ad"
adata = ad.read_h5ad(local_path)

We now run the ggml function and test for the best combination of the values $\alpha$ = {0.1,10} and $\lambda$ = {0.1,10} by passing lists for the two parameters.

To keep the runtime manageable, we limit the number of iterations to four and the subset of cells per patient to 250.
Depending on the size of your dataset and number of hyperparameter combinations, this may take a while.

In [8]:
w_theta, best = ggml_ot.ggml(adata, alpha=[0.1,10], reg=[0.1,10], max_iterations=4, n_splits=3, n_cells=250, verbose=False, n_threads=32)

keeping 7777 high variable genes
Starting the hyperparameter tuning


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=0.1, l=0.1, k=5",0.859259,0.814536,0.919751,0.88778,0.866641


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=0.1, l=10, k=5",0.688889,0.599336,0.785401,0.754764,0.701028


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=10, l=0.1, k=5",0.833333,0.897905,0.959503,0.939445,0.882809


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=10, l=10, k=5",0.862963,0.766445,0.889026,0.858675,0.850506


Once tuning is complete, the function returns 
- w_theta: the best learned weight matrix
- best: a dictionary of the best hyperparameter combination that was used to train w_theta

In [9]:
print(best)

{'score': np.float64(0.8828089346826489), 'knn_acc': np.float64(0.8333333333333334), 'alpha': 10, 'reg': 0.1, 'k': 5}


Alternatively, we can run `ggml` and set the parameters `alpha`, `reg` and `rank_k` to [], to start an hyperparameter tuning with default values.

### Evaluating generalizability

If we have a specific hyperparameters to test, we can assess how well they perform on unseen data when w_theta is trained with them. This ensures that the learned metric is not just memorizing training data but captures the general structure in the dataset.

Similar to the hyperparameter tuning, evaluating the generalizability can be done with `ggml`. If an integer value is given for the parameter `n_splits`, w_theta is trained on the given hyperparameters on n_splits different splits. The function evaluates the following measures:

- k-NN classification accuracy
- Adjusted Rand Index (ARI)
- Normalized Mutual Information (NMI)
- Variation of Information (VI)

To evaluate generalizability, we need to define a train/ test split. By specifying the ``n_splits`` parameter, the data is automatically partitioned into n_splits training and test sets, which are then used for evaluation. The parameter also implies the train/ test size - the test size is 1/n_splits and the train size 1 - 1/n_splits. 

In the following example, we create five splits, so an 80% training and 20% test split. The matrix w_theta is trained using the best hyperparameters identified during the tuning step above but with 500 cells per patient and five iterations.

In [4]:
dataset = ggml_ot.scRNA_Dataset(adata, n_cells=500)
w_theta = ggml_ot.ggml(dataset, alpha=10, reg=0.1, max_iterations=5, n_splits=5, n_threads=32)

keeping 7777 high variable genes
Starting the generalizability evaluation
Running GGML with alpha: 10, reg: 0.1, rank: 5


100%|██████████| 19/19 [05:47<00:00, 18.29s/it]


Iteration 1 with Loss  220.24761962890625


100%|██████████| 19/19 [05:45<00:00, 18.19s/it]


Iteration 2 with Loss  143.08209228515625


100%|██████████| 19/19 [05:41<00:00, 17.96s/it]


Iteration 3 with Loss  57.33650207519531


100%|██████████| 19/19 [05:41<00:00, 17.96s/it]


Iteration 4 with Loss  25.52369499206543


100%|██████████| 19/19 [05:38<00:00, 17.83s/it]


Iteration 5 with Loss  16.912389755249023


Unnamed: 0_level_0,knn_acc,mi,ari,vi
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
"a=10, l=0.1, k=5",1.0,1.0,1.0,1.0


Running GGML with alpha: 10, reg: 0.1, rank: 5


100%|██████████| 19/19 [05:38<00:00, 17.83s/it]


Iteration 1 with Loss  218.4571533203125


100%|██████████| 19/19 [05:44<00:00, 18.15s/it]


Iteration 2 with Loss  138.93653869628906


100%|██████████| 19/19 [05:46<00:00, 18.24s/it]


Iteration 3 with Loss  54.832210540771484


100%|██████████| 19/19 [05:47<00:00, 18.28s/it]


Iteration 4 with Loss  22.67696189880371


100%|██████████| 19/19 [05:43<00:00, 18.06s/it]


Iteration 5 with Loss  19.33045768737793


Unnamed: 0_level_0,knn_acc,mi,ari,vi
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
"a=10, l=0.1, k=5",1.0,1.0,1.0,1.0


Running GGML with alpha: 10, reg: 0.1, rank: 5


100%|██████████| 19/19 [05:41<00:00, 17.98s/it]


Iteration 1 with Loss  216.9450225830078


100%|██████████| 19/19 [05:45<00:00, 18.16s/it]


Iteration 2 with Loss  142.11526489257812


100%|██████████| 19/19 [05:47<00:00, 18.27s/it]


Iteration 3 with Loss  53.789432525634766


100%|██████████| 19/19 [05:45<00:00, 18.19s/it]


Iteration 4 with Loss  20.282501220703125


100%|██████████| 19/19 [05:39<00:00, 17.87s/it]


Iteration 5 with Loss  15.363859176635742


Unnamed: 0_level_0,knn_acc,mi,ari,vi
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
"a=10, l=0.1, k=5",0.833333,0.895945,0.954769,0.937487


Running GGML with alpha: 10, reg: 0.1, rank: 5


100%|██████████| 19/19 [05:40<00:00, 17.93s/it]


Iteration 1 with Loss  220.8211212158203


100%|██████████| 19/19 [05:40<00:00, 17.93s/it]


Iteration 2 with Loss  143.5382843017578


100%|██████████| 19/19 [05:40<00:00, 17.94s/it]


Iteration 3 with Loss  54.272804260253906


100%|██████████| 19/19 [05:39<00:00, 17.89s/it]


Iteration 4 with Loss  19.35325813293457


100%|██████████| 19/19 [05:39<00:00, 17.85s/it]


Iteration 5 with Loss  23.102890014648438


Unnamed: 0_level_0,knn_acc,mi,ari,vi
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
"a=10, l=0.1, k=5",0.833333,0.796121,0.909348,0.872499


Running GGML with alpha: 10, reg: 0.1, rank: 5


100%|██████████| 22/22 [06:31<00:00, 17.79s/it]


Iteration 1 with Loss  242.85745239257812


100%|██████████| 22/22 [06:34<00:00, 17.95s/it]


Iteration 2 with Loss  141.36297607421875


100%|██████████| 22/22 [06:38<00:00, 18.11s/it]


Iteration 3 with Loss  46.368350982666016


100%|██████████| 22/22 [06:37<00:00, 18.06s/it]


Iteration 4 with Loss  23.992477416992188


100%|██████████| 22/22 [06:31<00:00, 17.78s/it]


Iteration 5 with Loss  21.154428482055664


Unnamed: 0_level_0,knn_acc,mi,ari,vi
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
"a=10, l=0.1, k=5",0.8,0.895945,0.954769,0.937487


Unnamed: 0_level_0,knn_acc,mi,ari,vi,score
method,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
"a=10, l=0.1, k=5",0.893333,0.917602,0.963777,0.949495,0.918479
