### Library imports

In [None]:
import json
import re
from typing import Any, Dict, List

from datasets import load_dataset

from rapidfireai.infer.experiment import Experiment
from rapidfireai.infer.rag.context_generator import ContextGenerator

### Model config and Sampling Config

In [None]:
from rapidfireai.infer.utils.config import VLLMModelConfig

pipeline = VLLMModelConfig(
    model_config = {
        "model": "Qwen/Qwen2.5-0.5B-Instruct",
        "dtype": "half",
        "gpu_memory_utilization": 0.9,
        "tensor_parallel_size": 1,
        "distributed_executor_backend": "mp",
        "enable_chunked_prefill": True,
        "enable_prefix_caching": True,
        "max_model_len": 2048,
        "disable_log_stats": True,  # Disable VLLM progress logging
    },
    sampling_params={
        "temperature": 0.8,
        "top_p": 0.95,
        "max_tokens": 512,
    },
    context_generator=None
)

### Dataset

In [None]:
# Use test split for evaluation (not train)
dataset = load_dataset("openai/gsm8k", "main", split="train")
print(f"Loaded {len(dataset)} test samples")

In [None]:
# Adding id column to the dataset for testing online aggregation
dataset = dataset.add_column("id", list(range(1, len(dataset) + 1)))
dataset = dataset.shuffle(seed=1337) # Shuffling the dataset to ensure randomness

### Utility, Preprocessor, Postprocessor, Compute Metrics

In [None]:
def extract_solution(answer):
    solution = re.search("#### (\\-?[0-9\\.\\,]+)", answer)
    if solution is None:
        return "0"
    final_solution = solution.group(0)
    final_solution = final_solution.split("#### ")[1].replace(",", "")
    return final_solution

def preprocess_fn(batch: Dict[str, List], context_generator: ContextGenerator) -> Dict[str, List]:
    return {
        "prompts": [
            [
                {"role": "system", "content": 'Let\'s think step by step and output the final answer after "####".'},
                {"role": "user", "content": question}
            ]
            for question in batch["question"]
        ],
        **batch,
    }

def postprocess_fn(batch: Dict[str, List]) -> Dict[str, List]:
    batch["model_answer"] = [extract_solution(answer) for answer in batch["generated_text"]]
    batch["ground_truth"] = [extract_solution(answer) for answer in batch["answer"]]
    return batch

def compute_metrics_fn(batch: Dict[str, List]) -> Dict[str, Dict[str, Any]]:
    correct = sum(1 for pred, gt in zip(batch["model_answer"], batch["ground_truth"])
                  if pred == gt)
    total = len(batch["model_answer"])
    sum_n = sum(id for id in batch["id"])
    return {
        "Correct": {"value": correct},
        "Total": {"value": total},
        "SumN": {"value": sum_n}
    }

def accumulate_metrics_fn(aggregated_metrics: Dict[str, List]) -> Dict[str, Dict[str, Any]]:
    # aggregated_metrics is a dict of lists: {"Correct": [5, 3, 7], "Total": [10, 8, 12]}
    correct = sum(m.get("value", 0) for m in aggregated_metrics.get("Correct", [{}]))
    total = sum(m.get("value", 0) for m in aggregated_metrics.get("Total", [{}]))
    sum_n = sum(m.get("value", 0) for m in aggregated_metrics.get("SumN", [{}]))
    avg_n = float(sum_n) / total if total > 0 else 0
    accuracy = correct / total if total > 0 else 0
    return {
        "Total": {"value": total},
        "Correct": {"value": correct, "is_distributive": True, "value_range": (0, 1)}, # 0 (min) if not correct, 1 if correct (max)
        "SumN": {"value": sum_n, "is_distributive": True, "value_range": (0, 7472)}, # each sample can a have value ranging from 0 to 7472
        "AvgN": {"value": avg_n, "is_algebraic": True, "value_range": (0, 7472)},
        "Accuracy": {"value": accuracy, "is_algebraic": True, "value_range": (0, 1)} # Algebraic metric for online aggregation
    }

### Create Experiment

In [None]:
experiment = Experiment(experiment_name="trial-online", num_actors=8)

### Run Experiment

In [None]:
aggregated_results, metrics = experiment.run_evals(
    pipeline,
    dataset,
    batch_size=128,  # Per actor batch size
    preprocess_fn=preprocess_fn,
    postprocess_fn=postprocess_fn,
    compute_metrics_fn=compute_metrics_fn,
    accumulate_metrics_fn=accumulate_metrics_fn,
    online_strategy_kwargs={"strategy_name": "normal", "confidence_level": 0.95, "use_fpc": True}
)

In [None]:
aggregated_results, metrics = experiment.run_evals(
    pipeline,
    dataset,
    batch_size=128,  # Per actor batch size
    preprocess_fn=preprocess_fn,
    postprocess_fn=postprocess_fn,
    compute_metrics_fn=compute_metrics_fn,
    accumulate_metrics_fn=accumulate_metrics_fn,
    online_strategy_kwargs={"strategy_name": "wilson", "confidence_level": 0.95, "use_fpc": True}
)

In [None]:
aggregated_results, metrics = experiment.run_evals(
    pipeline,
    dataset,
    batch_size=128,  # Per actor batch size
    preprocess_fn=preprocess_fn,
    postprocess_fn=postprocess_fn,
    compute_metrics_fn=compute_metrics_fn,
    accumulate_metrics_fn=accumulate_metrics_fn,
    online_strategy_kwargs={"strategy_name": "hoeffding", "confidence_level": 0.95, "use_fpc": True}
)

### End Experiment

In [None]:
experiment.end()

### View Results

In [None]:
print(f"\nResults:")
print(json.dumps(metrics, indent=4))

In [None]:
print(f"\nFirst few examples:")
for i in range(min(3, metrics['Samples Processed']['value'])):
    print(f"\nExample {i+1}:")
    print(f"Question: {aggregated_results['question'][i]}")
    print(f"Ground truth: {aggregated_results['ground_truth'][i]}")
    print(f"Model answer: {aggregated_results['model_answer'][i]}")
    print(f"Generated text: {aggregated_results['generated_text'][i]}")