# Tutorial: Adding and Using a Custom Algorithm in the Auto-Search Framework

In this notebook, we'll walk through how to integrate a new algorithm (specifically an SVM classifier) into the auto-search framework outlined in your documentation. We will:

1. Inherit from the **BaseClassificationMethod** (or another suitable base) to define our custom method. Implement the required interfaces (`fit`, `predict`, and optionally `preprocessing_pipeline`).  
2. Show how to run the hyperparameter search using the integrated method.  
3. Provide an example `main.py`-like script that demonstrates how the auto-search process is orchestrated.

## 1. Folder Structure & Requirements

Before diving in, ensure you have the following directory structure (at least conceptually; your actual project structure can be more extensive):

```
examples/tuning/
└── classification_svm/
    ├── main.py
    ├── tutorial.ipynb  
    └── dataset_name/
        ├── pipeline_params_tuning_config.yaml
        └── config_yamls/
            ├── 0_test_acc_params_tuning_config.yaml
            ├── 1_test_acc_params_tuning_config.yaml
            └── 2_test_acc_params_tuning_config.yaml
```

Where `cta_svm` is the directory we created for our new algorithm. The same pattern can apply for other methods, such as `clustering_kmeans`, `regression_linreg`, etc.

We'll focus on the **SVM** example below.

---

## 2. Defining Our SVM Classifier

Suppose we want to define a custom SVM method for classification.
We'll inherit from BaseClassificationMethod and implement the required methods.

In [None]:
from typing import Optional
from dance.modules.base import BaseClassificationMethod
from sklearn.svm import SVC
import numpy as np

from dance.transforms.cell_feature import WeightedFeaturePCA
from dance.transforms.misc import Compose, SetConfig
from dance.typing import LogLevel

class SVM(BaseClassificationMethod):
    """The SVM cell-type classification model.

    Parameters
    ----------
    args : argparse.Namespace
        A Namespace contains arguments of SVM. See parser help document for more info.
    prj_path: str
        project path

    """

    def __init__(self, args, prj_path="./", random_state: Optional[int] = None):
        self.args = args
        self.random_state = random_state
        self._mdl = SVC(random_state=random_state, probability=True)

    @staticmethod
    def preprocessing_pipeline(n_components: int = 400, log_level: LogLevel = "INFO"):
        return Compose(
            WeightedFeaturePCA(n_components=n_components, split_name="train"),
            SetConfig({
                "feature_channel": "WeightedFeaturePCA",
                "label_channel": "cell_type"
            }),
            log_level=log_level,
        )

    def fit(self, x: np.ndarray, y: np.ndarray):
        """Train the classifier.

        Parameters
        ----------
        x
            Training cell features.
        y
            Training labels.

        """
        self._mdl.fit(x, y)

    def predict(self, x: np.ndarray):
        """Predict cell labels.

        Parameters
        ----------
        x
            Samples to be predicted (samplex x features).

        Returns
        -------
        y
            Predicted labels of the input samples.

        """
        return self._mdl.predict(x)


## 3. Example `main.py` File

Below is an example of how your `main.py` might look if you're adding SVM as one of the classification methods. This file orchestrates the entire pipeline:

1. **Register** preprocessing functions through annotations (optional)
2. **Parsing Arguments** and configuring hyperparameters.  
3. **Defining** an evaluation function that:  
   - Loads and preprocesses the data.  
   - Initializes your model (the new SVM class).  
   - Trains and scores the model.  
   - Logs results to Weights & Biases (wandb).  
4. **Running** the hyperparameter sweep agent (e.g., via `wandb_sweep_agent`).  
5. **Saving** results and optionally generating a second-stage tuning config file.

> **Note**: For demonstration, only relevant code is shown. Adjust as needed for your exact pipeline or data.

In [None]:
""" 
Step 1: preprocessing functions can be registered using register_preprocessor. 
In this example, the GaussRandProjFeature preprocessing function is registered within the feature.cell pipeline. 
This registered function can later be specified in the configuration file.
"""
from sklearn.random_projection import GaussianRandomProjection
from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform


