This repository requires to install the environment and datasets:
- follow here to install Dassl.pytorch and PyTorch.
- run
pip install -r requirements.txt
underCLIPCALIB/
to install a few more packages required by CLIP (this should be done whendassl
is activated). - follow DATASETS.md to install the datasets.
PS: You can also follow CoOp to perform the installation.
We present the basic usage here.
(a) TR -- ZS-Norm:
bash scripts/adapt_zs_norm.sh 0 imagenet SGD_lr1e-1_B256_ep300 16 TR none RN50
bash scripts/eval_zs_norm.sh 0 imagenetv2 SGD_lr1e-1_B256_ep300 16 TR none RN50
(b) TR -- Penalty:
bash scripts/adapt_zs_pen.sh.sh 0 imagenet SGD_lr1e-1_B256_ep300 16 TR none RN50
bash scripts/eval_zs_pen.sh.sh 0 imagenetv2 SGD_lr1e-1_B256_ep300 16 TR none RN50
(c) TR -- SaLS:
-
bash scripts/adapt.sh 0 imagenet SGD_lr1e-1_B256_ep300 16 TR none RN50
-
bash scripts/eval.sh 0 imagenetv2 SGD_lr1e-1_B256_ep300 16 TR none RN50
-
bash scripts/eval_zs.sh 0 imagenetv2 SGD_lr1e-1_B256_ep300 16 ZS none RN50
-
The logits of the predictions are renormalized using the following snippet.
logits_tr = (logits_tr - min_logits_tr)/ (max_logits_tr - min_logits_tr) logits_tr = logits_tr * (max_logits_zs - min_logits_zs) + min_logits_zs
The integration of proposed calibration techniques is also available for prompt learning and test time prompt tuning in the branches.
This repository is mainly based on CoOp and TaskRes code base. We sincerely thank prior authors on this topic for his awesome code base.