# Demo: Using Embeddings in Data Processing with Daft

Daft makes working with complex data easy. This demonstration will show an **end-to-end example of data processing with embeddings in Daft**. We will:
1. load a large dataset, 
2. compute embeddings, 
3. use the embeddings for semantic similarity, and 
4. inspect the results and then and write them out.

----

**Scenario:** There are a lot of questions on StackExchange. Unfortunately, the vast majority of questions do not have high visibility or good answers. But, a similar question with a high quality answer may already exist elsewhere on the site.

We would like to go through all the questions on StackExchange and **associate each question with another question that is highly rated**.


## Step 0: Dependencies and configuration

In [None]:
# Install Daft! (We use the nightly version here, but latest will work too.)
!pip install 'getdaft[aws]' --pre --extra-index-url https://pypi.anaconda.org/daft-nightly/simple

# We will use sentence-transformers for computing embeddings.
!pip install sentence-transformers

In [None]:
CI = False

## Step 1: Load the dataset

We will use the **StackExchange crawl from the [RedPajamas dataset](https://huggingface.co/datasets/togethercomputer/RedPajama-Data-1T)**. It is 75GB of `jsonl` files. 

*EDIT (June 2023): Our hosted version of the full dataset is temporarily unavailable. Please enjoy the demo with the sample dataset for now.*

**Note:** This demo runs best on a cluster with many GPUs available. Information on how to connect Daft to a cluster is available [here](https://www.getdaft.io/projects/docs/en/stable/learn/user_guides/scaling-up.html). 

If running on a single node, you can use the provided subsample of the data, which is 75MB in size. If you like, you can also truncate either dataset to a desired number of rows using `df.limit`.

In [13]:
import daft

SAMPLE_DATA_PATH = "s3://daft-public-data/redpajama-1t-sample/stackexchange_sample.jsonl"

df = daft.read_json(SAMPLE_DATA_PATH)

if CI:
    df = df.limit(500)

print(df)



+--------+--------------------------------------------------------------------------------------------------------------+
| text   | meta                                                                                                         |
| Utf8   | Struct[language: Utf8, url: Utf8, timestamp: Timestamp(Seconds, None), source: Utf8, question_score: Utf8]   |
+--------+--------------------------------------------------------------------------------------------------------------+
(No data to display: Dataframe not materialized)


## Step 2: Compute embeddings

We can see there is a `text` column that holds the question answer text and a `meta` column with metadata.

Let's **compute the embeddings of our text**. We start by putting our model (SentenceTransformers) into a **Daft UDF**.

In [14]:
MODEL_NAME = "all-MiniLM-L6-v2"

@daft.udf(return_dtype=daft.DataType.python())
class EncodingUDF:
    def __init__(self):
        from sentence_transformers import SentenceTransformer
        self.model = SentenceTransformer(MODEL_NAME)

    def __call__(self, text_col):
        return [
            self.model.encode(text, convert_to_tensor=True)
            for text in text_col.to_pylist()
        ]

Then, we can just call the UDF to run the model.

In [15]:
df = df.with_column("embedding", EncodingUDF(df["text"]))

Pause and notice how easy it was to write this. 

In particular, we are not forced to do any sort of unwieldy type coercion on the embedding result; instead, we can return the result as-is, whatever it is, even when running Daft in a cluster. Daft dataframes can hold a wide range of data types and also grants you the flexibility of Python's dynamic typing when necessary.

Next, let's also **extract the URL and score**:

In [None]:
df = df.select(
    df["embedding"],
    df["meta"].apply(lambda x: x["url"], return_dtype=daft.DataType.string()).alias("URL"),
    df["meta"].apply(lambda x: x["question_score"], return_dtype=daft.DataType.int64()).alias("question_score"),
)

and wait for all the results to finish computing:

In [17]:
df = df.collect()
embeddings_df = df
print("Embeddings complete!")



Embeddings complete!


## Step 3: Semantic similarity

Let's **get the top questions**. We will use `df.sort` to sort by score and then `df.limit` to grab some fraction of the top.

In [18]:
import math
NUM_TOP_QUESTIONS = math.ceil(math.sqrt(len(df)))

top_questions = (
    df
    .sort(df["question_score"], desc=True)
    .limit(NUM_TOP_QUESTIONS)
).to_pydict()

Now we will **take each regular question** and **find a related top question**. For this we will need to do a similarity search. Let's do that within a Daft UDF.

In [19]:
@daft.udf(return_dtype=daft.DataType.python())
def similarity_search(embedding_col, top_embeddings, top_urls):
    if len(embedding_col) == 0: return []
    
    from sentence_transformers import util
    import torch

    # Tensor prep
    query_embedding_t = torch.stack(embedding_col.to_pylist())
    if torch.cuda.is_available():
        query_embedding_t = query_embedding_t.to("cuda")
        top_embeddings = top_embeddings.to("cuda")

    # Do semantic search
    results = util.semantic_search(query_embedding_t, top_embeddings, top_k=1)
    
    # Extract URL and score from search results
    results = [res[0] for res in results]
    results = [
        {
            "related_top_question": top_urls[res["corpus_id"]],
            "similarity": res["score"],
        }
        for res in results
    ]
    return results

import torch
df = df.with_column(
    "search_result", 
    similarity_search(
        df["embedding"], 
        top_embeddings=torch.stack(top_questions["embedding"]), 
        top_urls=top_questions["URL"],
    )
)

df = df.select(
    df["URL"],
    df["question_score"],
    df["search_result"]
        .apply(lambda x: x["related_top_question"], return_dtype=daft.DataType.string())
        .alias("related_top_question"),
    df["search_result"]
        .apply(lambda x: x["similarity"], return_dtype=daft.DataType.float64())
        .alias("similarity"),
)

## Step 4: Inspect and write results

Did the matching work well? Let's take a peek at our best results to see if they make sense.

In [20]:
df = df.where(df["similarity"] < 0.99)  # To ignore duplicate questions.
df = df.sort(df["similarity"], desc=True)  
df.show()

URL Utf8,question_score Int64,related_top_question Utf8,similarity Float64
https://stackoverflow.com/questions/72682333,-1,https://askubuntu.com/questions/401449,0.821935
https://dba.stackexchange.com/questions/185574,1,https://askubuntu.com/questions/401449,0.783146
https://stackoverflow.com/questions/23113375,0,https://stackoverflow.com/questions/9907682,0.779253
https://stackoverflow.com/questions/68984510,0,https://askubuntu.com/questions/401449,0.770626
https://stackoverflow.com/questions/72643833,1,https://stackoverflow.com/questions/34030373,0.740242
https://stackoverflow.com/questions/6092305,1,https://stackoverflow.com/questions/9907682,0.727464
https://stackoverflow.com/questions/14266910,1,https://stackoverflow.com/questions/4690758,0.723069
https://stackoverflow.com/questions/60891048,1,https://stackoverflow.com/questions/1388818,0.711756


On the left hand side is an average question without much activity. The link in the right hand side contains a similar question that already has some high quality answers. Success!

Finally, we will probably want to save the results for future use. Let's write them out to parquet files locally.

In [22]:
df.write_parquet("question_matches.pq").to_pydict()

{'file_path': ['question_matches.pq/6ee09e59-dbea-495d-894f-0f182567035d-0.parquet']}

## Conclusion

We have shown a simple example of a complex data processing workflow. It involved typical tabular operations like sort and filter. But, we also had to do interleave some pretty interesting things with complex data: we created, stored, and searched across embeddings. 

**Daft** is a data processing framework that allows you to do express these things easily, while also scaling up to large clusters right out of the box.

# You can get daft at [getdaft.io](https://getdaft.io).