In [1]:
"""Generated by O1 and modified. This is meant to help make sure that we aren't fucking up embeddings. For each dataset, for each pair of models
the embeddings should NOT be the same. At he same time, the chunks should be the same as should the IDs and everything else.
"""
import itertools
from pathlib import Path

import torch
from pydantic import BaseModel
from pydantic import Field
from safetensors import safe_open
from tqdm import tqdm


DATASETS = [
    # (numbers are counts for documents, there may be some longer documents -> slightly more chunks)
    "arguana",  # 10K
    "fiqa",  # 50K -> 20K
    "scidocs",  # 25K -> 20K
    "nfcorpus",  # 5K
    "hotpotqa",  # 100K -> 20K
    "trec-covid",  # too much -> 20K
]

MODEL_NAMES = [
    "Salesforce/SFR-Embedding-Mistral",
    "WhereIsAI/UAE-Large-V1",
    "BAAI/bge-base-en-v1.5",
    "BAAI/bge-large-en-v1.5",
    "BAAI/bge-small-en-v1.5",
    "intfloat/e5-base-v2",
    "intfloat/e5-large-v2",
    "intfloat/e5-small-v2",
    "thenlper/gte-base",
    "thenlper/gte-large",
    "thenlper/gte-small",
    "sentence-transformers/gtr-t5-base",
    "sentence-transformers/gtr-t5-large",
    "mixedbread-ai/mxbai-embed-large-v1",
    "sentence-transformers/sentence-t5-base",
    "sentence-transformers/sentence-t5-large",
    "text-embedding-3-large",  # openai
    "text-embedding-3-small",  # openai
]
# MODEL_NAMES = [m.replace("/", "_") for m in MODEL_NAMES] # we store them this way


# NOTE: copied from chunk_dataset.py and elsewhere
class Chunk(BaseModel):
    id: str = Field(alias="id")
    doc_id: str = Field(alias="doc_id")
    index_in_doc: int = Field(alias="index_in_doc")
    text: str = Field(alias="text")


def load_safetensors_embeddings(filepath: Path) -> torch.Tensor:
    if not filepath.exists():
        return None
    with safe_open(filepath.as_posix(), framework="pt", device="cpu") as f:
        # Should contain a single key "embeddings"
        return f.get_tensor("embeddings")


def load_metadata(filepath: Path) -> list[Chunk]:
    with open(filepath) as f:
        return [Chunk.model_validate_json(line) for line in f if len(line.strip()) > 0]


def compare_metadata(meta1: list[Chunk], meta2: list[Chunk]) -> bool:
    """Return if metadata is the same."""
    if len(meta1) != len(meta2):
        return False
    for c1, c2 in zip(meta1, meta2, strict=False):
        if (
            c1.id != c2.id
            or c1.doc_id != c2.doc_id
            or c1.index_in_doc != c2.index_in_doc
            or c1.text != c2.text
        ):
            return False
    return True


def model2model_dimension(model_name: str) -> int:
    """Helper: get the size of the embedding dimension vector (1D, usually something like 768-4096)."""
    # Miscellaneous (HF)
    if "/" in model_name:
        assert model_name.count("/") == 1
        model_name = model_name.split("/")[-1]
    if model_name == "SFR-Embedding-Mistral":
        return 4096
    if model_name == "UAE-Large-V1" or model_name == "mxbai-embed-large-v1":
        return 1024
    # BGE Models (HF)
    if model_name == "bge-base-en-v1.5":
        return 768
    if model_name == "bge-large-en-v1.5":
        return 1024
    if model_name == "bge-small-en-v1.5":
        return 384
    #  E5 Models (HF)
    if model_name == "e5-base-v2":
        return 768
    if model_name == "e5-large-v2":
        return 1024
    if model_name == "e5-small-v2":
        return 384
    # GTE Models (HF)
    if model_name == "gte-base":
        return 768
    if model_name == "gte-large":
        return 1024
    if model_name == "gte-small":
        return 384
    # GTR-T5 Models (HF)
    if (
        model_name == "gtr-t5-base"
        or model_name == "gtr-t5-large"
        or model_name == "sentence-t5-base"
        or model_name == "sentence-t5-large"
    ):
        return 768
    # OpenAI Models
    if model_name == "text-embedding-3-large":
        return 3072
    if model_name == "text-embedding-3-small":
        return 1536
    # NOTE: cohere may be supported in THE FUTURE
    raise ValueError(f"Unsupported model: {model_name}")


