Skip to content

Commit

Permalink
add mscoco generative benchmark (#63)
Browse files Browse the repository at this point in the history
* add initial generative benchmark

* add pycocoeval dep

* muse generate_beamsearch

* update to generate

* clean generate

---------

Co-authored-by: Romain Beaumont <romain.rom1@gmail.com>
  • Loading branch information
gpucce and rom1504 committed Feb 4, 2023
1 parent aabe0fa commit c4d9927
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 3 deletions.
15 changes: 13 additions & 2 deletions clip_benchmark/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from copy import copy
import os
from clip_benchmark.datasets.builder import build_dataset, get_dataset_collate_fn, get_dataset_default_task, dataset_collection, get_dataset_collection_from_file
from clip_benchmark.metrics import zeroshot_classification, zeroshot_retrieval, linear_probe
from clip_benchmark.metrics import zeroshot_classification, zeroshot_retrieval, linear_probe, mscoco_generative
from clip_benchmark.model_collection import get_model_collection_from_file, model_collection
from clip_benchmark.models import load_clip, MODEL_TYPES

Expand All @@ -22,7 +22,7 @@ def get_parser_args():
parser_eval.add_argument('--model', type=str, default="ViT-B-32-quickgelu", help="Model architecture to use from OpenCLIP")
parser_eval.add_argument('--pretrained', type=str, default="laion400m_e32", help="Model checkpoint name to use from OpenCLIP")
parser_eval.add_argument('--pretrained_model', type=str, default="", nargs="+", help="Pre-trained model(s) to use. Can be the full model name where `model` and `pretrained` are comma separated (e.g., --pretrained_model='ViT-B-32-quickgelu,laion400m_e32'), a model collection name ('openai' or 'openclip_base' or 'openclip_multilingual' or 'openclip_all'), or path of a text file where each line is a model fullname where model and pretrained are comma separated (e.g., ViT-B-32-quickgelu,laion400m_e32). --model and --pretrained are ignored if --pretrained_model is used.")
parser_eval.add_argument('--task', type=str, default="auto", choices=["zeroshot_classification", "zeroshot_retrieval", "linear_probe", "auto"], help="Task to evaluate on. With --task=auto, the task is automatically inferred from the dataset.")
parser_eval.add_argument('--task', type=str, default="auto", choices=["zeroshot_classification", "zeroshot_retrieval", "linear_probe", "mscoco_generative", "auto"], help="Task to evaluate on. With --task=auto, the task is automatically inferred from the dataset.")
parser_eval.add_argument('--amp', default=True, action="store_true", help="whether to use mixed precision")
parser_eval.add_argument('--num_workers', default=4, type=int)
parser_eval.add_argument('--recall_k', default=[5], type=int, help="for retrieval, select the k for Recall@K metric. ", nargs="+",)
Expand Down Expand Up @@ -278,6 +278,17 @@ def run(args):
amp=args.amp,
verbose=args.verbose,
)
elif task == "mscoco_generative":
metrics = mscoco_generative.evaluate(
model=model,
dataloader=dataloader,
batch_size=args.batch_size,
num_workers=args.num_workers,
device=args.device,
amp=args.amp,
verbose=args.verbose,
transform=transform
)
else:
raise ValueError("Unsupported task: {}. task should `zeroshot_classification` or `zeroshot_retrieval`".format(task))
dump = {
Expand Down
31 changes: 31 additions & 0 deletions clip_benchmark/metrics/mscoco_generative.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import json
from pycocoevalcap.eval import COCOEvalCap
from open_clip import tokenize
from tqdm.auto import tqdm
from open_clip.tokenizer import _tokenizer

def evaluate(model, dataloader, batch_size, device, transform, train_dataloader=None, num_workers=None, amp=True, verbose=False):
coco = dataloader.dataset.coco
indexer = dataloader.dataset.ids
results = []
for idx, (img, _) in enumerate(tqdm(dataloader)):
n_samples = img.shape[0] # for last batch
idxs = [indexer[idx * batch_size + id] for id in range(n_samples)]
out = model.generate(img.to(device))
decoded = [_tokenizer.decode(i).split("<end_of_text>")[0].replace("<start_of_text>", "").strip() for i in out.cpu().numpy()]
for image_id, caption in zip(idxs, decoded):
results.append({"image_id":image_id, "caption":caption})
temp_res_file = "temp_results.json"
with open(temp_res_file, "w") as jf:
json.dump(results, jf)

coco_result = coco.loadRes(temp_res_file)
coco_eval = COCOEvalCap(coco, coco_result)
coco_eval.evaluate()
metrics = coco_eval.eval

# print output evaluation scores
for metric, score in metrics.items():
print(f'{metric}: {score:.3f}')

return metrics
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,6 @@ torchvision>=0.8.9,<2
tqdm>=2
scikit-learn>=1.0,<2
open_clip_torch>=0.2.1
pycocoevalcap
webdataset>=0.2.31
transformers
transformers

0 comments on commit c4d9927

Please sign in to comment.