In [1]:
import json
import os
import random
from typing import Annotated, List
from tqdm import tqdm

import bittensor as bt
import openai
import torch
from dotenv import load_dotenv
from loguru import logger
from pydantic import BaseModel

from datasets import load_dataset
from openkaito.protocol import TextEmbeddingSynapse
from openkaito.utils.embeddings import openai_embeddings_tensor
from openkaito.utils.version import get_version

from sentence_transformers import SentenceTransformer


  from .autonotebook import tqdm as notebook_tqdm


In [2]:

load_dotenv()

# for ranking results evaluation
llm_client = openai.OpenAI(
    api_key=os.environ["OPENAI_API_KEY"],
    organization=os.getenv("OPENAI_ORGANIZATION"),
    project=os.getenv("OPENAI_PROJECT"),
    max_retries=3,
)

In [3]:
ds = load_dataset("microsoft/ms_marco", "v1.1", split="test")

In [4]:

total_samples = ds.num_rows
print(f"Total samples: {total_samples}")

for i, row in enumerate(ds):
    print("====================")
    print("query:", row["query"])
    print("passages:", row["passages"])
    if i > 3:
        break

Total samples: 9650
query: does human hair stop squirrels
passages: {'is_selected': [0, 0, 1, 0, 0, 0, 0], 'passage_text': ['We have been feeding our back yard squirrels for the fall and winter and we noticed that a few of them have missing fur. One has a patch missing down his back and under both arms. Also another has some missing on his whole chest. They are all eating and seem to have a good appetite.', 'Critters cannot stand the smell of human hair, so sprinkling a barrier of hair clippings around your garden, or lightly working it into the soil when you plant bulbs, apparently does have some merit. The whole thing kind of makes me laugh. It never occurred to me that we are the ones that stink.', "Spread some human hair around your vegetable and flower gardens. This will scare the squirrels away because humans are predators of squirrels. It is better if the hair hasn't been washed so the squirrels will easily pick up the human scent.", '1 You can sprinkle blood meal around your ga

In [5]:
minilm_model = SentenceTransformer("all-MiniLM-L6-v2")


In [6]:

minilm_embedding = minilm_model.encode(
    ["Hello bittensor subnet openkaito"],
    convert_to_tensor=True,
    normalize_embeddings=True,
)
# print(minilm_embedding)
print(minilm_embedding.shape)

torch.Size([1, 384])


In [7]:
openai_large_precisions = []
minilm_precisions = []


report_interval = 20

for i, row in enumerate(ds):
    query = row["query"]
    passages = row["passages"]
    num_passages = len(passages["is_selected"])
    if sum(passages["is_selected"]) == 0:
        logger.trace(f"{i}: Query {query} has no positive passage, skipping")
        continue
    # positive_idx = passages["is_selected"].index(1)
    selected_indices = (
        torch.tensor(passages["is_selected"], dtype=torch.long).nonzero().squeeze()
    )

    texts = [query + "?"] + passages["passage_text"]
    openai_embeddings = openai_embeddings_tensor(
        llm_client,
        texts,
        dimensions=512,
        model="text-embedding-3-large",
    )
    openai_query_embeddings = openai_embeddings[0].unsqueeze(0)
    openai_passage_embeddings = openai_embeddings[1:]

    openai_top1_prediction = (
        (openai_query_embeddings @ openai_passage_embeddings.T).argmax().item()
    )
    openai_precision = openai_top1_prediction in selected_indices
    openai_large_precisions.append(openai_precision)

    minilm_embeddings = minilm_model.encode(
        texts, convert_to_tensor=True, normalize_embeddings=True
    )
    minilm_query_embedding = minilm_embeddings[0].unsqueeze(0)
    minilm_passage_embeddings = minilm_embeddings[1:]
    minilm_top1_prediction = (
        (minilm_query_embedding @ minilm_passage_embeddings.T).argmax().item()
    )
    minilm_precision = minilm_top1_prediction in selected_indices
    minilm_precisions.append(minilm_precision)

    if i % report_interval == 0:
        print("i:", i)
        print(
            f"OpenAI text-embedding-3-large avg retrieval precision: {sum(openai_large_precisions) / len(openai_large_precisions)}"
        )
        print(
            f"all-MiniLM-L6-v2 avg retrieval precision: {sum(minilm_precisions) / len(minilm_precisions)}"
        )


i: 0
OpenAI text-embedding-3-large avg retrieval precision: 1.0
all-MiniLM-L6-v2 avg retrieval precision: 0.0
i: 20
OpenAI text-embedding-3-large avg retrieval precision: 0.55
all-MiniLM-L6-v2 avg retrieval precision: 0.25
i: 60
OpenAI text-embedding-3-large avg retrieval precision: 0.39655172413793105
all-MiniLM-L6-v2 avg retrieval precision: 0.3275862068965517
i: 80
OpenAI text-embedding-3-large avg retrieval precision: 0.4230769230769231
all-MiniLM-L6-v2 avg retrieval precision: 0.38461538461538464
i: 100
OpenAI text-embedding-3-large avg retrieval precision: 0.4387755102040816
all-MiniLM-L6-v2 avg retrieval precision: 0.37755102040816324
i: 120
OpenAI text-embedding-3-large avg retrieval precision: 0.4322033898305085
all-MiniLM-L6-v2 avg retrieval precision: 0.3559322033898305
i: 140
OpenAI text-embedding-3-large avg retrieval precision: 0.4632352941176471
all-MiniLM-L6-v2 avg retrieval precision: 0.3897058823529412
i: 160
OpenAI text-embedding-3-large avg retrieval precision: 0.43

In [8]:
print(
    f"OpenAI text-embedding-3-large avg retrieval precision: {sum(openai_large_precisions) / len(openai_large_precisions)}"
)
print(
    f"all-MiniLM-L6-v2 avg retrieval precision: {sum(minilm_precisions) / len(minilm_precisions)}"
)

OpenAI text-embedding-3-large avg retrieval precision: 0.36126270733012306
all-MiniLM-L6-v2 avg retrieval precision: 0.35066880684858215
