This repository provides code to train and run CARE designed for Chinese medical retrieval tasks. The system trains a compact query encoder and a larger document encoder.
- Python 3.8+
- packages:
torch,transformers,huggingface_hub,numpy,tqdm,FlagEmbedding.
pip install -r requirements.txt
scripts/stage1-query-alignment.share training scripts for stage I. SetROOTand model paths (query encoder and document encoder seperately).FIX_DOC_ENCODERshould set to beTrueat this stage.scripts/stage2-joint-finetuning.share scripts for stage II. Model checkpoints of query encoder are from stage I, while model path of document encoder are same as stage I.
Both scripts requires set ROOT, MODEL_QUERY, MODEL_DOC, TRAIN_DATA, SAVE_DIR, and cluster env variables (WORLD_SIZE, RANK, MASTER_ADDR). The training entrypoint is train/main.py
export ROOT=/path/to/project
cd $ROOT/Medical-Asymmetric-Retriever/train
bash ../scripts/stage1-query-alignment.sh
bash ../scripts/stage2-joint-finetuning.shUse inference/asymmetric.py::CARE to load encoders and compute embeddings. Example:
from inference.asymmetric import CARE
import numpy as np
model_name_or_path_query = "path/to/query/encoder"
model_name_or_path_doc = "path/to/document/encoder"
care = CARE(
model_name_or_path_query=model_name_or_path_query,
model_name_or_path_doc=model_name_or_path_doc,
trust_remote_code=True,
use_fp16=False,
normalize_embeddings=True,
query_batch_size=2,
passage_batch_size=2,
)
queries = [
"什么是高血压?"
]
corpus = [
"高血压是指动脉血压持续升高,通常指收缩压≥140mmHg和/或舒张压≥90mmHg。"
]
query_embeddings = care.encode_queries(queries, task_name='retrieval')
print("Query Embeddings:", query_embeddings, query_embeddings.shape)
corpus_embeddings = care.encode_corpus(corpus, task_name='retrieval')
print("Corpus Embeddings:", corpus_embeddings, corpus_embeddings.shape)
scores = np.dot(query_embeddings, corpus_embeddings.T)
print("Similarity Scores:", scores)-
inference/asymmetric.py:CAREinference wrapper that loads separate query and document tokenizers and encoders.
-
scripts/stage1-query-alignment.sh:- Example shell script for stage-1 query alignment training.
-
scripts/stage2-joint-finetuning.sh:- Example shell script for stage-2 joint finetuning.
-
train/arguments.py:- Defines dataclasses for model, data and training arguments.
-
train/dataset.py:- Dataset classes and collators used during training.
-
train/load_model.py:- Utilities to construct/load the document encoder.
-
train/main.py:- Training entrypoint. Parses arguments into dataclasses, initializes
AsymmetricEmbedderRunner, and starts training.
- Training entrypoint. Parses arguments into dataclasses, initializes
-
train/modeling.py:- Contains
AsymmetricEmbedderModelwhich wraps the query and document encoders.
- Contains
-
train/runner.py:- Runner that wires together tokenizers, base models, the
AsymmetricEmbedderModel, and data collators.
- Runner that wires together tokenizers, base models, the
-
train/trainer.py:- Custom trainer.
-
data:- Sampled training data examples for stage-1 and stage-2.
stage1: Sampled data for stage-1.stage1-query-align-q.jsonlare query-side triples, andstage1-query-align-doc.jsonlare document-side triples.stage2: Sampled data for stage-2.stage2-medteb-retrieval.jsonlare sampled from MedTEB retrieval train part.stage2-medteb-sts.jsonlare sampled from MedTEB STS train part.
Use scripts/ as templates to reproduce experimental settings. Ensure paths and cluster environment variables are set correctly.
https://huggingface.co/PhilipGAQ/CARE-0.3B-4B
License: CC-BY-NC-SA-4.0.
@misc{jiang2026cmedtebcarebenchmarking, title={CMedTEB & CARE: Benchmarking and Enabling Efficient Chinese Medical Retrieval via Asymmetric Encoders}, author={Angqing Jiang and Jianlyu Chen and Zhe Fang and Yongcan Wang and Xinpeng Li and Keyu Ding and Defu Lian}, year={2026}, eprint={2604.10937}, archivePrefix={arXiv}, primaryClass={cs.IR}, url={https://arxiv.org/abs/2604.10937}, }