Tutorial for basic training on a single tissue section
- **Creator**: Amir Akbarnejad (aa36@sanger.ac.uk)
- **Affiliation**: Wellcome Sanger Institute and University of Cambridge
- **Date of Creation**: 23.06.2025
- **Date of Last Modificaion**: 23.06.2025

**To be able to run the notebook, the parts that you need to modify are specified by `TODO:MODIFY:`. The rest can be left untouched.**  

This notebook demonstrates how to train MintFlow on a single tissue section. 
This notebook is only for demonstration, and to get biologically meaningful results you may need longer training and/or different hyper-parameter settings.

# 1. Download the anndata object

Download this `.h5ad` file from google drive: https://drive.google.com/file/d/187Y44hpY5OuwMu0_PA9r9WvycMOx-uz5/view?usp=sharing 

and place it in a directory of you choice. Thereafter, set the variable `path_anndata` below to the path where you placed the`.h5ad` file. 

In [None]:
path_anndata = './NonGit/data_train_single_section.h5ad'  
# TODO:MODIFY: set to the path where you've put the `.h5ad` file that you downloaded.

In [None]:
import os, sys
import yaml
import mintflow
import pickle
from tqdm.autonotebook import tqdm


import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

# 1. Read the defualt configurations

In this section 4 default configuration files are read, which are later on customised. You only need to specify 
- `num_tissue_sections_training`: Number of tissue sections to be used for training.
- `num_tissue_sections_evaluation`: Number of tissue sections to be used for evaluation.

Same tissue sections can be used for training/evaluation, in which case these two numbers are the same.



In [None]:
config_data_train, config_data_evaluation, config_model, config_training = mintflow.get_default_configurations(
    num_tissue_sections_training=1,
    num_tissue_sections_evaluation=1
)

# 2. Customise the 4 configurations
In this section we customise the four configurations returned by `mintflow.get_default_configurations` above.


## 2.1. Costomise `config_data_train`

MintFlow requires that each tissue section is saved in a separate anndata file on disk (i.e. one anndata object for each tissue section). 
The `.X` field of each anndata object is required to have raw counts, in integer data type and "**without**" row-sum normalisation or log1p transform. 

The `.obs` field of each anndata object is required to have
- A column that specifies cell type labels
- A column that specifies a unique tissue section (i.e. slice) identifier. For each anndata object you can add a column to its `.obs` field that contains, e.g., the index or barcode of each tissue section that you've assiened to each tissue section.
- A column that specifies batch identifier to correct for batch effect (biological, technological, between-patient, etc.). 

In [None]:
# configure tissue section 1 =========
config_data_train['list_tissue']['anndata1']['file'] = path_anndata
#   the absolute path to anndata object of tissue section 1 on disk.


config_data_train['list_tissue']['anndata1']['obskey_cell_type'] = 'broad_celltypes'
#   meaning that for the 1st tissue section, cell type labels are provided in `broad_celltypes` column of `adata.obs`.


config_data_train['list_tissue']['anndata1']['obskey_sliceid_to_checkUnique'] = 'info_id'
#   meaning that for the 1st tissue section, tissue section ID (i.e. slice ID) is provided in `info_id` column of `adata.obs`


config_data_train['list_tissue']['anndata1']['obskey_x'] = 'x_centroid'
#   meaning that for the 1st tissue section, spatial x coordinates are provided in `x_centroid` column of `adata.obs`


config_data_train['list_tissue']['anndata1']['obskey_y'] = 'y_centroid'
#   meaning that for the 1st tissue section, spatial y coordinates are provided in `y_centroid` column of `adata.obs`



config_data_train['list_tissue']['anndata1']['obskey_biological_batch_key'] = 'info_id'
#   meaning that for the 1st tissue section, batch identifier is provided in `info_id` column of `adata.obs`


config_data_train['list_tissue']['anndata1']['config_dataloader_train']['width_window'] = 700
#   For tissue section one, the crop size of the customised dataloader desribed in Supplementary Fig. 16 of paper.
#   The larger this number, the larger the tissue crops, and the bigger the subset of cells in each training iteration.
#      This implies that more GPU memory would be required during training.
#   In this notebook after calling `mintflow.setup_data` in Sec 4 the crop(s) are shown on tissue, 
#      with some information on image title which can help you tune this parameter.
#   Pleaes refer to our documenation for details about how to tune this hyperparamter.

## 2.2. Costomise `config_data_evaluation`

The set of tissue sections for evaluation can be the same, in which case the same values can be used at the following.

Note that in the following cell instead of `['config_dataloader_train']['width_window']` we have `['config_dataloader_test']['width_window']`.

In [None]:
# configure tissue section 1 =======================
config_data_evaluation['list_tissue']['anndata1']['file'] = path_anndata
#   the absolute path to anndata object of tissue section 1 on disk.