def get_model_files(model_dir: Path) -> list[Path]:
    return [
        # embeddings
        model_dir / "embeddings_corpus_train.safetensors",
        model_dir / "embeddings_corpus_validation.safetensors",
        model_dir / "embeddings_queries_train.safetensors",
        model_dir / "embeddings_queries_validation.safetensors",
        # metadata
        model_dir / "metadatas_corpus_train.jsonl",
        model_dir / "metadatas_corpus_validation.jsonl",
        model_dir / "metadatas_queries_train.jsonl",
        model_dir / "metadatas_queries_validation.jsonl",
    ]


def get_reversed_model_files(model_dir: Path) -> list[Path]:
    return [
        # reverse because sometimes we do "corpus_embeddings..."
        #
        # saftensors
        model_dir / "corpus_train_embeddings.safetensors",
        model_dir / "queries_train_embeddings.safetensors",
        model_dir / "corpus_validation_embeddings.safetensors",
        model_dir / "queries_validation_embeddings.safetensors",
        # jsonls
        model_dir / "corpus_train_metadatas.jsonl",
        model_dir / "queries_train_metadatas.jsonl",
        model_dir / "corpus_validation_metadatas.jsonl",
        model_dir / "queries_validation_metadatas.jsonl",
    ]

In [2]:
def make_sure_individually_ok(root_dir: Path) -> list[tuple[str, str, str]]:
    skipped_model_pairs: list[
        tuple[str, str, str]
    ] = []  # (dataset_name, model_name, reason)
    cartesian_product = list(itertools.product(DATASETS, MODEL_NAMES))
    for dataset, model_name in tqdm(cartesian_product):
        model_dir = root_dir / model_name.replace("/", "_") / dataset
        if not model_dir.exists():
            skipped_model_pairs.append((dataset, model_name, "missing dir"))
            continue
        model_files = get_model_files(model_dir)
        tensors_files = model_files[:4]
        meta_files = model_files[4:]
        if any(not f.exists() for f in model_files):
            # exactly subset of
            # ├── corpus_embeddings.safetensors
            # ├── corpus_metadatas.jsonl
            # ├── corpus_train_embeddings.safetensors
            # ├── corpus_train_metadatas.jsonl
            # ├── corpus_validation_embeddings.safetensors
            # ├── corpus_validation_metadatas.jsonl
            # ├── queries_embeddings.safetensors
            # ├── queries_metadatas.jsonl
            # ├── queries_train_embeddings.safetensors
            # ├── queries_train_metadatas.jsonl
            # ├── queries_validation_embeddings.safetensors
            # └── queries_validation_metadatas.jsonl
            model_files = get_reversed_model_files(model_dir)
            if any(not f.exists() for f in model_files):
                skipped_model_pairs.append((dataset, model_name, "missing files"))
                continue
            tensors_files = model_files[:4]
            meta_files = model_files[4:]
        else:
            # If you make it here you should not be skipped and should be good if you make it through
            for f_tensor, f_meta in zip(tensors_files, meta_files, strict=False):
                tensors = load_safetensors_embeddings(f_tensor)
                meta = load_metadata(f_meta)
                assert len(tensors.shape) == 2
                assert len(meta) == tensors.shape[0]
                expected_length = model2model_dimension(model_name)
                assert tensors.shape[1] == expected_length
    return skipped_model_pairs

