In [1]:
import torch
import argparse
import json
import os
from utils import get_embeddings, run_retrieval, evaluate_retrieval, setup_logging
from data.utils import get_dataset
from models import load_model
import torch.nn.functional as F
import logging

parser = argparse.ArgumentParser()
# 'ViT-B-16' ; default
# 'laion2b_s34b_b88k'; default pretrained
model_name = "Marqo/marqo-fashionSigLIP"

# Args for datasets
parser.add_argument("--data-dir", type=str, default="./data/", help='Data directory.')
parser.add_argument('--dataset-config', default='./configs/fashion200k.json', help='Dataset config file.')
parser.add_argument("--batch-size", type=int, default=512)
parser.add_argument("--num-workers", type=int, default=4)
# Args for models

parser.add_argument('--model-name', type=str, default=model_name, help='Model name.')
parser.add_argument('--run-name', type=str, default=model_name, help='Run name.')
parser.add_argument("--pretrained", type=str, default=None, help='Pretrained name.')
parser.add_argument('--cache-dir', default="/home/jupyter/cache", help='Cache directory for models and datasets.')
parser.add_argument('--device', default='cuda', help='Device to use for inference.')
parser.add_argument("--query-prefix", type=str, default='', help="Query prefix if required (ex. 'description: ')")
# Args for evaluations
parser.add_argument('--Ks', default=[1, 10], nargs='+', help='Ks for metrics.')
parser.add_argument("--overwrite-embeddings", action="store_true", default=False)
parser.add_argument("--overwrite-retrieval", action="store_true", default=False)
parser.add_argument("--output-dir", type=str, default='./results')

_StoreAction(option_strings=['--output-dir'], dest='output_dir', nargs=None, const=None, default='./results', type=<class 'str'>, choices=None, help=None, metavar=None)

In [2]:
import sys
sys.argv = []
args = parser.parse_args()

In [3]:
setup_logging()
# Output directory settings
args.output_dir = os.path.join(args.output_dir, os.path.basename(args.dataset_config).replace('.json',''), args.run_name)
if not os.path.exists(args.output_dir):
    os.makedirs(args.output_dir, exist_ok=True)
else:
    logging.warning(f'Output directory {args.output_dir} exists. Ignore this if it is expected.')
with open(os.path.join(args.output_dir, 'args.json'), 'w') as f:
    json.dump(args.__dict__, f, indent=4)
args.embeddings_path = os.path.join(args.output_dir, "embeddings.pt")

# Read dataset config file
with open(args.dataset_config, 'r') as file:
    args.dataset_config = json.load(file)
logging.info("Dataset: " + args.dataset_config["name"])


2025-01-07,00:43:28 | INFO | Dataset: Fashion200k


In [4]:
# Load model
model, preprocess, tokenizer = load_model(args)

# Load documenets and generate embeddings
model = model.to(args.device)


2025-01-07,00:44:18 | INFO | Created a temporary directory at /var/tmp/tmp7cfcouiu
2025-01-07,00:44:18 | INFO | Writing /var/tmp/tmp7cfcouiu/_remote_module_non_scriptable.py
downloading https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_pytorch_model.bin to /home/jupyter/cache/tmpp4intzrl
2025-01-07,00:44:18 | INFO | downloading https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_pytorch_model.bin to /home/jupyter/cache/tmpp4intzrl


Downloading (…)ip_pytorch_model.bin:   0%|          | 0.00/813M [00:00<?, ?B/s]

Storing https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_pytorch_model.bin in cache at /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/blobs/f51a245681b2a027c26c1684a89dbd27cbd2819fca2fc2d4c697208d33d46400
2025-01-07,00:44:37 | INFO | Storing https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_pytorch_model.bin in cache at /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/blobs/f51a245681b2a027c26c1684a89dbd27cbd2819fca2fc2d4c697208d33d46400
Creating pointer from ../../blobs/f51a245681b2a027c26c1684a89dbd27cbd2819fca2fc2d4c697208d33d46400 to /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/snapshots/e5619578fd528afa0bf88d8fae37748336a57fa2/open_clip_pytorch_model.bin
2025-01-07,00:44:37 | INFO | Creating pointer from ../../blobs/f51a245681b2a027c26c1684a89dbd27cbd2819fca2fc2d4c697208d33d46400 to /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/snapshots/e5619578fd528afa0bf88d8fae37748336a57fa2/open_clip_pytorch_model.bi

