# arXiv Paper Embedding


## Multi GPU w/ Dask + CUDF
Using Dask and CuDF to orchestrate sentence embedding over multiple GPU workers.

![Rapids and Dask Logos](https://saturn-public-assets.s3.us-east-2.amazonaws.com/example-resources/rapids_dask.png "doc-image")

## Important Imports

* [`dask_saturn`](https://github.com/saturncloud/dask-saturn) and [`dask_distributed`](http://distributed.dask.org/en/stable/): Set up and run the Dask cluster in Saturn Cloud.
* [`dask-cudf`](https://docs.rapids.ai/api/cudf/stable/basics/dask-cudf.html): Create distributed `cudf` dataframes using Dask.

In [None]:
import sys
from pathlib import Path
sys.path.append(str(Path.cwd().parent))

import asyncio
import dask_cudf
from dask_saturn import SaturnCluster
from dask.distributed import Client, get_worker, wait
import cudf
import json
import numpy as np
import pandas as pd
import redis
import re
from redis.commands.search.field import VectorField, TagField

from config.redis_config import ARXIV_PAPERS_PREFIX_KEY, INDEX_NAME, INDEX_TYPE, QUEUE_NAME, REDIS_URL
from config import N_WORKERS, YEAR_PATTERN
from lib.embeddings import Embeddings
from lib.search_index import SearchIndex

## Helpers functions

In [None]:
def extract_year_from_journal_ref(journal_ref: str) -> str:
    if journal_ref:
        years = [int(year) for year in re.findall(YEAR_PATTERN, journal_ref)]
        year = str(min(years)) if years else ""
    else:
        year = ""
    return year

def process_categories(categories: str) -> str:
    return ","".join(categories.split(" "))


def process_papers(papers:list[dict]) -> list[dict]:
    embeddings = Embeddings()
    return [
        {
            "id": paper["id"],
            "year": extract_year_from_journal_ref(paper["journal_ref"]),
            "categories_processed": process_categories(paper["categories"]),
            "input": embeddings.clean_description(paper["title"] + " " + paper["abstract"])
        } for paper in papers
    ]

In [None]:
def embed_partition(df: dask_cudf.DataFrame):
    """
    Create embeddings on single partition of DF (one dask worker)
    """
    worker = get_worker()
    if hasattr(worker, "model"):
        model = worker.model
    else:
        from sentence_transformers import SentenceTransformer

        model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
        worker.model = model

    print("embedding input", flush=True)
        
    # embed the input      
    vectors = model.encode(
        sentences = df.input.values_host,
        normalize_embeddings = True,
        show_progress_bar = True
    )
    
    # Convert to cudf series and return
    df["vector"] = cudf.Series(vectors.tolist(), index=df.index)
    return df[["id", "vector"]]

def clear_workers():
    """
    Deletes model attribute, freeing up memory on the Dask workers
    """
    import torch
    import gc

    worker = get_worker()
    if hasattr(worker, "model"):
        del worker.model
    torch.cuda.empty_cache()
    gc.collect()
    return

In [None]:
def get_arxiv_papers(redis_client, ids: list[str]) -> list[dict]:
    pipe = redis_client.pipeline()
    fields = ["id", "title", "abstract", "categories", "journal_ref"]
    for id in ids:
        pipe.hmget(f"{ARXIV_PAPERS_PREFIX_KEY}/{id}", fields)
    papers_values = pipe.execute()
    return [dict(zip(fields, values)) for values in papers_values]

# Embedding compute with dask workers

In [None]:
redis_client = redis.from_url(REDIS_URL, decode_responses=True)

cluster = SaturnCluster(n_workers=N_WORKERS)
client = Client(cluster)
client.wait_for_workers(n_workers=N_WORKERS)

### Execute embedding and preprocessing by batch of 50000 papers

In [None]:
INITIAL_NUMBER_OF_VECTORS = redis_client.llen(QUEUE_NAME)
print(INITIAL_NUMBER_OF_VECTORS)

In [None]:
while True:
    papers_ids = redis_client.lpop(QUEUE_NAME, 100000)
    if not papers_ids:
        print("Queue is empty")
        break
    else:
        print(f"{len(papers_ids)} papers to process")

        papers = get_arxiv_papers(redis_client, papers_ids)
        print("Papers have been retrieved from Redis")

        processed_papers = process_papers(papers)
        cdf = cudf.DataFrame(processed_papers)
        ddf = dask_cudf.from_cudf(cdf, npartitions=N_WORKERS)

        output_df = ddf[["id", "input"]].map_partitions(
            func = embed_partition,
            meta = {
              "id": object,
              "vector": cudf.ListDtype('float32')
            }
        )
        full_ddf = ddf.merge(output_df)
        full_ddf=full_ddf.compute().to_pandas()

        print("Embedding and processing is done")

        pipe = redis_client.pipeline()
        for _,row in full_ddf.iterrows():
            pipe.hset(
                f'{ARXIV_PAPERS_PREFIX_KEY}/{row["id"]}',
                mapping={
                    "year":row["year"],
                    "categories_processed":row["categories_processed"],
                    "vector":np.array(row["vector"],dtype=np.float32).tobytes()        
                }
            )
        result = pipe.execute()
        print("Embeddings and preprocessing uploaded to Redis")

cluster.close()

## Create Index

In [None]:
async def create_index(redis_conn, index_name, index_type, number_of_vectors):
    search_index = SearchIndex()
    categories_field = TagField("categories_processed", separator="|")
    year_field = TagField("year", separator="|")
    try:
        result = await redis_conn.ft(index_name).info()
        print(f"Index {index_name} already exists")
    except redis.ResponseError as e:
        print(e)
        print("Creating vector search index")
        if index_type == "HNSW":
            await search_index.create_hnsw(
                categories_field,
                year_field,
                redis_conn=redis_conn,
                number_of_vectors=number_of_vectors,
                prefix=ARXIV_PAPERS_PREFIX_KEY,
                distance_metric="IP",
            )
        else:
            await search_index.create_flat(
                categories_field,
                year_field,
                redis_conn=redis_conn,
                number_of_vectors=number_of_vectors,
                prefix=ARXIV_PAPERS_PREFIX_KEY,
                distance_metric="IP",
            )
        print("Search index created")

In [None]:
redis_async_conn = redis.asyncio.from_url(REDIS_URL)
await create_index(
        redis_async_conn,
        INDEX_NAME,
        INDEX_TYPE,
        INITIAL_NUMBER_OF_VECTORS
)