In [3]:
ROOT_DIR1 = Path("/mnt/align3_drive/adrianoh/dl_final_project_embeddings_huggingface")
ROOT_DIR2 = Path("/mnt/align3_drive/adrianoh/dl_final_project_embeddings_openai")

In [4]:
x = make_sure_individually_ok(ROOT_DIR1)
print(
    "\n".join(map(str, [x for x in x if "text-embedding-3-" not in x[1]]))
)  # huggingface
print(len(x) / len(DATASETS) / len(MODEL_NAMES))  # pct -> want to be close to zero ngl

100%|██████████| 108/108 [00:17<00:00,  6.27it/s]

('arguana', 'Salesforce/SFR-Embedding-Mistral', 'missing dir')
('fiqa', 'Salesforce/SFR-Embedding-Mistral', 'missing dir')
('scidocs', 'Salesforce/SFR-Embedding-Mistral', 'missing dir')
('nfcorpus', 'Salesforce/SFR-Embedding-Mistral', 'missing dir')
('hotpotqa', 'Salesforce/SFR-Embedding-Mistral', 'missing dir')
('trec-covid', 'Salesforce/SFR-Embedding-Mistral', 'missing dir')
('trec-covid', 'sentence-transformers/sentence-t5-large', 'missing files')
0.17592592592592593





In [5]:
x = make_sure_individually_ok(ROOT_DIR2)
printout = "\n".join(map(str, [x for x in x if "text-embedding-3-" in x[1]]))  # openai
if "text-embedding-3-" in printout:
    print("WTF")
print(printout)  # should be EMPTY no matter what

100%|██████████| 108/108 [00:00<00:00, 2858.42it/s]







In [16]:
# 1. Check all pairs
def make_sure_pairs_ok(
    root_dir: Path,
    stream: bool = False,
    filter_against: str = None,
    filter_for: str = None,
) -> tuple[list[tuple[str, str, str, str]], list[tuple[str, str, str, str]]]:
    skipped_model_pairs: list[
        tuple[str, str, str, str]
    ] = []  # (dataset_name, model1_name, model2_name, reason)
    bad_model_pairs: list[
        tuple[str, str, str, str]
    ] = []  # (dataset_name, model1_name, model2_name, reason)
    cartesian_product = list(itertools.product(DATASETS, MODEL_NAMES, MODEL_NAMES))
    cartesian_product = [
        x for x in cartesian_product if x[1] != x[2]
    ]  # don't compare the same model to itself
    if stream:
        print("Filtering against: ", filter_against)
        print("Filtering for: ", filter_for)
        print("Starting on ", len(cartesian_product), " pairs")
    cartesian_product = [
        x
        for x in cartesian_product
        if (
            filter_against is None
            or not (filter_against in x[1] or filter_against in x[2])
        )
    ]
    cartesian_product = [
        x
        for x in cartesian_product
        if (filter_for is None or (filter_for in x[1] and filter_for in x[2]))
    ]
    if stream:
        print("Reduced to ", len(cartesian_product), " pairs")
    for dataset_name, model1_name, model2_name in tqdm(cartesian_product):
        assert model1_name != model2_name
        model1_path = root_dir / model1_name.replace("/", "_") / dataset_name
        model2_path = root_dir / model2_name.replace("/", "_") / dataset_name
        # NOTE we need something like these:
        # scidocs
        # │   │   ├── embeddings_corpus_train.safetensors
        # │   │   ├── embeddings_corpus_validation.safetensors
        # │   │   ├── embeddings_queries_train.safetensors
        # │   │   ├── embeddings_queries_validation.safetensors
        # │   │   ├── metadatas_corpus_train.jsonl
        # │   │   ├── metadatas_corpus_validation.jsonl
        # │   │   ├── metadatas_queries_train.jsonl
        # │   │   └── metadatas_queries_validation.jsonl
        model1_files = get_model_files(model1_path)
        if any(not f.exists() for f in model1_files):
            model1_files = get_reversed_model_files(model1_path)
        model2_files = get_model_files(model2_path)
        if any(not f.exists() for f in model2_files):
            model2_files = get_reversed_model_files(model2_path)
        if any(not f.exists() for f in model1_files) or any(
            not f.exists() for f in model2_files
        ):
            skipped_model_pairs.append(
                (dataset_name, model1_name, model2_name, "missing files")
            )
            continue
        assert len(model1_files) == len(model2_files)
        # Make sure that the pairs are OK
        for file1, file2 in zip(model1_files, model2_files, strict=False):
            assert file1.exists() and file2.exists()
            assert file1.name == file2.name
            # 1. Make sure all models are different embedding
            if file1.suffix == ".safetensors":
                tensors1 = load_safetensors_embeddings(file1)
                tensors2 = load_safetensors_embeddings(file2)
                assert tensors1 is not None and tensors2 is not None
                if tensors1.shape == tensors2.shape and torch.allclose(
                    tensors1, tensors2
                ):  # should be DIFFERENT
                    if stream:
                        print(
                            "BAD [EMBEDDINGS]: ",
                            dataset_name,
                            model1_name,
                            model2_name,
                        )
                    bad_model_pairs.append(
                        (
                            dataset_name,
                            model1_name,
                            model2_name,
                            "embeddings all close",
                        )
                    )
            # 2. Make sure that all metadatas match though
            elif file1.suffix == ".jsonl":
                meta1 = load_metadata(file1)
                meta2 = load_metadata(file2)
                assert meta1 is not None and meta2 is not None
                if not compare_metadata(meta1, meta2):  # should be SAME
                    if stream:
                        print(
                            "BAD [METADATA]: ",
                            dataset_name,
                            model1_name,
                            model2_name,
                        )
                    bad_model_pairs.append(
                        (
                            dataset_name,
                            model1_name,
                            model2_name,
                            "metadata not same",
                        )
                    )
    return skipped_model_pairs, bad_model_pairs