Downloading open_clip_config.json:   0%|          | 0.00/881 [00:00<?, ?B/s]

Storing https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_config.json in cache at /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1
2025-01-07,00:44:38 | INFO | Storing https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_config.json in cache at /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1
Creating pointer from ../../blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1 to /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/snapshots/e5619578fd528afa0bf88d8fae37748336a57fa2/open_clip_config.json
2025-01-07,00:44:38 | INFO | Creating pointer from ../../blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1 to /home/jupyter/cache/models--Marqo--marqo-fashionSigLIP/snapshots/e5619578fd528afa0bf88d8fae37748336a57fa2/open_clip_config.json
2025-01-07,00:44:38 | INFO | Loaded hf-hub:Marqo/marqo-fashionSigLIP model config.
2025-01-07,00:44:41 | INFO | Loadin

Downloading open_clip_config.json:   0%|          | 0.00/881 [00:00<?, ?B/s]

Storing https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_config.json in cache at /home/jupyter/cache/hub/models--Marqo--marqo-fashionSigLIP/blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1
2025-01-07,00:44:41 | INFO | Storing https://huggingface.co/Marqo/marqo-fashionSigLIP/resolve/main/open_clip_config.json in cache at /home/jupyter/cache/hub/models--Marqo--marqo-fashionSigLIP/blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1
Creating pointer from ../../blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1 to /home/jupyter/cache/hub/models--Marqo--marqo-fashionSigLIP/snapshots/e5619578fd528afa0bf88d8fae37748336a57fa2/open_clip_config.json
2025-01-07,00:44:41 | INFO | Creating pointer from ../../blobs/15eb27d98adcbfaf9e0915a5778a58eaffbb43a1 to /home/jupyter/cache/hub/models--Marqo--marqo-fashionSigLIP/snapshots/e5619578fd528afa0bf88d8fae37748336a57fa2/open_clip_config.json
downloading https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/tokenizer_config.json to /home/ju

Downloading tokenizer_config.json:   0%|          | 0.00/20.6k [00:00<?, ?B/s]

Storing https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/tokenizer_config.json in cache at /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/blobs/ed24b6426d2a710bedd0ba400fb180d9f50bab3b
2025-01-07,00:44:41 | INFO | Storing https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/tokenizer_config.json in cache at /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/blobs/ed24b6426d2a710bedd0ba400fb180d9f50bab3b
Creating pointer from ../../blobs/ed24b6426d2a710bedd0ba400fb180d9f50bab3b to /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/snapshots/41f575766f40e752fdd1383e9565b7f02388c1c4/tokenizer_config.json
2025-01-07,00:44:41 | INFO | Creating pointer from ../../blobs/ed24b6426d2a710bedd0ba400fb180d9f50bab3b to /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/snapshots/41f575766f40e752fdd1383e9565b7f02388c1c4/tokenizer_config.json
downloading https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/tokenizer.json to /home/jupyter/cache/tmp68hv61g5
2025-01-07,00:44:42 | INFO | 

Downloading tokenizer.json:   0%|          | 0.00/2.42M [00:00<?, ?B/s]

