Official repo of Feature Re-Embedding: Towards Foundation Model-Level Performance in Computational Pathology, CVPR 2024. [arXiv]
- Merge training scripts for the full task
- Improving README document
- Uploading DOCKERFILE
Uploading codes about survival prediction
- Uploaded interim Survival code, and training scripts for the full task will be fused later.
- Uploaded almost all codes, docker, and datasets.
To preprocess WSIs, we used CLAM. PLIP model and weight can be found in this.
Download the preprocessed patch features: Baidu Cloud.
--preset
bwh_biopsy.csv for Camlyon (It's the preset parameters officially provided by CLAM), --preset
preprocess_tcga_nsclc.csv for TCGA-NSCLS (It's the customized parameters), --preset
tcga.csv for other TCGA-BRCA (It's the preset parameters officially provided by CLAM)
# for Camlyon
python create_patches_fp.py --source DATA_DIRECTORY --save_dir RESULTS_DIRECTORY --patch_size 512 \
--step_size 512 --preset bwh_biopsy.csv --seg --patch
# for TCGA-NSCLC
python create_patches_fp.py --source DATA_DIRECTORY --save_dir RESULTS_DIRECTORY --patch_size 512 \
--step_size 512 --preset preprocess_tcga_nsclc.csv --seg --patch
Some code snippets about PLIP feature are shown below:
extract_features_fp.py
:
model = PLIP()
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
mean, std = (0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)
eval_transforms = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean = mean, std = std)])
models/plip.py
from transformers import CLIPVisionModelWithProjection
class PLIP(torch.nn.Module):
def __init__(self):
super(PLIPM,self).__init__()
self.model = model = CLIPVisionModelWithProjection.from_pretrained("vinid/plip")
def forward(self, input):
return self.model(batch_input).image_embeds
epeg_k
,crmsa_k
are the primary hyper-paras, you can set crmsa_mlp
, all_shortcut
if you want.
region_num
is the important hyper-para for GPU memory, and increasing it can significantly reduce GPU memory usage. Its default value is 8
, which takes up about 10GB
with an average sequence length of 9000
. I recommend changing this value to 16
or even larger
if you want to apply it to longer sequence tasks such as survival prediction.
from rrt import RRTEncoder
# you should put the rrt_enc before aggregation module, after fc and dp
# x_rrt = fc(x_rrt) # 1,N,1024 -> 1,N,512
# x_rrt = dropout(x_rrt)
rrt = RRTEncoder(mlp_dim=512,epeg_k=15,crmsa_k=3)
x_rrt = rrt(x_rrt) # 1,N,512 -> 1,N,512
# x_rrt = mil_model(x_rrt) # 1,N,512 -> 1,N,C
Download the Docker Image: Baidu Cloud.
Note: Because of code refactoring, this repository cannot fully reproduce the results in the paper. If you have a need for this, please contact me via email.
python3 main.py --project=$PROJECT_NAME --datasets=camelyon16 \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=rrtmil --pool=attn --n_trans_layers=2 --da_act=tanh --title=c16_r50_rrtmil \
--epeg_k=15 --crmsa_k=1 --all_shortcut --seed=2021
python3 main.py --project=$PROJECT_NAME --datasets=camelyon16 \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=rrtmil --pool=attn --n_trans_layers=2 --da_act=tanh --title=c16_plip_rrtmil \
--epeg_k=9 --crmsa_k=3 --all_shortcut --input_dim=512 --seed=2021
python3 main.py --project=$PROJECT_NAME --datasets=tcga --tcga_sub=brca \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=rrtmil --pool=attn --n_trans_layers=2 --da_act=tanh --title=brca_r50_rrtmil \
--epeg_k=17 --crmsa_k=3 --crmsa_heads=1 --input_dim=512 --seed=2021
python3 main.py --project=$PROJECT_NAME --datasets=tcga --tcga_sub=brca \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=rrtmil --pool=attn --n_trans_layers=2 --da_act=tanh --title=brca_plip_rrtmil \
--all_shortcut --crmsa_k=1 --input_dim=512 --seed=2021
python3 main.py --project=$PROJECT_NAME --datasets=tcga \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=rrtmil --pool=attn --n_trans_layers=2 --da_act=tanh --title=nsclc_r50_rrtmil \
--epeg_k=21 --crmsa_k=5 --input_dim=512 --seed=2021
python3 main.py --project=$PROJECT_NAME --datasets=tcga \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=rrtmil --pool=attn --n_trans_layers=2 --da_act=tanh --title=nsclc_plip_rrtmil \
--all_shortcut --crmsa_mlp --epeg_k=13 --crmsa_k=3 --crmsa_heads=1 \
--input_dim=512 --seed=2021
study
= {BLCA LUAD LUSC}
, feat
= {resnet50 plip}
. It is better to perform a grid search for the hyperparameters crmsa_k={1,3,5}
and epeg_k={9,15,21}
.
python ./Surviva/main.py --model RRTMIL \
--excel_file ./csv/${study}_Splits.csv \
--num_epoch 30 \
--epeg_k 15 crmsa_k 3\
--folder $feat
set --only_rrt_enc
and change the --model
with model name,e.g., for clam_sb
:
python3 main.py --project=$PROJECT_NAME --datasets=tcga --tcga_sub=brca \
--dataset_root=$DATASET_PATH --model_path=$OUTPUT_PATH --cv_fold=5 \
--model=clam_sb --only_rrt_enc --n_trans_layers=2 --title=brca_plip_rrt_clam \
--all_shortcut --crmsa_k=1 --input_dim=512 --seed=2021
@InProceedings{tang2024feature,
author = {Tang, Wenhao and Zhou, Fengtao and Huang, Sheng and Zhu, Xiang and Zhang, Yi and Liu, Bo},
title = {Feature Re-Embedding: Towards Foundation Model-Level Performance in Computational Pathology},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2024},
pages = {11343-11352}
}