In [None]:
MODEL_NAME = "Qwen/Qwen2.5-72B-Instruct"
NUM_GPUS = 2


In [None]:
MAX_TOKENS = 512
batch_size = 200

import torch
from datetime import datetime
import gc
from vllm import LLM, SamplingParams
from typing import List
from itertools import islice
import json
import time
from pathlib import Path
from tqdm.notebook import tqdm

def run_hf_inference(
    prompts_dict,
    model_name: str,
    model_path,
    temperature: float = 0.01,
    max_tokens: int = MAX_TOKENS,
    tensor_parallel_size: int = NUM_GPUS,
    **sampling_kwargs):
    """
    Run inference on a dictionary of prompts using HuggingFace Transformers.
    
    Args:
        prompts_dict: Dictionary mapping output_file -> list of prompts
        model_name: Name of the model (key in model_configs)
        model_path: Path to the model
        temperature: Sampling temperature
        max_tokens: Maximum tokens to generate
        tensor_parallel_size: Number of GPUs to use
        **sampling_kwargs: Additional sampling parameters
    
    Returns:
        None (saves results to files)
    """    
    print(f"Loading model: {model_name}")
    print(f"Model path: {model_path}")
    print(f"Using {tensor_parallel_size} GPUs")
    
    llm = LLM(
        model=model_path,
        dtype="auto",
        tensor_parallel_size=NUM_GPUS,
        # pipeline_parallel_size=NUM_GPUS,
        trust_remote_code=True,
        gpu_memory_utilization=0.95,
        # max_model_len=50000,
    )

    sampling_params = SamplingParams(
        temperature=0.01,
        max_tokens=MAX_TOKENS,
    )

    # Flatten all prompts and track their sources
    all_prompts = []
    prompt_to_file = []
    
    for output_file, prompts in prompts_dict.items():
        for prompt in prompts:
            all_prompts.append(prompt)
            prompt_to_file.append(output_file)
    
    print(f"\nRunning inference on {len(all_prompts)} prompts")

    # Initialize results dictionary
    file_results = {output_file: [] for output_file in prompts_dict.keys()}
    file_counters = {output_file: 0 for output_file in prompts_dict.keys()}
    
    # Process prompts in batches
    for i in tqdm(range(0, len(all_prompts), batch_size), desc="Inference Batches"):
        batch_prompts = all_prompts[i:i + batch_size]
        batch_files = prompt_to_file[i:i + batch_size]
        
        # Generate texts
        generated_texts = llm.generate(batch_prompts, sampling_params)
        
        # Collect results for this batch
        for prompt, generated_text, output_file in zip(batch_prompts, generated_texts, batch_files):
            result = {
                "prompt_question_index": file_counters[output_file],
                "prompt": prompt,
                "response": generated_text.outputs[0].text,
                "prompt_length": len(generated_text.prompt_token_ids),
                "response_length": len(generated_text.outputs[0].token_ids),
                "model": model_name,
            }
            file_results[output_file].append(result)
            file_counters[output_file] += 1
    
        # Save results to respective files
        for output_file, results in file_results.items():
            output_path = Path(output_file)
            output_path.parent.mkdir(parents=True, exist_ok=True)
            with open(output_path, 'w', encoding='utf-8') as f:
                json.dump(results, f, indent=2, ensure_ascii=False)
    
    return

import json
from pathlib import Path
MODELS = {
    # "openai/gpt-oss-120b": "/scratch/asing725/Huggingface/hub/models--openai--gpt-oss-120b/snapshots/b5c939de8f754692c1647ca79fbf85e8c1e70f8a",
    "meta-llama/Llama-3.1-70B-Instruct": "/scratch/asing725/Huggingface/hub/models--meta-llama--Llama-3.1-70B-Instruct/snapshots/1605565b47bb9346c5515c34102e054115b4f98b",
    "meta-llama/Llama-3.3-70B-Instruct": "/scratch/asing725/Huggingface/hub/models--meta-llama--Llama-3.3-70B-Instruct/snapshots/6f6073b423013f6a7d4d9f39144961bfbfbc386b",
    "Qwen/Qwen2.5-72B-Instruct": "/scratch/asing725/Huggingface/hub/models--Qwen--Qwen2.5-72B-Instruct/snapshots/495f39366efef23836d0cfae4fbe635880d2be31",
}

def process_prompts(base_path: str, model_name: str):
    base_dir = Path(base_path)

    # Datasets to process
    datasets = ['cfpb', 'fir', 'fir_hash']
    
    print(f"Starting to process prompts from: {base_path}")
    print("=" * 80)
    
    # Dictionary to store all prompts: output_file -> list of prompts
    prompts_dict = {}
    
    for dataset in datasets:
        dataset_dir = base_dir / dataset
        
        if not dataset_dir.exists():
            print(f"Skipping {dataset} - directory not found")
            continue
        
        print(f"\nProcessing dataset: {dataset}")
        print("-" * 80)
        
        # Get all JSON files
        json_files = sorted(dataset_dir.glob("*.json"))
        for json_file in json_files:
            if model_name == "meta-llama/Llama-3.3-70B-Instruct" and "cfpb"==dataset and "prompts_setup1_k10" in json_file.name:
                print(f"  Skipping file {json_file.name}  for model {model_name} and dataset {dataset}")
                continue
            all_prompts = []
            print(f"\nFile: {json_file.name}")
            
            # Load JSON file
            with open(json_file, 'r') as f:
                data = json.load(f)
            
            # Handle both single object and array
            if isinstance(data, dict):
                data = [data]
            
            print(f"  Found {len(data)} entries")
            
            # Process each entry
            for idx, entry in enumerate(data, 1):
                # Extract required fields                           
                prompt = entry['prompt']
                setup = entry['setup']
                question_id = entry.get('base_question_id')
                all_prompts.append(prompt)
                
                # print(f"  Entry {idx}/{len(data)}: ID={question_id}, setup={setup}")
            
            # Store prompts with their output file path
            output_file = str(dataset_dir / json_file.name.replace("prompts", f"results_{model_name}"))
            prompts_dict[output_file] = all_prompts
            print()
    
    # Single run_hf_inference call with all prompts
    if prompts_dict:
        run_hf_inference(prompts_dict, model_path=MODELS[model_name], model_name=model_name)
                    
    print("\n" + "=" * 80)
    print(f"Processing complete!")
    print("=" * 80)
    return

In [None]:
BASE_PATH = "/scratch/asing725/CSE336/privacy_qa/all_prompts"
# process_prompts(BASE_PATH, model_name="meta-llama/Llama-3.1-70B-Instruct")
process_prompts(BASE_PATH, model_name=MODEL_NAME)
# process_prompts(BASE_PATH, model_name="Qwen/Qwen2.5-72B-Instruct")