In [None]:
# Package setup
import os
import glob
import logging
import scanpy as sc

# Import expert functions
from src.utils.constants import TRAINING_KEYS
from src.models._jedvi import JEDVI
# Import model run functions
from src.tune.run import train, test, full_run

# Setup logger
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)
# Change to source directory
os.chdir('../')

### I/O

In [None]:
# Set training and testing data directory
work_dir = 'path/to/data_dir'                 # Replace with path to your data directory
train_dir = os.path.join(work_dir, 'train')
test_dir = os.path.join(work_dir, 'test')
# Set model output directory
model_dir = os.path.join(work_dir, 'models', 'small')
os.makedirs(model_dir, exist_ok=True)
# Set training path
train_p = os.path.join(train_dir, 'shared_model_set_100.h5ad')                      # Select a training adata
# Set testing path
test_p = os.path.join(test_dir, 'jurkat.h5ad')                                      # Select a testing adata
# Set path to config
config_p = '../resources/params/defaults.yaml'

### Model Training

In [None]:
# Get config name
config_name = os.path.basename(config_p).replace('.yaml', '')
# Train model with loaded config file
train_output = train(adata_p=train_p, config_p=config_p, out_dir=model_dir, verbose=True)
# Ouput is a dictionary with these keys
TRAINING_KEYS

### Full run for a single config

In [None]:
# Train model
results = full_run(
    config_p=config_p,              # Path to the .yaml training config
    train_p=train_p,                # Path to the .h5ad training adata
    test_p=test_p,                  # Path to the .h5ad testing adata
    model_dir=model_dir,            # Model output directory
    test_unseen=False               # Whether to test zero-shot test classification on perturbations
)

### Train multiple configs

In [None]:
# Run configs
config_dir = '../resources/params/runs/test/'
config_ps = glob.glob(f'{config_dir}/**/*.yaml', recursive=True)
config_ps

['../resources/params/runs/test/1.yaml',
 '../resources/params/runs/test/2.yaml']

In [None]:
# Train configs
for config_p in config_ps:
    run_name = os.path.basename(os.path.dirname(config_p))
    run_dir = os.path.join(model_dir, run_name)
    full_run(config_p=config_p, train_p=train_p, test_p=test_p, model_dir=run_dir, test_unseen=True)

#### Manual model testing

In [None]:
# Manual model testing

# Load best model
version_dir = 'model_dir/test/lightning_logs/version_18'
model = JEDVI.load_checkpoint(
    version_dir,
    adata=sc.read(train_p)
)
# Manually test model with control neighbor filtering
results = test(
    model, 
    test_adata_p=test_p, 
    output_dir=version_dir, 
    incl_unseen=False, 
    plot=True, 
    return_results=False, 
    min_ms=0.0,
    control_neighbor_threshold=0.1
)