In [None]:
import io
import os
from concurrent.futures import ThreadPoolExecutor
from typing import Any, Literal, Type
import numpy as np
import requests
import ray
import torchvision
from enum import Enum
import typer
import pandas as pd
import pyarrow as pa
from pyarrow import csv
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
from transformers import AutoTokenizer
from vllm.multimodal.image import ImagePixelData
from vllm import LLM, SamplingParams
from PIL import Image
from pymongo import MongoClient, UpdateOne

# Reinventing Multi-Modal Search with Anyscale and MongoDB

## Data processing pipeline tutorial

__Let's look at the data flow logic__

<img src='https://images.ctfassets.net/xjan103pcp94/Kb1UzXpXsig64xD1f29jE/945c113fd16f1badfd6d8c4400967062/image14.png' width=800px />

__Let's look at the scaling opportunites__

<img src='https://images.ctfassets.net/xjan103pcp94/4qBYwRfwQOD4DjA66N75Gq/8884421a98ea5013533e24f2a17e9cf5/image1.png' width=1000px />

<img src='https://images.ctfassets.net/xjan103pcp94/6kROYvGncNbWlCsfklwEjz/e12cd36ed7be5cd45e603d38cb10dc95/image2.png' width=1000px />

For this tutorial, we'll use a small number of records and a small scaling configuration

In [None]:
nsamples=200

The workers below will corresponds to processes, each assigned a CPU and optionally a GPU. Ray allows more flexibility but we'll keep this pipeline as simples as possible.

In [None]:
num_image_download_workers=2

num_llava_tokenizer_workers=2

num_llava_model_workers=1
# llava_model_accelerator_type=NVIDIA_TESLA_A10G
llava_model_batch_size=80

num_mistral_tokenizer_workers_per_classifier=2

num_mistral_model_workers_per_classifier=1
mistral_model_batch_size=80
# mistral_model_accelerator_type=NVIDIA_TESLA_A10G

num_mistral_detokenizer_workers_per_classifier=2

num_embedder_workers=1
embedding_model_batch_size=80
# embedding_model_accelerator_type=NVIDIA_TESLA_A10G

db_update_batch_size=80

num_db_workers=2

In [None]:
db_name: str = "myntra"
collection_name: str = "myntra-items-offline"
cluster_size: str = "m0"
scaling_config_path: str = ""

### Read raw data

The data comes from a subset of the Myntra retail products dataset: https://www.kaggle.com/datasets/ronakbokaria/myntra-products-dataset

In [None]:
path = 's3://anyscale-public-materials/mongodb-demo/raw/myntra_subset_deduped_10000.csv'

The Ray Data `read` methods use PyArrow in most cases for the physical reads -- the schema here is provided as PyArrow types.

> Learn more about the Apache Arrow project https://arrow.apache.org/docs/python/index.html

In [None]:
def read_data(path: str, nsamples: int) -> ray.data.Dataset:
    ds = ray.data.read_csv(
        path,
        parse_options=csv.ParseOptions(newlines_in_values=True),
        convert_options=csv.ConvertOptions(
            column_types={
                "id": pa.int64(),
                "name": pa.string(),
                "img": pa.string(),
                "asin": pa.string(),
                "price": pa.float64(),
                "mrp": pa.float64(),
                "rating": pa.float64(),
                "ratingTotal": pa.int64(),
                "discount": pa.int64(),
                "seller": pa.string(),
                "purl": pa.string(),
            }
        ),
        override_num_blocks=nsamples,
    )
    return ds.limit(nsamples)

In [None]:
ds = read_data(path, nsamples)

### Preprocess data

The following operations will be transforms applied to our data. 

We'll define them first...
* functions for stateless operations
* classes for operations which reuse state

... and then plug them into our pipeline with Ray's `map_batches` API

In [None]:
def download_image(url: str) -> bytes:
    try:
        response = requests.get(url)
        response.raise_for_status()
        return response.content
    except Exception:
        return b""