@register_preprocessor("feature", "cell",overwrite=True)  # NOTE: register any custom preprocessing function to be used for tuning
class GaussRandProjFeature(BaseTransform):
    """Custom preprocessing to extract cell feature via Gaussian random projection."""

    _DISPLAY_ATTRS = ("n_components", "eps")

    def __init__(self, n_components: int = 400, eps: float = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.n_components = n_components
        self.eps = eps

    def __call__(self, data):
        feat = data.get_feature(return_type="numpy")
        grp = GaussianRandomProjection(n_components=self.n_components, eps=self.eps)

        self.logger.info(f"Start generateing cell feature via Gaussian random projection (d={self.n_components}).")
        data.data.obsm[self.out] = grp.fit_transform(feat)

        return data


In [None]:
# Example main.py

import argparse
import gc
import os
import pprint
import random
import sys
from pathlib import Path
from typing import get_args

from dance.registry import register_preprocessor
from dance.transforms.base import BaseTransform
import torch
import wandb
import numpy as np

from dance import logger
from dance.datasets.singlemodality import CellTypeAnnotationDataset  # your dataset
from dance.pipeline import PipelinePlaner, get_step3_yaml, run_step3, save_summary_data
from dance.utils import set_seed
from dance.typing import LogLevel
from sklearn.random_projection import GaussianRandomProjection
root_path=str(Path(__file__).resolve().parent) if '__file__' in globals() else Path("tutorial.ipynb").resolve().parent

# Import your custom SVM class
# In reality, you'd do: from your_svm_file import SVM
# from your_svm_file import SVM


def main(args=None):
    #Step 2: Parsing Arguments and configuring hyperparameters
    parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
    parser.add_argument("--cache", action="store_true", help="Cache processed data.")
    parser.add_argument("--dense_dim", type=int, default=400, help="dim of PCA")
    parser.add_argument("--gpu", type=int, default=0, help="GPU id, set to -1 for CPU")
    parser.add_argument("--log_level", type=str, default="INFO", choices=get_args(LogLevel))
    parser.add_argument("--species", default="human")
    parser.add_argument("--test_dataset", nargs="+", default=[138], type=int, help="list of dataset id")
    parser.add_argument("--tissue", default="Brain")  # TODO: Add option for different tissue name for train/test
    parser.add_argument("--train_dataset", nargs="+", default=[328], type=int, help="list of dataset id")
    parser.add_argument("--valid_dataset", nargs="+", default=None, type=int, help="list of dataset id")
    parser.add_argument("--tune_mode", default="pipeline_params", choices=["pipeline", "params", "pipeline_params"])
    parser.add_argument("--seed", type=int, default=10)
    parser.add_argument("--count", type=int, default=2)
    parser.add_argument("--sweep_id", type=str, default=None)
    parser.add_argument("--summary_file_path", default="results/pipeline/best_test_acc.csv", type=str)
    parser.add_argument("--root_path", default=root_path, type=str)
    if args is None:
        args = parser.parse_args()
    else:
        args = parser.parse_args(args)

    # Construct the path to the tuning config file
    file_root_path = Path(
        args.root_path, "_".join([
            "-".join([str(num) for num in dataset])
            for dataset in [args.train_dataset, args.valid_dataset, args.test_dataset] if dataset is not None
        ])).resolve()
    logger.info(f"\n files is saved in {file_root_path}")

    # Instantiate pipeline planer from config file
    pipeline_planer = PipelinePlaner.from_config_file(f"{file_root_path}/{args.tune_mode}_tuning_config.yaml")
    os.environ["WANDB_AGENT_MAX_INITIAL_FAILURES"] = "2000"

    #Step 3: define evaluation function
    def evaluate_pipeline(tune_mode=args.tune_mode, pipeline_planer=pipeline_planer):
        """
        The evaluation function used by wandb_sweep_agent.
        It:
        1. Loads data.
        2. Applies the pipeline.
        3. Trains and scores the model.
        4: Evaluate model
        5. Logs metric(s) to wandb.
        """
        wandb.init(settings=wandb.Settings(start_method='thread'))
        set_seed(args.seed)

        # Load dataset
        data = CellTypeAnnotationDataset(train_dataset=args.train_dataset, test_dataset=args.test_dataset,
                                         valid_dataset=args.valid_dataset, species=args.species, tissue=args.tissue,
                                         data_dir="../temp_data").load_data()

        # Preprocessing pipeline
        kwargs = {tune_mode: dict(wandb.config)}
        preprocessing_pipeline = pipeline_planer.generate(**kwargs)
        preprocessing_pipeline(data)

        # Retrieve training / testing data
        x_train, y_train = data.get_train_data()
        y_train_converted = y_train.argmax(1)
        x_valid, y_valid = data.get_val_data()
        x_test, y_test = data.get_test_data()

        #Initialize our custom SVM model and train
        # from your_svm_file import SVM  # Place your SVM import here
        model = SVM(args, random_state=args.seed)
        model.fit(x_train, y_train_converted)

        #Evaluate model
        train_score = model.score(x_train, y_train)
        score = model.score(x_valid, y_valid)
        test_score = model.score(x_test, y_test)

        #Log results to wandb
        wandb.log({"train_acc": train_score, "acc": score, "test_acc": test_score})
        wandb.finish()

    # Step 4: Run the sweep
    entity, project, sweep_id = pipeline_planer.wandb_sweep_agent(
        evaluate_pipeline, sweep_id=args.sweep_id, count=args.count) 

    #Step 5: Save summary data (top results, etc.)
    save_summary_data(entity, project, sweep_id, summary_file_path=args.summary_file_path, root_path=file_root_path)

    # Optionally, handle pipeline + parameter search steps
    if args.tune_mode == "pipeline" or args.tune_mode == "pipeline_params":
        get_step3_yaml(result_load_path=f"{args.summary_file_path}", step2_pipeline_planer=pipeline_planer,
                       conf_load_path=f"{Path(args.root_path).resolve().parent}/step3_default_params.yaml",
                       root_path=file_root_path)
        if args.tune_mode == "pipeline_params":
            run_step3(file_root_path, evaluate_pipeline, tune_mode="params", step2_pipeline_planer=pipeline_planer)
if __name__ == "__main__":
    import os
    os.environ["http_proxy"] = "http://121.250.209.147:7890"
    os.environ["https_proxy"] = "http://121.250.209.147:7890"
    main([])


## 3. Configuration Files

The **configuration files** (e.g., `pipeline_params_tuning_config.yaml`, `pipeline_tuning_config.yaml`, `params_tuning_config.yaml`) guide the auto-search. Each file contains instructions for how to vary your preprocessing pipeline or model hyperparameters (or both). For example:

#pipeline_params_tuning_config.yaml
```yaml
type: preprocessor
tune_mode: pipeline_params
pipeline_tuning_top_k: 2
parameter_tuning_freq_n: 2
pipeline:
  - type: filter.gene
    include:
      - FilterGenesPercentile
      - FilterGenesScanpyOrder
      - FilterGenesPlaceHolder
    default_params:
      FilterGenesScanpyOrder:
          order: ["min_counts", "min_cells", "max_counts", "max_cells"]
          min_counts: 1
          max_counts: 134732
          min_cells: 1
          max_cells: 401
  - type: normalize
    include:
      - ScaleFeature
      - ScTransform
      - Log1P
      - NormalizeTotal
      - NormalizePlaceHolder
    default_params:
      ScTransform:
        processes_num: 8
  - type: filter.gene
    include:
      # - HighlyVariableGenesLogarithmizedByMeanAndDisp
      - HighlyVariableGenesRawCount
      - HighlyVariableGenesLogarithmizedByTopGenes
      - FilterGenesTopK
      - FilterGenesRegression
      # - FilterGenesNumberPlaceHolder
    default_params:
      FilterGenesTopK:
        num_genes: 100
      FilterGenesRegression:
        num_genes: 100
      HighlyVariableGenesRawCount:
        n_top_genes: 100
      HighlyVariableGenesLogarithmizedByTopGenes:
        n_top_genes: 100
  - type: feature.cell
    include:
      - WeightedFeaturePCA
      - WeightedFeatureSVD
      - CellPCA
      - CellSVD
      - GaussRandProjFeature  # Registered custom preprocessing func
      - FeatureCellPlaceHolder
    params:
      out: feature.cell
      log_level: INFO
  - type: misc
    target: SetConfig
    params:
      config_dict:
        feature_channel: feature.cell
        label_channel: cell_type
wandb:
  entity: xzy11632
  project: dance-dev
  method: grid #try grid to provide a comprehensive search
  metric:
    name: acc  # val/acc
    goal: maximize


```

**Tips**:

1. In `tune_mode=pipeline`, the system will only tune the preprocessing pipeline.  
2. In `tune_mode=params`, the system will only tune the model parameters.  
3. In `tune_mode=pipeline_params`, the system will do a two-stage search: first for pipelines, then for model parameters.

---

## 4. Testing & Execution

After setting everything up:

```bash
# Search only the best preprocessing pipeline:
python main.py --tune_mode pipeline

# Search only the best model hyperparameters:
python main.py --tune_mode params

# Joint two-stage search for both pipeline and parameters:
python main.py --tune_mode pipeline_params
```

Once this completes, you should see results logged into Weights & Biases (wandb). The save_summary_data function writes out a CSV of the top performing runs. If you selected pipeline_params, the script also generates a default param config for the second stage of the search, which is automatically run via run_step3.

## 5. Summary
By following these steps:

Inherit from the appropriate base class (in our case BaseClassificationMethod).
Implement the fit, predict, and (optionally) preprocessing_pipeline methods.
Integrate your custom model into the main.py script.
Create and reference the necessary configuration (YAML) files.
Run the pipeline using --tune_mode (pipeline|params|pipeline_params).
…you can easily plug in any custom algorithm—ranging from simple classification methods like an SVM to deep learning methods with pretraining steps—into this auto-search framework.

Happy coding and good luck with your hyperparameter searches!

