Skip to content

Commit

Permalink
Fix overwritten zeroshot templates issue #109. (#110)
Browse files Browse the repository at this point in the history
* fix overwritten zeroshot templates issue #109. Thanks to @djghosh13.

* handle non classification/linear probe case

* comments

* simplify

* support dumping the classnames and templates that are used for evaluation

* fix tests

* just comments

* just comments
  • Loading branch information
mehdidc committed Dec 1, 2023
1 parent 567f01c commit 0c11b17
Show file tree
Hide file tree
Showing 3 changed files with 113 additions and 58 deletions.
8 changes: 7 additions & 1 deletion clip_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ def get_parser_args():
parser_eval.add_argument('--annotation_file', default="", type=str, help="text annotation file for retrieval datasets. Only needed for when `--task` is `zeroshot_retrieval`.")
parser_eval.add_argument('--custom_classname_file', default=None, type=str, help="use custom json file with classnames for each dataset, where keys are dataset names and values are list of classnames.")
parser_eval.add_argument('--custom_template_file', default=None, type=str, help="use custom json file with prompts for each dataset, where keys are dataset names and values are list of prompts. For instance, to use CuPL prompts, use --custom_template_file='cupl_prompts.json'")
parser_eval.add_argument('--dump_classnames', default=False, action="store_true", help="dump classnames to the results json file.")
parser_eval.add_argument('--dump_templates', default=False, action="store_true", help="dump templates to the results json file.")

parser_eval.add_argument('--language', default="en", type=str, nargs="+", help="language(s) of classname and prompts to use for zeroshot classification.")
parser_eval.add_argument('--output', default="result.json", type=str, help="output file where to dump the metrics. Can be in form of a template, e.g., --output='{dataset}_{pretrained}_{model}_{language}_{task}.json'")
Expand Down Expand Up @@ -316,7 +318,7 @@ def run(args):
transform=transform
)
else:
raise ValueError("Unsupported task: {}. task should be `zeroshot_classification`, `zeroshot_retrieval`, `linear_probe`, or `captioning`".format(task))
raise ValueError("Unsupported task: {}. task should be `zeroshot_classification`, `zeroshot_retrieval`, `linear_probe`, or `captioning`".format(task))
dump = {
"dataset": args.dataset,
"model": args.model,
Expand All @@ -325,6 +327,10 @@ def run(args):
"metrics": metrics,
"language": args.language,
}
if hasattr(dataset, "classes") and dataset.classes and args.dump_classnames:
dump["classnames"] = dataset.classes
if hasattr(dataset, "templates") and dataset.templates and args.dump_templates:
dump["templates"] = dataset.templates
if args.verbose:
print(f"Dump results to: {output}")
with open(output, "w") as f:
Expand Down

0 comments on commit 0c11b17

Please sign in to comment.