# arXiv Paper Embedding


## Multi GPU w/ Dask + CUDF
Using Dask and CuDF to orchestrate sentence embedding over multiple GPU workers.

![Rapids and Dask Logos](https://saturn-public-assets.s3.us-east-2.amazonaws.com/example-resources/rapids_dask.png "doc-image")

### Important Imports

* [`dask_saturn`](https://github.com/saturncloud/dask-saturn) and [`dask_distributed`](http://distributed.dask.org/en/stable/): Set up and run the Dask cluster in Saturn Cloud.
* [`dask-cudf`](https://docs.rapids.ai/api/cudf/stable/basics/dask-cudf.html): Create distributed `cudf` dataframes using Dask.

In [1]:
import json
import os
import re
import string
import pickle

import numpy as np
import cudf
import dask_cudf
from dask_saturn import SaturnCluster
from dask.distributed import Client, wait, get_worker

from askyves.embedder import clean_description


DATA_PATH = "data/arxiv-metadata-oai-snapshot.json"
YEAR_CUTOFF = 2022
YEAR_PATTERN = r"(19|20[0-9]{2})"

PICKLED_DF_PATH = "data/cdf.pkl"
MODEL_NAME = "all-mpnet-base-v2"
OUTPUT_EMBEDDINGS_PATH = "data/arxiv_embedings.pkl"


### Start the Dask Cluster

The template resource you are running has a Dask cluster already attached to it with three workers. The `dask-saturn` code below creates two important objects: a cluster and a client.

* `cluster`: knows about and manages the scheduler and workers
    - can be used to create, resize, reconfigure, or destroy those resources
    - knows how to communicate with the scheduler, and where to find logs and diagnostic dashboards
* `client`: tells the cluster to do things
    - can send work to the cluster
    - can restart all the worker processes
    - can send data to the cluster or pull data back from the cluster

In [None]:
n_workers = 4
cluster = SaturnCluster(n_workers=n_workers)
client = Client(cluster)

If you already started the Dask cluster on the resource page, then the code above will run much more quickly since it will not have to wait for the cluster to turn on.

>**Pro tip**: Create and start the cluster in the Saturn Cloud UI before opening JupyterLab if you want to get a head start!

The last command ensures the kernel waits until all the desired workers are online before continuing.

In [30]:
client.wait_for_workers(n_workers=n_workers)

In [7]:
def process(paper: dict):
    paper = json.loads(paper)
    if paper['journal-ref']:
        years = [int(year) for year in re.findall(YEAR_PATTERN, paper['journal-ref'])]
        years = [year for year in years if (year <= 2022 and year >= 1991)]
        year = min(years) if years else None
    else:
        year = None
    return {
        'id': paper['id'],
        'title': paper['title'],
        'year': year,
        'authors': paper['authors'],
        'categories': ','.join(paper['categories'].split(' ')),
        'abstract': paper['abstract'],
        'update_date': paper["update_date"],
        "doi": paper["doi"],
        "journal-ref": paper["journal-ref"],
        "submitter": paper["submitter"],
        'input': clean_description(paper['title'] + ' ' + paper['abstract'])
    }

def papers():
    with open(DATA_PATH, 'r') as f:
        for paper in f:
            paper = process(paper)
            if paper['year']:
                yield paper


In [9]:
cdf = cudf.DataFrame(list(papers()))

In [18]:
# Pro Tip: Pickle the dataframe
# This might save you time in the future so you don't have to do all of that again
with open(PICKLED_DF_PATH, 'wb') as f:
    pickle.dump(cdf, f)

In [31]:
# Load pickle
# with open(PICKLED_DF_PATH, 'rb') as f:
#     cdf = pickle.load(f)

## Using Dask to parallelize things

In [34]:
# Convert our CuDF to a Dask-CuDF
ddf = dask_cudf.from_cudf(cdf, npartitions=n_workers).persist()

In [35]:
def embed_partition(df: dask_cudf.DataFrame, model_name: str=MODEL_NAME):
    """
    Create embeddings on single partition of DF (one dask worker)
    """
    worker = get_worker()
    if hasattr(worker, "model"):
        model = worker.model
    else:
        from sentence_transformers import SentenceTransformer

        model = SentenceTransformer("sentence-transformers/{model_name}")
        worker.model = model

    print("embedding input", flush=True)
        
    # embed the input      
    vectors = model.encode(
        sentences = df.input.values_host,
        normalize_embeddings = True,
        show_progress_bar = True
    )
    
    # Convert to cudf series and return
    df['vector'] = cudf.Series(vectors.tolist(), index=df.index)
    return df[['id', 'vector']]

def clear_workers():
    """
    Deletes model attribute, freeing up memory on the Dask workers
    """
    import torch
    import gc

    worker = get_worker()
    if hasattr(worker, "model"):
        del worker.model
    torch.cuda.empty_cache()
    gc.collect()
    return

In [None]:
output_df = ddf[["id", "input"]].map_partitions(
    func = embed_partition,
    meta = {
      "id": object,
      "vector": cudf.ListDtype('float32')
    }
)
# Gather results
output_df = output_df.persist()
%time _ = wait(output_df)

In [39]:
full_ddf = ddf.merge(output_df)

In [None]:
str_cols = list(full_ddf.columns)
str_cols.remove("vector")
full_ddf[str_cols] = full_ddf[str_cols].fillna('').astype(str)
full_ddf = full_ddf.dropna()

In [43]:
with open(OUTPUT_EMBEDDINGS_PATH, 'wb') as f:
    pickle.dump(full_ddf.compute().to_pandas(), f)

In [28]:
# Cleanup dask worker RAM
client.run(clear_workers)