In [1]:
from transformers import AutoModel, AutoTokenizer
from torch import Tensor
import torch

model = AutoModel.from_pretrained("thenlper/gte-base").to("cuda")
tokenizer = AutoTokenizer.from_pretrained("thenlper/gte-base")


def average_pool(last_hidden_states: Tensor, attention_mask: Tensor) -> Tensor:
    last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
    return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]


def process_batch(batch_of_text):
    model.eval()
    with torch.no_grad():
        batch_dict = tokenizer(
            batch_of_text,
            max_length=512,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        batch_dict = {k: v.to("cuda") for k, v in batch_dict.items()}
        outputs = model(**batch_dict)
        embeddings = average_pool(
            outputs.last_hidden_state, batch_dict["attention_mask"]
        )
        return embeddings

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os, json

from utils.data.data_module import DataModule

if os.getcwd().endswith("src"):
    os.chdir("..")

dataset_config = {}
with open("configs/datasets/id_dataset.json", "r") as f:
    dataset_config_2 = json.load(f)
    dataset_config.update(dataset_config_2)
datamodule = DataModule(**dataset_config)

Parsing animes...: 100%|██████████| 12294/12294 [00:01<00:00, 11621.06it/s]
Parsing users...: 100%|██████████| 73515/73515 [00:39<00:00, 1842.06it/s]
Resetting Train to k=0 ...: 100%|██████████| 48669/48669 [00:13<00:00, 3740.22it/s]


Number of Users: 54077, Hash[:8]: 9f0cd3, Hash: 9f0cd3119bd9ee7279856737c33aebb8
Total Animes: 12294, Total Users: 54077


In [3]:
import tqdm

all_embeddings = []
n_anime = datamodule.max_anime_count
for i in tqdm.tqdm(range(0, n_anime, 64)):
    samples = [
        datamodule.canonical_anime_mapping[i + j].name
        for j in range(64)
        if i + j < n_anime
    ]
    all_embeddings.append(process_batch(samples).detach().cpu().numpy())

100%|██████████| 193/193 [00:15<00:00, 12.58it/s]


In [4]:
import numpy as np

stacked_embeddings = np.vstack(all_embeddings)
os.makedirs("data/embeddings", exist_ok=True)
np.save("data/embeddings/gte-base_titles.npy", stacked_embeddings)

print(stacked_embeddings.shape)
print(stacked_embeddings.dtype)

(12294, 768)
float32
