In [None]:
import os
import numpy as np
from tqdm import tqdm
from openai import OpenAI

import logging
logging.getLogger("httpx").setLevel(logging.WARNING)

from dataset import load_comparison_dataset

from misc import HUF_TOKEN, OAI_TOKEN
from misc import EMBEDDING_DIR, EMBEDDING_MODEL, DEFAULT_BATCH_SIZE, MAX_CHARS

In [None]:
os.environ["OPENAI_API_KEY"] = OAI_TOKEN
client = OpenAI()

In [None]:
def safe_text(x):
    return x if len(x) <= MAX_CHARS else x[:MAX_CHARS]

def embed_to_shards(prompts, emb_dir=EMBEDDING_DIR, batch_size=DEFAULT_BATCH_SIZE, model=EMBEDDING_MODEL):
    os.makedirs(emb_dir, exist_ok=True)

    n = len(prompts)
    shard_paths = []

    for start in tqdm(range(0, n, batch_size), desc="Embedding"):
        end = min(start + batch_size, n)
        batch = prompts[start:end]
        batch = [safe_text(x) for x in batch]

        resp = client.embeddings.create(model=model, input=batch)
        embs = np.array([d.embedding for d in resp.data], dtype=np.float32)

        shard_path = os.path.join(emb_dir, f"emb_{start:08d}_{end:08d}.npy")
        np.save(shard_path, embs)
        shard_paths.append(shard_path)

    return shard_paths

In [None]:
_, prompts, _, _, _ = load_comparison_dataset(token=HUF_TOKEN)

In [None]:
embed_to_shards(prompts)