def download_images(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    with ThreadPoolExecutor() as executor:
        batch["url"] = batch["img"]
        batch["img"] = list(executor.map(download_image, batch["url"]))  # type: ignore
    return batch

class LargestCenterSquare:
    """Largest center square crop for images."""

    def __init__(self, size: int) -> None:
        self.size = size

    def __call__(self, row: dict[str, Any]) -> dict[str, Any]:
        """Crop the largest center square from an image."""
        img = Image.open(io.BytesIO(row["img"]))

        # First, resize the image such that the smallest side is self.size while preserving aspect ratio.
        img = torchvision.transforms.functional.resize(
            img=img,
            size=self.size,
        )

        # Then take a center crop to a square.
        w, h = img.size
        c_top = (h - self.size) // 2
        c_left = (w - self.size) // 2
        row["img"] = torchvision.transforms.functional.crop(
            img=img,
            top=c_top,
            left=c_left,
            height=self.size,
            width=self.size,
        )

        return row

DESCRIPTION_PROMPT_TEMPLATE = "<image>" * 1176 + (
    "\nUSER: Generate an ecommerce product description given the image and this title: {title}."
    "Make sure to include information about the color of the product in the description."
    "\nASSISTANT:"
)

def gen_description_prompt(row: dict[str, Any]) -> dict[str, Any]:
    title = row["name"]
    row["description_prompt"] = DESCRIPTION_PROMPT_TEMPLATE.format(title=title)

    return row

In [None]:
ds = (
    ds.map_batches(download_images, num_cpus=4, concurrency=num_image_download_workers)
    .filter(lambda x: bool(x["img"]))
    .map(LargestCenterSquare(size=336))
    .map(gen_description_prompt)
    .materialize()
)

In [None]:
ds.take_batch(2)

### Estimate input/output token distribution for LLAVA model

In [None]:
class LlaVAMistralTokenizer:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            "llava-hf/llava-v1.6-mistral-7b-hf",
        )

    def __call__(self, batch: dict[str, np.ndarray], input: str, output: str):
        batch[output] = self.tokenizer.batch_encode_plus(batch[input].tolist())["input_ids"]
        return batch
        
def compute_num_tokens(row: dict[str, Any], col: str) -> dict[str, Any]:
    row["num_tokens"] = len(row[col])
    return row
        
max_input_tokens = (
    ds.map_batches(
        LlaVAMistralTokenizer,
        fn_kwargs={
            "input": "description_prompt",
            "output": "description_prompt_tokens",
        },
        concurrency=num_llava_tokenizer_workers,
        num_cpus=1,
    )
    .select_columns(["description_prompt_tokens"])
    .map(compute_num_tokens, fn_kwargs={"col": "description_prompt_tokens"})
    .max(on="num_tokens")
)

max_output_tokens = 256  # maximum size of desired product description
max_model_length = max_input_tokens + max_output_tokens
print(
    f"Description gen: {max_input_tokens=} {max_output_tokens=} {max_model_length=}"
)

### Generate description using LLAVA model inference