Storing https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/tokenizer.json in cache at /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/blobs/06da9e637f9d1a09573910bbcc0f64394f4a89a0
2025-01-07,00:44:42 | INFO | Storing https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/tokenizer.json in cache at /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/blobs/06da9e637f9d1a09573910bbcc0f64394f4a89a0
Creating pointer from ../../blobs/06da9e637f9d1a09573910bbcc0f64394f4a89a0 to /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/snapshots/41f575766f40e752fdd1383e9565b7f02388c1c4/tokenizer.json
2025-01-07,00:44:42 | INFO | Creating pointer from ../../blobs/06da9e637f9d1a09573910bbcc0f64394f4a89a0 to /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/snapshots/41f575766f40e752fdd1383e9565b7f02388c1c4/tokenizer.json
downloading https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/special_tokens_map.json to /home/jupyter/cache/tmpocatjdvx
2025-01-07,00:44:42 | INFO | downloading https:/

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Storing https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/special_tokens_map.json in cache at /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/blobs/cc26c82999019b181968187db28e9fbca53a5e52
2025-01-07,00:44:42 | INFO | Storing https://huggingface.co/timm/ViT-B-16-SigLIP/resolve/main/special_tokens_map.json in cache at /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/blobs/cc26c82999019b181968187db28e9fbca53a5e52
Creating pointer from ../../blobs/cc26c82999019b181968187db28e9fbca53a5e52 to /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/snapshots/41f575766f40e752fdd1383e9565b7f02388c1c4/special_tokens_map.json
2025-01-07,00:44:42 | INFO | Creating pointer from ../../blobs/cc26c82999019b181968187db28e9fbca53a5e52 to /home/jupyter/cache/models--timm--ViT-B-16-SigLIP/snapshots/41f575766f40e752fdd1383e9565b7f02388c1c4/special_tokens_map.json


In [5]:
doc_dataset, item_ID = get_dataset(args, tokenizer, preprocess)
logging.info(f"Number of document rows: {len(doc_dataset):,}")


2025-01-07,00:44:54 | INFO | Loading dataset from huggingface.
D2: <dict object at 0x7f48f5814c80>
T4: <class 'datasets.data_files.DataFilesDict'>
# T4
D2: <dict object at 0x7f48f6a529b0>
T4: <class 'datasets.data_files.DataFilesList'>
# T4
T4: <class 'datasets.data_files.Url'>
# T4
D2: <dict object at 0x7f48f5879d20>
# D2
# D2
# D2
2025-01-07,00:44:55 | INFO | Using custom data configuration Marqo--fashion200k-e452c40783d60867
2025-01-07,00:44:55 | INFO | Overwrite dataset info from restored data version if exists.
2025-01-07,00:44:55 | INFO | Loading Dataset info from /home/jupyter/cache/datasets/Marqo___parquet/Marqo--fashion200k-e452c40783d60867/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7
2025-01-07,00:44:55 | INFO | Loading Dataset info from /home/jupyter/cache/datasets/Marqo___parquet/Marqo--fashion200k-e452c40783d60867/0.0.0/14a00e99c0d15a23649d0db8944380ac81082d4b021f398733dd84f3a6c569a7


  0%|          | 0/1 [00:00<?, ?it/s]

T4: <class 'data.utils.Transform'>
# T4
D2: <dict object at 0x7f48e3c19460>
T4: <class 'open_clip.tokenizer.HFTokenizer'>
# T4
D2: <dict object at 0x7f48f58052d0>
T4: <class 'transformers.models.t5.tokenization_t5_fast.T5TokenizerFast'>
# T4
D2: <dict object at 0x7f48f58a7c30>
T4: <class 'tokenizers.Tokenizer'>
# T4
T4: <class 'tokenizers.models.Model'>
# T4
D2: <dict object at 0x7f48f585a550>
D2: <dict object at 0x7f48f585a500>
D2: <dict object at 0x7f48f585a4b0>
# D2
D2: <dict object at 0x7f48f585a460>
# D2
D2: <dict object at 0x7f48f585a410>
# D2
D2: <dict object at 0x7f48f585a3c0>
# D2
D2: <dict object at 0x7f48f585a370>
# D2
D2: <dict object at 0x7f48f585a320>
# D2
D2: <dict object at 0x7f48f585a2d0>
# D2
D2: <dict object at 0x7f48f585a1e0>
# D2
D2: <dict object at 0x7f48f585a140>
# D2
D2: <dict object at 0x7f48f585a190>
# D2
D2: <dict object at 0x7f48f585a230>
# D2
D2: <dict object at 0x7f48f585a280>
# D2
D2: <dict object at 0x7f48f5868460>
# D2
D2: <dict object at 0x7f48f5868550

In [6]:
os.path.isfile(args.embeddings_path)

False

In [7]:
if not os.path.isfile(args.embeddings_path) or args.overwrite_embeddings:
    logging.info("Getting embeddings of documents")
    embeddings = get_embeddings(model, doc_dataset, args)
    torch.save(embeddings, args.embeddings_path)
