In [1]:
import os
import sys

project_root = os.path.abspath(os.path.join(os.path.dirname("__file__"), ".."))
sys.path.append(project_root)
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [2]:
from src.dataset.coupert import CoupertDataset
from src.arguments import DataArguments
from transformers import AutoTokenizer, AutoModel
import torch
from tqdm import tqdm
import numpy as np
from safetensors.torch import save_file, load_file

model_dir = "../model/BAAI/bge_base_en_v1.5"

data_dir = "../data/coupert"

  from .autonotebook import tqdm as notebook_tqdm
  check_for_updates()


In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModel.from_pretrained(model_dir, device_map="auto")

train_config = DataArguments(data_dir=data_dir, read_mode="text")

eval_config = DataArguments(data_dir=data_dir, read_mode="text")
gallery_config = DataArguments(data_dir=data_dir, read_mode="text")

train_dataset = CoupertDataset(train_config, mode="train")
eval_dataset = CoupertDataset(eval_config, mode="eval")
gallery_dataset = CoupertDataset(gallery_config, mode="gallery")

In [4]:
# # test dataset
# for data in tqdm(train_dataset):
#     # print(data)
#     continue

In [5]:
import logging


def get_collate_fn(tokenizer, mode="train"):
    def collate_fn(batch):
        titles = [item["title"] for item in batch]
        inputs = tokenizer(
            titles,
            padding=True,
            truncation=True,
            return_tensors="pt",
            max_length=256,
        )
        return inputs

    return collate_fn


train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=512,
    shuffle=True,
    collate_fn=get_collate_fn(tokenizer, mode="train"),
    num_workers=32,
    pin_memory=True,
)

eval_loader = torch.utils.data.DataLoader(
    eval_dataset,
    batch_size=512,
    shuffle=False,
    collate_fn=get_collate_fn(tokenizer, mode="eval"),
    num_workers=32,
    pin_memory=True,
)

gallery_loader = torch.utils.data.DataLoader(
    gallery_dataset,
    batch_size=512,
    shuffle=False,
    collate_fn=get_collate_fn(tokenizer, mode="gallery"),
    num_workers=32,
    pin_memory=True,
)

In [6]:
eval_embs = []
gallery_embs = []
with torch.no_grad():
    with torch.amp.autocast("cuda", torch.float16):
        for inputs in tqdm(eval_loader):
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
            outputs = model(**inputs).last_hidden_state[:, -1].detach().cpu().numpy()
            eval_embs.append(outputs)
        for inputs in tqdm(gallery_loader):
            inputs = {k: v.to("cuda") for k, v in inputs.items()}
            outputs = model(**inputs).last_hidden_state[:, -1].detach().cpu().numpy()
            gallery_embs.append(outputs)

eval_embs = np.concatenate(eval_embs, axis=0)
gallery_embs = np.concatenate(gallery_embs, axis=0)
eval_embs /= np.linalg.norm(eval_embs, axis=1, keepdims=True)
gallery_embs /= np.linalg.norm(gallery_embs, axis=1, keepdims=True)
print(eval_embs.shape)
print(gallery_embs.shape)

100%|██████████| 32/32 [00:14<00:00,  2.24it/s]
100%|██████████| 4860/4860 [14:18<00:00,  5.66it/s]


(16271, 768)
(2488144, 768)


In [7]:
eval_embs = torch.tensor(eval_embs)
gallery_embs = torch.tensor(gallery_embs)

save_file(
    {"eval_embs": eval_embs, "gallery_embs": gallery_embs},
    "../embeddings/embs_bge_base_en.safetensors",
)