In [18]:
skipped, bad = make_sure_pairs_ok(
    ROOT_DIR1, stream=True, filter_against="text-embedding-3"
)  # very low so we stream
print("=" * 50 + " SKIPPED " + "=" * 50)
print(
    "\n".join(
        map(str, [s for s in skipped if "text-embedding-3-" not in (s[1] + s[2])])
    )
)  # concenate is an easy hack for this
print("=" * 50 + " BAD " + "=" * 50)
print("\n".join(map(str, [s for s in bad if "text-embedding-3-" not in (s[1] + s[2])])))

Filtering against:  text-embedding-3
Filtering for:  None
Starting on  1836  pairs
Reduced to  1440  pairs


100%|██████████| 1440/1440 [09:00<00:00,  2.66it/s]

('arguana', 'Salesforce/SFR-Embedding-Mistral', 'WhereIsAI/UAE-Large-V1', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'BAAI/bge-base-en-v1.5', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'BAAI/bge-large-en-v1.5', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'BAAI/bge-small-en-v1.5', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'intfloat/e5-base-v2', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'intfloat/e5-large-v2', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'intfloat/e5-small-v2', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'thenlper/gte-base', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'thenlper/gte-large', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'thenlper/gte-small', 'missing files')
('arguana', 'Salesforce/SFR-Embedding-Mistral', 'sentence-transformers/gtr-t5-base', 'missing files')
('arguana




In [17]:
skipped, bad = make_sure_pairs_ok(ROOT_DIR2, filter_for="text-embedding-3")
print("=" * 50 + " SKIPPED " + "=" * 50)
print("\n".join(map(str, [s for s in skipped if "text-embedding-3-" in (s[1] + s[2])])))
print("=" * 50 + " BAD " + "=" * 50)
print("\n".join(map(str, [s for s in bad if "text-embedding-3-" in (s[1] + s[2])])))

100%|██████████| 12/12 [00:04<00:00,  2.47it/s]






