In [3]:
import os
import time
import asyncio
from typing import Dict, Set
from dotenv import load_dotenv
import pandas as pd
from datasets import load_dataset
from utils.utils import load_processed_ids, save_success
from utils.inference import LLMInferenceClient, ProviderConfig

In [4]:
load_dotenv()

True

In [5]:
ds = load_dataset("nvidia/HelpSteer3", "feedback", split="train")

In [7]:
ds

Dataset({
    features: ['domain', 'language', 'context', 'response1', 'response2', 'feedback1', 'feedback2'],
    num_rows: 38782
})

In [20]:
PROMPT_TEMPLATE = """
    Feedback: {feedback}

    Generate a list of principles that the response is evaluated against in the feedback.
    For each principle, identify a text span from the feedback relating to this principle and then state whether the text span suggests that the response satisfies the principle yes/no/partially.

    Return it as a json dictionary in the format {{"<principle 1>": "<supporting text span>-<yes/no/partially>", "<principle 2>": "<supporting text span>-<yes/no/partially>".}}
"""

In [21]:
providers = {
    "hf": ProviderConfig(
        model_name="Qwen/Qwen2.5-72B-Instruct",
        api_key=os.getenv("HF_TOKEN"),
        max_retries=5,
        
    ),
    "groq": ProviderConfig(
        model_name="llama3-70b-8192",
        api_key=os.getenv("GROQ_API_KEY"),
        max_retries=5,
        
    ),
    "deepinfra": ProviderConfig(
        model_name="Qwen/Qwen2.5-72B-Instruct",
        api_key=os.getenv("DEEPINFRA_API_KEY"),
        max_retries=5,
    ),
}

In [22]:
client = LLMInferenceClient(providers)

In [23]:
OUTPUT_FILE = "./data/extracted_principles.jsonl"
CHECKPOINT_FILE = "./data/checkpoint_ids.txt"

In [24]:
async def process_sample(
    client,
    sample: Dict,
    index: int,
    total: int,
    processed_ids: Set[str],
    output_file: str,
    checkpoint_file: str,
    prompt_template: str,
    provider: str = 'deepinfra',
    verbose: bool = False
):
    """Process a single sample asynchronously."""
    try:
        # Unique ID for checkpointing
        p_id = str(sample.get('prompt_id', f"idx_{index}"))
        
        # SKIP if already done
        if p_id in processed_ids:
            if verbose:
                print(f"[{index+1}/{total}] Skipping ID {p_id} (already processed)")
            return {"status": "skipped", "reason": "already_processed"}
            
        feedback_text = sample.get('feedback1', '')
        
        # Handle if feedback is a list (multiple feedbacks)
        if isinstance(feedback_text, list):
            # Join all feedbacks with newlines or process each separately
            feedback_text = '\n\n'.join(str(f) for f in feedback_text if f)
        
        if not feedback_text:
            if verbose:
                print(f"[{index+1}/{total}] Skipping ID {p_id} (no feedback)")
            return {"status": "skipped", "reason": "no_feedback"}

        print(f"[{index+1}/{total}] Processing ID: {p_id}...", end="", flush=True)
        
        # Call API (this is where concurrency happens!)
        extracted_json = await client.infer(
            feedback_text, 
            prompt_template, 
            provider=provider
        )
        
        if not extracted_json:
            print(f" API returned None/empty")
            return {"status": "error", "error": "API returned empty"}
        
        # Prepare Result
        result_entry = {
            "prompt_id": p_id,
            "original_feedback": feedback_text,
            "extracted_json": extracted_json
        }
        
        # Save immediately
        save_success(result_entry, p_id, output_file, checkpoint_file)
        
        print(" Done.")
        return {"status": "success", "p_id": p_id}
        
    except Exception as e:
        print(f"\n[ERROR] Sample {index}: {type(e).__name__}: {str(e)}")
        import traceback
        traceback.print_exc()
        return {"status": "error", "error": str(e)}


In [25]:
async def main_concurrent(
    client,
    ds,
    checkpoint_file: str,
    output_file: str,
    prompt_template: str,
    provider: str = 'deepinfra',
    batch_size: int = 10  # Process N samples concurrently
):
    """Main async processing function with concurrent batching."""
    processed_ids = load_processed_ids(checkpoint_file)
    print(f"Resuming... {len(processed_ids)} samples already processed.")
    
    total_samples = len(ds)
    
    # Process in batches to avoid overwhelming the API
    for batch_start in range(0, total_samples, batch_size):
        batch_end = min(batch_start + batch_size, total_samples)
        batch = ds.select(range(batch_start, batch_end))
        
        print(f"\n=== Processing batch {batch_start//batch_size + 1} ({batch_start+1}-{batch_end}/{total_samples}) ===")
        
        # Create tasks for concurrent processing
        tasks = [
            process_sample(
                client=client,
                sample=sample,
                index=batch_start + i,
                total=total_samples,
                processed_ids=processed_ids,
                output_file=output_file,
                checkpoint_file=checkpoint_file,
                prompt_template=prompt_template,
                provider=provider
            )
            for i, sample in enumerate(batch)
        ]
        
        # Run all tasks concurrently
        results = await asyncio.gather(*tasks, return_exceptions=True)
        
        # Update processed_ids for next batch
        processed_ids = load_processed_ids(checkpoint_file)
        
        # Optional: Print batch summary
        successful = sum(1 for r in results if r is not None and not isinstance(r, Exception))
        print(f"Batch complete: {successful}/{len(tasks)} successful")

        await asyncio.sleep(1)   # Rate limiting between batches

In [None]:
await main_concurrent(
    client=client,
    ds=ds,
    checkpoint_file=CHECKPOINT_FILE,
    output_file=OUTPUT_FILE,
    prompt_template=PROMPT_TEMPLATE,
    provider='deepinfra',
    batch_size=100,
)

Resuming... 10248 samples already processed.

=== Processing batch 1 (1-100/38782) ===
Batch complete: 100/100 successful

=== Processing batch 2 (101-200/38782) ===
Batch complete: 100/100 successful

=== Processing batch 3 (201-300/38782) ===
Batch complete: 100/100 successful

=== Processing batch 4 (301-400/38782) ===
Batch complete: 100/100 successful

=== Processing batch 5 (401-500/38782) ===
Batch complete: 100/100 successful

=== Processing batch 6 (501-600/38782) ===
Batch complete: 100/100 successful

=== Processing batch 7 (601-700/38782) ===
Batch complete: 100/100 successful

=== Processing batch 8 (701-800/38782) ===
Batch complete: 100/100 successful

=== Processing batch 9 (801-900/38782) ===
Batch complete: 100/100 successful

=== Processing batch 10 (901-1000/38782) ===
Batch complete: 100/100 successful

=== Processing batch 11 (1001-1100/38782) ===
Batch complete: 100/100 successful

=== Processing batch 12 (1101-1200/38782) ===
Batch complete: 100/100 successful

