Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
mehdidc committed Dec 2, 2023
2 parents 832e17d + 8a05786 commit ca2638f
Show file tree
Hide file tree
Showing 8 changed files with 994 additions and 95 deletions.
43 changes: 32 additions & 11 deletions clip_benchmark/cli.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,26 @@
"""Console script for clip_benchmark."""
import argparse
import sys
import random
import json
import torch
import csv
from copy import copy
import json
import os
import random
import sys
from copy import copy
from itertools import product
from clip_benchmark.datasets.builder import build_dataset, get_dataset_collate_fn, get_dataset_default_task, dataset_collection, get_dataset_collection_from_file
from clip_benchmark.metrics import image_caption_selection, zeroshot_classification, zeroshot_retrieval, linear_probe, captioning
from clip_benchmark.model_collection import get_model_collection_from_file, model_collection
from clip_benchmark.models import load_clip, MODEL_TYPES

import torch

from clip_benchmark.datasets.builder import (build_dataset, dataset_collection,
get_dataset_collate_fn,
get_dataset_collection_from_file,
get_dataset_default_task)
from clip_benchmark.metrics import (captioning, image_caption_selection,
linear_probe, zeroshot_classification,
zeroshot_retrieval)
from clip_benchmark.model_collection import (get_model_collection_from_file,
model_collection)
from clip_benchmark.models import MODEL_TYPES, load_clip


def get_parser_args():
parser = argparse.ArgumentParser()
Expand Down Expand Up @@ -81,7 +90,7 @@ def main_build(base):
# Build a benchmark single CSV file from a set of evaluations (JSON files)
rows = []
fieldnames = set()
for path in base.files:
def process_file(path: str):
data = json.load(open(path))
row = {}
row.update(data["metrics"])
Expand All @@ -91,6 +100,13 @@ def main_build(base):
for field in row.keys():
fieldnames.add(field)
rows.append(row)
for path in base.files:
if os.path.isdir(path):
files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith(".json")]
for file in files:
process_file(file)
else:
process_file(path)
with open(base.output, 'w') as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
writer.writeheader()
Expand Down Expand Up @@ -241,6 +257,11 @@ def run(args):
device=args.device
)
model.eval()
if args.model.count("nllb-clip") > 0:
# for NLLB-CLIP models, we need to set the language prior to running the tests
from clip_benchmark.models.nllb_clip import set_language

