# arXiv Paper Embedding


## Multi GPU w/ Dask + CUDF
Using Dask and CuDF to orchestrate sentence embedding over multiple GPU workers.

![Rapids and Dask Logos](https://saturn-public-assets.s3.us-east-2.amazonaws.com/example-resources/rapids_dask.png "doc-image")

### Important Imports

* [`dask_saturn`](https://github.com/saturncloud/dask-saturn) and [`dask_distributed`](http://distributed.dask.org/en/stable/): Set up and run the Dask cluster in Saturn Cloud.
* [`dask-cudf`](https://docs.rapids.ai/api/cudf/stable/basics/dask-cudf.html): Create distributed `cudf` dataframes using Dask.

In [1]:
import dask_cudf
import cudf
import json
import os
import re
import string

from dask_saturn import SaturnCluster
from dask.distributed import Client, wait
import redis

from dask.distributed import get_worker
import numpy as np

import typing as t
from redis import Redis
from redis.commands.search.field import VectorField, TagField
from redis.commands.search.query import Query

In [19]:
password=os.environ['PASSWORD']
port=os.environ['PORT']
host=os.environ['REDIS_HOST']


Modifier les histoires de secret

In [16]:
r = redis.StrictRedis(host=host, port=19013, db=0, password=password,decode_responses=True)


In [17]:
r

Redis<ConnectionPool<Connection<host=redis-19013.c21912.eu-west1-2.gcp.cloud.rlrcp.com,port=19013,db=0>>>

import uuid
import json

class RedisConnection(object):
    def __init__(self, host,port,password,db=0):
        self.__redis_url=f"redis://:{password}@{host}:{port}/{db}"
        self.__namespace="article"
        self.__db=db
        
    def put(self,queue,item):
        article_id = uuid.uuid1()
        self.__db.rpush(f'{self.namespace}:{queue}:{article_id}', json.dumps(item))
    
    def get(self,queue,article_id,timeout=None):
        key=f'{namespace}:{queue}:{article_id}'
        item=self.__db.blpop(key,timeout=timeout)
        if item is not None:
            try:
                item=json.loads(item[1])
            except ValueError as e:
                sys.stderr.write(f"Error with the json in queue ({str(item)}): {str(e)}")
                return None
            return item
        return None
        

In [4]:
n_workers = 3
cluster = SaturnCluster(n_workers=n_workers)
client = Client(cluster)

INFO:dask-saturn:Cluster is ready
INFO:dask-saturn:Registering default plugins


In [5]:
client.wait_for_workers(n_workers=n_workers)

In [6]:
def clean_description(description: str):
    if not description:
        return ""
    # remove unicode characters
    description = description.encode('ascii', 'ignore').decode()

    # remove punctuation
    description = re.sub('[%s]' % re.escape(string.punctuation), ' ', description)

    # clean up the spacing
    description = re.sub('\s{2,}', " ", description)

    # remove urls
    #description = re.sub("https*\S+", " ", description)

    # remove newlines
    description = description.replace("\n", " ")

    # remove all numbers
    #description = re.sub('\w*\d+\w*', '', description)

    # split on capitalized words
    description = " ".join(re.split('(?=[A-Z])', description))

    # clean up the spacing again
    description = re.sub('\s{2,}', " ", description)

    # make all words lowercase
    description = description.lower()

    return description

In [7]:
def process_papers(papers:list):
    def process(paper: dict):
        if paper['journal_ref']:
            years = [int(year) for year in re.findall(YEAR_PATTERN, paper['journal_ref'])]
            years = [year for year in years if (year <= 2022 and year >= 1991)]
            year = min(years) if years else None
        else:
            year = None
        return {
            'id': paper['id'],
            'title': paper['title'],   
            'authors': paper['authors'],
            'year': year,
            'categories': ','.join(paper['categories'].split(' ')),
            'abstract': paper['abstract'],
            'input': clean_description(paper['title'] + ' ' + paper['abstract'])
        }
    for paper in papers:
        paper = process(paper)
        if paper['year']:
            yield paper


In [8]:
async def create_hnsw(
        self,
        *fields,
        redis_conn: Redis,
        number_of_vectors: int,
        prefix: str,
        distance_metric: str = 'COSINE'
    ):
        """
        Create an approximate NN index via HNSW.
        Args:
            redis_conn (Redis): Redis connection object.
            number_of_vectors (int): Count of the number of initial vectors.
            prefix (str): key prefix to use for RediSearch index creation.
            distance_metric (str, optional): Distance metric to use for Vector Search. Defaults to 'COSINE'.
        """
        vector_field = VectorField(
            "vector",
            "HNSW", {
                "TYPE": "FLOAT32",
                "DIM": 768,
                "DISTANCE_METRIC": distance_metric,
                "INITIAL_CAP": number_of_vectors,
            }
        )
        await self._create(
            *fields,
            vector_field,
            redis_conn=redis_conn,
            prefix=prefix
        )

In [9]:
def embed_partition(df: dask_cudf.DataFrame):
    """
    Create embeddings on single partition of DF (one dask worker)
    """
    worker = get_worker()
    if hasattr(worker, "model"):
        model = worker.model
    else:
        from sentence_transformers import SentenceTransformer

        model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')
        worker.model = model

    print("embedding input", flush=True)
        
    # embed the input      
    vectors = model.encode(
        sentences = df.input.values_host,
        normalize_embeddings = True,
        show_progress_bar = True
    )
    
    # Convert to cudf series and return
    df['vector'] = cudf.Series(vectors.tolist(), index=df.index)
    return df[['id', 'vector']]

def clear_workers():
    """
    Deletes model attribute, freeing up memory on the Dask workers
    """
    import torch
    import gc

    worker = get_worker()
    if hasattr(worker, "model"):
        del worker.model
    torch.cuda.empty_cache()
    gc.collect()
    return

In [18]:
len(r.keys())

1101

In [26]:
r.keys()

['/arxiv/papers/0704.0860',
 '/arxiv/papers/0704.0777',
 '/arxiv/papers/0704.0104',
 '/arxiv/papers/0704.0611',
 '/arxiv/papers/0704.0960',
 '/arxiv/papers/0704.0277',
 '/arxiv/papers/0704.0723',
 '/arxiv/papers/0704.0381',
 '/arxiv/papers/0704.0042',
 '/arxiv/papers/0704.0628',
 '/arxiv/papers/0704.0350',
 '/arxiv/papers/0704.0407',
 '/arxiv/papers/0704.0646',
 '/arxiv/papers/0704.0001',
 '/arxiv/papers/0704.0643',
 '/arxiv/papers/0704.0423',
 '/arxiv/papers/0704.0515',
 '/arxiv/papers/0704.0925',
 '/arxiv/papers/0704.1048',
 '/arxiv/papers/0704.0826',
 '/arxiv/papers/0704.0208',
 '/arxiv/papers/0704.1093',
 '/arxiv/papers/0704.0650',
 '/arxiv/papers/0704.0781',
 '/arxiv/papers/0704.0358',
 '/arxiv/papers/0704.0275',
 '/arxiv/papers/0704.0430',
 '/arxiv/papers/0704.0125',
 '/arxiv/papers/0704.1054',
 '/arxiv/papers/0704.0464',
 '/arxiv/papers/0704.0892',
 '/arxiv/papers/0704.0719',
 '/arxiv/papers/0704.0543',
 '/arxiv/papers/0704.1004',
 '/arxiv/papers/0704.0386',
 '/arxiv/papers/0704

In [29]:
r.llen('papers_to_process')

0

In [24]:
for el in ['0704.0101','0704.0107','0704.0023',]:
    r.lpush('/arxiv/papers/',el)

In [32]:
YEAR_PATTERN = r"(19|20[0-9]{2})"
QUEUE='papers_to_process'
while True:
    article_ids=r.rpop(QUEUE, 1000)
    if not article_ids:
        print("Queue emptied")
        break
    papers=[r.hgetall(f'{QUEUE}/{article_id}') for article_id in article_ids]
    cdf = cudf.DataFrame(list(process_papers(papers)))
    ddf = dask_cudf.from_cudf(cdf, npartitions=n_workers)
    output_df = ddf[["id", "input"]].map_partitions(
        func = embed_partition,
        meta = {
          "id": object,
          "vector": cudf.ListDtype('float32')
        }
    )
    full_ddf = ddf.merge(output_df)
    full_ddf=full_ddf.compute().to_pandas()
    
    for _,row in full_ddf.iterrows():
        r.hset(f'{QUEUE}/{str(row["id"])}',mapping= {
            "paper_id":row["id"],
            "year":row['year'],
            "categories":row["categories"],
            "vector":np.array(row["vector"],dtype=np.float32).tobytes()        
        })
    
create_hnsw_index(r, r.llen('QUEUE'), distance_metric="COSINE") 

Queue emptied
