## STEP3. Train spHOT

- Perform sample-level prediction using scMILD, a teacher–student based MIL (Multiple Instance Learning) framework.
- Leverage attention scores to identify and rank __disease-related instances__.
- Refer to the `configs` directory for detailed configuration parameters.

__MIL architecture__
- __Bag__: Sample
- __Instance__:
    - If `option == "Cell"` → use cell-level instances.
    - If `option == "Domain"` → use domain-level instances.



### Cell-level

- Use all individual cells to construct each bag.

In [None]:
import sys
import os
sys.path.insert(0, os.path.abspath("../src"))

SAVE_DIR = '../result/'

cfgs = {
    'task_name' : 'scMILD_NOVAE_PT',
    'MIL': {
        # Cell-level do not use 'dt' output
        'adata_dir' : "../result/NOVAE_PT/de/adata.h5ad", # if None 'dt' saved directory
        'experiment': {
            'device': 1,
            'folds' : 5,
            'option' : 'Cell', # 'Cell' or 'Domain' -level prediction
        },
        'splits': {
            'file_path' :  "../data/split_info.csv",
            'target_col': 'disease_status',
            'sample_col': 'SID',
            'target_map':{
                'Disease' : 1,
                'Control' : 0,
            }
        },
        'dataset':{
            'num_workers': 4,
            'pin_memory' : True,
            'teacher' : {
                'batch_size' : 8,
                'obsm_key': 'X_Novae',
                'sample_col' : 'sample_to_numeric',
                'target_col' : 'target_to_numeric',
            },
            'student': {
                'batch_size': 1024,
            }
        },
        'model_params':{
            'data_dim' : 64,
            'mil_latent_dim' : 64,
            'teacher_learning_rate' : 1e-4, 
            'student_learning_rate' : 1e-4, 
            'encoder_learning_rate' : 1e-4,
            'dropout' : 0.2,
        },
        'train':{
            'n_epochs': 100,
            'stuOptPeriod' : 3,
            'stu_loss_weight_neg' : 0.3,
            'patience' : 15,
            'epoch_warmup' : 0, 
            'opl_warmup' : 0, 
            'train_stud' : True, 
            'opl' : True, 
            'use_loss' : True, 
            'op_lambda' : 0.5, 
            'op_gamma': 0.5, 
            'opl_comps' : [2],
        },
        'save_dir' : SAVE_DIR,
    }
}


In [None]:
from mil.base import MILRunner
runner = MILRunner(task_name=cfgs.get('task_name'),cfgs=cfgs.get('MIL'))

In [None]:
runner.run()

### Domain-level

- Use domain centroids obtained from __STEP 2. Domain Tree Generation__ to construct each bag.

In [None]:
SAVE_DIR = '../result/'

cfgs = {
    'task_name' : 'spHOT_NOVAE_PT',
    'MIL': {
        # Domain-level uses 'dt' output
        'adata_dir' : "../result/NOVAE_PT/dt/adata.h5ad", # if None 'dt' saved directory
        'experiment': {
            'device': 1,
            'folds' : 5,
            'option' : 'Domain',    # 'Cell' or 'Domain' -level prediction
            'min_k' : 5,            # Domain
            'max_k' : 50,           # Domain
            'step' : 5,             # Domain
        },
        'splits': {
            'file_path' : "../data/split_info.csv",
            'target_col': 'disease_status',
            'sample_col': 'SID',
            'target_map':{
                'Disease' : 1,
                'Control' : 0,
            }
        },
        'dataset':{
            'num_workers': 4,
            'pin_memory' : True,
            'teacher' : {
                'batch_size' : 8,
                'obsm_key': 'X_Novae',
                'sample_col' : 'sample_to_numeric',
                'target_col' : 'target_to_numeric',
                'domain_col' : 'HC_metadomain_', # Domain
            },
            'student': {
                'batch_size': 1024,
                'domain_kwargs':{
                    'with_replacement' : True,
                    'alpha' : 0.5,
                    'bag_size' : 512,
                    'base_seed' : 42,
                }
            }
        },
        'model_params':{
            'data_dim' : 64,
            'mil_latent_dim' : 64,
            'teacher_learning_rate' : 1e-4, 
            'student_learning_rate' : 1e-4, 
            'encoder_learning_rate' : 1e-4,
            'dropout' : 0.2,
        },
        'train':{
            'n_epochs': 100,
            'stuOptPeriod' : 3,
            'stu_loss_weight_neg' : 0.3,
            'patience' : 15,
            'epoch_warmup' : 0, 
            'opl_warmup' : 0, 
            'train_stud' : True, 
            'opl' : True, 
            'use_loss' : True, 
            'op_lambda' : 0.5, 
            'op_gamma': 0.5, 
            'opl_comps' : [2],
        },
        'save_dir' : SAVE_DIR,
    }
}


In [None]:
runner = MILRunner(task_name=cfgs.get('task_name'),cfgs=cfgs.get('MIL'))
runner.run()