In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.path.dirname("__file__"), ".."))
sys.path.append(project_root)
os.environ["TOKENIZERS_PARALLELISM"] = "false"


In [2]:
from src.dataset.coupert import CoupertDataset
from src.arguments import DataArguments
from utils.embedder import Embedder
from src.modeling.modeling_clip import CLIPForEmbedding
from transformers import AutoTokenizer, AutoModel, AutoProcessor
import torch
from tqdm import tqdm
import numpy as np
from safetensors.torch import save_file, load_file

model_dir = "../model/CLIP-ViT-L-14-laion2B-s32B-b82K"
embedding_path = "../embeddings/" + model_dir.split("/")[-1] + ".safetensors"
data_dir = "../data/coupert"

print(torch.cuda.device_count())

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


2


In [3]:
from torch.nn import DataParallel

processor = AutoProcessor.from_pretrained(model_dir, trust_remote_code=True)
model = CLIPForEmbedding.from_pretrained(model_dir)
model = DataParallel(model)

train_config = DataArguments(data_dir=data_dir, read_mode="all")

eval_config = DataArguments(data_dir=data_dir, read_mode="all")
gallery_config = DataArguments(data_dir=data_dir, read_mode="all")

train_dataset = CoupertDataset(train_config, mode="train")
eval_dataset = CoupertDataset(eval_config, mode="eval")
gallery_dataset = CoupertDataset(gallery_config, mode="gallery")


In [4]:
query_instruction_for_retrieval = (
    "Represent this title of product for searching similar products. \n {}"
)


def get_collate_fn(processor):
    def collate_fn(batch):
        images = [item["image"] for item in batch]
        texts = [item["title"] for item in batch]
        global_indices = [item["global_idx"] for item in batch]
        processor.image_processor.do_rescale = False
        processed = processor(
            text=texts,
            images=images,
            padding="max_length",
            truncation=True,
            return_tensors="pt",
        )

        outputs = {}

        outputs["text"] = {
            "input_ids": processed["input_ids"],
            "attention_mask": processed["attention_mask"],
        }

        outputs["image"] = {
            "pixel_values": processed["pixel_values"],
        }

        outputs["global_indices"] = torch.tensor(global_indices, dtype=torch.long)
        return outputs

    return collate_fn


train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    collate_fn=get_collate_fn(processor),
    num_workers=32,
    pin_memory=True,
)

eval_loader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=get_collate_fn(processor),
    num_workers=16,
    pin_memory=True,
)

gallery_loader = torch.utils.data.DataLoader(
    gallery_dataset,
    batch_size=128,
    shuffle=False,
    collate_fn=get_collate_fn(processor),
    num_workers=16,
    pin_memory=True,
)


In [5]:
embedder = Embedder(model, processor=processor, tokenizer=None)
embedder.embed(eval_loader, "eval")
embedder.embed(gallery_loader, "gallery")
embedder.save_embeddings(embedding_path)

  0%|          | 0/128 [00:00<?, ?it/s]ERROR:root:Error loading image: /mnt/weeddata/imgs_coupert/1/1015184901-276.jpg
ERROR:root:[Errno 2] No such file or directory: '/mnt/weeddata/imgs_coupert/1/1015184901-276.jpg'
 72%|███████▏  | 92/128 [01:16<00:30,  1.20it/s]


KeyboardInterrupt: 