In [None]:
from rayllm_batch.workload import ChatWorkloadBase
from typing import Optional, Dict, Any
import ray 
from ray.data.dataset import Dataset
from dataclasses import dataclass, field


@dataclass
class CNNDailySummary(ChatWorkloadBase):
    """The CNN/DailyMail summarization workload."""

    # We directly load the dataset from Hugging Face.
    dataset_file: Optional[str] = None
    # We will load only a portion of the dataset to run inference faster for the tutorial.
    dataset_fraction: float = 0.0005 # 0.2% of the 300K entries.
    # The sampling params for the LLM inference workload.
    sampling_params: Dict[str, Any] = field(default_factory=lambda: {"max_tokens": 200})

    def load_dataset(self) -> Dataset:
        # Load the dataset from Hugging Face into Ray Data.
        import datasets  # type: ignore

        df = datasets.load_dataset("cnn_dailymail", "3.0.0")
        return ray.data.from_huggingface(df["train"])

    def parse_row(self, row: dict[str, Any]) -> dict[str, Any]:
        # Parse the row into the format expected by the model.
        # We will use the article as the user prompt, and ask the model to 
        # generate a summary with the system prompt.
        return {
            "messages": [
                {
                    "role": "system",
                    "content": "You are a commentator. Your task is to "
                    "summarize highlights from article.",
                },
                {
                    "role": "user",
                    "content": f"# Article:\n{row['article']}\n\n"
                    "#Instructions:\nIn clear and concise language, "
                    "summarize the highlights presented in the article.",
                },
            ]
        }

In [None]:
# Prompts the user for Hugging Face token if required by the model.
from util.utils import prompt_for_hugging_face_token
HF_TOKEN = prompt_for_hugging_face_token("meta-llama/Meta-Llama-3.1-8B-Instruct")

In [None]:
from rayllm_batch import init_engine_from_config
# Read the model configs from the path.
model_config_path = "configs/llama-3.1-8b-a10g.yaml"

# One could potentially override the engine configs by passing in a dictionary here.
override = {"runtime_env": {"env_vars": {"HF_TOKEN": HF_TOKEN}}} # Override Ray's runtime env to include the Hugging Face token. Ray is being used under the hood to orchestrate the inference pipeline.
engine_config = init_engine_from_config(config=model_config_path, override=override)


In [None]:
from rayllm_batch import RayLLMBatch


workload = CNNDailySummary()
batch = RayLLMBatch(
    engine_cfg=engine_config,
    workload=workload,
    # Specify the batch size for inference. Set the batch size to as large as possible without running out of memory.
    # If you encounter out-of-memory errors, decreasing batch_size may help. 
    batch_size=None,
    # Set the number of replicas to use for the inference. Each replica will run one instance of inference pipeline.
    num_replicas=1,
)


# This will runs until completion.
ds = batch.run()


# Read the results
gen_texts = [r["generated_text"] for r in ds.take_all()]
print(gen_texts)