In [None]:
class LlaVAMistral:
    def __init__(
        self,
        max_model_len: int,
        max_num_seqs: int = 400,
        max_tokens: int = 1024,
        # NOTE: "fp8" currently doesn't support FlashAttention-2 backend so while
        # we can fit more sequences in memory, performance will be suboptimal
        kv_cache_dtype: str = "fp8",
    ):
        self.llm = LLM(
            model="llava-hf/llava-v1.6-mistral-7b-hf",
            trust_remote_code=True,
            enable_lora=False,
            max_num_seqs=max_num_seqs,
            max_model_len=max_model_len,
            gpu_memory_utilization=0.95,
            image_input_type="pixel_values",
            image_token_id=32000,
            image_input_shape="1,3,336,336",
            image_feature_size=1176,
            kv_cache_dtype=kv_cache_dtype,
            preemption_mode="swap",
        )
        self.sampling_params = SamplingParams(
            n=1,
            presence_penalty=0,
            frequency_penalty=0,
            repetition_penalty=1,
            length_penalty=1,
            top_p=1,
            top_k=-1,
            temperature=0,
            use_beam_search=False,
            ignore_eos=False,
            max_tokens=max_tokens,
            seed=None,
            detokenize=True,
        )

    def __call__(self, batch: dict[str, np.ndarray], col: str) -> dict[str, np.ndarray]:
        prompts = batch[col]
        images = batch["img"]
        responses = self.llm.generate(
            [
                {
                    "prompt": prompt,
                    "multi_modal_data": ImagePixelData(image),
                }
                for prompt, image in zip(prompts, images)
            ],
            sampling_params=self.sampling_params,
        )

        batch["description"] = [resp.outputs[0].text for resp in responses]  # type: ignore
        return batch

In [None]:
ds = ds.map_batches(
    LlaVAMistral,
    fn_constructor_kwargs={
        "max_model_len": max_model_length,
        "max_tokens": max_output_tokens,
        "max_num_seqs": 400,
    },
    fn_kwargs={"col": "description_prompt"},
    batch_size=llava_model_batch_size,
    num_gpus=1.0,
    concurrency=num_llava_model_workers,
)

In [None]:
ds = ds.materialize()

ds.take_batch(2)

### Generate classifier prompts and tokenize them

In the classification step, we'll reduce the number of classifiers (vs. the full pipeline) to require fewer GPUs

In [None]:
def construct_prompt_classifier(
    row: dict[str, Any],
    prompt_template: str,
    classes: list[str],
    col: str,
) -> dict[str, Any]:
    classes_str = ", ".join(classes)
    title = row["name"]
    description = row["description"]
    row[f"{col}_prompt"] = prompt_template.format(
        title=title,
        description=description,
        classes_str=classes_str,
    )
    return row
    
classifiers: dict[str, Any] = {
    "category": {
        "classes": ["Tops", "Bottoms", "Dresses", "Footwear", "Accessories"],
        "prompt_template": (
            "Given the title of this product: {title} and "
            "the description: {description}, what category does it belong to? "
            "Chose from the following categories: {classes_str}. "
            "Return the category that best fits the product. Only return the category name and nothing else."
        ),
        "prompt_constructor": construct_prompt_classifier,
    },
    "season": {
        "classes": ["Summer", "Winter", "Spring", "Fall"],
        "prompt_template": (
            "Given the title of this product: {title} and "
            "the description: {description}, what season does it belong to? "
            "Chose from the following seasons: {classes_str}. "
            "Return the season that best fits the product. Only return the season name and nothing else."
        ),
        "prompt_constructor": construct_prompt_classifier,
    },
#    "color": {
#        "classes": [
#            "Red",
#            "Blue",
#            "Green",
#            "Yellow",
#            "Black",
#            "White",
#            "Pink",
#            "Purple",
#            "Orange",
#            "Brown",
#            "Grey",
#        ],
#        "prompt_template": (
#            "Given the title of this product: {title} and "
#            "the description: {description}, what color does it belong to? "
#            "Chose from the following colors: {classes_str}. "
#            "Return the color that best fits the product. Only return the color name and nothing else."
#        ),
#        "prompt_constructor": construct_prompt_classifier,
#    },
}

In [None]:
class MistralTokenizer:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.1",
        )

    def __call__(self, batch: dict, input: str, output: str):
        batch[output] = self.tokenizer.apply_chat_template(
            conversation=[[{"role": "user", "content": input_}] for input_ in batch[input]],
            add_generation_prompt=True,
            tokenize=True,
            return_tensors="np",
        )
        return batch

