Skip to content

Commit

Permalink
Move support of model types to all_clip module.
Browse files Browse the repository at this point in the history
I created the all_clip module in order to have a single place to support all kind of clip models.
It is already used in clip retrieval and I propose to use it here too.
  • Loading branch information
rom1504 committed Jan 21, 2024
1 parent 5f23a76 commit a2f1a1e
Show file tree
Hide file tree
Showing 6 changed files with 32 additions and 83 deletions.
12 changes: 2 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -80,16 +80,8 @@ Here is an example of use

### How to add other CLIP models

Please follow these steps:
1. Add a identity file to load model in `clip_benchmark/models`
2. Define a loading function, that returns a tuple (model, transform, tokenizer). Please see `clip_benchmark/models/open_clip.py` as an example.
3. Add the function into `TYPE2FUNC` in `clip_benchmark/models/__init__.py`

Remarks:
- The new tokenizer/model must enable to do the following things as https://github.com/openai/CLIP#usage
- `tokenizer(texts).to(device)` ... `texts` is a list of string
- `model.encode_text(tokenized_texts)` ... `tokenized_texts` is a output from `tokenizer(texts).to(device)`
- `model.encode_image(images)` ... `images` is a image tensor by the `transform`
Please add your model into [all-clip](https://github.com/DataToML/all-clip) and it will be supported into CLIP-benchmark (and in clip-retrieval).
See [How to add a model type](https://github.com/DataToML/all-clip?tab=readme-ov-file#how-to-add-a-model-type)


### CIFAR-10 example
Expand Down
20 changes: 10 additions & 10 deletions clip_benchmark/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,9 @@
from typing import Union
import torch
from .open_clip import load_open_clip
from .japanese_clip import load_japanese_clip
import all_clip

# loading function must return (model, transform, tokenizer)
TYPE2FUNC = {
"open_clip": load_open_clip,
"ja_clip": load_japanese_clip
}
MODEL_TYPES = list(TYPE2FUNC.keys())
# see https://github.com/rom1504/all-clip?tab=readme-ov-file#supported-models
MODEL_TYPES = ["openai_clip", "open_clip", "ja_clip", "hf_clip", "nm"]


def load_clip(
Expand All @@ -19,5 +14,10 @@ def load_clip(
device: Union[str, torch.device] = "cuda"
):
assert model_type in MODEL_TYPES, f"model_type={model_type} is invalid!"
load_func = TYPE2FUNC[model_type]
return load_func(model_name=model_name, pretrained=pretrained, cache_dir=cache_dir, device=device)
return all_clip.load_clip(
clip_model=model_type+":"+model_name+"/"+pretrained,
use_jit=True,
warmup_batch_size=1,
clip_cache_path=cache_dir,
device=device,
)
54 changes: 0 additions & 54 deletions clip_benchmark/models/japanese_clip.py

This file was deleted.

8 changes: 0 additions & 8 deletions clip_benchmark/models/open_clip.py

This file was deleted.

1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,4 @@ open_clip_torch>=0.2.1
pycocoevalcap
webdataset>=0.2.31
transformers
all_clip>=1.0.0,<2
20 changes: 19 additions & 1 deletion tests/test_clip_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from clip_benchmark.cli import run
import logging
import torch
import pytest

class base_args:
dataset="dummy"
Expand Down Expand Up @@ -109,10 +110,27 @@ class linear_probe_args:
custom_classname_file=None
distributed=False

def test_base():

def test_linear_probe():
if torch.cuda.is_available():
run(linear_probe_args)
else:
logging.warning("GPU acceleration is required for linear evaluation to ensure optimal performance and efficiency.")


@pytest.mark.parametrize(
"full_model_name",
[
"openai_clip:ViT-B/32",
"open_clip:ViT-B-32/laion2b_s34b_b79k",
"hf_clip:patrickjohncyh/fashion-clip",
],
)
def test_base(full_model_name):
model_type, model_name = full_model_name.split(":")
model, pretrained = model_name.split("/")
base_args.model_type = model_type
base_args.model = model
base_args.pretrained = pretrained
os.environ["CUDA_VISIBLE_DEVICES"] = ""
run(base_args)

0 comments on commit a2f1a1e

Please sign in to comment.