config_data_evaluation['list_tissue']['anndata1']['obskey_cell_type'] = 'broad_celltypes'
#   meaning that for the 1st tissue section, cell type labels are provided in `broad_celltypes` column of `adata.obs`

config_data_evaluation['list_tissue']['anndata1']['obskey_sliceid_to_checkUnique'] = 'info_id'
#   meaning that for the 1st tissue section, tissue section ID (i.e. slice ID) is provided in `info_id` column of `adata.obs`

config_data_evaluation['list_tissue']['anndata1']['obskey_x'] = 'x_centroid'
#   meaning that for the 1st tissue section, spatial x coordinates are provided in `x_centroid` column of `adata.obs`


config_data_evaluation['list_tissue']['anndata1']['obskey_y'] = 'y_centroid'
#   meaning that for the 1st tissue section, spatial y coordinates are provided in `y_centroid` column of `adata.obs`


config_data_evaluation['list_tissue']['anndata1']['obskey_biological_batch_key'] = 'info_id'
#   meaning that for the 1st tissue section, batch identifier is provided in `info_id` column of `adata.obs`

config_data_evaluation['list_tissue']['anndata1']['config_dataloader_test']['width_window'] = 700
#   For tissue section one, the crop size of the customised dataloader desribed in Supplementary Fig. 16 of paper.
#   The larger this number, the larger the tissue crops, and the bigger the subset of cells in each training iteration.
#      This implies that more GPU memory would be required during training.
#   In this notebook after calling `mintflow.setup_data` in Sec 4 the crop(s) are shown on tissue, 
#      with some information on image title which can help you tune this parameter.
#   Pleaes refer to our documenation for details about how to tune this hyperparamter.


## 2.3. Customise `config_model`

None of model configuration are essential to tune. So in this tutorial we leave `config_model` untouched. Please refer to our documentation for changes that you can make to `config_model.`


## 2.4. Customise `config_training` 

In [None]:
config_training['num_training_epochs'] = 20
# number of training epochs, i.e. the number of times the model sees the dataset during training.

config_training['flag_use_GPU'] = 'True'
# whether GPU is used.

config_training['flag_enable_wandb'] = 'True'
# if set to True, during training different loss terms are logged to wandb.
# It's highly recommended to enable wandb. Please refer to wandb website for more info: `wandb.ai`


config_training['wandb_project_name'] = 'MintFlow'
# wandb project name (ignored if `config_training['flag_enable_wandb']` is set to False)

config_training['wandb_run_name'] = 'Mintflow_Tutorial_June22nd'
# wandb run name (ignored if `config_training['flag_enable_wandb']` is set to False)


# 3. Verify and post-process the four configurations

In this section we verify/and postprocess the four configurations to, e.g., check for errors.

In [None]:
config_data_train = mintflow.verify_and_postprocess_config_data_train(config_data_train) 

In [None]:
config_data_evaluation = mintflow.verify_and_postprocess_config_data_evaluation(config_data_evaluation)

In [None]:
config_model = mintflow.verify_and_postprocess_config_model(config_model, num_tissue_sections=len(config_data_train))  

In [None]:
config_training = mintflow.verify_and_postprocess_config_training(config_training) 

# 4. Setup the Data/Model/Trainer
Having created and verified the 4 configurations, in this section we create the variables `data_mintflow`, `model`, and `trainer`.

In [None]:
dict_all4_configs = {
    'config_data_train':config_data_train,
    'config_data_evaluation':config_data_evaluation,
    'config_model':config_model,
    'config_training':config_training
}

In [None]:
data_mintflow = mintflow.setup_data(dict_all4_configs=dict_all4_configs)

In [None]:
model = mintflow.setup_model(
    dict_all4_configs=dict_all4_configs,
    data_mintflow=data_mintflow
)

In [None]:
trainer = mintflow.Trainer(
    dict_all4_configs=dict_all4_configs,
    model=model,
    data_mintflow=data_mintflow
)

# 5. Train the Model

In [None]:
list_evaluation_result_knownsignallinggenes = []
for idx_epoch in tqdm(range(config_training['num_training_epochs'])):
    '''
    To change the number of epochs, please set `config_training['num_training_epochs']` at the beginning of this notebook
    and please refrain from changing the for loop to, e.g., `for idx_epoch in tqdm(range(10))`.
    ''' 
    
    # train for one epoch
    trainer.train_one_epoch()

    # get/save the predictions
    predictions = mintflow.predict(
        dict_all4_configs=dict_all4_configs,
        data_mintflow=data_mintflow,
        model=model,
        evalulate_on_sections="all",
    )
    with open("./NonGit/predictions_epoch_{}.pkl".format(idx_epoch), 'wb') as f:
        pickle.dump(
            predictions,
            f
        )

    # save the checkpoint
    mintflow.dump_checkpoint(
        model=model,
        data_mintflow=data_mintflow,
        dict_all4_configs=dict_all4_configs,
        path_dump="./NonGit/checkpoint_epoch_{}.pt".format(idx_epoch),
    )    