In [None]:
for classifier, classifier_spec in classifiers.items():
    ds = (
        ds.map(
            classifier_spec["prompt_constructor"],
            fn_kwargs={
                "prompt_template": classifier_spec["prompt_template"],
                "classes": classifier_spec["classes"],
                "col": classifier,
            },
        )
        .map_batches(
            MistralTokenizer,
            fn_kwargs={
                "input": f"{classifier}_prompt",
                "output": f"{classifier}_prompt_tokens",
            },
            concurrency=num_mistral_tokenizer_workers_per_classifier,
            num_cpus=1,
        )
        .materialize()
    )

In [None]:
ds.take_batch(2)

### Estimate input/output token distribution for Mistral models

In [None]:
for classifier, classifier_spec in classifiers.items():
    max_output_tokens = (
        ray.data.from_items(
            [
                {
                    "output": max(classifier_spec["classes"], key=len),
                }
            ]
        )
        .map_batches(
            MistralTokenizer,
            fn_kwargs={
                "input": "output",
                "output": "output_tokens",
            },
            concurrency=1,
            num_cpus=1,
        )
        .map(
            compute_num_tokens,
            fn_kwargs={"col": "output_tokens"},
        )
        .max(on="num_tokens")
    )
    # allow for 40 tokens of buffer to account for non-exact outputs e.g "the color is Red" instead of just "Red"
    buffer_size = 40
    classifier_spec["max_output_tokens"] = max_output_tokens + buffer_size

    max_input_tokens = (
        ds.select_columns([f"{classifier}_prompt_tokens"])
        .map(compute_num_tokens, fn_kwargs={"col": f"{classifier}_prompt_tokens"})
        .max(on="num_tokens")
    )
    max_output_tokens = classifier_spec["max_output_tokens"]
    print(f"{classifier=} {max_input_tokens=} {max_output_tokens=}")
    max_model_length = max_input_tokens + max_output_tokens
    classifier_spec["max_model_length"] = max_model_length

In [None]:
class MistralvLLM:
    def __init__(
        self,
        max_model_len: int = 4096,
        max_tokens: int = 2048,
        max_num_seqs: int = 256,
        # NOTE: "fp8" currently doesn't support FlashAttention-2 backend so while
        # we can fit more sequences in memory, performance will be suboptimal
        kv_cache_dtype: str = "fp8",
    ):
        self.llm = LLM(
            model="mistralai/Mistral-7B-Instruct-v0.1",
            trust_remote_code=True,
            enable_lora=False,
            max_num_seqs=max_num_seqs,
            max_model_len=max_model_len,
            gpu_memory_utilization=0.90,
            skip_tokenizer_init=True,
            kv_cache_dtype=kv_cache_dtype,
            preemption_mode="swap",
        )
        self.sampling_params = SamplingParams(
            n=1,
            presence_penalty=0,
            frequency_penalty=0,
            repetition_penalty=1,
            length_penalty=1,
            top_p=1,
            top_k=-1,
            temperature=0,
            use_beam_search=False,
            ignore_eos=False,
            max_tokens=max_tokens,
            seed=None,
            detokenize=False,
        )

    def __call__(
        self, batch: dict[str, np.ndarray], input: str, output: str
    ) -> dict[str, np.ndarray]:
        responses = self.llm.generate(
            prompt_token_ids=[ids.tolist() for ids in batch[input]],
            sampling_params=self.sampling_params,
        )
        batch[output] = [resp.outputs[0].token_ids for resp in responses]  # type: ignore
        return batch


class MistralDeTokenizer:
    def __init__(self) -> None:
        self.tokenizer = AutoTokenizer.from_pretrained(
            "mistralai/Mistral-7B-Instruct-v0.1",
        )

    def __call__(self, batch: dict[str, np.ndarray], key: str) -> dict[str, np.ndarray]:
        batch[key] = self.tokenizer.batch_decode(batch[key], skip_special_tokens=True)
        return batch

