## Test Marqo Model

In [7]:
import sys
import os

# 获取当前文件的目录
current_dir = os.path.dirname(os.path.abspath("__file__"))

# 获取上一级目录
parent_dir = os.path.dirname(current_dir)

# 将上一级目录添加到 sys.path
sys.path.append(parent_dir)

from transformers import AutoModel, AutoProcessor
import torch
from PIL import Image
import requests

from train.marqo_fashionSigLIP import MarqoFashionSigLIP, MarqoFashionSigLIPProcessor


model_name = "/data/model/Marqo/marqo-ecommerce-embeddings-L"
# model_name = 'Marqo/marqo-ecommerce-embeddings-B'

model = MarqoFashionSigLIP.from_pretrained(model_name)
processor = MarqoFashionSigLIPProcessor.from_pretrained(model_name)

print("Finish Loading Model")

model

# img = Image.open(
#     requests.get(
#         "https://raw.githubusercontent.com/marqo-ai/marqo-ecommerce-embeddings/refs/heads/main/images/dining-chairs.png",
#         stream=True,
#     ).raw
# ).convert("RGB")
# image = [img]
# text = ["dining chairs", "a laptop", "toothbrushes"]
# processed = processor(
#     text=text, images=image, padding="max_length", return_tensors="pt"
# )
# processor.image_processor.do_rescale = False
# with torch.no_grad():
#     image_features = model.get_image_features(processed["pixel_values"], normalize=True)
#     text_features = model.get_text_features(processed["input_ids"], normalize=True)

#     text_probs = (100 * image_features @ text_features.T).softmax(dim=-1)

# print(text_probs)
# [1.0000e+00, 8.3131e-12, 5.2173e-12]

You are using a model of type siglip to instantiate a model of type . This is not supported for all configurations of models and can yield errors.


Finish Loading Model