set_language(tokenizer, args.language)
dataset = build_dataset(
dataset_name=args.dataset,
root=dataset_root,
Expand Down Expand Up @@ -295,7 +316,7 @@ def run(args):
verbose=args.verbose,
save_clf=args.save_clf,
load_clfs=args.load_clfs,
)
)
elif task == "zeroshot_retrieval":
metrics = zeroshot_retrieval.evaluate(
model,
Expand Down
92 changes: 51 additions & 41 deletions clip_benchmark/datasets/builder.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
import json
import os
import warnings
import sys
import json
import warnings
from subprocess import call
from collections import defaultdict

import torch
from torchvision.datasets import (
VisionDataset, ImageFolder,
CIFAR10, CIFAR100, ImageNet, CocoCaptions, Flickr8k, Flickr30k, Food101, SUN397,
StanfordCars, FGVCAircraft, DTD, OxfordIIITPet, Caltech101, Flowers102,
MNIST, STL10, EuroSAT, GTSRB, Kitti, Country211, PCAM, RenderedSST2
)

from . import voc2007, flickr, caltech101, imagenetv2, objectnet, babel_imagenet, sugar_crepe
from torch.utils.data import default_collate
from PIL import Image
from torchvision.datasets import (CIFAR10, CIFAR100, DTD, GTSRB, MNIST, PCAM,
STL10, SUN397, CocoCaptions, Country211,
EuroSAT, FGVCAircraft, Flowers102, Food101,
ImageFolder, ImageNet, OxfordIIITPet,
RenderedSST2, StanfordCars)

from . import (babel_imagenet, caltech101, flickr, imagenetv2, objectnet,
sugar_crepe, voc2007)


def build_dataset(dataset_name, root="root", transform=None, split="test", download=True, annotation_file=None, language="en", task="zeroshot_classification", wds_cache_dir=None, custom_classname_file=None, custom_template_file=None, **kwargs):
Expand Down Expand Up @@ -108,7 +107,7 @@ def download_imagenet(r):
elif dataset_name == "imagenet-w":
assert split in ("train", "test"), f"Only `train` and `test` split available for {dataset_name}"
from imagenet_w import AddWatermark
from torchvision.transforms import Normalize, CenterCrop
from torchvision.transforms import CenterCrop, Normalize
if not os.path.exists(root):
download_imagenet(root)
index_normalize = None
Expand Down Expand Up @@ -264,35 +263,44 @@ def download_imagenet(r):
ds = CocoCaptions(root=root_split, annFile=annotation_file, transform=transform, **kwargs)
elif dataset_name == 'multilingual_mscoco_captions':
from clip_benchmark.datasets import multilingual_mscoco
if(language not in multilingual_mscoco.SUPPORTED_LANGUAGES):
if language not in multilingual_mscoco.SUPPORTED_LANGUAGES:
raise ValueError("Unsupported language for multilingual_ms_coco:", language)

def get_archive_name(target_split):
if target_split == "train":
return "train2014.zip"
elif target_split in ("val", "test"):
return "val2014.zip"
else:
raise ValueError(f"split should be `train` or `val` or `test` for `{dataset_name}`")

def download_mscoco_split(target_split):
archive_name = get_archive_name(target_split)
root_split = os.path.join(root, archive_name.replace(".zip", ""))
if not os.path.exists(root_split):
print(f"Downloading mscoco_captions {archive_name}...")
if not os.path.exists(os.path.join(root, archive_name)):
call(f"wget http://images.cocodataset.org/zips/{archive_name} --output-document={root}/{archive_name}", shell=True)
call(f"unzip {root}/{archive_name} -d {root}", shell=True)

# The multilingual MS-COCO uses images from various splits
for target_split in ['train', 'val', 'test']:
download_mscoco_split(target_split)

annotation_file = os.path.join(root, multilingual_mscoco.CAPTIONS_FILE_NAME.format(language))
if (os.path.exists(annotation_file) == False):

annotation_file = os.path.join(root, multilingual_mscoco.OUTPUT_FILENAME_TEMPLATE.format(language))
if not os.path.exists(annotation_file):
multilingual_mscoco.create_annotation_file(root, language)

ds = multilingual_mscoco.Multilingual_MSCOCO(root=root, ann_file=annotation_file, transform=transform, **kwargs)
elif dataset_name == 'crossmodal3600':
from clip_benchmark.datasets import crossmodal3600
if language not in crossmodal3600.SUPPORTED_LANGUAGES:
raise ValueError("Unsupported language for Crossmodal-3600:", language)

annotation_file = os.path.join(root, crossmodal3600.OUTPUT_FILENAME_TEMPLATE.format(language))
if not os.path.exists(annotation_file):
crossmodal3600.create_annotation_file(root, language)

ds = crossmodal3600.Crossmodal3600(root=root, ann_file=annotation_file, transform=transform, **kwargs)
elif dataset_name == 'xtd200':
from clip_benchmark.datasets import xtd200
if language not in xtd200.SUPPORTED_LANGUAGES:
raise ValueError("Unsupported language for xtd200:", language)

annotation_file = os.path.join(root, xtd200.OUTPUT_FILENAME_TEMPLATE.format(language))
if not os.path.exists(annotation_file):
xtd200.create_annotation_file(root, language)

ds = xtd200.XTD200(root=root, ann_file=annotation_file, transform=transform, **kwargs)
elif dataset_name == 'flickr30k-200':
from clip_benchmark.datasets import flickr30k_200
if language not in flickr30k_200.SUPPORTED_LANGUAGES:
raise ValueError("Unsupported language for flickr30k-200:", language)

annotation_file = os.path.join(root, flickr30k_200.OUTPUT_FILENAME_TEMPLATE.format(language))
if not os.path.exists(annotation_file):
flickr30k_200.create_annotation_file(root, language)

ds = flickr30k_200.Flickr30k_200(root=root, ann_file=annotation_file, transform=transform, **kwargs)
elif dataset_name == "flickr30k":
# downloadable from https://www.kaggle.com/datasets/adityajn105/flickr30k
# https://github.com/mehdidc/retrieval_annotations/releases/tag/1.0.0(annotations)
Expand Down Expand Up @@ -513,15 +521,15 @@ def __len__(self):
return 1

def get_dataset_default_task(dataset):
if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions"):
if dataset in ("flickr30k", "flickr8k", "mscoco_captions", "multilingual_mscoco_captions", "flickr30k-200", "crossmodal3600", "xtd200"):
return "zeroshot_retrieval"
elif dataset.startswith("sugar_crepe"):
return "image_caption_selection"
else:
return "zeroshot_classification"

def get_dataset_collate_fn(dataset_name):
if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k") or dataset_name.startswith("sugar_crepe"):
if dataset_name in ("mscoco_captions", "multilingual_mscoco_captions", "flickr30k", "flickr8k", "flickr30k-200", "crossmodal3600", "xtd200") or dataset_name.startswith("sugar_crepe"):
return image_captions_collate_fn
else:
return default_collate
Expand All @@ -535,7 +543,8 @@ def has_kaggle():

def build_vtab_dataset(dataset_name, transform, download=True, split="test", data_dir="root", classnames=[]):
# Using VTAB splits instead of default TFDS splits
from .tfds import VTABIterableDataset, disable_gpus_on_tensorflow, download_tfds_dataset
from .tfds import (VTABIterableDataset, disable_gpus_on_tensorflow,
download_tfds_dataset)

# avoid Tensorflow owning GPUs to not clash with PyTorch
disable_gpus_on_tensorflow()
Expand Down Expand Up @@ -648,6 +657,7 @@ def build_vtab_dataset(dataset_name, transform, download=True, split="test", dat
classes = tfds_dataset._dataset_builder.info.features[task].names
elif dataset_name == "sun397":
from task_adaptation.data.sun397 import Sun397Data

#FIXME There is a problem in `sun397`, when TFDS tries download it
# there is an image that cannot be decoded. For the time being
# we will use torchvision's SUN397 instead.
Expand Down
152 changes: 152 additions & 0 deletions clip_benchmark/datasets/crossmodal3600.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import codecs
import json
import os
from subprocess import call

from PIL import Image
from torchvision.datasets import VisionDataset

SUPPORTED_LANGUAGES = [
"ar",
"bn",
"cs",
"da",
"de",
"el",
"en",
"es",
"fa",
"fi",
"fil",
"fr",
"he",
"hi",
"hr",
"hu",
"id",
"it",
"ja",
"ko",
"mi",
"nl",
"no",
"pl",
"pt",
"quz",
"ro",
"ru",
"sv",
"sw",
"te",
"th",
"tr",
"uk",
"vi",
"zh",
]

CAPTIONS_DOWNLOAD_URL = "https://google.github.io/crossmodal-3600/web-data/captions.zip"
IMAGES_DOWNLOAD_URL = (
"https://open-images-dataset.s3.amazonaws.com/crossmodal-3600/images.tgz"
)
OUTPUT_FILENAME_TEMPLATE = "crossmodal3600_captions-{}.json"


class Crossmodal3600(VisionDataset):
def __init__(self, root, ann_file, transform=None, target_transform=None):
super().__init__(root, transform=transform, target_transform=target_transform)
self.ann_file = os.path.expanduser(ann_file)
with codecs.open(ann_file, "r", encoding="utf-8") as fp:
data = json.load(fp)
self.data = [
(img_path, txt)
for img_path, txt in zip(data["image_paths"], data["annotations"])
]

def __getitem__(self, index):
img, captions = self.data[index]

# Image
img = Image.open(img).convert("RGB")
if self.transform is not None:
img = self.transform(img)

# Captions
target = [
captions,
]
if self.target_transform is not None:
target = self.target_transform(target)

return img, target

def __len__(self) -> int:
return len(self.data)


def _download_captions(out_path):
os.makedirs(out_path, exist_ok=True)
print("Downloading captions")
call(f"wget {CAPTIONS_DOWNLOAD_URL} -O captions.zip", shell=True)
call(f"unzip captions.zip -d {out_path}", shell=True)
call("rm captions.zip", shell=True)


def _download_images(out_path):
os.makedirs(out_path, exist_ok=True)
print("Downloading images")
call(f"wget {IMAGES_DOWNLOAD_URL} -O images.tgz", shell=True)
call(f"tar -xzf images.tgz -C {out_path}", shell=True)
call("rm images.tgz", shell=True)


def create_annotation_file(root, lang_code):
if lang_code not in SUPPORTED_LANGUAGES:
raise ValueError(
f"Language code {lang_code} not supported. Supported languages are {SUPPORTED_LANGUAGES}"
)
data_dir = os.path.join(root, "xm3600")
images_dir = os.path.join(data_dir, "images")
if not os.path.exists(images_dir):
_download_images(images_dir)
captions_path = os.path.join(data_dir, "captions.jsonl")
if not os.path.exists(captions_path):
_download_captions(data_dir)
with open(captions_path, "r", encoding="utf-8") as f:
data = f.readlines()
data = [json.loads(line) for line in data]

number_of_missing_images = 0
valid_images, valid_annotations, valid_indicies = [], [], []
for i, data_item in enumerate(data):
image_id = data_item["image/key"]
image_name = f"{image_id}.jpg"
image_path = os.path.join(images_dir, image_name)
if not os.path.exists(image_path):
print("Missing image file", image_name)
number_of_missing_images += 1
continue
captions = data_item[lang_code]["caption"]
txt = captions[0]

valid_images.append(image_path)
valid_annotations.append(txt)
valid_indicies.append(i)

if number_of_missing_images > 0:
print(f"*** WARNING *** missing {number_of_missing_images} files.")

with codecs.open(
os.path.join(root, OUTPUT_FILENAME_TEMPLATE.format(lang_code)),
"w",
encoding="utf-8",
) as fp:
json.dump(
{
"image_paths": valid_images,
"annotations": valid_annotations,
"indicies": valid_indicies,
},
fp,
ensure_ascii=False,
)
Loading

0 comments on commit ca2638f

Please sign in to comment.