In [None]:
def clean_response(
    row: dict[str, Any], response_col: str, classes: list[str]
) -> dict[str, Any]:
    response_str = row[response_col]
    matches = []
    for class_ in classes:
        if class_.lower() in response_str.lower():
            matches.append(class_)
    if len(matches) == 1:
        response = matches[0]
    else:
        response = None
    row[response_col] = response
    return row

### Generate classifier responses using Mistral model inference

In [None]:
for classifier, classifier_spec in classifiers.items():
    ds = (
        ds.map_batches(
            MistralvLLM,
            fn_kwargs={
                "input": f"{classifier}_prompt_tokens",
                "output": f"{classifier}_response",
            },
            fn_constructor_kwargs={
                "max_model_len": classifier_spec["max_model_length"],
                "max_tokens": classifier_spec["max_output_tokens"],
            },
            batch_size=mistral_model_batch_size,
            num_gpus=1.0,
            concurrency=num_mistral_model_workers_per_classifier,
        )
        .map_batches(
            MistralDeTokenizer,
            fn_kwargs={"key": f"{classifier}_response"},
            concurrency=num_mistral_detokenizer_workers_per_classifier,
            num_cpus=1,
        )
        .map(
            clean_response,
            fn_kwargs={
                "classes": classifier_spec["classes"],
                "response_col": f"{classifier}_response",
            },
        )
    )

In [None]:
ds = ds.materialize()

ds.take_batch(2)

### Generate embeddings using embedding model inference

To reduce resource requirements, we'll alter this code vs. the full pipeline, to run the embedder model on CPU. It's not quite as fast but performance is acceptable for small data scales.

In [None]:
class EmbedderSentenceTransformer:
    def __init__(self, model: str = "thenlper/gte-large", device: str = "cuda"):
        self.model = SentenceTransformer(model) # comment out the use of the device param to keep model on CPU

    def __call__(
        self, batch: dict[str, np.ndarray], cols: list[str]
    ) -> dict[str, np.ndarray]:
        for col in cols:
            batch[f"{col}_embedding"] = self.model.encode(  # type: ignore
                batch[col].tolist(), batch_size=len(batch[col])
            )
        return batch
        
ds = ds.map_batches(
    EmbedderSentenceTransformer,
    fn_kwargs={"cols": ["name", "description"]},
    batch_size=embedding_model_batch_size,
    #num_gpus=1.0,
    concurrency=num_embedder_workers,
    #accelerator_type=embedding_model_accelerator_type,
)

In [None]:
ds = ds.materialize()

ds.take_batch(2)

### Upsert records into MongoDB collection

In [None]:
def update_record(batch: dict[str, np.ndarray]) -> dict[str, np.ndarray]:
    batch["_id"] = batch["name"]
    return {
        "_id": batch["_id"],
        "name": batch["name"],
        "img": batch["url"],
        "price": batch["price"],
        "rating": batch["rating"],
        "description": batch["description"],
        "category": batch["category_response"],
        "season": batch["season_response"],
#        "color": batch["color_response"],
        "name_embedding": batch["name_embedding"].tolist(),
        "description_embedding": batch["description_embedding"].tolist(),
    }


In [None]:
class MongoBulkUpdate:
    def __init__(self, db: str, collection: str) -> None:
        client = MongoClient(os.environ["DB_CONNECTION_STRING"])
        self.collection = client[db][collection]

    def __call__(self, batch_df: pd.DataFrame) -> dict[str, np.ndarray]:
        docs = batch_df.to_dict(orient="records")
        bulk_ops = [
            UpdateOne(filter={"_id": doc["_id"]}, update={"$set": doc}, upsert=True)
            for doc in docs
        ]
        self.collection.bulk_write(bulk_ops)
        return {}

In [None]:
(
    ds.map_batches(update_record)
    .map_batches(
        MongoBulkUpdate,
        fn_constructor_kwargs={
            "db": db_name,
            "collection": collection_name,
        },
        batch_size=db_update_batch_size,
        concurrency=num_db_workers,
        num_cpus=0.1,
        batch_format="pandas",
    )
    .materialize()
)