MarqoFashionSigLIP(
  (model): CustomTextCLIP(
    (visual): TimmModel(
      (trunk): VisionTransformer(
        (patch_embed): PatchEmbed(
          (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
          (norm): Identity()
        )
        (pos_drop): Dropout(p=0.0, inplace=False)
        (patch_drop): Identity()
        (norm_pre): Identity()
        (blocks): Sequential(
          (0): Block(
            (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
            (attn): Attention(
              (qkv): Linear(in_features=1024, out_features=3072, bias=True)
              (q_norm): Identity()
              (k_norm): Identity()
              (attn_drop): Dropout(p=0.0, inplace=False)
              (proj): Linear(in_features=1024, out_features=1024, bias=True)
              (proj_drop): Dropout(p=0.0, inplace=False)
            )
            (ls1): Identity()
            (drop_path1): Identity()
            (norm2): LayerNorm((1024,), eps=1e-06, eleme

In [11]:
import os
import pandas as pd
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms


class ShopeeDataset(Dataset):
    def __init__(self, df, img_dir, split: str = "train"):
        self.df = df
        self.img_dir = img_dir
        self.transform = transforms.Compose(
            [
                transforms.Resize(256),
                transforms.CenterCrop(224),
                # convert to RGB
                transforms.Lambda(lambda img: img.convert("RGB")),
                transforms.ToTensor(),
            ]
        )
        self.split = split
        self.len = len(self.df)
        # self.imgs = self._read_all_images()

    def __len__(self):
        return self.len

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        if self.split == "train":
            query_text = row["query"]
            pos_text = row["pos_txt"][0]
            neg_text = row["neg_txt"][0]

            query_img_path = os.path.join(self.img_dir, row["image"])
            pos_img_path = os.path.join(self.img_dir, row["pos_img"][0])
            neg_img_path = os.path.join(self.img_dir, row["neg_img"][0])

            query_img = self._get_image(query_img_path)
            pos_img = self._get_image(pos_img_path)
            neg_img = self._get_image(neg_img_path)

            return {
                "query": {
                    "text": query_text,
                    "image": query_img,
                },
                "pos": {
                    "text": pos_text,
                    "image": pos_img,
                },
                "neg": {
                    "text": neg_text,
                    "image": neg_img,
                },
            }

        elif self.split == "valid":
            title = row["title"]
            image_path = os.path.join(self.img_dir, row["image"])

            img = self._get_image(image_path)
            pil_img = Image.open(image_path)

            return {
                "title": title,
                "image": img,
                "pil_image": pil_img,
                "image_path": image_path,
            }

    def _get_image(self, path):
        img = Image.open(path)
        img = self.transform(img)
        return img

    def get_pil_image(self, path):
        return Image.open(path)


def get_collate_fn(processor):
    def collate_fn(batch):
        images = [item["image"] for item in batch]
        texts = [item["title"] for item in batch]

        processor.image_processor.do_rescale = True

        processed = processor(
            text=texts,
            images=images,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        return processed

    return collate_fn


def visualize(df, index, col):
    """可视化图片和预测结果，预测结果为posting_id的list"""
    row = df.iloc[index]
    preds = row[col]
    img_dir: str = "../data/train_images"
    images = [df[df.posting_id == pred].image.values[0] for pred in preds]

    target_title = row.title
    target_img = row.image

    titles = [df[df.posting_id == pred].title.values[0] for pred in preds]
    images = [Image.open(os.path.join(img_dir, img)) for img in images]

    rows = 5
    cols = 5
    fig, ax = plt.subplots(rows, cols, figsize=(20, 20))

    plt.axis("off")

    for i in range(cols):
        ax[0, i].axis("off")

    for i in range(rows):
        ax[i, 0].axis("off")

    ax[0, 0].imshow(Image.open(os.path.join(img_dir, target_img)))
    ax[0, 0].set_title(
        "\n".join([target_title[i : i + 10] for i in range(0, len(target_title), 10)]),
        fontsize=12,
    )

    for i in range(1, rows):
        for j in range(1, cols):
            idx = (i - 1) * cols + (j - 1)
            ax[i, j].axis("off")
            if idx < len(images):
                ax[i, j].imshow(images[idx])
                ax[i, j].set_title(
                    "\n".join(
                        [
                            titles[idx][k : k + 10]
                            for k in range(0, len(titles[idx]), 10)
                        ]
                    ),
                    fontsize=12,
                )


def compute_f1(col):
    def f1(row):
        n = len(np.intersect1d(row["label"], row[col]))
        return 2 * n / (len(row["label"]) + len(row[col]))

    return f1


def compute_recall(col):
    def recall(row):
        n = len(np.intersect1d(row["label"], row[col]))
        return n / len(row["label"])

    return recall


def compute_precision(col):
    def precision(row):
        n = len(np.intersect1d(row["label"], row[col]))
        return n / len(row[col])

    return precision


def compute_precision_K(col, K):
    def precision(row):
        n = len(np.intersect1d(row["label"], row[col][:K]))
        return n / K

    return precision


def compute_AP(col, N):
    """compute average precision"""

    def AP(row):
        n = len(np.intersect1d(row["label"], row[col]))
        max_n = min(len(row[col]), N)
        if n == 0:
            return 0
        return (
            sum(
                [
                    compute_precision_K(col, i)(row)
                    for i in range(1, max_n + 1)
                    if row[col][i - 1] in row["label"]
                ]
            )
            / max_n
        )

    return AP


def retrieval(embs, df, chunk_size=4096, threshold=None, topK=None):
    assert (
        threshold is not None or topK is not None
    ), "Either threshold or topK should be provided"
    assert (
        threshold is None or topK is None
    ), "Only one of threshold or topK should be provided"

    embs_pt = torch.tensor(embs).cuda()

    num_chunks = (embs_pt.shape[0] + chunk_size - 1) // chunk_size
    posting_id = df.posting_id.to_list()
    topk_posting_id = []

    print(f"Chunk Size: {chunk_size}, {num_chunks} chunks")

    for i in tqdm(range(num_chunks)):
        start = i * chunk_size
        end = min((i + 1) * chunk_size, embs_pt.shape[0])
        sim = embs_pt[start:end] @ embs_pt.T

        if topK is not None:
            indices = torch.topk(sim, topK, dim=1).indices.cpu().numpy()
            topk_posting_id.extend([[posting_id[j] for j in row] for row in indices])
        elif threshold is not None:
            mask = sim > threshold
            indices = [
                torch.nonzero(mask[j]).squeeze().cpu().numpy()
                for j in range(mask.shape[0])
            ]
            indices = [np.unique(i) for i in indices]
            sorted_indices = [
                indices[j][np.argsort(-sim[j, indices[j]].cpu().numpy())]
                for j in range(len(indices))
            ]
            topk_posting_id.extend(
                [[posting_id[j] for j in row] for row in sorted_indices]
            )

    # clean up
    del embs_pt
    torch.cuda.empty_cache()

    return topk_posting_id

In [12]:
from torch.utils.data import DataLoader

df = pd.read_csv("/data/MMRetrieval/data/train.csv")
img_dir = "/data/MMRetrieval/data/train_images"
split = "valid"

dataset = ShopeeDataset(df, img_dir, split)
collate_fn = get_collate_fn(processor)

dataloader = DataLoader(
    dataset,
    batch_size=128,
    shuffle=False,
    num_workers=16,
    collate_fn=collate_fn,
)

In [13]:
from tqdm import tqdm

embs = []

model.eval()
model.cuda()
for batch in tqdm(dataloader):
    with torch.no_grad():
        for k, v in batch.items():
            batch[k] = v.cuda() if isinstance(v, torch.Tensor) else v
        with torch.amp.autocast("cuda", torch.float16):
            image_features = model.get_image_features(
                batch["pixel_values"], normalize=True
            )
            text_features = model.get_text_features(batch["input_ids"], normalize=True)
        emb = torch.cat((image_features, text_features), dim=1).detach().cpu().numpy()
        emb /= np.linalg.norm(emb, axis=1, keepdims=True)

        embs.append(emb)

embs = np.concatenate(embs, axis=0)

del model, processor
torch.cuda.empty_cache()

  0%|          | 0/268 [00:00<?, ?it/s]

100%|██████████| 268/268 [03:35<00:00,  1.24it/s]


In [14]:
threshold = 0.73

for threshold in np.linspace(0.7, 0.8, 20):
    df["marqo"] = retrieval(embs, df, threshold=threshold, chunk_size=4096)

    tmp = df.groupby("label_group").posting_id.agg("unique").to_dict()
    df["label"] = df.label_group.map(tmp)

    df["f1"] = df.apply(compute_f1("marqo"), axis=1)
    df["recall"] = df.apply(compute_recall("marqo"), axis=1)
    df["precision"] = df.apply(compute_precision("marqo"), axis=1)
    df["AP"] = df.apply(compute_AP("marqo", 50), axis=1)

    print(f"Threshold: {threshold}")
    print(f"F1: {df.f1.mean()}")
    print(f"Recall: {df.recall.mean()}")
    print(f"Precision: {df.precision.mean()}")
    print(f"AP@50: {df.AP.mean()}")

Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:12<00:00,  1.37s/it]


Threshold: 0.7
F1: 0.7169715860699555
Recall: 0.7958610525142905
Precision: 0.7736292542971328
AP@50: 0.757403999577966
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:12<00:00,  1.33s/it]


Threshold: 0.7052631578947368
F1: 0.7187369420713211
Recall: 0.7865147777279675
Precision: 0.7848555276786418
AP@50: 0.7690263850097852
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:12<00:00,  1.40s/it]


Threshold: 0.7105263157894737
F1: 0.719146966965128
Recall: 0.7761383939515527
Precision: 0.7954824117160577
AP@50: 0.7801023693409754
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7157894736842105
F1: 0.7196337173414733
Recall: 0.7665292097139313
Precision: 0.8057446026784527
AP@50: 0.7907708731161478
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7210526315789474
F1: 0.7192573051170332
Recall: 0.7560084059511296
Precision: 0.8156266593117298
AP@50: 0.8011919587163536
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7263157894736842
F1: 0.7186044442860963
Recall: 0.7456848270970036
Precision: 0.8249685156600259
AP@50: 0.8110690502541517
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.731578947368421
F1: 0.7173077917063254
Recall: 0.735082747127935
Precision: 0.8343524106640746
AP@50: 0.8209920261169152
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7368421052631579
F1: 0.7152681817973798
Recall: 0.7241369263874514
Precision: 0.843287200971867
AP@50: 0.8304340280625496
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7421052631578947
F1: 0.7129533766821228
Recall: 0.7131644723281786
Precision: 0.8521008155326583
AP@50: 0.8397588415621583
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7473684210526316
F1: 0.7098753393865171
Recall: 0.7016584144223298
Precision: 0.8603454348802744
AP@50: 0.8485581905033526
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7526315789473684
F1: 0.7072072882821211
Recall: 0.6911278858334041
Precision: 0.8685003229152716
AP@50: 0.8572428003719255
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:12<00:00,  1.37s/it]


Threshold: 0.7578947368421053
F1: 0.7036960116689935
Recall: 0.6796932204259352
Precision: 0.8767250907823159
AP@50: 0.8660484545581179
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7631578947368421
F1: 0.6997773026777802
Recall: 0.6680817934394059
Precision: 0.8846056243794971
AP@50: 0.8744638245532933
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.768421052631579
F1: 0.6950464747505527
Recall: 0.6562221047898362
Precision: 0.891643099381259
AP@50: 0.8820693871528933
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7736842105263159
F1: 0.6900677085171764
Recall: 0.6442280773411908
Precision: 0.8986886502333409
AP@50: 0.8897412791588223
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.32s/it]


Threshold: 0.7789473684210526
F1: 0.6843970295678965
Recall: 0.6324254335795592
Precision: 0.9046001692456294
AP@50: 0.8962021497566823
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:12<00:00,  1.36s/it]


Threshold: 0.7842105263157895
F1: 0.6791941892001341
Recall: 0.620877312023803
Precision: 0.9111681462821601
AP@50: 0.9032947162826817
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.7894736842105263
F1: 0.6736200193781967
Recall: 0.6093784185340856
Precision: 0.9171332044300043
AP@50: 0.9098274363726835
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.32s/it]


Threshold: 0.7947368421052632
F1: 0.6677414577320583
Recall: 0.5979179466716646
Precision: 0.9231126443726941
AP@50: 0.9163473692853301
Chunk Size: 4096, 9 chunks


100%|██████████| 9/9 [00:11<00:00,  1.33s/it]


Threshold: 0.8
F1: 0.6613124322752144
Recall: 0.5861098454183068
Precision: 0.9287804559457362
AP@50: 0.922463597129737


In [None]:
index = 1
samples = df[df["f1"] < 0.6].index.to_list()
visualize(df, samples[index], "marqo")
visualize(df, samples[index], "label")