In [23]:
from datasets import load_dataset

dataset = load_dataset("quora")

In [24]:
import pandas as pd


def flatten_data(dataset):
    for pairs in dataset["train"]["questions"]:
        for id, text in zip(pairs["id"], pairs["text"]):
            yield {"id": id, "text": text}


df = pd.DataFrame(flatten_data(dataset))
print(f"Found {len(df)} training examples")
df.drop_duplicates(subset="id", inplace=True)
print(f"Found {len(df)} unique training examples")
df.to_csv("quora_train.csv", index=False)

Found 808580 training examples
Found 537933 unique training examples


In [4]:
# This should be the set of texts we should just pass and tune embedding:
df.head()

Unnamed: 0,id,text
0,1,What is the step by step guide to invest in sh...
1,2,What is the step by step guide to invest in sh...
2,3,What is the story of Kohinoor (Koh-i-Noor) Dia...
3,4,What would happen if the Indian government sto...
4,5,How can I increase the speed of my internet co...


In [7]:
import chunk
from functools import lru_cache
import itertools
from openai import AsyncOpenAI, OpenAI
from typing import AsyncGenerator, List, Literal, Tuple
import asyncio


def batched(iterable, n=1):
    """
    Yields batches of size n from iterable
    """
    it = iter(iterable)
    while True:
        chunk = list(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk


class Embedder:
    @classmethod
    def batch_text(cls, texts: List[Tuple[int, str]], batch_size: int = 32):
        for batch in batched(texts, batch_size):
            yield batch

    @classmethod
    async def embed_openai(
        cls,
        chunks: List[Tuple[int, str]],
        model: Literal[
            "text-embedding-3-small", "text-embedding-3-large"
        ] = "text-embedding-3-small",
    ):
        client = AsyncOpenAI()
        sem = asyncio.Semaphore(32)

        # There's an opportunity to cache this single method
        # Also an opportunity to add retry logic.
        async def fetch_embedding(idz: int, text_str: str):
            async with sem:
                response = await client.embeddings.create(input=text_str, model=model)
                return (idz, response.data[0].embedding)

        results = await asyncio.gather(
            *[fetch_embedding(idz, text_str) for (idz, text_str) in chunks]
        )
        return results

In [59]:
def generate_id_str():
    df = pd.read_csv("quora_train.csv")
    sample = df.head(100)
    sample_tuples = zip(sample["id"], sample["text"])
    return sample_tuples


async def embed(sample_tuples):
    results = await Embedder.embed_openai(
        chunks=sample_tuples, model="text-embedding-3-small"
    )
    df = pd.DataFrame(results, columns=["id", "embedding"]).set_index("id")
    return df


sample_tuples = generate_id_str()
embeddings_df = await embed(sample_tuples)
embeddings_df.to_csv("quora_train_embeddings.csv", index=True)
embeddings_df.head()

Unnamed: 0_level_0,embedding
id,Unnamed: 1_level_1
1,"[-0.0020987512543797493, 0.0270865336060524, 0..."
2,"[0.001020575873553753, 0.013290978036820889, 0..."
3,"[-0.009665180929005146, 0.010190058499574661, ..."
4,"[-0.029207905754446983, -0.0005261672777123749..."
5,"[-0.021656425669789314, -0.010261747054755688,..."


In [60]:
dataset

DatasetDict({
    train: Dataset({
        features: ['questions', 'is_duplicate'],
        num_rows: 404290
    })
})

In [64]:
def new_dataset(embeddings: pd.DataFrame):
    dataset = load_dataset("quora")
    question = dataset["train"]
    for data in question:
        try:
            id1, id2 = data["questions"]["id"]
            e1, e2 = embeddings.iloc[id1][0], embeddings.iloc[id2][0]
            yield (id1, id2, e1, e2, data["is_duplicate"])
        except Exception as e:
            break


df = pd.DataFrame(new_dataset(embeddings_df), columns=["id1", "id2", "e1", "e2", "is_duplicate"])
df

  e1, e2 = embeddings.iloc[id1][0], embeddings.iloc[id2][0]


Unnamed: 0,id1,id2,e1,e2,is_duplicate
0,1,2,"[0.001020575873553753, 0.013290978036820889, 0...","[-0.009665180929005146, 0.010190058499574661, ...",False
1,3,4,"[-0.029207905754446983, -0.0005261672777123749...","[-0.021656425669789314, -0.010261747054755688,...",False
2,5,6,"[0.03239667788147926, 0.02922770380973816, 0.0...","[0.0016469762194901705, -0.05773146450519562, ...",False
3,7,8,"[0.049847107380628586, -0.010537532158195972, ...","[0.0366658940911293, 0.02135019563138485, -0.0...",False
4,9,10,"[0.043602462857961655, 0.020671047270298004, 0...","[-0.005765771958976984, -0.018585262820124626,...",False
5,11,12,"[0.026014558970928192, -0.014319832436740398, ...","[-0.005711750593036413, -0.023612970486283302,...",True
6,13,14,"[0.019067052751779556, 0.05856925621628761, -0...","[0.005276682320982218, 0.004194203298538923, 0...",False
7,15,16,"[0.015116829425096512, 0.0010464431252330542, ...","[0.030122632160782814, -0.04054081067442894, 0...",True
8,17,18,"[0.02603609673678875, -0.04887866973876953, -0...","[0.0476459376513958, 0.013450254686176777, -0....",False
9,19,20,"[0.0013767997734248638, -0.00693125743418932, ...","[-0.004183064680546522, -0.0024396006483584642...",False
