In [None]:
import ray
import pandas as pd
import numpy as np
from time import time
import torch
from transformers import AutoTokenizer, AutoModel
import torch.nn.functional as F
from optimum.bettertransformer import BetterTransformer

ray.init(
    runtime_env={
        "env_vars": {
          "TOKENIZERS_PARALLELISM": "false",
          "AIM_UI_TELEMETRY_ENABLED": "0",
        },
    },
)

In [None]:
class EmbedOptimum:
    """Main actor class used to generate embeddings."""
    def __init__(self, model_name="thenlper/gte-large", chunk_size=512):
        self.model_name = model_name
        self.chunk_size = chunk_size
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)

        self.model = AutoModel.from_pretrained(model_name).to("cuda")
        self.model = BetterTransformer.transform(self.model)
        self.model = torch.compile(self.model)

    def _average_pool(self, last_hidden_states, attention_mask):
        last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0)
        return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None]
    
    def __call__(self, batch):
        batch_text = batch["text"].tolist()
        # Directly return torch tensors.
        tokenize_results = self.tokenizer(batch_text, max_length=self.chunk_size, padding=True, truncation=True, return_tensors="pt")

        # can use either torch.no_grad() or torch.inference_mode() here, not much difference in this case.
        # this significantly reduces GRAM usage and prevents OOMing during embed model call.
        with torch.no_grad():
            model_input = {
                "input_ids": tokenize_results["input_ids"],
                "token_type_ids": tokenize_results["token_type_ids"],
                "attention_mask": tokenize_results["attention_mask"],
            }
            model_input = {k: v.to("cuda") for k, v in model_input.items()}

            outputs = self.model(**model_input)
            embeddings = self._average_pool(outputs.last_hidden_state, model_input['attention_mask'])
            embeddings = F.normalize(embeddings)

            batch["values"] = embeddings.detach().cpu().numpy().astype(np.float32)
            return batch


def flatten_metadata_col(row):
    """Helper function to "pop out" metadata key/value pairs into their own columns.
    This allows us to use numpy as the default batch format instead of pandas,
    which can be more memory demanding."""
    for k, v in row["metadata"].items():
        row["metadata_" + k] = v
    del row["metadata"]
    return row

In [None]:
# Get chunked input files to read.

# CSV file with a single column `files`, containing URLs to parquet files 
# containing chunked input data.
# These are files generated from notebook `0_chunk_raw_inputs.ipynb`.
chunked_file_paths = pd.read_csv("CHUNKED-FILES-PATHS.csv")['files'].tolist()
NUM_CHUNKED_FILES = len(chunked_file_paths)

In [None]:
from ray.data._internal.execution.backpressure_policy import (
    ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY,
    ConcurrencyCapBackpressurePolicy,
    StreamingOutputBackpressurePolicy,
)

# Number of (GPU) workers used to generate embeddings.
embed_num_workers = 16

# Concurrency cap to improve stability of long running job.
# With 16 workers, use a max read parallelism of 200.
# Scale accordingly based on number of GPU workers.
embed_batch_size = 1000
MAX_CONCURRENCY_PER_16_WORKERS = 200
MAX_CONCURRENCY = max(
    int(MAX_CONCURRENCY_PER_16_WORKERS * embed_num_workers / 16), 
    embed_num_workers,
)

configs = {
    ENABLED_BACKPRESSURE_POLICIES_CONFIG_KEY: [
        ConcurrencyCapBackpressurePolicy,
        StreamingOutputBackpressurePolicy,
    ],
    ConcurrencyCapBackpressurePolicy.INIT_CAP_CONFIG_KEY: MAX_CONCURRENCY,
    ConcurrencyCapBackpressurePolicy.CAP_MULTIPLIER_CONFIG_KEY: 1,
}

ctx = ray.data.DataContext.get_current()
for k, v in configs.items():
    ctx.set_config(k, v)
# Set the max target block size and batch size in the `map_batches` call below
# in order to control the output file sizes in the hundreds of MBs.
ctx.target_max_block_size = 300_000_000

# Allow for up to 1 write failure per block, since the job is long.
ctx.max_errored_blocks = NUM_CHUNKED_FILES

In [None]:
# Main Ray Data code.
print(f"===> Starting embedding for {len(chunked_file_paths)} chunked files")

start_t = time()
chunked_ds_read = ray.data.read_parquet(chunked_file_paths)
chunked_ds_read = chunked_ds_read.map(flatten_metadata_col)
embedded_ds = (
    chunked_ds_read.map_batches(
        EmbedOptimum,
        concurrency=embed_num_workers,
        batch_size=embed_batch_size,
        num_gpus=1,  # 1 GPU for each actor.
        max_concurrency=2, # Reduce GPU idle time.
    # This second map_batches call is used to control the output file sizes.
    ).map_batches(lambda x: x, batch_size=10_000)
)

embeddings_output_path = "YOUR-OUTPUT-BUCKET-HERE"
embedded_ds.write_parquet(embeddings_output_path)
print(f"===> Finished embedding {len(chunked_file_paths)} chunked files in {time() - start_t} s")
print(f"===> Wrote output files to {embeddings_output_path}")