else:
    logging.info("Loading embeddings of documents")
    embeddings = torch.load(args.embeddings_path)

2025-01-07,00:46:13 | INFO | Getting embeddings of documents
  1%|          | 4/394 [00:47<1:12:56, 11.22s/it]

In [8]:
# Run tasks
for task in args.dataset_config["tasks"]:
    task_dir = os.path.join(args.output_dir, task['name'])
    if not os.path.exists(task_dir):
        os.makedirs(task_dir, exist_ok=True)
    logging.info(f'Task: {json.dumps(task, indent=4)}')

    for query_col in task["query_col"]:
        gt_dir = os.path.join(args.data_dir, args.dataset_config["name"], 'gt_query_doc')
        gt_results_path = os.path.join(gt_dir, f"ground_truth_{query_col}-{'+'.join(task['doc_col'])}.json")
        assert os.path.exists(gt_results_path)

        # Ground-truth query-doc
        logging.info("Loading ground truth")
        with open(gt_results_path, "r") as f:
            gt_results = json.load(f)
            test_queries = list(gt_results.keys()) # randomly sampled queries (up to 2000)
        
        # Running retrieval
        retrieval_path = os.path.join(task_dir, f"retrieved_{query_col}-{'+'.join(task['doc_col'])}.json")
        if os.path.exists(retrieval_path) and not args.overwrite_retrieval:
            logging.info("Loading retrieval")
            with open(retrieval_path, "r") as f:
                retrieval_results = json.load(f)
        else:
            logging.info("Running retrieval")
            if len(task['doc_col'])==1:
                doc_embeddings = embeddings[task['doc_col'][0]].to(args.device)
            else:
                assert ('doc_weights' in task and len(task['doc_weights'])==len(task['doc_col'])), \
                    "Must provide the same number of weights for multi-field documents as the number of multi-fields."
                doc_embeddings = F.normalize(torch.stack([w*embeddings[c] for c, w in zip(task['doc_col'], task['doc_weights'])], dim=1).sum(1), dim=-1).to(args.device)
            retrieval_results = run_retrieval(test_queries, item_ID, doc_embeddings, tokenizer, model, max(args.Ks), args)
            with open(retrieval_path, "w") as f:
                json.dump(retrieval_results, f, indent=4)

        # Evaluation Starts
        logging.info("Evaluation Starts")
        output_results = evaluate_retrieval(gt_results, retrieval_results, args)
        output_json = os.path.join(task_dir, f"result_{query_col}-{'+'.join(task['doc_col'])}.json")
        output_json_dict = json.dumps(output_results, indent=4)
        logging.info(output_json_dict)
        with open(output_json, 'w') as f:
            f.write(output_json_dict)

2025-01-07,00:07:58 | INFO | Task: {
    "name": "text-to-image",
    "query_col": [
        "text"
    ],
    "doc_col": [
        "image"
    ]
}
2025-01-07,00:07:58 | INFO | Loading ground truth
2025-01-07,00:07:58 | INFO | Running retrieval
100%|██████████| 2000/2000 [01:32<00:00, 21.56it/s]
2025-01-07,00:09:31 | INFO | Evaluation Starts
2025-01-07,00:09:31 | INFO | For evaluation, we ignore identical query and document ids (default), please explicitly set ``ignore_identical_ids=False`` to ignore this.
2025-01-07,00:09:31 | INFO | 

2025-01-07,00:09:31 | INFO | NDCG@1: 0.0630
2025-01-07,00:09:31 | INFO | NDCG@10: 0.1365
2025-01-07,00:09:31 | INFO | 

2025-01-07,00:09:31 | INFO | MAP@1: 0.0624
2025-01-07,00:09:31 | INFO | MAP@10: 0.1075
2025-01-07,00:09:31 | INFO | 

2025-01-07,00:09:31 | INFO | Recall@1: 0.0624
2025-01-07,00:09:31 | INFO | Recall@10: 0.2297
2025-01-07,00:09:31 | INFO | 

2025-01-07,00:09:31 | INFO | P@1: 0.0630
2025-01-07,00:09:31 | INFO | P@10: 0.0238
2025-01-07,0